Added bindings for MLPPReg.

This commit is contained in:
Relintai 2023-02-04 00:54:27 +01:00
parent 6a5db4e60a
commit bda7a7aee4
3 changed files with 33 additions and 0 deletions

View File

@ -139,6 +139,27 @@ Ref<MLPPMatrix> MLPPReg::reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t
return reg_driv;
}
MLPPReg::MLPPReg() {
}
MLPPReg::~MLPPReg() {
}
void MLPPReg::_bind_methods() {
ClassDB::bind_method(D_METHOD("reg_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termv);
ClassDB::bind_method(D_METHOD("reg_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termm);
ClassDB::bind_method(D_METHOD("reg_weightsv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsv);
ClassDB::bind_method(D_METHOD("reg_weightsm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsm);
ClassDB::bind_method(D_METHOD("reg_deriv_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termv);
ClassDB::bind_method(D_METHOD("reg_deriv_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termm);
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_RIDGE);
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_LASSO);
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_ELASTIC_NET);
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_WEIGHT_CLIPPING);
}
real_t MLPPReg::reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) {
MLPPActivation act;

View File

@ -39,6 +39,12 @@ public:
Ref<MLPPVector> reg_deriv_termv(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg);
Ref<MLPPMatrix> reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg);
MLPPReg();
~MLPPReg();
protected:
static void _bind_methods();
private:
real_t reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg, int j);
real_t reg_deriv_termmr(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j);
@ -60,4 +66,6 @@ private:
real_t regDerivTerm(std::vector<std::vector<real_t>> weights, real_t lambda, real_t alpha, std::string reg, int i, int j);
};
VARIANT_ENUM_CAST(MLPPReg::RegularizationType);
#endif /* Reg_hpp */

View File

@ -27,6 +27,8 @@ SOFTWARE.
#include "mlpp/lin_alg/mlpp_matrix.h"
#include "mlpp/lin_alg/mlpp_vector.h"
#include "mlpp/regularization/reg.h"
#include "mlpp/activation/activation.h"
#include "mlpp/kmeans/kmeans.h"
@ -39,6 +41,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPVector>();
ClassDB::register_class<MLPPMatrix>();
ClassDB::register_class<MLPPReg>();
ClassDB::register_class<MLPPActivation>();
ClassDB::register_class<MLPPKNN>();