From cd35ebfd794b468604c261522c6f565876902552 Mon Sep 17 00:00:00 2001 From: Relintai Date: Wed, 27 Dec 2023 11:12:37 +0100 Subject: [PATCH] Reworked MLPPTests::test_univariate_linear_regression(). --- test/mlpp_tests.cpp | 34 +++++++++++++--------------------- test/mlpp_tests_old.cpp | 15 --------------- 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 4bc4286..81d7ba4 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -195,32 +195,24 @@ void MLPPTests::test_linear_algebra() { } void MLPPTests::test_univariate_linear_regression() { - // Univariate, simple linear regression, case where k = 1 - MLPPData data; - - Ref ds = data.load_fires_and_crime(_fires_and_crime_data_path); - - MLPPUniLinReg model(ds->get_input(), ds->get_output()); - - std::vector slr_res_n = { - 24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291, - 35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294, - 25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226, - 23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751, - 30.868464, 20.398645 + const real_t slr_res_n_arr[] = { + 24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291, // + 35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294, // + 25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226, // + 23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751, // + 30.868464, 20.398645 // }; - Ref slr_res_v; - slr_res_v.instance(); - slr_res_v->set_from_std_vector(slr_res_n); + Ref slr_res_v(memnew(MLPPVector(slr_res_n_arr, 37))); + + // Univariate, simple linear regression, case where k = 1 + MLPPData data; + Ref ds = data.load_fires_and_crime(_fires_and_crime_data_path); + MLPPUniLinReg model(ds->get_input(), ds->get_output()); Ref res = model.model_set_test(ds->get_input()); - if (!slr_res_v->is_equal_approx(res)) { - ERR_PRINT("!slr_res_v->is_equal_approx(res)"); - ERR_PRINT(res->to_string()); - ERR_PRINT(slr_res_v->to_string()); - } + is_approx_equals_vec(res, slr_res_v, "test_univariate_linear_regression()"); } void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) { diff --git a/test/mlpp_tests_old.cpp b/test/mlpp_tests_old.cpp index 9421a01..973b902 100644 --- a/test/mlpp_tests_old.cpp +++ b/test/mlpp_tests_old.cpp @@ -77,21 +77,6 @@ void MLPPTestsOld::test_linear_algebra() { } void MLPPTestsOld::test_univariate_linear_regression() { - // Univariate, simple linear regression, case where k = 1 - MLPPData data; - - Ref ds = data.load_fires_and_crime(_fires_and_crime_data_path); - - MLPPUniLinRegOld model_old(ds->get_input()->to_std_vector(), ds->get_output()->to_std_vector()); - - std::vector slr_res = { - 24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046, - 54.4587, 18.8083, 23.4468, 18.5432, 19.2059, 21.1938, 23.0492, 18.8083, 25.4348, 35.9046, - 37.76, 40.278, 63.8683, 68.5068, 40.4106, 46.772, 32.0612, 23.3143, 44.784, 44.519, - 27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986 - }; - - is_approx_equals_dvec(dstd_vec_to_vec_old(model_old.modelSetTest(ds->get_input()->to_std_vector())), dstd_vec_to_vec_old(slr_res), "stat.mode(x)"); } void MLPPTestsOld::test_multivariate_linear_regression_gradient_descent(bool ui) {