Registered MLPPANN into the ClassDB.

This commit is contained in:
Relintai 2023-02-12 18:35:53 +01:00
parent b62df601fb
commit 34f81277cd
3 changed files with 15 additions and 4 deletions

View File

@ -848,3 +848,6 @@ void MLPPANN::print_ui(int epoch, real_t cost_prev, std::vector<real_t> y_hat, s
}
}
}
void MLPPANN::_bind_methods() {
}

View File

@ -9,6 +9,8 @@
#include "core/math/math_defs.h"
#include "core/object/reference.h"
#include "../hidden_layer/hidden_layer.h"
#include "../output_layer/output_layer.h"
@ -19,7 +21,9 @@
#include <tuple>
#include <vector>
class MLPPANN {
class MLPPANN : public Reference {
GDCLASS(MLPPANN, Reference);
public:
std::vector<real_t> model_set_test(std::vector<std::vector<real_t>> X);
real_t model_test(std::vector<real_t> x);
@ -49,7 +53,7 @@ public:
MLPPANN();
~MLPPANN();
private:
protected:
real_t apply_learning_rate_scheduler(real_t learningRate, real_t decayConstant, real_t epoch, real_t dropRate);
real_t cost(std::vector<real_t> y_hat, std::vector<real_t> y);
@ -60,6 +64,8 @@ private:
void print_ui(int epoch, real_t cost_prev, std::vector<real_t> y_hat, std::vector<real_t> outputSet);
static void _bind_methods();
std::vector<std::vector<real_t>> inputSet;
std::vector<real_t> outputSet;
std::vector<real_t> y_hat;

View File

@ -42,7 +42,10 @@ SOFTWARE.
#include "mlpp/multi_output_layer/multi_output_layer.h"
#include "mlpp/output_layer/output_layer.h"
#include "mlpp/ann/ann.h"
#include "mlpp/auto_encoder/auto_encoder.h"
#include "mlpp/bernoulli_nb/bernoulli_nb.h"
#include "mlpp/c_log_log_reg/c_log_log_reg.h"
#include "mlpp/dual_svc/dual_svc.h"
#include "mlpp/exp_reg/exp_reg.h"
#include "mlpp/gan/gan.h"
@ -63,8 +66,6 @@ SOFTWARE.
#include "mlpp/tanh_reg/tanh_reg.h"
#include "mlpp/uni_lin_reg/uni_lin_reg.h"
#include "mlpp/wgan/wgan.h"
#include "mlpp/c_log_log_reg/c_log_log_reg.h"
#include "mlpp/bernoulli_nb/bernoulli_nb.h"
#include "test/mlpp_tests.h"
@ -112,6 +113,7 @@ void register_pmlpp_types(ModuleRegistrationLevel p_level) {
ClassDB::register_class<MLPPDualSVC>();
ClassDB::register_class<MLPPCLogLogReg>();
ClassDB::register_class<MLPPBernoulliNB>();
ClassDB::register_class<MLPPANN>();
ClassDB::register_class<MLPPDataESimple>();
ClassDB::register_class<MLPPDataSimple>();