From 2b33f8a5edc00537e0b397b931cc176aa8bd8c82 Mon Sep 17 00:00:00 2001 From: Relintai Date: Sun, 5 Feb 2023 09:50:39 +0100 Subject: [PATCH] Fix matmultm in LinAlg, and smaller improvements. --- mlpp/lin_alg/lin_alg.cpp | 33 ++++++++++++++++++++++----------- mlpp/mlp/mlp.cpp | 5 ----- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/mlpp/lin_alg/lin_alg.cpp b/mlpp/lin_alg/lin_alg.cpp index 5d3843e..8f8fccd 100644 --- a/mlpp/lin_alg/lin_alg.cpp +++ b/mlpp/lin_alg/lin_alg.cpp @@ -132,15 +132,15 @@ Ref MLPPLinAlg::subtractionm(const Ref &A, const Ref MLPPLinAlg::matmultm(const Ref &A, const Ref &B) { ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref()); + Size2i a_size = A->size(); Size2i b_size = B->size(); - //TODO double check the formula for this - ERR_FAIL_COND_V(a_size.y != b_size.x || a_size.x != b_size.y, Ref()); + ERR_FAIL_COND_V(a_size.x != b_size.y, Ref()); Ref C; C.instance(); - C->resize(a_size); + C->resize(Size2i(b_size.x, a_size.y)); C->fill(0); const real_t *a_ptr = A->ptr(); @@ -148,14 +148,16 @@ Ref MLPPLinAlg::matmultm(const Ref &A, const Refptrw(); for (int i = 0; i < a_size.y; i++) { - for (int k = 0; k < a_size.y; k++) { + for (int k = 0; k < b_size.y; k++) { int ind_i_k = A->calculate_index(i, k); - for (int j = 0; j < a_size.x; j++) { + for (int j = 0; j < b_size.x; j++) { int ind_i_j = C->calculate_index(i, j); int ind_k_j = B->calculate_index(k, j); c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; + + //C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j } } } @@ -2269,11 +2271,16 @@ std::vector MLPPLinAlg::mat_vec_mult(std::vector> A, } Ref MLPPLinAlg::mat_vec_addv(const Ref &A, const Ref &b) { - Ref ret; - ret.instance(); - ret->resize(A->size()); + ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref()); Size2i a_size = A->size(); + + ERR_FAIL_COND_V(a_size.x != b->size(), Ref()); + + Ref ret; + ret.instance(); + ret->resize(a_size); + const real_t *a_ptr = A->ptr(); const real_t *b_ptr = b->ptr(); real_t *ret_ptr = ret->ptrw(); @@ -2289,13 +2296,17 @@ Ref MLPPLinAlg::mat_vec_addv(const Ref &A, const Ref MLPPLinAlg::mat_vec_multv(const Ref &A, const Ref &b) { - Ref c; - c.instance(); + ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref()); Size2i a_size = A->size(); int b_size = b->size(); + ERR_FAIL_COND_V(a_size.x < b->size(), Ref()); + + Ref c; + c.instance(); c->resize(a_size.y); + c->fill(0); const real_t *a_ptr = A->ptr(); const real_t *b_ptr = b->ptr(); @@ -2305,7 +2316,7 @@ Ref MLPPLinAlg::mat_vec_multv(const Ref &A, const Refcalculate_index(i, k); - c_ptr[i] = a_ptr[mat_index] * b_ptr[k]; + c_ptr[i] += a_ptr[mat_index] * b_ptr[k]; } } diff --git a/mlpp/mlp/mlp.cpp b/mlpp/mlp/mlp.cpp index e946c4a..35e4767 100644 --- a/mlpp/mlp/mlp.cpp +++ b/mlpp/mlp/mlp.cpp @@ -397,11 +397,6 @@ void MLPPMLP::forward_pass() { MLPPLinAlg alg; MLPPActivation avn; - //TODO - - //ERR_PRINT(Variant(input_set->size()).operator String()); - //ERR_PRINT(Variant(weights1->size()).operator String()); - z2 = alg.mat_vec_addv(alg.matmultm(input_set, weights1), bias1); a2 = avn.sigmoid_normv(z2);