Cleaned up MLPPUniLinReg.

This commit is contained in:
Relintai 2023-04-28 18:48:47 +02:00
parent 2c0e20dd8b
commit aaa236b14c
2 changed files with 29 additions and 17 deletions

View File

@ -15,28 +15,35 @@
// Univariate Linear Regression Model // Univariate Linear Regression Model
// ŷ = b0 + b1x1 // ŷ = b0 + b1x1
Ref<MLPPVector> MLPPUniLinReg::get_input_set() { Ref<MLPPVector> MLPPUniLinReg::get_input_set() const {
return _input_set; return _input_set;
} }
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) { void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
_input_set = val; _input_set = val;
} }
Ref<MLPPVector> MLPPUniLinReg::get_output_set() { Ref<MLPPVector> MLPPUniLinReg::get_output_set() const {
return _output_set; return _output_set;
} }
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) { void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
_output_set = val; _output_set = val;
} }
real_t MLPPUniLinReg::get_b0() { real_t MLPPUniLinReg::get_b0() const {
return _b0; return _b0;
} }
real_t MLPPUniLinReg::get_b1() { void MLPPUniLinReg::set_b0(const real_t val) {
return _b1; _b0 = val;
} }
void MLPPUniLinReg::initialize() { real_t MLPPUniLinReg::get_b1() const {
return _b1;
}
void MLPPUniLinReg::set_b1(const real_t val) {
_b1 = val;
}
void MLPPUniLinReg::fit() {
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
MLPPStat estimator; MLPPStat estimator;
@ -59,10 +66,7 @@ MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPV
_input_set = p_input_set; _input_set = p_input_set;
_output_set = p_output_set; _output_set = p_output_set;
MLPPStat estimator; fit();
_b1 = estimator.b1_estimation(_input_set, _output_set);
_b0 = estimator.b0_estimation(_input_set, _output_set);
} }
MLPPUniLinReg::MLPPUniLinReg() { MLPPUniLinReg::MLPPUniLinReg() {
@ -82,9 +86,14 @@ void MLPPUniLinReg::_bind_methods() {
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set"); ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0); ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1); ClassDB::bind_method(D_METHOD("set_b0", "val"), &MLPPUniLinReg::set_b0);
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b0"), "set_b0", "get_b0");
ClassDB::bind_method(D_METHOD("initialize"), &MLPPUniLinReg::initialize); ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
ClassDB::bind_method(D_METHOD("set_b1", "val"), &MLPPUniLinReg::set_b1);
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b1"), "set_b1", "get_b1");
ClassDB::bind_method(D_METHOD("fit"), &MLPPUniLinReg::fit);
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test); ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test); ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);

View File

@ -19,16 +19,19 @@ class MLPPUniLinReg : public Reference {
GDCLASS(MLPPUniLinReg, Reference); GDCLASS(MLPPUniLinReg, Reference);
public: public:
Ref<MLPPVector> get_input_set(); Ref<MLPPVector> get_input_set() const;
void set_input_set(const Ref<MLPPVector> &val); void set_input_set(const Ref<MLPPVector> &val);
Ref<MLPPVector> get_output_set(); Ref<MLPPVector> get_output_set() const;
void set_output_set(const Ref<MLPPVector> &val); void set_output_set(const Ref<MLPPVector> &val);
real_t get_b0(); real_t get_b0() const;
real_t get_b1(); void set_b0(const real_t val);
void initialize(); real_t get_b1() const;
void set_b1(const real_t val);
void fit();
Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x); Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x);
real_t model_test(real_t x); real_t model_test(real_t x);