Fix matmultm in LinAlg, and smaller improvements.

This commit is contained in:
Relintai 2023-02-05 09:50:39 +01:00
parent 9142592077
commit 2b33f8a5ed
2 changed files with 22 additions and 16 deletions

View File

@ -132,15 +132,15 @@ Ref<MLPPMatrix> MLPPLinAlg::subtractionm(const Ref<MLPPMatrix> &A, const Ref<MLP
}
Ref<MLPPMatrix> MLPPLinAlg::matmultm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref<MLPPMatrix>());
Size2i a_size = A->size();
Size2i b_size = B->size();
//TODO double check the formula for this
ERR_FAIL_COND_V(a_size.y != b_size.x || a_size.x != b_size.y, Ref<MLPPMatrix>());
ERR_FAIL_COND_V(a_size.x != b_size.y, Ref<MLPPMatrix>());
Ref<MLPPMatrix> C;
C.instance();
C->resize(a_size);
C->resize(Size2i(b_size.x, a_size.y));
C->fill(0);
const real_t *a_ptr = A->ptr();
@ -148,14 +148,16 @@ Ref<MLPPMatrix> MLPPLinAlg::matmultm(const Ref<MLPPMatrix> &A, const Ref<MLPPMat
real_t *c_ptr = C->ptrw();
for (int i = 0; i < a_size.y; i++) {
for (int k = 0; k < a_size.y; k++) {
for (int k = 0; k < b_size.y; k++) {
int ind_i_k = A->calculate_index(i, k);
for (int j = 0; j < a_size.x; j++) {
for (int j = 0; j < b_size.x; j++) {
int ind_i_j = C->calculate_index(i, j);
int ind_k_j = B->calculate_index(k, j);
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
//C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j
}
}
}
@ -2269,11 +2271,16 @@ std::vector<real_t> MLPPLinAlg::mat_vec_mult(std::vector<std::vector<real_t>> A,
}
Ref<MLPPMatrix> MLPPLinAlg::mat_vec_addv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b) {
Ref<MLPPMatrix> ret;
ret.instance();
ret->resize(A->size());
ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref<MLPPMatrix>());
Size2i a_size = A->size();
ERR_FAIL_COND_V(a_size.x != b->size(), Ref<MLPPMatrix>());
Ref<MLPPMatrix> ret;
ret.instance();
ret->resize(a_size);
const real_t *a_ptr = A->ptr();
const real_t *b_ptr = b->ptr();
real_t *ret_ptr = ret->ptrw();
@ -2289,13 +2296,17 @@ Ref<MLPPMatrix> MLPPLinAlg::mat_vec_addv(const Ref<MLPPMatrix> &A, const Ref<MLP
return ret;
}
Ref<MLPPVector> MLPPLinAlg::mat_vec_multv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b) {
Ref<MLPPVector> c;
c.instance();
ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref<MLPPMatrix>());
Size2i a_size = A->size();
int b_size = b->size();
ERR_FAIL_COND_V(a_size.x < b->size(), Ref<MLPPMatrix>());
Ref<MLPPVector> c;
c.instance();
c->resize(a_size.y);
c->fill(0);
const real_t *a_ptr = A->ptr();
const real_t *b_ptr = b->ptr();
@ -2305,7 +2316,7 @@ Ref<MLPPVector> MLPPLinAlg::mat_vec_multv(const Ref<MLPPMatrix> &A, const Ref<ML
for (int k = 0; k < b_size; ++k) {
int mat_index = A->calculate_index(i, k);
c_ptr[i] = a_ptr[mat_index] * b_ptr[k];
c_ptr[i] += a_ptr[mat_index] * b_ptr[k];
}
}

View File

@ -397,11 +397,6 @@ void MLPPMLP::forward_pass() {
MLPPLinAlg alg;
MLPPActivation avn;
//TODO
//ERR_PRINT(Variant(input_set->size()).operator String());
//ERR_PRINT(Variant(weights1->size()).operator String());
z2 = alg.mat_vec_addv(alg.matmultm(input_set, weights1), bias1);
a2 = avn.sigmoid_normv(z2);