Added matrix multiplication tests, and fixed mult in MLPPMatrix.

This commit is contained in:
Relintai 2023-04-26 16:18:16 +02:00
parent 1497a2c1b0
commit e0b813eacf
3 changed files with 60 additions and 10 deletions

View File

@ -665,28 +665,36 @@ void MLPPMatrix::subb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
} }
void MLPPMatrix::mult(const Ref<MLPPMatrix> &B) { void MLPPMatrix::mult(const Ref<MLPPMatrix> &B) {
ERR_FAIL_MSG("TODO");
ERR_FAIL_COND(!B.is_valid()); ERR_FAIL_COND(!B.is_valid());
Size2i b_size = B->size(); Size2i b_size = B->size();
ERR_FAIL_COND(_size.x != b_size.y || _size.y != b_size.x); ERR_FAIL_COND(_size.x != b_size.y || _size.y != b_size.x);
//TODO need to make a copy of this, resize, and use that to get results into this Ref<MLPPMatrix> A = duplicate();
Size2i a_size = A->size();
Size2i rs = Size2i(b_size.x, a_size.y);
if (_size != rs) {
resize(rs);
}
fill(0);
const real_t *a_ptr = A->ptr();
const real_t *b_ptr = B->ptr(); const real_t *b_ptr = B->ptr();
real_t *c_ptr = ptrw(); real_t *c_ptr = ptrw();
for (int ay = 0; ay < _size.y; ay++) { for (int ay = 0; ay < a_size.y; ay++) {
for (int by = 0; by < b_size.y; by++) { for (int by = 0; by < b_size.y; by++) {
int ind_ay_by = calculate_index(ay, by); int ind_ay_by = A->calculate_index(ay, by);
for (int bx = 0; bx < b_size.x; bx++) { for (int bx = 0; bx < b_size.x; bx++) {
int ind_ay_bx = calculate_index(ay, bx); int ind_ay_bx = calculate_index(ay, bx);
int ind_by_bx = B->calculate_index(by, bx); int ind_by_bx = B->calculate_index(by, bx);
c_ptr[ind_ay_bx] += c_ptr[ind_ay_by] * b_ptr[ind_by_bx]; c_ptr[ind_ay_bx] += a_ptr[ind_ay_by] * b_ptr[ind_by_bx];
} }
} }
} }
@ -717,8 +725,6 @@ Ref<MLPPMatrix> MLPPMatrix::multn(const Ref<MLPPMatrix> &B) const {
int ind_k_j = B->calculate_index(k, j); int ind_k_j = B->calculate_index(k, j);
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
//C->set_element(i, j, C->get_element(i, j) + get_element(i, k) * B->get_element(k, j
} }
} }
} }
@ -752,8 +758,6 @@ void MLPPMatrix::multb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
int ind_k_j = B->calculate_index(k, j); int ind_k_j = B->calculate_index(k, j);
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
//C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j
} }
} }
} }

View File

@ -1353,6 +1353,50 @@ void MLPPTests::test_mlpp_matrix() {
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test."); is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
} }
void MLPPTests::test_mlpp_matrix_mul() {
std::vector<std::vector<real_t>> A = {
{ 1, 2 },
{ 3, 4 },
{ 5, 6 },
{ 7, 8 }
};
std::vector<std::vector<real_t>> B = {
{ 1, 2, 3, 4 },
{ 5, 6, 7, 8 }
};
std::vector<std::vector<real_t>> C = {
{ 11, 14, 17, 20 },
{ 23, 30, 37, 44 },
{ 35, 46, 57, 68 },
{ 47, 62, 77, 92 }
};
Ref<MLPPMatrix> rmata;
rmata.instance();
rmata->set_from_std_vectors(A);
Ref<MLPPMatrix> rmatb;
rmatb.instance();
rmatb->set_from_std_vectors(B);
Ref<MLPPMatrix> rmatc;
rmatc.instance();
rmatc->set_from_std_vectors(C);
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
Ref<MLPPMatrix> rmatr2;
rmatr2.instance();
rmatr2->multb(rmata, rmatb);
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
rmata->mult(rmatb);
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
}
void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) { void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) {
if (!Math::is_equal_approx(a, b)) { if (!Math::is_equal_approx(a, b)) {
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b)); ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
@ -1580,4 +1624,5 @@ void MLPPTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector); ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector);
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix); ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix);
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul);
} }

View File

@ -66,6 +66,7 @@ public:
void test_mlpp_vector(); void test_mlpp_vector();
void test_mlpp_matrix(); void test_mlpp_matrix();
void test_mlpp_matrix_mul();
void is_approx_equalsd(real_t a, real_t b, const String &str); void is_approx_equalsd(real_t a, real_t b, const String &str);
void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str); void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str);