From 34f81277cd70e936cde9d8185b17e09a19d2953b Mon Sep 17 00:00:00 2001 From: Relintai Date: Sun, 12 Feb 2023 18:35:53 +0100 Subject: [PATCH] Registered MLPPANN into the ClassDB. --- mlpp/ann/ann.cpp | 3 +++ mlpp/ann/ann.h | 10 ++++++++-- register_types.cpp | 6 ++++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/mlpp/ann/ann.cpp b/mlpp/ann/ann.cpp index e42123b..0ee290d 100644 --- a/mlpp/ann/ann.cpp +++ b/mlpp/ann/ann.cpp @@ -848,3 +848,6 @@ void MLPPANN::print_ui(int epoch, real_t cost_prev, std::vector y_hat, s } } } + +void MLPPANN::_bind_methods() { +} diff --git a/mlpp/ann/ann.h b/mlpp/ann/ann.h index ea70cc1..770cd7a 100644 --- a/mlpp/ann/ann.h +++ b/mlpp/ann/ann.h @@ -9,6 +9,8 @@ #include "core/math/math_defs.h" +#include "core/object/reference.h" + #include "../hidden_layer/hidden_layer.h" #include "../output_layer/output_layer.h" @@ -19,7 +21,9 @@ #include #include -class MLPPANN { +class MLPPANN : public Reference { + GDCLASS(MLPPANN, Reference); + public: std::vector model_set_test(std::vector> X); real_t model_test(std::vector x); @@ -49,7 +53,7 @@ public: MLPPANN(); ~MLPPANN(); -private: +protected: real_t apply_learning_rate_scheduler(real_t learningRate, real_t decayConstant, real_t epoch, real_t dropRate); real_t cost(std::vector y_hat, std::vector y); @@ -60,6 +64,8 @@ private: void print_ui(int epoch, real_t cost_prev, std::vector y_hat, std::vector outputSet); + static void _bind_methods(); + std::vector> inputSet; std::vector outputSet; std::vector y_hat; diff --git a/register_types.cpp b/register_types.cpp index 2c56d4d..0dbd630 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -42,7 +42,10 @@ SOFTWARE. #include "mlpp/multi_output_layer/multi_output_layer.h" #include "mlpp/output_layer/output_layer.h" +#include "mlpp/ann/ann.h" #include "mlpp/auto_encoder/auto_encoder.h" +#include "mlpp/bernoulli_nb/bernoulli_nb.h" +#include "mlpp/c_log_log_reg/c_log_log_reg.h" #include "mlpp/dual_svc/dual_svc.h" #include "mlpp/exp_reg/exp_reg.h" #include "mlpp/gan/gan.h" @@ -63,8 +66,6 @@ SOFTWARE. #include "mlpp/tanh_reg/tanh_reg.h" #include "mlpp/uni_lin_reg/uni_lin_reg.h" #include "mlpp/wgan/wgan.h" -#include "mlpp/c_log_log_reg/c_log_log_reg.h" -#include "mlpp/bernoulli_nb/bernoulli_nb.h" #include "test/mlpp_tests.h" @@ -112,6 +113,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();