Added bindings for Activation.

This commit is contained in:
Relintai 2023-02-03 02:08:48 +01:00
parent 1aafbdf66a
commit fa4b7b6b56
3 changed files with 54 additions and 3 deletions

View File

@ -2479,6 +2479,53 @@ Ref<MLPPMatrix> MLPPActivation::arcoth_derivm(const Ref<MLPPMatrix> &z) {
return alg.element_wise_divisionm(alg.onematm(z->size().x, z->size().y), alg.subtractionm(alg.onematm(z->size().x, z->size().y), alg.hadamard_productm(z, z)));
}
void MLPPActivation::_bind_methods() {
ClassDB::bind_method(D_METHOD("run_activation_real", "func", "z", "deriv"), &MLPPActivation::run_activation_real, false);
ClassDB::bind_method(D_METHOD("run_activation_vector", "func", "z", "deriv"), &MLPPActivation::run_activation_vector, false);
ClassDB::bind_method(D_METHOD("run_activation_matrix", "func", "z", "deriv"), &MLPPActivation::run_activation_matrix, false);
ClassDB::bind_method(D_METHOD("run_activation_norm_real", "func", "z"), &MLPPActivation::run_activation_norm_real);
ClassDB::bind_method(D_METHOD("run_activation_norm_vector", "func", "z"), &MLPPActivation::run_activation_norm_vector);
ClassDB::bind_method(D_METHOD("run_activation_norm_matrix", "func", "z"), &MLPPActivation::run_activation_norm_matrix);
real_t run_activation_norm_real(const ActivationFunction func, const real_t z);
Ref<MLPPVector> run_activation_norm_vector(const ActivationFunction func, const Ref<MLPPVector> &z);
Ref<MLPPMatrix> run_activation_norm_matrix(const ActivationFunction func, const Ref<MLPPMatrix> &z);
ClassDB::bind_method(D_METHOD("run_activation_deriv_real", "func", "z"), &MLPPActivation::run_activation_deriv_real);
ClassDB::bind_method(D_METHOD("run_activation_deriv_vector", "func", "z"), &MLPPActivation::run_activation_deriv_vector);
ClassDB::bind_method(D_METHOD("run_activation_deriv_matrix", "func", "z"), &MLPPActivation::run_activation_deriv_matrix);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LINEAR);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGMOID);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SWISH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_MISH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIN_C);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTMAX);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTPLUS);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTSIGN);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ADJ_SOFTMAX);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_C_LOG_LOG);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LOGIT);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GAUSSIAN_CDF);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_RELU);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GELU);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGN);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_UNIT_STEP);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SINH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COSH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_TANH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_CSCH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SECH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COTH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSINH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOSH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARTANH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCSCH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSECH);
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOTH);
}
//======================== OLD =============================
real_t MLPPActivation::linear(real_t z, bool deriv) {

View File

@ -18,7 +18,6 @@
#include <vector>
//TODO this should probably be a singleton
//TODO Activation functions should either have a variant which does not allocate, or they should just be reworked altogether
//TODO Methods here should probably use error macros, in a way where they get disabled in non-tools(?) (maybe release?) builds
@ -537,7 +536,8 @@ public:
std::vector<real_t> activation(std::vector<real_t> z, bool deriv, real_t (*function)(real_t, bool));
private:
protected:
static void _bind_methods();
};
VARIANT_ENUM_CAST(MLPPActivation::ActivationFunction);

View File

@ -27,8 +27,10 @@ SOFTWARE.
#include "mlpp/lin_alg/mlpp_matrix.h"
#include "mlpp/lin_alg/mlpp_vector.h"
#include "mlpp/knn/knn.h"
#include "mlpp/activation/activation.h"
#include "mlpp/kmeans/kmeans.h"
#include "mlpp/knn/knn.h"
#include "test/mlpp_tests.h"
@ -37,6 +39,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPVector>();
ClassDB::register_class<MLPPMatrix>();
ClassDB::register_class<MLPPActivation>();
ClassDB::register_class<MLPPKNN>();
ClassDB::register_class<MLPPKMeans>();