diff --git a/mlpp/hidden_layer/hidden_layer.cpp b/mlpp/hidden_layer/hidden_layer.cpp index 483c07e..814945b 100644 --- a/mlpp/hidden_layer/hidden_layer.cpp +++ b/mlpp/hidden_layer/hidden_layer.cpp @@ -54,8 +54,10 @@ MLPPHiddenLayer::MLPPHiddenLayer(int p_n_hidden, MLPPActivation::ActivationFunct weights->resize(Size2i(input->size().x, n_hidden)); bias->resize(n_hidden); - MLPPUtilities::weight_initializationm(weights, weight_init); - MLPPUtilities::bias_initializationv(bias); + MLPPUtilities utils; + + utils.weight_initializationm(weights, weight_init); + utils.bias_initializationv(bias); } MLPPHiddenLayer::MLPPHiddenLayer() { diff --git a/mlpp/utilities/utilities.cpp b/mlpp/utilities/utilities.cpp index 69c116e..a9feca5 100644 --- a/mlpp/utilities/utilities.cpp +++ b/mlpp/utilities/utilities.cpp @@ -605,3 +605,23 @@ real_t MLPPUtilities::accuracy(std::vector y_hat, std::vector y) real_t MLPPUtilities::f1_score(std::vector y_hat, std::vector y) { return 2 * precision(y_hat, y) * recall(y_hat, y) / (precision(y_hat, y) + recall(y_hat, y)); } + +void MLPPUtilities::_bind_methods() { + ClassDB::bind_method(D_METHOD("weight_initializationv", "weights", "type"), &MLPPUtilities::weight_initializationv, WEIGHT_DISTRIBUTION_TYPE_DEFAULT); + ClassDB::bind_method(D_METHOD("weight_initializationm", "weights", "type"), &MLPPUtilities::weight_initializationm, WEIGHT_DISTRIBUTION_TYPE_DEFAULT); + ClassDB::bind_method(D_METHOD("bias_initializationr"), &MLPPUtilities::bias_initializationr); + ClassDB::bind_method(D_METHOD("bias_initializationv", "z"), &MLPPUtilities::bias_initializationv); + + ClassDB::bind_method(D_METHOD("performance_vec", "y_hat", "output_set"), &MLPPUtilities::performance_vec); + ClassDB::bind_method(D_METHOD("performance_mat", "y_hat", "y"), &MLPPUtilities::performance_mat); + ClassDB::bind_method(D_METHOD("performance_pool_int_array_vec", "y_hat", "output_set"), &MLPPUtilities::performance_pool_int_array_vec); + + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_DEFAULT); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_NORMAL); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_UNIFORM); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_NORMAL); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_UNIFORM); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_NORMAL); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_UNIFORM); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_UNIFORM); +} diff --git a/mlpp/utilities/utilities.h b/mlpp/utilities/utilities.h index a660744..60685e3 100644 --- a/mlpp/utilities/utilities.h +++ b/mlpp/utilities/utilities.h @@ -13,6 +13,8 @@ #include "core/string/ustring.h" #include "core/variant/variant.h" +#include "core/object/reference.h" + #include "../lin_alg/mlpp_matrix.h" #include "../lin_alg/mlpp_vector.h" @@ -20,7 +22,9 @@ #include #include -class MLPPUtilities { +class MLPPUtilities : public Reference { + GDCLASS(MLPPUtilities, Reference); + public: // Weight Init static std::vector weightInitialization(int n, std::string type = "Default"); @@ -40,10 +44,10 @@ public: WEIGHT_DISTRIBUTION_TYPE_UNIFORM, }; - static void weight_initializationv(Ref weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT); - static void weight_initializationm(Ref weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT); - static real_t bias_initializationr(); - static void bias_initializationv(Ref z); + void weight_initializationv(Ref weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT); + void weight_initializationm(Ref weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT); + real_t bias_initializationr(); + void bias_initializationv(Ref z); // Cost/Performance related Functions real_t performance(std::vector y_hat, std::vector y); @@ -76,7 +80,10 @@ public: real_t accuracy(std::vector y_hat, std::vector y); real_t f1_score(std::vector y_hat, std::vector y); -private: +protected: + static void _bind_methods(); }; +VARIANT_ENUM_CAST(MLPPUtilities::WeightDistributionType); + #endif /* Utilities_hpp */ diff --git a/register_types.cpp b/register_types.cpp index 19f9db5..227ceaf 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -28,6 +28,7 @@ SOFTWARE. #include "mlpp/lin_alg/mlpp_vector.h" #include "mlpp/regularization/reg.h" +#include "mlpp/utilities/utilities.h" #include "mlpp/activation/activation.h" @@ -41,6 +42,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class();