mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-11-09 13:22:09 +01:00
Add variables for all the used test datasets.
This commit is contained in:
parent
48f7cbe454
commit
478859374a
@ -170,7 +170,7 @@ void MLPPTests::test_univariate_linear_regression() {
|
||||
// Univariate, simple linear regression, case where k = 1
|
||||
MLPPData data;
|
||||
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_load_fires_and_crime_data_path);
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
|
||||
|
||||
MLPPUniLinReg model(ds->input, ds->output);
|
||||
|
||||
@ -188,7 +188,7 @@ void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
@ -200,7 +200,7 @@ void MLPPTests::test_multivariate_linear_regression_sgd(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
@ -212,7 +212,7 @@ void MLPPTests::test_multivariate_linear_regression_mbgd(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
@ -224,7 +224,7 @@ void MLPPTests::test_multivariate_linear_regression_normal_equation(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg model(ds->input, ds->output); // Can use Lasso, Ridge, ElasticNet Reg
|
||||
|
||||
@ -236,7 +236,7 @@ void MLPPTests::test_multivariate_linear_regression_adam() {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
MLPPLinReg adamModel(alg.transpose(ds->input), ds->output);
|
||||
alg.printVector(adamModel.modelSetTest(ds->input));
|
||||
@ -247,7 +247,7 @@ void MLPPTests::test_multivariate_linear_regression_score_sgd_adam(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
const int TRIAL_NUM = 1000;
|
||||
|
||||
@ -272,7 +272,7 @@ void MLPPTests::test_multivariate_linear_regression_epochs_gradient_descent(bool
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
std::cout << "Total epoch num: 300" << std::endl;
|
||||
std::cout << "Method: 1st Order w/ Jacobians" << std::endl;
|
||||
@ -286,7 +286,7 @@ void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
|
||||
MLPPData data;
|
||||
MLPPLinAlg alg;
|
||||
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_load_california_housing_data_path);
|
||||
Ref<MLPPDataSimple> ds = data.load_california_housing(_california_housing_data_path);
|
||||
|
||||
std::cout << "--------------------------------------------" << std::endl;
|
||||
std::cout << "Total epoch num: 300" << std::endl;
|
||||
@ -297,15 +297,22 @@ void MLPPTests::test_multivariate_linear_regression_newton_raphson(bool ui) {
|
||||
alg.printVector(model2.modelSetTest(ds->input));
|
||||
}
|
||||
|
||||
//MLPPStat stat;
|
||||
//MLPPLinAlg alg;
|
||||
//MLPPActivation avn;
|
||||
//MLPPCost cost;
|
||||
//MLPPData data;
|
||||
//MLPPConvolutions conv;
|
||||
|
||||
void MLPPTests::test_logistic_regression(bool ui) {
|
||||
//MLPPStat stat;
|
||||
// MLPPLinAlg alg;
|
||||
//MLPPLinAlg alg;
|
||||
//MLPPActivation avn;
|
||||
// MLPPCost cost;
|
||||
// MLPPData data;
|
||||
// MLPPConvolutions conv;
|
||||
//MLPPCost cost;
|
||||
//MLPPData data;
|
||||
//MLPPConvolutions conv;
|
||||
|
||||
// // LOGISTIC REGRESSION
|
||||
// LOGISTIC REGRESSION
|
||||
// auto [inputSet, outputSet] = data.load rastCancer();
|
||||
// LogReg model(inputSet, outputSet);
|
||||
// model.SGD(0.001, 100000, 0);
|
||||
@ -883,8 +890,14 @@ IAEDMAT_FAILED:
|
||||
}
|
||||
|
||||
MLPPTests::MLPPTests() {
|
||||
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
|
||||
_load_california_housing_data_path = "res://datasets/CaliforniaHousing.csv";
|
||||
_breast_cancer_data_path = "res://datasets/BreastCancer.csv";
|
||||
_breast_cancer_svm_data_path = "res://datasets/BreastCancerSVM.csv";
|
||||
_california_housing_data_path = "res://datasets/CaliforniaHousing.csv";
|
||||
_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
|
||||
_iris_data_path = "res://datasets/Iris.csv";
|
||||
_mnist_test_data_path = "res://datasets/MnistTest.csv";
|
||||
_mnist_train_data_path = "res://datasets/MnistTrain.csv";
|
||||
_wine_data_path = "res://datasets/Wine.csv";
|
||||
}
|
||||
|
||||
MLPPTests::~MLPPTests() {
|
||||
|
@ -68,8 +68,14 @@ public:
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
|
||||
String _load_fires_and_crime_data_path;
|
||||
String _load_california_housing_data_path;
|
||||
String _breast_cancer_data_path;
|
||||
String _breast_cancer_svm_data_path;
|
||||
String _california_housing_data_path;
|
||||
String _fires_and_crime_data_path;
|
||||
String _iris_data_path;
|
||||
String _mnist_test_data_path;
|
||||
String _mnist_train_data_path;
|
||||
String _wine_data_path;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user