MLPPVector math api rework pt4.

This commit is contained in:
Relintai 2023-04-24 18:11:22 +02:00
parent 63393dd662
commit bd9c03f1f2
4 changed files with 91 additions and 46 deletions

View File

@ -1121,6 +1121,26 @@ Ref<MLPPMatrix> MLPPMatrix::mat_vec_addnm(const Ref<MLPPMatrix> &A, const Ref<ML
return ret; return ret;
} }
Ref<MLPPMatrix> MLPPMatrix::outer_product(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b) {
Ref<MLPPMatrix> C;
C.instance();
Size2i size = Size2i(b->size(), a->size());
C->resize(size);
const real_t *a_ptr = a->ptr();
const real_t *b_ptr = b->ptr();
for (int i = 0; i < size.y; ++i) {
real_t curr_a = a_ptr[i];
for (int j = 0; j < size.x; ++j) {
C->set_element(i, j, curr_a * b_ptr[j]);
}
}
return C;
}
Ref<MLPPMatrix> MLPPMatrix::diagnm(const Ref<MLPPVector> &a) { Ref<MLPPMatrix> MLPPMatrix::diagnm(const Ref<MLPPVector> &a) {
int a_size = a->size(); int a_size = a->size();

View File

@ -701,6 +701,8 @@ public:
Ref<MLPPVector> mat_vec_multnv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b); Ref<MLPPVector> mat_vec_multnv(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b);
Ref<MLPPMatrix> mat_vec_addnm(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b); Ref<MLPPMatrix> mat_vec_addnm(const Ref<MLPPMatrix> &A, const Ref<MLPPVector> &b);
Ref<MLPPMatrix> outer_product(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b); // This multiplies a, bT
// set_diagonal (just sets diagonal), set_as_diagonal (zeros, then sets diagonal to vec) // set_diagonal (just sets diagonal), set_as_diagonal (zeros, then sets diagonal to vec)
// Also a variant that copies // Also a variant that copies
Ref<MLPPMatrix> diagnm(const Ref<MLPPVector> &a); Ref<MLPPMatrix> diagnm(const Ref<MLPPVector> &a);

View File

@ -838,37 +838,33 @@ std::vector<std::vector<real_t>> MLPPVector::round(std::vector<std::vector<real_
} }
*/ */
real_t MLPPVector::euclidean_distance(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b) { real_t MLPPVector::euclidean_distance(const Ref<MLPPVector> &b) {
ERR_FAIL_COND_V(!a.is_valid() || !b.is_valid(), 0); ERR_FAIL_COND_V(!b.is_valid(), 0);
int a_size = a->size(); ERR_FAIL_COND_V(_size != b->size(), 0);
ERR_FAIL_COND_V(a_size != b->size(), 0); const real_t *aa = ptr();
const real_t *aa = a->ptr();
const real_t *ba = b->ptr(); const real_t *ba = b->ptr();
real_t dist = 0; real_t dist = 0;
for (int i = 0; i < a_size; i++) { for (int i = 0; i < _size; i++) {
dist += (aa[i] - ba[i]) * (aa[i] - ba[i]); dist += (aa[i] - ba[i]) * (aa[i] - ba[i]);
} }
return Math::sqrt(dist); return Math::sqrt(dist);
} }
real_t MLPPVector::euclidean_distance_squared(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b) { real_t MLPPVector::euclidean_distance_squared(const Ref<MLPPVector> &b) {
ERR_FAIL_COND_V(!a.is_valid() || !b.is_valid(), 0); ERR_FAIL_COND_V(!b.is_valid(), 0);
int a_size = a->size(); ERR_FAIL_COND_V(_size != b->size(), 0);
ERR_FAIL_COND_V(a_size != b->size(), 0); const real_t *aa = ptr();
const real_t *aa = a->ptr();
const real_t *ba = b->ptr(); const real_t *ba = b->ptr();
real_t dist = 0; real_t dist = 0;
for (int i = 0; i < a_size; i++) { for (int i = 0; i < _size; i++) {
dist += (aa[i] - ba[i]) * (aa[i] - ba[i]); dist += (aa[i] - ba[i]) * (aa[i] - ba[i]);
} }
@ -887,26 +883,21 @@ real_t MLPPVector::norm_2(std::vector<std::vector<real_t>> A) {
} }
*/ */
real_t MLPPVector::norm_sqv(const Ref<MLPPVector> &a) { real_t MLPPVector::norm_sq() {
ERR_FAIL_COND_V(!a.is_valid(), 0); const real_t *a_ptr = ptr();
int size = a->size();
const real_t *a_ptr = a->ptr();
real_t n_sq = 0; real_t n_sq = 0;
for (int i = 0; i < size; ++i) { for (int i = 0; i < _size; ++i) {
n_sq += a_ptr[i] * a_ptr[i]; n_sq += a_ptr[i] * a_ptr[i];
} }
return n_sq; return n_sq;
} }
real_t MLPPVector::sum_elementsv(const Ref<MLPPVector> &a) { real_t MLPPVector::sum_elements() {
int a_size = a->size(); const real_t *a_ptr = ptr();
const real_t *a_ptr = a->ptr();
real_t sum = 0; real_t sum = 0;
for (int i = 0; i < a_size; ++i) { for (int i = 0; i < _size; ++i) {
sum += a_ptr[i]; sum += a_ptr[i];
} }
return sum; return sum;
@ -918,8 +909,22 @@ real_t MLPPVector::cosineSimilarity(std::vector<real_t> a, std::vector<real_t> b
} }
*/ */
Ref<MLPPVector> MLPPVector::subtract_matrix_rowsnv(const Ref<MLPPVector> &a, const Ref<MLPPMatrix> &B) { void MLPPVector::subtract_matrix_rows(const Ref<MLPPMatrix> &B) {
Ref<MLPPVector> c = a->duplicate(); Size2i b_size = B->size();
ERR_FAIL_COND(b_size.x != size());
const real_t *b_ptr = B->ptr();
real_t *c_ptr = ptrw();
for (int i = 0; i < b_size.y; ++i) {
for (int j = 0; j < b_size.x; ++j) {
c_ptr[j] -= b_ptr[B->calculate_index(i, j)];
}
}
}
Ref<MLPPVector> MLPPVector::subtract_matrix_rowsn(const Ref<MLPPMatrix> &B) {
Ref<MLPPVector> c = duplicate();
Size2i b_size = B->size(); Size2i b_size = B->size();
@ -936,20 +941,36 @@ Ref<MLPPVector> MLPPVector::subtract_matrix_rowsnv(const Ref<MLPPVector> &a, con
return c; return c;
} }
void MLPPVector::subtract_matrix_rowsb(const Ref<MLPPVector> &a, const Ref<MLPPMatrix> &B) {
Size2i b_size = B->size();
Ref<MLPPMatrix> MLPPVector::outer_product(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b) { ERR_FAIL_COND(b_size.x != a->size());
set_from_mlpp_vector(a);
const real_t *b_ptr = B->ptr();
real_t *c_ptr = ptrw();
for (int i = 0; i < b_size.y; ++i) {
for (int j = 0; j < b_size.x; ++j) {
c_ptr[j] -= b_ptr[B->calculate_index(i, j)];
}
}
}
Ref<MLPPMatrix> MLPPVector::outer_product(const Ref<MLPPVector> &b) {
Ref<MLPPMatrix> C; Ref<MLPPMatrix> C;
C.instance(); C.instance();
Size2i size = Size2i(b->size(), a->size()); Size2i sm = Size2i(b->size(), size());
C->resize(size); C->resize(sm);
const real_t *a_ptr = a->ptr(); const real_t *a_ptr = ptr();
const real_t *b_ptr = b->ptr(); const real_t *b_ptr = b->ptr();
for (int i = 0; i < size.y; ++i) { for (int i = 0; i < sm.y; ++i) {
real_t curr_a = a_ptr[i]; real_t curr_a = a_ptr[i];
for (int j = 0; j < size.x; ++j) { for (int j = 0; j < sm.x; ++j) {
C->set_element(i, j, curr_a * b_ptr[j]); C->set_element(i, j, curr_a * b_ptr[j]);
} }
} }
@ -957,19 +978,17 @@ Ref<MLPPMatrix> MLPPVector::outer_product(const Ref<MLPPVector> &a, const Ref<ML
return C; return C;
} }
Ref<MLPPMatrix> MLPPVector::diagnm(const Ref<MLPPVector> &a) { Ref<MLPPMatrix> MLPPVector::diagnm() {
int a_size = a->size();
Ref<MLPPMatrix> B; Ref<MLPPMatrix> B;
B.instance(); B.instance();
B->resize(Size2i(a_size, a_size)); B->resize(Size2i(_size, _size));
B->fill(0); B->fill(0);
const real_t *a_ptr = a->ptr(); const real_t *a_ptr = ptr();
real_t *b_ptr = B->ptrw(); real_t *b_ptr = B->ptrw();
for (int i = 0; i < a_size; ++i) { for (int i = 0; i < _size; ++i) {
b_ptr[B->calculate_index(i, i)] = a_ptr[i]; b_ptr[B->calculate_index(i, i)] = a_ptr[i];
} }

View File

@ -423,24 +423,28 @@ public:
//std::vector<real_t> round(std::vector<real_t> a); //std::vector<real_t> round(std::vector<real_t> a);
real_t euclidean_distance(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b); real_t euclidean_distance(const Ref<MLPPVector> &b);
real_t euclidean_distance_squared(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b); real_t euclidean_distance_squared(const Ref<MLPPVector> &b);
/* /*
real_t norm_2(std::vector<real_t> a); real_t norm_2(std::vector<real_t> a);
*/ */
real_t norm_sqv(const Ref<MLPPVector> &a); real_t norm_sq();
real_t sum_elementsv(const Ref<MLPPVector> &a); real_t sum_elements();
//real_t cosineSimilarity(std::vector<real_t> a, std::vector<real_t> b); //real_t cosineSimilarity(std::vector<real_t> a, std::vector<real_t> b);
Ref<MLPPVector> subtract_matrix_rowsnv(const Ref<MLPPVector> &a, const Ref<MLPPMatrix> &B); void subtract_matrix_rows(const Ref<MLPPMatrix> &B);
Ref<MLPPMatrix> outer_product(const Ref<MLPPVector> &a, const Ref<MLPPVector> &b); // This multiplies a, bT Ref<MLPPVector> subtract_matrix_rowsn(const Ref<MLPPMatrix> &B);
void subtract_matrix_rowsb(const Ref<MLPPVector> &a, const Ref<MLPPMatrix> &B);
// This multiplies a, bT
Ref<MLPPMatrix> outer_product(const Ref<MLPPVector> &b);
// as_diagonal_matrix / to_diagonal_matrix // as_diagonal_matrix / to_diagonal_matrix
Ref<MLPPMatrix> diagnm(const Ref<MLPPVector> &a); Ref<MLPPMatrix> diagnm();
String to_string(); String to_string();