pmlpp/mlpp/kmeans/kmeans.cpp

756 lines
20 KiB
C++
Raw Normal View History

2023-12-30 00:41:59 +01:00
/*************************************************************************/
/* kmeans.cpp */
/*************************************************************************/
/* This file is part of: */
/* PMLPP Machine Learning Library */
/* https://github.com/Relintai/pmlpp */
/*************************************************************************/
2023-12-30 00:43:39 +01:00
/* Copyright (c) 2023-present Péter Magyar. */
2023-12-30 00:41:59 +01:00
/* Copyright (c) 2022-2023 Marc Melikyan */
/* */
/* Permission is hereby granted, free of charge, to any person obtaining */
/* a copy of this software and associated documentation files (the */
/* "Software"), to deal in the Software without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of the Software, and to */
/* permit persons to whom the Software is furnished to do so, subject to */
/* the following conditions: */
/* */
/* The above copyright notice and this permission notice shall be */
/* included in all copies or substantial portions of the Software. */
/* */
/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
/*************************************************************************/
2023-01-24 18:12:23 +01:00
#include "kmeans.h"
2023-01-24 19:00:54 +01:00
#include "../utilities/utilities.h"
2023-01-29 15:46:55 +01:00
#include "core/math/random_pcg.h"
2023-01-24 19:00:54 +01:00
#include <climits>
#include <iostream>
#include <random>
2023-01-28 14:35:05 +01:00
Ref<MLPPMatrix> MLPPKMeans::get_input_set() {
return _input_set;
}
void MLPPKMeans::set_input_set(const Ref<MLPPMatrix> &val) {
_input_set = val;
_initialized = false;
}
int MLPPKMeans::get_k() {
return _k;
}
void MLPPKMeans::set_k(const int val) {
_k = val;
_initialized = false;
}
2023-01-24 19:20:18 +01:00
2023-01-28 14:35:05 +01:00
MLPPKMeans::MeanType MLPPKMeans::get_mean_type() {
return _mean_type;
}
void MLPPKMeans::set_mean_type(const MLPPKMeans::MeanType val) {
_mean_type = val;
_initialized = false;
}
void MLPPKMeans::initialize() {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND(!_input_set.is_valid());
2023-01-28 14:35:05 +01:00
if (_mean_type == MEAN_TYPE_KMEANSPP) {
2023-01-29 15:46:55 +01:00
_kmeanspp_initialization();
2023-01-24 19:00:54 +01:00
} else {
2023-01-29 15:46:55 +01:00
_centroid_initialization();
2023-01-24 19:00:54 +01:00
}
2023-01-29 15:46:55 +01:00
_initialized = true;
2023-01-24 19:00:54 +01:00
}
2023-01-28 14:35:05 +01:00
Ref<MLPPMatrix> MLPPKMeans::model_set_test(const Ref<MLPPMatrix> &X) {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND_V(!X.is_valid(), Ref<MLPPMatrix>());
ERR_FAIL_COND_V(!_initialized, Ref<MLPPMatrix>());
int input_set_size_y = _input_set->size().y;
Ref<MLPPMatrix> closest_centroids;
closest_centroids.instance();
closest_centroids->resize(Size2i(_mu->size().x, input_set_size_y));
Ref<MLPPVector> closest_centroid;
closest_centroid.instance();
closest_centroid->resize(_mu->size().x);
Ref<MLPPVector> tmp_xiv;
tmp_xiv.instance();
tmp_xiv->resize(X->size().x);
Ref<MLPPVector> tmp_mujv;
tmp_mujv.instance();
tmp_mujv->resize(_mu->size().x);
int r0_size = _r->size().x;
for (int i = 0; i < input_set_size_y; ++i) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(0, closest_centroid);
X->row_get_into_mlpp_vector(i, tmp_xiv);
2023-01-29 15:46:55 +01:00
for (int j = 0; j < r0_size; ++j) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, tmp_mujv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
bool is_centroid_closer = tmp_xiv->euclidean_distance(tmp_mujv) < tmp_xiv->euclidean_distance(closest_centroid);
2023-01-29 15:46:55 +01:00
if (is_centroid_closer) {
closest_centroid->set_from_mlpp_vector(tmp_mujv);
}
}
2023-04-29 15:07:30 +02:00
closest_centroids->row_set_mlpp_vector(i, closest_centroid);
2023-01-29 15:46:55 +01:00
}
return closest_centroids;
2023-01-28 14:35:05 +01:00
}
Ref<MLPPVector> MLPPKMeans::model_test(const Ref<MLPPVector> &x) {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND_V(!x.is_valid(), Ref<MLPPVector>());
ERR_FAIL_COND_V(!_initialized, Ref<MLPPVector>());
Ref<MLPPVector> closest_centroid;
closest_centroid.instance();
closest_centroid->resize(_mu->size().x);
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(0, closest_centroid);
2023-01-29 15:46:55 +01:00
int mu_size_y = _mu->size().y;
Ref<MLPPVector> tmp_mujv;
tmp_mujv.instance();
tmp_mujv->resize(_mu->size().x);
for (int j = 0; j < mu_size_y; ++j) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, tmp_mujv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
if (x->euclidean_distance(tmp_mujv) < x->euclidean_distance(closest_centroid)) {
2023-01-29 15:46:55 +01:00
closest_centroid->set_from_mlpp_vector(tmp_mujv);
}
}
return closest_centroid;
2023-01-28 14:35:05 +01:00
}
void MLPPKMeans::train(int epoch_num, bool UI) {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND(!_input_set.is_valid());
if (!_initialized) {
initialize();
}
real_t cost_prev = 0;
int epoch = 1;
_evaluate();
while (true) {
// STEPS OF THE ALGORITHM
// 1. DETERMINE r_nk
// 2. DETERMINE J
// 3. DETERMINE mu_k
// STOP IF CONVERGED, ELSE REPEAT
cost_prev = _cost();
_compute_mu();
_evaluate();
// UI PORTION
if (UI) {
MLPPUtilities::cost_info(epoch, cost_prev, _cost());
}
epoch++;
if (epoch > epoch_num) {
break;
}
}
2023-01-28 14:35:05 +01:00
}
2023-01-29 15:46:55 +01:00
2023-01-28 14:35:05 +01:00
real_t MLPPKMeans::score() {
2023-01-29 15:46:55 +01:00
return _cost();
2023-01-28 14:35:05 +01:00
}
2023-01-29 15:46:55 +01:00
2023-01-28 14:35:05 +01:00
Ref<MLPPVector> MLPPKMeans::silhouette_scores() {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND_V(!_initialized, Ref<MLPPVector>());
Ref<MLPPMatrix> closest_centroids = model_set_test(_input_set);
ERR_FAIL_COND_V(!closest_centroids.is_valid(), Ref<MLPPVector>());
int input_set_size_y = _input_set->size().y;
int input_set_size_x = _input_set->size().x;
int mu_size_y = _mu->size().y;
int closest_centroids_size_y = closest_centroids->size().y;
Ref<MLPPVector> silhouette_scores;
silhouette_scores.instance();
silhouette_scores->resize(input_set_size_y);
Ref<MLPPVector> input_set_i_tempv;
input_set_i_tempv.instance();
input_set_i_tempv->resize(input_set_size_x);
Ref<MLPPVector> input_set_j_tempv;
input_set_j_tempv.instance();
input_set_j_tempv->resize(input_set_size_x);
Ref<MLPPVector> input_set_k_tempv;
input_set_k_tempv.instance();
input_set_k_tempv->resize(input_set_size_x);
Ref<MLPPVector> r_i_tempv;
r_i_tempv.instance();
r_i_tempv->resize(_r->size().x);
Ref<MLPPVector> r_j_tempv;
r_j_tempv.instance();
r_j_tempv->resize(_r->size().x);
Ref<MLPPVector> closest_centroids_i_tempv;
closest_centroids_i_tempv.instance();
closest_centroids_i_tempv->resize(closest_centroids->size().x);
Ref<MLPPVector> closest_centroids_k_tempv;
closest_centroids_k_tempv.instance();
closest_centroids_k_tempv->resize(closest_centroids->size().x);
Ref<MLPPVector> mu_j_tempv;
mu_j_tempv.instance();
mu_j_tempv->resize(_mu->size().x);
for (int i = 0; i < input_set_size_y; ++i) {
2023-04-29 15:07:30 +02:00
_r->row_get_into_mlpp_vector(i, r_i_tempv);
_input_set->row_get_into_mlpp_vector(i, input_set_i_tempv);
2023-01-29 15:46:55 +01:00
// COMPUTING a[i]
real_t a = 0;
for (int j = 0; j < input_set_size_y; ++j) {
if (i == j) {
continue;
}
2023-04-29 15:07:30 +02:00
_r->row_get_into_mlpp_vector(j, r_j_tempv);
2023-01-29 15:46:55 +01:00
if (r_i_tempv->is_equal_approx(r_j_tempv)) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(j, input_set_j_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
a += input_set_i_tempv->euclidean_distance(input_set_j_tempv);
2023-01-29 15:46:55 +01:00
}
}
// NORMALIZE a[i]
a /= closest_centroids->size().x - 1;
2023-04-29 15:07:30 +02:00
closest_centroids->row_get_into_mlpp_vector(i, closest_centroids_i_tempv);
2023-01-29 15:46:55 +01:00
// COMPUTING b[i]
2023-04-16 16:23:33 +02:00
real_t b = Math_INF;
2023-01-29 15:46:55 +01:00
for (int j = 0; j < mu_size_y; ++j) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, mu_j_tempv);
2023-01-29 15:46:55 +01:00
if (!closest_centroids_i_tempv->is_equal_approx(mu_j_tempv)) {
real_t sum = 0;
for (int k = 0; k < input_set_size_y; ++k) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(k, input_set_k_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
sum += input_set_i_tempv->euclidean_distance(input_set_k_tempv);
2023-01-29 15:46:55 +01:00
}
// NORMALIZE b[i]
real_t k_cluster_size = 0;
for (int k = 0; k < closest_centroids_size_y; ++k) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(k, closest_centroids_k_tempv);
2023-01-29 15:46:55 +01:00
if (closest_centroids_k_tempv->is_equal_approx(mu_j_tempv)) {
++k_cluster_size;
}
}
if (sum / k_cluster_size < b) {
b = sum / k_cluster_size;
}
}
}
silhouette_scores->element_set(i, (b - a) / fmax(a, b));
2023-01-29 15:46:55 +01:00
// Or the expanded version:
// if(a < b) {
// silhouette_scores->element_set(i, 1 - a/b);
2023-01-29 15:46:55 +01:00
// }
// else if(a == b){
// silhouette_scores->element_set(i, 0);
2023-01-29 15:46:55 +01:00
// }
// else{
// silhouette_scores->element_set(i, b/a - 1);
2023-01-29 15:46:55 +01:00
// }
}
return silhouette_scores;
2023-01-28 14:35:05 +01:00
}
MLPPKMeans::MLPPKMeans() {
2023-01-29 15:46:55 +01:00
_mu.instance();
_r.instance();
2023-01-28 14:35:05 +01:00
_accuracy_threshold = 0;
_k = 0;
_initialized = false;
_mean_type = MEAN_TYPE_CENTROID;
}
MLPPKMeans::~MLPPKMeans() {
}
2023-01-29 15:46:55 +01:00
// This simply computes r_nk
2023-01-28 14:35:05 +01:00
void MLPPKMeans::_evaluate() {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND(!_initialized);
if (_r->size() != Size2i(_k, _input_set->size().y)) {
_r->resize(Size2i(_k, _input_set->size().y));
}
int r_size_y = _r->size().y;
int r_size_x = _r->size().x;
Ref<MLPPVector> closest_centroid;
closest_centroid.instance();
closest_centroid->resize(_mu->size().x);
Ref<MLPPVector> input_set_i_tempv;
input_set_i_tempv.instance();
input_set_i_tempv->resize(_input_set->size().x);
Ref<MLPPVector> mu_j_tempv;
mu_j_tempv.instance();
mu_j_tempv->resize(_mu->size().x);
real_t closest_centroid_current_dist = 0;
int closest_centroid_index = 0;
_r->fill(0);
for (int i = 0; i < r_size_y; ++i) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(0, closest_centroid);
_input_set->row_get_into_mlpp_vector(i, input_set_i_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
closest_centroid_current_dist = input_set_i_tempv->euclidean_distance(closest_centroid);
2023-01-29 15:46:55 +01:00
for (int j = 0; j < r_size_x; ++j) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, mu_j_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
bool is_centroid_closer = input_set_i_tempv->euclidean_distance(mu_j_tempv) < closest_centroid_current_dist;
2023-01-29 15:46:55 +01:00
if (is_centroid_closer) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, closest_centroid);
2023-04-30 17:39:00 +02:00
closest_centroid_current_dist = input_set_i_tempv->euclidean_distance(closest_centroid);
2023-01-29 15:46:55 +01:00
closest_centroid_index = j;
}
}
_r->element_set(i, closest_centroid_index, 1);
2023-01-29 15:46:55 +01:00
}
2023-01-28 14:35:05 +01:00
}
2023-01-29 15:46:55 +01:00
// This simply computes or re-computes mu_k
2023-01-28 14:35:05 +01:00
void MLPPKMeans::_compute_mu() {
2023-01-29 15:46:55 +01:00
int mu_size_y = _mu->size().y;
int r_size_y = _r->size().y;
Ref<MLPPVector> num;
num.instance();
num->resize(_r->size().x);
Ref<MLPPVector> input_set_j_tempv;
input_set_j_tempv.instance();
input_set_j_tempv->resize(_input_set->size().x);
Ref<MLPPVector> mat_tempv;
mat_tempv.instance();
mat_tempv->resize(_input_set->size().x);
Ref<MLPPVector> mu_tempv;
mu_tempv.instance();
mu_tempv->resize(_mu->size().x);
for (int i = 0; i < mu_size_y; ++i) {
num->fill(0);
real_t den = 0;
for (int j = 0; j < r_size_y; ++j) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(j, input_set_j_tempv);
2023-01-29 15:46:55 +01:00
real_t r_j_i = _r->element_get(j, i);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
mat_tempv->scalar_multiplyb(_r->element_get(j, i), input_set_j_tempv);
num->add(mat_tempv);
2023-01-29 15:46:55 +01:00
den += r_j_i;
}
2023-04-30 17:39:00 +02:00
mu_tempv->scalar_multiplyb(real_t(1) / real_t(den), num);
2023-01-29 15:46:55 +01:00
2023-04-29 15:07:30 +02:00
_mu->row_set_mlpp_vector(i, mu_tempv);
2023-01-29 15:46:55 +01:00
}
2023-01-28 14:35:05 +01:00
}
2023-01-29 15:46:55 +01:00
void MLPPKMeans::_centroid_initialization() {
RandomPCG rand;
rand.randomize();
Size2i mu_size = Size2i(_input_set->size().x, _k);
if (_mu->size() != mu_size) {
_mu->resize(mu_size);
}
Ref<MLPPVector> mu_tempv;
mu_tempv.instance();
mu_tempv->resize(_mu->size().x);
int input_set_size_y_rand = _input_set->size().y - 1;
for (int i = 0; i < _k; ++i) {
int indx = rand.random(0, input_set_size_y_rand);
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(indx, mu_tempv);
_mu->row_set_mlpp_vector(i, mu_tempv);
2023-01-29 15:46:55 +01:00
}
2023-01-28 14:35:05 +01:00
}
2023-01-29 15:46:55 +01:00
void MLPPKMeans::_kmeanspp_initialization() {
RandomPCG rand;
rand.randomize();
Size2i mu_size = Size2i(_input_set->size().x, _k);
if (_mu->size() != mu_size) {
_mu->resize(mu_size);
}
int input_set_size_y = _input_set->size().y;
Ref<MLPPVector> mu_tempv;
mu_tempv.instance();
mu_tempv->resize(_mu->size().x);
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(rand.random(0, input_set_size_y - 1), mu_tempv);
_mu->row_set_mlpp_vector(0, mu_tempv);
2023-01-29 15:46:55 +01:00
Ref<MLPPVector> input_set_j_tempv;
input_set_j_tempv.instance();
input_set_j_tempv->resize(_input_set->size().x);
Ref<MLPPVector> farthest_centroid;
farthest_centroid.instance();
farthest_centroid->resize(_input_set->size().x);
for (int i = 1; i < _k - 1; ++i) {
for (int j = 0; j < input_set_size_y; ++j) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(j, input_set_j_tempv);
2023-01-29 15:46:55 +01:00
real_t max_dist = 0;
// SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST
// AS TO SPREAD OUT THE CLUSTER CENTROIDS.
real_t sum = 0;
for (int k = 0; k < i; k++) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(k, mu_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
sum += input_set_j_tempv->euclidean_distance(mu_tempv);
2023-01-29 15:46:55 +01:00
}
if (sum * sum > max_dist) {
farthest_centroid->set_from_mlpp_vector(input_set_j_tempv);
max_dist = sum * sum;
}
}
2023-04-29 15:07:30 +02:00
_mu->row_set_mlpp_vector(i, farthest_centroid);
2023-01-29 15:46:55 +01:00
}
2023-01-28 14:35:05 +01:00
}
real_t MLPPKMeans::_cost() {
2023-01-29 15:46:55 +01:00
ERR_FAIL_COND_V(!_initialized, 0);
Ref<MLPPVector> input_set_i_tempv;
input_set_i_tempv.instance();
input_set_i_tempv->resize(_input_set->size().x);
Ref<MLPPVector> mu_j_tempv;
mu_j_tempv.instance();
mu_j_tempv->resize(_mu->size().x);
Ref<MLPPVector> sub_tempv;
sub_tempv.instance();
sub_tempv->resize(_input_set->size().x);
int r_size_y = _r->size().y;
int r_size_x = _r->size().x;
real_t sum = 0;
for (int i = 0; i < r_size_y; i++) {
2023-04-29 15:07:30 +02:00
_input_set->row_get_into_mlpp_vector(i, input_set_i_tempv);
2023-01-29 15:46:55 +01:00
for (int j = 0; j < r_size_x; j++) {
2023-04-29 15:07:30 +02:00
_mu->row_get_into_mlpp_vector(j, mu_j_tempv);
2023-01-29 15:46:55 +01:00
2023-04-30 17:39:00 +02:00
sub_tempv->subb(input_set_i_tempv, mu_j_tempv);
sum += _r->element_get(i, j) * sub_tempv->norm_sq();
2023-01-29 15:46:55 +01:00
}
}
return sum;
2023-01-28 14:35:05 +01:00
}
void MLPPKMeans::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPKMeans::get_input_set);
ClassDB::bind_method(D_METHOD("set_input_set", "value"), &MLPPKMeans::set_input_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
ClassDB::bind_method(D_METHOD("get_k"), &MLPPKMeans::get_k);
ClassDB::bind_method(D_METHOD("set_k", "value"), &MLPPKMeans::set_k);
ADD_PROPERTY(PropertyInfo(Variant::INT, "k"), "set_k", "get_k");
ClassDB::bind_method(D_METHOD("get_mean_type"), &MLPPKMeans::get_mean_type);
ClassDB::bind_method(D_METHOD("set_mean_type", "value"), &MLPPKMeans::set_mean_type);
ADD_PROPERTY(PropertyInfo(Variant::INT, "mean_type", PROPERTY_HINT_ENUM, "Centroid,KMeansPP"), "set_mean_type", "get_mean_type");
ClassDB::bind_method(D_METHOD("initialize"), &MLPPKMeans::initialize);
ClassDB::bind_method(D_METHOD("model_set_test", "X"), &MLPPKMeans::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPKMeans::model_test);
ClassDB::bind_method(D_METHOD("train", "epoch_num", "UI"), &MLPPKMeans::train, false);
ClassDB::bind_method(D_METHOD("score"), &MLPPKMeans::score);
ClassDB::bind_method(D_METHOD("silhouette_scores"), &MLPPKMeans::silhouette_scores);
BIND_ENUM_CONSTANT(MEAN_TYPE_CENTROID);
BIND_ENUM_CONSTANT(MEAN_TYPE_KMEANSPP);
}
/*
2023-01-27 13:01:16 +01:00
std::vector<std::vector<real_t>> MLPPKMeans::modelSetTest(std::vector<std::vector<real_t>> X) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-27 13:01:16 +01:00
std::vector<std::vector<real_t>> closestCentroids;
2023-01-24 19:00:54 +01:00
for (int i = 0; i < inputSet.size(); i++) {
2023-01-27 13:01:16 +01:00
std::vector<real_t> closestCentroid = mu[0];
2023-01-24 19:00:54 +01:00
for (int j = 0; j < r[0].size(); j++) {
bool isCentroidCloser = alg.euclideanDistance(X[i], mu[j]) < alg.euclideanDistance(X[i], closestCentroid);
if (isCentroidCloser) {
closestCentroid = mu[j];
}
}
closestCentroids.push_back(closestCentroid);
}
return closestCentroids;
}
2023-01-27 13:01:16 +01:00
std::vector<real_t> MLPPKMeans::modelTest(std::vector<real_t> x) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-27 13:01:16 +01:00
std::vector<real_t> closestCentroid = mu[0];
2023-01-24 19:00:54 +01:00
for (int j = 0; j < mu.size(); j++) {
if (alg.euclideanDistance(x, mu[j]) < alg.euclideanDistance(x, closestCentroid)) {
closestCentroid = mu[j];
}
}
return closestCentroid;
}
2023-01-25 00:25:18 +01:00
void MLPPKMeans::train(int epoch_num, bool UI) {
2023-01-27 13:01:16 +01:00
real_t cost_prev = 0;
2023-01-24 19:00:54 +01:00
int epoch = 1;
Evaluate();
while (true) {
// STEPS OF THE ALGORITHM
// 1. DETERMINE r_nk
// 2. DETERMINE J
// 3. DETERMINE mu_k
// STOP IF CONVERGED, ELSE REPEAT
cost_prev = Cost();
computeMu();
Evaluate();
// UI PORTION
if (UI) {
MLPPUtilities::CostInfo(epoch, cost_prev, Cost());
2023-01-24 19:00:54 +01:00
}
epoch++;
if (epoch > epoch_num) {
break;
}
}
}
2023-01-27 13:01:16 +01:00
real_t MLPPKMeans::score() {
2023-01-24 19:00:54 +01:00
return Cost();
}
2023-01-27 13:01:16 +01:00
std::vector<real_t> MLPPKMeans::silhouette_scores() {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-27 13:01:16 +01:00
std::vector<std::vector<real_t>> closestCentroids = modelSetTest(inputSet);
std::vector<real_t> silhouette_scores;
2023-01-24 19:00:54 +01:00
for (int i = 0; i < inputSet.size(); i++) {
// COMPUTING a[i]
2023-01-27 13:01:16 +01:00
real_t a = 0;
2023-01-24 19:00:54 +01:00
for (int j = 0; j < inputSet.size(); j++) {
if (i != j && r[i] == r[j]) {
a += alg.euclideanDistance(inputSet[i], inputSet[j]);
}
}
// NORMALIZE a[i]
a /= closestCentroids[i].size() - 1;
// COMPUTING b[i]
2023-01-27 13:01:16 +01:00
real_t b = INT_MAX;
2023-01-24 19:00:54 +01:00
for (int j = 0; j < mu.size(); j++) {
if (closestCentroids[i] != mu[j]) {
2023-01-27 13:01:16 +01:00
real_t sum = 0;
2023-01-24 19:00:54 +01:00
for (int k = 0; k < inputSet.size(); k++) {
sum += alg.euclideanDistance(inputSet[i], inputSet[k]);
}
// NORMALIZE b[i]
2023-01-27 13:01:16 +01:00
real_t k_clusterSize = 0;
2023-01-24 19:00:54 +01:00
for (int k = 0; k < closestCentroids.size(); k++) {
if (closestCentroids[k] == mu[j]) {
k_clusterSize++;
}
}
if (sum / k_clusterSize < b) {
b = sum / k_clusterSize;
}
}
}
silhouette_scores.push_back((b - a) / fmax(a, b));
// Or the expanded version:
// if(a < b) {
// silhouette_scores.push_back(1 - a/b);
// }
// else if(a == b){
// silhouette_scores.push_back(0);
// }
// else{
// silhouette_scores.push_back(b/a - 1);
// }
}
return silhouette_scores;
}
// This simply computes r_nk
2023-01-25 00:25:18 +01:00
void MLPPKMeans::Evaluate() {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-24 19:00:54 +01:00
r.resize(inputSet.size());
for (int i = 0; i < r.size(); i++) {
r[i].resize(k);
}
for (int i = 0; i < r.size(); i++) {
2023-01-27 13:01:16 +01:00
std::vector<real_t> closestCentroid = mu[0];
2023-01-24 19:00:54 +01:00
for (int j = 0; j < r[0].size(); j++) {
bool isCentroidCloser = alg.euclideanDistance(inputSet[i], mu[j]) < alg.euclideanDistance(inputSet[i], closestCentroid);
if (isCentroidCloser) {
closestCentroid = mu[j];
}
}
for (int j = 0; j < r[0].size(); j++) {
if (mu[j] == closestCentroid) {
r[i][j] = 1;
} else {
r[i][j] = 0;
}
}
}
}
// This simply computes or re-computes mu_k
2023-01-25 00:25:18 +01:00
void MLPPKMeans::computeMu() {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-24 19:00:54 +01:00
for (int i = 0; i < mu.size(); i++) {
2023-01-27 13:01:16 +01:00
std::vector<real_t> num;
2023-01-24 19:00:54 +01:00
num.resize(r.size());
for (int i = 0; i < num.size(); i++) {
num[i] = 0;
}
2023-01-27 13:01:16 +01:00
real_t den = 0;
2023-01-24 19:00:54 +01:00
for (int j = 0; j < r.size(); j++) {
num = alg.addition(num, alg.scalarMultiply(r[j][i], inputSet[j]));
}
for (int j = 0; j < r.size(); j++) {
den += r[j][i];
}
2023-01-27 13:01:16 +01:00
mu[i] = alg.scalarMultiply(real_t(1) / real_t(den), num);
2023-01-24 19:00:54 +01:00
}
}
2023-01-25 00:25:18 +01:00
void MLPPKMeans::centroidInitialization(int k) {
2023-01-24 19:00:54 +01:00
mu.resize(k);
for (int i = 0; i < k; i++) {
std::random_device rd;
std::default_random_engine generator(rd());
std::uniform_int_distribution<int> distribution(0, int(inputSet.size() - 1));
mu[i].resize(inputSet.size());
mu[i] = inputSet[distribution(generator)];
}
}
2023-01-25 00:25:18 +01:00
void MLPPKMeans::kmeansppInitialization(int k) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-24 19:00:54 +01:00
std::random_device rd;
std::default_random_engine generator(rd());
std::uniform_int_distribution<int> distribution(0, int(inputSet.size() - 1));
mu.push_back(inputSet[distribution(generator)]);
for (int i = 0; i < k - 1; i++) {
2023-01-27 13:01:16 +01:00
std::vector<real_t> farthestCentroid;
2023-01-24 19:00:54 +01:00
for (int j = 0; j < inputSet.size(); j++) {
2023-01-27 13:01:16 +01:00
real_t max_dist = 0;
2023-01-28 14:35:05 +01:00
// SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST
// AS TO SPREAD OUT THE CLUSTER CENTROIDS.
2023-01-27 13:01:16 +01:00
real_t sum = 0;
2023-01-24 19:00:54 +01:00
for (int k = 0; k < mu.size(); k++) {
sum += alg.euclideanDistance(inputSet[j], mu[k]);
}
if (sum * sum > max_dist) {
farthestCentroid = inputSet[j];
max_dist = sum * sum;
}
}
mu.push_back(farthestCentroid);
}
}
2023-01-27 13:01:16 +01:00
real_t MLPPKMeans::Cost() {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg;
2023-01-27 13:01:16 +01:00
real_t sum = 0;
2023-01-24 19:00:54 +01:00
for (int i = 0; i < r.size(); i++) {
for (int j = 0; j < r[0].size(); j++) {
sum += r[i][j] * alg.norm_sq(alg.subtraction(inputSet[i], mu[j]));
}
}
return sum;
}
2023-01-24 19:20:18 +01:00
2023-01-28 14:35:05 +01:00
*/