pmlpp/mlpp/uni_lin_reg/uni_lin_reg.cpp

92 lines
2.6 KiB
C++
Raw Normal View History

//
// UniLinReg.cpp
//
// Created by Marc Melikyan on 9/29/20.
//
2023-01-24 18:12:23 +01:00
#include "uni_lin_reg.h"
2023-02-08 12:46:56 +01:00
2023-01-24 18:12:23 +01:00
#include "../lin_alg/lin_alg.h"
#include "../stat/stat.h"
2023-02-08 12:46:56 +01:00
// General Multivariate Linear Regression Model
// ŷ = b0 + b1x1 + b2x2 + ... + bkxk
// Univariate Linear Regression Model
// ŷ = b0 + b1x1
2023-02-09 02:27:04 +01:00
Ref<MLPPVector> MLPPUniLinReg::get_input_set() {
return _input_set;
}
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
_input_set = val;
}
Ref<MLPPVector> MLPPUniLinReg::get_output_set() {
return _output_set;
}
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
_output_set = val;
}
real_t MLPPUniLinReg::get_b0() {
return _b0;
}
real_t MLPPUniLinReg::get_b1() {
return _b1;
}
void MLPPUniLinReg::initialize() {
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
2023-02-08 12:46:56 +01:00
MLPPStat estimator;
2023-02-09 02:27:04 +01:00
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
2023-01-24 19:00:54 +01:00
}
2023-02-09 02:27:04 +01:00
Ref<MLPPVector> MLPPUniLinReg::model_set_test(const Ref<MLPPVector> &x) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-02-09 02:27:04 +01:00
return alg.scalar_addnv(_b0, alg.scalar_multiplynv(_b1, x));
2023-01-24 19:00:54 +01:00
}
2023-02-09 02:27:04 +01:00
real_t MLPPUniLinReg::model_test(real_t x) {
return _b0 + _b1 * x;
}
MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set) {
_input_set = p_input_set;
_output_set = p_output_set;
MLPPStat estimator;
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
}
MLPPUniLinReg::MLPPUniLinReg() {
_b0 = 0;
_b1 = 0;
}
MLPPUniLinReg::~MLPPUniLinReg() {
}
void MLPPUniLinReg::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPUniLinReg::get_input_set);
ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPUniLinReg::set_input_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPUniLinReg::get_output_set);
ClassDB::bind_method(D_METHOD("set_output_set", "val"), &MLPPUniLinReg::set_output_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize);
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);
}