From 38dbc2d4701436cffe48651056aaf451c8d177d8 Mon Sep 17 00:00:00 2001 From: Relintai Date: Fri, 28 Apr 2023 21:23:37 +0200 Subject: [PATCH] MLPPSoftmaxReg api rework. --- mlpp/softmax_reg/softmax_reg.cpp | 186 ++++++++++++++++--------------- mlpp/softmax_reg/softmax_reg.h | 41 +++---- test/mlpp_tests.cpp | 2 +- 3 files changed, 120 insertions(+), 109 deletions(-) diff --git a/mlpp/softmax_reg/softmax_reg.cpp b/mlpp/softmax_reg/softmax_reg.cpp index 3b7f2e0..46f0b84 100644 --- a/mlpp/softmax_reg/softmax_reg.cpp +++ b/mlpp/softmax_reg/softmax_reg.cpp @@ -13,65 +13,79 @@ #include -Ref MLPPSoftmaxReg::get_input_set() { +Ref MLPPSoftmaxReg::get_input_set() const { return _input_set; } void MLPPSoftmaxReg::set_input_set(const Ref &val) { _input_set = val; - - _initialized = false; } -Ref MLPPSoftmaxReg::get_output_set() { +Ref MLPPSoftmaxReg::get_output_set() const { return _output_set; } void MLPPSoftmaxReg::set_output_set(const Ref &val) { _output_set = val; - - _initialized = false; } -MLPPReg::RegularizationType MLPPSoftmaxReg::get_reg() { +MLPPReg::RegularizationType MLPPSoftmaxReg::get_reg() const { return _reg; } void MLPPSoftmaxReg::set_reg(const MLPPReg::RegularizationType val) { _reg = val; - - _initialized = false; } -real_t MLPPSoftmaxReg::get_lambda() { +real_t MLPPSoftmaxReg::get_lambda() const { return _lambda; } void MLPPSoftmaxReg::set_lambda(const real_t val) { _lambda = val; - - _initialized = false; } -real_t MLPPSoftmaxReg::get_alpha() { +real_t MLPPSoftmaxReg::get_alpha() const { return _alpha; } void MLPPSoftmaxReg::set_alpha(const real_t val) { _alpha = val; +} - _initialized = false; +Ref MLPPSoftmaxReg::data_y_hat_get() const { + return _y_hat; +} +void MLPPSoftmaxReg::data_y_hat_set(const Ref &val) { + _y_hat = val; +} + +Ref MLPPSoftmaxReg::data_weights_get() const { + return _weights; +} +void MLPPSoftmaxReg::data_weights_set(const Ref &val) { + _weights = val; +} + +Ref MLPPSoftmaxReg::data_bias_get() const { + return _bias; +} +void MLPPSoftmaxReg::data_bias_set(const Ref &val) { + _bias = val; } Ref MLPPSoftmaxReg::model_test(const Ref &x) { - ERR_FAIL_COND_V(!_initialized, Ref()); + ERR_FAIL_COND_V(!_input_set.is_valid() || !_output_set.is_valid(), Ref()); + ERR_FAIL_COND_V(needs_init(), Ref()); return evaluatev(x); } Ref MLPPSoftmaxReg::model_set_test(const Ref &X) { - ERR_FAIL_COND_V(!_initialized, Ref()); + ERR_FAIL_COND_V(!_input_set.is_valid() || !_output_set.is_valid(), Ref()); + ERR_FAIL_COND_V(needs_init(), Ref()); return evaluatem(X); } -void MLPPSoftmaxReg::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { - ERR_FAIL_COND(!_initialized); +void MLPPSoftmaxReg::train_gradient_descent(real_t learning_rate, int max_epoch, bool ui) { + ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); + ERR_FAIL_COND(needs_init()); MLPPReg regularization; real_t cost_prev = 0; @@ -113,17 +127,19 @@ void MLPPSoftmaxReg::gradient_descent(real_t learning_rate, int max_epoch, bool } } -void MLPPSoftmaxReg::sgd(real_t learning_rate, int max_epoch, bool ui) { - ERR_FAIL_COND(!_initialized); +void MLPPSoftmaxReg::train_sgd(real_t learning_rate, int max_epoch, bool ui) { + ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); + ERR_FAIL_COND(needs_init()); MLPPReg regularization; real_t cost_prev = 0; int epoch = 1; + int n = _input_set->size().y; std::random_device rd; std::default_random_engine generator(rd()); - std::uniform_int_distribution distribution(0, int(_n - 1)); + std::uniform_int_distribution distribution(0, int(n - 1)); Ref input_set_row_tmp; input_set_row_tmp.instance(); @@ -185,15 +201,17 @@ void MLPPSoftmaxReg::sgd(real_t learning_rate, int max_epoch, bool ui) { forward_pass(); } -void MLPPSoftmaxReg::mbgd(real_t learning_rate, int max_epoch, int mini_batch_size, bool ui) { - ERR_FAIL_COND(!_initialized); +void MLPPSoftmaxReg::train_mbgd(real_t learning_rate, int max_epoch, int mini_batch_size, bool ui) { + ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); + ERR_FAIL_COND(needs_init()); MLPPReg regularization; real_t cost_prev = 0; int epoch = 1; + int n = _input_set->size().y; // Creating the mini-batches - int n_mini_batch = _n / mini_batch_size; + int n_mini_batch = n / mini_batch_size; MLPPUtilities::CreateMiniBatchMMBatch batches = MLPPUtilities::create_mini_batchesmm(_input_set, _output_set, n_mini_batch); while (true) { @@ -234,98 +252,80 @@ void MLPPSoftmaxReg::mbgd(real_t learning_rate, int max_epoch, int mini_batch_si } real_t MLPPSoftmaxReg::score() { - ERR_FAIL_COND_V(!_initialized, 0); + ERR_FAIL_COND_V(!_input_set.is_valid() || !_output_set.is_valid(), 0); + ERR_FAIL_COND_V(needs_init(), 0); MLPPUtilities util; return util.performance_mat(_y_hat, _output_set); } -void MLPPSoftmaxReg::save(const String &file_name) { - ERR_FAIL_COND(!_initialized); - - MLPPUtilities util; - - //util.saveParameters(file_name, _weights, _bias); -} - -bool MLPPSoftmaxReg::is_initialized() { - return _initialized; -} -void MLPPSoftmaxReg::initialize() { - if (_initialized) { - return; +bool MLPPSoftmaxReg::needs_init() const { + if (!_input_set.is_valid()) { + return true; } + if (!_output_set.is_valid()) { + return true; + } + + int n = _input_set->size().y; + int k = _input_set->size().x; + int n_class = _output_set->size().x; + + if (_y_hat->size().x != n) { + return true; + } + + if (_weights->size() != Size2i(n_class, k)) { + return true; + } + + if (_bias->size() != n_class) { + return true; + } + + return false; +} +void MLPPSoftmaxReg::initialize() { ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); - _n = _input_set->size().y; - _k = _input_set->size().x; - _n_class = _output_set->size().x; + int n = _input_set->size().y; + int k = _input_set->size().x; + int n_class = _output_set->size().x; - _y_hat.instance(); - _y_hat->resize(Size2i(_n, 0)); + _y_hat->resize(Size2i(n, 0)); MLPPUtilities util; - _weights.instance(); - _weights->resize(Size2i(_n_class, _k)); - - _bias.instance(); - _bias->resize(_n_class); + _weights->resize(Size2i(n_class, k)); + _bias->resize(n_class); util.weight_initializationm(_weights); util.bias_initializationv(_bias); - - _initialized = true; } MLPPSoftmaxReg::MLPPSoftmaxReg(const Ref &p_input_set, const Ref &p_output_set, MLPPReg::RegularizationType p_reg, real_t p_lambda, real_t p_alpha) { _input_set = p_input_set; _output_set = p_output_set; - - _n = _input_set->size().y; - _k = _input_set->size().x; - _n_class = _output_set->size().x; - _reg = p_reg; _lambda = p_lambda; _alpha = p_alpha; - if (!_y_hat.is_valid()) { - _y_hat.instance(); - } - _y_hat->resize(Size2i(_n, 0)); - - MLPPUtilities util; - - if (!_weights.is_valid()) { - _weights.instance(); - } - _weights->resize(Size2i(_n_class, _k)); - - if (!_bias.is_valid()) { - _bias.instance(); - } - _bias->resize(_n_class); - - util.weight_initializationm(_weights); - util.bias_initializationv(_bias); - - _initialized = true; + _y_hat.instance(); + _weights.instance(); + _bias.instance(); } MLPPSoftmaxReg::MLPPSoftmaxReg() { - _n = 0; - _k = 0; - _n_class = 0; - // Regularization Params _reg = MLPPReg::REGULARIZATION_TYPE_NONE; _lambda = 0.5; _alpha = 0.5; /* This is the controlling param for Elastic Net*/ - _initialized = false; + _y_hat.instance(); + _weights.instance(); + _bias.instance(); } MLPPSoftmaxReg::~MLPPSoftmaxReg() { } @@ -376,17 +376,27 @@ void MLPPSoftmaxReg::_bind_methods() { ClassDB::bind_method(D_METHOD("set_alpha", "val"), &MLPPSoftmaxReg::set_alpha); ADD_PROPERTY(PropertyInfo(Variant::REAL, "alpha"), "set_alpha", "get_alpha"); + ClassDB::bind_method(D_METHOD("data_y_hat_get"), &MLPPSoftmaxReg::data_y_hat_get); + ClassDB::bind_method(D_METHOD("data_y_hat_set", "val"), &MLPPSoftmaxReg::data_y_hat_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "data_y_hat", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "data_y_hat_set", "data_y_hat_get"); + + ClassDB::bind_method(D_METHOD("data_weights_get"), &MLPPSoftmaxReg::data_weights_get); + ClassDB::bind_method(D_METHOD("data_weights_set", "val"), &MLPPSoftmaxReg::data_weights_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "data_weights", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "data_weights_set", "data_weights_get"); + + ClassDB::bind_method(D_METHOD("data_bias_get"), &MLPPSoftmaxReg::data_bias_get); + ClassDB::bind_method(D_METHOD("data_bias_set", "val"), &MLPPSoftmaxReg::data_bias_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "data_bias", PROPERTY_HINT_RESOURCE_TYPE, "MLPPVector"), "data_bias_set", "data_bias_get"); + ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPSoftmaxReg::model_test); ClassDB::bind_method(D_METHOD("model_set_test", "X"), &MLPPSoftmaxReg::model_set_test); - ClassDB::bind_method(D_METHOD("gradient_descent", "learning_rate", "max_epoch", "ui"), &MLPPSoftmaxReg::gradient_descent, false); - ClassDB::bind_method(D_METHOD("sgd", "learning_rate", "max_epoch", "ui"), &MLPPSoftmaxReg::sgd, false); - ClassDB::bind_method(D_METHOD("mbgd", "learning_rate", "max_epoch", "mini_batch_size", "ui"), &MLPPSoftmaxReg::mbgd, false); + ClassDB::bind_method(D_METHOD("train_gradient_descent", "learning_rate", "max_epoch", "ui"), &MLPPSoftmaxReg::train_gradient_descent, false); + ClassDB::bind_method(D_METHOD("train_sgd", "learning_rate", "max_epoch", "ui"), &MLPPSoftmaxReg::train_sgd, false); + ClassDB::bind_method(D_METHOD("train_mbgd", "learning_rate", "max_epoch", "mini_batch_size", "ui"), &MLPPSoftmaxReg::train_mbgd, false); ClassDB::bind_method(D_METHOD("score"), &MLPPSoftmaxReg::score); - ClassDB::bind_method(D_METHOD("save", "file_name"), &MLPPSoftmaxReg::save); - - ClassDB::bind_method(D_METHOD("is_initialized"), &MLPPSoftmaxReg::is_initialized); + ClassDB::bind_method(D_METHOD("needs_init"), &MLPPSoftmaxReg::needs_init); ClassDB::bind_method(D_METHOD("initialize"), &MLPPSoftmaxReg::initialize); } diff --git a/mlpp/softmax_reg/softmax_reg.h b/mlpp/softmax_reg/softmax_reg.h index bc82805..9677edd 100644 --- a/mlpp/softmax_reg/softmax_reg.h +++ b/mlpp/softmax_reg/softmax_reg.h @@ -10,44 +10,51 @@ #include "core/math/math_defs.h" -#include "core/object/reference.h" +#include "core/object/resource.h" #include "../lin_alg/mlpp_matrix.h" #include "../lin_alg/mlpp_vector.h" #include "../regularization/reg.h" -class MLPPSoftmaxReg : public Reference { - GDCLASS(MLPPSoftmaxReg, Reference); +class MLPPSoftmaxReg : public Resource { + GDCLASS(MLPPSoftmaxReg, Resource); public: - Ref get_input_set(); + Ref get_input_set() const; void set_input_set(const Ref &val); - Ref get_output_set(); + Ref get_output_set() const; void set_output_set(const Ref &val); - MLPPReg::RegularizationType get_reg(); + MLPPReg::RegularizationType get_reg() const; void set_reg(const MLPPReg::RegularizationType val); - real_t get_lambda(); + real_t get_lambda() const; void set_lambda(const real_t val); - real_t get_alpha(); + real_t get_alpha() const; void set_alpha(const real_t val); + Ref data_y_hat_get() const; + void data_y_hat_set(const Ref &val); + + Ref data_weights_get() const; + void data_weights_set(const Ref &val); + + Ref data_bias_get() const; + void data_bias_set(const Ref &val); + Ref model_test(const Ref &x); Ref model_set_test(const Ref &X); - void gradient_descent(real_t learning_rate, int max_epoch, bool ui = false); - void sgd(real_t learning_rate, int max_epoch, bool ui = false); - void mbgd(real_t learning_rate, int max_epoch, int mini_batch_size, bool ui = false); + void train_gradient_descent(real_t learning_rate, int max_epoch, bool ui = false); + void train_sgd(real_t learning_rate, int max_epoch, bool ui = false); + void train_mbgd(real_t learning_rate, int max_epoch, int mini_batch_size, bool ui = false); real_t score(); - void save(const String &file_name); - - bool is_initialized(); + bool needs_init() const; void initialize(); MLPPSoftmaxReg(const Ref &p_input_set, const Ref &p_output_set, MLPPReg::RegularizationType p_reg = MLPPReg::REGULARIZATION_TYPE_NONE, real_t p_lambda = 0.5, real_t p_alpha = 0.5); @@ -75,12 +82,6 @@ protected: Ref _y_hat; Ref _weights; Ref _bias; - - int _n; - int _k; - int _n_class; - - bool _initialized; }; #endif /* SoftmaxReg_hpp */ diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index d163f0e..9ac4152 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -345,7 +345,7 @@ void MLPPTests::test_softmax_regression(bool ui) { // SOFTMAX REGRESSION MLPPSoftmaxReg model(dt->get_input(), dt->get_output()); - model.sgd(0.1, 10000, ui); + model.train_sgd(0.1, 10000, ui); PLOG_MSG(model.model_set_test(dt->get_input())->to_string()); PLOG_MSG("ACCURACY: " + String::num(100 * model.score()) + "%"); }