From 38cc0c15023f280babd29530f46a222cd474eac9 Mon Sep 17 00:00:00 2001 From: novak_99 Date: Tue, 23 Nov 2021 23:45:20 -0800 Subject: [PATCH] added lecun weight init methods --- MLPP/Utilities/Utilities.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/MLPP/Utilities/Utilities.cpp b/MLPP/Utilities/Utilities.cpp index 2dbb179..53b5962 100644 --- a/MLPP/Utilities/Utilities.cpp +++ b/MLPP/Utilities/Utilities.cpp @@ -34,6 +34,14 @@ namespace MLPP{ std::uniform_real_distribution distribution(-sqrt(6 / n), sqrt(6 / n)); weights.push_back(distribution(generator)); } + else if(type == "LeCunNormal"){ + std::normal_distribution distribution(0, sqrt(1 / n)); + weights.push_back(distribution(generator)); + } + else if(type == "LeCunUniform"){ + std::uniform_real_distribution distribution(-sqrt(3/n), sqrt(3/n)); + weights.push_back(distribution(generator)); + } else if(type == "Uniform"){ std::uniform_real_distribution distribution(-1/sqrt(n), 1/sqrt(n)); weights.push_back(distribution(generator)); @@ -79,6 +87,14 @@ namespace MLPP{ std::uniform_real_distribution distribution(-sqrt(6 / n), sqrt(6 / n)); weights[i].push_back(distribution(generator)); } + else if(type == "LeCunNormal"){ + std::normal_distribution distribution(0, sqrt(1 / n)); + weights[i].push_back(distribution(generator)); + } + else if(type == "LeCunUniform"){ + std::uniform_real_distribution distribution(-sqrt(3/n), sqrt(3/n)); + weights[i].push_back(distribution(generator)); + } else if(type == "Uniform"){ std::uniform_real_distribution distribution(-1/sqrt(n), 1/sqrt(n)); weights[i].push_back(distribution(generator));