mirror of
https://github.com/Relintai/MLPP.git
synced 2025-04-23 21:23:22 +02:00
"Vectorized" implementation of SGD for Tanh Reg
This commit is contained in:
parent
687dada9f1
commit
e5598185dd
@ -66,6 +66,7 @@ namespace MLPP{
|
||||
}
|
||||
|
||||
void TanhReg::SGD(double learning_rate, int max_epoch, bool UI){
|
||||
LinAlg alg;
|
||||
Reg regularization;
|
||||
Utilities util;
|
||||
double cost_prev = 0;
|
||||
@ -80,24 +81,15 @@ namespace MLPP{
|
||||
double y_hat = Evaluate(inputSet[outputIndex]);
|
||||
cost_prev = Cost({y_hat}, {outputSet[outputIndex]});
|
||||
|
||||
double error = y_hat - outputSet[outputIndex];
|
||||
|
||||
for(int i = 0; i < k; i++){
|
||||
|
||||
// Calculating the weight gradients
|
||||
|
||||
double w_gradient = (y_hat - outputSet[outputIndex]) * (1 - y_hat * y_hat) * inputSet[outputIndex][i];
|
||||
|
||||
|
||||
// Weight updation
|
||||
weights[i] -= learning_rate * w_gradient;
|
||||
}
|
||||
// Weight Updation
|
||||
weights = alg.subtraction(weights, alg.scalarMultiply(learning_rate * error * (1 - y_hat * y_hat), inputSet[outputIndex]));
|
||||
weights = regularization.regWeights(weights, lambda, alpha, reg);
|
||||
|
||||
// Calculating the bias gradients
|
||||
double b_gradient = (y_hat - outputSet[outputIndex]) * (1 - y_hat * y_hat);
|
||||
|
||||
// Bias updation
|
||||
bias -= learning_rate * b_gradient;
|
||||
bias -= learning_rate * error * (1 - y_hat * y_hat);
|
||||
|
||||
y_hat = Evaluate({inputSet[outputIndex]});
|
||||
|
||||
if(UI) {
|
||||
|
Loading…
Reference in New Issue
Block a user