MLPPKMeans initial api cleanup pass.

This commit is contained in:
Relintai 2023-01-28 14:35:05 +01:00
parent 26117ac21d
commit bd67fcecc6
4 changed files with 144 additions and 28 deletions

View File

@ -12,16 +12,101 @@
#include <iostream>
#include <random>
Ref<MLPPMatrix> MLPPKMeans::get_input_set() {
return _input_set;
}
void MLPPKMeans::set_input_set(const Ref<MLPPMatrix> &val) {
_input_set = val;
_initialized = false;
}
MLPPKMeans::MLPPKMeans(std::vector<std::vector<real_t>> inputSet, int k, std::string init_type) :
inputSet(inputSet), k(k), init_type(init_type) {
if (init_type == "KMeans++") {
kmeansppInitialization(k);
int MLPPKMeans::get_k() {
return _k;
}
void MLPPKMeans::set_k(const int val) {
_k = val;
_initialized = false;
}
MLPPKMeans::MeanType MLPPKMeans::get_mean_type() {
return _mean_type;
}
void MLPPKMeans::set_mean_type(const MLPPKMeans::MeanType val) {
_mean_type = val;
_initialized = false;
}
void MLPPKMeans::initialize() {
if (_mean_type == MEAN_TYPE_KMEANSPP) {
_kmeanspp_initialization(_k);
} else {
centroidInitialization(k);
_centroid_initialization(_k);
}
}
Ref<MLPPMatrix> MLPPKMeans::model_set_test(const Ref<MLPPMatrix> &X) {
return Ref<MLPPMatrix>();
}
Ref<MLPPVector> MLPPKMeans::model_test(const Ref<MLPPVector> &x) {
return Ref<MLPPVector>();
}
void MLPPKMeans::train(int epoch_num, bool UI) {
}
real_t MLPPKMeans::score() {
return 0;
}
Ref<MLPPVector> MLPPKMeans::silhouette_scores() {
return Ref<MLPPVector>();
}
MLPPKMeans::MLPPKMeans() {
_accuracy_threshold = 0;
_k = 0;
_initialized = false;
_mean_type = MEAN_TYPE_CENTROID;
}
MLPPKMeans::~MLPPKMeans() {
}
void MLPPKMeans::_evaluate() {
}
void MLPPKMeans::_compute_mu() {
}
void MLPPKMeans::_centroid_initialization(int k) {
}
void MLPPKMeans::_kmeanspp_initialization(int k) {
}
real_t MLPPKMeans::_cost() {
return 0;
}
void MLPPKMeans::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPKMeans::get_input_set);
ClassDB::bind_method(D_METHOD("set_input_set", "value"), &MLPPKMeans::set_input_set);
ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set");
ClassDB::bind_method(D_METHOD("get_k"), &MLPPKMeans::get_k);
ClassDB::bind_method(D_METHOD("set_k", "value"), &MLPPKMeans::set_k);
ADD_PROPERTY(PropertyInfo(Variant::INT, "k"), "set_k", "get_k");
ClassDB::bind_method(D_METHOD("get_mean_type"), &MLPPKMeans::get_mean_type);
ClassDB::bind_method(D_METHOD("set_mean_type", "value"), &MLPPKMeans::set_mean_type);
ADD_PROPERTY(PropertyInfo(Variant::INT, "mean_type", PROPERTY_HINT_ENUM, "Centroid,KMeansPP"), "set_mean_type", "get_mean_type");
ClassDB::bind_method(D_METHOD("initialize"), &MLPPKMeans::initialize);
ClassDB::bind_method(D_METHOD("model_set_test", "X"), &MLPPKMeans::model_set_test);
ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPKMeans::model_test);
ClassDB::bind_method(D_METHOD("train", "epoch_num", "UI"), &MLPPKMeans::train, false);
ClassDB::bind_method(D_METHOD("score"), &MLPPKMeans::score);
ClassDB::bind_method(D_METHOD("silhouette_scores"), &MLPPKMeans::silhouette_scores);
BIND_ENUM_CONSTANT(MEAN_TYPE_CENTROID);
BIND_ENUM_CONSTANT(MEAN_TYPE_KMEANSPP);
}
/*
std::vector<std::vector<real_t>> MLPPKMeans::modelSetTest(std::vector<std::vector<real_t>> X) {
MLPPLinAlg alg;
std::vector<std::vector<real_t>> closestCentroids;
@ -207,8 +292,8 @@ void MLPPKMeans::kmeansppInitialization(int k) {
std::vector<real_t> farthestCentroid;
for (int j = 0; j < inputSet.size(); j++) {
real_t max_dist = 0;
/* SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST
AS TO SPREAD OUT THE CLUSTER CENTROIDS. */
// SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST
// AS TO SPREAD OUT THE CLUSTER CENTROIDS.
real_t sum = 0;
for (int k = 0; k < mu.size(); k++) {
sum += alg.euclideanDistance(inputSet[j], mu[k]);
@ -233,3 +318,4 @@ real_t MLPPKMeans::Cost() {
return sum;
}
*/

View File

@ -10,38 +10,64 @@
#include "core/math/math_defs.h"
#include <string>
#include <vector>
#include "core/object/reference.h"
#include "../lin_alg/mlpp_matrix.h"
#include "../lin_alg/mlpp_vector.h"
class MLPPKMeans : public Reference {
GDCLASS(MLPPKMeans, Reference);
class MLPPKMeans {
public:
MLPPKMeans(std::vector<std::vector<real_t>> inputSet, int k, std::string init_type = "Default");
std::vector<std::vector<real_t>> modelSetTest(std::vector<std::vector<real_t>> X);
std::vector<real_t> modelTest(std::vector<real_t> x);
void train(int epoch_num, bool UI = 1);
enum MeanType {
MEAN_TYPE_CENTROID = 0,
MEAN_TYPE_KMEANSPP,
};
public:
Ref<MLPPMatrix> get_input_set();
void set_input_set(const Ref<MLPPMatrix> &val);
int get_k();
void set_k(const int val);
MeanType get_mean_type();
void set_mean_type(const MeanType val);
void initialize();
Ref<MLPPMatrix> model_set_test(const Ref<MLPPMatrix> &X);
Ref<MLPPVector> model_test(const Ref<MLPPVector> &x);
void train(int epoch_num, bool UI = false);
real_t score();
std::vector<real_t> silhouette_scores();
Ref<MLPPVector> silhouette_scores();
private:
void Evaluate();
void computeMu();
MLPPKMeans();
~MLPPKMeans();
void centroidInitialization(int k);
void kmeansppInitialization(int k);
real_t Cost();
protected:
std::vector<std::vector<real_t>> inputSet;
std::vector<std::vector<real_t>> mu;
std::vector<std::vector<real_t>> r;
void _evaluate();
void _compute_mu();
real_t euclideanDistance(std::vector<real_t> A, std::vector<real_t> B);
void _centroid_initialization(int k);
void _kmeanspp_initialization(int k);
real_t _cost();
real_t accuracy_threshold;
int k;
static void _bind_methods();
std::string init_type;
Ref<MLPPMatrix> _input_set;
Ref<MLPPMatrix> _mu;
Ref<MLPPMatrix> _r;
real_t _accuracy_threshold;
int _k;
bool _initialized;
MeanType _mean_type;
};
VARIANT_ENUM_CAST(MLPPKMeans::MeanType);
#endif /* KMeans_hpp */

View File

@ -28,6 +28,7 @@ SOFTWARE.
#include "mlpp/lin_alg/mlpp_vector.h"
#include "mlpp/knn/knn.h"
#include "mlpp/kmeans/kmeans.h"
#include "test/mlpp_tests.h"
@ -37,6 +38,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPMatrix>();
ClassDB::register_class<MLPPKNN>();
ClassDB::register_class<MLPPKMeans>();
ClassDB::register_class<MLPPDataESimple>();
ClassDB::register_class<MLPPDataSimple>();

View File

@ -541,6 +541,7 @@ void MLPPTests::test_k_means(bool ui) {
MLPPLinAlg alg;
// KMeans
/*
std::vector<std::vector<real_t>> inputSet = { { 32, 0, 7 }, { 2, 28, 17 }, { 0, 9, 23 } };
MLPPKMeans kmeans(inputSet, 3, "KMeans++");
kmeans.train(3, ui);
@ -548,6 +549,7 @@ void MLPPTests::test_k_means(bool ui) {
alg.printMatrix(kmeans.modelSetTest(inputSet)); // Returns the assigned centroids to each of the respective training examples
std::cout << std::endl;
alg.printVector(kmeans.silhouette_scores());
*/
}
void MLPPTests::test_knn(bool ui) {
MLPPLinAlg alg;