diff --git a/mlpp/lin_alg/mlpp_matrix.cpp b/mlpp/lin_alg/mlpp_matrix.cpp index 8713d98..83504cf 100644 --- a/mlpp/lin_alg/mlpp_matrix.cpp +++ b/mlpp/lin_alg/mlpp_matrix.cpp @@ -665,28 +665,36 @@ void MLPPMatrix::subb(const Ref &A, const Ref &B) { } void MLPPMatrix::mult(const Ref &B) { - ERR_FAIL_MSG("TODO"); - ERR_FAIL_COND(!B.is_valid()); Size2i b_size = B->size(); ERR_FAIL_COND(_size.x != b_size.y || _size.y != b_size.x); - //TODO need to make a copy of this, resize, and use that to get results into this + Ref A = duplicate(); + Size2i a_size = A->size(); + Size2i rs = Size2i(b_size.x, a_size.y); + + if (_size != rs) { + resize(rs); + } + + fill(0); + + const real_t *a_ptr = A->ptr(); const real_t *b_ptr = B->ptr(); real_t *c_ptr = ptrw(); - for (int ay = 0; ay < _size.y; ay++) { + for (int ay = 0; ay < a_size.y; ay++) { for (int by = 0; by < b_size.y; by++) { - int ind_ay_by = calculate_index(ay, by); + int ind_ay_by = A->calculate_index(ay, by); for (int bx = 0; bx < b_size.x; bx++) { int ind_ay_bx = calculate_index(ay, bx); int ind_by_bx = B->calculate_index(by, bx); - c_ptr[ind_ay_bx] += c_ptr[ind_ay_by] * b_ptr[ind_by_bx]; + c_ptr[ind_ay_bx] += a_ptr[ind_ay_by] * b_ptr[ind_by_bx]; } } } @@ -717,8 +725,6 @@ Ref MLPPMatrix::multn(const Ref &B) const { int ind_k_j = B->calculate_index(k, j); c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; - - //C->set_element(i, j, C->get_element(i, j) + get_element(i, k) * B->get_element(k, j } } } @@ -752,8 +758,6 @@ void MLPPMatrix::multb(const Ref &A, const Ref &B) { int ind_k_j = B->calculate_index(k, j); c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; - - //C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j } } } diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 60abbea..ac9174e 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -1353,6 +1353,50 @@ void MLPPTests::test_mlpp_matrix() { is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test."); } +void MLPPTests::test_mlpp_matrix_mul() { + std::vector> A = { + { 1, 2 }, + { 3, 4 }, + { 5, 6 }, + { 7, 8 } + }; + + std::vector> B = { + { 1, 2, 3, 4 }, + { 5, 6, 7, 8 } + }; + + std::vector> C = { + { 11, 14, 17, 20 }, + { 23, 30, 37, 44 }, + { 35, 46, 57, 68 }, + { 47, 62, 77, 92 } + }; + + Ref rmata; + rmata.instance(); + rmata->set_from_std_vectors(A); + + Ref rmatb; + rmatb.instance(); + rmatb->set_from_std_vectors(B); + + Ref rmatc; + rmatc.instance(); + rmatc->set_from_std_vectors(C); + + Ref rmatr1 = rmata->multn(rmatb); + is_approx_equals_mat(rmatr1, rmatc, "Ref rmatr1 = rmata->multn(rmatb);"); + + Ref rmatr2; + rmatr2.instance(); + rmatr2->multb(rmata, rmatb); + is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);"); + + rmata->mult(rmatb); + is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);"); +} + void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) { if (!Math::is_equal_approx(a, b)) { ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b)); @@ -1580,4 +1624,5 @@ void MLPPTests::_bind_methods() { ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector); ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix); + ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul); } diff --git a/test/mlpp_tests.h b/test/mlpp_tests.h index a88d1b0..e993274 100644 --- a/test/mlpp_tests.h +++ b/test/mlpp_tests.h @@ -66,6 +66,7 @@ public: void test_mlpp_vector(); void test_mlpp_matrix(); + void test_mlpp_matrix_mul(); void is_approx_equalsd(real_t a, real_t b, const String &str); void is_approx_equals_dvec(const Vector &a, const Vector &b, const String &str);