diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index 40f0027..4f582f7 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -562,7 +562,7 @@ void MLPPTests::test_soft_max_network(bool ui) { Ref dt = data.load_wine(_wine_data_path); MLPPSoftmaxNet model(dt->get_input(), dt->get_output(), 1); - model.train_gradient_descent(0.01, 100000, ui); + model.train_gradient_descent(0.000001, 300, ui); PLOG_MSG(model.model_set_test(dt->get_input())->to_string()); std::cout << "ACCURACY: " << 100 * model.score() << "%" << std::endl; } @@ -602,30 +602,20 @@ void MLPPTests::test_dynamically_sized_ann(bool ui) { output_set->set_from_std_vector(outputSet); MLPPANN ann(algn.transposenm(input_set), output_set); + ann.add_layer(2, MLPPActivation::ACTIVATION_FUNCTION_COSH); ann.add_output_layer(MLPPActivation::ACTIVATION_FUNCTION_SIGMOID, MLPPCost::COST_TYPE_LOGISTIC_LOSS); ann.amsgrad(0.1, 10000, 1, 0.9, 0.999, 0.000001, ui); ann.adadelta(1, 1000, 2, 0.9, 0.000001, ui); ann.momentum(0.1, 8000, 2, 0.9, true, ui); - ann.set_learning_rate_scheduler_drop(MLPPANN::SCHEDULER_TYPE_STEP, 0.5, 1000); ann.gradient_descent(0.01, 30000); + PLOG_MSG(ann.model_set_test(algn.transposenm(input_set))->to_string()); PLOG_MSG("ACCURACY: " + String::num(100 * ann.score()) + "%"); } void MLPPTests::test_wgan_old(bool ui) { - //MLPPStat stat; - MLPPLinAlg alg; - //MLPPActivation avn; - //MLPPCost cost; - //MLPPData data; - //MLPPConvolutions conv; - - std::vector> outputSet = { - { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 }, - { 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 } - }; } void MLPPTests::test_wgan(bool ui) { //MLPPStat stat; diff --git a/test/mlpp_tests_old.cpp b/test/mlpp_tests_old.cpp index 76ff7b6..5acb2d4 100644 --- a/test/mlpp_tests_old.cpp +++ b/test/mlpp_tests_old.cpp @@ -121,54 +121,10 @@ void MLPPTestsOld::test_support_vector_classification(bool ui) { void MLPPTestsOld::test_mlp(bool ui) { } void MLPPTestsOld::test_soft_max_network(bool ui) { - MLPPLinAlgOld alg; - MLPPData data; - - // SOFTMAX NETWORK - Ref dt = data.load_wine(_wine_data_path); - - MLPPSoftmaxNetOld model_old(dt->get_input()->to_std_vector(), dt->get_output()->to_std_vector(), 1); - model_old.gradientDescent(0.01, 100000, ui); - alg.printMatrix(model_old.modelSetTest(dt->get_input()->to_std_vector())); - std::cout << "ACCURACY: " << 100 * model_old.score() << "%" << std::endl; } void MLPPTestsOld::test_autoencoder(bool ui) { - MLPPLinAlgOld alg; - - std::vector> inputSet = { { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, { 3, 5, 9, 12, 15, 18, 21, 24, 27, 30 } }; - - // AUTOENCODER - MLPPAutoEncoderOld model_old(alg.transpose(inputSet), 5); - model_old.SGD(0.001, 300000, ui); - alg.printMatrix(model_old.modelSetTest(alg.transpose(inputSet))); - std::cout << "ACCURACY (Old): " << 100 * model_old.score() << "%" << std::endl; - - Ref input_set; - input_set.instance(); - input_set->set_from_std_vectors(inputSet); } void MLPPTestsOld::test_dynamically_sized_ann(bool ui) { - MLPPLinAlgOld alg; - - // DYNAMICALLY SIZED ANN - // Possible Weight Init Methods: Default, Uniform, HeNormal, HeUniform, XavierNormal, XavierUniform - // Possible Activations: Linear, Sigmoid, Swish, Softplus, Softsign, CLogLog, Ar{Sinh, Cosh, Tanh, Csch, Sech, Coth}, GaussianCDF, GELU, UnitStep - // Possible Loss Functions: MSE, RMSE, MBE, LogLoss, CrossEntropy, HingeLoss - std::vector> inputSet = { { 0, 0, 1, 1 }, { 0, 1, 0, 1 } }; - std::vector outputSet = { 0, 1, 1, 0 }; - - MLPPANNOld ann_old(alg.transpose(inputSet), outputSet); - ann_old.addLayer(2, "Cosh"); - ann_old.addOutputLayer("Sigmoid", "LogLoss"); - - ann_old.AMSGrad(0.1, 10000, 1, 0.9, 0.999, 0.000001, ui); - ann_old.Adadelta(1, 1000, 2, 0.9, 0.000001, ui); - ann_old.Momentum(0.1, 8000, 2, 0.9, true, ui); - - ann_old.setLearningRateScheduler("Step", 0.5, 1000); - ann_old.gradientDescent(0.01, 30000); - alg.printVector(ann_old.modelSetTest(alg.transpose(inputSet))); - std::cout << "ACCURACY: " << 100 * ann_old.score() << "%" << std::endl; } void MLPPTestsOld::test_wgan_old(bool ui) { //MLPPStat stat;