mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-02-01 17:07:02 +01:00
Fix matmultm in LinAlg, and smaller improvements.
This commit is contained in:
parent
9142592077
commit
2b33f8a5ed
@ -132,15 +132,15 @@ Ref<MLPPMatrix> MLPPLinAlg::subtractionm(const Ref<MLPPMatrix> &A, const Ref<MLP
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPLinAlg::matmultm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref<MLPPMatrix>());
|
||||
|
||||
Size2i a_size = A->size();
|
||||
Size2i b_size = B->size();
|
||||
|
||||
//TODO double check the formula for this
|
||||
ERR_FAIL_COND_V(a_size.y != b_size.x || a_size.x != b_size.y, Ref<MLPPMatrix>());
|
||||
ERR_FAIL_COND_V(a_size.x != b_size.y, Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPMatrix> C;
|
||||
C.instance();
|
||||
C->resize(a_size);
|
||||
C->resize(Size2i(b_size.x, a_size.y));
|
||||
C->fill(0);
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
@ -148,14 +148,16 @@ Ref<MLPPMatrix> MLPPLinAlg::matmultm(const Ref<MLPPMatrix> &A, const Ref<MLPPMat
|
||||
real_t *c_ptr = C->ptrw();
|
||||
|
||||
for (int i = 0; i < a_size.y; i++) {
|
||||
for (int k = 0; k < a_size.y; k++) {
|
||||
for (int k = 0; k < b_size.y; k++) {
|
||||
int ind_i_k = A->calculate_index(i, k);
|
||||
|
||||
for (int j = 0; j < a_size.x; j++) {
|
||||
for (int j = 0; j < b_size.x; j++) {
|
||||
int ind_i_j = C->calculate_index(i, j);
|
||||
int ind_k_j = B->calculate_index(k, j);
|
||||
|
||||
c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j];
|
||||
|
||||
//C->set_element(i, j, C->get_element(i, j) + A->get_element(i, k) * B->get_element(k, j
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2269,11 +2271,16 @@ std::vector<real_t> MLPPLinAlg::mat_vec_mult(std::vector<std::vector<real_t>> A,
|
||||
}
|
||||
|
||||
Ref<MLPPMatrix> MLPPLinAlg::mat_vec_addv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b) {
|
||||
Ref<MLPPMatrix> ret;
|
||||
ret.instance();
|
||||
ret->resize(A->size());
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref<MLPPMatrix>());
|
||||
|
||||
Size2i a_size = A->size();
|
||||
|
||||
ERR_FAIL_COND_V(a_size.x != b->size(), Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPMatrix> ret;
|
||||
ret.instance();
|
||||
ret->resize(a_size);
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
const real_t *b_ptr = b->ptr();
|
||||
real_t *ret_ptr = ret->ptrw();
|
||||
@ -2289,13 +2296,17 @@ Ref<MLPPMatrix> MLPPLinAlg::mat_vec_addv(const Ref<MLPPMatrix> &A, const Ref<MLP
|
||||
return ret;
|
||||
}
|
||||
Ref<MLPPVector> MLPPLinAlg::mat_vec_multv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b) {
|
||||
Ref<MLPPVector> c;
|
||||
c.instance();
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !b.is_valid(), Ref<MLPPMatrix>());
|
||||
|
||||
Size2i a_size = A->size();
|
||||
int b_size = b->size();
|
||||
|
||||
ERR_FAIL_COND_V(a_size.x < b->size(), Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPVector> c;
|
||||
c.instance();
|
||||
c->resize(a_size.y);
|
||||
c->fill(0);
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
const real_t *b_ptr = b->ptr();
|
||||
@ -2305,7 +2316,7 @@ Ref<MLPPVector> MLPPLinAlg::mat_vec_multv(const Ref<MLPPMatrix> &A, const Ref<ML
|
||||
for (int k = 0; k < b_size; ++k) {
|
||||
int mat_index = A->calculate_index(i, k);
|
||||
|
||||
c_ptr[i] = a_ptr[mat_index] * b_ptr[k];
|
||||
c_ptr[i] += a_ptr[mat_index] * b_ptr[k];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -397,11 +397,6 @@ void MLPPMLP::forward_pass() {
|
||||
MLPPLinAlg alg;
|
||||
MLPPActivation avn;
|
||||
|
||||
//TODO
|
||||
|
||||
//ERR_PRINT(Variant(input_set->size()).operator String());
|
||||
//ERR_PRINT(Variant(weights1->size()).operator String());
|
||||
|
||||
z2 = alg.mat_vec_addv(alg.matmultm(input_set, weights1), bias1);
|
||||
a2 = avn.sigmoid_normv(z2);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user