Initial bindings for MLPPUtilities.

This commit is contained in:
Relintai 2023-02-04 01:18:50 +01:00
parent bbe334856b
commit 1aa239720b
4 changed files with 39 additions and 8 deletions

View File

@ -54,8 +54,10 @@ MLPPHiddenLayer::MLPPHiddenLayer(int p_n_hidden, MLPPActivation::ActivationFunct
weights->resize(Size2i(input->size().x, n_hidden));
bias->resize(n_hidden);
MLPPUtilities::weight_initializationm(weights, weight_init);
MLPPUtilities::bias_initializationv(bias);
MLPPUtilities utils;
utils.weight_initializationm(weights, weight_init);
utils.bias_initializationv(bias);
}
MLPPHiddenLayer::MLPPHiddenLayer() {

View File

@ -605,3 +605,23 @@ real_t MLPPUtilities::accuracy(std::vector<real_t> y_hat, std::vector<real_t> y)
real_t MLPPUtilities::f1_score(std::vector<real_t> y_hat, std::vector<real_t> y) {
return 2 * precision(y_hat, y) * recall(y_hat, y) / (precision(y_hat, y) + recall(y_hat, y));
}
void MLPPUtilities::_bind_methods() {
ClassDB::bind_method(D_METHOD("weight_initializationv", "weights", "type"), &MLPPUtilities::weight_initializationv, WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
ClassDB::bind_method(D_METHOD("weight_initializationm", "weights", "type"), &MLPPUtilities::weight_initializationm, WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
ClassDB::bind_method(D_METHOD("bias_initializationr"), &MLPPUtilities::bias_initializationr);
ClassDB::bind_method(D_METHOD("bias_initializationv", "z"), &MLPPUtilities::bias_initializationv);
ClassDB::bind_method(D_METHOD("performance_vec", "y_hat", "output_set"), &MLPPUtilities::performance_vec);
ClassDB::bind_method(D_METHOD("performance_mat", "y_hat", "y"), &MLPPUtilities::performance_mat);
ClassDB::bind_method(D_METHOD("performance_pool_int_array_vec", "y_hat", "output_set"), &MLPPUtilities::performance_pool_int_array_vec);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_NORMAL);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_UNIFORM);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_NORMAL);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_UNIFORM);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_NORMAL);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_UNIFORM);
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_UNIFORM);
}

View File

@ -13,6 +13,8 @@
#include "core/string/ustring.h"
#include "core/variant/variant.h"
#include "core/object/reference.h"
#include "../lin_alg/mlpp_matrix.h"
#include "../lin_alg/mlpp_vector.h"
@ -20,7 +22,9 @@
#include <tuple>
#include <vector>
class MLPPUtilities {
class MLPPUtilities : public Reference {
GDCLASS(MLPPUtilities, Reference);
public:
// Weight Init
static std::vector<real_t> weightInitialization(int n, std::string type = "Default");
@ -40,10 +44,10 @@ public:
WEIGHT_DISTRIBUTION_TYPE_UNIFORM,
};
static void weight_initializationv(Ref<MLPPVector> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
static void weight_initializationm(Ref<MLPPMatrix> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
static real_t bias_initializationr();
static void bias_initializationv(Ref<MLPPVector> z);
void weight_initializationv(Ref<MLPPVector> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
void weight_initializationm(Ref<MLPPMatrix> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
real_t bias_initializationr();
void bias_initializationv(Ref<MLPPVector> z);
// Cost/Performance related Functions
real_t performance(std::vector<real_t> y_hat, std::vector<real_t> y);
@ -76,7 +80,10 @@ public:
real_t accuracy(std::vector<real_t> y_hat, std::vector<real_t> y);
real_t f1_score(std::vector<real_t> y_hat, std::vector<real_t> y);
private:
protected:
static void _bind_methods();
};
VARIANT_ENUM_CAST(MLPPUtilities::WeightDistributionType);
#endif /* Utilities_hpp */

View File

@ -28,6 +28,7 @@ SOFTWARE.
#include "mlpp/lin_alg/mlpp_vector.h"
#include "mlpp/regularization/reg.h"
#include "mlpp/utilities/utilities.h"
#include "mlpp/activation/activation.h"
@ -41,6 +42,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPVector>();
ClassDB::register_class<MLPPMatrix>();
ClassDB::register_class<MLPPUtilities>();
ClassDB::register_class<MLPPReg>();
ClassDB::register_class<MLPPActivation>();