mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-02 16:29:35 +01:00
MLPPMatrix math api rework pt2.
This commit is contained in:
parent
a2c3e9badb
commit
70d7928cb0
@ -246,18 +246,53 @@ void MLPPMatrix::multb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
}
|
||||
}
|
||||
|
||||
Ref<MLPPMatrix> MLPPMatrix::hadamard_productnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref<MLPPMatrix>());
|
||||
Size2i a_size = A->size();
|
||||
ERR_FAIL_COND_V(a_size != B->size(), Ref<MLPPMatrix>());
|
||||
void MLPPMatrix::hadamard_product(const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND(!B.is_valid());
|
||||
ERR_FAIL_COND(_size != B->size());
|
||||
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = ptrw();
|
||||
|
||||
for (int i = 0; i < _size.y; i++) {
|
||||
for (int j = 0; j < _size.x; j++) {
|
||||
int ind_i_j = calculate_index(i, j);
|
||||
c_ptr[ind_i_j] = c_ptr[ind_i_j] * b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPMatrix::hadamard_productn(const Ref<MLPPMatrix> &B) const {
|
||||
ERR_FAIL_COND_V(!B.is_valid(), Ref<MLPPMatrix>());
|
||||
ERR_FAIL_COND_V(_size != B->size(), Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPMatrix> C;
|
||||
C.instance();
|
||||
C->resize(a_size);
|
||||
C->resize(_size);
|
||||
|
||||
const real_t *a_ptr = ptr();
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = C->ptrw();
|
||||
|
||||
for (int i = 0; i < _size.y; i++) {
|
||||
for (int j = 0; j < _size.x; j++) {
|
||||
int ind_i_j = calculate_index(i, j);
|
||||
c_ptr[ind_i_j] = a_ptr[ind_i_j] * b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
void MLPPMatrix::hadamard_productb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND(!A.is_valid() || !B.is_valid());
|
||||
Size2i a_size = A->size();
|
||||
ERR_FAIL_COND(a_size != B->size());
|
||||
|
||||
if (a_size != _size) {
|
||||
resize(a_size);
|
||||
}
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = C->ptrw();
|
||||
real_t *c_ptr = ptrw();
|
||||
|
||||
for (int i = 0; i < a_size.y; i++) {
|
||||
for (int j = 0; j < a_size.x; j++) {
|
||||
@ -265,10 +300,9 @@ Ref<MLPPMatrix> MLPPMatrix::hadamard_productnm(const Ref<MLPPMatrix> &A, const R
|
||||
c_ptr[ind_i_j] = a_ptr[ind_i_j] * b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPMatrix::kronecker_productnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
|
||||
void MLPPMatrix::kronecker_product(const Ref<MLPPMatrix> &B) {
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,2,3,4,5]
|
||||
@ -283,14 +317,102 @@ Ref<MLPPMatrix> MLPPMatrix::kronecker_productnm(const Ref<MLPPMatrix> &A, const
|
||||
// Resulting matrix: A.size() * B.size()
|
||||
// A[0].size() * B[0].size()
|
||||
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref<MLPPMatrix>());
|
||||
Size2i a_size = A->size();
|
||||
ERR_FAIL_COND(!B.is_valid());
|
||||
Size2i a_size = size();
|
||||
Size2i b_size = B->size();
|
||||
|
||||
Ref<MLPPMatrix> A = duplicate();
|
||||
|
||||
resize(Size2i(b_size.x * a_size.x, b_size.y * a_size.y));
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
|
||||
Ref<MLPPVector> row_tmp;
|
||||
row_tmp.instance();
|
||||
row_tmp->resize(b_size.x);
|
||||
|
||||
for (int i = 0; i < _size.y; ++i) {
|
||||
for (int j = 0; j < b_size.y; ++j) {
|
||||
B->get_row_into_mlpp_vector(j, row_tmp);
|
||||
|
||||
Vector<Ref<MLPPVector>> row;
|
||||
for (int k = 0; k < _size.x; ++k) {
|
||||
row.push_back(row_tmp->scalar_multiplyn(a_ptr[A->calculate_index(i, k)]));
|
||||
}
|
||||
|
||||
Ref<MLPPVector> flattened_row = row_tmp->flatten_vectorsn(row);
|
||||
|
||||
set_row_mlpp_vector(i * b_size.y + j, flattened_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPMatrix::kronecker_productn(const Ref<MLPPMatrix> &B) const {
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,2,3,4,5]
|
||||
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
|
||||
// Resulting matrix: A.size() * B.size()
|
||||
// A[0].size() * B[0].size()
|
||||
|
||||
ERR_FAIL_COND_V(!B.is_valid(), Ref<MLPPMatrix>());
|
||||
Size2i a_size = size();
|
||||
Size2i b_size = B->size();
|
||||
|
||||
Ref<MLPPMatrix> C;
|
||||
C.instance();
|
||||
C->resize(Size2i(b_size.x * a_size.x, b_size.y * a_size.y));
|
||||
|
||||
const real_t *a_ptr = ptr();
|
||||
|
||||
Ref<MLPPVector> row_tmp;
|
||||
row_tmp.instance();
|
||||
row_tmp->resize(b_size.x);
|
||||
|
||||
for (int i = 0; i < a_size.y; ++i) {
|
||||
for (int j = 0; j < b_size.y; ++j) {
|
||||
B->get_row_into_mlpp_vector(j, row_tmp);
|
||||
|
||||
Vector<Ref<MLPPVector>> row;
|
||||
for (int k = 0; k < a_size.x; ++k) {
|
||||
row.push_back(row_tmp->scalar_multiplyn(a_ptr[calculate_index(i, k)]));
|
||||
}
|
||||
|
||||
Ref<MLPPVector> flattened_row = row_tmp->flatten_vectorsn(row);
|
||||
|
||||
C->set_row_mlpp_vector(i * b_size.y + j, flattened_row);
|
||||
}
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
void MLPPMatrix::kronecker_productb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,1,1,1] [1,2,3,4,5]
|
||||
// [1,2,3,4,5]
|
||||
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
// [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5] [1,2,3,4,5]
|
||||
|
||||
// Resulting matrix: A.size() * B.size()
|
||||
// A[0].size() * B[0].size()
|
||||
|
||||
ERR_FAIL_COND(!A.is_valid() || !B.is_valid());
|
||||
Size2i a_size = A->size();
|
||||
Size2i b_size = B->size();
|
||||
|
||||
resize(Size2i(b_size.x * a_size.x, b_size.y * a_size.y));
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
|
||||
Ref<MLPPVector> row_tmp;
|
||||
@ -308,24 +430,58 @@ Ref<MLPPMatrix> MLPPMatrix::kronecker_productnm(const Ref<MLPPMatrix> &A, const
|
||||
|
||||
Ref<MLPPVector> flattened_row = row_tmp->flatten_vectorsn(row);
|
||||
|
||||
C->set_row_mlpp_vector(i * b_size.y + j, flattened_row);
|
||||
set_row_mlpp_vector(i * b_size.y + j, flattened_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MLPPMatrix::element_wise_division(const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND(!B.is_valid());
|
||||
ERR_FAIL_COND(_size != B->size());
|
||||
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = ptrw();
|
||||
|
||||
for (int i = 0; i < _size.y; i++) {
|
||||
for (int j = 0; j < _size.x; j++) {
|
||||
int ind_i_j = calculate_index(i, j);
|
||||
c_ptr[ind_i_j] /= b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionn(const Ref<MLPPMatrix> &B) const {
|
||||
ERR_FAIL_COND_V(!B.is_valid(), Ref<MLPPMatrix>());
|
||||
ERR_FAIL_COND_V(_size != B->size(), Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPMatrix> C;
|
||||
C.instance();
|
||||
C->resize(_size);
|
||||
|
||||
const real_t *a_ptr = ptr();
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = C->ptrw();
|
||||
|
||||
for (int i = 0; i < _size.y; i++) {
|
||||
for (int j = 0; j < _size.x; j++) {
|
||||
int ind_i_j = calculate_index(i, j);
|
||||
c_ptr[ind_i_j] = a_ptr[ind_i_j] / b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionnvnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND_V(!A.is_valid() || !B.is_valid(), Ref<MLPPMatrix>());
|
||||
void MLPPMatrix::element_wise_divisionb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B) {
|
||||
ERR_FAIL_COND(!A.is_valid() || !B.is_valid());
|
||||
Size2i a_size = A->size();
|
||||
ERR_FAIL_COND_V(a_size != B->size(), Ref<MLPPMatrix>());
|
||||
ERR_FAIL_COND(a_size != B->size());
|
||||
|
||||
Ref<MLPPMatrix> C;
|
||||
C.instance();
|
||||
C->resize(a_size);
|
||||
if (a_size != _size) {
|
||||
resize(a_size);
|
||||
}
|
||||
|
||||
const real_t *a_ptr = A->ptr();
|
||||
const real_t *b_ptr = B->ptr();
|
||||
real_t *c_ptr = C->ptrw();
|
||||
real_t *c_ptr = ptrw();
|
||||
|
||||
for (int i = 0; i < a_size.y; i++) {
|
||||
for (int j = 0; j < a_size.x; j++) {
|
||||
@ -333,8 +489,6 @@ Ref<MLPPMatrix> MLPPMatrix::element_wise_divisionnvnm(const Ref<MLPPMatrix> &A,
|
||||
c_ptr[ind_i_j] = a_ptr[ind_i_j] / b_ptr[ind_i_j];
|
||||
}
|
||||
}
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
Ref<MLPPMatrix> MLPPMatrix::transposenm(const Ref<MLPPMatrix> &A) {
|
||||
|
@ -601,9 +601,17 @@ public:
|
||||
Ref<MLPPMatrix> multn(const Ref<MLPPMatrix> &B) const;
|
||||
void multb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
|
||||
Ref<MLPPMatrix> hadamard_productnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
Ref<MLPPMatrix> kronecker_productnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
Ref<MLPPMatrix> element_wise_divisionnvnm(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
void hadamard_product(const Ref<MLPPMatrix> &B);
|
||||
Ref<MLPPMatrix> hadamard_productn(const Ref<MLPPMatrix> &B) const;
|
||||
void hadamard_productb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
|
||||
void kronecker_product(const Ref<MLPPMatrix> &B);
|
||||
Ref<MLPPMatrix> kronecker_productn(const Ref<MLPPMatrix> &B) const;
|
||||
void kronecker_productb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
|
||||
void element_wise_division(const Ref<MLPPMatrix> &B);
|
||||
Ref<MLPPMatrix> element_wise_divisionn(const Ref<MLPPMatrix> &B) const;
|
||||
void element_wise_divisionb(const Ref<MLPPMatrix> &A, const Ref<MLPPMatrix> &B);
|
||||
|
||||
Ref<MLPPMatrix> transposenm(const Ref<MLPPMatrix> &A);
|
||||
Ref<MLPPMatrix> scalar_multiplynm(real_t scalar, const Ref<MLPPMatrix> &A);
|
||||
|
Loading…
Reference in New Issue
Block a user