From fa4b7b6b56cf6965d74e2af7ac013447038c1ec7 Mon Sep 17 00:00:00 2001 From: Relintai Date: Fri, 3 Feb 2023 02:08:48 +0100 Subject: [PATCH] Added bindings for Activation. --- mlpp/activation/activation.cpp | 47 ++++++++++++++++++++++++++++++++++ mlpp/activation/activation.h | 4 +-- register_types.cpp | 6 ++++- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/mlpp/activation/activation.cpp b/mlpp/activation/activation.cpp index adc37ea..1880276 100644 --- a/mlpp/activation/activation.cpp +++ b/mlpp/activation/activation.cpp @@ -2479,6 +2479,53 @@ Ref MLPPActivation::arcoth_derivm(const Ref &z) { return alg.element_wise_divisionm(alg.onematm(z->size().x, z->size().y), alg.subtractionm(alg.onematm(z->size().x, z->size().y), alg.hadamard_productm(z, z))); } +void MLPPActivation::_bind_methods() { + ClassDB::bind_method(D_METHOD("run_activation_real", "func", "z", "deriv"), &MLPPActivation::run_activation_real, false); + ClassDB::bind_method(D_METHOD("run_activation_vector", "func", "z", "deriv"), &MLPPActivation::run_activation_vector, false); + ClassDB::bind_method(D_METHOD("run_activation_matrix", "func", "z", "deriv"), &MLPPActivation::run_activation_matrix, false); + + ClassDB::bind_method(D_METHOD("run_activation_norm_real", "func", "z"), &MLPPActivation::run_activation_norm_real); + ClassDB::bind_method(D_METHOD("run_activation_norm_vector", "func", "z"), &MLPPActivation::run_activation_norm_vector); + ClassDB::bind_method(D_METHOD("run_activation_norm_matrix", "func", "z"), &MLPPActivation::run_activation_norm_matrix); + + real_t run_activation_norm_real(const ActivationFunction func, const real_t z); + Ref run_activation_norm_vector(const ActivationFunction func, const Ref &z); + Ref run_activation_norm_matrix(const ActivationFunction func, const Ref &z); + + ClassDB::bind_method(D_METHOD("run_activation_deriv_real", "func", "z"), &MLPPActivation::run_activation_deriv_real); + ClassDB::bind_method(D_METHOD("run_activation_deriv_vector", "func", "z"), &MLPPActivation::run_activation_deriv_vector); + ClassDB::bind_method(D_METHOD("run_activation_deriv_matrix", "func", "z"), &MLPPActivation::run_activation_deriv_matrix); + + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LINEAR); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGMOID); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SWISH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_MISH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIN_C); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTMAX); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTPLUS); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTSIGN); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ADJ_SOFTMAX); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_C_LOG_LOG); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LOGIT); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GAUSSIAN_CDF); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_RELU); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GELU); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGN); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_UNIT_STEP); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SINH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COSH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_TANH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_CSCH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SECH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COTH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSINH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOSH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARTANH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCSCH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSECH); + BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOTH); +} + //======================== OLD ============================= real_t MLPPActivation::linear(real_t z, bool deriv) { diff --git a/mlpp/activation/activation.h b/mlpp/activation/activation.h index 007ddd1..c2d6df2 100644 --- a/mlpp/activation/activation.h +++ b/mlpp/activation/activation.h @@ -18,7 +18,6 @@ #include -//TODO this should probably be a singleton //TODO Activation functions should either have a variant which does not allocate, or they should just be reworked altogether //TODO Methods here should probably use error macros, in a way where they get disabled in non-tools(?) (maybe release?) builds @@ -537,7 +536,8 @@ public: std::vector activation(std::vector z, bool deriv, real_t (*function)(real_t, bool)); -private: +protected: + static void _bind_methods(); }; VARIANT_ENUM_CAST(MLPPActivation::ActivationFunction); diff --git a/register_types.cpp b/register_types.cpp index e6fd666..47b0042 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -27,8 +27,10 @@ SOFTWARE. #include "mlpp/lin_alg/mlpp_matrix.h" #include "mlpp/lin_alg/mlpp_vector.h" -#include "mlpp/knn/knn.h" +#include "mlpp/activation/activation.h" + #include "mlpp/kmeans/kmeans.h" +#include "mlpp/knn/knn.h" #include "test/mlpp_tests.h" @@ -37,6 +39,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class();