mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-17 14:57:19 +01:00
Added matrix multiplication tests, and fixed mult in MLPPMatrix.
This commit is contained in:
parent
1497a2c1b0
commit
e0b813eacf
@ -665,28 +665,36 @@ void MLPPMatrix::subb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
}
|
||||
|
||||
void MLPPMatrix::mult(const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_MSG("TODO");
|
||||
|
||||
ERR_FAIL_COND(!B.is_valid());
|
||||
|
||||
Size2i b_size = B->size();
|
||||
|
||||
ERR_FAIL_COND(_size.x != b_size.y || _size.y != b_size.x);
|
||||
|
||||
//TODO need to make a copy of this, resize, and use that to get results into this
|
||||
Ref<MLPPMatrix> A = duplicate();
|
||||
Size2i a_size = A->size();
|
||||
|
||||
Size2i rs = Size2i(b_size.x, a_size.y);
|
||||
|
||||
if (_size != rs) {
|
||||
resize(rs);
|
||||
}
|
||||
|
||||
fill(0);
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = ptrw();
|
||||
|
||||
for (int ay = 0; ay < _size.y; ay++) {
|
||||
for (int ay = 0; ay < a_size.y; ay++) {
|
||||
for (int by = 0; by < b_size.y; by++) {
|
||||
int ind_ay_by = calculate_index(ay, by);
|
||||
int ind_ay_by = A->calculate_index(ay, by);
|
||||
|
||||
for (int bx = 0; bx < b_size.x; bx++) {
|
||||
int ind_ay_bx = calculate_index(ay, bx);
|
||||
int ind_by_bx = B->calculate_index(by, bx);
|
||||
|
||||
c_ptr[ind_ay_bx] += c_ptr[ind_ay_by] * b_ptr[ind_by_bx];
|
||||
c_ptr[ind_ay_bx] += a_ptr[ind_ay_by] * b_ptr[ind_by_bx];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -717,8 +725,6 @@ Ref<MLPPMatrix> MLPPMatrix::multn(const Ref<MLPPMatrix> &B) const {
|
||||
int ind_k_j = B->calculate_index(k, j);
|
||||
|
||||
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
|
||||
|
||||
//C->set_element(i, j, C->get_element(i, j) + get_element(i, k) * B->get_element(k, j
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -752,8 +758,6 @@ void MLPPMatrix::multb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
int ind_k_j = B->calculate_index(k, j);
|
||||
|
||||
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
|
||||
|
||||
//C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1353,6 +1353,50 @@ void MLPPTests::test_mlpp_matrix() {
|
||||
is_approx_equals_mat(rmat, rmat2, "re-set_from_std_vectors test.");
|
||||
}
|
||||
|
||||
void MLPPTests::test_mlpp_matrix_mul() {
|
||||
std::vector<std::vector<real_t>> A = {
|
||||
{ 1, 2 },
|
||||
{ 3, 4 },
|
||||
{ 5, 6 },
|
||||
{ 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> B = {
|
||||
{ 1, 2, 3, 4 },
|
||||
{ 5, 6, 7, 8 }
|
||||
};
|
||||
|
||||
std::vector<std::vector<real_t>> C = {
|
||||
{ 11, 14, 17, 20 },
|
||||
{ 23, 30, 37, 44 },
|
||||
{ 35, 46, 57, 68 },
|
||||
{ 47, 62, 77, 92 }
|
||||
};
|
||||
|
||||
Ref<MLPPMatrix> rmata;
|
||||
rmata.instance();
|
||||
rmata->set_from_std_vectors(A);
|
||||
|
||||
Ref<MLPPMatrix> rmatb;
|
||||
rmatb.instance();
|
||||
rmatb->set_from_std_vectors(B);
|
||||
|
||||
Ref<MLPPMatrix> rmatc;
|
||||
rmatc.instance();
|
||||
rmatc->set_from_std_vectors(C);
|
||||
|
||||
Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);
|
||||
is_approx_equals_mat(rmatr1, rmatc, "Ref<MLPPMatrix> rmatr1 = rmata->multn(rmatb);");
|
||||
|
||||
Ref<MLPPMatrix> rmatr2;
|
||||
rmatr2.instance();
|
||||
rmatr2->multb(rmata, rmatb);
|
||||
is_approx_equals_mat(rmatr2, rmatc, "rmatr2->multb(rmata, rmatb);");
|
||||
|
||||
rmata->mult(rmatb);
|
||||
is_approx_equals_mat(rmata, rmatc, "rmata->mult(rmatb);");
|
||||
}
|
||||
|
||||
void MLPPTests::is_approx_equalsd(real_t a, real_t b, const String &str) {
|
||||
if (!Math::is_equal_approx(a, b)) {
|
||||
ERR_PRINT("TEST FAILED: " + str + " Got: " + String::num(a) + " Should be: " + String::num(b));
|
||||
@ -1580,4 +1624,5 @@ void MLPPTests::_bind_methods() {
|
||||
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_vector"), &MLPPTests::test_mlpp_vector);
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix"), &MLPPTests::test_mlpp_matrix);
|
||||
ClassDB::bind_method(D_METHOD("test_mlpp_matrix_mul"), &MLPPTests::test_mlpp_matrix_mul);
|
||||
}
|
||||
|
@ -66,6 +66,7 @@ public:
|
||||
|
||||
void test_mlpp_vector();
|
||||
void test_mlpp_matrix();
|
||||
void test_mlpp_matrix_mul();
|
||||
|
||||
void is_approx_equalsd(real_t a, real_t b, const String &str);
|
||||
void is_approx_equals_dvec(const Vector<real_t> &a, const Vector<real_t> &b, const String &str);
|
||||
|
Loading…
Reference in New Issue
Block a user