mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-02 16:29:35 +01:00
Reworked MLPPTests::test_univariate_linear_regression().
This commit is contained in:
parent
5d1c4e1d23
commit
cd35ebfd79
@ -195,32 +195,24 @@ void MLPPTests::test_linear_algebra() {
|
||||
}
|
||||
|
||||
void MLPPTests::test_univariate_linear_regression() {
|
||||
// Univariate, simple linear regression, case where k = 1
|
||||
MLPPData data;
|
||||
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
|
||||
|
||||
MLPPUniLinReg model(ds->get_input(), ds->get_output());
|
||||
|
||||
std::vector<real_t> slr_res_n = {
|
||||
24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291,
|
||||
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294,
|
||||
25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226,
|
||||
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751,
|
||||
30.868464, 20.398645
|
||||
const real_t slr_res_n_arr[] = {
|
||||
24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291, //
|
||||
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294, //
|
||||
25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226, //
|
||||
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751, //
|
||||
30.868464, 20.398645 //
|
||||
};
|
||||
|
||||
Ref<MLPPVector> slr_res_v;
|
||||
slr_res_v.instance();
|
||||
slr_res_v->set_from_std_vector(slr_res_n);
|
||||
Ref<MLPPVector> slr_res_v(memnew(MLPPVector(slr_res_n_arr, 37)));
|
||||
|
||||
// Univariate, simple linear regression, case where k = 1
|
||||
MLPPData data;
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
|
||||
MLPPUniLinReg model(ds->get_input(), ds->get_output());
|
||||
|
||||
Ref<MLPPVector> res = model.model_set_test(ds->get_input());
|
||||
|
||||
if (!slr_res_v->is_equal_approx(res)) {
|
||||
ERR_PRINT("!slr_res_v->is_equal_approx(res)");
|
||||
ERR_PRINT(res->to_string());
|
||||
ERR_PRINT(slr_res_v->to_string());
|
||||
}
|
||||
is_approx_equals_vec(res, slr_res_v, "test_univariate_linear_regression()");
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
|
@ -77,21 +77,6 @@ void MLPPTestsOld::test_linear_algebra() {
|
||||
}
|
||||
|
||||
void MLPPTestsOld::test_univariate_linear_regression() {
|
||||
// Univariate, simple linear regression, case where k = 1
|
||||
MLPPData data;
|
||||
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
|
||||
|
||||
MLPPUniLinRegOld model_old(ds->get_input()->to_std_vector(), ds->get_output()->to_std_vector());
|
||||
|
||||
std::vector<real_t> slr_res = {
|
||||
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
|
||||
54.4587, 18.8083, 23.4468, 18.5432, 19.2059, 21.1938, 23.0492, 18.8083, 25.4348, 35.9046,
|
||||
37.76, 40.278, 63.8683, 68.5068, 40.4106, 46.772, 32.0612, 23.3143, 44.784, 44.519,
|
||||
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
|
||||
};
|
||||
|
||||
is_approx_equals_dvec(dstd_vec_to_vec_old(model_old.modelSetTest(ds->get_input()->to_std_vector())), dstd_vec_to_vec_old(slr_res), "stat.mode(x)");
|
||||
}
|
||||
|
||||
void MLPPTestsOld::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
|
Loading…
Reference in New Issue
Block a user