Added bindings for Cost.

This commit is contained in:
Relintai 2023-02-04 12:34:00 +01:00
parent 3731e8f12f
commit 5e82d4a907
3 changed files with 71 additions and 1 deletions

View File

@ -980,3 +980,68 @@ std::vector<real_t> MLPPCost::dualFormSVMDeriv(std::vector<real_t> alpha, std::v
return alg.subtraction(alphaQDeriv, one);
}
void MLPPCost::_bind_methods() {
ClassDB::bind_method(D_METHOD("msev", "y_hat", "y"), &MLPPCost::msev);
ClassDB::bind_method(D_METHOD("msem", "y_hat", "y"), &MLPPCost::msem);
ClassDB::bind_method(D_METHOD("mse_derivv", "y_hat", "y"), &MLPPCost::mse_derivv);
ClassDB::bind_method(D_METHOD("mse_derivm", "y_hat", "y"), &MLPPCost::mse_derivm);
ClassDB::bind_method(D_METHOD("rmsev", "y_hat", "y"), &MLPPCost::rmsev);
ClassDB::bind_method(D_METHOD("rmsem", "y_hat", "y"), &MLPPCost::rmsem);
ClassDB::bind_method(D_METHOD("rmse_derivv", "y_hat", "y"), &MLPPCost::rmse_derivv);
ClassDB::bind_method(D_METHOD("rmse_derivm", "y_hat", "y"), &MLPPCost::rmse_derivm);
ClassDB::bind_method(D_METHOD("maev", "y_hat", "y"), &MLPPCost::maev);
ClassDB::bind_method(D_METHOD("maem", "y_hat", "y"), &MLPPCost::maem);
ClassDB::bind_method(D_METHOD("mae_derivv", "y_hat", "y"), &MLPPCost::mae_derivv);
ClassDB::bind_method(D_METHOD("mae_derivm", "y_hat", "y"), &MLPPCost::mae_derivm);
ClassDB::bind_method(D_METHOD("mbev", "y_hat", "y"), &MLPPCost::mbev);
ClassDB::bind_method(D_METHOD("mbem", "y_hat", "y"), &MLPPCost::mbem);
ClassDB::bind_method(D_METHOD("mbe_derivv", "y_hat", "y"), &MLPPCost::mbe_derivv);
ClassDB::bind_method(D_METHOD("mbe_derivm", "y_hat", "y"), &MLPPCost::mbe_derivm);
ClassDB::bind_method(D_METHOD("log_lossv", "y_hat", "y"), &MLPPCost::log_lossv);
ClassDB::bind_method(D_METHOD("log_lossm", "y_hat", "y"), &MLPPCost::log_lossm);
ClassDB::bind_method(D_METHOD("log_loss_derivv", "y_hat", "y"), &MLPPCost::log_loss_derivv);
ClassDB::bind_method(D_METHOD("log_loss_derivm", "y_hat", "y"), &MLPPCost::log_loss_derivm);
ClassDB::bind_method(D_METHOD("cross_entropyv", "y_hat", "y"), &MLPPCost::cross_entropyv);
ClassDB::bind_method(D_METHOD("cross_entropym", "y_hat", "y"), &MLPPCost::cross_entropym);
ClassDB::bind_method(D_METHOD("cross_entropy_derivv", "y_hat", "y"), &MLPPCost::cross_entropy_derivv);
ClassDB::bind_method(D_METHOD("cross_entropy_derivm", "y_hat", "y"), &MLPPCost::cross_entropy_derivm);
ClassDB::bind_method(D_METHOD("huber_lossv", "y_hat", "y"), &MLPPCost::huber_lossv);
ClassDB::bind_method(D_METHOD("huber_lossm", "y_hat", "y"), &MLPPCost::huber_lossm);
ClassDB::bind_method(D_METHOD("huber_loss_derivv", "y_hat", "y"), &MLPPCost::huber_loss_derivv);
ClassDB::bind_method(D_METHOD("huber_loss_derivm", "y_hat", "y"), &MLPPCost::huber_loss_derivm);
ClassDB::bind_method(D_METHOD("hinge_lossv", "y_hat", "y"), &MLPPCost::hinge_lossv);
ClassDB::bind_method(D_METHOD("hinge_lossm", "y_hat", "y"), &MLPPCost::hinge_lossm);
ClassDB::bind_method(D_METHOD("hinge_loss_derivv", "y_hat", "y"), &MLPPCost::hinge_loss_derivv);
ClassDB::bind_method(D_METHOD("hinge_loss_derivm", "y_hat", "y"), &MLPPCost::hinge_loss_derivm);
ClassDB::bind_method(D_METHOD("hinge_losswv", "y_hat", "y"), &MLPPCost::hinge_losswv);
ClassDB::bind_method(D_METHOD("hinge_losswm", "y_hat", "y"), &MLPPCost::hinge_losswm);
ClassDB::bind_method(D_METHOD("hinge_loss_derivwv", "y_hat", "y", "C"), &MLPPCost::hinge_loss_derivwv);
ClassDB::bind_method(D_METHOD("hinge_loss_derivwm", "y_hat", "y", "C"), &MLPPCost::hinge_loss_derivwm);
ClassDB::bind_method(D_METHOD("wasserstein_lossv", "y_hat", "y"), &MLPPCost::wasserstein_lossv);
ClassDB::bind_method(D_METHOD("wasserstein_lossm", "y_hat", "y"), &MLPPCost::wasserstein_lossm);
ClassDB::bind_method(D_METHOD("wasserstein_loss_derivv", "y_hat", "y"), &MLPPCost::wasserstein_loss_derivv);
ClassDB::bind_method(D_METHOD("wasserstein_loss_derivm", "y_hat", "y"), &MLPPCost::wasserstein_loss_derivm);
ClassDB::bind_method(D_METHOD("dual_form_svm", "alpha", "X", "y"), &MLPPCost::dual_form_svm);
ClassDB::bind_method(D_METHOD("dual_form_svm_deriv", "alpha", "X", "y"), &MLPPCost::dual_form_svm_deriv);
}

View File

@ -154,6 +154,9 @@ public:
real_t dualFormSVM(std::vector<real_t> alpha, std::vector<std::vector<real_t>> X, std::vector<real_t> y); // TO DO: DON'T forget to add non-linear kernelizations.
std::vector<real_t> dualFormSVMDeriv(std::vector<real_t> alpha, std::vector<std::vector<real_t>> X, std::vector<real_t> y);
protected:
static void _bind_methods();
};
#endif /* Cost_hpp */

View File

@ -27,9 +27,10 @@ SOFTWARE.
#include "mlpp/lin_alg/mlpp_matrix.h"
#include "mlpp/lin_alg/mlpp_vector.h"
#include "mlpp/activation/activation.h"
#include "mlpp/cost/cost.h"
#include "mlpp/regularization/reg.h"
#include "mlpp/utilities/utilities.h"
#include "mlpp/activation/activation.h"
#include "mlpp/hidden_layer/hidden_layer.h"
@ -46,6 +47,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPUtilities>();
ClassDB::register_class<MLPPReg>();
ClassDB::register_class<MLPPActivation>();
ClassDB::register_class<MLPPCost>();
ClassDB::register_class<MLPPHiddenLayer>();