mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-22 15:06:47 +01:00
Split the matrix tests into a new subclass.
This commit is contained in:
parent
e0b813eacf
commit
aeb3975450
1
SCsub
1
SCsub
@ -90,6 +90,7 @@ sources = [
|
||||
"mlpp/activation/activation_old.cpp",
|
||||
|
||||
"test/mlpp_tests.cpp",
|
||||
"test/mlpp_matrix_tests.cpp",
|
||||
]
|
||||
|
||||
|
||||
|
@ -58,6 +58,7 @@ def get_doc_classes():
|
||||
"MLPPData",
|
||||
|
||||
"MLPPTests",
|
||||
"MLPPMatrixTests",
|
||||
]
|
||||
|
||||
def get_doc_path():
|
||||
|
@ -70,6 +70,7 @@ SOFTWARE.
|
||||
#include "mlpp/wgan/wgan.h"
|
||||
|
||||
#include "test/mlpp_tests.h"
|
||||
#include "test/mlpp_matrix_tests.h"
|
||||
|
||||
void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
if (p_level == MODULE_REGISTRATION_LEVEL_SCENE) {
|
||||
@ -125,6 +126,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
ClassDB::register_class<MLPPData>();
|
||||
|
||||
ClassDB::register_class<MLPPTests>();
|
||||
ClassDB::register_class<MLPPMatrixTests>();
|
||||
}
|
||||
}
|
||||
|
||||
|
85
test/mlpp_matrix_tests.cpp
Normal file
85
test/mlpp_matrix_tests.cpp
Normal file
@ -0,0 +1,85 @@
|
||||
|
||||
#include "mlpp_matrix_tests.h"
|
||||
|
||||
//TODO remove
|
||||
#include <vector>
|
||||
|
||||
#include "../mlpp/lin_alg/mlpp_matrix.h"
|
||||
|
||||
void MLPPMatrixTests::test_mlpp_matrix() {
|
||||
std::vector<std::vector<real_t>> A = {
|
||||
{ 1, 0, 0, 0 },
|
||||
{ 0, 1, 0, 0 },
|
||||
{ 0, 0, 1, 0 },
|
||||
{ 0, 0, 0, 1 }
|
||||
};
|
||||
|
||||
Ref<MLPPMatrix> rmat;
|
||||
rmat.instance();
|
||||
rmat->set_from_std_vectors(A);
|
||||
|
||||
Ref<MLPPMatrix> rmat2;
|
||||
rmat2.instance();
|
||||
rmat2->set_from_std_vectors(A);
|
||||
|
||||
is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test.");
|
||||
|
||||
rmat2->set_from_std_vectors(A);
|
||||
|
||||
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
|
||||
}
|
||||
|
||||
void MLPPMatrixTests::test_mlpp_matrix_mul() {
|
||||
std::vector<std::vector<real_t>> A = {
|
||||
{ 1, 2 },
|
||||
{ 3, 4 },
|
||||
{ 5, 6 },
|
||||
{ 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> B = {
|
||||
{ 1, 2, 3, 4 },
|
||||
{ 5, 6, 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> C = {
|
||||
{ 11, 14, 17, 20 },
|
||||
{ 23, 30, 37, 44 },
|
||||
{ 35, 46, 57, 68 },
|
||||
{ 47, 62, 77, 92 }
|
||||
};
|
||||
|
||||
Ref<MLPPMatrix> rmata;
|
||||
rmata.instance();
|
||||
rmata->set_from_std_vectors(A);
|
||||
|
||||
Ref<MLPPMatrix> rmatb;
|
||||
rmatb.instance();
|
||||
rmatb->set_from_std_vectors(B);
|
||||
|
||||
Ref<MLPPMatrix> rmatc;
|
||||
rmatc.instance();
|
||||
rmatc->set_from_std_vectors(C);
|
||||
|
||||
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
|
||||
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
|
||||
|
||||
Ref<MLPPMatrix> rmatr2;
|
||||
rmatr2.instance();
|
||||
rmatr2->multb(rmata, rmatb);
|
||||
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
|
||||
|
||||
rmata->mult(rmatb);
|
||||
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
|
||||
}
|
||||
|
||||
MLPPMatrixTests::MLPPMatrixTests() {
|
||||
}
|
||||
|
||||
MLPPMatrixTests::~MLPPMatrixTests() {
|
||||
}
|
||||
|
||||
void MLPPMatrixTests::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPMatrixTests::test_mlpp_matrix);
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPMatrixTests::test_mlpp_matrix_mul);
|
||||
}
|
34
test/mlpp_matrix_tests.h
Normal file
34
test/mlpp_matrix_tests.h
Normal file
@ -0,0 +1,34 @@
|
||||
#ifndef MLPP_MATRIX_TESTS_H
|
||||
#define MLPP_MATRIX_TESTS_H
|
||||
|
||||
// TODO port this class to use the test module once it's working
|
||||
// Also don't forget to remove it's bindings
|
||||
|
||||
#include "core/math/math_defs.h"
|
||||
|
||||
#include "core/containers/vector.h"
|
||||
|
||||
#include "core/object/reference.h"
|
||||
|
||||
#include "core/string/ustring.h"
|
||||
|
||||
#include "mlpp_tests.h"
|
||||
|
||||
class MLPPMatrix;
|
||||
class MLPPVector;
|
||||
|
||||
class MLPPMatrixTests : public MLPPTests {
|
||||
GDCLASS(MLPPMatrixTests, MLPPTests);
|
||||
|
||||
public:
|
||||
void test_mlpp_matrix();
|
||||
void test_mlpp_matrix_mul();
|
||||
|
||||
MLPPMatrixTests();
|
||||
~MLPPMatrixTests();
|
||||
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
#endif
|
@ -1330,72 +1330,6 @@ void MLPPTests::test_mlpp_vector() {
|
||||
|
||||
is_approx_equals_vec(rv, rv2, "re-set_from_std_vectors test.");
|
||||
}
|
||||
void MLPPTests::test_mlpp_matrix() {
|
||||
std::vector<std::vector<real_t>> A = {
|
||||
{ 1, 0, 0, 0 },
|
||||
{ 0, 1, 0, 0 },
|
||||
{ 0, 0, 1, 0 },
|
||||
{ 0, 0, 0, 1 }
|
||||
};
|
||||
|
||||
Ref<MLPPMatrix> rmat;
|
||||
rmat.instance();
|
||||
rmat->set_from_std_vectors(A);
|
||||
|
||||
Ref<MLPPMatrix> rmat2;
|
||||
rmat2.instance();
|
||||
rmat2->set_from_std_vectors(A);
|
||||
|
||||
is_approx_equals_mat(rmat, rmat2, "set_from_std_vectors test.");
|
||||
|
||||
rmat2->set_from_std_vectors(A);
|
||||
|
||||
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
|
||||
}
|
||||
|
||||
void MLPPTests::test_mlpp_matrix_mul() {
|
||||
std::vector<std::vector<real_t>> A = {
|
||||
{ 1, 2 },
|
||||
{ 3, 4 },
|
||||
{ 5, 6 },
|
||||
{ 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> B = {
|
||||
{ 1, 2, 3, 4 },
|
||||
{ 5, 6, 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> C = {
|
||||
{ 11, 14, 17, 20 },
|
||||
{ 23, 30, 37, 44 },
|
||||
{ 35, 46, 57, 68 },
|
||||
{ 47, 62, 77, 92 }
|
||||
};
|
||||
|
||||
Ref<MLPPMatrix> rmata;
|
||||
rmata.instance();
|
||||
rmata->set_from_std_vectors(A);
|
||||
|
||||
Ref<MLPPMatrix> rmatb;
|
||||
rmatb.instance();
|
||||
rmatb->set_from_std_vectors(B);
|
||||
|
||||
Ref<MLPPMatrix> rmatc;
|
||||
rmatc.instance();
|
||||
rmatc->set_from_std_vectors(C);
|
||||
|
||||
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
|
||||
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
|
||||
|
||||
Ref<MLPPMatrix> rmatr2;
|
||||
rmatr2.instance();
|
||||
rmatr2->multb(rmata, rmatb);
|
||||
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
|
||||
|
||||
rmata->mult(rmatb);
|
||||
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
|
||||
}
|
||||
|
||||
void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) {
|
||||
if (!Math::is_equal_approx(a, b)) {
|
||||
@ -1623,6 +1557,4 @@ void MLPPTests::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("test_support_vector_classification_kernel", "ui"), &MLPPTests::test_support_vector_classification_kernel, false);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector);
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix);
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul);
|
||||
}
|
||||
|
@ -65,8 +65,6 @@ public:
|
||||
void test_support_vector_classification_kernel(bool ui = false);
|
||||
|
||||
void test_mlpp_vector();
|
||||
void test_mlpp_matrix();
|
||||
void test_mlpp_matrix_mul();
|
||||
|
||||
void is_approx_equalsd(real_t a, real_t b, const String &str);
|
||||
void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str);
|
||||
|
Loading…
Reference in New Issue
Block a user