mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-22 15:06:47 +01:00
Split the wgan test.
This commit is contained in:
parent
bdfa69f1e9
commit
a42cfac723
@ -636,18 +636,6 @@ std::tuple<std::vector<std::vector<std::vector<real_t>>>, std::vector<real_t>> M
|
|||||||
|
|
||||||
if (!network.empty()) {
|
if (!network.empty()) {
|
||||||
auto hiddenLayerAvn = network[network.size() - 1].activation_map[network[network.size() - 1].activation];
|
auto hiddenLayerAvn = network[network.size() - 1].activation_map[network[network.size() - 1].activation];
|
||||||
|
|
||||||
//std::cout << "=-------=--==-=-=-=" << std::endl;
|
|
||||||
//alg.printVector(outputLayer->delta);
|
|
||||||
//std::cout << "=-------=--==-=-=-=" << std::endl;
|
|
||||||
//alg.printVector(outputLayer->weights);
|
|
||||||
|
|
||||||
//std::cout << "=-------=--==-=-=-=" << std::endl;
|
|
||||||
//alg.printMatrix(alg.outerProduct(outputLayer->delta, outputLayer->weights));
|
|
||||||
//std::cout << "=-------=--==-=-=-=" << std::endl;
|
|
||||||
//alg.printMatrix((avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1));
|
|
||||||
//CRASH_NOW();
|
|
||||||
|
|
||||||
network[network.size() - 1].delta = alg.hadamard_product(alg.outerProduct(outputLayer->delta, outputLayer->weights), (avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1));
|
network[network.size() - 1].delta = alg.hadamard_product(alg.outerProduct(outputLayer->delta, outputLayer->weights), (avn.*hiddenLayerAvn)(network[network.size() - 1].z, 1));
|
||||||
std::vector<std::vector<real_t>> hiddenLayerWGrad = alg.matmult(alg.transpose(network[network.size() - 1].input), network[network.size() - 1].delta);
|
std::vector<std::vector<real_t>> hiddenLayerWGrad = alg.matmult(alg.transpose(network[network.size() - 1].input), network[network.size() - 1].delta);
|
||||||
|
|
||||||
|
@ -475,7 +475,7 @@ void MLPPTests::test_dynamically_sized_ann(bool ui) {
|
|||||||
alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
|
alg.printVector(ann.modelSetTest(alg.transpose(inputSet)));
|
||||||
std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
|
std::cout << "ACCURACY: " << 100 * ann.score() << "%" << std::endl;
|
||||||
}
|
}
|
||||||
void MLPPTests::test_wgan(bool ui) {
|
void MLPPTests::test_wgan_old(bool ui) {
|
||||||
//MLPPStat stat;
|
//MLPPStat stat;
|
||||||
MLPPLinAlg alg;
|
MLPPLinAlg alg;
|
||||||
//MLPPActivation avn;
|
//MLPPActivation avn;
|
||||||
@ -496,6 +496,19 @@ void MLPPTests::test_wgan(bool ui) {
|
|||||||
gan_old.gradientDescent(0.1, 55000, ui);
|
gan_old.gradientDescent(0.1, 55000, ui);
|
||||||
std::cout << "GENERATED INPUT: (Gaussian-sampled noise):" << std::endl;
|
std::cout << "GENERATED INPUT: (Gaussian-sampled noise):" << std::endl;
|
||||||
alg.printMatrix(gan_old.generateExample(100));
|
alg.printMatrix(gan_old.generateExample(100));
|
||||||
|
}
|
||||||
|
void MLPPTests::test_wgan(bool ui) {
|
||||||
|
//MLPPStat stat;
|
||||||
|
MLPPLinAlg alg;
|
||||||
|
//MLPPActivation avn;
|
||||||
|
//MLPPCost cost;
|
||||||
|
//MLPPData data;
|
||||||
|
//MLPPConvolutions conv;
|
||||||
|
|
||||||
|
std::vector<std::vector<real_t>> outputSet = {
|
||||||
|
{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 },
|
||||||
|
{ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 }
|
||||||
|
};
|
||||||
|
|
||||||
Ref<MLPPMatrix> output_set;
|
Ref<MLPPMatrix> output_set;
|
||||||
output_set.instance();
|
output_set.instance();
|
||||||
@ -1250,6 +1263,7 @@ void MLPPTests::_bind_methods() {
|
|||||||
ClassDB::bind_method(D_METHOD("test_soft_max_network", "ui"), &MLPPTests::test_soft_max_network, false);
|
ClassDB::bind_method(D_METHOD("test_soft_max_network", "ui"), &MLPPTests::test_soft_max_network, false);
|
||||||
ClassDB::bind_method(D_METHOD("test_autoencoder", "ui"), &MLPPTests::test_autoencoder, false);
|
ClassDB::bind_method(D_METHOD("test_autoencoder", "ui"), &MLPPTests::test_autoencoder, false);
|
||||||
ClassDB::bind_method(D_METHOD("test_dynamically_sized_ann", "ui"), &MLPPTests::test_dynamically_sized_ann, false);
|
ClassDB::bind_method(D_METHOD("test_dynamically_sized_ann", "ui"), &MLPPTests::test_dynamically_sized_ann, false);
|
||||||
|
ClassDB::bind_method(D_METHOD("test_wgan_old", "ui"), &MLPPTests::test_wgan_old, false);
|
||||||
ClassDB::bind_method(D_METHOD("test_wgan", "ui"), &MLPPTests::test_wgan, false);
|
ClassDB::bind_method(D_METHOD("test_wgan", "ui"), &MLPPTests::test_wgan, false);
|
||||||
ClassDB::bind_method(D_METHOD("test_ann", "ui"), &MLPPTests::test_ann, false);
|
ClassDB::bind_method(D_METHOD("test_ann", "ui"), &MLPPTests::test_ann, false);
|
||||||
ClassDB::bind_method(D_METHOD("test_dynamically_sized_mann", "ui"), &MLPPTests::test_dynamically_sized_mann, false);
|
ClassDB::bind_method(D_METHOD("test_dynamically_sized_mann", "ui"), &MLPPTests::test_dynamically_sized_mann, false);
|
||||||
|
@ -44,6 +44,7 @@ public:
|
|||||||
void test_soft_max_network(bool ui = false);
|
void test_soft_max_network(bool ui = false);
|
||||||
void test_autoencoder(bool ui = false);
|
void test_autoencoder(bool ui = false);
|
||||||
void test_dynamically_sized_ann(bool ui = false);
|
void test_dynamically_sized_ann(bool ui = false);
|
||||||
|
void test_wgan_old(bool ui = false);
|
||||||
void test_wgan(bool ui = false);
|
void test_wgan(bool ui = false);
|
||||||
void test_ann(bool ui = false);
|
void test_ann(bool ui = false);
|
||||||
void test_dynamically_sized_mann(bool ui = false);
|
void test_dynamically_sized_mann(bool ui = false);
|
||||||
|
Loading…
Reference in New Issue
Block a user