Tweaks to MLPPUniLinReg.

This commit is contained in:
Relintai 2023-04-28 19:28:53 +02:00
parent 28b7007bb7
commit 63d8dbf676
2 changed files with 7 additions and 7 deletions

View File

@ -43,7 +43,7 @@ void MLPPUniLinReg::set_b1(const real_t val) {
_b1 = val; _b1 = val;
} }
void MLPPUniLinReg::fit() { void MLPPUniLinReg::train() {
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid()); ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
MLPPStat estimator; MLPPStat estimator;
@ -66,7 +66,7 @@ MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPV
_input_set = p_input_set; _input_set = p_input_set;
_output_set = p_output_set; _output_set = p_output_set;
fit(); train();
} }
MLPPUniLinReg::MLPPUniLinReg() { MLPPUniLinReg::MLPPUniLinReg() {
@ -93,7 +93,7 @@ void MLPPUniLinReg::_bind_methods() {
ClassDB::bind_method(D_METHOD("set_b1", "val"), &MLPPUniLinReg::set_b1); ClassDB::bind_method(D_METHOD("set_b1", "val"), &MLPPUniLinReg::set_b1);
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b1"), "set_b1", "get_b1"); ADD_PROPERTY(PropertyInfo(Variant::REAL, "b1"), "set_b1", "get_b1");
ClassDB::bind_method(D_METHOD("fit"), &MLPPUniLinReg::fit); ClassDB::bind_method(D_METHOD("train"), &MLPPUniLinReg::train);
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test); ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test); ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);

View File

@ -10,13 +10,13 @@
#include "core/math/math_defs.h" #include "core/math/math_defs.h"
#include "core/object/reference.h" #include "core/object/resource.h"
#include "../lin_alg/mlpp_matrix.h" #include "../lin_alg/mlpp_matrix.h"
#include "../lin_alg/mlpp_vector.h" #include "../lin_alg/mlpp_vector.h"
class MLPPUniLinReg : public Reference { class MLPPUniLinReg : public Resource {
GDCLASS(MLPPUniLinReg, Reference); GDCLASS(MLPPUniLinReg, Resource);
public: public:
Ref<MLPPVector> get_input_set() const; Ref<MLPPVector> get_input_set() const;
@ -31,7 +31,7 @@ public:
real_t get_b1() const; real_t get_b1() const;
void set_b1(const real_t val); void set_b1(const real_t val);
void fit(); void train();
Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x); Ref<MLPPVector> model_set_test(const Ref<MLPPVector> &x);
real_t model_test(real_t x); real_t model_test(real_t x);