Split the matrix tests into a new subclass.

This commit is contained in:
Relintai 2023-04-26 16:33:20 +02:00
parent e0b813eacf
commit aeb3975450
7 changed files with 123 additions and 70 deletions

1
SCsub
View File

@ -90,6 +90,7 @@ sources = [
"mlpp/activation/activation_old.cpp", "mlpp/activation/activation_old.cpp",
"test/mlpp_tests.cpp", "test/mlpp_tests.cpp",
"test/mlpp_matrix_tests.cpp",
] ]

View File

@ -58,6 +58,7 @@ def get_doc_classes():
"MLPPData", "MLPPData",
"MLPPTests", "MLPPTests",
"MLPPMatrixTests",
] ]
def get_doc_path(): def get_doc_path():

View File

@ -70,6 +70,7 @@ SOFTWARE.
#include "mlpp/wgan/wgan.h" #include "mlpp/wgan/wgan.h"
#include "test/mlpp_tests.h" #include "test/mlpp_tests.h"
#include "test/mlpp_matrix_tests.h"
void register_pmlpp_types(ModuleRegistrationLevel p_level) { void register_pmlpp_types(ModuleRegistrationLevel p_level) {
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) { if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
@ -125,6 +126,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPData>(); ClassDB::register_class<MLPPData>();
ClassDB::register_class<MLPPTests>(); ClassDB::register_class<MLPPTests>();
ClassDB::register_class<MLPPMatrixTests>();
} }
} }

View File

@ -0,0 +1,85 @@
#include "mlpp_matrix_tests.h"
//TODO remove
#include <vector>
#include "../mlpp/lin_alg/mlpp_matrix.h"
void MLPPMatrixTests::test_mlpp_matrix() {
std::vector<std::vector<real_t>> A = {
{ 1, 0, 0, 0 },
{ 0, 1, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 }
};
Ref<MLPPMatrix> rmat;
rmat.instance();
rmat->set_from_std_vectors(A);
Ref<MLPPMatrix> rmat2;
rmat2.instance();
rmat2->set_from_std_vectors(A);
is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test.");
rmat2->set_from_std_vectors(A);
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
}
void MLPPMatrixTests::test_mlpp_matrix_mul() {
std::vector<std::vector<real_t>> A = {
{ 1, 2 },
{ 3, 4 },
{ 5, 6 },
{ 7, 8 }
};
std::vector<std::vector<real_t>> B = {
{ 1, 2, 3, 4 },
{ 5, 6, 7, 8 }
};
std::vector<std::vector<real_t>> C = {
{ 11, 14, 17, 20 },
{ 23, 30, 37, 44 },
{ 35, 46, 57, 68 },
{ 47, 62, 77, 92 }
};
Ref<MLPPMatrix> rmata;
rmata.instance();
rmata->set_from_std_vectors(A);
Ref<MLPPMatrix> rmatb;
rmatb.instance();
rmatb->set_from_std_vectors(B);
Ref<MLPPMatrix> rmatc;
rmatc.instance();
rmatc->set_from_std_vectors(C);
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
Ref<MLPPMatrix> rmatr2;
rmatr2.instance();
rmatr2->multb(rmata, rmatb);
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
rmata->mult(rmatb);
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
}
MLPPMatrixTests::MLPPMatrixTests() {
}
MLPPMatrixTests::~MLPPMatrixTests() {
}
void MLPPMatrixTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPMatrixTests::test_mlpp_matrix);
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPMatrixTests::test_mlpp_matrix_mul);
}

34
test/mlpp_matrix_tests.h Normal file
View File

@ -0,0 +1,34 @@
#ifndef MLPP_MATRIX_TESTS_H
#define MLPP_MATRIX_TESTS_H
// TODO port this class to use the test module once it's working
// Also don't forget to remove it's bindings
#include "core/math/math_defs.h"
#include "core/containers/vector.h"
#include "core/object/reference.h"
#include "core/string/ustring.h"
#include "mlpp_tests.h"
class MLPPMatrix;
class MLPPVector;
class MLPPMatrixTests : public MLPPTests {
GDCLASS(MLPPMatrixTests, MLPPTests);
public:
void test_mlpp_matrix();
void test_mlpp_matrix_mul();
MLPPMatrixTests();
~MLPPMatrixTests();
protected:
static void _bind_methods();
};
#endif

View File

@ -1330,72 +1330,6 @@ void MLPPTests::test_mlpp_vector() {
is_approx_equals_vec(rv, rv2, "re-set_from_std_vectors test."); is_approx_equals_vec(rv, rv2, "re-set_from_std_vectors test.");
} }
void MLPPTests::test_mlpp_matrix() {
std::vector<std::vector<real_t>> A = {
{ 1, 0, 0, 0 },
{ 0, 1, 0, 0 },
{ 0, 0, 1, 0 },
{ 0, 0, 0, 1 }
};
Ref<MLPPMatrix> rmat;
rmat.instance();
rmat->set_from_std_vectors(A);
Ref<MLPPMatrix> rmat2;
rmat2.instance();
rmat2->set_from_std_vectors(A);
is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test.");
rmat2->set_from_std_vectors(A);
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
}
void MLPPTests::test_mlpp_matrix_mul() {
std::vector<std::vector<real_t>> A = {
{ 1, 2 },
{ 3, 4 },
{ 5, 6 },
{ 7, 8 }
};
std::vector<std::vector<real_t>> B = {
{ 1, 2, 3, 4 },
{ 5, 6, 7, 8 }
};
std::vector<std::vector<real_t>> C = {
{ 11, 14, 17, 20 },
{ 23, 30, 37, 44 },
{ 35, 46, 57, 68 },
{ 47, 62, 77, 92 }
};
Ref<MLPPMatrix> rmata;
rmata.instance();
rmata->set_from_std_vectors(A);
Ref<MLPPMatrix> rmatb;
rmatb.instance();
rmatb->set_from_std_vectors(B);
Ref<MLPPMatrix> rmatc;
rmatc.instance();
rmatc->set_from_std_vectors(C);
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
Ref<MLPPMatrix> rmatr2;
rmatr2.instance();
rmatr2->multb(rmata, rmatb);
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
rmata->mult(rmatb);
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
}
void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) { void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) {
if (!Math::is_equal_approx(a, b)) { if (!Math::is_equal_approx(a, b)) {
@ -1623,6 +1557,4 @@ void MLPPTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_support_vector_classification_kernel", "ui"), &MLPPTests::test_support_vector_classification_kernel, false); ClassDB::bind_method(D_METHOD("test_support_vector_classification_kernel", "ui"), &MLPPTests::test_support_vector_classification_kernel, false);
ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector); ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector);
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix);
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul);
} }

View File

@ -65,8 +65,6 @@ public:
void test_support_vector_classification_kernel(bool ui = false); void test_support_vector_classification_kernel(bool ui = false);
void test_mlpp_vector(); void test_mlpp_vector();
void test_mlpp_matrix();
void test_mlpp_matrix_mul();
void is_approx_equalsd(real_t a, real_t b, const String &str); void is_approx_equalsd(real_t a, real_t b, const String &str);
void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str); void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str);