mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-10 17:49:36 +01:00
Added bindings for MLPPReg.
This commit is contained in:
parent
6a5db4e60a
commit
bda7a7aee4
@ -139,6 +139,27 @@ Ref<MLPPMatrix> MLPPReg::reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t
|
||||
return reg_driv;
|
||||
}
|
||||
|
||||
MLPPReg::MLPPReg() {
|
||||
}
|
||||
MLPPReg::~MLPPReg() {
|
||||
}
|
||||
|
||||
void MLPPReg::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("reg_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termv);
|
||||
ClassDB::bind_method(D_METHOD("reg_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_termm);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("reg_weightsv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsv);
|
||||
ClassDB::bind_method(D_METHOD("reg_weightsm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_weightsm);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("reg_deriv_termv", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termv);
|
||||
ClassDB::bind_method(D_METHOD("reg_deriv_termm", "weights", "lambda", "alpha", "reg"), &MLPPReg::reg_deriv_termm);
|
||||
|
||||
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_RIDGE);
|
||||
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_LASSO);
|
||||
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_ELASTIC_NET);
|
||||
BIND_ENUM_CONSTANT(REGULARIZATION_TYPE_WEIGHT_CLIPPING);
|
||||
}
|
||||
|
||||
real_t MLPPReg::reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) {
|
||||
MLPPActivation act;
|
||||
|
||||
|
@ -39,6 +39,12 @@ public:
|
||||
Ref<MLPPVector> reg_deriv_termv(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
||||
Ref<MLPPMatrix> reg_deriv_termm(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg);
|
||||
|
||||
MLPPReg();
|
||||
~MLPPReg();
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
|
||||
private:
|
||||
real_t reg_deriv_termvr(const Ref<MLPPVector> &weights, real_t lambda, real_t alpha, RegularizationType reg, int j);
|
||||
real_t reg_deriv_termmr(const Ref<MLPPMatrix> &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j);
|
||||
@ -60,4 +66,6 @@ private:
|
||||
real_t regDerivTerm(std::vector<std::vector<real_t>> weights, real_t lambda, real_t alpha, std::string reg, int i, int j);
|
||||
};
|
||||
|
||||
VARIANT_ENUM_CAST(MLPPReg::RegularizationType);
|
||||
|
||||
#endif /* Reg_hpp */
|
||||
|
@ -27,6 +27,8 @@ SOFTWARE.
|
||||
#include "mlpp/lin_alg/mlpp_matrix.h"
|
||||
#include "mlpp/lin_alg/mlpp_vector.h"
|
||||
|
||||
#include "mlpp/regularization/reg.h"
|
||||
|
||||
#include "mlpp/activation/activation.h"
|
||||
|
||||
#include "mlpp/kmeans/kmeans.h"
|
||||
@ -39,6 +41,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
ClassDB::register_class<MLPPVector>();
|
||||
ClassDB::register_class<MLPPMatrix>();
|
||||
|
||||
ClassDB::register_class<MLPPReg>();
|
||||
|
||||
ClassDB::register_class<MLPPActivation>();
|
||||
|
||||
ClassDB::register_class<MLPPKNN>();
|
||||
|
Loading…
Reference in New Issue
Block a user