#ifndef MLPP_ACTIVATION_H #define MLPP_ACTIVATION_H // // Activation.hpp // // Created by Marc Melikyan on 1/16/21. // #include "core/math/math_defs.h" #include "core/object/reference.h" #include "core/object/func_ref.h" #include "../lin_alg/mlpp_matrix.h" #include "../lin_alg/mlpp_vector.h" #include //TODO this should probably be a singleton //TODO Activation functions should either have a variant which does not allocate, or they should just be reworked altogether //TODO Methods here should probably use error macros, in a way where they get disabled in non-tools(?) (maybe release?) builds class MLPPActivation : public Reference { GDCLASS(MLPPActivation, Reference); public: enum ActivationFunction { ACTIVATION_FUNCTION_LINEAR = 0, ACTIVATION_FUNCTION_SIGMOID, ACTIVATION_FUNCTION_SWISH, ACTIVATION_FUNCTION_MISH, ACTIVATION_FUNCTION_SIN_C, ACTIVATION_FUNCTION_SOFTPLUS, ACTIVATION_FUNCTION_SOFTSIGN, ACTIVATION_FUNCTION_C_LOG_LOG, ACTIVATION_FUNCTION_LOGIT, ACTIVATION_FUNCTION_GAUSSIAN_CDF, ACTIVATION_FUNCTION_RELU, ACTIVATION_FUNCTION_GELU, ACTIVATION_FUNCTION_SIGN, ACTIVATION_FUNCTION_UNIT_STEP, ACTIVATION_FUNCTION_SINH, ACTIVATION_FUNCTION_COSH, ACTIVATION_FUNCTION_TANH, ACTIVATION_FUNCTION_CSCH, ACTIVATION_FUNCTION_SECH, ACTIVATION_FUNCTION_COTH, ACTIVATION_FUNCTION_ARSINH, ACTIVATION_FUNCTION_ARCOSH, ACTIVATION_FUNCTION_ARTANH, ACTIVATION_FUNCTION_ARCSCH, ACTIVATION_FUNCTION_ARSECH, ACTIVATION_FUNCTION_ARCOTH, }; public: //TODO add override for vec, and real_t typedef Ref (MLPPActivation::*ActivationFunctionPointer)(const Ref &); ActivationFunctionPointer get_activation_function_ptr(const ActivationFunction func, const bool deriv = false); Ref get_activation_function_funcref(const ActivationFunction func, const bool deriv = false); Ref run_activation_vector(const ActivationFunction func, const Ref &z, const bool deriv = false); Ref run_activation_matrix(const ActivationFunction func, const Ref &z, const bool deriv = false); Ref run_activation_norm_vector(const ActivationFunction func, const Ref &z); Ref run_activation_norm_matrix(const ActivationFunction func, const Ref &z); Ref run_activation_deriv_vector(const ActivationFunction func, const Ref &z); Ref run_activation_deriv_matrix(const ActivationFunction func, const Ref &z); Ref activation(const Ref &z, real_t (*function)(real_t)); //ACTIVATION FUNCTIONS //LINEAR real_t linear_norm(real_t z); Ref linear_norm(const Ref &z); Ref linear_norm(const Ref &z); real_t linear_deriv(real_t z); Ref linear_deriv(const Ref &z); Ref linear_deriv(const Ref &z); //SIGMOID real_t sigmoid_norm(real_t z); Ref sigmoid_norm(const Ref &z); Ref sigmoid_norm(const Ref &z); real_t sigmoid_deriv(real_t z); Ref sigmoid_deriv(const Ref &z); Ref sigmoid_deriv(const Ref &z); //SOFTMAX Ref softmax_norm(const Ref &z); Ref softmax_norm(const Ref &z); Ref softmax_deriv(const Ref &z); Ref softmax_deriv(const Ref &z); //ADJ_SOFTMAX Ref adj_softmax_norm(const Ref &z); Ref adj_softmax_norm(const Ref &z); Ref adj_softmax_deriv(const Ref &z); Ref adj_softmax_deriv(const Ref &z); //SOFTMAX DERIV Ref softmax_deriv_norm(const Ref &z); Vector> softmax_deriv_norm(const Ref &z); Ref softmax_deriv_deriv(const Ref &z); Vector> softmax_deriv_deriv(const Ref &z); //SOFTPLUS real_t softplus_norm(real_t z); Ref softplus_norm(const Ref &z); Ref softplus_norm(const Ref &z); real_t softplus_deriv(real_t z); Ref softplus_deriv(const Ref &z); Ref softplus_deriv(const Ref &z); //SOFTSIGN real_t softsign_norm(real_t z); Ref softsign_norm(const Ref &z); Ref softsign_norm(const Ref &z); real_t softsign_deriv(real_t z); Ref softsign_deriv(const Ref &z); Ref softsign_deriv(const Ref &z); //GAUSSIANCDF real_t gaussian_cdf_norm(real_t z); Ref gaussian_cdf_norm(const Ref &z); Ref gaussian_cdf_norm(const Ref &z); real_t gaussian_cdf_deriv(real_t z); Ref gaussian_cdf_deriv(const Ref &z); Ref gaussian_cdf_deriv(const Ref &z); //CLOGLOG real_t cloglog_norm(real_t z); Ref cloglog_norm(const Ref &z); Ref cloglog_norm(const Ref &z); real_t cloglog_deriv(real_t z); Ref cloglog_deriv(const Ref &z); Ref cloglog_deriv(const Ref &z); //LOGIT real_t logit_norm(real_t z); Ref logit_norm(const Ref &z); Ref logit_norm(const Ref &z); real_t logit_deriv(real_t z); Ref logit_deriv(const Ref &z); Ref logit_deriv(const Ref &z); //UNITSTEP real_t unit_step_norm(real_t z); Ref unit_step_norm(const Ref &z); Ref unit_step_norm(const Ref &z); real_t unit_step_deriv(real_t z); Ref unit_step_deriv(const Ref &z); Ref unit_step_deriv(const Ref &z); //SWISH real_t swish_norm(real_t z); Ref swish_norm(const Ref &z); Ref swish_norm(const Ref &z); real_t swish_deriv(real_t z); Ref swish_deriv(const Ref &z); Ref swish_deriv(const Ref &z); //MISH real_t mish_norm(real_t z); Ref mish_norm(const Ref &z); Ref mish_norm(const Ref &z); real_t mish_deriv(real_t z); Ref mish_deriv(const Ref &z); Ref mish_deriv(const Ref &z); //SINC real_t sinc_norm(real_t z); Ref sinc_norm(const Ref &z); Ref sinc_norm(const Ref &z); real_t sinc_deriv(real_t z); Ref sinc_deriv(const Ref &z); Ref sinc_deriv(const Ref &z); //RELU real_t relu_norm(real_t z); Ref relu_norm(const Ref &z); Ref relu_norm(const Ref &z); real_t relu_deriv(real_t z); Ref relu_deriv(const Ref &z); Ref relu_deriv(const Ref &z); //LEAKYRELU real_t leaky_relu_norm(real_t z, real_t c); Ref leaky_relu_norm(const Ref &z, real_t c); Ref leaky_relu_norm(const Ref &z, real_t c); real_t leaky_relu_deriv(real_t z, real_t c); Ref leaky_relu_deriv(const Ref &z, real_t c); Ref leaky_relu_deriv(const Ref &z, real_t c); //ELU real_t elu_norm(real_t z, real_t c); Ref elu_norm(const Ref &z, real_t c); Ref elu_norm(const Ref &z, real_t c); real_t elu_deriv(real_t z, real_t c); Ref elu_deriv(const Ref &z, real_t c); Ref elu_deriv(const Ref &z, real_t c); //SELU real_t selu_norm(real_t z, real_t lambda, real_t c); Ref selu_norm(const Ref &z, real_t lambda, real_t c); Ref selu_norm(Ref, real_t lambda, real_t c); real_t selu_deriv(real_t z, real_t lambda, real_t c); Ref selu_deriv(const Ref &z, real_t lambda, real_t c); Ref selu_deriv(Ref, real_t lambda, real_t c); //GELU real_t gelu_norm(real_t z); Ref gelu_norm(const Ref &z); Ref gelu_norm(const Ref &z); real_t gelu_deriv(real_t z); Ref gelu_deriv(const Ref &z); Ref gelu_deriv(const Ref &z); //SIGN real_t sign_norm(real_t z); Ref sign_norm(const Ref &z); Ref sign_norm(const Ref &z); real_t sign_deriv(real_t z); Ref sign_deriv(const Ref &z); Ref sign_deriv(const Ref &z); //SINH real_t sinh_norm(real_t z); Ref sinh_norm(const Ref &z); Ref sinh_norm(const Ref &z); real_t sinh_deriv(real_t z); Ref sinh_deriv(const Ref &z); Ref sinh_deriv(const Ref &z); //COSH real_t cosh_norm(real_t z); Ref cosh_norm(const Ref &z); Ref cosh_norm(const Ref &z); real_t cosh_deriv(real_t z); Ref cosh_deriv(const Ref &z); Ref cosh_deriv(const Ref &z); //TANH real_t tanh_norm(real_t z); Ref tanh_norm(const Ref &z); Ref tanh_norm(const Ref &z); real_t tanh_deriv(real_t z); Ref tanh_deriv(const Ref &z); Ref tanh_deriv(const Ref &z); //CSCH real_t csch_norm(real_t z); Ref csch_norm(const Ref &z); Ref csch_norm(const Ref &z); real_t csch_deriv(real_t z); Ref csch_deriv(const Ref &z); Ref csch_deriv(const Ref &z); //SECH real_t sech_norm(real_t z); Ref sech_norm(const Ref &z); Ref sech_norm(const Ref &z); real_t sech_deriv(real_t z); Ref sech_deriv(const Ref &z); Ref sech_deriv(const Ref &z); //COTH real_t coth_norm(real_t z); Ref coth_norm(const Ref &z); Ref coth_norm(const Ref &z); real_t coth_deriv(real_t z); Ref coth_deriv(const Ref &z); Ref coth_deriv(const Ref &z); //ARSINH real_t arsinh_norm(real_t z); Ref arsinh_norm(const Ref &z); Ref arsinh_norm(const Ref &z); real_t arsinh_deriv(real_t z); Ref arsinh_deriv(const Ref &z); Ref arsinh_deriv(const Ref &z); //ARCOSH real_t arcosh_norm(real_t z); Ref arcosh_norm(const Ref &z); Ref arcosh_norm(const Ref &z); real_t arcosh_deriv(real_t z); Ref arcosh_deriv(const Ref &z); Ref arcosh_deriv(const Ref &z); //ARTANH real_t artanh_norm(real_t z); Ref artanh_norm(const Ref &z); Ref artanh_norm(const Ref &z); real_t artanh_deriv(real_t z); Ref artanh_deriv(const Ref &z); Ref artanh_deriv(const Ref &z); //ARCSCH real_t arcsch_norm(real_t z); Ref arcsch_norm(const Ref &z); Ref arcsch_norm(const Ref &z); real_t arcsch_deriv(real_t z); Ref arcsch_deriv(const Ref &z); Ref arcsch_deriv(const Ref &z); //ARSECH real_t arsech_norm(real_t z); Ref arsech_norm(const Ref &z); Ref arsech_norm(const Ref &z); real_t arsech_deriv(real_t z); Ref arsech_deriv(const Ref &z); Ref arsech_deriv(const Ref &z); //ARCOTH real_t arcoth_norm(real_t z); Ref arcoth_norm(const Ref &z); Ref arcoth_norm(const Ref &z); real_t arcoth_deriv(real_t z); Ref arcoth_deriv(const Ref &z); Ref arcoth_deriv(const Ref &z); // ========= OLD =========== real_t linear(real_t z, bool deriv = false); std::vector linear(std::vector z, bool deriv = false); std::vector> linear(std::vector> z, bool deriv = false); real_t sigmoid(real_t z, bool deriv = false); std::vector sigmoid(std::vector z, bool deriv = false); std::vector> sigmoid(std::vector> z, bool deriv = false); std::vector softmax(std::vector z, bool deriv = false); std::vector> softmax(std::vector> z, bool deriv = false); std::vector adjSoftmax(std::vector z); std::vector> adjSoftmax(std::vector> z); std::vector> softmaxDeriv(std::vector z); std::vector>> softmaxDeriv(std::vector> z); real_t softplus(real_t z, bool deriv = false); std::vector softplus(std::vector z, bool deriv = false); std::vector> softplus(std::vector> z, bool deriv = false); real_t softsign(real_t z, bool deriv = false); std::vector softsign(std::vector z, bool deriv = false); std::vector> softsign(std::vector> z, bool deriv = false); real_t gaussianCDF(real_t z, bool deriv = false); std::vector gaussianCDF(std::vector z, bool deriv = false); std::vector> gaussianCDF(std::vector> z, bool deriv = false); real_t cloglog(real_t z, bool deriv = false); std::vector cloglog(std::vector z, bool deriv = false); std::vector> cloglog(std::vector> z, bool deriv = false); real_t logit(real_t z, bool deriv = false); std::vector logit(std::vector z, bool deriv = false); std::vector> logit(std::vector> z, bool deriv = false); real_t unitStep(real_t z, bool deriv = false); std::vector unitStep(std::vector z, bool deriv = false); std::vector> unitStep(std::vector> z, bool deriv = false); real_t swish(real_t z, bool deriv = false); std::vector swish(std::vector z, bool deriv = false); std::vector> swish(std::vector> z, bool deriv = false); real_t mish(real_t z, bool deriv = false); std::vector mish(std::vector z, bool deriv = false); std::vector> mish(std::vector> z, bool deriv = false); real_t sinc(real_t z, bool deriv = false); std::vector sinc(std::vector z, bool deriv = false); std::vector> sinc(std::vector> z, bool deriv = false); real_t RELU(real_t z, bool deriv = false); std::vector RELU(std::vector z, bool deriv = false); std::vector> RELU(std::vector> z, bool deriv = false); real_t leakyReLU(real_t z, real_t c, bool deriv = false); std::vector leakyReLU(std::vector z, real_t c, bool deriv = false); std::vector> leakyReLU(std::vector> z, real_t c, bool deriv = false); real_t ELU(real_t z, real_t c, bool deriv = false); std::vector ELU(std::vector z, real_t c, bool deriv = false); std::vector> ELU(std::vector> z, real_t c, bool deriv = false); real_t SELU(real_t z, real_t lambda, real_t c, bool deriv = false); std::vector SELU(std::vector z, real_t lambda, real_t c, bool deriv = false); std::vector> SELU(std::vector>, real_t lambda, real_t c, bool deriv = false); real_t GELU(real_t z, bool deriv = false); std::vector GELU(std::vector z, bool deriv = false); std::vector> GELU(std::vector> z, bool deriv = false); real_t sign(real_t z, bool deriv = false); std::vector sign(std::vector z, bool deriv = false); std::vector> sign(std::vector> z, bool deriv = false); real_t sinh(real_t z, bool deriv = false); std::vector sinh(std::vector z, bool deriv = false); std::vector> sinh(std::vector> z, bool deriv = false); real_t cosh(real_t z, bool deriv = false); std::vector cosh(std::vector z, bool deriv = false); std::vector> cosh(std::vector> z, bool deriv = false); real_t tanh(real_t z, bool deriv = false); std::vector tanh(std::vector z, bool deriv = false); std::vector> tanh(std::vector> z, bool deriv = false); real_t csch(real_t z, bool deriv = false); std::vector csch(std::vector z, bool deriv = false); std::vector> csch(std::vector> z, bool deriv = false); real_t sech(real_t z, bool deriv = false); std::vector sech(std::vector z, bool deriv = false); std::vector> sech(std::vector> z, bool deriv = false); real_t coth(real_t z, bool deriv = false); std::vector coth(std::vector z, bool deriv = false); std::vector> coth(std::vector> z, bool deriv = false); real_t arsinh(real_t z, bool deriv = false); std::vector arsinh(std::vector z, bool deriv = false); std::vector> arsinh(std::vector> z, bool deriv = false); real_t arcosh(real_t z, bool deriv = false); std::vector arcosh(std::vector z, bool deriv = false); std::vector> arcosh(std::vector> z, bool deriv = false); real_t artanh(real_t z, bool deriv = false); std::vector artanh(std::vector z, bool deriv = false); std::vector> artanh(std::vector> z, bool deriv = false); real_t arcsch(real_t z, bool deriv = false); std::vector arcsch(std::vector z, bool deriv = false); std::vector> arcsch(std::vector> z, bool deriv = false); real_t arsech(real_t z, bool deriv = false); std::vector arsech(std::vector z, bool deriv = false); std::vector> arsech(std::vector> z, bool deriv = false); real_t arcoth(real_t z, bool deriv = false); std::vector arcoth(std::vector z, bool deriv = false); std::vector> arcoth(std::vector> z, bool deriv = false); std::vector activation(std::vector z, bool deriv, real_t (*function)(real_t, bool)); private: }; VARIANT_ENUM_CAST(MLPPActivation::ActivationFunction); #endif /* Activation_hpp */