mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-08 17:29:36 +01:00
Port multivariate linear regression tests.
This commit is contained in:
parent
c79bd2050d
commit
203932973b
@ -184,6 +184,119 @@ void MLPPTests::test_univariate_linear_regression() {
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
model.gradientDescent(0.001, 30, ui);
|
||||
alg.printVector(model.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
model.SGD(0.00000001, 300000, ui);
|
||||
alg.printVector(model.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
model.MBGD(0.001, 10000, 2, ui);
|
||||
alg.printVector(model.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_normal_equation(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
model.normalEquation();
|
||||
alg.printVector(model.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_adam() {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
MLPPLinReg adamModel(alg.transpose(ds->input), ds->output);
|
||||
alg.printVector(adamModel.modelSetTest(ds->input));
|
||||
std::cout << "ACCURACY: " << 100 * adamModel.score() << "%" << std::endl;
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_score_sgd_adam(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
const int TRIAL_NUM = 1000;
|
||||
|
||||
double scoreSGD = 0;
|
||||
double scoreADAM = 0;
|
||||
for (int i = 0; i < TRIAL_NUM; i++) {
|
||||
MLPPLinReg modelf(alg.transpose(ds->input), ds->output);
|
||||
modelf.MBGD(0.001, 5, 1, ui);
|
||||
scoreSGD += modelf.score();
|
||||
|
||||
MLPPLinReg adamModelf(alg.transpose(ds->input), ds->output);
|
||||
adamModelf.Adam(0.1, 5, 1, 0.9, 0.999, 1e-8, ui); // Change batch size = sgd, bgd
|
||||
scoreADAM += adamModelf.score();
|
||||
}
|
||||
|
||||
std::cout << "ACCURACY, AVG, SGD: " << 100 * scoreSGD / TRIAL_NUM << "%" << std::endl;
|
||||
std::cout << std::endl;
|
||||
std::cout << "ACCURACY, AVG, ADAM: " << 100 * scoreADAM / TRIAL_NUM << "%" << std::endl;
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
std::cout << "Total epoch num: 300" << std::endl;
|
||||
std::cout << "Method: 1st Order w/ Jacobians" << std::endl;
|
||||
|
||||
MLPPLinReg model3(alg.transpose(ds->input), ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
model3.gradientDescent(0.001, 300, ui);
|
||||
alg.printVector(model3.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
|
||||
std::cout << "--------------------------------------------" << std::endl;
|
||||
std::cout << "Total epoch num: 300" << std::endl;
|
||||
std::cout << "Method: Newtonian 2nd Order w/ Hessians" << std::endl;
|
||||
MLPPLinReg model2(alg.transpose(ds->input), ds->output);
|
||||
|
||||
model2.NewtonRaphson(1.5, 300, ui);
|
||||
alg.printVector(model2.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
void MLPPTests::is_approx_equalsd(double a, double b, const String &str) {
|
||||
if (!Math::is_equal_approx(a, b)) {
|
||||
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
|
||||
@ -285,6 +398,7 @@ IAEDMAT_FAILED:
|
||||
|
||||
MLPPTests::MLPPTests() {
|
||||
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
|
||||
_load_california_housing_data_path = "res://datasets/CaliforniaHousing.csv";
|
||||
}
|
||||
|
||||
MLPPTests::~MLPPTests() {
|
||||
@ -294,4 +408,13 @@ void MLPPTests::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics);
|
||||
ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra);
|
||||
ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_gradient_descent, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_sgd", "ui"), &MLPPTests::test_multivariate_linear_regression_sgd, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_mbgd", "ui"), &MLPPTests::test_multivariate_linear_regression_mbgd, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_normal_equation", "ui"), &MLPPTests::test_multivariate_linear_regression_normal_equation, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_adam"), &MLPPTests::test_multivariate_linear_regression_adam);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_score_sgd_adam", "ui"), &MLPPTests::test_multivariate_linear_regression_score_sgd_adam, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_epochs_gradient_descent", "ui"), &MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent, false);
|
||||
ClassDB::bind_method(D_METHOD("test_multivariate_linear_regression_newton_raphson", "ui"), &MLPPTests::test_multivariate_linear_regression_newton_raphson, false);
|
||||
}
|
||||
|
@ -18,6 +18,15 @@ public:
|
||||
void test_linear_algebra();
|
||||
void test_univariate_linear_regression();
|
||||
|
||||
void test_multivariate_linear_regression_gradient_descent(bool ui = false);
|
||||
void test_multivariate_linear_regression_sgd(bool ui = false);
|
||||
void test_multivariate_linear_regression_mbgd(bool ui = false);
|
||||
void test_multivariate_linear_regression_normal_equation(bool ui = false);
|
||||
void test_multivariate_linear_regression_adam();
|
||||
void test_multivariate_linear_regression_score_sgd_adam(bool ui = false);
|
||||
void test_multivariate_linear_regression_epochs_gradient_descent(bool ui = false);
|
||||
void test_multivariate_linear_regression_newton_raphson(bool ui = false);
|
||||
|
||||
void is_approx_equalsd(double a, double b, const String &str);
|
||||
void is_approx_equals_dvec(const Vector<double> &a, const Vector<double> &b, const String &str);
|
||||
void is_approx_equals_dmat(const Vector<Vector<double>> &a, const Vector<Vector<double>> &b, const String &str);
|
||||
@ -29,6 +38,7 @@ protected:
|
||||
static void _bind_methods();
|
||||
|
||||
String _load_fires_and_crime_data_path;
|
||||
String _load_california_housing_data_path;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user