mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-31 16:17:10 +01:00
Cleaned up UniLinReg.
This commit is contained in:
parent
d4409491d7
commit
8fe8070fce
@ -22,6 +22,13 @@ real_t MLPPStat::b1Estimation(const std::vector<real_t> &x, const std::vector<re
|
||||
return covariance(x, y) / variance(x);
|
||||
}
|
||||
|
||||
real_t MLPPStat::b0_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
|
||||
return meanv(y) - b1_estimation(x, y) * meanv(x);
|
||||
}
|
||||
real_t MLPPStat::b1_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
|
||||
return covariancev(x, y) / variancev(x);
|
||||
}
|
||||
|
||||
real_t MLPPStat::mean(const std::vector<real_t> &x) {
|
||||
real_t sum = 0;
|
||||
for (int i = 0; i < x.size(); i++) {
|
||||
@ -126,6 +133,21 @@ real_t MLPPStat::meanv(const Ref<MLPPVector> &x) {
|
||||
return sum / x_size;
|
||||
}
|
||||
|
||||
real_t MLPPStat::variancev(const Ref<MLPPVector> &x) {
|
||||
real_t x_mean = meanv(x);
|
||||
|
||||
int x_size = x->size();
|
||||
const real_t *x_ptr = x->ptr();
|
||||
|
||||
real_t sum = 0;
|
||||
for (int i = 0; i < x_size; ++i) {
|
||||
real_t xi = x_ptr[i];
|
||||
|
||||
sum += (xi - x_mean) * (xi - x_mean);
|
||||
}
|
||||
return sum / (x_size - 1);
|
||||
}
|
||||
|
||||
real_t MLPPStat::covariancev(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y) {
|
||||
ERR_FAIL_COND_V(x->size() != y->size(), 0);
|
||||
|
||||
|
@ -21,6 +21,9 @@ public:
|
||||
real_t b0Estimation(const std::vector<real_t> &x, const std::vector<real_t> &y);
|
||||
real_t b1Estimation(const std::vector<real_t> &x, const std::vector<real_t> &y);
|
||||
|
||||
real_t b0_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
|
||||
real_t b1_estimation(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
|
||||
|
||||
// Statistical Functions
|
||||
real_t mean(const std::vector<real_t> &x);
|
||||
real_t median(std::vector<real_t> x);
|
||||
@ -36,6 +39,7 @@ public:
|
||||
real_t chebyshevIneq(const real_t k);
|
||||
|
||||
real_t meanv(const Ref<MLPPVector> &x);
|
||||
real_t variancev(const Ref<MLPPVector> &x);
|
||||
real_t covariancev(const Ref<MLPPVector> &x, const Ref<MLPPVector> &y);
|
||||
|
||||
// Extras
|
||||
|
@ -9,26 +9,83 @@
|
||||
#include "../lin_alg/lin_alg.h"
|
||||
#include "../stat/stat.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// General Multivariate Linear Regression Model
|
||||
// ŷ = b0 + b1x1 + b2x2 + ... + bkxk
|
||||
|
||||
// Univariate Linear Regression Model
|
||||
// ŷ = b0 + b1x1
|
||||
|
||||
MLPPUniLinReg::MLPPUniLinReg(std::vector<real_t> x, std::vector<real_t> y) :
|
||||
inputSet(x), outputSet(y) {
|
||||
Ref<MLPPVector> MLPPUniLinReg::get_input_set() {
|
||||
return _input_set;
|
||||
}
|
||||
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
|
||||
_input_set = val;
|
||||
}
|
||||
|
||||
Ref<MLPPVector> MLPPUniLinReg::get_output_set() {
|
||||
return _output_set;
|
||||
}
|
||||
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
|
||||
_output_set = val;
|
||||
}
|
||||
|
||||
real_t MLPPUniLinReg::get_b0() {
|
||||
return _b0;
|
||||
}
|
||||
real_t MLPPUniLinReg::get_b1() {
|
||||
return _b1;
|
||||
}
|
||||
|
||||
void MLPPUniLinReg::initialize() {
|
||||
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
|
||||
|
||||
MLPPStat estimator;
|
||||
b1 = estimator.b1Estimation(inputSet, outputSet);
|
||||
b0 = estimator.b0Estimation(inputSet, outputSet);
|
||||
|
||||
_b1 = estimator.b1_estimation(_input_set, _output_set);
|
||||
_b0 = estimator.b0_estimation(_input_set, _output_set);
|
||||
}
|
||||
|
||||
std::vector<real_t> MLPPUniLinReg::modelSetTest(std::vector<real_t> x) {
|
||||
Ref<MLPPVector> MLPPUniLinReg::model_set_test(const Ref<MLPPVector> &x) {
|
||||
MLPPLinAlg alg;
|
||||
return alg.scalarAdd(b0, alg.scalarMultiply(b1, x));
|
||||
|
||||
return alg.scalar_addnv(_b0, alg.scalar_multiplynv(_b1, x));
|
||||
}
|
||||
|
||||
real_t MLPPUniLinReg::modelTest(real_t input) {
|
||||
return b0 + b1 * input;
|
||||
real_t MLPPUniLinReg::model_test(real_t x) {
|
||||
return _b0 + _b1 * x;
|
||||
}
|
||||
|
||||
MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set) {
|
||||
_input_set = p_input_set;
|
||||
_output_set = p_output_set;
|
||||
|
||||
MLPPStat estimator;
|
||||
|
||||
_b1 = estimator.b1_estimation(_input_set, _output_set);
|
||||
_b0 = estimator.b0_estimation(_input_set, _output_set);
|
||||
}
|
||||
|
||||
MLPPUniLinReg::MLPPUniLinReg() {
|
||||
_b0 = 0;
|
||||
_b1 = 0;
|
||||
}
|
||||
MLPPUniLinReg::~MLPPUniLinReg() {
|
||||
}
|
||||
|
||||
void MLPPUniLinReg::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPUniLinReg::get_input_set);
|
||||
ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPUniLinReg::set_input_set);
|
||||
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
|
||||
|
||||
ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPUniLinReg::get_output_set);
|
||||
ClassDB::bind_method(D_METHOD("set_output_set", "val"), &MLPPUniLinReg::set_output_set);
|
||||
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
|
||||
|
||||
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
|
||||
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
|
||||
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);
|
||||
}
|
||||
|
@ -10,20 +10,42 @@
|
||||
|
||||
#include "core/math/math_defs.h"
|
||||
|
||||
#include <vector>
|
||||
#include "core/object/reference.h"
|
||||
|
||||
#include "../lin_alg/mlpp_matrix.h"
|
||||
#include "../lin_alg/mlpp_vector.h"
|
||||
|
||||
class MLPPUniLinReg : public Reference {
|
||||
GDCLASS(MLPPUniLinReg, Reference);
|
||||
|
||||
class MLPPUniLinReg {
|
||||
public:
|
||||
MLPPUniLinReg(std::vector<real_t> x, std::vector<real_t> y);
|
||||
std::vector<real_t> modelSetTest(std::vector<real_t> x);
|
||||
real_t modelTest(real_t x);
|
||||
Ref<MLPPVector> get_input_set();
|
||||
void set_input_set(const Ref<MLPPVector> &val);
|
||||
|
||||
private:
|
||||
std::vector<real_t> inputSet;
|
||||
std::vector<real_t> outputSet;
|
||||
Ref<MLPPVector> get_output_set();
|
||||
void set_output_set(const Ref<MLPPVector> &val);
|
||||
|
||||
real_t b0;
|
||||
real_t b1;
|
||||
real_t get_b0();
|
||||
real_t get_b1();
|
||||
|
||||
void initialize();
|
||||
|
||||
Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x);
|
||||
real_t model_test(real_t x);
|
||||
|
||||
MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set);
|
||||
|
||||
MLPPUniLinReg();
|
||||
~MLPPUniLinReg();
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
|
||||
Ref<MLPPVector> _input_set;
|
||||
Ref<MLPPVector> _output_set;
|
||||
|
||||
real_t _b0;
|
||||
real_t _b1;
|
||||
};
|
||||
|
||||
#endif /* UniLinReg_hpp */
|
||||
|
@ -39,6 +39,7 @@ SOFTWARE.
|
||||
#include "mlpp/kmeans/kmeans.h"
|
||||
#include "mlpp/knn/knn.h"
|
||||
#include "mlpp/pca/pca.h"
|
||||
#include "mlpp/uni_lin_reg/uni_lin_reg.h"
|
||||
#include "mlpp/wgan/wgan.h"
|
||||
|
||||
#include "mlpp/mlp/mlp.h"
|
||||
@ -65,6 +66,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
ClassDB::register_class<MLPPMLP>();
|
||||
ClassDB::register_class<MLPPWGAN>();
|
||||
ClassDB::register_class<MLPPPCA>();
|
||||
ClassDB::register_class<MLPPUniLinReg>();
|
||||
|
||||
ClassDB::register_class<MLPPDataESimple>();
|
||||
ClassDB::register_class<MLPPDataSimple>();
|
||||
|
@ -49,6 +49,7 @@
|
||||
|
||||
#include "../mlpp/mlp/mlp_old.h"
|
||||
#include "../mlpp/pca/pca_old.h"
|
||||
#include "../mlpp/uni_lin_reg/uni_lin_reg_old.h"
|
||||
#include "../mlpp/wgan/wgan_old.h"
|
||||
|
||||
Vector<real_t> dstd_vec_to_vec(const std::vector<real_t> &in) {
|
||||
@ -181,7 +182,7 @@ void MLPPTests::test_univariate_linear_regression() {
|
||||
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_fires_and_crime_data_path);
|
||||
|
||||
MLPPUniLinReg model(ds->input, ds->output);
|
||||
MLPPUniLinRegOld model_old(ds->input, ds->output);
|
||||
|
||||
std::vector<real_t> slr_res = {
|
||||
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
|
||||
@ -190,7 +191,37 @@ void MLPPTests::test_univariate_linear_regression() {
|
||||
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
|
||||
};
|
||||
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(model_old.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
|
||||
|
||||
Ref<MLPPVector> input;
|
||||
input.instance();
|
||||
input->set_from_std_vector(ds->input);
|
||||
|
||||
Ref<MLPPVector> output;
|
||||
output.instance();
|
||||
output->set_from_std_vector(ds->output);
|
||||
|
||||
MLPPUniLinReg model(input, output);
|
||||
|
||||
std::vector<real_t> slr_res_n = {
|
||||
24.109467, 28.482935, 29.808228, 26.097408, 27.290173, 61.085152, 30.470875, 25.037172, 25.567291,
|
||||
35.904579, 54.458687, 18.808294, 23.446819, 18.543236, 19.205883, 21.193821, 23.049232, 18.808294,
|
||||
25.434761, 35.904579, 37.759987, 40.278046, 63.868271, 68.50679, 40.410576, 46.77198, 32.061226,
|
||||
23.314291, 44.784042, 44.518982, 27.82029, 20.663704, 22.519115, 53.796036, 38.952751,
|
||||
30.868464, 20.398645
|
||||
};
|
||||
|
||||
Ref<MLPPVector> slr_res_v;
|
||||
slr_res_v.instance();
|
||||
slr_res_v->set_from_std_vector(slr_res_n);
|
||||
|
||||
Ref<MLPPVector> res = model.model_set_test(input);
|
||||
|
||||
if (!slr_res_v->is_equal_approx(res)) {
|
||||
ERR_PRINT("!slr_res_v->is_equal_approx(res)");
|
||||
ERR_PRINT(res->to_string());
|
||||
ERR_PRINT(slr_res_v->to_string());
|
||||
}
|
||||
}
|
||||
|
||||
void MLPPTests::test_multivariate_linear_regression_gradient_descent(bool ui) {
|
||||
|
Loading…
Reference in New Issue
Block a user