#ifndef MLPP_ACTIVATION_H #define MLPP_ACTIVATION_H // // Activation.hpp // // Created by Marc Melikyan on 1/16/21. // #include "core/math/math_defs.h" #include "core/object/func_ref.h" #include "core/object/reference.h" #include "../lin_alg/mlpp_matrix.h" #include "../lin_alg/mlpp_vector.h" #include //TODO Activation functions should either have a variant which does not allocate, or they should just be reworked altogether //TODO Methods here should probably use error macros, in a way where they get disabled in non-tools(?) (maybe release?) builds class MLPPActivation : public Reference { GDCLASS(MLPPActivation, Reference); public: enum ActivationFunction { ACTIVATION_FUNCTION_LINEAR = 0, ACTIVATION_FUNCTION_SIGMOID, ACTIVATION_FUNCTION_SWISH, ACTIVATION_FUNCTION_MISH, ACTIVATION_FUNCTION_SIN_C, ACTIVATION_FUNCTION_SOFTMAX, ACTIVATION_FUNCTION_SOFTPLUS, ACTIVATION_FUNCTION_SOFTSIGN, ACTIVATION_FUNCTION_ADJ_SOFTMAX, ACTIVATION_FUNCTION_C_LOG_LOG, ACTIVATION_FUNCTION_LOGIT, ACTIVATION_FUNCTION_GAUSSIAN_CDF, ACTIVATION_FUNCTION_RELU, ACTIVATION_FUNCTION_GELU, ACTIVATION_FUNCTION_SIGN, ACTIVATION_FUNCTION_UNIT_STEP, ACTIVATION_FUNCTION_SINH, ACTIVATION_FUNCTION_COSH, ACTIVATION_FUNCTION_TANH, ACTIVATION_FUNCTION_CSCH, ACTIVATION_FUNCTION_SECH, ACTIVATION_FUNCTION_COTH, ACTIVATION_FUNCTION_ARSINH, ACTIVATION_FUNCTION_ARCOSH, ACTIVATION_FUNCTION_ARTANH, ACTIVATION_FUNCTION_ARCSCH, ACTIVATION_FUNCTION_ARSECH, ACTIVATION_FUNCTION_ARCOTH, }; public: typedef real_t (MLPPActivation::*RealActivationFunctionPointer)(real_t); typedef Ref (MLPPActivation::*VectorActivationFunctionPointer)(const Ref &); typedef Ref (MLPPActivation::*MatrixActivationFunctionPointer)(const Ref &); RealActivationFunctionPointer get_activation_function_ptr_real(const ActivationFunction func, const bool deriv = false); VectorActivationFunctionPointer get_activation_function_ptr_vector(const ActivationFunction func, const bool deriv = false); MatrixActivationFunctionPointer get_activation_function_ptr_matrix(const ActivationFunction func, const bool deriv = false); RealActivationFunctionPointer get_activation_function_ptr_normal_real(const ActivationFunction func); VectorActivationFunctionPointer get_activation_function_ptr_normal_vector(const ActivationFunction func); MatrixActivationFunctionPointer get_activation_function_ptr_normal_matrix(const ActivationFunction func); RealActivationFunctionPointer get_activation_function_ptr_deriv_real(const ActivationFunction func); VectorActivationFunctionPointer get_activation_function_ptr_deriv_vector(const ActivationFunction func); MatrixActivationFunctionPointer get_activation_function_ptr_deriv_matrix(const ActivationFunction func); real_t run_activation_real(const ActivationFunction func, const real_t z, const bool deriv = false); Ref run_activation_vector(const ActivationFunction func, const Ref &z, const bool deriv = false); Ref run_activation_matrix(const ActivationFunction func, const Ref &z, const bool deriv = false); real_t run_activation_norm_real(const ActivationFunction func, const real_t z); Ref run_activation_norm_vector(const ActivationFunction func, const Ref &z); Ref run_activation_norm_matrix(const ActivationFunction func, const Ref &z); real_t run_activation_deriv_real(const ActivationFunction func, const real_t z); Ref run_activation_deriv_vector(const ActivationFunction func, const Ref &z); Ref run_activation_deriv_matrix(const ActivationFunction func, const Ref &z); Ref activationr(const Ref &z, real_t (*function)(real_t)); //ACTIVATION FUNCTIONS //LINEAR real_t linear_normr(real_t z); Ref linear_normv(const Ref &z); Ref linear_normm(const Ref &z); real_t linear_derivr(real_t z); Ref linear_derivv(const Ref &z); Ref linear_derivm(const Ref &z); //SIGMOID real_t sigmoid_normr(real_t z); Ref sigmoid_normv(const Ref &z); Ref sigmoid_normm(const Ref &z); real_t sigmoid_derivr(real_t z); Ref sigmoid_derivv(const Ref &z); Ref sigmoid_derivm(const Ref &z); //SOFTMAX real_t softmax_normr(real_t z); Ref softmax_normv(const Ref &z); Ref softmax_normm(const Ref &z); real_t softmax_derivr(real_t z); Ref softmax_derivv(const Ref &z); Ref softmax_derivm(const Ref &z); //ADJ_SOFTMAX real_t adj_softmax_normr(real_t z); Ref adj_softmax_normv(const Ref &z); Ref adj_softmax_normm(const Ref &z); real_t adj_softmax_derivr(real_t z); Ref adj_softmax_derivv(const Ref &z); Ref adj_softmax_derivm(const Ref &z); //SOFTMAX DERIV Ref softmax_deriv_normv(const Ref &z); Vector> softmax_deriv_normm(const Ref &z); Ref softmax_deriv_derivv(const Ref &z); Vector> softmax_deriv_derivm(const Ref &z); //SOFTPLUS real_t softplus_normr(real_t z); Ref softplus_normv(const Ref &z); Ref softplus_normm(const Ref &z); real_t softplus_derivr(real_t z); Ref softplus_derivv(const Ref &z); Ref softplus_derivm(const Ref &z); //SOFTSIGN real_t softsign_normr(real_t z); Ref softsign_normv(const Ref &z); Ref softsign_normm(const Ref &z); real_t softsign_derivr(real_t z); Ref softsign_derivv(const Ref &z); Ref softsign_derivm(const Ref &z); //GAUSSIANCDF real_t gaussian_cdf_normr(real_t z); Ref gaussian_cdf_normv(const Ref &z); Ref gaussian_cdf_normm(const Ref &z); real_t gaussian_cdf_derivr(real_t z); Ref gaussian_cdf_derivv(const Ref &z); Ref gaussian_cdf_derivm(const Ref &z); //CLOGLOG real_t cloglog_normr(real_t z); Ref cloglog_normv(const Ref &z); Ref cloglog_normm(const Ref &z); real_t cloglog_derivr(real_t z); Ref cloglog_derivv(const Ref &z); Ref cloglog_derivm(const Ref &z); //LOGIT real_t logit_normr(real_t z); Ref logit_normv(const Ref &z); Ref logit_normm(const Ref &z); real_t logit_derivr(real_t z); Ref logit_derivv(const Ref &z); Ref logit_derivm(const Ref &z); //UNITSTEP real_t unit_step_normr(real_t z); Ref unit_step_normv(const Ref &z); Ref unit_step_normm(const Ref &z); real_t unit_step_derivr(real_t z); Ref unit_step_derivv(const Ref &z); Ref unit_step_derivm(const Ref &z); //SWISH real_t swish_normr(real_t z); Ref swish_normv(const Ref &z); Ref swish_normm(const Ref &z); real_t swish_derivr(real_t z); Ref swish_derivv(const Ref &z); Ref swish_derivm(const Ref &z); //MISH real_t mish_normr(real_t z); Ref mish_normv(const Ref &z); Ref mish_normm(const Ref &z); real_t mish_derivr(real_t z); Ref mish_derivv(const Ref &z); Ref mish_derivm(const Ref &z); //SINC real_t sinc_normr(real_t z); Ref sinc_normv(const Ref &z); Ref sinc_normm(const Ref &z); real_t sinc_derivr(real_t z); Ref sinc_derivv(const Ref &z); Ref sinc_derivm(const Ref &z); //RELU real_t relu_normr(real_t z); Ref relu_normv(const Ref &z); Ref relu_normm(const Ref &z); real_t relu_derivr(real_t z); Ref relu_derivv(const Ref &z); Ref relu_derivm(const Ref &z); //LEAKYRELU real_t leaky_relu_normr(real_t z, real_t c); Ref leaky_relu_normv(const Ref &z, real_t c); Ref leaky_relu_normm(const Ref &z, real_t c); real_t leaky_relu_derivr(real_t z, real_t c); Ref leaky_relu_derivv(const Ref &z, real_t c); Ref leaky_relu_derivm(const Ref &z, real_t c); //ELU real_t elu_normr(real_t z, real_t c); Ref elu_normv(const Ref &z, real_t c); Ref elu_normm(const Ref &z, real_t c); real_t elu_derivr(real_t z, real_t c); Ref elu_derivv(const Ref &z, real_t c); Ref elu_derivm(const Ref &z, real_t c); //SELU real_t selu_normr(real_t z, real_t lambda, real_t c); Ref selu_normv(const Ref &z, real_t lambda, real_t c); Ref selu_normm(const Ref &z, real_t lambda, real_t c); real_t selu_derivr(real_t z, real_t lambda, real_t c); Ref selu_derivv(const Ref &z, real_t lambda, real_t c); Ref selu_derivm(const Ref &z, real_t lambda, real_t c); //GELU real_t gelu_normr(real_t z); Ref gelu_normv(const Ref &z); Ref gelu_normm(const Ref &z); real_t gelu_derivr(real_t z); Ref gelu_derivv(const Ref &z); Ref gelu_derivm(const Ref &z); //SIGN real_t sign_normr(real_t z); Ref sign_normv(const Ref &z); Ref sign_normm(const Ref &z); real_t sign_derivr(real_t z); Ref sign_derivv(const Ref &z); Ref sign_derivm(const Ref &z); //SINH real_t sinh_normr(real_t z); Ref sinh_normv(const Ref &z); Ref sinh_normm(const Ref &z); real_t sinh_derivr(real_t z); Ref sinh_derivv(const Ref &z); Ref sinh_derivm(const Ref &z); //COSH real_t cosh_normr(real_t z); Ref cosh_normv(const Ref &z); Ref cosh_normm(const Ref &z); real_t cosh_derivr(real_t z); Ref cosh_derivv(const Ref &z); Ref cosh_derivm(const Ref &z); //TANH real_t tanh_normr(real_t z); Ref tanh_normv(const Ref &z); Ref tanh_normm(const Ref &z); real_t tanh_derivr(real_t z); Ref tanh_derivv(const Ref &z); Ref tanh_derivm(const Ref &z); //CSCH real_t csch_normr(real_t z); Ref csch_normv(const Ref &z); Ref csch_normm(const Ref &z); real_t csch_derivr(real_t z); Ref csch_derivv(const Ref &z); Ref csch_derivm(const Ref &z); //SECH real_t sech_normr(real_t z); Ref sech_normv(const Ref &z); Ref sech_normm(const Ref &z); real_t sech_derivr(real_t z); Ref sech_derivv(const Ref &z); Ref sech_derivm(const Ref &z); //COTH real_t coth_normr(real_t z); Ref coth_normv(const Ref &z); Ref coth_normm(const Ref &z); real_t coth_derivr(real_t z); Ref coth_derivv(const Ref &z); Ref coth_derivm(const Ref &z); //ARSINH real_t arsinh_normr(real_t z); Ref arsinh_normv(const Ref &z); Ref arsinh_normm(const Ref &z); real_t arsinh_derivr(real_t z); Ref arsinh_derivv(const Ref &z); Ref arsinh_derivm(const Ref &z); //ARCOSH real_t arcosh_normr(real_t z); Ref arcosh_normv(const Ref &z); Ref arcosh_normm(const Ref &z); real_t arcosh_derivr(real_t z); Ref arcosh_derivv(const Ref &z); Ref arcosh_derivm(const Ref &z); //ARTANH real_t artanh_normr(real_t z); Ref artanh_normv(const Ref &z); Ref artanh_normm(const Ref &z); real_t artanh_derivr(real_t z); Ref artanh_derivv(const Ref &z); Ref artanh_derivm(const Ref &z); //ARCSCH real_t arcsch_normr(real_t z); Ref arcsch_normv(const Ref &z); Ref arcsch_normm(const Ref &z); real_t arcsch_derivr(real_t z); Ref arcsch_derivv(const Ref &z); Ref arcsch_derivm(const Ref &z); //ARSECH real_t arsech_normr(real_t z); Ref arsech_normv(const Ref &z); Ref arsech_normm(const Ref &z); real_t arsech_derivr(real_t z); Ref arsech_derivv(const Ref &z); Ref arsech_derivm(const Ref &z); //ARCOTH real_t arcoth_normr(real_t z); Ref arcoth_normv(const Ref &z); Ref arcoth_normm(const Ref &z); real_t arcoth_derivr(real_t z); Ref arcoth_derivv(const Ref &z); Ref arcoth_derivm(const Ref &z); protected: static void _bind_methods(); }; VARIANT_ENUM_CAST(MLPPActivation::ActivationFunction); #endif /* Activation_hpp */