mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-02-08 18:00:04 +01:00
Initial bindings for MLPPUtilities.
This commit is contained in:
parent
bbe334856b
commit
1aa239720b
@ -54,8 +54,10 @@ MLPPHiddenLayer::MLPPHiddenLayer(int p_n_hidden, MLPPActivation::ActivationFunct
|
|||||||
weights->resize(Size2i(input->size().x, n_hidden));
|
weights->resize(Size2i(input->size().x, n_hidden));
|
||||||
bias->resize(n_hidden);
|
bias->resize(n_hidden);
|
||||||
|
|
||||||
MLPPUtilities::weight_initializationm(weights, weight_init);
|
MLPPUtilities utils;
|
||||||
MLPPUtilities::bias_initializationv(bias);
|
|
||||||
|
utils.weight_initializationm(weights, weight_init);
|
||||||
|
utils.bias_initializationv(bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
MLPPHiddenLayer::MLPPHiddenLayer() {
|
MLPPHiddenLayer::MLPPHiddenLayer() {
|
||||||
|
@ -605,3 +605,23 @@ real_t MLPPUtilities::accuracy(std::vector<real_t> y_hat, std::vector<real_t> y)
|
|||||||
real_t MLPPUtilities::f1_score(std::vector<real_t> y_hat, std::vector<real_t> y) {
|
real_t MLPPUtilities::f1_score(std::vector<real_t> y_hat, std::vector<real_t> y) {
|
||||||
return 2 * precision(y_hat, y) * recall(y_hat, y) / (precision(y_hat, y) + recall(y_hat, y));
|
return 2 * precision(y_hat, y) * recall(y_hat, y) / (precision(y_hat, y) + recall(y_hat, y));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MLPPUtilities::_bind_methods() {
|
||||||
|
ClassDB::bind_method(D_METHOD("weight_initializationv", "weights", "type"), &MLPPUtilities::weight_initializationv, WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||||
|
ClassDB::bind_method(D_METHOD("weight_initializationm", "weights", "type"), &MLPPUtilities::weight_initializationm, WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||||
|
ClassDB::bind_method(D_METHOD("bias_initializationr"), &MLPPUtilities::bias_initializationr);
|
||||||
|
ClassDB::bind_method(D_METHOD("bias_initializationv", "z"), &MLPPUtilities::bias_initializationv);
|
||||||
|
|
||||||
|
ClassDB::bind_method(D_METHOD("performance_vec", "y_hat", "output_set"), &MLPPUtilities::performance_vec);
|
||||||
|
ClassDB::bind_method(D_METHOD("performance_mat", "y_hat", "y"), &MLPPUtilities::performance_mat);
|
||||||
|
ClassDB::bind_method(D_METHOD("performance_pool_int_array_vec", "y_hat", "output_set"), &MLPPUtilities::performance_pool_int_array_vec);
|
||||||
|
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_NORMAL);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_UNIFORM);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_NORMAL);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_HE_UNIFORM);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_NORMAL);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_LE_CUN_UNIFORM);
|
||||||
|
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_UNIFORM);
|
||||||
|
}
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
#include "core/string/ustring.h"
|
#include "core/string/ustring.h"
|
||||||
#include "core/variant/variant.h"
|
#include "core/variant/variant.h"
|
||||||
|
|
||||||
|
#include "core/object/reference.h"
|
||||||
|
|
||||||
#include "../lin_alg/mlpp_matrix.h"
|
#include "../lin_alg/mlpp_matrix.h"
|
||||||
#include "../lin_alg/mlpp_vector.h"
|
#include "../lin_alg/mlpp_vector.h"
|
||||||
|
|
||||||
@ -20,7 +22,9 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
class MLPPUtilities {
|
class MLPPUtilities : public Reference {
|
||||||
|
GDCLASS(MLPPUtilities, Reference);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Weight Init
|
// Weight Init
|
||||||
static std::vector<real_t> weightInitialization(int n, std::string type = "Default");
|
static std::vector<real_t> weightInitialization(int n, std::string type = "Default");
|
||||||
@ -40,10 +44,10 @@ public:
|
|||||||
WEIGHT_DISTRIBUTION_TYPE_UNIFORM,
|
WEIGHT_DISTRIBUTION_TYPE_UNIFORM,
|
||||||
};
|
};
|
||||||
|
|
||||||
static void weight_initializationv(Ref<MLPPVector> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
void weight_initializationv(Ref<MLPPVector> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||||
static void weight_initializationm(Ref<MLPPMatrix> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
void weight_initializationm(Ref<MLPPMatrix> weights, WeightDistributionType type = WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||||
static real_t bias_initializationr();
|
real_t bias_initializationr();
|
||||||
static void bias_initializationv(Ref<MLPPVector> z);
|
void bias_initializationv(Ref<MLPPVector> z);
|
||||||
|
|
||||||
// Cost/Performance related Functions
|
// Cost/Performance related Functions
|
||||||
real_t performance(std::vector<real_t> y_hat, std::vector<real_t> y);
|
real_t performance(std::vector<real_t> y_hat, std::vector<real_t> y);
|
||||||
@ -76,7 +80,10 @@ public:
|
|||||||
real_t accuracy(std::vector<real_t> y_hat, std::vector<real_t> y);
|
real_t accuracy(std::vector<real_t> y_hat, std::vector<real_t> y);
|
||||||
real_t f1_score(std::vector<real_t> y_hat, std::vector<real_t> y);
|
real_t f1_score(std::vector<real_t> y_hat, std::vector<real_t> y);
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
|
static void _bind_methods();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
VARIANT_ENUM_CAST(MLPPUtilities::WeightDistributionType);
|
||||||
|
|
||||||
#endif /* Utilities_hpp */
|
#endif /* Utilities_hpp */
|
||||||
|
@ -28,6 +28,7 @@ SOFTWARE.
|
|||||||
#include "mlpp/lin_alg/mlpp_vector.h"
|
#include "mlpp/lin_alg/mlpp_vector.h"
|
||||||
|
|
||||||
#include "mlpp/regularization/reg.h"
|
#include "mlpp/regularization/reg.h"
|
||||||
|
#include "mlpp/utilities/utilities.h"
|
||||||
|
|
||||||
#include "mlpp/activation/activation.h"
|
#include "mlpp/activation/activation.h"
|
||||||
|
|
||||||
@ -41,6 +42,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
|||||||
ClassDB::register_class<MLPPVector>();
|
ClassDB::register_class<MLPPVector>();
|
||||||
ClassDB::register_class<MLPPMatrix>();
|
ClassDB::register_class<MLPPMatrix>();
|
||||||
|
|
||||||
|
ClassDB::register_class<MLPPUtilities>();
|
||||||
ClassDB::register_class<MLPPReg>();
|
ClassDB::register_class<MLPPReg>();
|
||||||
|
|
||||||
ClassDB::register_class<MLPPActivation>();
|
ClassDB::register_class<MLPPActivation>();
|
||||||
|
Loading…
Reference in New Issue
Block a user