From f8f3edf6ef7e579417c594bb089844e482ccd25c Mon Sep 17 00:00:00 2001 From: Relintai Date: Mon, 24 Apr 2023 12:24:09 +0200 Subject: [PATCH] Start reworking vector's math api. Alno some notes. --- mlpp/lin_alg/mlpp_vector.cpp | 71 +++++++++++++++++++++++++++--------- mlpp/lin_alg/mlpp_vector.h | 20 +++++++--- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/mlpp/lin_alg/mlpp_vector.cpp b/mlpp/lin_alg/mlpp_vector.cpp index aad1c33..9dc4ad1 100644 --- a/mlpp/lin_alg/mlpp_vector.cpp +++ b/mlpp/lin_alg/mlpp_vector.cpp @@ -3,7 +3,33 @@ #include "mlpp_matrix.h" -Ref MLPPVector::flattenmnv(const Vector> &A) { + +void MLPPVector::flatten_vectors(const Vector> &A) { + int vsize = 0; + for (int i = 0; i < A.size(); ++i) { + vsize += A[i]->size(); + } + + resize(vsize); + + int a_index = 0; + real_t *a_ptr = ptrw(); + + for (int i = 0; i < A.size(); ++i) { + const Ref &r = A[i]; + + int r_size = r->size(); + const real_t *r_ptr = r->ptr(); + + for (int j = 0; j < r_size; ++j) { + a_ptr[a_index] = r_ptr[j]; + ++a_index; + } + } +} + + +Ref MLPPVector::flatten_vectorsn(const Vector> &A) { Ref a; a.instance(); @@ -32,44 +58,55 @@ Ref MLPPVector::flattenmnv(const Vector> &A) { return a; } -Ref MLPPVector::hadamard_productnv(const Ref &a, const Ref &b) { - ERR_FAIL_COND_V(!a.is_valid() || !b.is_valid(), Ref()); +void MLPPVector::hadamard_product(const Ref &b) { + ERR_FAIL_COND(!b.is_valid()); + + ERR_FAIL_COND(_size != b->size()); + + const real_t *a_ptr = ptr(); + const real_t *b_ptr = b->ptr(); + real_t *out_ptr = ptrw(); + + for (int i = 0; i < _size; ++i) { + out_ptr[i] = a_ptr[i] * b_ptr[i]; + } +} +Ref MLPPVector::hadamard_productn(const Ref &b) { + ERR_FAIL_COND_V(!b.is_valid(), Ref()); Ref out; out.instance(); - int size = a->size(); + ERR_FAIL_COND_V(_size != b->size(), Ref()); - ERR_FAIL_COND_V(size != b->size(), Ref()); + out->resize(_size); - out->resize(size); - - const real_t *a_ptr = a->ptr(); + const real_t *a_ptr = ptr(); const real_t *b_ptr = b->ptr(); real_t *out_ptr = out->ptrw(); - for (int i = 0; i < size; ++i) { + for (int i = 0; i < _size; ++i) { out_ptr[i] = a_ptr[i] * b_ptr[i]; } return out; } -void MLPPVector::hadamard_productv(const Ref &a, const Ref &b, Ref out) { - ERR_FAIL_COND(!a.is_valid() || !b.is_valid() || !out.is_valid()); +void MLPPVector::hadamard_productb(const Ref &a, const Ref &b) { + ERR_FAIL_COND(!a.is_valid() || !b.is_valid()); - int size = a->size(); + int s = a->size(); - ERR_FAIL_COND(size != b->size()); + ERR_FAIL_COND(s != b->size()); - if (unlikely(out->size() != size)) { - out->resize(size); + if (unlikely(size() != s)) { + resize(s); } const real_t *a_ptr = a->ptr(); const real_t *b_ptr = b->ptr(); - real_t *out_ptr = out->ptrw(); + real_t *out_ptr = ptrw(); - for (int i = 0; i < size; ++i) { + for (int i = 0; i < s; ++i) { out_ptr[i] = a_ptr[i] * b_ptr[i]; } } diff --git a/mlpp/lin_alg/mlpp_vector.h b/mlpp/lin_alg/mlpp_vector.h index 9268471..e5ebb7e 100644 --- a/mlpp/lin_alg/mlpp_vector.h +++ b/mlpp/lin_alg/mlpp_vector.h @@ -330,14 +330,22 @@ public: } // New apis should look like this: - //Ref substract(const Ref &b); - //void substracted(const Ref &b); - //void subtraction(const Ref &a, const Ref &b); -> result is in this (subtractionv like) + //void substract(const Ref &b); <- this should be the simplest / most obvious method + //Ref substractn(const Ref &b); + //void substractb(const Ref &a, const Ref &b); -> result is in this (subtractionv like) + + // Or: + //void hadamard_product(const Ref &b); <- this should be the simplest / most obvious method + //Ref hadamard_productn(const Ref &b); <- n -> new + //void hadamard_productb(const Ref &a, const Ref &b); <- b -> between, result is stored in *this - Ref flattenmnv(const Vector> &A); + + void flatten_vectors(const Vector> &A); + Ref flatten_vectorsn(const Vector> &A); - Ref hadamard_productnv(const Ref &a, const Ref &b); - void hadamard_productv(const Ref &a, const Ref &b, Ref out); + void hadamard_product(const Ref &b); + Ref hadamard_productn(const Ref &b); + void hadamard_productb(const Ref &a, const Ref &b); Ref element_wise_divisionnv(const Ref &a, const Ref &b);