mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-04-08 02:41:47 +02:00
Added bindings for MLPPReg.
This commit is contained in:
parent
6a5db4e60a
commit
bda7a7aee4
@ -139,6 +139,27 @@ Ref<MLPPMatrix> MLPPReg::reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t
|
|||||||
return reg_driv;
|
return reg_driv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MLPPReg::MLPPReg() {
|
||||||
|
}
|
||||||
|
MLPPReg::~MLPPReg() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLPPReg::_bind_methods() {
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termv);
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termm);
|
||||||
|
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_weightsv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsv);
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_weightsm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsm);
|
||||||
|
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_deriv_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termv);
|
||||||
|
ClassDB::bind_method(D_METHOD("reg_deriv_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termm);
|
||||||
|
|
||||||
|
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_RIDGE);
|
||||||
|
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_LASSO);
|
||||||
|
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_ELASTIC_NET);
|
||||||
|
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_WEIGHT_CLIPPING);
|
||||||
|
}
|
||||||
|
|
||||||
real_t MLPPReg::reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) {
|
real_t MLPPReg::reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) {
|
||||||
MLPPActivation act;
|
MLPPActivation act;
|
||||||
|
|
||||||
|
@ -39,6 +39,12 @@ public:
|
|||||||
Ref<MLPPVector> reg_deriv_termv(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
Ref<MLPPVector> reg_deriv_termv(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
||||||
Ref<MLPPMatrix> reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
Ref<MLPPMatrix> reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
||||||
|
|
||||||
|
MLPPReg();
|
||||||
|
~MLPPReg();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
real_t reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg, int j);
|
real_t reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg, int j);
|
||||||
real_t reg_deriv_termmr(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j);
|
real_t reg_deriv_termmr(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j);
|
||||||
@ -60,4 +66,6 @@ private:
|
|||||||
real_t regDerivTerm(std::vector<std::vector<real_t>> weights, real_t lambda, real_t alpha, std::string reg, int i, int j);
|
real_t regDerivTerm(std::vector<std::vector<real_t>> weights, real_t lambda, real_t alpha, std::string reg, int i, int j);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
VARIANT_ENUM_CAST(MLPPReg::RegularizationType);
|
||||||
|
|
||||||
#endif /* Reg_hpp */
|
#endif /* Reg_hpp */
|
||||||
|
@ -27,6 +27,8 @@ SOFTWARE.
|
|||||||
#include "mlpp/lin_alg/mlpp_matrix.h"
|
#include "mlpp/lin_alg/mlpp_matrix.h"
|
||||||
#include "mlpp/lin_alg/mlpp_vector.h"
|
#include "mlpp/lin_alg/mlpp_vector.h"
|
||||||
|
|
||||||
|
#include "mlpp/regularization/reg.h"
|
||||||
|
|
||||||
#include "mlpp/activation/activation.h"
|
#include "mlpp/activation/activation.h"
|
||||||
|
|
||||||
#include "mlpp/kmeans/kmeans.h"
|
#include "mlpp/kmeans/kmeans.h"
|
||||||
@ -39,6 +41,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
|||||||
ClassDB::register_class<MLPPVector>();
|
ClassDB::register_class<MLPPVector>();
|
||||||
ClassDB::register_class<MLPPMatrix>();
|
ClassDB::register_class<MLPPMatrix>();
|
||||||
|
|
||||||
|
ClassDB::register_class<MLPPReg>();
|
||||||
|
|
||||||
ClassDB::register_class<MLPPActivation>();
|
ClassDB::register_class<MLPPActivation>();
|
||||||
|
|
||||||
ClassDB::register_class<MLPPKNN>();
|
ClassDB::register_class<MLPPKNN>();
|
||||||
|
Loading…
Reference in New Issue
Block a user