mirror of
https://github.com/Relintai/MLPP.git
synced 2025-02-04 15:55:53 +01:00
"Vectorized" implementation of SGD for CLogLog Reg
This commit is contained in:
parent
13b0d76c5c
commit
687dada9f1
@ -96,6 +96,7 @@ namespace MLPP{
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CLogLogReg::SGD(double learning_rate, int max_epoch, bool UI){
|
void CLogLogReg::SGD(double learning_rate, int max_epoch, bool UI){
|
||||||
|
LinAlg alg;
|
||||||
Reg regularization;
|
Reg regularization;
|
||||||
double cost_prev = 0;
|
double cost_prev = 0;
|
||||||
int epoch = 1;
|
int epoch = 1;
|
||||||
@ -109,26 +110,17 @@ namespace MLPP{
|
|||||||
|
|
||||||
double y_hat = Evaluate(inputSet[outputIndex]);
|
double y_hat = Evaluate(inputSet[outputIndex]);
|
||||||
double z = propagate(inputSet[outputIndex]);
|
double z = propagate(inputSet[outputIndex]);
|
||||||
|
|
||||||
cost_prev = Cost({y_hat}, {outputSet[outputIndex]});
|
cost_prev = Cost({y_hat}, {outputSet[outputIndex]});
|
||||||
|
|
||||||
for(int i = 0; i < k; i++){
|
double error = y_hat - outputSet[outputIndex];
|
||||||
|
|
||||||
// Calculating the weight gradients
|
// Weight Updation
|
||||||
double w_gradient = (y_hat - outputSet[outputIndex]) * exp(z-exp(z)) * inputSet[outputIndex][i];
|
weights = alg.subtraction(weights, alg.scalarMultiply(learning_rate * error * exp(z-exp(z)), inputSet[outputIndex]));
|
||||||
|
|
||||||
|
|
||||||
// Weight updation
|
|
||||||
weights[i] -= learning_rate * w_gradient;
|
|
||||||
}
|
|
||||||
weights = regularization.regWeights(weights, lambda, alpha, reg);
|
weights = regularization.regWeights(weights, lambda, alpha, reg);
|
||||||
|
|
||||||
|
|
||||||
// Calculating the bias gradients
|
|
||||||
double b_gradient = (y_hat - outputSet[outputIndex]) * exp(z-exp(z));
|
|
||||||
|
|
||||||
// Bias updation
|
// Bias updation
|
||||||
bias -= learning_rate * b_gradient;
|
bias -= learning_rate * error * exp(z-exp(z));
|
||||||
|
|
||||||
y_hat = Evaluate({inputSet[outputIndex]});
|
y_hat = Evaluate({inputSet[outputIndex]});
|
||||||
|
|
||||||
if(UI) {
|
if(UI) {
|
||||||
|
Loading…
Reference in New Issue
Block a user