Added loader methods that use the engine's FileAccess instead of ifstreams.

This commit is contained in:
Relintai 2023-01-25 18:27:14 +01:00
parent 3a67c5873b
commit 9403f8efe2
3 changed files with 257 additions and 5 deletions

View File

@ -6,9 +6,13 @@
//
#include "data.h"
#include "core/os/file_access.h"
#include "../lin_alg/lin_alg.h"
#include "../softmax_net/softmax_net.h"
#include "../stat/stat.h"
#include <algorithm>
#include <cmath>
#include <fstream>
@ -16,6 +20,185 @@
#include <random>
#include <sstream>
void MLPPDataESimple::_bind_methods() {
}
void MLPPDataSimple::_bind_methods() {
}
void MLPPDataComplex::_bind_methods() {
}
// Loading Datasets
Ref<MLPPDataSimple> MLPPData::load_breast_cancer(const String &path) {
const int BREAST_CANCER_SIZE = 30; // k = 30
Ref<MLPPDataSimple> data;
data.instance();
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
return data;
}
Ref<MLPPDataSimple> MLPPData::load_breast_cancer_svc(const String &path) {
const int BREAST_CANCER_SIZE = 30; // k = 30
Ref<MLPPDataSimple> data;
data.instance();
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
return data;
}
Ref<MLPPDataComplex> MLPPData::load_iris(const String &path) {
const int IRIS_SIZE = 4;
const int ONE_HOT_NUM = 3;
std::vector<double> tempOutputSet;
Ref<MLPPDataComplex> data;
data.instance();
set_data_supervised(IRIS_SIZE, path, data->input, tempOutputSet);
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
return data;
}
Ref<MLPPDataComplex> MLPPData::load_wine(const String &path) {
const int WINE_SIZE = 4;
const int ONE_HOT_NUM = 3;
std::vector<double> tempOutputSet;
Ref<MLPPDataComplex> data;
data.instance();
set_data_supervised(WINE_SIZE, path, data->input, tempOutputSet);
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
return data;
}
Ref<MLPPDataComplex> MLPPData::load_mnist_train(const String &path) {
const int MNIST_SIZE = 784;
const int ONE_HOT_NUM = 10;
std::vector<std::vector<double>> inputSet;
std::vector<double> tempOutputSet;
Ref<MLPPDataComplex> data;
data.instance();
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
return data;
}
Ref<MLPPDataComplex> MLPPData::load_mnist_test(const String &path) {
const int MNIST_SIZE = 784;
const int ONE_HOT_NUM = 10;
std::vector<std::vector<double>> inputSet;
std::vector<double> tempOutputSet;
Ref<MLPPDataComplex> data;
data.instance();
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
return data;
}
Ref<MLPPDataSimple> MLPPData::load_california_housing(const String &path) {
const int CALIFORNIA_HOUSING_SIZE = 13; // k = 30
Ref<MLPPDataSimple> data;
data.instance();
set_data_supervised(CALIFORNIA_HOUSING_SIZE, path, data->input, data->output);
return data;
}
Ref<MLPPDataESimple> MLPPData::load_fires_and_crime(const String &path) {
// k is implicitly 1.
Ref<MLPPDataESimple> data;
data.instance();
set_data_simple(path, data->input, data->output);
return data;
}
// MULTIVARIATE SUPERVISED
void MLPPData::set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet) {
MLPPLinAlg alg;
inputSet.resize(k);
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
ERR_FAIL_COND(!file);
while (!file->eof_reached()) {
Vector<String> ll = file->get_csv_line();
for (int i = 0; i < k; ++i) {
inputSet[i].push_back(ll[i].to_double());
}
outputSet.push_back(ll[k].to_double());
}
inputSet = alg.transpose(inputSet);
memdelete(file);
}
void MLPPData::set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet) {
MLPPLinAlg alg;
inputSet.resize(k);
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
ERR_FAIL_COND(!file);
while (!file->eof_reached()) {
Vector<String> ll = file->get_csv_line();
for (int i = 0; i < k; ++i) {
inputSet[i].push_back(ll[i].to_double());
}
}
inputSet = alg.transpose(inputSet);
memdelete(file);
}
void MLPPData::set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet) {
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
ERR_FAIL_COND(!file);
while (!file->eof_reached()) {
Vector<String> ll = file->get_csv_line();
for (int i = 0; i < ll.size(); i += 2) {
inputSet.push_back(ll[i].to_double());
outputSet.push_back(ll[i + 1].to_double());
}
}
memdelete(file);
}
// Loading Datasets
std::tuple<std::vector<std::vector<double>>, std::vector<double>> MLPPData::loadBreastCancer() {
@ -754,3 +937,13 @@ std::vector<double> MLPPData::reverseOneHot(std::vector<std::vector<double>> tem
return outputSet;
}
void MLPPData::_bind_methods() {
ClassDB::bind_method(D_METHOD("load_breast_cancer", "path"), &MLPPData::load_breast_cancer);
ClassDB::bind_method(D_METHOD("load_breast_cancer_svc", "path"), &MLPPData::load_breast_cancer_svc);
ClassDB::bind_method(D_METHOD("load_iris", "path"), &MLPPData::load_iris);
ClassDB::bind_method(D_METHOD("load_wine", "path"), &MLPPData::load_wine);
ClassDB::bind_method(D_METHOD("load_mnist_train", "path"), &MLPPData::load_mnist_train);
ClassDB::bind_method(D_METHOD("load_mnist_test", "path"), &MLPPData::load_mnist_test);
ClassDB::bind_method(D_METHOD("load_california_housing", "path"), &MLPPData::load_california_housing);
ClassDB::bind_method(D_METHOD("load_fires_and_crime", "path"), &MLPPData::load_fires_and_crime);
}

View File

@ -9,13 +9,65 @@
// Created by Marc Melikyan on 11/4/20.
//
#include "core/string/ustring.h"
#include "core/object/reference.h"
#include <string>
#include <tuple>
#include <vector>
class MLPPDataESimple : public Reference {
GDCLASS(MLPPDataESimple, Reference);
class MLPPData {
public:
std::vector<double> input;
std::vector<double> output;
protected:
static void _bind_methods();
};
class MLPPDataSimple : public Reference {
GDCLASS(MLPPDataSimple, Reference);
public:
std::vector<std::vector<double>> input;
std::vector<double> output;
protected:
static void _bind_methods();
};
class MLPPDataComplex : public Reference {
GDCLASS(MLPPDataComplex, Reference);
public:
std::vector<std::vector<double>> input;
std::vector<std::vector<double>> output;
protected:
static void _bind_methods();
};
class MLPPData : public Reference {
GDCLASS(MLPPData, Reference);
public:
// Load Datasets
Ref<MLPPDataSimple> load_breast_cancer(const String &path);
Ref<MLPPDataSimple> load_breast_cancer_svc(const String &path);
Ref<MLPPDataComplex> load_iris(const String &path);
Ref<MLPPDataComplex> load_wine(const String &path);
Ref<MLPPDataComplex> load_mnist_train(const String &path);
Ref<MLPPDataComplex> load_mnist_test(const String &path);
Ref<MLPPDataSimple> load_california_housing(const String &path);
Ref<MLPPDataESimple> load_fires_and_crime(const String &path);
void set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet);
void set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet);
void set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet);
// Load Datasets
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancer();
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancerSVC();
@ -92,8 +144,8 @@ public:
return setInputSet;
}
private:
protected:
static void _bind_methods();
};
#endif /* Data_hpp */

View File

@ -23,10 +23,17 @@ SOFTWARE.
#include "register_types.h"
#include "mlpp/data/data.h"
#include "test/mlpp_tests.h"
void register_pmlpp_types(ModuleRegistrationLevel p_level) {
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
ClassDB::register_class<MLPPDataESimple>();
ClassDB::register_class<MLPPDataSimple>();
ClassDB::register_class<MLPPDataComplex>();
ClassDB::register_class<MLPPData>();
ClassDB::register_class<MLPPTests>();
}
}