diff --git a/MLPP/Utilities/Utilities.cpp b/MLPP/Utilities/Utilities.cpp index 2dbb179..53b5962 100644 --- a/MLPP/Utilities/Utilities.cpp +++ b/MLPP/Utilities/Utilities.cpp @@ -34,6 +34,14 @@ namespace MLPP{ std::uniform_real_distribution distribution(-sqrt(6 / n), sqrt(6 / n)); weights.push_back(distribution(generator)); } + else if(type == "LeCunNormal"){ + std::normal_distribution distribution(0, sqrt(1 / n)); + weights.push_back(distribution(generator)); + } + else if(type == "LeCunUniform"){ + std::uniform_real_distribution distribution(-sqrt(3/n), sqrt(3/n)); + weights.push_back(distribution(generator)); + } else if(type == "Uniform"){ std::uniform_real_distribution distribution(-1/sqrt(n), 1/sqrt(n)); weights.push_back(distribution(generator)); @@ -79,6 +87,14 @@ namespace MLPP{ std::uniform_real_distribution distribution(-sqrt(6 / n), sqrt(6 / n)); weights[i].push_back(distribution(generator)); } + else if(type == "LeCunNormal"){ + std::normal_distribution distribution(0, sqrt(1 / n)); + weights[i].push_back(distribution(generator)); + } + else if(type == "LeCunUniform"){ + std::uniform_real_distribution distribution(-sqrt(3/n), sqrt(3/n)); + weights[i].push_back(distribution(generator)); + } else if(type == "Uniform"){ std::uniform_real_distribution distribution(-1/sqrt(n), 1/sqrt(n)); weights[i].push_back(distribution(generator));