mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-06 17:09:36 +01:00
Added loader methods that use the engine's FileAccess instead of ifstreams.
This commit is contained in:
parent
3a67c5873b
commit
9403f8efe2
@ -6,9 +6,13 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "data.h"
|
#include "data.h"
|
||||||
|
|
||||||
|
#include "core/os/file_access.h"
|
||||||
|
|
||||||
#include "../lin_alg/lin_alg.h"
|
#include "../lin_alg/lin_alg.h"
|
||||||
#include "../softmax_net/softmax_net.h"
|
#include "../softmax_net/softmax_net.h"
|
||||||
#include "../stat/stat.h"
|
#include "../stat/stat.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@ -16,6 +20,185 @@
|
|||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
void MLPPDataESimple::_bind_methods() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLPPDataSimple::_bind_methods() {
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLPPDataComplex::_bind_methods() {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading Datasets
|
||||||
|
Ref<MLPPDataSimple> MLPPData::load_breast_cancer(const String &path) {
|
||||||
|
const int BREAST_CANCER_SIZE = 30; // k = 30
|
||||||
|
|
||||||
|
Ref<MLPPDataSimple> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataSimple> MLPPData::load_breast_cancer_svc(const String &path) {
|
||||||
|
const int BREAST_CANCER_SIZE = 30; // k = 30
|
||||||
|
|
||||||
|
Ref<MLPPDataSimple> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(BREAST_CANCER_SIZE, path, data->input, data->output);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> MLPPData::load_iris(const String &path) {
|
||||||
|
const int IRIS_SIZE = 4;
|
||||||
|
const int ONE_HOT_NUM = 3;
|
||||||
|
|
||||||
|
std::vector<double> tempOutputSet;
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(IRIS_SIZE, path, data->input, tempOutputSet);
|
||||||
|
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> MLPPData::load_wine(const String &path) {
|
||||||
|
const int WINE_SIZE = 4;
|
||||||
|
const int ONE_HOT_NUM = 3;
|
||||||
|
|
||||||
|
std::vector<double> tempOutputSet;
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(WINE_SIZE, path, data->input, tempOutputSet);
|
||||||
|
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> MLPPData::load_mnist_train(const String &path) {
|
||||||
|
const int MNIST_SIZE = 784;
|
||||||
|
const int ONE_HOT_NUM = 10;
|
||||||
|
|
||||||
|
std::vector<std::vector<double>> inputSet;
|
||||||
|
std::vector<double> tempOutputSet;
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
|
||||||
|
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> MLPPData::load_mnist_test(const String &path) {
|
||||||
|
const int MNIST_SIZE = 784;
|
||||||
|
const int ONE_HOT_NUM = 10;
|
||||||
|
std::vector<std::vector<double>> inputSet;
|
||||||
|
std::vector<double> tempOutputSet;
|
||||||
|
|
||||||
|
Ref<MLPPDataComplex> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(MNIST_SIZE, path, data->input, tempOutputSet);
|
||||||
|
data->output = oneHotRep(tempOutputSet, ONE_HOT_NUM);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataSimple> MLPPData::load_california_housing(const String &path) {
|
||||||
|
const int CALIFORNIA_HOUSING_SIZE = 13; // k = 30
|
||||||
|
|
||||||
|
Ref<MLPPDataSimple> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_supervised(CALIFORNIA_HOUSING_SIZE, path, data->input, data->output);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ref<MLPPDataESimple> MLPPData::load_fires_and_crime(const String &path) {
|
||||||
|
// k is implicitly 1.
|
||||||
|
|
||||||
|
Ref<MLPPDataESimple> data;
|
||||||
|
data.instance();
|
||||||
|
|
||||||
|
set_data_simple(path, data->input, data->output);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
// MULTIVARIATE SUPERVISED
|
||||||
|
|
||||||
|
void MLPPData::set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet) {
|
||||||
|
MLPPLinAlg alg;
|
||||||
|
|
||||||
|
inputSet.resize(k);
|
||||||
|
|
||||||
|
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||||
|
|
||||||
|
ERR_FAIL_COND(!file);
|
||||||
|
|
||||||
|
while (!file->eof_reached()) {
|
||||||
|
Vector<String> ll = file->get_csv_line();
|
||||||
|
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
inputSet[i].push_back(ll[i].to_double());
|
||||||
|
}
|
||||||
|
|
||||||
|
outputSet.push_back(ll[k].to_double());
|
||||||
|
}
|
||||||
|
|
||||||
|
inputSet = alg.transpose(inputSet);
|
||||||
|
|
||||||
|
memdelete(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLPPData::set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet) {
|
||||||
|
MLPPLinAlg alg;
|
||||||
|
|
||||||
|
inputSet.resize(k);
|
||||||
|
|
||||||
|
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||||
|
|
||||||
|
ERR_FAIL_COND(!file);
|
||||||
|
|
||||||
|
while (!file->eof_reached()) {
|
||||||
|
Vector<String> ll = file->get_csv_line();
|
||||||
|
|
||||||
|
for (int i = 0; i < k; ++i) {
|
||||||
|
inputSet[i].push_back(ll[i].to_double());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputSet = alg.transpose(inputSet);
|
||||||
|
|
||||||
|
memdelete(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLPPData::set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet) {
|
||||||
|
FileAccess *file = FileAccess::open(file_name, FileAccess::READ);
|
||||||
|
|
||||||
|
ERR_FAIL_COND(!file);
|
||||||
|
|
||||||
|
while (!file->eof_reached()) {
|
||||||
|
Vector<String> ll = file->get_csv_line();
|
||||||
|
|
||||||
|
for (int i = 0; i < ll.size(); i += 2) {
|
||||||
|
inputSet.push_back(ll[i].to_double());
|
||||||
|
outputSet.push_back(ll[i + 1].to_double());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memdelete(file);
|
||||||
|
}
|
||||||
|
|
||||||
// Loading Datasets
|
// Loading Datasets
|
||||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> MLPPData::loadBreastCancer() {
|
std::tuple<std::vector<std::vector<double>>, std::vector<double>> MLPPData::loadBreastCancer() {
|
||||||
@ -754,3 +937,13 @@ std::vector<double> MLPPData::reverseOneHot(std::vector<std::vector<double>> tem
|
|||||||
return outputSet;
|
return outputSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MLPPData::_bind_methods() {
|
||||||
|
ClassDB::bind_method(D_METHOD("load_breast_cancer", "path"), &MLPPData::load_breast_cancer);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_breast_cancer_svc", "path"), &MLPPData::load_breast_cancer_svc);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_iris", "path"), &MLPPData::load_iris);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_wine", "path"), &MLPPData::load_wine);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_mnist_train", "path"), &MLPPData::load_mnist_train);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_mnist_test", "path"), &MLPPData::load_mnist_test);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_california_housing", "path"), &MLPPData::load_california_housing);
|
||||||
|
ClassDB::bind_method(D_METHOD("load_fires_and_crime", "path"), &MLPPData::load_fires_and_crime);
|
||||||
|
}
|
||||||
|
@ -9,13 +9,65 @@
|
|||||||
// Created by Marc Melikyan on 11/4/20.
|
// Created by Marc Melikyan on 11/4/20.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "core/string/ustring.h"
|
||||||
|
|
||||||
|
#include "core/object/reference.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
class MLPPDataESimple : public Reference {
|
||||||
|
GDCLASS(MLPPDataESimple, Reference);
|
||||||
|
|
||||||
class MLPPData {
|
|
||||||
public:
|
public:
|
||||||
|
std::vector<double> input;
|
||||||
|
std::vector<double> output;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MLPPDataSimple : public Reference {
|
||||||
|
GDCLASS(MLPPDataSimple, Reference);
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::vector<std::vector<double>> input;
|
||||||
|
std::vector<double> output;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MLPPDataComplex : public Reference {
|
||||||
|
GDCLASS(MLPPDataComplex, Reference);
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::vector<std::vector<double>> input;
|
||||||
|
std::vector<std::vector<double>> output;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MLPPData : public Reference {
|
||||||
|
GDCLASS(MLPPData, Reference);
|
||||||
|
|
||||||
|
public:
|
||||||
|
// Load Datasets
|
||||||
|
Ref<MLPPDataSimple> load_breast_cancer(const String &path);
|
||||||
|
Ref<MLPPDataSimple> load_breast_cancer_svc(const String &path);
|
||||||
|
Ref<MLPPDataComplex> load_iris(const String &path);
|
||||||
|
Ref<MLPPDataComplex> load_wine(const String &path);
|
||||||
|
Ref<MLPPDataComplex> load_mnist_train(const String &path);
|
||||||
|
Ref<MLPPDataComplex> load_mnist_test(const String &path);
|
||||||
|
Ref<MLPPDataSimple> load_california_housing(const String &path);
|
||||||
|
Ref<MLPPDataESimple> load_fires_and_crime(const String &path);
|
||||||
|
|
||||||
|
void set_data_supervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet, std::vector<double> &outputSet);
|
||||||
|
void set_data_unsupervised(int k, const String &file_name, std::vector<std::vector<double>> &inputSet);
|
||||||
|
void set_data_simple(const String &file_name, std::vector<double> &inputSet, std::vector<double> &outputSet);
|
||||||
|
|
||||||
// Load Datasets
|
// Load Datasets
|
||||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancer();
|
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancer();
|
||||||
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancerSVC();
|
std::tuple<std::vector<std::vector<double>>, std::vector<double>> loadBreastCancerSVC();
|
||||||
@ -92,8 +144,8 @@ public:
|
|||||||
return setInputSet;
|
return setInputSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif /* Data_hpp */
|
#endif /* Data_hpp */
|
||||||
|
@ -23,10 +23,17 @@ SOFTWARE.
|
|||||||
|
|
||||||
#include "register_types.h"
|
#include "register_types.h"
|
||||||
|
|
||||||
|
#include "mlpp/data/data.h"
|
||||||
|
|
||||||
#include "test/mlpp_tests.h"
|
#include "test/mlpp_tests.h"
|
||||||
|
|
||||||
void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||||
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
|
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
|
||||||
|
ClassDB::register_class<MLPPDataESimple>();
|
||||||
|
ClassDB::register_class<MLPPDataSimple>();
|
||||||
|
ClassDB::register_class<MLPPDataComplex>();
|
||||||
|
ClassDB::register_class<MLPPData>();
|
||||||
|
|
||||||
ClassDB::register_class<MLPPTests>();
|
ClassDB::register_class<MLPPTests>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user