diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index fed2ba3..888d892 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -184,6 +184,119 @@ void MLPPTests::test_univariate_linear_regression() { is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)"); } +void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg + + model.gradientDescent(0.001, 30, ui); + alg.printVector(model.modelSetTest(ds->input)); +} + +void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg + + model.SGD(0.00000001, 300000, ui); + alg.printVector(model.modelSetTest(ds->input)); +} + +void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg + + model.MBGD(0.001, 10000, 2, ui); + alg.printVector(model.modelSetTest(ds->input)); +} + +void MLPPTests::test_multivariate_linear_regression_normal_equation(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg + + model.normalEquation(); + alg.printVector(model.modelSetTest(ds->input)); +} + +void MLPPTests::test_multivariate_linear_regression_adam() { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + MLPPLinReg adamModel(alg.transpose(ds->input), ds->output); + alg.printVector(adamModel.modelSetTest(ds->input)); + std::cout << "ACCURACY: " << 100 * adamModel.score() << "%" << std::endl; +} + +void MLPPTests::test_multivariate_linear_regression_score_sgd_adam(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + const int TRIAL_NUM = 1000; + + double scoreSGD = 0; + double scoreADAM = 0; + for (int i = 0; i < TRIAL_NUM; i++) { + MLPPLinReg modelf(alg.transpose(ds->input), ds->output); + modelf.MBGD(0.001, 5, 1, ui); + scoreSGD += modelf.score(); + + MLPPLinReg adamModelf(alg.transpose(ds->input), ds->output); + adamModelf.Adam(0.1, 5, 1, 0.9, 0.999, 1e-8, ui); // Change batch size = sgd, bgd + scoreADAM += adamModelf.score(); + } + + std::cout << "ACCURACY, AVG, SGD: " << 100 * scoreSGD / TRIAL_NUM << "%" << std::endl; + std::cout << std::endl; + std::cout << "ACCURACY, AVG, ADAM: " << 100 * scoreADAM / TRIAL_NUM << "%" << std::endl; +} + +void MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + std::cout << "Total epoch num: 300" << std::endl; + std::cout << "Method: 1st Order w/ Jacobians" << std::endl; + + MLPPLinReg model3(alg.transpose(ds->input), ds->output); // Can use Lasso, Ridge, ElasticNet Reg + model3.gradientDescent(0.001, 300, ui); + alg.printVector(model3.modelSetTest(ds->input)); +} + +void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) { + MLPPData data; + MLPPLinAlg alg; + + Ref ds = data.load_california_housing(_load_california_housing_data_path); + + std::cout << "--------------------------------------------" << std::endl; + std::cout << "Total epoch num: 300" << std::endl; + std::cout << "Method: Newtonian 2nd Order w/ Hessians" << std::endl; + MLPPLinReg model2(alg.transpose(ds->input), ds->output); + + model2.NewtonRaphson(1.5, 300, ui); + alg.printVector(model2.modelSetTest(ds->input)); +} + void MLPPTests::is_approx_equalsd(double a, double b, const String &str) { if (!Math::is_equal_approx(a, b)) { ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b)); @@ -285,6 +398,7 @@ IAEDMAT_FAILED: MLPPTests::MLPPTests() { _load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv"; + _load_california_housing_data_path = "res://datasets/CaliforniaHousing.csv"; } MLPPTests::~MLPPTests() { @@ -294,4 +408,13 @@ void MLPPTests::_bind_methods() { ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics); ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra); ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression); + + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_gradient_descent, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_sgd", "ui"), &MLPPTests::test_multivariate_linear_regression_sgd, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_mbgd", "ui"), &MLPPTests::test_multivariate_linear_regression_mbgd, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_normal_equation", "ui"), &MLPPTests::test_multivariate_linear_regression_normal_equation, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_adam"), &MLPPTests::test_multivariate_linear_regression_adam); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_score_sgd_adam", "ui"), &MLPPTests::test_multivariate_linear_regression_score_sgd_adam, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_epochs_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent, false); + ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_newton_raphson", "ui"), &MLPPTests::test_multivariate_linear_regression_newton_raphson, false); } diff --git a/test/mlpp_tests.h b/test/mlpp_tests.h index a999013..ba9f44a 100644 --- a/test/mlpp_tests.h +++ b/test/mlpp_tests.h @@ -18,6 +18,15 @@ public: void test_linear_algebra(); void test_univariate_linear_regression(); + void test_multivariate_linear_regression_gradient_descent(bool ui = false); + void test_multivariate_linear_regression_sgd(bool ui = false); + void test_multivariate_linear_regression_mbgd(bool ui = false); + void test_multivariate_linear_regression_normal_equation(bool ui = false); + void test_multivariate_linear_regression_adam(); + void test_multivariate_linear_regression_score_sgd_adam(bool ui = false); + void test_multivariate_linear_regression_epochs_gradient_descent(bool ui = false); + void test_multivariate_linear_regression_newton_raphson(bool ui = false); + void is_approx_equalsd(double a, double b, const String &str); void is_approx_equals_dvec(const Vector &a, const Vector &b, const String &str); void is_approx_equals_dmat(const Vector> &a, const Vector> &b, const String &str); @@ -29,6 +38,7 @@ protected: static void _bind_methods(); String _load_fires_and_crime_data_path; + String _load_california_housing_data_path; }; #endif