Added test for univariate linear regression.

This commit is contained in:
Relintai 2023-01-25 18:36:07 +01:00
parent 9403f8efe2
commit 76d22d9e58
2 changed files with 24 additions and 1 deletions

View File

@ -79,7 +79,7 @@ void MLPPTests::test_statistics() {
is_approx_equalsd(stat.mean(x), 5.5, "Arithmetic Mean");
is_approx_equalsd(stat.mean(x), 5.5, "Median");
is_approx_equals_dvec(dstd_vec_to_vec(x), dstd_vec_to_vec(stat.mode(x)), "stat.mode(x)");
is_approx_equals_dvec(dstd_vec_to_vec(stat.mode(x)), dstd_vec_to_vec(x), "stat.mode(x)");
is_approx_equalsd(stat.range(x), 9, "Range");
is_approx_equalsd(stat.midrange(x), 4.5, "Midrange");
@ -166,6 +166,24 @@ void MLPPTests::test_linear_algebra() {
is_approx_equals_dmat(dstd_mat_to_mat(alg.identity(10)), dstd_mat_to_mat(id_10_res), "alg.identity(10)");
}
void MLPPTests::test_univariate_linear_regression() {
// Univariate, simple linear regression, case where k = 1
MLPPData data;
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_load_fires_and_crime_data_path);
MLPPUniLinReg model(ds->input, ds->output);
std::vector<double> slr_res = {
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
54.4587, 18.8083, 23.4468, 18.5432, 19.2059, 21.1938, 23.0492, 18.8083, 25.4348, 35.9046,
37.76, 40.278, 63.8683, 68.5068, 40.4106, 46.772, 32.0612, 23.3143, 44.784, 44.519,
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
};
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
}
void MLPPTests::is_approx_equalsd(double a, double b, const String &str) {
if (!Math::is_equal_approx(a, b)) {
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
@ -266,6 +284,7 @@ IAEDMAT_FAILED:
}
MLPPTests::MLPPTests() {
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
}
MLPPTests::~MLPPTests() {
@ -274,4 +293,5 @@ MLPPTests::~MLPPTests() {
void MLPPTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics);
ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra);
ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression);
}

View File

@ -16,6 +16,7 @@ class MLPPTests : public Reference {
public:
void test_statistics();
void test_linear_algebra();
void test_univariate_linear_regression();
void is_approx_equalsd(double a, double b, const String &str);
void is_approx_equals_dvec(const Vector<double> &a, const Vector<double> &b, const String &str);
@ -26,6 +27,8 @@ public:
protected:
static void _bind_methods();
String _load_fires_and_crime_data_path;
};
#endif