2023-01-23 21:13:26 +01:00
//
// SVC.cpp
//
// Created by Marc Melikyan on 10/2/20.
//
2023-01-24 19:14:38 +01:00
# include "svc.h"
2023-01-24 18:12:23 +01:00
# include "../activation/activation.h"
2023-01-24 19:00:54 +01:00
# include "../cost/cost.h"
2023-01-24 18:12:23 +01:00
# include "../lin_alg/lin_alg.h"
# include "../regularization/reg.h"
# include "../utilities/utilities.h"
2023-01-23 21:13:26 +01:00
# include <iostream>
# include <random>
2023-01-24 19:20:18 +01:00
2023-01-27 13:01:16 +01:00
MLPPSVC : : MLPPSVC ( std : : vector < std : : vector < real_t > > inputSet , std : : vector < real_t > outputSet , real_t C ) :
2023-01-24 19:00:54 +01:00
inputSet ( inputSet ) , outputSet ( outputSet ) , n ( inputSet . size ( ) ) , k ( inputSet [ 0 ] . size ( ) ) , C ( C ) {
y_hat . resize ( n ) ;
2023-01-25 01:09:37 +01:00
weights = MLPPUtilities : : weightInitialization ( k ) ;
bias = MLPPUtilities : : biasInitialization ( ) ;
2023-01-24 19:00:54 +01:00
}
2023-01-27 13:01:16 +01:00
std : : vector < real_t > MLPPSVC : : modelSetTest ( std : : vector < std : : vector < real_t > > X ) {
2023-01-24 19:00:54 +01:00
return Evaluate ( X ) ;
}
2023-01-27 13:01:16 +01:00
real_t MLPPSVC : : modelTest ( std : : vector < real_t > x ) {
2023-01-24 19:00:54 +01:00
return Evaluate ( x ) ;
}
2023-01-27 13:01:16 +01:00
void MLPPSVC : : gradientDescent ( real_t learning_rate , int max_epoch , bool UI ) {
2023-01-24 19:37:08 +01:00
class MLPPCost cost ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-25 00:54:50 +01:00
MLPPReg regularization ;
2023-01-27 13:01:16 +01:00
real_t cost_prev = 0 ;
2023-01-24 19:00:54 +01:00
int epoch = 1 ;
forwardPass ( ) ;
while ( true ) {
cost_prev = Cost ( y_hat , outputSet , weights , C ) ;
weights = alg . subtraction ( weights , alg . scalarMultiply ( learning_rate / n , alg . mat_vec_mult ( alg . transpose ( inputSet ) , cost . HingeLossDeriv ( z , outputSet , C ) ) ) ) ;
weights = regularization . regWeights ( weights , learning_rate / n , 0 , " Ridge " ) ;
// Calculating the bias gradients
bias + = learning_rate * alg . sum_elements ( cost . HingeLossDeriv ( y_hat , outputSet , C ) ) / n ;
forwardPass ( ) ;
// UI PORTION
if ( UI ) {
2023-01-25 01:09:37 +01:00
MLPPUtilities : : CostInfo ( epoch , cost_prev , Cost ( y_hat , outputSet , weights , C ) ) ;
MLPPUtilities : : UI ( weights , bias ) ;
2023-01-24 19:00:54 +01:00
}
epoch + + ;
if ( epoch > max_epoch ) {
break ;
}
}
}
2023-01-27 13:01:16 +01:00
void MLPPSVC : : SGD ( real_t learning_rate , int max_epoch , bool UI ) {
2023-01-24 19:37:08 +01:00
class MLPPCost cost ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-25 00:54:50 +01:00
MLPPReg regularization ;
2023-01-24 19:00:54 +01:00
2023-01-27 13:01:16 +01:00
real_t cost_prev = 0 ;
2023-01-24 19:00:54 +01:00
int epoch = 1 ;
while ( true ) {
std : : random_device rd ;
std : : default_random_engine generator ( rd ( ) ) ;
std : : uniform_int_distribution < int > distribution ( 0 , int ( n - 1 ) ) ;
int outputIndex = distribution ( generator ) ;
2023-01-27 13:01:16 +01:00
real_t y_hat = Evaluate ( inputSet [ outputIndex ] ) ;
real_t z = propagate ( inputSet [ outputIndex ] ) ;
2023-01-24 19:00:54 +01:00
cost_prev = Cost ( { z } , { outputSet [ outputIndex ] } , weights , C ) ;
2023-01-27 13:01:16 +01:00
real_t costDeriv = cost . HingeLossDeriv ( std : : vector < real_t > ( { z } ) , std : : vector < real_t > ( { outputSet [ outputIndex ] } ) , C ) [ 0 ] ; // Explicit conversion to avoid ambiguity with overloaded function. Error occured on Ubuntu.
2023-01-24 19:00:54 +01:00
// Weight Updation
weights = alg . subtraction ( weights , alg . scalarMultiply ( learning_rate * costDeriv , inputSet [ outputIndex ] ) ) ;
weights = regularization . regWeights ( weights , learning_rate , 0 , " Ridge " ) ;
// Bias updation
bias - = learning_rate * costDeriv ;
y_hat = Evaluate ( { inputSet [ outputIndex ] } ) ;
if ( UI ) {
2023-01-25 01:09:37 +01:00
MLPPUtilities : : CostInfo ( epoch , cost_prev , Cost ( { z } , { outputSet [ outputIndex ] } , weights , C ) ) ;
MLPPUtilities : : UI ( weights , bias ) ;
2023-01-24 19:00:54 +01:00
}
epoch + + ;
if ( epoch > max_epoch ) {
break ;
}
}
forwardPass ( ) ;
}
2023-01-27 13:01:16 +01:00
void MLPPSVC : : MBGD ( real_t learning_rate , int max_epoch , int mini_batch_size , bool UI ) {
2023-01-24 19:37:08 +01:00
class MLPPCost cost ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-25 00:54:50 +01:00
MLPPReg regularization ;
2023-01-27 13:01:16 +01:00
real_t cost_prev = 0 ;
2023-01-24 19:00:54 +01:00
int epoch = 1 ;
// Creating the mini-batches
int n_mini_batch = n / mini_batch_size ;
2023-01-25 01:09:37 +01:00
auto [ inputMiniBatches , outputMiniBatches ] = MLPPUtilities : : createMiniBatches ( inputSet , outputSet , n_mini_batch ) ;
2023-01-24 19:00:54 +01:00
while ( true ) {
for ( int i = 0 ; i < n_mini_batch ; i + + ) {
2023-01-27 13:01:16 +01:00
std : : vector < real_t > y_hat = Evaluate ( inputMiniBatches [ i ] ) ;
std : : vector < real_t > z = propagate ( inputMiniBatches [ i ] ) ;
2023-01-24 19:00:54 +01:00
cost_prev = Cost ( z , outputMiniBatches [ i ] , weights , C ) ;
// Calculating the weight gradients
weights = alg . subtraction ( weights , alg . scalarMultiply ( learning_rate / n , alg . mat_vec_mult ( alg . transpose ( inputMiniBatches [ i ] ) , cost . HingeLossDeriv ( z , outputMiniBatches [ i ] , C ) ) ) ) ;
weights = regularization . regWeights ( weights , learning_rate / n , 0 , " Ridge " ) ;
// Calculating the bias gradients
bias - = learning_rate * alg . sum_elements ( cost . HingeLossDeriv ( y_hat , outputMiniBatches [ i ] , C ) ) / n ;
forwardPass ( ) ;
y_hat = Evaluate ( inputMiniBatches [ i ] ) ;
if ( UI ) {
2023-01-25 01:09:37 +01:00
MLPPUtilities : : CostInfo ( epoch , cost_prev , Cost ( z , outputMiniBatches [ i ] , weights , C ) ) ;
MLPPUtilities : : UI ( weights , bias ) ;
2023-01-24 19:00:54 +01:00
}
}
epoch + + ;
if ( epoch > max_epoch ) {
break ;
}
}
forwardPass ( ) ;
}
2023-01-27 13:01:16 +01:00
real_t MLPPSVC : : score ( ) {
2023-01-25 01:09:37 +01:00
MLPPUtilities util ;
2023-01-24 19:00:54 +01:00
return util . performance ( y_hat , outputSet ) ;
}
2023-01-25 01:09:37 +01:00
void MLPPSVC : : save ( std : : string fileName ) {
MLPPUtilities util ;
2023-01-24 19:00:54 +01:00
util . saveParameters ( fileName , weights , bias ) ;
}
2023-01-27 13:01:16 +01:00
real_t MLPPSVC : : Cost ( std : : vector < real_t > z , std : : vector < real_t > y , std : : vector < real_t > weights , real_t C ) {
2023-01-24 19:37:08 +01:00
class MLPPCost cost ;
2023-01-24 19:00:54 +01:00
return cost . HingeLoss ( z , y , weights , C ) ;
}
2023-01-27 13:01:16 +01:00
std : : vector < real_t > MLPPSVC : : Evaluate ( std : : vector < std : : vector < real_t > > X ) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-24 19:00:54 +01:00
return avn . sign ( alg . scalarAdd ( bias , alg . mat_vec_mult ( X , weights ) ) ) ;
}
2023-01-27 13:01:16 +01:00
std : : vector < real_t > MLPPSVC : : propagate ( std : : vector < std : : vector < real_t > > X ) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-24 19:00:54 +01:00
return alg . scalarAdd ( bias , alg . mat_vec_mult ( X , weights ) ) ;
}
2023-01-27 13:01:16 +01:00
real_t MLPPSVC : : Evaluate ( std : : vector < real_t > x ) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-24 19:00:54 +01:00
return avn . sign ( alg . dot ( weights , x ) + bias ) ;
}
2023-01-27 13:01:16 +01:00
real_t MLPPSVC : : propagate ( std : : vector < real_t > x ) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-24 19:00:54 +01:00
return alg . dot ( weights , x ) + bias ;
}
// sign ( wTx + b )
2023-01-25 01:09:37 +01:00
void MLPPSVC : : forwardPass ( ) {
2023-01-25 00:29:02 +01:00
MLPPLinAlg alg ;
2023-01-24 19:23:30 +01:00
MLPPActivation avn ;
2023-01-24 19:00:54 +01:00
z = propagate ( inputSet ) ;
y_hat = avn . sign ( z ) ;
}