From 6a5db4e60a6928c5695c934b28cbe9f3580274d2 Mon Sep 17 00:00:00 2001 From: Relintai Date: Sat, 4 Feb 2023 00:49:16 +0100 Subject: [PATCH] Added new api to MLPPReg. --- mlpp/regularization/reg.cpp | 179 +++++++++++++++++++++++++++++++++++- mlpp/regularization/reg.h | 34 ++++++- 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/mlpp/regularization/reg.cpp b/mlpp/regularization/reg.cpp index b1897d1..8c725f2 100644 --- a/mlpp/regularization/reg.cpp +++ b/mlpp/regularization/reg.cpp @@ -5,12 +5,190 @@ // #include "reg.h" + +#include "core/math/math_defs.h" + #include "../activation/activation.h" #include "../lin_alg/lin_alg.h" + #include #include +real_t MLPPReg::reg_termv(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + int size = weights->size(); + const real_t *weights_ptr = weights->ptr(); + if (reg == REGULARIZATION_TYPE_RIDGE) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + real_t wi = weights_ptr[i]; + reg += wi * wi; + } + return reg * lambda / 2; + } else if (reg == REGULARIZATION_TYPE_LASSO) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + reg += ABS(weights_ptr[i]); + } + return reg * lambda; + } else if (reg == REGULARIZATION_TYPE_ELASTIC_NET) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + real_t wi = weights_ptr[i]; + reg += alpha * ABS(wi); // Lasso Reg + reg += ((1 - alpha) / 2) * wi * wi; // Ridge Reg + } + return reg * lambda; + } + + return 0; +} +real_t MLPPReg::reg_termm(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + int size = weights->data_size(); + const real_t *weights_ptr = weights->ptr(); + + if (reg == REGULARIZATION_TYPE_RIDGE) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + real_t wi = weights_ptr[i]; + reg += wi * wi; + } + return reg * lambda / 2; + } else if (reg == REGULARIZATION_TYPE_LASSO) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + reg += ABS(weights_ptr[i]); + } + return reg * lambda; + } else if (reg == REGULARIZATION_TYPE_ELASTIC_NET) { + real_t reg = 0; + for (int i = 0; i < size; ++i) { + real_t wi = weights_ptr[i]; + reg += alpha * ABS(wi); // Lasso Reg + reg += ((1 - alpha) / 2) * wi * wi; // Ridge Reg + } + return reg * lambda; + } + + return 0; +} + +Ref MLPPReg::reg_weightsv(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + MLPPLinAlg alg; + + if (reg == REGULARIZATION_TYPE_WEIGHT_CLIPPING) { + return reg_deriv_termv(weights, lambda, alpha, reg); + } + + return alg.subtractionnv(weights, reg_deriv_termv(weights, lambda, alpha, reg)); + + // for(int i = 0; i < weights.size(); i++){ + // weights[i] -= regDerivTerm(weights, lambda, alpha, reg, i); + // } + // return weights; +} +Ref MLPPReg::reg_weightsm(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + MLPPLinAlg alg; + + if (reg == REGULARIZATION_TYPE_WEIGHT_CLIPPING) { + return reg_deriv_termm(weights, lambda, alpha, reg); + } + + return alg.subtractionm(weights, reg_deriv_termm(weights, lambda, alpha, reg)); + + // for(int i = 0; i < weights.size(); i++){ + // for(int j = 0; j < weights[i].size(); j++){ + // weights[i][j] -= regDerivTerm(weights, lambda, alpha, reg, i, j); + // } + // } + // return weights; +} + +Ref MLPPReg::reg_deriv_termv(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + Ref reg_driv; + reg_driv.instance(); + + int size = weights->size(); + + reg_driv->resize(size); + + real_t *reg_driv_ptr = reg_driv->ptrw(); + + for (int i = 0; i < size; ++i) { + reg_driv_ptr[i] = reg_deriv_termvr(weights, lambda, alpha, reg, i); + } + + return reg_driv; +} +Ref MLPPReg::reg_deriv_termm(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg) { + Ref reg_driv; + reg_driv.instance(); + + Size2i size = weights->size(); + + reg_driv->resize(size); + + real_t *reg_driv_ptr = reg_driv->ptrw(); + + for (int i = 0; i < size.y; ++i) { + for (int j = 0; j < size.x; ++j) { + reg_driv_ptr[reg_driv->calculate_index(i, j)] = reg_deriv_termmr(weights, lambda, alpha, reg, i, j); + } + } + + return reg_driv; +} + +real_t MLPPReg::reg_deriv_termvr(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int j) { + MLPPActivation act; + + real_t wj = weights->get_element(j); + + if (reg == REGULARIZATION_TYPE_RIDGE) { + return lambda * wj; + } else if (reg == REGULARIZATION_TYPE_LASSO) { + return lambda * act.sign(wj); + } else if (reg == REGULARIZATION_TYPE_ELASTIC_NET) { + return alpha * lambda * act.sign(wj) + (1 - alpha) * lambda * wj; + } else if (reg == REGULARIZATION_TYPE_WEIGHT_CLIPPING) { // Preparation for Wasserstein GANs. + // We assume lambda is the lower clipping threshold, while alpha is the higher clipping threshold. + // alpha > lambda. + if (wj > alpha) { + return alpha; + } else if (wj < lambda) { + return lambda; + } else { + return wj; + } + } else { + return 0; + } +} +real_t MLPPReg::reg_deriv_termmr(const Ref &weights, real_t lambda, real_t alpha, MLPPReg::RegularizationType reg, int i, int j) { + MLPPActivation act; + + real_t wj = weights->get_element(i, j); + + if (reg == REGULARIZATION_TYPE_RIDGE) { + return lambda * wj; + } else if (reg == REGULARIZATION_TYPE_LASSO) { + return lambda * act.sign(wj); + } else if (reg == REGULARIZATION_TYPE_ELASTIC_NET) { + return alpha * lambda * act.sign(wj) + (1 - alpha) * lambda * wj; + } else if (reg == REGULARIZATION_TYPE_WEIGHT_CLIPPING) { // Preparation for Wasserstein GANs. + // We assume lambda is the lower clipping threshold, while alpha is the higher clipping threshold. + // alpha > lambda. + if (wj > alpha) { + return alpha; + } else if (wj < lambda) { + return lambda; + } else { + return wj; + } + } else { + return 0; + } +} real_t MLPPReg::regTerm(std::vector weights, real_t lambda, real_t alpha, std::string reg) { if (reg == "Ridge") { @@ -162,4 +340,3 @@ real_t MLPPReg::regDerivTerm(std::vector> weights, real_t la return 0; } } - diff --git a/mlpp/regularization/reg.h b/mlpp/regularization/reg.h index deca56a..e27827d 100644 --- a/mlpp/regularization/reg.h +++ b/mlpp/regularization/reg.h @@ -11,12 +11,41 @@ #include "core/math/math_defs.h" -#include +#include "core/object/reference.h" + +#include "../lin_alg/mlpp_matrix.h" +#include "../lin_alg/mlpp_vector.h" + #include +#include +class MLPPReg : public Reference { + GDCLASS(MLPPReg, Reference); -class MLPPReg { public: + enum RegularizationType { + REGULARIZATION_TYPE_RIDGE = 0, + REGULARIZATION_TYPE_LASSO, + REGULARIZATION_TYPE_ELASTIC_NET, + REGULARIZATION_TYPE_WEIGHT_CLIPPING, + }; + + real_t reg_termv(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + real_t reg_termm(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + + Ref reg_weightsv(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + Ref reg_weightsm(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + + Ref reg_deriv_termv(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + Ref reg_deriv_termm(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg); + +private: + real_t reg_deriv_termvr(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg, int j); + real_t reg_deriv_termmr(const Ref &weights, real_t lambda, real_t alpha, RegularizationType reg, int i, int j); + +public: + // ======== OLD ========= + real_t regTerm(std::vector weights, real_t lambda, real_t alpha, std::string reg); real_t regTerm(std::vector> weights, real_t lambda, real_t alpha, std::string reg); @@ -31,5 +60,4 @@ private: real_t regDerivTerm(std::vector> weights, real_t lambda, real_t alpha, std::string reg, int i, int j); }; - #endif /* Reg_hpp */