From f6a32ca301b2b3dd11094cb501bc529f0e9c34d9 Mon Sep 17 00:00:00 2001 From: Relintai Date: Fri, 3 Feb 2023 01:58:17 +0100 Subject: [PATCH] Implement the new helper methods in Activation. --- mlpp/activation/activation.cpp | 800 ++++++++++++++++++++++++++++++++- mlpp/activation/activation.h | 30 +- 2 files changed, 816 insertions(+), 14 deletions(-) diff --git a/mlpp/activation/activation.cpp b/mlpp/activation/activation.cpp index 35d03e2..60cea06 100644 --- a/mlpp/activation/activation.cpp +++ b/mlpp/activation/activation.cpp @@ -10,32 +10,799 @@ #include #include -MLPPActivation::ActivationFunctionPointer MLPPActivation::get_activation_function_ptr(const ActivationFunction func, const bool deriv) { - return NULL; +MLPPActivation::RealActivationFunctionPointer MLPPActivation::get_activation_function_ptr_real(const ActivationFunction func, const bool deriv) { + if (deriv) { + return get_activation_function_ptr_normal_real(func); + } else { + return get_activation_function_ptr_deriv_real(func); + } +} +MLPPActivation::VectorActivationFunctionPointer MLPPActivation::get_activation_function_ptr_vector(const ActivationFunction func, const bool deriv) { + if (deriv) { + return get_activation_function_ptr_normal_vector(func); + } else { + return get_activation_function_ptr_deriv_vector(func); + } +} +MLPPActivation::MatrixActivationFunctionPointer MLPPActivation::get_activation_function_ptr_matrix(const ActivationFunction func, const bool deriv) { + if (deriv) { + return get_activation_function_ptr_normal_matrix(func); + } else { + return get_activation_function_ptr_deriv_matrix(func); + } } +MLPPActivation::RealActivationFunctionPointer MLPPActivation::get_activation_function_ptr_normal_real(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_normr; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_normr; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_normr; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_normr; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_normr; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_normr; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_normr; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_normr; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_normr; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_normr; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_normr; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_normr; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_normr; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_normr; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_normr; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_normr; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_normr; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_normr; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_normr; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_normr; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_normr; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_normr; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_normr; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_normr; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_normr; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_normr; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_normr; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_normr; + default: + return NULL; + } +} +MLPPActivation::VectorActivationFunctionPointer MLPPActivation::get_activation_function_ptr_normal_vector(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_normv; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_normv; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_normv; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_normv; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_normv; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_normv; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_normv; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_normv; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_normv; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_normv; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_normv; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_normv; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_normv; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_normv; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_normv; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_normv; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_normv; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_normv; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_normv; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_normv; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_normv; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_normv; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_normv; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_normv; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_normv; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_normv; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_normv; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_normv; + default: + return NULL; + } +} +MLPPActivation::MatrixActivationFunctionPointer MLPPActivation::get_activation_function_ptr_normal_matrix(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_normm; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_normm; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_normm; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_normm; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_normm; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_normm; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_normm; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_normm; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_normm; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_normm; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_normm; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_normm; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_normm; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_normm; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_normm; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_normm; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_normm; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_normm; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_normm; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_normm; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_normm; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_normm; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_normm; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_normm; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_normm; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_normm; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_normm; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_normm; + default: + return NULL; + } +} + +MLPPActivation::RealActivationFunctionPointer MLPPActivation::get_activation_function_ptr_deriv_real(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_normr; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_normr; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_normr; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_normr; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_normr; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_normr; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_normr; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_normr; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_normr; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_normr; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_normr; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_normr; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_normr; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_normr; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_normr; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_normr; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_normr; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_normr; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_normr; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_normr; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_normr; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_normr; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_normr; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_normr; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_normr; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_normr; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_normr; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_normr; + default: + return NULL; + } +} +MLPPActivation::VectorActivationFunctionPointer MLPPActivation::get_activation_function_ptr_deriv_vector(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_derivv; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_derivv; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_derivv; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_derivv; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_derivv; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_derivv; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_derivv; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_derivv; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_derivv; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_derivv; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_derivv; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_derivv; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_derivv; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_derivv; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_derivv; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_derivv; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_derivv; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_derivv; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_derivv; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_derivv; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_derivv; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_derivv; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_derivv; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_derivv; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_derivv; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_derivv; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_derivv; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_derivv; + default: + return NULL; + } +} +MLPPActivation::MatrixActivationFunctionPointer MLPPActivation::get_activation_function_ptr_deriv_matrix(const ActivationFunction func) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return &MLPPActivation::linear_derivm; + case ACTIVATION_FUNCTION_SIGMOID: + return &MLPPActivation::sigmoid_derivm; + case ACTIVATION_FUNCTION_SWISH: + return &MLPPActivation::swish_derivm; + case ACTIVATION_FUNCTION_MISH: + return &MLPPActivation::mish_derivm; + case ACTIVATION_FUNCTION_SIN_C: + return &MLPPActivation::sinc_derivm; + case ACTIVATION_FUNCTION_SOFTMAX: + return &MLPPActivation::softmax_derivm; + case ACTIVATION_FUNCTION_SOFTPLUS: + return &MLPPActivation::softplus_derivm; + case ACTIVATION_FUNCTION_SOFTSIGN: + return &MLPPActivation::softsign_derivm; + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return &MLPPActivation::adj_softmax_derivm; + case ACTIVATION_FUNCTION_C_LOG_LOG: + return &MLPPActivation::cloglog_derivm; + case ACTIVATION_FUNCTION_LOGIT: + return &MLPPActivation::logit_derivm; + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return &MLPPActivation::gaussian_cdf_derivm; + case ACTIVATION_FUNCTION_RELU: + return &MLPPActivation::relu_derivm; + case ACTIVATION_FUNCTION_GELU: + return &MLPPActivation::gelu_derivm; + case ACTIVATION_FUNCTION_SIGN: + return &MLPPActivation::sign_derivm; + case ACTIVATION_FUNCTION_UNIT_STEP: + return &MLPPActivation::unit_step_derivm; + case ACTIVATION_FUNCTION_SINH: + return &MLPPActivation::sinh_derivm; + case ACTIVATION_FUNCTION_COSH: + return &MLPPActivation::cosh_derivm; + case ACTIVATION_FUNCTION_TANH: + return &MLPPActivation::tanh_derivm; + case ACTIVATION_FUNCTION_CSCH: + return &MLPPActivation::csch_derivm; + case ACTIVATION_FUNCTION_SECH: + return &MLPPActivation::sech_derivm; + case ACTIVATION_FUNCTION_COTH: + return &MLPPActivation::coth_derivm; + case ACTIVATION_FUNCTION_ARSINH: + return &MLPPActivation::arsinh_derivm; + case ACTIVATION_FUNCTION_ARCOSH: + return &MLPPActivation::arcosh_derivm; + case ACTIVATION_FUNCTION_ARTANH: + return &MLPPActivation::artanh_derivm; + case ACTIVATION_FUNCTION_ARCSCH: + return &MLPPActivation::arcsch_derivm; + case ACTIVATION_FUNCTION_ARSECH: + return &MLPPActivation::arsech_derivm; + case ACTIVATION_FUNCTION_ARCOTH: + return &MLPPActivation::arcoth_derivm; + default: + return NULL; + } +} + +real_t MLPPActivation::run_activation_real(const ActivationFunction func, const real_t z, const bool deriv) { + if (deriv) { + return run_activation_norm_real(func, z); + } else { + return run_activation_deriv_real(func, z); + } +} Ref MLPPActivation::run_activation_vector(const ActivationFunction func, const Ref &z, const bool deriv) { - return Ref(); + if (deriv) { + return run_activation_norm_vector(func, z); + } else { + return run_activation_deriv_vector(func, z); + } } Ref MLPPActivation::run_activation_matrix(const ActivationFunction func, const Ref &z, const bool deriv) { - return Ref(); + if (deriv) { + return run_activation_norm_matrix(func, z); + } else { + return run_activation_deriv_matrix(func, z); + } } +real_t MLPPActivation::run_activation_norm_real(const ActivationFunction func, const real_t z) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_normr(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_normr(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_normr(z); + case ACTIVATION_FUNCTION_MISH: + return mish_normr(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_normr(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_normr(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_normr(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_normr(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_normr(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_normr(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_normr(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_normr(z); + case ACTIVATION_FUNCTION_RELU: + return relu_normr(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_normr(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_normr(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_normr(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_normr(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_normr(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_normr(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_normr(z); + case ACTIVATION_FUNCTION_SECH: + return sech_normr(z); + case ACTIVATION_FUNCTION_COTH: + return coth_normr(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_normr(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_normr(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_normr(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_normr(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_normr(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_normr(z); + default: + ERR_FAIL_V(0); + } +} Ref MLPPActivation::run_activation_norm_vector(const ActivationFunction func, const Ref &z) { - return Ref(); + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_normv(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_normv(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_normv(z); + case ACTIVATION_FUNCTION_MISH: + return mish_normv(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_normv(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_normv(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_normv(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_normv(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_normv(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_normv(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_normv(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_normv(z); + case ACTIVATION_FUNCTION_RELU: + return relu_normv(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_normv(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_normv(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_normv(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_normv(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_normv(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_normv(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_normv(z); + case ACTIVATION_FUNCTION_SECH: + return sech_normv(z); + case ACTIVATION_FUNCTION_COTH: + return coth_normv(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_normv(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_normv(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_normv(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_normv(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_normv(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_normv(z); + default: + ERR_FAIL_V(Ref()); + } } Ref MLPPActivation::run_activation_norm_matrix(const ActivationFunction func, const Ref &z) { - return Ref(); + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_normm(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_normm(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_normm(z); + case ACTIVATION_FUNCTION_MISH: + return mish_normm(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_normm(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_normm(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_normm(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_normm(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_normm(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_normm(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_normm(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_normm(z); + case ACTIVATION_FUNCTION_RELU: + return relu_normm(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_normm(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_normm(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_normm(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_normm(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_normm(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_normm(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_normm(z); + case ACTIVATION_FUNCTION_SECH: + return sech_normm(z); + case ACTIVATION_FUNCTION_COTH: + return coth_normm(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_normm(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_normm(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_normm(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_normm(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_normm(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_normm(z); + default: + ERR_FAIL_V(Ref()); + } } +real_t MLPPActivation::run_activation_deriv_real(const ActivationFunction func, const real_t z) { + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_normr(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_normr(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_normr(z); + case ACTIVATION_FUNCTION_MISH: + return mish_normr(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_normr(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_normr(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_normr(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_normr(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_normr(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_normr(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_normr(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_normr(z); + case ACTIVATION_FUNCTION_RELU: + return relu_normr(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_normr(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_normr(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_normr(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_normr(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_normr(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_normr(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_normr(z); + case ACTIVATION_FUNCTION_SECH: + return sech_normr(z); + case ACTIVATION_FUNCTION_COTH: + return coth_normr(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_normr(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_normr(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_normr(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_normr(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_normr(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_normr(z); + default: + ERR_FAIL_V(0); + } +} Ref MLPPActivation::run_activation_deriv_vector(const ActivationFunction func, const Ref &z) { - return Ref(); + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_derivv(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_derivv(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_derivv(z); + case ACTIVATION_FUNCTION_MISH: + return mish_derivv(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_derivv(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_derivv(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_derivv(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_derivv(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_derivv(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_derivv(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_derivv(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_derivv(z); + case ACTIVATION_FUNCTION_RELU: + return relu_derivv(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_derivv(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_derivv(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_derivv(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_derivv(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_derivv(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_derivv(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_derivv(z); + case ACTIVATION_FUNCTION_SECH: + return sech_derivv(z); + case ACTIVATION_FUNCTION_COTH: + return coth_derivv(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_derivv(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_derivv(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_derivv(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_derivv(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_derivv(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_derivv(z); + default: + ERR_FAIL_V(Ref()); + } } Ref MLPPActivation::run_activation_deriv_matrix(const ActivationFunction func, const Ref &z) { - return Ref(); + switch (func) { + case ACTIVATION_FUNCTION_LINEAR: + return linear_derivm(z); + case ACTIVATION_FUNCTION_SIGMOID: + return sigmoid_derivm(z); + case ACTIVATION_FUNCTION_SWISH: + return swish_derivm(z); + case ACTIVATION_FUNCTION_MISH: + return mish_derivm(z); + case ACTIVATION_FUNCTION_SIN_C: + return sinc_derivm(z); + case ACTIVATION_FUNCTION_SOFTMAX: + return softmax_derivm(z); + case ACTIVATION_FUNCTION_SOFTPLUS: + return softplus_derivm(z); + case ACTIVATION_FUNCTION_SOFTSIGN: + return softsign_derivm(z); + case ACTIVATION_FUNCTION_ADJ_SOFTMAX: + return adj_softmax_derivm(z); + case ACTIVATION_FUNCTION_C_LOG_LOG: + return cloglog_derivm(z); + case ACTIVATION_FUNCTION_LOGIT: + return logit_derivm(z); + case ACTIVATION_FUNCTION_GAUSSIAN_CDF: + return gaussian_cdf_derivm(z); + case ACTIVATION_FUNCTION_RELU: + return relu_derivm(z); + case ACTIVATION_FUNCTION_GELU: + return gelu_derivm(z); + case ACTIVATION_FUNCTION_SIGN: + return sign_derivm(z); + case ACTIVATION_FUNCTION_UNIT_STEP: + return unit_step_derivm(z); + case ACTIVATION_FUNCTION_SINH: + return sinh_derivm(z); + case ACTIVATION_FUNCTION_COSH: + return cosh_derivm(z); + case ACTIVATION_FUNCTION_TANH: + return tanh_derivm(z); + case ACTIVATION_FUNCTION_CSCH: + return csch_derivm(z); + case ACTIVATION_FUNCTION_SECH: + return sech_derivm(z); + case ACTIVATION_FUNCTION_COTH: + return coth_derivm(z); + case ACTIVATION_FUNCTION_ARSINH: + return arsinh_derivm(z); + case ACTIVATION_FUNCTION_ARCOSH: + return arcosh_derivm(z); + case ACTIVATION_FUNCTION_ARTANH: + return artanh_derivm(z); + case ACTIVATION_FUNCTION_ARCSCH: + return arcsch_derivm(z); + case ACTIVATION_FUNCTION_ARSECH: + return arsech_derivm(z); + case ACTIVATION_FUNCTION_ARCOTH: + return arcoth_derivm(z); + default: + ERR_FAIL_V(Ref()); + } } -Ref MLPPActivation::activation(const Ref &z, real_t (*function)(real_t)) { +Ref MLPPActivation::activationr(const Ref &z, real_t (*function)(real_t)) { Ref a; a.instance(); @@ -113,6 +880,10 @@ Ref MLPPActivation::sigmoid_derivm(const Ref &z) { } //SOFTMAX + +real_t MLPPActivation::softmax_normr(real_t z) { + return z; +} Ref MLPPActivation::softmax_normv(const Ref &z) { MLPPLinAlg alg; @@ -162,6 +933,9 @@ Ref MLPPActivation::softmax_normm(const Ref &z) { return a; } +real_t MLPPActivation::softmax_derivr(real_t z) { + return z; +} Ref MLPPActivation::softmax_derivv(const Ref &z) { MLPPLinAlg alg; @@ -213,6 +987,10 @@ Ref MLPPActivation::softmax_derivm(const Ref &z) { //ADJ_SOFTMAX +real_t MLPPActivation::adj_softmax_normr(real_t z) { + return 0; +} + Ref MLPPActivation::adj_softmax_normv(const Ref &z) { MLPPLinAlg alg; @@ -256,6 +1034,10 @@ Ref MLPPActivation::adj_softmax_normm(const Ref &z) { return n; } +real_t MLPPActivation::adj_softmax_derivr(real_t z) { + return 0; +} + Ref MLPPActivation::adj_softmax_derivv(const Ref &z) { MLPPLinAlg alg; diff --git a/mlpp/activation/activation.h b/mlpp/activation/activation.h index 68470d3..007ddd1 100644 --- a/mlpp/activation/activation.h +++ b/mlpp/activation/activation.h @@ -32,8 +32,10 @@ public: ACTIVATION_FUNCTION_SWISH, ACTIVATION_FUNCTION_MISH, ACTIVATION_FUNCTION_SIN_C, + ACTIVATION_FUNCTION_SOFTMAX, ACTIVATION_FUNCTION_SOFTPLUS, ACTIVATION_FUNCTION_SOFTSIGN, + ACTIVATION_FUNCTION_ADJ_SOFTMAX, ACTIVATION_FUNCTION_C_LOG_LOG, ACTIVATION_FUNCTION_LOGIT, ACTIVATION_FUNCTION_GAUSSIAN_CDF, @@ -56,21 +58,35 @@ public: }; public: - //TODO add override for vec, and real_t - typedef Ref (MLPPActivation::*ActivationFunctionPointer)(const Ref &); - ActivationFunctionPointer get_activation_function_ptr(const ActivationFunction func, const bool deriv = false); - Ref get_activation_function_funcref(const ActivationFunction func, const bool deriv = false); + typedef real_t (MLPPActivation::*RealActivationFunctionPointer)(real_t); + typedef Ref (MLPPActivation::*VectorActivationFunctionPointer)(const Ref &); + typedef Ref (MLPPActivation::*MatrixActivationFunctionPointer)(const Ref &); + RealActivationFunctionPointer get_activation_function_ptr_real(const ActivationFunction func, const bool deriv = false); + VectorActivationFunctionPointer get_activation_function_ptr_vector(const ActivationFunction func, const bool deriv = false); + MatrixActivationFunctionPointer get_activation_function_ptr_matrix(const ActivationFunction func, const bool deriv = false); + + RealActivationFunctionPointer get_activation_function_ptr_normal_real(const ActivationFunction func); + VectorActivationFunctionPointer get_activation_function_ptr_normal_vector(const ActivationFunction func); + MatrixActivationFunctionPointer get_activation_function_ptr_normal_matrix(const ActivationFunction func); + + RealActivationFunctionPointer get_activation_function_ptr_deriv_real(const ActivationFunction func); + VectorActivationFunctionPointer get_activation_function_ptr_deriv_vector(const ActivationFunction func); + MatrixActivationFunctionPointer get_activation_function_ptr_deriv_matrix(const ActivationFunction func); + + real_t run_activation_real(const ActivationFunction func, const real_t z, const bool deriv = false); Ref run_activation_vector(const ActivationFunction func, const Ref &z, const bool deriv = false); Ref run_activation_matrix(const ActivationFunction func, const Ref &z, const bool deriv = false); + real_t run_activation_norm_real(const ActivationFunction func, const real_t z); Ref run_activation_norm_vector(const ActivationFunction func, const Ref &z); Ref run_activation_norm_matrix(const ActivationFunction func, const Ref &z); + real_t run_activation_deriv_real(const ActivationFunction func, const real_t z); Ref run_activation_deriv_vector(const ActivationFunction func, const Ref &z); Ref run_activation_deriv_matrix(const ActivationFunction func, const Ref &z); - Ref activation(const Ref &z, real_t (*function)(real_t)); + Ref activationr(const Ref &z, real_t (*function)(real_t)); //ACTIVATION FUNCTIONS @@ -96,17 +112,21 @@ public: //SOFTMAX + real_t softmax_normr(real_t z); Ref softmax_normv(const Ref &z); Ref softmax_normm(const Ref &z); + real_t softmax_derivr(real_t z); Ref softmax_derivv(const Ref &z); Ref softmax_derivm(const Ref &z); //ADJ_SOFTMAX + real_t adj_softmax_normr(real_t z); Ref adj_softmax_normv(const Ref &z); Ref adj_softmax_normm(const Ref &z); + real_t adj_softmax_derivr(real_t z); Ref adj_softmax_derivv(const Ref &z); Ref adj_softmax_derivm(const Ref &z);