mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-11-08 13:12:09 +01:00
Improvements to MLPPTests::test_multivariate_linear_regression_gradient_descent() and to MLPPTests::test_multivariate_linear_regression_sgd().
This commit is contained in:
parent
af3895545d
commit
f1c7d4a22e
@ -222,8 +222,18 @@ void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
model.gradient_descent(0.001, 30, ui);
|
||||
PLOG_MSG(model.model_set_test(ds->get_input())->to_string());
|
||||
model.gradient_descent(0.0000001, 30, ui);
|
||||
Ref<MLPPVector> res = model.model_set_test(ds->get_input());
|
||||
|
||||
MLPPCost mlpp_cost;
|
||||
|
||||
int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res);
|
||||
|
||||
//Lose the bottom 14 bits (This should allow for 16384 difference.)
|
||||
rmse = rmse >> 14;
|
||||
rmse = rmse << 14;
|
||||
|
||||
is_approx_equalsd(rmse, 163840, "test_multivariate_linear_regression_gradient_descent() RMSE");
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
|
||||
@ -234,7 +244,17 @@ void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
|
||||
|
||||
MLPPLinReg model(ds->get_input(), ds->get_output()); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
model.sgd(0.00000001, 300000, ui);
|
||||
PLOG_MSG(model.model_set_test(ds->get_input())->to_string());
|
||||
Ref<MLPPVector> res = model.model_set_test(ds->get_input());
|
||||
|
||||
MLPPCost mlpp_cost;
|
||||
|
||||
int rmse = (int)mlpp_cost.rmsev(ds->get_output(), res);
|
||||
|
||||
//Lose the bottom 15 bits (This should allow for 2^15 difference.)
|
||||
rmse = rmse >> 15;
|
||||
rmse = rmse << 15;
|
||||
|
||||
is_approx_equalsd(rmse, 98304, "test_multivariate_linear_regression_sgd() RMSE");
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {
|
||||
|
Loading…
Reference in New Issue
Block a user