From 1d28a330747138b6bb5a3b962505409e2bb24a66 Mon Sep 17 00:00:00 2001 From: Relintai Date: Wed, 27 Dec 2023 11:00:59 +0100 Subject: [PATCH] Fully ported MLPPTests::test_linear_algebra(). --- test/mlpp_tests.cpp | 82 +++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 9654921..e958d5a 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -141,53 +141,61 @@ void MLPPTests::test_linear_algebra() { Ref square_rot(memnew(MLPPMatrix(square_rot_res_arr, 4, 2))); is_approx_equals_mat(square->rotaten(Math_PI / 4), square_rot, "square->rotaten(Math_PI / 4)"); - /* - std::vector> A = { - { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, - { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, - }; - std::vector a = { 4, 3, 1, 3 }; - std::vector b = { 3, 5, 6, 1 }; - std::vector> mmtr_res = { - { 2, 4, 6, 8, 10, 12, 14, 16, 18, 20 }, - { 4, 8, 12, 16, 20, 24, 28, 32, 36, 40 }, - { 6, 12, 18, 24, 30, 36, 42, 48, 54, 60 }, - { 8, 16, 24, 32, 40, 48, 56, 64, 72, 80 }, - { 10, 20, 30, 40, 50, 60, 70, 80, 90, 100 }, - { 12, 24, 36, 48, 60, 72, 84, 96, 108, 120 }, - { 14, 28, 42, 56, 70, 84, 98, 112, 126, 140 }, - { 16, 32, 48, 64, 80, 96, 112, 128, 144, 160 }, - { 18, 36, 54, 72, 90, 108, 126, 144, 162, 180 }, - { 20, 40, 60, 80, 100, 120, 140, 160, 180, 200 } + const real_t A_arr[] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // + }; + const real_t a_arr[] = { 4, 3, 1, 3 }; + const real_t b_arr[] = { 3, 5, 6, 1 }; + + const real_t mmtr_res_arr[] = { + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, // + 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, // + 6, 12, 18, 24, 30, 36, 42, 48, 54, 60, // + 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, // + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, // + 12, 24, 36, 48, 60, 72, 84, 96, 108, 120, // + 14, 28, 42, 56, 70, 84, 98, 112, 126, 140, // + 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, // + 18, 36, 54, 72, 90, 108, 126, 144, 162, 180, // + 20, 40, 60, 80, 100, 120, 140, 160, 180, 200 // }; - is_approx_equals_dmat(dstd_mat_to_mat_old(alg.matmult(alg.transpose(A), A)), dstd_mat_to_mat_old(mmtr_res), "alg.matmult(alg.transpose(A), A)"); + Ref A(memnew(MLPPMatrix(A_arr, 2, 10))); + Ref a(memnew(MLPPVector(a_arr, 4))); + Ref b(memnew(MLPPVector(b_arr, 4))); + Ref mmtr_res(memnew(MLPPMatrix(mmtr_res_arr, 10, 10))); - is_approx_equalsd(alg.dot(a, b), 36, "alg.dot(a, b)"); + is_approx_equals_mat(alg.matmultnm(alg.transposenm(A), A), mmtr_res, "alg.matmultnm(alg.transposenm(A), A)"); - std::vector> had_prod_res = { - { 1, 4, 9, 16, 25, 36, 49, 64, 81, 100 }, - { 1, 4, 9, 16, 25, 36, 49, 64, 81, 100 } + is_approx_equalsd(alg.dotnv(a, b), 36, "alg.dotnv(a, b)"); + + const real_t had_prod_res_arr[] = { + 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, // + 1, 4, 9, 16, 25, 36, 49, 64, 81, 100 // }; - is_approx_equals_dmat(dstd_mat_to_mat_old(alg.hadamard_product(A, A)), dstd_mat_to_mat_old(had_prod_res), "alg.hadamard_product(A, A)"); + Ref had_prod_res(memnew(MLPPMatrix(had_prod_res_arr, 2, 10))); - std::vector> id_10_res = { - { 1, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 1, 0, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 1, 0, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 1, 0, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 0, 1, 0 }, - { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, + is_approx_equals_mat(alg.hadamard_productnm(A, A), had_prod_res, "alg.hadamard_productnm(A, A)"); + + const real_t id_10_res_arr[] = { + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // }; - is_approx_equals_dmat(dstd_mat_to_mat_old(alg.identity(10)), dstd_mat_to_mat_old(id_10_res), "alg.identity(10)"); - */ + Ref id_10_res(memnew(MLPPMatrix(id_10_res_arr, 10, 10))); + + is_approx_equals_mat(alg.identitym(10), id_10_res, "alg.identitym(10)"); } void MLPPTests::test_univariate_linear_regression() {