2023-12-30 00:41:59 +01:00
|
|
|
/*************************************************************************/
|
|
|
|
/* uni_lin_reg.cpp */
|
|
|
|
/*************************************************************************/
|
|
|
|
/* This file is part of: */
|
|
|
|
/* PMLPP Machine Learning Library */
|
|
|
|
/* https://github.com/Relintai/pmlpp */
|
|
|
|
/*************************************************************************/
|
2023-12-30 00:43:39 +01:00
|
|
|
/* Copyright (c) 2023-present Péter Magyar. */
|
2023-12-30 00:41:59 +01:00
|
|
|
/* Copyright (c) 2022-2023 Marc Melikyan */
|
|
|
|
/* */
|
|
|
|
/* Permission is hereby granted, free of charge, to any person obtaining */
|
|
|
|
/* a copy of this software and associated documentation files (the */
|
|
|
|
/* "Software"), to deal in the Software without restriction, including */
|
|
|
|
/* without limitation the rights to use, copy, modify, merge, publish, */
|
|
|
|
/* distribute, sublicense, and/or sell copies of the Software, and to */
|
|
|
|
/* permit persons to whom the Software is furnished to do so, subject to */
|
|
|
|
/* the following conditions: */
|
|
|
|
/* */
|
|
|
|
/* The above copyright notice and this permission notice shall be */
|
|
|
|
/* included in all copies or substantial portions of the Software. */
|
|
|
|
/* */
|
|
|
|
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
|
|
|
|
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
|
|
|
|
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
|
|
|
|
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
|
|
|
|
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
|
|
|
|
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
|
|
|
|
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
|
|
|
|
/*************************************************************************/
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-01-24 18:12:23 +01:00
|
|
|
#include "uni_lin_reg.h"
|
|
|
|
#include "../stat/stat.h"
|
2023-02-08 12:46:56 +01:00
|
|
|
|
2023-01-23 21:13:26 +01:00
|
|
|
// General Multivariate Linear Regression Model
|
|
|
|
// ŷ = b0 + b1x1 + b2x2 + ... + bkxk
|
|
|
|
|
|
|
|
// Univariate Linear Regression Model
|
|
|
|
// ŷ = b0 + b1x1
|
|
|
|
|
2023-04-28 18:48:47 +02:00
|
|
|
Ref<MLPPVector> MLPPUniLinReg::get_input_set() const {
|
2023-02-09 02:27:04 +01:00
|
|
|
return _input_set;
|
|
|
|
}
|
|
|
|
void MLPPUniLinReg::set_input_set(const Ref<MLPPVector> &val) {
|
|
|
|
_input_set = val;
|
|
|
|
}
|
|
|
|
|
2023-04-28 18:48:47 +02:00
|
|
|
Ref<MLPPVector> MLPPUniLinReg::get_output_set() const {
|
2023-02-09 02:27:04 +01:00
|
|
|
return _output_set;
|
|
|
|
}
|
|
|
|
void MLPPUniLinReg::set_output_set(const Ref<MLPPVector> &val) {
|
|
|
|
_output_set = val;
|
|
|
|
}
|
|
|
|
|
2023-04-28 18:48:47 +02:00
|
|
|
real_t MLPPUniLinReg::get_b0() const {
|
2023-02-09 02:27:04 +01:00
|
|
|
return _b0;
|
|
|
|
}
|
2023-04-28 18:48:47 +02:00
|
|
|
void MLPPUniLinReg::set_b0(const real_t val) {
|
|
|
|
_b0 = val;
|
|
|
|
}
|
|
|
|
|
|
|
|
real_t MLPPUniLinReg::get_b1() const {
|
2023-02-09 02:27:04 +01:00
|
|
|
return _b1;
|
|
|
|
}
|
2023-04-28 18:48:47 +02:00
|
|
|
void MLPPUniLinReg::set_b1(const real_t val) {
|
|
|
|
_b1 = val;
|
|
|
|
}
|
2023-02-09 02:27:04 +01:00
|
|
|
|
2023-04-28 19:28:53 +02:00
|
|
|
void MLPPUniLinReg::train() {
|
2023-02-09 02:27:04 +01:00
|
|
|
ERR_FAIL_COND(!_input_set.is_valid() || !_output_set.is_valid());
|
|
|
|
|
2023-02-08 12:46:56 +01:00
|
|
|
MLPPStat estimator;
|
2023-02-09 02:27:04 +01:00
|
|
|
|
|
|
|
_b1 = estimator.b1_estimation(_input_set, _output_set);
|
|
|
|
_b0 = estimator.b0_estimation(_input_set, _output_set);
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-02-09 02:27:04 +01:00
|
|
|
Ref<MLPPVector> MLPPUniLinReg::model_set_test(const Ref<MLPPVector> &x) {
|
2023-04-28 21:07:46 +02:00
|
|
|
return x->scalar_multiplyn(_b1)->scalar_addn(_b0);
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-02-09 02:27:04 +01:00
|
|
|
real_t MLPPUniLinReg::model_test(real_t x) {
|
|
|
|
return _b0 + _b1 * x;
|
|
|
|
}
|
|
|
|
|
|
|
|
MLPPUniLinReg::MLPPUniLinReg(const Ref<MLPPVector> &p_input_set, const Ref<MLPPVector> &p_output_set) {
|
|
|
|
_input_set = p_input_set;
|
|
|
|
_output_set = p_output_set;
|
|
|
|
|
2023-04-28 19:28:53 +02:00
|
|
|
train();
|
2023-02-09 02:27:04 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
MLPPUniLinReg::MLPPUniLinReg() {
|
|
|
|
_b0 = 0;
|
|
|
|
_b1 = 0;
|
|
|
|
}
|
|
|
|
MLPPUniLinReg::~MLPPUniLinReg() {
|
|
|
|
}
|
|
|
|
|
|
|
|
void MLPPUniLinReg::_bind_methods() {
|
|
|
|
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPUniLinReg::get_input_set);
|
|
|
|
ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPUniLinReg::set_input_set);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
|
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPUniLinReg::get_output_set);
|
|
|
|
ClassDB::bind_method(D_METHOD("set_output_set", "val"), &MLPPUniLinReg::set_output_set);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_output_set", "get_output_set");
|
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("get_b0"), &MLPPUniLinReg::get_b0);
|
2023-04-28 18:48:47 +02:00
|
|
|
ClassDB::bind_method(D_METHOD("set_b0", "val"), &MLPPUniLinReg::set_b0);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b0"), "set_b0", "get_b0");
|
|
|
|
|
2023-02-09 02:27:04 +01:00
|
|
|
ClassDB::bind_method(D_METHOD("get_b1"), &MLPPUniLinReg::get_b1);
|
2023-04-28 18:48:47 +02:00
|
|
|
ClassDB::bind_method(D_METHOD("set_b1", "val"), &MLPPUniLinReg::set_b1);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::REAL, "b1"), "set_b1", "get_b1");
|
2023-02-09 02:27:04 +01:00
|
|
|
|
2023-04-28 19:28:53 +02:00
|
|
|
ClassDB::bind_method(D_METHOD("train"), &MLPPUniLinReg::train);
|
2023-02-09 02:27:04 +01:00
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("model_set_test", "x"), &MLPPUniLinReg::model_set_test);
|
|
|
|
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPUniLinReg::model_test);
|
2023-01-23 21:13:26 +01:00
|
|
|
}
|