Simplified element_wise_division in MLPPMatrix.

This commit is contained in:
Relintai 2023-04-25 13:47:18 +02:00
parent 134d7966c0
commit 9df031e10c

View File

@ -439,20 +439,21 @@ void MLPPMatrix::element_wise_division(const Ref<MLPPMatrix> &B) {
ERR_FAIL_COND(!B.is_valid()); ERR_FAIL_COND(!B.is_valid());
ERR_FAIL_COND(_size != B->size()); ERR_FAIL_COND(_size != B->size());
int ds = data_size();
const real_t *b_ptr = B->ptr(); const real_t *b_ptr = B->ptr();
real_t *c_ptr = ptrw(); real_t *c_ptr = ptrw();
for (int i = 0; i < _size.y; i++) { for (int i = 0; i < ds; i++) {
for (int j = 0; j < _size.x; j++) { c_ptr[i] /= b_ptr[i];
int ind_i_j = calculate_index(i, j);
c_ptr[ind_i_j] /= b_ptr[ind_i_j];
}
} }
} }
Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionn(const Ref<MLPPMatrix> &B) const { Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionn(const Ref<MLPPMatrix> &B) const {
ERR_FAIL_COND_V(!B.is_valid(), Ref<MLPPMatrix>()); ERR_FAIL_COND_V(!B.is_valid(), Ref<MLPPMatrix>());
ERR_FAIL_COND_V(_size != B->size(), Ref<MLPPMatrix>()); ERR_FAIL_COND_V(_size != B->size(), Ref<MLPPMatrix>());
int ds = data_size();
Ref<MLPPMatrix> C; Ref<MLPPMatrix> C;
C.instance(); C.instance();
C->resize(_size); C->resize(_size);
@ -461,11 +462,8 @@ Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionn(const Ref<MLPPMatrix> &B) con
const real_t *b_ptr = B->ptr(); const real_t *b_ptr = B->ptr();
real_t *c_ptr = C->ptrw(); real_t *c_ptr = C->ptrw();
for (int i = 0; i < _size.y; i++) { for (int i = 0; i < ds; i++) {
for (int j = 0; j < _size.x; j++) { c_ptr[i] = a_ptr[i] / b_ptr[i];
int ind_i_j = calculate_index(i, j);
c_ptr[ind_i_j] = a_ptr[ind_i_j] / b_ptr[ind_i_j];
}
} }
return C; return C;
@ -479,15 +477,14 @@ void MLPPMatrix::element_wise_divisionb(const Ref<MLPPMatrix> &A, const Ref<MLPP
resize(a_size); resize(a_size);
} }
int ds = data_size();
const real_t *a_ptr = A->ptr(); const real_t *a_ptr = A->ptr();
const real_t *b_ptr = B->ptr(); const real_t *b_ptr = B->ptr();
real_t *c_ptr = ptrw(); real_t *c_ptr = ptrw();
for (int i = 0; i < a_size.y; i++) { for (int i = 0; i < ds; i++) {
for (int j = 0; j < a_size.x; j++) { c_ptr[i] = a_ptr[i] / b_ptr[i];
int ind_i_j = A->calculate_index(i, j);
c_ptr[ind_i_j] = a_ptr[ind_i_j] / b_ptr[ind_i_j];
}
} }
} }