Line search method- find optimal learning rates for neural nets.

This commit is contained in:
novak_99 2022-02-05 00:42:41 -08:00
parent 559e55fd89
commit d483662ce2
6 changed files with 78 additions and 4 deletions

View File

@ -558,6 +558,50 @@ void ANN::Adam(double learning_rate, int max_epoch, int mini_batch_size, double
}
}
double ANN::backTrackingLineSearch(double beta, double learningRate){
LinAlg alg;
forwardPass();
std::vector<double> outputLayerWeights = outputLayer->weights;
std::vector<std::vector<std::vector<double>>> cumulativeHiddenWeights;
if(!network.empty()){
cumulativeHiddenWeights.push_back(network[network.size() - 1].weights);
for(int i = network.size() - 2; i >= 0; i--){
cumulativeHiddenWeights.push_back(network[i].weights);
}
}
while(true){
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learningRate/n, cumulativeHiddenLayerWGrad);
outputWGrad = alg.scalarMultiply(learningRate/n, outputWGrad);
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learningRate); // subject to change. may want bias to have this matrix too.
forwardPass();
if(Cost(y_hat, outputSet) > Cost(y_hat, outputSet) - (learningRate/2) * (alg.norm_2(cumulativeHiddenLayerWGrad) + alg.norm_2(outputWGrad))){
learningRate *= beta;
}
else {
outputLayer->weights = outputLayerWeights;
if(!network.empty()){
network[network.size() - 1].weights = cumulativeHiddenWeights[0];
for(int i = network.size() - 2; i >= 0; i--){
network[i].weights = cumulativeHiddenWeights[(network.size() - 2) - i + 1];
}
}
return learningRate;
break;
}
}
}
void ANN::setLearningRateScheduler(std::string type, double decayConstant){
lrScheduler = type;
ANN::decayConstant = decayConstant;

View File

@ -34,6 +34,8 @@ class ANN{
double score();
void save(std::string fileName);
double backTrackingLineSearch(double beta, double learningRate); // Use this to find an optimal learning rate value.
void setLearningRateScheduler(std::string type, double decayConstant);
void setLearningRateScheduler(std::string type, double decayConstant, double dropRate);

View File

@ -480,6 +480,16 @@ namespace MLPP{
return B;
}
double LinAlg::norm_2(std::vector<std::vector<double>> A){
double sum = 0;
for(int i = 0; i < A.size(); i++){
for(int j = 0; j < A[i].size(); j++){
sum += A[i][j] * A[i][j];
}
}
return std::sqrt(sum);
}
std::vector<std::vector<double>> LinAlg::identity(double d){
std::vector<std::vector<double>> identityMat;
identityMat.resize(d);
@ -1183,4 +1193,16 @@ namespace MLPP{
}
return A;
}
double LinAlg::norm_2(std::vector<std::vector<std::vector<double>>> A){
double sum = 0;
for(int i = 0; i < A.size(); i++){
for(int j = 0; j < A[i].size(); j++){
for(int k = 0; k < A[i][j].size(); k++){
sum += A[i][j][k] * A[i][j][k];
}
}
}
return std::sqrt(sum);
}
}

View File

@ -90,6 +90,8 @@ namespace MLPP{
std::vector<std::vector<double>> round(std::vector<std::vector<double>> A);
double norm_2(std::vector<std::vector<double>> A);
std::vector<std::vector<double>> identity(double d);
std::vector<std::vector<double>> cov(std::vector<std::vector<double>> A);
@ -222,6 +224,8 @@ namespace MLPP{
std::vector<std::vector<std::vector<double>>> abs(std::vector<std::vector<std::vector<double>>> A);
double norm_2(std::vector<std::vector<std::vector<double>>> A);
private:
};

BIN
a.out

Binary file not shown.

View File

@ -372,10 +372,12 @@ int main() {
//ann.AMSGrad(0.1, 10000, 1, 0.9, 0.999, 0.000001, 1);
//ann.Adadelta(1, 1000, 2, 0.9, 0.000001, 1);
//ann.Momentum(0.1, 8000, 2, 0.9, true, 1);
ann.setLearningRateScheduler("Step", 0.5, 1000);
ann.gradientDescent(0.1, 20000, 1);
alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
std::cout << ann.backTrackingLineSearch(0.707, 0.1) << std::endl;
//ann.setLearningRateScheduler("Step", 0.5, 1000);
//ann.gradientDescent(0.1, 20000, 1);
//alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
//std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
//std::vector<std::vector<double>> outputSet = {{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20},
// {2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40}};