diff --git a/mlpp/c_log_log_reg/c_log_log_reg.cpp b/mlpp/c_log_log_reg/c_log_log_reg.cpp index 9f3e294..c932279 100644 --- a/mlpp/c_log_log_reg/c_log_log_reg.cpp +++ b/mlpp/c_log_log_reg/c_log_log_reg.cpp @@ -248,3 +248,6 @@ void MLPPCLogLogReg::forward_pass() { _z = propagatem(_input_set); _y_hat = avn.cloglog(_z); } + +void MLPPCLogLogReg::_bind_methods() { +} diff --git a/mlpp/c_log_log_reg/c_log_log_reg.h b/mlpp/c_log_log_reg/c_log_log_reg.h index 948b4f9..57f547c 100644 --- a/mlpp/c_log_log_reg/c_log_log_reg.h +++ b/mlpp/c_log_log_reg/c_log_log_reg.h @@ -10,10 +10,14 @@ #include "core/math/math_defs.h" +#include "core/object/reference.h" + #include #include -class MLPPCLogLogReg { +class MLPPCLogLogReg : public Reference { + GDCLASS(MLPPCLogLogReg, Reference); + public: std::vector model_set_test(std::vector> X); real_t model_test(std::vector x); @@ -30,7 +34,7 @@ public: MLPPCLogLogReg(); ~MLPPCLogLogReg(); -private: +protected: void weight_initialization(int k); void bias_initialization(); @@ -44,6 +48,8 @@ private: void forward_pass(); + static void _bind_methods(); + std::vector> _input_set; std::vector _output_set; std::vector _y_hat; diff --git a/register_types.cpp b/register_types.cpp index 442ba21..913dc44 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -63,6 +63,7 @@ SOFTWARE. #include "mlpp/tanh_reg/tanh_reg.h" #include "mlpp/uni_lin_reg/uni_lin_reg.h" #include "mlpp/wgan/wgan.h" +#include "mlpp/c_log_log_reg/c_log_log_reg.h" #include "test/mlpp_tests.h" @@ -108,6 +109,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();