diff --git a/mlpp/ann/ann.cpp b/mlpp/ann/ann.cpp index e42123b..0ee290d 100644 --- a/mlpp/ann/ann.cpp +++ b/mlpp/ann/ann.cpp @@ -848,3 +848,6 @@ void MLPPANN::print_ui(int epoch, real_t cost_prev, std::vector y_hat, s } } } + +void MLPPANN::_bind_methods() { +} diff --git a/mlpp/ann/ann.h b/mlpp/ann/ann.h index ea70cc1..770cd7a 100644 --- a/mlpp/ann/ann.h +++ b/mlpp/ann/ann.h @@ -9,6 +9,8 @@ #include "core/math/math_defs.h" +#include "core/object/reference.h" + #include "../hidden_layer/hidden_layer.h" #include "../output_layer/output_layer.h" @@ -19,7 +21,9 @@ #include #include -class MLPPANN { +class MLPPANN : public Reference { + GDCLASS(MLPPANN, Reference); + public: std::vector model_set_test(std::vector> X); real_t model_test(std::vector x); @@ -49,7 +53,7 @@ public: MLPPANN(); ~MLPPANN(); -private: +protected: real_t apply_learning_rate_scheduler(real_t learningRate, real_t decayConstant, real_t epoch, real_t dropRate); real_t cost(std::vector y_hat, std::vector y); @@ -60,6 +64,8 @@ private: void print_ui(int epoch, real_t cost_prev, std::vector y_hat, std::vector outputSet); + static void _bind_methods(); + std::vector> inputSet; std::vector outputSet; std::vector y_hat; diff --git a/register_types.cpp b/register_types.cpp index 2c56d4d..0dbd630 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -42,7 +42,10 @@ SOFTWARE. #include "mlpp/multi_output_layer/multi_output_layer.h" #include "mlpp/output_layer/output_layer.h" +#include "mlpp/ann/ann.h" #include "mlpp/auto_encoder/auto_encoder.h" +#include "mlpp/bernoulli_nb/bernoulli_nb.h" +#include "mlpp/c_log_log_reg/c_log_log_reg.h" #include "mlpp/dual_svc/dual_svc.h" #include "mlpp/exp_reg/exp_reg.h" #include "mlpp/gan/gan.h" @@ -63,8 +66,6 @@ SOFTWARE. #include "mlpp/tanh_reg/tanh_reg.h" #include "mlpp/uni_lin_reg/uni_lin_reg.h" #include "mlpp/wgan/wgan.h" -#include "mlpp/c_log_log_reg/c_log_log_reg.h" -#include "mlpp/bernoulli_nb/bernoulli_nb.h" #include "test/mlpp_tests.h" @@ -112,6 +113,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();