diff --git a/mlpp/bernoulli_nb/bernoulli_nb.cpp b/mlpp/bernoulli_nb/bernoulli_nb.cpp index 784564e..4f8a691 100644 --- a/mlpp/bernoulli_nb/bernoulli_nb.cpp +++ b/mlpp/bernoulli_nb/bernoulli_nb.cpp @@ -187,3 +187,6 @@ void MLPPBernoulliNB::evaluate() { } } } + +void MLPPBernoulliNB::_bind_methods() { +} diff --git a/mlpp/bernoulli_nb/bernoulli_nb.h b/mlpp/bernoulli_nb/bernoulli_nb.h index b91c073..ec79e55 100644 --- a/mlpp/bernoulli_nb/bernoulli_nb.h +++ b/mlpp/bernoulli_nb/bernoulli_nb.h @@ -10,10 +10,14 @@ #include "core/math/math_defs.h" +#include "core/object/reference.h" + #include #include -class MLPPBernoulliNB { +class MLPPBernoulliNB : public Reference { + GDCLASS(MLPPBernoulliNB, Reference); + public: std::vector model_set_test(std::vector> X); real_t model_test(std::vector x); @@ -25,11 +29,13 @@ public: MLPPBernoulliNB(); ~MLPPBernoulliNB(); -private: +protected: void compute_vocab(); void compute_theta(); void evaluate(); + static void _bind_methods(); + // Model Params real_t _prior_1; real_t _prior_0; diff --git a/register_types.cpp b/register_types.cpp index 913dc44..2c56d4d 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -64,6 +64,7 @@ SOFTWARE. #include "mlpp/uni_lin_reg/uni_lin_reg.h" #include "mlpp/wgan/wgan.h" #include "mlpp/c_log_log_reg/c_log_log_reg.h" +#include "mlpp/bernoulli_nb/bernoulli_nb.h" #include "test/mlpp_tests.h" @@ -110,6 +111,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();