From 5e82d4a907d8b8f4d0d9a23dae7509941e23c119 Mon Sep 17 00:00:00 2001 From: Relintai Date: Sat, 4 Feb 2023 12:34:00 +0100 Subject: [PATCH] Added bindings for Cost. --- mlpp/cost/cost.cpp | 65 ++++++++++++++++++++++++++++++++++++++++++++++ mlpp/cost/cost.h | 3 +++ register_types.cpp | 4 ++- 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/mlpp/cost/cost.cpp b/mlpp/cost/cost.cpp index 9ae1450..c7194af 100644 --- a/mlpp/cost/cost.cpp +++ b/mlpp/cost/cost.cpp @@ -980,3 +980,68 @@ std::vector MLPPCost::dualFormSVMDeriv(std::vector alpha, std::v return alg.subtraction(alphaQDeriv, one); } + +void MLPPCost::_bind_methods() { + ClassDB::bind_method(D_METHOD("msev", "y_hat", "y"), &MLPPCost::msev); + ClassDB::bind_method(D_METHOD("msem", "y_hat", "y"), &MLPPCost::msem); + + ClassDB::bind_method(D_METHOD("mse_derivv", "y_hat", "y"), &MLPPCost::mse_derivv); + ClassDB::bind_method(D_METHOD("mse_derivm", "y_hat", "y"), &MLPPCost::mse_derivm); + + ClassDB::bind_method(D_METHOD("rmsev", "y_hat", "y"), &MLPPCost::rmsev); + ClassDB::bind_method(D_METHOD("rmsem", "y_hat", "y"), &MLPPCost::rmsem); + + ClassDB::bind_method(D_METHOD("rmse_derivv", "y_hat", "y"), &MLPPCost::rmse_derivv); + ClassDB::bind_method(D_METHOD("rmse_derivm", "y_hat", "y"), &MLPPCost::rmse_derivm); + + ClassDB::bind_method(D_METHOD("maev", "y_hat", "y"), &MLPPCost::maev); + ClassDB::bind_method(D_METHOD("maem", "y_hat", "y"), &MLPPCost::maem); + + ClassDB::bind_method(D_METHOD("mae_derivv", "y_hat", "y"), &MLPPCost::mae_derivv); + ClassDB::bind_method(D_METHOD("mae_derivm", "y_hat", "y"), &MLPPCost::mae_derivm); + + ClassDB::bind_method(D_METHOD("mbev", "y_hat", "y"), &MLPPCost::mbev); + ClassDB::bind_method(D_METHOD("mbem", "y_hat", "y"), &MLPPCost::mbem); + + ClassDB::bind_method(D_METHOD("mbe_derivv", "y_hat", "y"), &MLPPCost::mbe_derivv); + ClassDB::bind_method(D_METHOD("mbe_derivm", "y_hat", "y"), &MLPPCost::mbe_derivm); + + ClassDB::bind_method(D_METHOD("log_lossv", "y_hat", "y"), &MLPPCost::log_lossv); + ClassDB::bind_method(D_METHOD("log_lossm", "y_hat", "y"), &MLPPCost::log_lossm); + + ClassDB::bind_method(D_METHOD("log_loss_derivv", "y_hat", "y"), &MLPPCost::log_loss_derivv); + ClassDB::bind_method(D_METHOD("log_loss_derivm", "y_hat", "y"), &MLPPCost::log_loss_derivm); + + ClassDB::bind_method(D_METHOD("cross_entropyv", "y_hat", "y"), &MLPPCost::cross_entropyv); + ClassDB::bind_method(D_METHOD("cross_entropym", "y_hat", "y"), &MLPPCost::cross_entropym); + + ClassDB::bind_method(D_METHOD("cross_entropy_derivv", "y_hat", "y"), &MLPPCost::cross_entropy_derivv); + ClassDB::bind_method(D_METHOD("cross_entropy_derivm", "y_hat", "y"), &MLPPCost::cross_entropy_derivm); + + ClassDB::bind_method(D_METHOD("huber_lossv", "y_hat", "y"), &MLPPCost::huber_lossv); + ClassDB::bind_method(D_METHOD("huber_lossm", "y_hat", "y"), &MLPPCost::huber_lossm); + + ClassDB::bind_method(D_METHOD("huber_loss_derivv", "y_hat", "y"), &MLPPCost::huber_loss_derivv); + ClassDB::bind_method(D_METHOD("huber_loss_derivm", "y_hat", "y"), &MLPPCost::huber_loss_derivm); + + ClassDB::bind_method(D_METHOD("hinge_lossv", "y_hat", "y"), &MLPPCost::hinge_lossv); + ClassDB::bind_method(D_METHOD("hinge_lossm", "y_hat", "y"), &MLPPCost::hinge_lossm); + + ClassDB::bind_method(D_METHOD("hinge_loss_derivv", "y_hat", "y"), &MLPPCost::hinge_loss_derivv); + ClassDB::bind_method(D_METHOD("hinge_loss_derivm", "y_hat", "y"), &MLPPCost::hinge_loss_derivm); + + ClassDB::bind_method(D_METHOD("hinge_losswv", "y_hat", "y"), &MLPPCost::hinge_losswv); + ClassDB::bind_method(D_METHOD("hinge_losswm", "y_hat", "y"), &MLPPCost::hinge_losswm); + + ClassDB::bind_method(D_METHOD("hinge_loss_derivwv", "y_hat", "y", "C"), &MLPPCost::hinge_loss_derivwv); + ClassDB::bind_method(D_METHOD("hinge_loss_derivwm", "y_hat", "y", "C"), &MLPPCost::hinge_loss_derivwm); + + ClassDB::bind_method(D_METHOD("wasserstein_lossv", "y_hat", "y"), &MLPPCost::wasserstein_lossv); + ClassDB::bind_method(D_METHOD("wasserstein_lossm", "y_hat", "y"), &MLPPCost::wasserstein_lossm); + + ClassDB::bind_method(D_METHOD("wasserstein_loss_derivv", "y_hat", "y"), &MLPPCost::wasserstein_loss_derivv); + ClassDB::bind_method(D_METHOD("wasserstein_loss_derivm", "y_hat", "y"), &MLPPCost::wasserstein_loss_derivm); + + ClassDB::bind_method(D_METHOD("dual_form_svm", "alpha", "X", "y"), &MLPPCost::dual_form_svm); + ClassDB::bind_method(D_METHOD("dual_form_svm_deriv", "alpha", "X", "y"), &MLPPCost::dual_form_svm_deriv); +} diff --git a/mlpp/cost/cost.h b/mlpp/cost/cost.h index aa5ce26..d849322 100644 --- a/mlpp/cost/cost.h +++ b/mlpp/cost/cost.h @@ -154,6 +154,9 @@ public: real_t dualFormSVM(std::vector alpha, std::vector> X, std::vector y); // TO DO: DON'T forget to add non-linear kernelizations. std::vector dualFormSVMDeriv(std::vector alpha, std::vector> X, std::vector y); + +protected: + static void _bind_methods(); }; #endif /* Cost_hpp */ diff --git a/register_types.cpp b/register_types.cpp index 9c25853..374b7c4 100644 --- a/register_types.cpp +++ b/register_types.cpp @@ -27,9 +27,10 @@ SOFTWARE. #include "mlpp/lin_alg/mlpp_matrix.h" #include "mlpp/lin_alg/mlpp_vector.h" +#include "mlpp/activation/activation.h" +#include "mlpp/cost/cost.h" #include "mlpp/regularization/reg.h" #include "mlpp/utilities/utilities.h" -#include "mlpp/activation/activation.h" #include "mlpp/hidden_layer/hidden_layer.h" @@ -46,6 +47,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) { ClassDB::register_class(); ClassDB::register_class(); ClassDB::register_class(); + ClassDB::register_class(); ClassDB::register_class();