pmlpp/mlpp/pca/pca.cpp

126 lines
3.1 KiB
C++
Raw Normal View History

//
// PCA.cpp
//
// Created by Marc Melikyan on 10/2/20.
//
2023-01-24 18:12:23 +01:00
#include "pca.h"
#include "../data/data.h"
2023-01-24 19:00:54 +01:00
#include "../lin_alg/lin_alg.h"
2023-02-08 01:26:37 +01:00
Ref<MLPPMatrix> MLPPPCA::get_input_set() {
return _input_set;
}
void MLPPPCA::set_input_set(const Ref<MLPPMatrix> &val) {
_input_set = val;
}
int MLPPPCA::get_k() {
return _k;
}
void MLPPPCA::set_k(const int val) {
_k = val;
}
Ref<MLPPMatrix> MLPPPCA::principal_components() {
ERR_FAIL_COND_V(!_input_set.is_valid() || _k == 0, Ref<MLPPMatrix>());
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-25 00:21:31 +01:00
MLPPData data;
2023-01-24 19:00:54 +01:00
2023-04-22 14:39:13 +02:00
MLPPLinAlg::SVDResult svr_res = alg.svd(alg.covnm(_input_set));
2023-02-08 01:26:37 +01:00
_x_normalized = data.mean_centering(_input_set);
Size2i svr_res_u_size = svr_res.U->size();
_u_reduce->resize(Size2i(_k, svr_res_u_size.y));
for (int i = 0; i < _k; ++i) {
for (int j = 0; j < svr_res_u_size.y; ++j) {
_u_reduce->set_element(j, i, svr_res.U->get_element(j, i));
2023-01-24 19:00:54 +01:00
}
}
2023-02-08 01:26:37 +01:00
2023-04-22 14:23:51 +02:00
_z = alg.matmultnm(alg.transposenm(_u_reduce), _x_normalized);
2023-02-08 01:26:37 +01:00
return _z;
2023-01-24 19:00:54 +01:00
}
2023-02-08 01:26:37 +01:00
2023-01-24 19:00:54 +01:00
// Simply tells us the percentage of variance maintained.
2023-01-27 13:01:16 +01:00
real_t MLPPPCA::score() {
2023-02-08 01:26:37 +01:00
ERR_FAIL_COND_V(!_input_set.is_valid() || _k == 0, 0);
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-02-08 01:26:37 +01:00
2023-04-22 14:23:51 +02:00
Ref<MLPPMatrix> x_approx = alg.matmultnm(_u_reduce, _z);
2023-02-08 01:26:37 +01:00
real_t num = 0;
real_t den = 0;
Size2i x_normalized_size = _x_normalized->size();
int x_normalized_size_y = x_normalized_size.y;
Ref<MLPPVector> x_approx_row_tmp;
x_approx_row_tmp.instance();
x_approx_row_tmp->resize(x_approx->size().x);
Ref<MLPPVector> x_normalized_row_tmp;
x_normalized_row_tmp.instance();
x_normalized_row_tmp->resize(x_normalized_size.x);
for (int i = 0; i < x_normalized_size_y; ++i) {
_x_normalized->get_row_into_mlpp_vector(i, x_normalized_row_tmp);
x_approx->get_row_into_mlpp_vector(i, x_approx_row_tmp);
num += alg.norm_sqv(alg.subtractionnv(x_normalized_row_tmp, x_approx_row_tmp));
2023-01-24 19:00:54 +01:00
}
2023-02-08 01:26:37 +01:00
num /= x_normalized_size_y;
for (int i = 0; i < x_normalized_size_y; ++i) {
_x_normalized->get_row_into_mlpp_vector(i, x_normalized_row_tmp);
den += alg.norm_sqv(x_normalized_row_tmp);
2023-01-24 19:00:54 +01:00
}
2023-02-08 01:26:37 +01:00
den /= x_normalized_size_y;
2023-01-24 19:00:54 +01:00
if (den == 0) {
den += 1e-10; // For numerical sanity as to not recieve a domain error
}
2023-02-08 01:26:37 +01:00
2023-01-24 19:00:54 +01:00
return 1 - num / den;
}
2023-01-24 19:20:18 +01:00
2023-02-08 01:26:37 +01:00
MLPPPCA::MLPPPCA(const Ref<MLPPMatrix> &p_input_set, int p_k) {
_k = p_k;
_input_set = p_input_set;
_x_normalized.instance();
_u_reduce.instance();
_z.instance();
}
MLPPPCA::MLPPPCA() {
_k = 0;
_x_normalized.instance();
_u_reduce.instance();
_z.instance();
}
MLPPPCA::~MLPPPCA() {
}
void MLPPPCA::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPPCA::get_input_set);
ClassDB::bind_method(D_METHOD("set_input_set", "val"), &MLPPPCA::set_input_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "get_input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
ClassDB::bind_method(D_METHOD("get_k"), &MLPPPCA::get_k);
ClassDB::bind_method(D_METHOD("set_k", "val"), &MLPPPCA::set_k);
ADD_PROPERTY(PropertyInfo(Variant::INT, "k"), "set_k", "get_k");
ClassDB::bind_method(D_METHOD("principal_components"), &MLPPPCA::principal_components);
ClassDB::bind_method(D_METHOD("score"), &MLPPPCA::score);
}