2023-01-23 21:13:26 +01:00
|
|
|
//
|
|
|
|
// kNN.cpp
|
|
|
|
//
|
|
|
|
// Created by Marc Melikyan on 10/2/20.
|
|
|
|
//
|
|
|
|
|
2023-01-24 18:12:23 +01:00
|
|
|
#include "knn.h"
|
|
|
|
#include "../lin_alg/lin_alg.h"
|
|
|
|
#include "../utilities/utilities.h"
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
#include "core/containers/hash_map.h"
|
|
|
|
#include "core/containers/vector.h"
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
Ref<MLPPMatrix> MLPPKNN::get_input_set() {
|
|
|
|
return _input_set;
|
|
|
|
}
|
|
|
|
void MLPPKNN::set_input_set(const Ref<MLPPMatrix> &val) {
|
|
|
|
_input_set = val;
|
|
|
|
}
|
|
|
|
|
|
|
|
Ref<MLPPVector> MLPPKNN::get_output_set() {
|
|
|
|
return _output_set;
|
|
|
|
}
|
|
|
|
void MLPPKNN::set_output_set(const Ref<MLPPVector> &val) {
|
|
|
|
_output_set = val;
|
|
|
|
}
|
2023-01-24 19:20:18 +01:00
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
int MLPPKNN::get_k() {
|
|
|
|
return _k;
|
|
|
|
}
|
|
|
|
void MLPPKNN::set_k(const int val) {
|
|
|
|
_k = val;
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
PoolIntArray MLPPKNN::model_set_test(const Ref<MLPPMatrix> &X) {
|
|
|
|
ERR_FAIL_COND_V(!X.is_valid(), PoolIntArray());
|
|
|
|
|
|
|
|
Ref<MLPPVector> v;
|
|
|
|
v.instance();
|
|
|
|
|
|
|
|
int y_size = X->size().y;
|
|
|
|
|
|
|
|
PoolIntArray y_hat;
|
|
|
|
y_hat.resize(y_size);
|
|
|
|
|
|
|
|
for (int i = 0; i < y_size; i++) {
|
|
|
|
X->get_row_into_mlpp_vector(i, v);
|
|
|
|
|
|
|
|
y_hat.set(i, model_test(v));
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
2023-01-28 01:02:57 +01:00
|
|
|
|
2023-01-24 19:00:54 +01:00
|
|
|
return y_hat;
|
|
|
|
}
|
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
int MLPPKNN::model_test(const Ref<MLPPVector> &x) {
|
|
|
|
return determine_class(nearest_neighbors(x));
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
|
|
|
|
2023-01-27 13:01:16 +01:00
|
|
|
real_t MLPPKNN::score() {
|
2023-01-28 01:02:57 +01:00
|
|
|
MLPPUtilities util;
|
|
|
|
return util.performance_pool_int_array_vec(model_set_test(_input_set), _output_set);
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
MLPPKNN::MLPPKNN() {
|
|
|
|
_k = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
MLPPKNN::~MLPPKNN() {
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-01-28 01:02:57 +01:00
|
|
|
// Private Model Functions
|
|
|
|
PoolIntArray MLPPKNN::nearest_neighbors(const Ref<MLPPVector> &x) {
|
|
|
|
ERR_FAIL_COND_V(!_input_set.is_valid(), PoolIntArray());
|
|
|
|
|
2023-01-25 00:29:02 +01:00
|
|
|
MLPPLinAlg alg;
|
2023-01-24 19:00:54 +01:00
|
|
|
// The nearest neighbors
|
2023-01-28 01:02:57 +01:00
|
|
|
PoolIntArray knn;
|
|
|
|
|
|
|
|
HashMap<int, bool> skip_map;
|
|
|
|
|
|
|
|
Ref<MLPPVector> tmpv1;
|
|
|
|
tmpv1.instance();
|
|
|
|
Ref<MLPPVector> tmpv2;
|
|
|
|
tmpv2.instance();
|
|
|
|
|
|
|
|
int iuss = _input_set->size().y;
|
2023-01-23 21:13:26 +01:00
|
|
|
|
2023-01-24 19:00:54 +01:00
|
|
|
//Perfom this loop unless and until all k nearest neighbors are found, appended, and returned
|
2023-01-28 01:02:57 +01:00
|
|
|
for (int i = 0; i < _k; ++i) {
|
2023-01-24 19:00:54 +01:00
|
|
|
int neighbor = 0;
|
2023-01-28 01:02:57 +01:00
|
|
|
|
|
|
|
for (int j = 0; j < iuss; j++) {
|
|
|
|
if (skip_map.has(j)) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
_input_set->get_row_into_mlpp_vector(j, tmpv1);
|
|
|
|
_input_set->get_row_into_mlpp_vector(neighbor, tmpv2);
|
|
|
|
|
|
|
|
bool is_neighbor_nearer = alg.euclidean_distance(x, tmpv1) < alg.euclidean_distance(x, tmpv2);
|
|
|
|
|
|
|
|
if (is_neighbor_nearer) {
|
2023-01-24 19:00:54 +01:00
|
|
|
neighbor = j;
|
|
|
|
}
|
|
|
|
}
|
2023-01-28 01:02:57 +01:00
|
|
|
|
|
|
|
if (!skip_map.has(neighbor)) {
|
|
|
|
knn.push_back(neighbor);
|
|
|
|
skip_map.set(neighbor, true);
|
|
|
|
}
|
2023-01-24 19:00:54 +01:00
|
|
|
}
|
2023-01-28 01:02:57 +01:00
|
|
|
|
2023-01-24 19:00:54 +01:00
|
|
|
return knn;
|
|
|
|
}
|
2023-01-28 01:02:57 +01:00
|
|
|
|
|
|
|
int MLPPKNN::determine_class(const PoolIntArray &knn) {
|
|
|
|
ERR_FAIL_COND_V(!_output_set.is_valid(), 0);
|
|
|
|
|
|
|
|
int output_set_size = _output_set->size();
|
|
|
|
|
|
|
|
ERR_FAIL_COND_V(output_set_size == 0, 0);
|
|
|
|
|
|
|
|
const real_t *os_ptr = _output_set->ptr();
|
|
|
|
|
|
|
|
HashMap<int, int> class_nums;
|
|
|
|
|
|
|
|
for (int i = 0; i < output_set_size; ++i) {
|
|
|
|
class_nums[static_cast<int>(os_ptr[i])] = 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
PoolIntArray::Read knn_r = knn.read();
|
|
|
|
const int *knn_ptr = knn_r.ptr();
|
|
|
|
int knn_size = knn.size();
|
|
|
|
|
|
|
|
for (int i = 0; i < knn_size; ++i) {
|
|
|
|
for (int j = 0; j < output_set_size; j++) {
|
|
|
|
int opj = static_cast<int>(os_ptr[j]);
|
|
|
|
if (knn_ptr[i] == opj) {
|
|
|
|
class_nums[opj]++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
int final_class = static_cast<int>(os_ptr[0]);
|
|
|
|
int max = class_nums[final_class];
|
|
|
|
|
|
|
|
for (int i = 0; i < output_set_size; ++i) {
|
|
|
|
int opi = static_cast<int>(os_ptr[i]);
|
|
|
|
|
|
|
|
if (class_nums[opi] > max) {
|
|
|
|
max = class_nums[opi];
|
|
|
|
final_class = opi;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return final_class;
|
|
|
|
}
|
|
|
|
|
|
|
|
void MLPPKNN::_bind_methods() {
|
|
|
|
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPKNN::get_input_set);
|
|
|
|
ClassDB::bind_method(D_METHOD("set_input_set", "value"), &MLPPKNN::set_input_set);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
|
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("get_output_set"), &MLPPKNN::get_output_set);
|
|
|
|
ClassDB::bind_method(D_METHOD("set_output_set", "value"), &MLPPKNN::set_output_set);
|
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "output_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPVector"), "set_output_set", "get_output_set");
|
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("get_k"), &MLPPKNN::get_k);
|
|
|
|
ClassDB::bind_method(D_METHOD("set_k", "value"), &MLPPKNN::set_k);
|
2023-01-28 14:33:35 +01:00
|
|
|
ADD_PROPERTY(PropertyInfo(Variant::INT, "k"), "set_k", "get_k");
|
2023-01-28 01:02:57 +01:00
|
|
|
|
|
|
|
ClassDB::bind_method(D_METHOD("model_set_test", "X"), &MLPPKNN::model_set_test);
|
|
|
|
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPKNN::model_test);
|
|
|
|
ClassDB::bind_method(D_METHOD("score"), &MLPPKNN::score);
|
|
|
|
}
|