diff --git a/mlpp/log_reg/log_reg.cpp b/mlpp/log_reg/log_reg.cpp index 196062b..588cf7e 100644 --- a/mlpp/log_reg/log_reg.cpp +++ b/mlpp/log_reg/log_reg.cpp @@ -14,9 +14,15 @@ #include #include +MLPPLogReg::MLPPLogReg(std::vector> pinputSet, std::vector poutputSet, std::string preg, real_t plambda, real_t palpha) { + inputSet = pinputSet; + outputSet = poutputSet; + n = pinputSet.size(); + k = pinputSet[0].size(); + reg = preg; + lambda = plambda; + alpha = palpha; -MLPPLogReg::MLPPLogReg(std::vector> inputSet, std::vector outputSet, std::string reg, real_t lambda, real_t alpha) : - inputSet(inputSet), outputSet(outputSet), n(inputSet.size()), k(inputSet[0].size()), reg(reg), lambda(lambda), alpha(alpha) { y_hat.resize(n); weights = MLPPUtilities::weightInitialization(k); bias = MLPPUtilities::biasInitialization(); @@ -140,7 +146,9 @@ void MLPPLogReg::MBGD(real_t learning_rate, int max_epoch, int mini_batch_size, // Creating the mini-batches int n_mini_batch = n / mini_batch_size; - auto [inputMiniBatches, outputMiniBatches] = MLPPUtilities::createMiniBatches(inputSet, outputSet, n_mini_batch); + auto bacthes = MLPPUtilities::createMiniBatches(inputSet, outputSet, n_mini_batch); + auto inputMiniBatches = std::get<0>(bacthes); + auto outputMiniBatches = std::get<1>(bacthes); while (true) { for (int i = 0; i < n_mini_batch; i++) { @@ -171,12 +179,12 @@ void MLPPLogReg::MBGD(real_t learning_rate, int max_epoch, int mini_batch_size, } real_t MLPPLogReg::score() { - MLPPUtilities util; + MLPPUtilities util; return util.performance(y_hat, outputSet); } void MLPPLogReg::save(std::string fileName) { - MLPPUtilities util; + MLPPUtilities util; util.saveParameters(fileName, weights, bias); } diff --git a/mlpp/log_reg/log_reg.h b/mlpp/log_reg/log_reg.h index 39580ee..e66a5c8 100644 --- a/mlpp/log_reg/log_reg.h +++ b/mlpp/log_reg/log_reg.h @@ -13,8 +13,6 @@ #include #include - - class MLPPLogReg { public: MLPPLogReg(std::vector> inputSet, std::vector outputSet, std::string reg = "None", real_t lambda = 0.5, real_t alpha = 0.5); @@ -50,5 +48,4 @@ private: real_t alpha; /* This is the controlling param for Elastic Net*/ }; - #endif /* LogReg_hpp */