diff --git a/mlpp/wgan/wgan.cpp b/mlpp/wgan/wgan.cpp index f46de94..fbf4e6d 100644 --- a/mlpp/wgan/wgan.cpp +++ b/mlpp/wgan/wgan.cpp @@ -636,18 +636,6 @@ std::tuple>>, std::vector> M if (!network.empty()) { auto hiddenLayerAvn = network[network.size() - 1].activation_map[network[network.size() - 1].activation]; - - //std::cout << "=-------=--==-=-=-=" << std::endl; - //alg.printVector(outputLayer->delta); - //std::cout << "=-------=--==-=-=-=" << std::endl; - //alg.printVector(outputLayer->weights); - - //std::cout << "=-------=--==-=-=-=" << std::endl; - //alg.printMatrix(alg.outerProduct(outputLayer->delta, outputLayer->weights)); - //std::cout << "=-------=--==-=-=-=" << std::endl; - //alg.printMatrix((avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1)); - //CRASH_NOW(); - network[network.size() - 1].delta = alg.hadamard_product(alg.outerProduct(outputLayer->delta, outputLayer->weights), (avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1)); std::vector> hiddenLayerWGrad = alg.matmult(alg.transpose(network[network.size() - 1].input), network[network.size() - 1].delta); diff --git a/test/mlpp_tests.cpp b/test/mlpp_tests.cpp index e26b1fb..3c3edad 100644 --- a/test/mlpp_tests.cpp +++ b/test/mlpp_tests.cpp @@ -475,7 +475,7 @@ void MLPPTests::test_dynamically_sized_ann(bool ui) { alg.printVector(ann.modelSetTest(alg.transpose(inputSet))); std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl; } -void MLPPTests::test_wgan(bool ui) { +void MLPPTests::test_wgan_old(bool ui) { //MLPPStat stat; MLPPLinAlg alg; //MLPPActivation avn; @@ -496,6 +496,19 @@ void MLPPTests::test_wgan(bool ui) { gan_old.gradientDescent(0.1, 55000, ui); std::cout << "GENERATED INPUT: (Gaussian-sampled noise):" << std::endl; alg.printMatrix(gan_old.generateExample(100)); +} +void MLPPTests::test_wgan(bool ui) { + //MLPPStat stat; + MLPPLinAlg alg; + //MLPPActivation avn; + //MLPPCost cost; + //MLPPData data; + //MLPPConvolutions conv; + + std::vector> outputSet = { + { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 }, + { 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 } + }; Ref output_set; output_set.instance(); @@ -1250,6 +1263,7 @@ void MLPPTests::_bind_methods() { ClassDB::bind_method(D_METHOD("test_soft_max_network", "ui"), &MLPPTests::test_soft_max_network, false); ClassDB::bind_method(D_METHOD("test_autoencoder", "ui"), &MLPPTests::test_autoencoder, false); ClassDB::bind_method(D_METHOD("test_dynamically_sized_ann", "ui"), &MLPPTests::test_dynamically_sized_ann, false); + ClassDB::bind_method(D_METHOD("test_wgan_old", "ui"), &MLPPTests::test_wgan_old, false); ClassDB::bind_method(D_METHOD("test_wgan", "ui"), &MLPPTests::test_wgan, false); ClassDB::bind_method(D_METHOD("test_ann", "ui"), &MLPPTests::test_ann, false); ClassDB::bind_method(D_METHOD("test_dynamically_sized_mann", "ui"), &MLPPTests::test_dynamically_sized_mann, false); diff --git a/test/mlpp_tests.h b/test/mlpp_tests.h index 4124e93..a88d1b0 100644 --- a/test/mlpp_tests.h +++ b/test/mlpp_tests.h @@ -44,6 +44,7 @@ public: void test_soft_max_network(bool ui = false); void test_autoencoder(bool ui = false); void test_dynamically_sized_ann(bool ui = false); + void test_wgan_old(bool ui = false); void test_wgan(bool ui = false); void test_ann(bool ui = false); void test_dynamically_sized_mann(bool ui = false);