Improvements to MLPPTests::test_multivariate_linear_regression_gradient_descent() and to MLPPTests::test_multivariate_linear_regression_sgd().

This commit is contained in:
Relintai 2023-12-27 12:47:04 +01:00
parent af3895545d
commit f1c7d4a22e
1 changed files with 23 additions and 3 deletions

View File

@ -222,8 +222,18 @@ void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg
model.gradient_descent(0.001, 30, ui);
PLOG_MSG(model.model_set_test(ds->get_input())->to_string());
model.gradient_descent(0.0000001, 30, ui);
Ref<MLPPVector> res = model.model_set_test(ds->get_input());
MLPPCost mlpp_cost;
int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res);
//Lose the bottom 14 bits (This should allow for 16384 difference.)
rmse = rmse >> 14;
rmse = rmse << 14;
is_approx_equalsd(rmse, 163840, "test_multivariate_linear_regression_gradient_descent() RMSE");
}
void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
@ -234,7 +244,17 @@ void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg
model.sgd(0.00000001, 300000, ui);
PLOG_MSG(model.model_set_test(ds->get_input())->to_string());
Ref<MLPPVector> res = model.model_set_test(ds->get_input());
MLPPCost mlpp_cost;
int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res);
//Lose the bottom 15 bits (This should allow for 2^15 difference.)
rmse = rmse >> 15;
rmse = rmse << 15;
is_approx_equalsd(rmse, 98304, "test_multivariate_linear_regression_sgd() RMSE");
}
void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {