diff --git a/mlpp/regularization/reg.cpp b/mlpp/regularization/reg.cpp index 8c725f2..ad9e984 100644 --- a/mlpp/regularization/reg.cpp +++ b/mlpp/regularization/reg.cpp @@ -139,6 +139,27 @@ Ref MLPPReg::reg_deriv_termm(const Ref &weights, real_t return reg_driv; } +MLPPReg::MLPPReg() { +} +MLPPReg::~MLPPReg() { +} + +void MLPPReg::_bind_methods() { + ClassDB::bind_method(D_METHOD("reg_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termv); + ClassDB::bind_method(D_METHOD("reg_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termm); + + ClassDB::bind_method(D_METHOD("reg_weightsv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsv); + ClassDB::bind_method(D_METHOD("reg_weightsm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsm); + + ClassDB::bind_method(D_METHOD("reg_deriv_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termv); + ClassDB::bind_method(D_METHOD("reg_deriv_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termm); + + BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_RIDGE); + BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_LASSO); + BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_ELASTIC_NET); + BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_WEIGHT_CLIPPING); +} + real_t MLPPReg::reg_deriv_termvr(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) { MLPPActivation act; diff --git a/mlpp/regularization/reg.h b/mlpp/regularization/reg.h index e27827d..e7e1418 100644 --- a/mlpp/regularization/reg.h +++ b/mlpp/regularization/reg.h @@ -39,6 +39,12 @@ public: Ref reg_deriv_termv(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); Ref reg_deriv_termm(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + MLPPReg(); + ~MLPPReg(); + +protected: + static void _bind_methods(); + private: real_t reg_deriv_termvr(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg, int j); real_t reg_deriv_termmr(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j); @@ -60,4 +66,6 @@ private: real_t regDerivTerm(std::vector> weights, real_t lambda, real_t alpha, std::string reg, int i, int j); }; +VARIANT_ENUM_CAST(MLPPReg::RegularizationType); + #endif /* Reg_hpp */ diff --git a/register_types.cpp b/register_types.cpp index 47b0042..19f9db5 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -27,6 +27,8 @@ SOFTWARE. #include "mlpp/lin_alg/mlpp_matrix.h" #include "mlpp/lin_alg/mlpp_vector.h" +#include "mlpp/regularization/reg.h" + #include "mlpp/activation/activation.h" #include "mlpp/kmeans/kmeans.h" @@ -39,6 +41,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class();