diff --git a/mlpp/lin_alg/lin_alg.cpp b/mlpp/lin_alg/lin_alg.cpp index 4b54779..5d3843e 100644 --- a/mlpp/lin_alg/lin_alg.cpp +++ b/mlpp/lin_alg/lin_alg.cpp @@ -133,7 +133,10 @@ Ref MLPPLinAlg::subtractionm(const Ref &A, const Ref MLPPLinAlg::matmultm(const Ref &A, const Ref &B) { ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref()); Size2i a_size = A->size(); - ERR_FAIL_COND_V(a_size != B->size(), Ref()); + Size2i b_size = B->size(); + + //TODO double check the formula for this + ERR_FAIL_COND_V(a_size.y != b_size.x || a_size.x != b_size.y, Ref()); Ref C; C.instance(); @@ -149,8 +152,8 @@ Ref MLPPLinAlg::matmultm(const Ref &A, const Refcalculate_index(i, k); for (int j = 0; j < a_size.x; j++) { - int ind_i_j = A->calculate_index(i, j); - int ind_k_j = A->calculate_index(k, j); + int ind_i_j = C->calculate_index(i, j); + int ind_k_j = B->calculate_index(k, j); c_ptr[ind_i_j] += a_ptr[ind_i_k] * b_ptr[ind_k_j]; } diff --git a/mlpp/mlp/mlp.cpp b/mlpp/mlp/mlp.cpp index 35e4767..e946c4a 100644 --- a/mlpp/mlp/mlp.cpp +++ b/mlpp/mlp/mlp.cpp @@ -397,6 +397,11 @@ void MLPPMLP::forward_pass() { MLPPLinAlg alg; MLPPActivation avn; + //TODO + + //ERR_PRINT(Variant(input_set->size()).operator String()); + //ERR_PRINT(Variant(weights1->size()).operator String()); + z2 = alg.mat_vec_addv(alg.matmultm(input_set, weights1), bias1); a2 = avn.sigmoid_normv(z2);