Reworked MLPPTests::test_univariate_linear_regression().

This commit is contained in:
Relintai 2023-12-27 11:12:37 +01:00
parent 5d1c4e1d23
commit cd35ebfd79
2 changed files with 13 additions and 36 deletions

View File

@ -195,32 +195,24 @@ void MLPPTests::test_linear_algebra() {
} }
void MLPPTests::test_univariate_linear_regression() { void MLPPTests::test_univariate_linear_regression() {
// Univariate, simple linear regression, case where k = 1 const real_t slr_res_n_arr[] = {
MLPPData data; 24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291, //
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294, //
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path); 25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226, //
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751, //
MLPPUniLinReg model(ds->get_input(), ds->get_output()); 30.868464, 20.398645 //
std::vector<real_t> slr_res_n = {
24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291,
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294,
25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226,
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751,
30.868464, 20.398645
}; };
Ref<MLPPVector> slr_res_v; Ref<MLPPVector> slr_res_v(memnew(MLPPVector(slr_res_n_arr, 37)));
slr_res_v.instance();
slr_res_v->set_from_std_vector(slr_res_n); // Univariate, simple linear regression, case where k = 1
MLPPData data;
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
MLPPUniLinReg model(ds->get_input(), ds->get_output());
Ref<MLPPVector> res = model.model_set_test(ds->get_input()); Ref<MLPPVector> res = model.model_set_test(ds->get_input());
if (!slr_res_v->is_equal_approx(res)) { is_approx_equals_vec(res, slr_res_v, "test_univariate_linear_regression()");
ERR_PRINT("!slr_res_v->is_equal_approx(res)");
ERR_PRINT(res->to_string());
ERR_PRINT(slr_res_v->to_string());
}
} }
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) { void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {

View File

@ -77,21 +77,6 @@ void MLPPTestsOld::test_linear_algebra() {
} }
void MLPPTestsOld::test_univariate_linear_regression() { void MLPPTestsOld::test_univariate_linear_regression() {
// Univariate, simple linear regression, case where k = 1
MLPPData data;
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
MLPPUniLinRegOld model_old(ds->get_input()->to_std_vector(), ds->get_output()->to_std_vector());
std::vector<real_t> slr_res = {
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
54.4587, 18.8083, 23.4468, 18.5432, 19.2059, 21.1938, 23.0492, 18.8083, 25.4348, 35.9046,
37.76, 40.278, 63.8683, 68.5068, 40.4106, 46.772, 32.0612, 23.3143, 44.784, 44.519,
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
};
is_approx_equals_dvec(dstd_vec_to_vec_old(model_old.modelSetTest(ds->get_input()->to_std_vector())), dstd_vec_to_vec_old(slr_res), "stat.mode(x)");
} }
void MLPPTestsOld::test_multivariate_linear_regression_gradient_descent(bool ui) { void MLPPTestsOld::test_multivariate_linear_regression_gradient_descent(bool ui) {