Fixed warnings in MLPPLogReg.

This commit is contained in:
Relintai 2023-02-10 21:44:27 +01:00
parent 73e22e5a7c
commit 20892804ba
2 changed files with 13 additions and 8 deletions

View File

@ -14,9 +14,15 @@
#include <iostream>
#include <random>
MLPPLogReg::MLPPLogReg(std::vector<std::vector<real_t>> pinputSet, std::vector<real_t> poutputSet, std::string preg, real_t plambda, real_t palpha) {
inputSet = pinputSet;
outputSet = poutputSet;
n = pinputSet.size();
k = pinputSet[0].size();
reg = preg;
lambda = plambda;
alpha = palpha;
MLPPLogReg::MLPPLogReg(std::vector<std::vector<real_t>> inputSet, std::vector<real_t> outputSet, std::string reg, real_t lambda, real_t alpha) :
inputSet(inputSet), outputSet(outputSet), n(inputSet.size()), k(inputSet[0].size()), reg(reg), lambda(lambda), alpha(alpha) {
y_hat.resize(n);
weights = MLPPUtilities::weightInitialization(k);
bias = MLPPUtilities::biasInitialization();
@ -140,7 +146,9 @@ void MLPPLogReg::MBGD(real_t learning_rate, int max_epoch, int mini_batch_size,
// Creating the mini-batches
int n_mini_batch = n / mini_batch_size;
auto [inputMiniBatches, outputMiniBatches] = MLPPUtilities::createMiniBatches(inputSet, outputSet, n_mini_batch);
auto bacthes = MLPPUtilities::createMiniBatches(inputSet, outputSet, n_mini_batch);
auto inputMiniBatches = std::get<0>(bacthes);
auto outputMiniBatches = std::get<1>(bacthes);
while (true) {
for (int i = 0; i < n_mini_batch; i++) {

View File

@ -13,8 +13,6 @@
#include <string>
#include <vector>
class MLPPLogReg {
public:
MLPPLogReg(std::vector<std::vector<real_t>> inputSet, std::vector<real_t> outputSet, std::string reg = "None", real_t lambda = 0.5, real_t alpha = 0.5);
@ -50,5 +48,4 @@ private:
real_t alpha; /* This is the controlling param for Elastic Net*/
};
#endif /* LogReg_hpp */