diff --git a/mlpp/kmeans/kmeans.cpp b/mlpp/kmeans/kmeans.cpp index 0e5d1a1..d0db130 100644 --- a/mlpp/kmeans/kmeans.cpp +++ b/mlpp/kmeans/kmeans.cpp @@ -12,16 +12,101 @@ #include #include +Ref MLPPKMeans::get_input_set() { + return _input_set; +} +void MLPPKMeans::set_input_set(const Ref &val) { + _input_set = val; + _initialized = false; +} -MLPPKMeans::MLPPKMeans(std::vector> inputSet, int k, std::string init_type) : - inputSet(inputSet), k(k), init_type(init_type) { - if (init_type == "KMeans++") { - kmeansppInitialization(k); +int MLPPKMeans::get_k() { + return _k; +} +void MLPPKMeans::set_k(const int val) { + _k = val; + _initialized = false; +} + +MLPPKMeans::MeanType MLPPKMeans::get_mean_type() { + return _mean_type; +} +void MLPPKMeans::set_mean_type(const MLPPKMeans::MeanType val) { + _mean_type = val; + _initialized = false; +} + +void MLPPKMeans::initialize() { + if (_mean_type == MEAN_TYPE_KMEANSPP) { + _kmeanspp_initialization(_k); } else { - centroidInitialization(k); + _centroid_initialization(_k); } } +Ref MLPPKMeans::model_set_test(const Ref &X) { + return Ref(); +} +Ref MLPPKMeans::model_test(const Ref &x) { + return Ref(); +} +void MLPPKMeans::train(int epoch_num, bool UI) { +} +real_t MLPPKMeans::score() { + return 0; +} +Ref MLPPKMeans::silhouette_scores() { + return Ref(); +} + +MLPPKMeans::MLPPKMeans() { + _accuracy_threshold = 0; + _k = 0; + _initialized = false; + + _mean_type = MEAN_TYPE_CENTROID; +} +MLPPKMeans::~MLPPKMeans() { +} + +void MLPPKMeans::_evaluate() { +} +void MLPPKMeans::_compute_mu() { +} + +void MLPPKMeans::_centroid_initialization(int k) { +} +void MLPPKMeans::_kmeanspp_initialization(int k) { +} +real_t MLPPKMeans::_cost() { + return 0; +} + +void MLPPKMeans::_bind_methods() { + ClassDB::bind_method(D_METHOD("get_input_set"), &MLPPKMeans::get_input_set); + ClassDB::bind_method(D_METHOD("set_input_set", "value"), &MLPPKMeans::set_input_set); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "input_set", PROPERTY_HINT_RESOURCE_TYPE, "MLPPMatrix"), "set_input_set", "get_input_set"); + + ClassDB::bind_method(D_METHOD("get_k"), &MLPPKMeans::get_k); + ClassDB::bind_method(D_METHOD("set_k", "value"), &MLPPKMeans::set_k); + ADD_PROPERTY(PropertyInfo(Variant::INT, "k"), "set_k", "get_k"); + + ClassDB::bind_method(D_METHOD("get_mean_type"), &MLPPKMeans::get_mean_type); + ClassDB::bind_method(D_METHOD("set_mean_type", "value"), &MLPPKMeans::set_mean_type); + ADD_PROPERTY(PropertyInfo(Variant::INT, "mean_type", PROPERTY_HINT_ENUM, "Centroid,KMeansPP"), "set_mean_type", "get_mean_type"); + + ClassDB::bind_method(D_METHOD("initialize"), &MLPPKMeans::initialize); + ClassDB::bind_method(D_METHOD("model_set_test", "X"), &MLPPKMeans::model_set_test); + ClassDB::bind_method(D_METHOD("model_test", "x"), &MLPPKMeans::model_test); + ClassDB::bind_method(D_METHOD("train", "epoch_num", "UI"), &MLPPKMeans::train, false); + ClassDB::bind_method(D_METHOD("score"), &MLPPKMeans::score); + ClassDB::bind_method(D_METHOD("silhouette_scores"), &MLPPKMeans::silhouette_scores); + + BIND_ENUM_CONSTANT(MEAN_TYPE_CENTROID); + BIND_ENUM_CONSTANT(MEAN_TYPE_KMEANSPP); +} + +/* std::vector> MLPPKMeans::modelSetTest(std::vector> X) { MLPPLinAlg alg; std::vector> closestCentroids; @@ -207,8 +292,8 @@ void MLPPKMeans::kmeansppInitialization(int k) { std::vector farthestCentroid; for (int j = 0; j < inputSet.size(); j++) { real_t max_dist = 0; - /* SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST - AS TO SPREAD OUT THE CLUSTER CENTROIDS. */ + // SUM ALL THE SQUARED DISTANCES, CHOOSE THE ONE THAT'S FARTHEST + // AS TO SPREAD OUT THE CLUSTER CENTROIDS. real_t sum = 0; for (int k = 0; k < mu.size(); k++) { sum += alg.euclideanDistance(inputSet[j], mu[k]); @@ -233,3 +318,4 @@ real_t MLPPKMeans::Cost() { return sum; } +*/ \ No newline at end of file diff --git a/mlpp/kmeans/kmeans.h b/mlpp/kmeans/kmeans.h index 5a93012..8dbea48 100644 --- a/mlpp/kmeans/kmeans.h +++ b/mlpp/kmeans/kmeans.h @@ -10,38 +10,64 @@ #include "core/math/math_defs.h" -#include -#include +#include "core/object/reference.h" +#include "../lin_alg/mlpp_matrix.h" +#include "../lin_alg/mlpp_vector.h" + +class MLPPKMeans : public Reference { + GDCLASS(MLPPKMeans, Reference); -class MLPPKMeans { public: - MLPPKMeans(std::vector> inputSet, int k, std::string init_type = "Default"); - std::vector> modelSetTest(std::vector> X); - std::vector modelTest(std::vector x); - void train(int epoch_num, bool UI = 1); + enum MeanType { + MEAN_TYPE_CENTROID = 0, + MEAN_TYPE_KMEANSPP, + }; + +public: + Ref get_input_set(); + void set_input_set(const Ref &val); + + int get_k(); + void set_k(const int val); + + MeanType get_mean_type(); + void set_mean_type(const MeanType val); + + void initialize(); + + Ref model_set_test(const Ref &X); + Ref model_test(const Ref &x); + void train(int epoch_num, bool UI = false); real_t score(); - std::vector silhouette_scores(); + Ref silhouette_scores(); -private: - void Evaluate(); - void computeMu(); + MLPPKMeans(); + ~MLPPKMeans(); - void centroidInitialization(int k); - void kmeansppInitialization(int k); - real_t Cost(); +protected: + - std::vector> inputSet; - std::vector> mu; - std::vector> r; + void _evaluate(); + void _compute_mu(); - real_t euclideanDistance(std::vector A, std::vector B); + void _centroid_initialization(int k); + void _kmeanspp_initialization(int k); + real_t _cost(); - real_t accuracy_threshold; - int k; + static void _bind_methods(); - std::string init_type; + Ref _input_set; + Ref _mu; + Ref _r; + + real_t _accuracy_threshold; + int _k; + bool _initialized; + + MeanType _mean_type; }; +VARIANT_ENUM_CAST(MLPPKMeans::MeanType); #endif /* KMeans_hpp */ diff --git a/register_types.cpp b/register_types.cpp index 67379ce..e6fd666 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -28,6 +28,7 @@ SOFTWARE. #include "mlpp/lin_alg/mlpp_vector.h" #include "mlpp/knn/knn.h" +#include "mlpp/kmeans/kmeans.h" #include "test/mlpp_tests.h" @@ -37,6 +38,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 74eafa7..9822cbb 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -541,6 +541,7 @@ void MLPPTests::test_k_means(bool ui) { MLPPLinAlg alg; // KMeans + /* std::vector> inputSet = { { 32, 0, 7 }, { 2, 28, 17 }, { 0, 9, 23 } }; MLPPKMeans kmeans(inputSet, 3, "KMeans++"); kmeans.train(3, ui); @@ -548,6 +549,7 @@ void MLPPTests::test_k_means(bool ui) { alg.printMatrix(kmeans.modelSetTest(inputSet)); // Returns the assigned centroids to each of the respective training examples std::cout << std::endl; alg.printVector(kmeans.silhouette_scores()); + */ } void MLPPTests::test_knn(bool ui) { MLPPLinAlg alg;