diff --git a/mlpp/stat/stat.cpp b/mlpp/stat/stat.cpp index 1e3a7c3..11312f3 100644 --- a/mlpp/stat/stat.cpp +++ b/mlpp/stat/stat.cpp @@ -22,6 +22,13 @@ real_t MLPPStat::b1Estimation(const std::vector &x, const std::vector &x, const Ref &y) { + return meanv(y) - b1_estimation(x, y) * meanv(x); +} +real_t MLPPStat::b1_estimation(const Ref &x, const Ref &y) { + return covariancev(x, y) / variancev(x); +} + real_t MLPPStat::mean(const std::vector &x) { real_t sum = 0; for (int i = 0; i < x.size(); i++) { @@ -126,6 +133,21 @@ real_t MLPPStat::meanv(const Ref &x) { return sum / x_size; } +real_t MLPPStat::variancev(const Ref &x) { + real_t x_mean = meanv(x); + + int x_size = x->size(); + const real_t *x_ptr = x->ptr(); + + real_t sum = 0; + for (int i = 0; i < x_size; ++i) { + real_t xi = x_ptr[i]; + + sum += (xi - x_mean) * (xi - x_mean); + } + return sum / (x_size - 1); +} + real_t MLPPStat::covariancev(const Ref &x, const Ref &y) { ERR_FAIL_COND_V(x->size() != y->size(), 0); diff --git a/mlpp/stat/stat.h b/mlpp/stat/stat.h index c9b2a2e..b19cc2c 100644 --- a/mlpp/stat/stat.h +++ b/mlpp/stat/stat.h @@ -21,6 +21,9 @@ public: real_t b0Estimation(const std::vector &x, const std::vector &y); real_t b1Estimation(const std::vector &x, const std::vector &y); + real_t b0_estimation(const Ref &x, const Ref &y); + real_t b1_estimation(const Ref &x, const Ref &y); + // Statistical Functions real_t mean(const std::vector &x); real_t median(std::vector x); @@ -36,6 +39,7 @@ public: real_t chebyshevIneq(const real_t k); real_t meanv(const Ref &x); + real_t variancev(const Ref &x); real_t covariancev(const Ref &x, const Ref &y); // Extras diff --git a/mlpp/uni_lin_reg/uni_lin_reg.cpp b/mlpp/uni_lin_reg/uni_lin_reg.cpp index 12b1928..d5abe6c 100644 --- a/mlpp/uni_lin_reg/uni_lin_reg.cpp +++ b/mlpp/uni_lin_reg/uni_lin_reg.cpp @@ -9,26 +9,83 @@ #include "../lin_alg/lin_alg.h" #include "../stat/stat.h" -#include - // General Multivariate Linear Regression Model // ŷ = b0 + b1x1 + b2x2 + ... + bkxk // Univariate Linear Regression Model // ŷ = b0 + b1x1 -MLPPUniLinReg::MLPPUniLinReg(std::vector x, std::vector y) : - inputSet(x), outputSet(y) { +Ref MLPPUniLinReg::get_input_set() { + return _input_set; +} +void MLPPUniLinReg::set_input_set(const Ref &val) { + _input_set = val; +} + +Ref MLPPUniLinReg::get_output_set() { + return _output_set; +} +void MLPPUniLinReg::set_output_set(const Ref &val) { + _output_set = val; +} + +real_t MLPPUniLinReg::get_b0() { + return _b0; +} +real_t MLPPUniLinReg::get_b1() { + return _b1; +} + +void MLPPUniLinReg::initialize() { + ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); + MLPPStat estimator; - b1 = estimator.b1Estimation(inputSet, outputSet); - b0 = estimator.b0Estimation(inputSet, outputSet); + + _b1 = estimator.b1_estimation(_input_set, _output_set); + _b0 = estimator.b0_estimation(_input_set, _output_set); } -std::vector MLPPUniLinReg::modelSetTest(std::vector x) { +Ref MLPPUniLinReg::model_set_test(const Ref &x) { MLPPLinAlg alg; - return alg.scalarAdd(b0, alg.scalarMultiply(b1, x)); + + return alg.scalar_addnv(_b0, alg.scalar_multiplynv(_b1, x)); } -real_t MLPPUniLinReg::modelTest(real_t input) { - return b0 + b1 * input; +real_t MLPPUniLinReg::model_test(real_t x) { + return _b0 + _b1 * x; +} + +MLPPUniLinReg::MLPPUniLinReg(const Ref &p_input_set, const Ref &p_output_set) { + _input_set = p_input_set; + _output_set = p_output_set; + + MLPPStat estimator; + + _b1 = estimator.b1_estimation(_input_set, _output_set); + _b0 = estimator.b0_estimation(_input_set, _output_set); +} + +MLPPUniLinReg::MLPPUniLinReg() { + _b0 = 0; + _b1 = 0; +} +MLPPUniLinReg::~MLPPUniLinReg() { +} + +void MLPPUniLinReg::_bind_methods() { + ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPUniLinReg::get_input_set); + ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPUniLinReg::set_input_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set"); + + ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPUniLinReg::get_output_set); + ClassDB::bind_method(D_METHOD("set_output_set", "val"), &MLPPUniLinReg::set_output_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set"); + + ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0); + ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1); + + ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize); + + ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test); + ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test); } diff --git a/mlpp/uni_lin_reg/uni_lin_reg.h b/mlpp/uni_lin_reg/uni_lin_reg.h index 39d852b..3a293c9 100644 --- a/mlpp/uni_lin_reg/uni_lin_reg.h +++ b/mlpp/uni_lin_reg/uni_lin_reg.h @@ -10,20 +10,42 @@ #include "core/math/math_defs.h" -#include +#include "core/object/reference.h" + +#include "../lin_alg/mlpp_matrix.h" +#include "../lin_alg/mlpp_vector.h" + +class MLPPUniLinReg : public Reference { + GDCLASS(MLPPUniLinReg, Reference); -class MLPPUniLinReg { public: - MLPPUniLinReg(std::vector x, std::vector y); - std::vector modelSetTest(std::vector x); - real_t modelTest(real_t x); + Ref get_input_set(); + void set_input_set(const Ref &val); -private: - std::vector inputSet; - std::vector outputSet; + Ref get_output_set(); + void set_output_set(const Ref &val); - real_t b0; - real_t b1; + real_t get_b0(); + real_t get_b1(); + + void initialize(); + + Ref model_set_test(const Ref &x); + real_t model_test(real_t x); + + MLPPUniLinReg(const Ref &p_input_set, const Ref &p_output_set); + + MLPPUniLinReg(); + ~MLPPUniLinReg(); + +protected: + static void _bind_methods(); + + Ref _input_set; + Ref _output_set; + + real_t _b0; + real_t _b1; }; #endif /* UniLinReg_hpp */ diff --git a/register_types.cpp b/register_types.cpp index 5848de5..0406b71 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -39,6 +39,7 @@ SOFTWARE. #include "mlpp/kmeans/kmeans.h" #include "mlpp/knn/knn.h" #include "mlpp/pca/pca.h" +#include "mlpp/uni_lin_reg/uni_lin_reg.h" #include "mlpp/wgan/wgan.h" #include "mlpp/mlp/mlp.h" @@ -65,6 +66,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 696676b..008f09f 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -49,6 +49,7 @@ #include "../mlpp/mlp/mlp_old.h" #include "../mlpp/pca/pca_old.h" +#include "../mlpp/uni_lin_reg/uni_lin_reg_old.h" #include "../mlpp/wgan/wgan_old.h" Vector dstd_vec_to_vec(const std::vector &in) { @@ -181,7 +182,7 @@ void MLPPTests::test_univariate_linear_regression() { Ref ds = data.load_fires_and_crime(_fires_and_crime_data_path); - MLPPUniLinReg model(ds->input, ds->output); + MLPPUniLinRegOld model_old(ds->input, ds->output); std::vector slr_res = { 24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046, @@ -190,7 +191,37 @@ void MLPPTests::test_univariate_linear_regression() { 27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986 }; - is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)"); + is_approx_equals_dvec(dstd_vec_to_vec(model_old.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)"); + + Ref input; + input.instance(); + input->set_from_std_vector(ds->input); + + Ref output; + output.instance(); + output->set_from_std_vector(ds->output); + + MLPPUniLinReg model(input, output); + + std::vector slr_res_n = { + 24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291, + 35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294, + 25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226, + 23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751, + 30.868464, 20.398645 + }; + + Ref slr_res_v; + slr_res_v.instance(); + slr_res_v->set_from_std_vector(slr_res_n); + + Ref res = model.model_set_test(input); + + if (!slr_res_v->is_equal_approx(res)) { + ERR_PRINT("!slr_res_v->is_equal_approx(res)"); + ERR_PRINT(res->to_string()); + ERR_PRINT(slr_res_v->to_string()); + } } void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {