Port multivariate linear regression tests.

This commit is contained in:
Relintai 2023-01-25 23:43:45 +01:00
parent c79bd2050d
commit 203932973b
2 changed files with 133 additions and 0 deletions

View File

@ -184,6 +184,119 @@ void MLPPTests::test_univariate_linear_regression() {
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)"); is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
} }
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
model.gradientDescent(0.001, 30, ui);
alg.printVector(model.modelSetTest(ds->input));
}
void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
model.SGD(0.00000001, 300000, ui);
alg.printVector(model.modelSetTest(ds->input));
}
void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
model.MBGD(0.001, 10000, 2, ui);
alg.printVector(model.modelSetTest(ds->input));
}
void MLPPTests::test_multivariate_linear_regression_normal_equation(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
model.normalEquation();
alg.printVector(model.modelSetTest(ds->input));
}
void MLPPTests::test_multivariate_linear_regression_adam() {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
MLPPLinReg adamModel(alg.transpose(ds->input), ds->output);
alg.printVector(adamModel.modelSetTest(ds->input));
std::cout << "ACCURACY: " << 100 * adamModel.score() << "%" << std::endl;
}
void MLPPTests::test_multivariate_linear_regression_score_sgd_adam(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
const int TRIAL_NUM = 1000;
double scoreSGD = 0;
double scoreADAM = 0;
for (int i = 0; i < TRIAL_NUM; i++) {
MLPPLinReg modelf(alg.transpose(ds->input), ds->output);
modelf.MBGD(0.001, 5, 1, ui);
scoreSGD += modelf.score();
MLPPLinReg adamModelf(alg.transpose(ds->input), ds->output);
adamModelf.Adam(0.1, 5, 1, 0.9, 0.999, 1e-8, ui); // Change batch size = sgd, bgd
scoreADAM += adamModelf.score();
}
std::cout << "ACCURACY, AVG, SGD: " << 100 * scoreSGD / TRIAL_NUM << "%" << std::endl;
std::cout << std::endl;
std::cout << "ACCURACY, AVG, ADAM: " << 100 * scoreADAM / TRIAL_NUM << "%" << std::endl;
}
void MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
std::cout << "Total epoch num: 300" << std::endl;
std::cout << "Method: 1st Order w/ Jacobians" << std::endl;
MLPPLinReg model3(alg.transpose(ds->input), ds->output); // Can use Lasso, Ridge, ElasticNet Reg
model3.gradientDescent(0.001, 300, ui);
alg.printVector(model3.modelSetTest(ds->input));
}
void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
MLPPData data;
MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
std::cout << "--------------------------------------------" << std::endl;
std::cout << "Total epoch num: 300" << std::endl;
std::cout << "Method: Newtonian 2nd Order w/ Hessians" << std::endl;
MLPPLinReg model2(alg.transpose(ds->input), ds->output);
model2.NewtonRaphson(1.5, 300, ui);
alg.printVector(model2.modelSetTest(ds->input));
}
void MLPPTests::is_approx_equalsd(double a, double b, const String &str) { void MLPPTests::is_approx_equalsd(double a, double b, const String &str) {
if (!Math::is_equal_approx(a, b)) { if (!Math::is_equal_approx(a, b)) {
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b)); ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
@ -285,6 +398,7 @@ IAEDMAT_FAILED:
MLPPTests::MLPPTests() { MLPPTests::MLPPTests() {
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv"; _load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
_load_california_housing_data_path = "res://datasets/CaliforniaHousing.csv";
} }
MLPPTests::~MLPPTests() { MLPPTests::~MLPPTests() {
@ -294,4 +408,13 @@ void MLPPTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics); ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics);
ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra); ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra);
ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression); ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_gradient_descent, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_sgd", "ui"), &MLPPTests::test_multivariate_linear_regression_sgd, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_mbgd", "ui"), &MLPPTests::test_multivariate_linear_regression_mbgd, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_normal_equation", "ui"), &MLPPTests::test_multivariate_linear_regression_normal_equation, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_adam"), &MLPPTests::test_multivariate_linear_regression_adam);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_score_sgd_adam", "ui"), &MLPPTests::test_multivariate_linear_regression_score_sgd_adam, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_epochs_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent, false);
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_newton_raphson", "ui"), &MLPPTests::test_multivariate_linear_regression_newton_raphson, false);
} }

View File

@ -18,6 +18,15 @@ public:
void test_linear_algebra(); void test_linear_algebra();
void test_univariate_linear_regression(); void test_univariate_linear_regression();
void test_multivariate_linear_regression_gradient_descent(bool ui = false);
void test_multivariate_linear_regression_sgd(bool ui = false);
void test_multivariate_linear_regression_mbgd(bool ui = false);
void test_multivariate_linear_regression_normal_equation(bool ui = false);
void test_multivariate_linear_regression_adam();
void test_multivariate_linear_regression_score_sgd_adam(bool ui = false);
void test_multivariate_linear_regression_epochs_gradient_descent(bool ui = false);
void test_multivariate_linear_regression_newton_raphson(bool ui = false);
void is_approx_equalsd(double a, double b, const String &str); void is_approx_equalsd(double a, double b, const String &str);
void is_approx_equals_dvec(const Vector<double> &a, const Vector<double> &b, const String &str); void is_approx_equals_dvec(const Vector<double> &a, const Vector<double> &b, const String &str);
void is_approx_equals_dmat(const Vector<Vector<double>> &a, const Vector<Vector<double>> &b, const String &str); void is_approx_equals_dmat(const Vector<Vector<double>> &a, const Vector<Vector<double>> &b, const String &str);
@ -29,6 +38,7 @@ protected:
static void _bind_methods(); static void _bind_methods();
String _load_fires_and_crime_data_path; String _load_fires_and_crime_data_path;
String _load_california_housing_data_path;
}; };
#endif #endif