mirror of
https://github.com/Relintai/MLPP.git
synced 2025-02-04 15:55:53 +01:00
added lecun weight init methods
This commit is contained in:
parent
d60cae6189
commit
38cc0c1502
@ -34,6 +34,14 @@ namespace MLPP{
|
|||||||
std::uniform_real_distribution<double> distribution(-sqrt(6 / n), sqrt(6 / n));
|
std::uniform_real_distribution<double> distribution(-sqrt(6 / n), sqrt(6 / n));
|
||||||
weights.push_back(distribution(generator));
|
weights.push_back(distribution(generator));
|
||||||
}
|
}
|
||||||
|
else if(type == "LeCunNormal"){
|
||||||
|
std::normal_distribution<double> distribution(0, sqrt(1 / n));
|
||||||
|
weights.push_back(distribution(generator));
|
||||||
|
}
|
||||||
|
else if(type == "LeCunUniform"){
|
||||||
|
std::uniform_real_distribution<double> distribution(-sqrt(3/n), sqrt(3/n));
|
||||||
|
weights.push_back(distribution(generator));
|
||||||
|
}
|
||||||
else if(type == "Uniform"){
|
else if(type == "Uniform"){
|
||||||
std::uniform_real_distribution<double> distribution(-1/sqrt(n), 1/sqrt(n));
|
std::uniform_real_distribution<double> distribution(-1/sqrt(n), 1/sqrt(n));
|
||||||
weights.push_back(distribution(generator));
|
weights.push_back(distribution(generator));
|
||||||
@ -79,6 +87,14 @@ namespace MLPP{
|
|||||||
std::uniform_real_distribution<double> distribution(-sqrt(6 / n), sqrt(6 / n));
|
std::uniform_real_distribution<double> distribution(-sqrt(6 / n), sqrt(6 / n));
|
||||||
weights[i].push_back(distribution(generator));
|
weights[i].push_back(distribution(generator));
|
||||||
}
|
}
|
||||||
|
else if(type == "LeCunNormal"){
|
||||||
|
std::normal_distribution<double> distribution(0, sqrt(1 / n));
|
||||||
|
weights[i].push_back(distribution(generator));
|
||||||
|
}
|
||||||
|
else if(type == "LeCunUniform"){
|
||||||
|
std::uniform_real_distribution<double> distribution(-sqrt(3/n), sqrt(3/n));
|
||||||
|
weights[i].push_back(distribution(generator));
|
||||||
|
}
|
||||||
else if(type == "Uniform"){
|
else if(type == "Uniform"){
|
||||||
std::uniform_real_distribution<double> distribution(-1/sqrt(n), 1/sqrt(n));
|
std::uniform_real_distribution<double> distribution(-1/sqrt(n), 1/sqrt(n));
|
||||||
weights[i].push_back(distribution(generator));
|
weights[i].push_back(distribution(generator));
|
||||||
|
Loading…
Reference in New Issue
Block a user