mirror of
https://github.com/Relintai/MLPP.git
synced 2024-11-12 10:15:01 +01:00
"Vectorized" implementation of SGD for ProbitReg
This commit is contained in:
parent
009fec444a
commit
13b0d76c5c
@ -94,6 +94,7 @@ namespace MLPP{
|
||||
}
|
||||
|
||||
void ProbitReg::SGD(double learning_rate, int max_epoch, bool UI){
|
||||
// NOTE: ∂y_hat/∂z is sparse
|
||||
LinAlg alg;
|
||||
Activation avn;
|
||||
Reg regularization;
|
||||
@ -111,24 +112,15 @@ namespace MLPP{
|
||||
double z = propagate(inputSet[outputIndex]);
|
||||
cost_prev = Cost({y_hat}, {outputSet[outputIndex]});
|
||||
|
||||
double error = y_hat - outputSet[outputIndex];
|
||||
|
||||
for(int i = 0; i < k; i++){
|
||||
|
||||
// Calculating the weight gradients
|
||||
|
||||
double w_gradient = (y_hat - outputSet[outputIndex]) * ((1 / sqrt(2 * M_PI)) * exp(-z * z / 2)) * inputSet[outputIndex][i];
|
||||
|
||||
std::cout << exp(-z * z / 2) << std::endl;
|
||||
// Weight updation
|
||||
weights[i] -= learning_rate * w_gradient;
|
||||
}
|
||||
// Weight Updation
|
||||
weights = alg.subtraction(weights, alg.scalarMultiply(learning_rate * error * ((1 / sqrt(2 * M_PI)) * exp(-z * z / 2)), inputSet[outputIndex]));
|
||||
weights = regularization.regWeights(weights, lambda, alpha, reg);
|
||||
|
||||
// Calculating the bias gradients
|
||||
double b_gradient = (y_hat - outputSet[outputIndex]);
|
||||
|
||||
// Bias updation
|
||||
bias -= learning_rate * b_gradient * ((1 / sqrt(2 * M_PI)) * exp(-z * z / 2));
|
||||
bias -= learning_rate * error * ((1 / sqrt(2 * M_PI)) * exp(-z * z / 2));
|
||||
|
||||
y_hat = Evaluate({inputSet[outputIndex]});
|
||||
|
||||
if(UI) {
|
||||
|
18
main.cpp
18
main.cpp
@ -142,22 +142,22 @@ int main() {
|
||||
// alg.printVector(model.modelSetTest((alg.transpose(inputSet))));
|
||||
// std::cout << "ACCURACY: " << 100 * model.score() << "%" << std::endl;
|
||||
|
||||
// LOGISTIC REGRESSION
|
||||
std::vector<std::vector<double>> inputSet;
|
||||
std::vector<double> outputSet;
|
||||
data.setData(30, "/Users/marcmelikyan/Desktop/Data/BreastCancer.csv", inputSet, outputSet);
|
||||
LogReg model(inputSet, outputSet);
|
||||
model.SGD(0.001, 100000, 0);
|
||||
// // LOGISTIC REGRESSION
|
||||
// std::vector<std::vector<double>> inputSet;
|
||||
// std::vector<double> outputSet;
|
||||
// data.setData(30, "/Users/marcmelikyan/Desktop/Data/BreastCancer.csv", inputSet, outputSet);
|
||||
// LogReg model(inputSet, outputSet);
|
||||
// model.SGD(0.001, 100000, 0);
|
||||
// model.MLE(0.1, 10000, 0);
|
||||
alg.printVector(model.modelSetTest(inputSet));
|
||||
std::cout << "ACCURACY: " << 100 * model.score() << "%" << std::endl;
|
||||
// alg.printVector(model.modelSetTest(inputSet));
|
||||
// std::cout << "ACCURACY: " << 100 * model.score() << "%" << std::endl;
|
||||
|
||||
// // PROBIT REGRESSION
|
||||
// std::vector<std::vector<double>> inputSet;
|
||||
// std::vector<double> outputSet;
|
||||
// data.setData(30, "/Users/marcmelikyan/Desktop/Data/BreastCancer.csv", inputSet, outputSet);
|
||||
// ProbitReg model(inputSet, outputSet);
|
||||
// model.gradientDescent(0.0001, 10000, 1);
|
||||
// model.SGD(0.001, 10000, 1);
|
||||
// alg.printVector(model.modelSetTest(inputSet));
|
||||
// std::cout << "ACCURACY: " << 100 * model.score() << "%" << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user