From 61793e4a4d9eb8fd2e6c5e8c0b662af276aa9f0f Mon Sep 17 00:00:00 2001 From: Relintai Date: Sun, 12 Feb 2023 16:30:22 +0100 Subject: [PATCH] Also register MLPPExpReg to the ClassDB. --- mlpp/exp_reg/exp_reg.cpp | 8 ++++++++ mlpp/exp_reg/exp_reg.h | 13 +++++++++++-- register_types.cpp | 2 ++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mlpp/exp_reg/exp_reg.cpp b/mlpp/exp_reg/exp_reg.cpp index 44d379d..91da395 100644 --- a/mlpp/exp_reg/exp_reg.cpp +++ b/mlpp/exp_reg/exp_reg.cpp @@ -229,6 +229,11 @@ MLPPExpReg::MLPPExpReg(std::vector> p_input_set, std::vector _bias = MLPPUtilities::biasInitialization(); } +MLPPExpReg::MLPPExpReg() { +} +MLPPExpReg::~MLPPExpReg() { +} + real_t MLPPExpReg::cost(std::vector y_hat, std::vector y) { MLPPReg regularization; MLPPCost mlpp_cost; @@ -265,3 +270,6 @@ std::vector MLPPExpReg::evaluatem(std::vector> X) { void MLPPExpReg::forward_pass() { _y_hat = evaluatem(_input_set); } + +void MLPPExpReg::_bind_methods() { +} diff --git a/mlpp/exp_reg/exp_reg.h b/mlpp/exp_reg/exp_reg.h index 0b475e7..494b45d 100644 --- a/mlpp/exp_reg/exp_reg.h +++ b/mlpp/exp_reg/exp_reg.h @@ -10,10 +10,14 @@ #include "core/math/math_defs.h" +#include "core/object/reference.h" + #include #include -class MLPPExpReg { +class MLPPExpReg : public Reference { + GDCLASS(MLPPExpReg, Reference); + public: std::vector model_set_test(std::vector> X); real_t model_test(std::vector x); @@ -28,7 +32,10 @@ public: MLPPExpReg(std::vector> p_input_set, std::vector p_output_set, std::string p_reg = "None", real_t p_lambda = 0.5, real_t p_alpha = 0.5); -private: + MLPPExpReg(); + ~MLPPExpReg(); + +protected: real_t cost(std::vector y_hat, std::vector y); real_t evaluatev(std::vector x); @@ -36,6 +43,8 @@ private: void forward_pass(); + static void _bind_methods(); + std::vector> _input_set; std::vector _output_set; std::vector _y_hat; diff --git a/register_types.cpp b/register_types.cpp index 4cff229..44f9e83 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -60,6 +60,7 @@ SOFTWARE. #include "mlpp/lin_reg/lin_reg.h" #include "mlpp/gaussian_nb/gaussian_nb.h" #include "mlpp/gan/gan.h" +#include "mlpp/exp_reg/exp_reg.h" #include "test/mlpp_tests.h" @@ -102,6 +103,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();