mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-22 15:06:47 +01:00
Fixed warnings in MLPPMANN.
This commit is contained in:
parent
d467f1ccf1
commit
73e22e5a7c
@ -13,7 +13,6 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
MLPPMANN::MLPPMANN(std::vector<std::vector<real_t>> inputSet, std::vector<std::vector<real_t>> outputSet) :
|
||||
inputSet(inputSet), outputSet(outputSet), n(inputSet.size()), k(inputSet[0].size()), n_output(outputSet[0].size()) {
|
||||
}
|
||||
@ -27,7 +26,7 @@ std::vector<std::vector<real_t>> MLPPMANN::modelSetTest(std::vector<std::vector<
|
||||
network[0].input = X;
|
||||
network[0].forwardPass();
|
||||
|
||||
for (int i = 1; i < network.size(); i++) {
|
||||
for (uint32_t i = 1; i < network.size(); i++) {
|
||||
network[i].input = network[i - 1].a;
|
||||
network[i].forwardPass();
|
||||
}
|
||||
@ -42,7 +41,7 @@ std::vector<std::vector<real_t>> MLPPMANN::modelSetTest(std::vector<std::vector<
|
||||
std::vector<real_t> MLPPMANN::modelTest(std::vector<real_t> x) {
|
||||
if (!network.empty()) {
|
||||
network[0].Test(x);
|
||||
for (int i = 1; i < network.size(); i++) {
|
||||
for (uint32_t i = 1; i < network.size(); i++) {
|
||||
network[i].Test(network[i - 1].a_test);
|
||||
}
|
||||
outputLayer->Test(network[network.size() - 1].a_test);
|
||||
@ -89,9 +88,9 @@ void MLPPMANN::gradientDescent(real_t learning_rate, int max_epoch, bool UI) {
|
||||
network[network.size() - 1].bias = alg.subtractMatrixRows(network[network.size() - 1].bias, alg.scalarMultiply(learning_rate / n, network[network.size() - 1].delta));
|
||||
|
||||
for (int i = network.size() - 2; i >= 0; i--) {
|
||||
auto hiddenLayerAvn = network[i].activation_map[network[i].activation];
|
||||
hiddenLayerAvn = network[i].activation_map[network[i].activation];
|
||||
network[i].delta = alg.hadamard_product(alg.matmult(network[i + 1].delta, network[i + 1].weights), (avn.*hiddenLayerAvn)(network[i].z, 1));
|
||||
std::vector<std::vector<real_t>> hiddenLayerWGrad = alg.matmult(alg.transpose(network[i].input), network[i].delta);
|
||||
hiddenLayerWGrad = alg.matmult(alg.transpose(network[i].input), network[i].delta);
|
||||
network[i].weights = alg.subtraction(network[i].weights, alg.scalarMultiply(learning_rate / n, hiddenLayerWGrad));
|
||||
network[i].weights = regularization.regWeights(network[i].weights, network[i].lambda, network[i].alpha, network[i].reg);
|
||||
network[i].bias = alg.subtractMatrixRows(network[i].bias, alg.scalarMultiply(learning_rate / n, network[i].delta));
|
||||
@ -121,16 +120,16 @@ void MLPPMANN::gradientDescent(real_t learning_rate, int max_epoch, bool UI) {
|
||||
}
|
||||
|
||||
real_t MLPPMANN::score() {
|
||||
MLPPUtilities util;
|
||||
MLPPUtilities util;
|
||||
forwardPass();
|
||||
return util.performance(y_hat, outputSet);
|
||||
}
|
||||
|
||||
void MLPPMANN::save(std::string fileName) {
|
||||
MLPPUtilities util;
|
||||
MLPPUtilities util;
|
||||
if (!network.empty()) {
|
||||
util.saveParameters(fileName, network[0].weights, network[0].bias, 0, 1);
|
||||
for (int i = 1; i < network.size(); i++) {
|
||||
for (uint32_t i = 1; i < network.size(); i++) {
|
||||
util.saveParameters(fileName, network[i].weights, network[i].bias, 1, i + 1);
|
||||
}
|
||||
util.saveParameters(fileName, outputLayer->weights, outputLayer->bias, 1, network.size() + 1);
|
||||
@ -164,7 +163,7 @@ real_t MLPPMANN::Cost(std::vector<std::vector<real_t>> y_hat, std::vector<std::v
|
||||
|
||||
auto cost_function = outputLayer->cost_map[outputLayer->cost];
|
||||
if (!network.empty()) {
|
||||
for (int i = 0; i < network.size() - 1; i++) {
|
||||
for (uint32_t i = 0; i < network.size() - 1; i++) {
|
||||
totalRegTerm += regularization.regTerm(network[i].weights, network[i].lambda, network[i].alpha, network[i].reg);
|
||||
}
|
||||
}
|
||||
@ -176,7 +175,7 @@ void MLPPMANN::forwardPass() {
|
||||
network[0].input = inputSet;
|
||||
network[0].forwardPass();
|
||||
|
||||
for (int i = 1; i < network.size(); i++) {
|
||||
for (uint32_t i = 1; i < network.size(); i++) {
|
||||
network[i].input = network[i - 1].a;
|
||||
network[i].forwardPass();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user