Split the wgan test.

This commit is contained in:
Relintai 2023-02-06 14:27:43 +01:00
parent bdfa69f1e9
commit a42cfac723
3 changed files with 16 additions and 13 deletions

View File

@ -636,18 +636,6 @@ std::tuple<std::vector<std::vector<std::vector<real_t>>>, std::vector<real_t>> M
if (!network.empty()) {
auto hiddenLayerAvn = network[network.size() - 1].activation_map[network[network.size() - 1].activation];
//std::cout << "=-------=--==-=-=-=" << std::endl;
//alg.printVector(outputLayer->delta);
//std::cout << "=-------=--==-=-=-=" << std::endl;
//alg.printVector(outputLayer->weights);
//std::cout << "=-------=--==-=-=-=" << std::endl;
//alg.printMatrix(alg.outerProduct(outputLayer->delta, outputLayer->weights));
//std::cout << "=-------=--==-=-=-=" << std::endl;
//alg.printMatrix((avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1));
//CRASH_NOW();
network[network.size() - 1].delta = alg.hadamard_product(alg.outerProduct(outputLayer->delta, outputLayer->weights), (avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1));
std::vector<std::vector<real_t>> hiddenLayerWGrad = alg.matmult(alg.transpose(network[network.size() - 1].input), network[network.size() - 1].delta);

View File

@ -475,7 +475,7 @@ void MLPPTests::test_dynamically_sized_ann(bool ui) {
alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
}
void MLPPTests::test_wgan(bool ui) {
void MLPPTests::test_wgan_old(bool ui) {
//MLPPStat stat;
MLPPLinAlg alg;
//MLPPActivation avn;
@ -496,6 +496,19 @@ void MLPPTests::test_wgan(bool ui) {
gan_old.gradientDescent(0.1, 55000, ui);
std::cout << "GENERATED INPUT: (Gaussian-sampled noise):" << std::endl;
alg.printMatrix(gan_old.generateExample(100));
}
void MLPPTests::test_wgan(bool ui) {
//MLPPStat stat;
MLPPLinAlg alg;
//MLPPActivation avn;
//MLPPCost cost;
//MLPPData data;
//MLPPConvolutions conv;
std::vector<std::vector<real_t>> outputSet = {
{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 },
{ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 }
};
Ref<MLPPMatrix> output_set;
output_set.instance();
@ -1250,6 +1263,7 @@ void MLPPTests::_bind_methods() {
ClassDB::bind_method(D_METHOD("test_soft_max_network", "ui"), &MLPPTests::test_soft_max_network, false);
ClassDB::bind_method(D_METHOD("test_autoencoder", "ui"), &MLPPTests::test_autoencoder, false);
ClassDB::bind_method(D_METHOD("test_dynamically_sized_ann", "ui"), &MLPPTests::test_dynamically_sized_ann, false);
ClassDB::bind_method(D_METHOD("test_wgan_old", "ui"), &MLPPTests::test_wgan_old, false);
ClassDB::bind_method(D_METHOD("test_wgan", "ui"), &MLPPTests::test_wgan, false);
ClassDB::bind_method(D_METHOD("test_ann", "ui"), &MLPPTests::test_ann, false);
ClassDB::bind_method(D_METHOD("test_dynamically_sized_mann", "ui"), &MLPPTests::test_dynamically_sized_mann, false);

View File

@ -44,6 +44,7 @@ public:
void test_soft_max_network(bool ui = false);
void test_autoencoder(bool ui = false);
void test_dynamically_sized_ann(bool ui = false);
void test_wgan_old(bool ui = false);
void test_wgan(bool ui = false);
void test_ann(bool ui = false);
void test_dynamically_sized_mann(bool ui = false);