Add variables for all the used test datasets.

This commit is contained in:
Relintai 2023-01-26 00:53:08 +01:00
parent 48f7cbe454
commit 478859374a
2 changed files with 37 additions and 18 deletions

View File

@ -170,7 +170,7 @@ void MLPPTests::test_univariate_linear_regression() {
// Univariate, simple linear regression, case where k = 1 // Univariate, simple linear regression, case where k = 1
MLPPData data; MLPPData data;
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_load_fires_and_crime_data_path); Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
MLPPUniLinReg model(ds->input, ds->output); MLPPUniLinReg model(ds->input, ds->output);
@ -188,7 +188,7 @@ void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
@ -200,7 +200,7 @@ void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
@ -212,7 +212,7 @@ void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
@ -224,7 +224,7 @@ void MLPPTests::test_multivariate_linear_regression_normal_equation(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
@ -236,7 +236,7 @@ void MLPPTests::test_multivariate_linear_regression_adam() {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
MLPPLinReg adamModel(alg.transpose(ds->input), ds->output); MLPPLinReg adamModel(alg.transpose(ds->input), ds->output);
alg.printVector(adamModel.modelSetTest(ds->input)); alg.printVector(adamModel.modelSetTest(ds->input));
@ -247,7 +247,7 @@ void MLPPTests::test_multivariate_linear_regression_score_sgd_adam(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
const int TRIAL_NUM = 1000; const int TRIAL_NUM = 1000;
@ -272,7 +272,7 @@ void MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent(bool
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
std::cout << "Total epoch num: 300" << std::endl; std::cout << "Total epoch num: 300" << std::endl;
std::cout << "Method: 1st Order w/ Jacobians" << std::endl; std::cout << "Method: 1st Order w/ Jacobians" << std::endl;
@ -286,7 +286,7 @@ void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
MLPPData data; MLPPData data;
MLPPLinAlg alg; MLPPLinAlg alg;
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path); Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
std::cout << "--------------------------------------------" << std::endl; std::cout << "--------------------------------------------" << std::endl;
std::cout << "Total epoch num: 300" << std::endl; std::cout << "Total epoch num: 300" << std::endl;
@ -297,15 +297,22 @@ void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
alg.printVector(model2.modelSetTest(ds->input)); alg.printVector(model2.modelSetTest(ds->input));
} }
//MLPPStat stat;
//MLPPLinAlg alg;
//MLPPActivation avn;
//MLPPCost cost;
//MLPPData data;
//MLPPConvolutions conv;
void MLPPTests::test_logistic_regression(bool ui) { void MLPPTests::test_logistic_regression(bool ui) {
//MLPPStat stat; //MLPPStat stat;
// MLPPLinAlg alg; //MLPPLinAlg alg;
//MLPPActivation avn; //MLPPActivation avn;
// MLPPCost cost; //MLPPCost cost;
// MLPPData data; //MLPPData data;
// MLPPConvolutions conv; //MLPPConvolutions conv;
// // LOGISTIC REGRESSION // LOGISTIC REGRESSION
// auto [inputSet, outputSet] = data.load rastCancer(); // auto [inputSet, outputSet] = data.load rastCancer();
// LogReg model(inputSet, outputSet); // LogReg model(inputSet, outputSet);
// model.SGD(0.001, 100000, 0); // model.SGD(0.001, 100000, 0);
@ -883,8 +890,14 @@ IAEDMAT_FAILED:
} }
MLPPTests::MLPPTests() { MLPPTests::MLPPTests() {
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv"; _breast_cancer_data_path = "res://datasets/BreastCancer.csv";
_load_california_housing_data_path = "res://datasets/CaliforniaHousing.csv"; _breast_cancer_svm_data_path = "res://datasets/BreastCancerSVM.csv";
_california_housing_data_path = "res://datasets/CaliforniaHousing.csv";
_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
_iris_data_path = "res://datasets/Iris.csv";
_mnist_test_data_path = "res://datasets/MnistTest.csv";
_mnist_train_data_path = "res://datasets/MnistTrain.csv";
_wine_data_path = "res://datasets/Wine.csv";
} }
MLPPTests::~MLPPTests() { MLPPTests::~MLPPTests() {

View File

@ -68,8 +68,14 @@ public:
protected: protected:
static void _bind_methods(); static void _bind_methods();
String _load_fires_and_crime_data_path; String _breast_cancer_data_path;
String _load_california_housing_data_path; String _breast_cancer_svm_data_path;
String _california_housing_data_path;
String _fires_and_crime_data_path;
String _iris_data_path;
String _mnist_test_data_path;
String _mnist_train_data_path;
String _wine_data_path;
}; };
#endif #endif