mirror of
https://github.com/Relintai/MLPP.git
synced 2024-11-12 10:15:01 +01:00
Line search method- find optimal learning rates for neural nets.
This commit is contained in:
parent
559e55fd89
commit
d483662ce2
@ -558,6 +558,50 @@ void ANN::Adam(double learning_rate, int max_epoch, int mini_batch_size, double
|
||||
}
|
||||
}
|
||||
|
||||
double ANN::backTrackingLineSearch(double beta, double learningRate){
|
||||
LinAlg alg;
|
||||
forwardPass();
|
||||
std::vector<double> outputLayerWeights = outputLayer->weights;
|
||||
std::vector<std::vector<std::vector<double>>> cumulativeHiddenWeights;
|
||||
|
||||
if(!network.empty()){
|
||||
|
||||
cumulativeHiddenWeights.push_back(network[network.size() - 1].weights);
|
||||
|
||||
for(int i = network.size() - 2; i >= 0; i--){
|
||||
cumulativeHiddenWeights.push_back(network[i].weights);
|
||||
}
|
||||
}
|
||||
|
||||
while(true){
|
||||
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
|
||||
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learningRate/n, cumulativeHiddenLayerWGrad);
|
||||
outputWGrad = alg.scalarMultiply(learningRate/n, outputWGrad);
|
||||
|
||||
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learningRate); // subject to change. may want bias to have this matrix too.
|
||||
|
||||
forwardPass();
|
||||
|
||||
if(Cost(y_hat, outputSet) > Cost(y_hat, outputSet) - (learningRate/2) * (alg.norm_2(cumulativeHiddenLayerWGrad) + alg.norm_2(outputWGrad))){
|
||||
learningRate *= beta;
|
||||
}
|
||||
else {
|
||||
outputLayer->weights = outputLayerWeights;
|
||||
|
||||
if(!network.empty()){
|
||||
network[network.size() - 1].weights = cumulativeHiddenWeights[0];
|
||||
|
||||
for(int i = network.size() - 2; i >= 0; i--){
|
||||
network[i].weights = cumulativeHiddenWeights[(network.size() - 2) - i + 1];
|
||||
}
|
||||
}
|
||||
return learningRate;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void ANN::setLearningRateScheduler(std::string type, double decayConstant){
|
||||
lrScheduler = type;
|
||||
ANN::decayConstant = decayConstant;
|
||||
|
@ -34,6 +34,8 @@ class ANN{
|
||||
double score();
|
||||
void save(std::string fileName);
|
||||
|
||||
double backTrackingLineSearch(double beta, double learningRate); // Use this to find an optimal learning rate value.
|
||||
|
||||
void setLearningRateScheduler(std::string type, double decayConstant);
|
||||
void setLearningRateScheduler(std::string type, double decayConstant, double dropRate);
|
||||
|
||||
|
@ -480,6 +480,16 @@ namespace MLPP{
|
||||
return B;
|
||||
}
|
||||
|
||||
double LinAlg::norm_2(std::vector<std::vector<double>> A){
|
||||
double sum = 0;
|
||||
for(int i = 0; i < A.size(); i++){
|
||||
for(int j = 0; j < A[i].size(); j++){
|
||||
sum += A[i][j] * A[i][j];
|
||||
}
|
||||
}
|
||||
return std::sqrt(sum);
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> LinAlg::identity(double d){
|
||||
std::vector<std::vector<double>> identityMat;
|
||||
identityMat.resize(d);
|
||||
@ -1183,4 +1193,16 @@ namespace MLPP{
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
double LinAlg::norm_2(std::vector<std::vector<std::vector<double>>> A){
|
||||
double sum = 0;
|
||||
for(int i = 0; i < A.size(); i++){
|
||||
for(int j = 0; j < A[i].size(); j++){
|
||||
for(int k = 0; k < A[i][j].size(); k++){
|
||||
sum += A[i][j][k] * A[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::sqrt(sum);
|
||||
}
|
||||
}
|
@ -90,6 +90,8 @@ namespace MLPP{
|
||||
|
||||
std::vector<std::vector<double>> round(std::vector<std::vector<double>> A);
|
||||
|
||||
double norm_2(std::vector<std::vector<double>> A);
|
||||
|
||||
std::vector<std::vector<double>> identity(double d);
|
||||
|
||||
std::vector<std::vector<double>> cov(std::vector<std::vector<double>> A);
|
||||
@ -222,6 +224,8 @@ namespace MLPP{
|
||||
|
||||
std::vector<std::vector<std::vector<double>>> abs(std::vector<std::vector<std::vector<double>>> A);
|
||||
|
||||
double norm_2(std::vector<std::vector<std::vector<double>>> A);
|
||||
|
||||
private:
|
||||
};
|
||||
|
||||
|
10
main.cpp
10
main.cpp
@ -372,10 +372,12 @@ int main() {
|
||||
//ann.AMSGrad(0.1, 10000, 1, 0.9, 0.999, 0.000001, 1);
|
||||
//ann.Adadelta(1, 1000, 2, 0.9, 0.000001, 1);
|
||||
//ann.Momentum(0.1, 8000, 2, 0.9, true, 1);
|
||||
ann.setLearningRateScheduler("Step", 0.5, 1000);
|
||||
ann.gradientDescent(0.1, 20000, 1);
|
||||
alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
|
||||
std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
|
||||
|
||||
std::cout << ann.backTrackingLineSearch(0.707, 0.1) << std::endl;
|
||||
//ann.setLearningRateScheduler("Step", 0.5, 1000);
|
||||
//ann.gradientDescent(0.1, 20000, 1);
|
||||
//alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
|
||||
//std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
|
||||
|
||||
//std::vector<std::vector<double>> outputSet = {{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20},
|
||||
// {2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40}};
|
||||
|
Loading…
Reference in New Issue
Block a user