mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-01-02 16:29:35 +01:00
Also register MLPPExpReg to the ClassDB.
This commit is contained in:
parent
7659f8c4ea
commit
61793e4a4d
@ -229,6 +229,11 @@ MLPPExpReg::MLPPExpReg(std::vector<std::vector<real_t>> p_input_set, std::vector
|
||||
_bias = MLPPUtilities::biasInitialization();
|
||||
}
|
||||
|
||||
MLPPExpReg::MLPPExpReg() {
|
||||
}
|
||||
MLPPExpReg::~MLPPExpReg() {
|
||||
}
|
||||
|
||||
real_t MLPPExpReg::cost(std::vector<real_t> y_hat, std::vector<real_t> y) {
|
||||
MLPPReg regularization;
|
||||
MLPPCost mlpp_cost;
|
||||
@ -265,3 +270,6 @@ std::vector<real_t> MLPPExpReg::evaluatem(std::vector<std::vector<real_t>> X) {
|
||||
void MLPPExpReg::forward_pass() {
|
||||
_y_hat = evaluatem(_input_set);
|
||||
}
|
||||
|
||||
void MLPPExpReg::_bind_methods() {
|
||||
}
|
||||
|
@ -10,10 +10,14 @@
|
||||
|
||||
#include "core/math/math_defs.h"
|
||||
|
||||
#include "core/object/reference.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
class MLPPExpReg {
|
||||
class MLPPExpReg : public Reference {
|
||||
GDCLASS(MLPPExpReg, Reference);
|
||||
|
||||
public:
|
||||
std::vector<real_t> model_set_test(std::vector<std::vector<real_t>> X);
|
||||
real_t model_test(std::vector<real_t> x);
|
||||
@ -28,7 +32,10 @@ public:
|
||||
|
||||
MLPPExpReg(std::vector<std::vector<real_t>> p_input_set, std::vector<real_t> p_output_set, std::string p_reg = "None", real_t p_lambda = 0.5, real_t p_alpha = 0.5);
|
||||
|
||||
private:
|
||||
MLPPExpReg();
|
||||
~MLPPExpReg();
|
||||
|
||||
protected:
|
||||
real_t cost(std::vector<real_t> y_hat, std::vector<real_t> y);
|
||||
|
||||
real_t evaluatev(std::vector<real_t> x);
|
||||
@ -36,6 +43,8 @@ private:
|
||||
|
||||
void forward_pass();
|
||||
|
||||
static void _bind_methods();
|
||||
|
||||
std::vector<std::vector<real_t>> _input_set;
|
||||
std::vector<real_t> _output_set;
|
||||
std::vector<real_t> _y_hat;
|
||||
|
@ -60,6 +60,7 @@ SOFTWARE.
|
||||
#include "mlpp/lin_reg/lin_reg.h"
|
||||
#include "mlpp/gaussian_nb/gaussian_nb.h"
|
||||
#include "mlpp/gan/gan.h"
|
||||
#include "mlpp/exp_reg/exp_reg.h"
|
||||
|
||||
#include "test/mlpp_tests.h"
|
||||
|
||||
@ -102,6 +103,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
|
||||
ClassDB::register_class<MLPPLinReg>();
|
||||
ClassDB::register_class<MLPPGaussianNB>();
|
||||
ClassDB::register_class<MLPPGAN>();
|
||||
ClassDB::register_class<MLPPExpReg>();
|
||||
|
||||
ClassDB::register_class<MLPPDataESimple>();
|
||||
ClassDB::register_class<MLPPDataSimple>();
|
||||
|
Loading…
Reference in New Issue
Block a user