removed line search

This commit is contained in:
novak_99 2022-02-10 18:36:56 -08:00
parent 11ca2d952f
commit b0b686799c
4 changed files with 5 additions and 52 deletions

View File

@ -71,11 +71,13 @@ namespace MLPP {
cost_prev = Cost(y_hat, outputSet);
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learning_rate/n, cumulativeHiddenLayerWGrad);
outputWGrad = alg.scalarMultiply(learning_rate/n, outputWGrad);
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learning_rate); // subject to change. may want bias to have this matrix too.
std::cout << learning_rate << std::endl;
forwardPass();
if(UI) { ANN::UI(epoch, cost_prev, y_hat, outputSet); }
@ -557,52 +559,6 @@ void ANN::Adam(double learning_rate, int max_epoch, int mini_batch_size, double
util.saveParameters(fileName, outputLayer->weights, outputLayer->bias, 0, network.size() + 1);
}
}
// https://www.youtube.com/watch?v=4qDt4QUl4zE
// The above video detailed the necessary components of the line search algorithm.
double ANN::backTrackingLineSearch(double beta, double learningRate){
LinAlg alg;
forwardPass();
std::vector<double> outputLayerWeights = outputLayer->weights;
std::vector<std::vector<std::vector<double>>> cumulativeHiddenWeights;
if(!network.empty()){
cumulativeHiddenWeights.push_back(network[network.size() - 1].weights);
for(int i = network.size() - 2; i >= 0; i--){
cumulativeHiddenWeights.push_back(network[i].weights);
}
}
while(true){
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learningRate/n, cumulativeHiddenLayerWGrad);
outputWGrad = alg.scalarMultiply(learningRate/n, outputWGrad);
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learningRate); // subject to change. may want bias to have this matrix too.
forwardPass();
if(Cost(y_hat, outputSet) > Cost(y_hat, outputSet) - (learningRate/2) * (alg.norm_2(cumulativeHiddenLayerWGrad) + alg.norm_2(outputWGrad))){
learningRate *= beta;
}
else {
outputLayer->weights = outputLayerWeights;
if(!network.empty()){
network[network.size() - 1].weights = cumulativeHiddenWeights[0];
for(int i = network.size() - 2; i >= 0; i--){
network[i].weights = cumulativeHiddenWeights[(network.size() - 2) - i + 1];
}
}
return learningRate;
break;
}
}
}
void ANN::setLearningRateScheduler(std::string type, double decayConstant){
lrScheduler = type;

View File

@ -32,9 +32,7 @@ class ANN{
void Nadam(double learning_rate, int max_epoch, int mini_batch_size, double b1, double b2, double e, bool UI = 1);
void AMSGrad(double learning_rate, int max_epoch, int mini_batch_size, double b1, double b2, double e, bool UI = 1);
double score();
void save(std::string fileName);
double backTrackingLineSearch(double beta, double learningRate); // Use this to find an optimal learning rate value.
void save(std::string fileName);
void setLearningRateScheduler(std::string type, double decayConstant);
void setLearningRateScheduler(std::string type, double decayConstant, double dropRate);

BIN
a.out

Binary file not shown.

View File

@ -373,9 +373,8 @@ int main() {
//ann.Adadelta(1, 1000, 2, 0.9, 0.000001, 1);
//ann.Momentum(0.1, 8000, 2, 0.9, true, 1);
std::cout << ann.backTrackingLineSearch(0.707, 0.1) << std::endl;
//ann.setLearningRateScheduler("Step", 0.5, 1000);
//ann.gradientDescent(0.1, 20000, 1);
ann.gradientDescent(1, 5, 1);
//alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
//std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;