mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-02-22 20:04:19 +01:00
Added bindings for Activation.
This commit is contained in:
parent
1aafbdf66a
commit
fa4b7b6b56
@ -2479,6 +2479,53 @@ Ref<MLPPMatrix> MLPPActivation::arcoth_derivm(const Ref<MLPPMatrix> &z) {
|
||||
return alg.element_wise_divisionm(alg.onematm(z->size().x, z->size().y), alg.subtractionm(alg.onematm(z->size().x, z->size().y), alg.hadamard_productm(z, z)));
|
||||
}
|
||||
|
||||
void MLPPActivation::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("run_activation_real", "func", "z", "deriv"), &MLPPActivation::run_activation_real, false);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_vector", "func", "z", "deriv"), &MLPPActivation::run_activation_vector, false);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_matrix", "func", "z", "deriv"), &MLPPActivation::run_activation_matrix, false);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("run_activation_norm_real", "func", "z"), &MLPPActivation::run_activation_norm_real);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_norm_vector", "func", "z"), &MLPPActivation::run_activation_norm_vector);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_norm_matrix", "func", "z"), &MLPPActivation::run_activation_norm_matrix);
|
||||
|
||||
real_t run_activation_norm_real(const ActivationFunction func, const real_t z);
|
||||
Ref<MLPPVector> run_activation_norm_vector(const ActivationFunction func, const Ref<MLPPVector> &z);
|
||||
Ref<MLPPMatrix> run_activation_norm_matrix(const ActivationFunction func, const Ref<MLPPMatrix> &z);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("run_activation_deriv_real", "func", "z"), &MLPPActivation::run_activation_deriv_real);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_deriv_vector", "func", "z"), &MLPPActivation::run_activation_deriv_vector);
|
||||
ClassDB::bind_method(D_METHOD("run_activation_deriv_matrix", "func", "z"), &MLPPActivation::run_activation_deriv_matrix);
|
||||
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LINEAR);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGMOID);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SWISH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_MISH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIN_C);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTMAX);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTPLUS);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SOFTSIGN);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ADJ_SOFTMAX);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_C_LOG_LOG);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_LOGIT);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GAUSSIAN_CDF);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_RELU);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_GELU);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SIGN);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_UNIT_STEP);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SINH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COSH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_TANH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_CSCH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_SECH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_COTH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSINH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOSH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARTANH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCSCH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARSECH);
|
||||
BIND_ENUM_CONSTANT(ACTIVATION_FUNCTION_ARCOTH);
|
||||
}
|
||||
|
||||
//======================== OLD =============================
|
||||
|
||||
real_t MLPPActivation::linear(real_t z, bool deriv) {
|
||||
|
@ -18,7 +18,6 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
//TODO this should probably be a singleton
|
||||
//TODO Activation functions should either have a variant which does not allocate, or they should just be reworked altogether
|
||||
//TODO Methods here should probably use error macros, in a way where they get disabled in non-tools(?) (maybe release?) builds
|
||||
|
||||
@ -537,7 +536,8 @@ public:
|
||||
|
||||
std::vector<real_t> activation(std::vector<real_t> z, bool deriv, real_t (*function)(real_t, bool));
|
||||
|
||||
private:
|
||||
protected:
|
||||
static void _bind_methods();
|
||||
};
|
||||
|
||||
VARIANT_ENUM_CAST(MLPPActivation::ActivationFunction);
|
||||
|
@ -27,8 +27,10 @@ SOFTWARE.
|
||||
#include "mlpp/lin_alg/mlpp_matrix.h"
|
||||
#include "mlpp/lin_alg/mlpp_vector.h"
|
||||
|
||||
#include "mlpp/knn/knn.h"
|
||||
#include "mlpp/activation/activation.h"
|
||||
|
||||
#include "mlpp/kmeans/kmeans.h"
|
||||
#include "mlpp/knn/knn.h"
|
||||
|
||||
#include "test/mlpp_tests.h"
|
||||
|
||||
@ -37,6 +39,8 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
ClassDB::register_class<MLPPVector>();
|
||||
ClassDB::register_class<MLPPMatrix>();
|
||||
|
||||
ClassDB::register_class<MLPPActivation>();
|
||||
|
||||
ClassDB::register_class<MLPPKNN>();
|
||||
ClassDB::register_class<MLPPKMeans>();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user