diff --git a/SCsub b/SCsub index 476d9f0..0781ea2 100644 --- a/SCsub +++ b/SCsub @@ -90,6 +90,7 @@ sources = [ "mlpp/activation/activation_old.cpp", "test/mlpp_tests.cpp", + "test/mlpp_matrix_tests.cpp", ] diff --git a/config.py b/config.py index f4cf6c2..23e9930 100644 --- a/config.py +++ b/config.py @@ -58,6 +58,7 @@ def get_doc_classes(): "MLPPData", "MLPPTests", + "MLPPMatrixTests", ] def get_doc_path(): diff --git a/register_types.cpp b/register_types.cpp index a136907..1c5712c 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -70,6 +70,7 @@ SOFTWARE. #include "mlpp/wgan/wgan.h" #include "test/mlpp_tests.h" +#include "test/mlpp_matrix_tests.h" void register_pmlpp_types(ModuleRegistrationLevel p_level) { if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) { @@ -125,6 +126,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); } } diff --git a/test/mlpp_matrix_tests.cpp b/test/mlpp_matrix_tests.cpp new file mode 100644 index 0000000..f7d64a6 --- /dev/null +++ b/test/mlpp_matrix_tests.cpp @@ -0,0 +1,85 @@ + +#include "mlpp_matrix_tests.h" + +//TODO remove +#include + +#include "../mlpp/lin_alg/mlpp_matrix.h" + +void MLPPMatrixTests::test_mlpp_matrix() { + std::vector> A = { + { 1, 0, 0, 0 }, + { 0, 1, 0, 0 }, + { 0, 0, 1, 0 }, + { 0, 0, 0, 1 } + }; + + Ref rmat; + rmat.instance(); + rmat->set_from_std_vectors(A); + + Ref rmat2; + rmat2.instance(); + rmat2->set_from_std_vectors(A); + + is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test."); + + rmat2->set_from_std_vectors(A); + + is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test."); +} + +void MLPPMatrixTests::test_mlpp_matrix_mul() { + std::vector> A = { + { 1, 2 }, + { 3, 4 }, + { 5, 6 }, + { 7, 8 } + }; + + std::vector> B = { + { 1, 2, 3, 4 }, + { 5, 6, 7, 8 } + }; + + std::vector> C = { + { 11, 14, 17, 20 }, + { 23, 30, 37, 44 }, + { 35, 46, 57, 68 }, + { 47, 62, 77, 92 } + }; + + Ref rmata; + rmata.instance(); + rmata->set_from_std_vectors(A); + + Ref rmatb; + rmatb.instance(); + rmatb->set_from_std_vectors(B); + + Ref rmatc; + rmatc.instance(); + rmatc->set_from_std_vectors(C); + + Ref rmatr1 = rmata->multn(rmatb); + is_approx_equals_mat(rmatr1, rmatc, "Ref rmatr1 = rmata->multn(rmatb);"); + + Ref rmatr2; + rmatr2.instance(); + rmatr2->multb(rmata, rmatb); + is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);"); + + rmata->mult(rmatb); + is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);"); +} + +MLPPMatrixTests::MLPPMatrixTests() { +} + +MLPPMatrixTests::~MLPPMatrixTests() { +} + +void MLPPMatrixTests::_bind_methods() { + ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPMatrixTests::test_mlpp_matrix); + ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPMatrixTests::test_mlpp_matrix_mul); +} diff --git a/test/mlpp_matrix_tests.h b/test/mlpp_matrix_tests.h new file mode 100644 index 0000000..478ece2 --- /dev/null +++ b/test/mlpp_matrix_tests.h @@ -0,0 +1,34 @@ +#ifndef MLPP_MATRIX_TESTS_H +#define MLPP_MATRIX_TESTS_H + +// TODO port this class to use the test module once it's working +// Also don't forget to remove it's bindings + +#include "core/math/math_defs.h" + +#include "core/containers/vector.h" + +#include "core/object/reference.h" + +#include "core/string/ustring.h" + +#include "mlpp_tests.h" + +class MLPPMatrix; +class MLPPVector; + +class MLPPMatrixTests : public MLPPTests { + GDCLASS(MLPPMatrixTests, MLPPTests); + +public: + void test_mlpp_matrix(); + void test_mlpp_matrix_mul(); + + MLPPMatrixTests(); + ~MLPPMatrixTests(); + +protected: + static void _bind_methods(); +}; + +#endif diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index ac9174e..85c3ba4 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -1330,72 +1330,6 @@ void MLPPTests::test_mlpp_vector() { is_approx_equals_vec(rv, rv2, "re-set_from_std_vectors test."); } -void MLPPTests::test_mlpp_matrix() { - std::vector> A = { - { 1, 0, 0, 0 }, - { 0, 1, 0, 0 }, - { 0, 0, 1, 0 }, - { 0, 0, 0, 1 } - }; - - Ref rmat; - rmat.instance(); - rmat->set_from_std_vectors(A); - - Ref rmat2; - rmat2.instance(); - rmat2->set_from_std_vectors(A); - - is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test."); - - rmat2->set_from_std_vectors(A); - - is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test."); -} - -void MLPPTests::test_mlpp_matrix_mul() { - std::vector> A = { - { 1, 2 }, - { 3, 4 }, - { 5, 6 }, - { 7, 8 } - }; - - std::vector> B = { - { 1, 2, 3, 4 }, - { 5, 6, 7, 8 } - }; - - std::vector> C = { - { 11, 14, 17, 20 }, - { 23, 30, 37, 44 }, - { 35, 46, 57, 68 }, - { 47, 62, 77, 92 } - }; - - Ref rmata; - rmata.instance(); - rmata->set_from_std_vectors(A); - - Ref rmatb; - rmatb.instance(); - rmatb->set_from_std_vectors(B); - - Ref rmatc; - rmatc.instance(); - rmatc->set_from_std_vectors(C); - - Ref rmatr1 = rmata->multn(rmatb); - is_approx_equals_mat(rmatr1, rmatc, "Ref rmatr1 = rmata->multn(rmatb);"); - - Ref rmatr2; - rmatr2.instance(); - rmatr2->multb(rmata, rmatb); - is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);"); - - rmata->mult(rmatb); - is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);"); -} void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) { if (!Math::is_equal_approx(a, b)) { @@ -1623,6 +1557,4 @@ void MLPPTests::_bind_methods() { ClassDB::bind_method(D_METHOD("test_support_vector_classification_kernel", "ui"), &MLPPTests::test_support_vector_classification_kernel, false); ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector); - ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix); - ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul); } diff --git a/test/mlpp_tests.h b/test/mlpp_tests.h index e993274..2e70e19 100644 --- a/test/mlpp_tests.h +++ b/test/mlpp_tests.h @@ -65,8 +65,6 @@ public: void test_support_vector_classification_kernel(bool ui = false); void test_mlpp_vector(); - void test_mlpp_matrix(); - void test_mlpp_matrix_mul(); void is_approx_equalsd(real_t a, real_t b, const String &str); void is_approx_equals_dvec(const Vector &a, const Vector &b, const String &str);