Also register MLPPExpReg to the ClassDB.

This commit is contained in:
Relintai 2023-02-12 16:30:22 +01:00
parent 7659f8c4ea
commit 61793e4a4d
3 changed files with 21 additions and 2 deletions

View File

@ -229,6 +229,11 @@ MLPPExpReg::MLPPExpReg(std::vector<std::vector<real_t>> p_input_set, std::vector
_bias = MLPPUtilities::biasInitialization(); _bias = MLPPUtilities::biasInitialization();
} }
MLPPExpReg::MLPPExpReg() {
}
MLPPExpReg::~MLPPExpReg() {
}
real_t MLPPExpReg::cost(std::vector<real_t> y_hat, std::vector<real_t> y) { real_t MLPPExpReg::cost(std::vector<real_t> y_hat, std::vector<real_t> y) {
MLPPReg regularization; MLPPReg regularization;
MLPPCost mlpp_cost; MLPPCost mlpp_cost;
@ -265,3 +270,6 @@ std::vector<real_t> MLPPExpReg::evaluatem(std::vector<std::vector<real_t>> X) {
void MLPPExpReg::forward_pass() { void MLPPExpReg::forward_pass() {
_y_hat = evaluatem(_input_set); _y_hat = evaluatem(_input_set);
} }
void MLPPExpReg::_bind_methods() {
}

View File

@ -10,10 +10,14 @@
#include "core/math/math_defs.h" #include "core/math/math_defs.h"
#include "core/object/reference.h"
#include <string> #include <string>
#include <vector> #include <vector>
class MLPPExpReg { class MLPPExpReg : public Reference {
GDCLASS(MLPPExpReg, Reference);
public: public:
std::vector<real_t> model_set_test(std::vector<std::vector<real_t>> X); std::vector<real_t> model_set_test(std::vector<std::vector<real_t>> X);
real_t model_test(std::vector<real_t> x); real_t model_test(std::vector<real_t> x);
@ -28,7 +32,10 @@ public:
MLPPExpReg(std::vector<std::vector<real_t>> p_input_set, std::vector<real_t> p_output_set, std::string p_reg = "None", real_t p_lambda = 0.5, real_t p_alpha = 0.5); MLPPExpReg(std::vector<std::vector<real_t>> p_input_set, std::vector<real_t> p_output_set, std::string p_reg = "None", real_t p_lambda = 0.5, real_t p_alpha = 0.5);
private: MLPPExpReg();
~MLPPExpReg();
protected:
real_t cost(std::vector<real_t> y_hat, std::vector<real_t> y); real_t cost(std::vector<real_t> y_hat, std::vector<real_t> y);
real_t evaluatev(std::vector<real_t> x); real_t evaluatev(std::vector<real_t> x);
@ -36,6 +43,8 @@ private:
void forward_pass(); void forward_pass();
static void _bind_methods();
std::vector<std::vector<real_t>> _input_set; std::vector<std::vector<real_t>> _input_set;
std::vector<real_t> _output_set; std::vector<real_t> _output_set;
std::vector<real_t> _y_hat; std::vector<real_t> _y_hat;

View File

@ -60,6 +60,7 @@ SOFTWARE.
#include "mlpp/lin_reg/lin_reg.h" #include "mlpp/lin_reg/lin_reg.h"
#include "mlpp/gaussian_nb/gaussian_nb.h" #include "mlpp/gaussian_nb/gaussian_nb.h"
#include "mlpp/gan/gan.h" #include "mlpp/gan/gan.h"
#include "mlpp/exp_reg/exp_reg.h"
#include "test/mlpp_tests.h" #include "test/mlpp_tests.h"
@ -102,6 +103,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPLinReg>(); ClassDB::register_class<MLPPLinReg>();
ClassDB::register_class<MLPPGaussianNB>(); ClassDB::register_class<MLPPGaussianNB>();
ClassDB::register_class<MLPPGAN>(); ClassDB::register_class<MLPPGAN>();
ClassDB::register_class<MLPPExpReg>();
ClassDB::register_class<MLPPDataESimple>(); ClassDB::register_class<MLPPDataESimple>();
ClassDB::register_class<MLPPDataSimple>(); ClassDB::register_class<MLPPDataSimple>();