mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-03-12 22:38:51 +01:00
Added test for univariate linear regression.
This commit is contained in:
parent
9403f8efe2
commit
76d22d9e58
@ -79,7 +79,7 @@ void MLPPTests::test_statistics() {
|
||||
is_approx_equalsd(stat.mean(x), 5.5, "Arithmetic Mean");
|
||||
is_approx_equalsd(stat.mean(x), 5.5, "Median");
|
||||
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(x), dstd_vec_to_vec(stat.mode(x)), "stat.mode(x)");
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(stat.mode(x)), dstd_vec_to_vec(x), "stat.mode(x)");
|
||||
|
||||
is_approx_equalsd(stat.range(x), 9, "Range");
|
||||
is_approx_equalsd(stat.midrange(x), 4.5, "Midrange");
|
||||
@ -166,6 +166,24 @@ void MLPPTests::test_linear_algebra() {
|
||||
is_approx_equals_dmat(dstd_mat_to_mat(alg.identity(10)), dstd_mat_to_mat(id_10_res), "alg.identity(10)");
|
||||
}
|
||||
|
||||
void MLPPTests::test_univariate_linear_regression() {
|
||||
// Univariate, simple linear regression, case where k = 1
|
||||
MLPPData data;
|
||||
|
||||
Ref<MLPPDataESimple> ds = data.load_fires_and_crime(_load_fires_and_crime_data_path);
|
||||
|
||||
MLPPUniLinReg model(ds->input, ds->output);
|
||||
|
||||
std::vector<double> slr_res = {
|
||||
24.1095, 28.4829, 29.8082, 26.0974, 27.2902, 61.0851, 30.4709, 25.0372, 25.5673, 35.9046,
|
||||
54.4587, 18.8083, 23.4468, 18.5432, 19.2059, 21.1938, 23.0492, 18.8083, 25.4348, 35.9046,
|
||||
37.76, 40.278, 63.8683, 68.5068, 40.4106, 46.772, 32.0612, 23.3143, 44.784, 44.519,
|
||||
27.8203, 20.6637, 22.5191, 53.796, 38.9527, 30.8685, 20.3986
|
||||
};
|
||||
|
||||
is_approx_equals_dvec(dstd_vec_to_vec(model.modelSetTest(ds->input)), dstd_vec_to_vec(slr_res), "stat.mode(x)");
|
||||
}
|
||||
|
||||
void MLPPTests::is_approx_equalsd(double a, double b, const String &str) {
|
||||
if (!Math::is_equal_approx(a, b)) {
|
||||
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
|
||||
@ -266,6 +284,7 @@ IAEDMAT_FAILED:
|
||||
}
|
||||
|
||||
MLPPTests::MLPPTests() {
|
||||
_load_fires_and_crime_data_path = "res://datasets/FiresAndCrime.csv";
|
||||
}
|
||||
|
||||
MLPPTests::~MLPPTests() {
|
||||
@ -274,4 +293,5 @@ MLPPTests::~MLPPTests() {
|
||||
void MLPPTests::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("test_statistics"), &MLPPTests::test_statistics);
|
||||
ClassDB::bind_method(D_METHOD("test_linear_algebra"), &MLPPTests::test_linear_algebra);
|
||||
ClassDB::bind_method(D_METHOD("test_univariate_linear_regression"), &MLPPTests::test_univariate_linear_regression);
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ class MLPPTests : public Reference {
|
||||
public:
|
||||
void test_statistics();
|
||||
void test_linear_algebra();
|
||||
void test_univariate_linear_regression();
|
||||
|
||||
void is_approx_equalsd(double a, double b, const String &str);
|
||||
void is_approx_equals_dvec(const Vector<double> &a, const Vector<double> &b, const String &str);
|
||||
@ -26,6 +27,8 @@ public:
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
|
||||
String _load_fires_and_crime_data_path;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user