From f1c7d4a22e8029301cea5855a16c8a9c6f4300f5 Mon Sep 17 00:00:00 2001 From: Relintai Date: Wed, 27 Dec 2023 12:47:04 +0100 Subject: [PATCH] Improvements to MLPPTests::test_multivariate_linear_regression_gradient_descent() and to MLPPTests::test_multivariate_linear_regression_sgd(). --- test/mlpp_tests.cpp | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 81d7ba4..feac2f0 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -222,8 +222,18 @@ void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) { Ref ds = data.load_california_housing(_california_housing_data_path); MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg - model.gradient_descent(0.001, 30, ui); - PLOG_MSG(model.model_set_test(ds->get_input())->to_string()); + model.gradient_descent(0.0000001, 30, ui); + Ref res = model.model_set_test(ds->get_input()); + + MLPPCost mlpp_cost; + + int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res); + + //Lose the bottom 14 bits (This should allow for 16384 difference.) + rmse = rmse >> 14; + rmse = rmse << 14; + + is_approx_equalsd(rmse, 163840, "test_multivariate_linear_regression_gradient_descent() RMSE"); } void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) { @@ -234,7 +244,17 @@ void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) { MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg model.sgd(0.00000001, 300000, ui); - PLOG_MSG(model.model_set_test(ds->get_input())->to_string()); + Ref res = model.model_set_test(ds->get_input()); + + MLPPCost mlpp_cost; + + int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res); + + //Lose the bottom 15 bits (This should allow for 2^15 difference.) + rmse = rmse >> 15; + rmse = rmse << 15; + + is_approx_equalsd(rmse, 98304, "test_multivariate_linear_regression_sgd() RMSE"); } void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {