Cleaned up UniLinReg.

This commit is contained in:
Relintai 2023-02-09 02:27:04 +01:00
parent d4409491d7
commit 8fe8070fce
6 changed files with 160 additions and 22 deletions

View File

@ -22,6 +22,13 @@ real_t MLPPStat::b1Estimation(const std::vector<real_t> &x, const std::vector<re
return covariance(x, y) / variance(x);
}
real_t MLPPStat::b0_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
return meanv(y) - b1_estimation(x, y) * meanv(x);
}
real_t MLPPStat::b1_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
return covariancev(x, y) / variancev(x);
}
real_t MLPPStat::mean(const std::vector<real_t> &x) {
real_t sum = 0;
for (int i = 0; i < x.size(); i++) {
@ -126,6 +133,21 @@ real_t MLPPStat::meanv(const Ref<MLPPVector> &x) {
return sum / x_size;
}
real_t MLPPStat::variancev(const Ref<MLPPVector> &x) {
real_t x_mean = meanv(x);
int x_size = x->size();
const real_t *x_ptr = x->ptr();
real_t sum = 0;
for (int i = 0; i < x_size; ++i) {
real_t xi = x_ptr[i];
sum += (xi - x_mean) * (xi - x_mean);
}
return sum / (x_size - 1);
}
real_t MLPPStat::covariancev(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
ERR_FAIL_COND_V(x->size() != y->size(), 0);

View File

@ -21,6 +21,9 @@ public:
real_t b0Estimation(const std::vector<real_t> &x, const std::vector<real_t> &y);
real_t b1Estimation(const std::vector<real_t> &x, const std::vector<real_t> &y);
real_t b0_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
real_t b1_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
// Statistical Functions
real_t mean(const std::vector<real_t> &x);
real_t median(std::vector<real_t> x);
@ -36,6 +39,7 @@ public:
real_t chebyshevIneq(const real_t k);
real_t meanv(const Ref<MLPPVector> &x);
real_t variancev(const Ref<MLPPVector> &x);
real_t covariancev(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
// Extras

View File

@ -9,26 +9,83 @@
#include "../lin_alg/lin_alg.h"
#include "../stat/stat.h"
#include <iostream>
// General Multivariate Linear Regression Model
// ŷ = b0 + b1x1 + b2x2 + ... + bkxk
// Univariate Linear Regression Model
// ŷ = b0 + b1x1
MLPPUniLinReg::MLPPUniLinReg(std::vector<real_t> x, std::vector<real_t> y) :
inputSet(x), outputSet(y) {
Ref<MLPPVector> MLPPUniLinReg::get_input_set() {
return _input_set;
}
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
_input_set = val;
}
Ref<MLPPVector> MLPPUniLinReg::get_output_set() {
return _output_set;
}
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
_output_set = val;
}
real_t MLPPUniLinReg::get_b0() {
return _b0;
}
real_t MLPPUniLinReg::get_b1() {
return _b1;
}
void MLPPUniLinReg::initialize() {
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
MLPPStat estimator;
b1 = estimator.b1Estimation(inputSet, outputSet);
b0 = estimator.b0Estimation(inputSet, outputSet);
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
}
std::vector<real_t> MLPPUniLinReg::modelSetTest(std::vector<real_t> x) {
Ref<MLPPVector> MLPPUniLinReg::model_set_test(const Ref<MLPPVector> &x) {
MLPPLinAlg alg;
return alg.scalarAdd(b0, alg.scalarMultiply(b1, x));
return alg.scalar_addnv(_b0, alg.scalar_multiplynv(_b1, x));
}
real_t MLPPUniLinReg::modelTest(real_t input) {
return b0 + b1 * input;
real_t MLPPUniLinReg::model_test(real_t x) {
return _b0 + _b1 * x;
}
MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set) {
_input_set = p_input_set;
_output_set = p_output_set;
MLPPStat estimator;
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
}
MLPPUniLinReg::MLPPUniLinReg() {
_b0 = 0;
_b1 = 0;
}
MLPPUniLinReg::~MLPPUniLinReg() {
}
void MLPPUniLinReg::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPUniLinReg::get_input_set);
ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPUniLinReg::set_input_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPUniLinReg::get_output_set);
ClassDB::bind_method(D_METHOD("set_output_set", "val"), &MLPPUniLinReg::set_output_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize);
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);
}

View File

@ -10,20 +10,42 @@
#include "core/math/math_defs.h"
#include <vector>
#include "core/object/reference.h"
#include "../lin_alg/mlpp_matrix.h"
#include "../lin_alg/mlpp_vector.h"
class MLPPUniLinReg : public Reference {
GDCLASS(MLPPUniLinReg, Reference);
class MLPPUniLinReg {
public:
MLPPUniLinReg(std::vector<real_t> x, std::vector<real_t> y);
std::vector<real_t> modelSetTest(std::vector<real_t> x);
real_t modelTest(real_t x);
Ref<MLPPVector> get_input_set();
void set_input_set(const Ref<MLPPVector> &val);
private:
std::vector<real_t> inputSet;
std::vector<real_t> outputSet;
Ref<MLPPVector> get_output_set();
void set_output_set(const Ref<MLPPVector> &val);
real_t b0;
real_t b1;
real_t get_b0();
real_t get_b1();
void initialize();
Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x);
real_t model_test(real_t x);
MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set);
MLPPUniLinReg();
~MLPPUniLinReg();
protected:
static void _bind_methods();
Ref<MLPPVector> _input_set;
Ref<MLPPVector> _output_set;
real_t _b0;
real_t _b1;
};
#endif /* UniLinReg_hpp */

View File

@ -39,6 +39,7 @@ SOFTWARE.
#include "mlpp/kmeans/kmeans.h"
#include "mlpp/knn/knn.h"
#include "mlpp/pca/pca.h"
#include "mlpp/uni_lin_reg/uni_lin_reg.h"
#include "mlpp/wgan/wgan.h"
#include "mlpp/mlp/mlp.h"
@ -65,6 +66,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPMLP>();
ClassDB::register_class<MLPPWGAN>();
ClassDB::register_class<MLPPPCA>();
ClassDB::register_class<MLPPUniLinReg>();
ClassDB::register_class<MLPPDataESimple>();
ClassDB::register_class<MLPPDataSimple>();

View File

@ -49,6 +49,7 @@
#include "../mlpp/mlp/mlp_old.h"
#include "../mlpp/pca/pca_old.h"
#include "../mlpp/uni_lin_reg/uni_lin_reg_old.h"
#include "../mlpp/wgan/wgan_old.h"
Vector<real_t> dstd_vec_to_vec(const std::vector<real_t> &in) {
@ -181,7 +182,7 @@ void MLPPTests::test_univariate_linear_regression() {
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
MLPPUniLinReg model(ds->input, ds->output);
MLPPUniLinRegOld model_old(ds->input, ds->output);
std::vector<real_t> slr_res = {
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
@ -190,7 +191,37 @@ void MLPPTests::test_univariate_linear_regression() {
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
};
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
is_approx_equals_dvec(dstd_vec_to_vec(model_old.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
Ref<MLPPVector> input;
input.instance();
input->set_from_std_vector(ds->input);
Ref<MLPPVector> output;
output.instance();
output->set_from_std_vector(ds->output);
MLPPUniLinReg model(input, output);
std::vector<real_t> slr_res_n = {
24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291,
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294,
25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226,
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751,
30.868464, 20.398645
};
Ref<MLPPVector> slr_res_v;
slr_res_v.instance();
slr_res_v->set_from_std_vector(slr_res_n);
Ref<MLPPVector> res = model.model_set_test(input);
if (!slr_res_v->is_equal_approx(res)) {
ERR_PRINT("!slr_res_v->is_equal_approx(res)");
ERR_PRINT(res->to_string());
ERR_PRINT(slr_res_v->to_string());
}
}
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {