mirror of
https://github.com/Relintai/MLPP.git
synced 2024-11-12 10:15:01 +01:00
removed line search
This commit is contained in:
parent
11ca2d952f
commit
b0b686799c
@ -71,11 +71,13 @@ namespace MLPP {
|
||||
cost_prev = Cost(y_hat, outputSet);
|
||||
|
||||
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
|
||||
|
||||
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learning_rate/n, cumulativeHiddenLayerWGrad);
|
||||
outputWGrad = alg.scalarMultiply(learning_rate/n, outputWGrad);
|
||||
|
||||
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learning_rate); // subject to change. may want bias to have this matrix too.
|
||||
|
||||
std::cout << learning_rate << std::endl;
|
||||
|
||||
forwardPass();
|
||||
|
||||
if(UI) { ANN::UI(epoch, cost_prev, y_hat, outputSet); }
|
||||
@ -557,52 +559,6 @@ void ANN::Adam(double learning_rate, int max_epoch, int mini_batch_size, double
|
||||
util.saveParameters(fileName, outputLayer->weights, outputLayer->bias, 0, network.size() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// https://www.youtube.com/watch?v=4qDt4QUl4zE
|
||||
// The above video detailed the necessary components of the line search algorithm.
|
||||
double ANN::backTrackingLineSearch(double beta, double learningRate){
|
||||
LinAlg alg;
|
||||
forwardPass();
|
||||
std::vector<double> outputLayerWeights = outputLayer->weights;
|
||||
std::vector<std::vector<std::vector<double>>> cumulativeHiddenWeights;
|
||||
|
||||
if(!network.empty()){
|
||||
|
||||
cumulativeHiddenWeights.push_back(network[network.size() - 1].weights);
|
||||
|
||||
for(int i = network.size() - 2; i >= 0; i--){
|
||||
cumulativeHiddenWeights.push_back(network[i].weights);
|
||||
}
|
||||
}
|
||||
|
||||
while(true){
|
||||
auto [cumulativeHiddenLayerWGrad, outputWGrad] = computeGradients(y_hat, outputSet);
|
||||
cumulativeHiddenLayerWGrad = alg.scalarMultiply(learningRate/n, cumulativeHiddenLayerWGrad);
|
||||
outputWGrad = alg.scalarMultiply(learningRate/n, outputWGrad);
|
||||
|
||||
updateParameters(cumulativeHiddenLayerWGrad, outputWGrad, learningRate); // subject to change. may want bias to have this matrix too.
|
||||
|
||||
forwardPass();
|
||||
|
||||
if(Cost(y_hat, outputSet) > Cost(y_hat, outputSet) - (learningRate/2) * (alg.norm_2(cumulativeHiddenLayerWGrad) + alg.norm_2(outputWGrad))){
|
||||
learningRate *= beta;
|
||||
}
|
||||
else {
|
||||
outputLayer->weights = outputLayerWeights;
|
||||
|
||||
if(!network.empty()){
|
||||
network[network.size() - 1].weights = cumulativeHiddenWeights[0];
|
||||
|
||||
for(int i = network.size() - 2; i >= 0; i--){
|
||||
network[i].weights = cumulativeHiddenWeights[(network.size() - 2) - i + 1];
|
||||
}
|
||||
}
|
||||
return learningRate;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void ANN::setLearningRateScheduler(std::string type, double decayConstant){
|
||||
lrScheduler = type;
|
||||
|
@ -32,9 +32,7 @@ class ANN{
|
||||
void Nadam(double learning_rate, int max_epoch, int mini_batch_size, double b1, double b2, double e, bool UI = 1);
|
||||
void AMSGrad(double learning_rate, int max_epoch, int mini_batch_size, double b1, double b2, double e, bool UI = 1);
|
||||
double score();
|
||||
void save(std::string fileName);
|
||||
|
||||
double backTrackingLineSearch(double beta, double learningRate); // Use this to find an optimal learning rate value.
|
||||
void save(std::string fileName);
|
||||
|
||||
void setLearningRateScheduler(std::string type, double decayConstant);
|
||||
void setLearningRateScheduler(std::string type, double decayConstant, double dropRate);
|
||||
|
3
main.cpp
3
main.cpp
@ -373,9 +373,8 @@ int main() {
|
||||
//ann.Adadelta(1, 1000, 2, 0.9, 0.000001, 1);
|
||||
//ann.Momentum(0.1, 8000, 2, 0.9, true, 1);
|
||||
|
||||
std::cout << ann.backTrackingLineSearch(0.707, 0.1) << std::endl;
|
||||
//ann.setLearningRateScheduler("Step", 0.5, 1000);
|
||||
//ann.gradientDescent(0.1, 20000, 1);
|
||||
ann.gradientDescent(1, 5, 1);
|
||||
//alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
|
||||
//std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user