Cleaned up MLPPUniLinReg.

This commit is contained in:
Relintai 2023-04-28 18:48:47 +02:00
parent 2c0e20dd8b
commit aaa236b14c
2 changed files with 29 additions and 17 deletions

View File

@ -15,28 +15,35 @@
// Univariate Linear Regression Model
// ŷ = b0 + b1x1
Ref<MLPPVector> MLPPUniLinReg::get_input_set() {
Ref<MLPPVector> MLPPUniLinReg::get_input_set() const {
return _input_set;
}
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
_input_set = val;
}
Ref<MLPPVector> MLPPUniLinReg::get_output_set() {
Ref<MLPPVector> MLPPUniLinReg::get_output_set() const {
return _output_set;
}
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
_output_set = val;
}
real_t MLPPUniLinReg::get_b0() {
real_t MLPPUniLinReg::get_b0() const {
return _b0;
}
real_t MLPPUniLinReg::get_b1() {
return _b1;
void MLPPUniLinReg::set_b0(const real_t val) {
_b0 = val;
}
void MLPPUniLinReg::initialize() {
real_t MLPPUniLinReg::get_b1() const {
return _b1;
}
void MLPPUniLinReg::set_b1(const real_t val) {
_b1 = val;
}
void MLPPUniLinReg::fit() {
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
MLPPStat estimator;
@ -59,10 +66,7 @@ MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPV
_input_set = p_input_set;
_output_set = p_output_set;
MLPPStat estimator;
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
fit();
}
MLPPUniLinReg::MLPPUniLinReg() {
@ -82,9 +86,14 @@ void MLPPUniLinReg::_bind_methods() {
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
ClassDB::bind_method(D_METHOD("set_b0", "val"), &MLPPUniLinReg::set_b0);
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b0"), "set_b0", "get_b0");
ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize);
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
ClassDB::bind_method(D_METHOD("set_b1", "val"), &MLPPUniLinReg::set_b1);
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b1"), "set_b1", "get_b1");
ClassDB::bind_method(D_METHOD("fit"), &MLPPUniLinReg::fit);
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);

View File

@ -19,16 +19,19 @@ class MLPPUniLinReg : public Reference {
GDCLASS(MLPPUniLinReg, Reference);
public:
Ref<MLPPVector> get_input_set();
Ref<MLPPVector> get_input_set() const;
void set_input_set(const Ref<MLPPVector> &val);
Ref<MLPPVector> get_output_set();
Ref<MLPPVector> get_output_set() const;
void set_output_set(const Ref<MLPPVector> &val);
real_t get_b0();
real_t get_b1();
real_t get_b0() const;
void set_b0(const real_t val);
void initialize();
real_t get_b1() const;
void set_b1(const real_t val);
void fit();
Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x);
real_t model_test(real_t x);