From 81744cc4601ff7cbcfd33f4ffafc9d104211cbaf Mon Sep 17 00:00:00 2001 From: Relintai Date: Tue, 25 Apr 2023 14:05:40 +0200 Subject: [PATCH] Simplified hadamard_product in MLPPMatrix. --- mlpp/lin_alg/mlpp_matrix.cpp | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/mlpp/lin_alg/mlpp_matrix.cpp b/mlpp/lin_alg/mlpp_matrix.cpp index 61fbc29..2833827 100644 --- a/mlpp/lin_alg/mlpp_matrix.cpp +++ b/mlpp/lin_alg/mlpp_matrix.cpp @@ -250,20 +250,21 @@ void MLPPMatrix::hadamard_product(const Ref &B) { ERR_FAIL_COND(!B.is_valid()); ERR_FAIL_COND(_size != B->size()); + int ds = data_size(); + const real_t *b_ptr = B->ptr(); real_t *c_ptr = ptrw(); - for (int i = 0; i < _size.y; i++) { - for (int j = 0; j < _size.x; j++) { - int ind_i_j = calculate_index(i, j); - c_ptr[ind_i_j] = c_ptr[ind_i_j] * b_ptr[ind_i_j]; - } + for (int i = 0; i < ds; i++) { + c_ptr[i] = c_ptr[i] * b_ptr[i]; } } Ref MLPPMatrix::hadamard_productn(const Ref &B) const { ERR_FAIL_COND_V(!B.is_valid(), Ref()); ERR_FAIL_COND_V(_size != B->size(), Ref()); + int ds = data_size(); + Ref C; C.instance(); C->resize(_size); @@ -272,11 +273,8 @@ Ref MLPPMatrix::hadamard_productn(const Ref &B) const { const real_t *b_ptr = B->ptr(); real_t *c_ptr = C->ptrw(); - for (int i = 0; i < _size.y; i++) { - for (int j = 0; j < _size.x; j++) { - int ind_i_j = calculate_index(i, j); - c_ptr[ind_i_j] = a_ptr[ind_i_j] * b_ptr[ind_i_j]; - } + for (int i = 0; i < ds; i++) { + c_ptr[i] = a_ptr[i] * b_ptr[i]; } return C; @@ -290,15 +288,14 @@ void MLPPMatrix::hadamard_productb(const Ref &A, const Refptr(); const real_t *b_ptr = B->ptr(); real_t *c_ptr = ptrw(); - for (int i = 0; i < a_size.y; i++) { - for (int j = 0; j < a_size.x; j++) { - int ind_i_j = A->calculate_index(i, j); - c_ptr[ind_i_j] = a_ptr[ind_i_j] * b_ptr[ind_i_j]; - } + for (int i = 0; i < ds; i++) { + c_ptr[i] = a_ptr[i] * b_ptr[i]; } }