mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-18 15:07:16 +01:00
Added loader methods that use the engine's FileAccess instead of ifstreams.
This commit is contained in:
parent
3a67c5873b
commit
9403f8efe2
@ -6,9 +6,13 @@
|
||||
//
|
||||
|
||||
#include "data.h"
|
||||
|
||||
#include "core/os/file_access.h"
|
||||
|
||||
#include "../lin_alg/lin_alg.h"
|
||||
#include "../softmax_net/softmax_net.h"
|
||||
#include "../stat/stat.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
@ -16,6 +20,185 @@
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
||||
void MLPPDataESimple::_bind_methods() {
|
||||
}
|
||||
|
||||
void MLPPDataSimple::_bind_methods() {
|
||||
}
|
||||
|
||||
void MLPPDataComplex::_bind_methods() {
|
||||
}
|
||||
|
||||
// Loading Datasets
|
||||
Ref<MLPPDataSimple> MLPPData::load_breast_cancer(const String &path) {
|
||||
const int BREAST_CANCER_SIZE = 30; // k = 30
|
||||
|
||||
Ref<MLPPDataSimple> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataSimple> MLPPData::load_breast_cancer_svc(const String &path) {
|
||||
const int BREAST_CANCER_SIZE = 30; // k = 30
|
||||
|
||||
Ref<MLPPDataSimple> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataComplex> MLPPData::load_iris(const String &path) {
|
||||
const int IRIS_SIZE = 4;
|
||||
const int ONE_HOT_NUM = 3;
|
||||
|
||||
std::vector<double> tempOutputSet;
|
||||
|
||||
Ref<MLPPDataComplex> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(IRIS_SIZE, path, data->input, tempOutputSet);
|
||||
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataComplex> MLPPData::load_wine(const String &path) {
|
||||
const int WINE_SIZE = 4;
|
||||
const int ONE_HOT_NUM = 3;
|
||||
|
||||
std::vector<double> tempOutputSet;
|
||||
|
||||
Ref<MLPPDataComplex> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(WINE_SIZE, path, data->input, tempOutputSet);
|
||||
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataComplex> MLPPData::load_mnist_train(const String &path) {
|
||||
const int MNIST_SIZE = 784;
|
||||
const int ONE_HOT_NUM = 10;
|
||||
|
||||
std::vector<std::vector<double>> inputSet;
|
||||
std::vector<double> tempOutputSet;
|
||||
|
||||
Ref<MLPPDataComplex> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
|
||||
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataComplex> MLPPData::load_mnist_test(const String &path) {
|
||||
const int MNIST_SIZE = 784;
|
||||
const int ONE_HOT_NUM = 10;
|
||||
std::vector<std::vector<double>> inputSet;
|
||||
std::vector<double> tempOutputSet;
|
||||
|
||||
Ref<MLPPDataComplex> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
|
||||
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataSimple> MLPPData::load_california_housing(const String &path) {
|
||||
const int CALIFORNIA_HOUSING_SIZE = 13; // k = 30
|
||||
|
||||
Ref<MLPPDataSimple> data;
|
||||
data.instance();
|
||||
|
||||
set_data_supervised(CALIFORNIA_HOUSING_SIZE, path, data->input, data->output);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
Ref<MLPPDataESimple> MLPPData::load_fires_and_crime(const String &path) {
|
||||
// k is implicitly 1.
|
||||
|
||||
Ref<MLPPDataESimple> data;
|
||||
data.instance();
|
||||
|
||||
set_data_simple(path, data->input, data->output);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
// MULTIVARIATE SUPERVISED
|
||||
|
||||
void MLPPData::set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet) {
|
||||
MLPPLinAlg alg;
|
||||
|
||||
inputSet.resize(k);
|
||||
|
||||
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||
|
||||
ERR_FAIL_COND(!file);
|
||||
|
||||
while (!file->eof_reached()) {
|
||||
Vector<String> ll = file->get_csv_line();
|
||||
|
||||
for (int i = 0; i < k; ++i) {
|
||||
inputSet[i].push_back(ll[i].to_double());
|
||||
}
|
||||
|
||||
outputSet.push_back(ll[k].to_double());
|
||||
}
|
||||
|
||||
inputSet = alg.transpose(inputSet);
|
||||
|
||||
memdelete(file);
|
||||
}
|
||||
|
||||
void MLPPData::set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet) {
|
||||
MLPPLinAlg alg;
|
||||
|
||||
inputSet.resize(k);
|
||||
|
||||
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||
|
||||
ERR_FAIL_COND(!file);
|
||||
|
||||
while (!file->eof_reached()) {
|
||||
Vector<String> ll = file->get_csv_line();
|
||||
|
||||
for (int i = 0; i < k; ++i) {
|
||||
inputSet[i].push_back(ll[i].to_double());
|
||||
}
|
||||
}
|
||||
|
||||
inputSet = alg.transpose(inputSet);
|
||||
|
||||
memdelete(file);
|
||||
}
|
||||
|
||||
void MLPPData::set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet) {
|
||||
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||
|
||||
ERR_FAIL_COND(!file);
|
||||
|
||||
while (!file->eof_reached()) {
|
||||
Vector<String> ll = file->get_csv_line();
|
||||
|
||||
for (int i = 0; i < ll.size(); i += 2) {
|
||||
inputSet.push_back(ll[i].to_double());
|
||||
outputSet.push_back(ll[i + 1].to_double());
|
||||
}
|
||||
}
|
||||
|
||||
memdelete(file);
|
||||
}
|
||||
|
||||
// Loading Datasets
|
||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> MLPPData::loadBreastCancer() {
|
||||
@ -754,3 +937,13 @@ std::vector<double> MLPPData::reverseOneHot(std::vector<std::vector<double>> tem
|
||||
return outputSet;
|
||||
}
|
||||
|
||||
void MLPPData::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("load_breast_cancer", "path"), &MLPPData::load_breast_cancer);
|
||||
ClassDB::bind_method(D_METHOD("load_breast_cancer_svc", "path"), &MLPPData::load_breast_cancer_svc);
|
||||
ClassDB::bind_method(D_METHOD("load_iris", "path"), &MLPPData::load_iris);
|
||||
ClassDB::bind_method(D_METHOD("load_wine", "path"), &MLPPData::load_wine);
|
||||
ClassDB::bind_method(D_METHOD("load_mnist_train", "path"), &MLPPData::load_mnist_train);
|
||||
ClassDB::bind_method(D_METHOD("load_mnist_test", "path"), &MLPPData::load_mnist_test);
|
||||
ClassDB::bind_method(D_METHOD("load_california_housing", "path"), &MLPPData::load_california_housing);
|
||||
ClassDB::bind_method(D_METHOD("load_fires_and_crime", "path"), &MLPPData::load_fires_and_crime);
|
||||
}
|
||||
|
@ -9,13 +9,65 @@
|
||||
// Created by Marc Melikyan on 11/4/20.
|
||||
//
|
||||
|
||||
#include "core/string/ustring.h"
|
||||
|
||||
#include "core/object/reference.h"
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
class MLPPDataESimple : public Reference {
|
||||
GDCLASS(MLPPDataESimple, Reference);
|
||||
|
||||
class MLPPData {
|
||||
public:
|
||||
std::vector<double> input;
|
||||
std::vector<double> output;
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
class MLPPDataSimple : public Reference {
|
||||
GDCLASS(MLPPDataSimple, Reference);
|
||||
|
||||
public:
|
||||
std::vector<std::vector<double>> input;
|
||||
std::vector<double> output;
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
class MLPPDataComplex : public Reference {
|
||||
GDCLASS(MLPPDataComplex, Reference);
|
||||
|
||||
public:
|
||||
std::vector<std::vector<double>> input;
|
||||
std::vector<std::vector<double>> output;
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
class MLPPData : public Reference {
|
||||
GDCLASS(MLPPData, Reference);
|
||||
|
||||
public:
|
||||
// Load Datasets
|
||||
Ref<MLPPDataSimple> load_breast_cancer(const String &path);
|
||||
Ref<MLPPDataSimple> load_breast_cancer_svc(const String &path);
|
||||
Ref<MLPPDataComplex> load_iris(const String &path);
|
||||
Ref<MLPPDataComplex> load_wine(const String &path);
|
||||
Ref<MLPPDataComplex> load_mnist_train(const String &path);
|
||||
Ref<MLPPDataComplex> load_mnist_test(const String &path);
|
||||
Ref<MLPPDataSimple> load_california_housing(const String &path);
|
||||
Ref<MLPPDataESimple> load_fires_and_crime(const String &path);
|
||||
|
||||
void set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet);
|
||||
void set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet);
|
||||
void set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet);
|
||||
|
||||
// Load Datasets
|
||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancer();
|
||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancerSVC();
|
||||
@ -92,8 +144,8 @@ public:
|
||||
return setInputSet;
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
|
||||
#endif /* Data_hpp */
|
||||
|
@ -23,10 +23,17 @@ SOFTWARE.
|
||||
|
||||
#include "register_types.h"
|
||||
|
||||
#include "mlpp/data/data.h"
|
||||
|
||||
#include "test/mlpp_tests.h"
|
||||
|
||||
void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
|
||||
ClassDB::register_class<MLPPDataESimple>();
|
||||
ClassDB::register_class<MLPPDataSimple>();
|
||||
ClassDB::register_class<MLPPDataComplex>();
|
||||
ClassDB::register_class<MLPPData>();
|
||||
|
||||
ClassDB::register_class<MLPPTests>();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user