diff --git a/mlpp/wgan/wgan.cpp b/mlpp/wgan/wgan.cpp index f78234f..dfe0741 100644 --- a/mlpp/wgan/wgan.cpp +++ b/mlpp/wgan/wgan.cpp @@ -20,12 +20,6 @@ Ref MLPPWGAN::get_output_set() { } void MLPPWGAN::set_output_set(const Ref &val) { _output_set = val; - - _n = 0; - - if (_output_set.is_valid()) { - _n = _output_set->size().y; - } } int MLPPWGAN::get_k() const { @@ -43,13 +37,14 @@ void MLPPWGAN::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { //MLPPCost mlpp_cost; real_t cost_prev = 0; int epoch = 1; + int n = _output_set->size().y; forward_pass(); const int CRITIC_INTERATIONS = 5; // Wasserstein GAN specific parameter. while (true) { - cost_prev = cost(_y_hat, MLPPVector::create_vec_one(_n)); + cost_prev = cost(_y_hat, MLPPVector::create_vec_one(n)); Ref generator_input_set; Ref discriminator_input_set; @@ -60,38 +55,38 @@ void MLPPWGAN::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { // Training of the discriminator. for (int i = 0; i < CRITIC_INTERATIONS; i++) { - generator_input_set = MLPPMatrix::create_gaussian_noise(_n, _k); + generator_input_set = MLPPMatrix::create_gaussian_noise(n, _k); discriminator_input_set->set_from_mlpp_matrix(model_set_test_generator(generator_input_set)); discriminator_input_set->rows_add_mlpp_matrix(_output_set); // Fake + real inputs. ly_hat = model_set_test_discriminator(discriminator_input_set); - loutput_set = MLPPVector::create_vec_one(_n)->scalar_multiplyn(-1); // WGAN changes y_i = 1 and y_i = 0 to y_i = 1 and y_i = -1 - Ref output_set_real = MLPPVector::create_vec_one(_n); + loutput_set = MLPPVector::create_vec_one(n)->scalar_multiplyn(-1); // WGAN changes y_i = 1 and y_i = 0 to y_i = 1 and y_i = -1 + Ref output_set_real = MLPPVector::create_vec_one(n); loutput_set->append_mlpp_vector(output_set_real); // Fake + real output scores. DiscriminatorGradientResult discriminator_gradient_results = compute_discriminator_gradients(ly_hat, loutput_set); Ref cumulative_discriminator_hidden_layer_w_grad = discriminator_gradient_results.cumulative_hidden_layer_w_grad; Ref output_discriminator_w_grad = discriminator_gradient_results.output_w_grad; - cumulative_discriminator_hidden_layer_w_grad->scalar_multiply(learning_rate / _n); - output_discriminator_w_grad->scalar_multiply(learning_rate / _n); + cumulative_discriminator_hidden_layer_w_grad->scalar_multiply(learning_rate / n); + output_discriminator_w_grad->scalar_multiply(learning_rate / n); update_discriminator_parameters(cumulative_discriminator_hidden_layer_w_grad, output_discriminator_w_grad, learning_rate); } // Training of the generator. - generator_input_set = MLPPMatrix::create_gaussian_noise(_n, _k); + generator_input_set = MLPPMatrix::create_gaussian_noise(n, _k); discriminator_input_set->set_from_mlpp_matrix(model_set_test_generator(generator_input_set)); ly_hat = model_set_test_discriminator(discriminator_input_set); - loutput_set = MLPPVector::create_vec_one(_n); + loutput_set = MLPPVector::create_vec_one(n); Ref cumulative_generator_hidden_layer_w_grad = compute_generator_gradients(_y_hat, loutput_set); - cumulative_generator_hidden_layer_w_grad->scalar_multiply(learning_rate / _n); + cumulative_generator_hidden_layer_w_grad->scalar_multiply(learning_rate / n); update_generator_parameters(cumulative_generator_hidden_layer_w_grad, learning_rate); forward_pass(); if (ui) { - handle_ui(epoch, cost_prev, _y_hat, MLPPVector::create_vec_one(_n)); + handle_ui(epoch, cost_prev, _y_hat, MLPPVector::create_vec_one(n)); } epoch++; @@ -104,7 +99,9 @@ void MLPPWGAN::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { real_t MLPPWGAN::score() { MLPPUtilities util; forward_pass(); - return util.performance_vec(_y_hat, MLPPVector::create_vec_one(_n)); + int n = _output_set->size().y; + + return util.performance_vec(_y_hat, MLPPVector::create_vec_one(n)); } void MLPPWGAN::save(const String &file_name) { @@ -134,8 +131,10 @@ void MLPPWGAN::create_layer(int n_hidden, MLPPActivation::ActivationFunction act layer->set_lambda(lambda); layer->set_alpha(alpha); + int n = _output_set->size().y; + if (_network.empty()) { - layer->set_input(MLPPMatrix::create_gaussian_noise(_n, _k)); + layer->set_input(MLPPMatrix::create_gaussian_noise(n, _k)); } else { layer->set_input(_network.write[_network.size() - 1]->get_a()); } @@ -149,7 +148,9 @@ void MLPPWGAN::add_layer(Ref layer) { } if (_network.empty()) { - layer->set_input(MLPPMatrix::create_gaussian_noise(_n, _k)); + int n = _output_set->size().y; + + layer->set_input(MLPPMatrix::create_gaussian_noise(n, _k)); } else { layer->set_input(_network.write[_network.size() - 1]->get_a()); } @@ -187,16 +188,14 @@ void MLPPWGAN::add_output_layer(MLPPUtilities::WeightDistributionType weight_ini _output_layer->set_alpha(alpha); } -MLPPWGAN::MLPPWGAN(real_t p_k, const Ref &p_output_set) { +MLPPWGAN::MLPPWGAN(int p_k, const Ref &p_output_set) { _output_set = p_output_set; - _n = p_output_set->size().y; _k = p_k; _y_hat.instance(); } MLPPWGAN::MLPPWGAN() { - _n = 0; _k = 0; _y_hat.instance(); @@ -256,10 +255,12 @@ real_t MLPPWGAN::cost(const Ref &y_hat, const Ref &y) { } void MLPPWGAN::forward_pass() { + int n = _output_set->size().y; + if (!_network.empty()) { Ref layer = _network[0]; - layer->set_input(MLPPMatrix::create_gaussian_noise(_n, _k)); + layer->set_input(MLPPMatrix::create_gaussian_noise(n, _k)); layer->forward_pass(); for (int i = 1; i < _network.size(); i++) { @@ -271,7 +272,7 @@ void MLPPWGAN::forward_pass() { _output_layer->set_input(_network.write[_network.size() - 1]->get_a()); } else { // Should never happen, though. - _output_layer->set_input(MLPPMatrix::create_gaussian_noise(_n, _k)); + _output_layer->set_input(MLPPMatrix::create_gaussian_noise(n, _k)); } _output_layer->forward_pass(); @@ -280,8 +281,10 @@ void MLPPWGAN::forward_pass() { } void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_updations, const Ref &output_layer_updation, real_t learning_rate) { + int n = _output_set->size().y; + _output_layer->set_weights(_output_layer->get_weights()->subn(output_layer_updation)); - _output_layer->set_bias(_output_layer->get_bias() - learning_rate * _output_layer->get_delta()->sum_elements() / _n); + _output_layer->set_bias(_output_layer->get_bias() - learning_rate * _output_layer->get_delta()->sum_elements() / n); if (!_network.empty()) { Ref layer = _network[_network.size() - 1]; @@ -292,7 +295,7 @@ void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_upd hidden_layer_updations->z_slice_get_into_mlpp_matrix(0, slice); layer->set_weights(layer->get_weights()->subn(slice)); - layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / _n))); + layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / n))); for (int i = _network.size() - 2; i > _network.size() / 2; i--) { layer = _network[i]; @@ -300,13 +303,15 @@ void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_upd hidden_layer_updations->z_slice_get_into_mlpp_matrix((_network.size() - 2) - i + 1, slice); layer->set_weights(layer->get_weights()->subn(slice)); - layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / _n))); + layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / n))); } } } void MLPPWGAN::update_generator_parameters(Ref hidden_layer_updations, real_t learning_rate) { if (!_network.empty()) { + int n = _output_set->size().y; + Ref slice; slice.instance(); @@ -318,7 +323,7 @@ void MLPPWGAN::update_generator_parameters(Ref hidden_layer_updatio //std::cout << network[i].weights.size() << "x" << network[i].weights[0].size() << std::endl; //std::cout << hiddenLayerUpdations[(network.size() - 2) - i + 1].size() << "x" << hiddenLayerUpdations[(network.size() - 2) - i + 1][0].size() << std::endl; layer->set_weights(layer->get_weights()->subn(slice)); - layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / _n))); + layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / n))); } } } diff --git a/mlpp/wgan/wgan.h b/mlpp/wgan/wgan.h index d114ed8..10cf596 100644 --- a/mlpp/wgan/wgan.h +++ b/mlpp/wgan/wgan.h @@ -49,7 +49,7 @@ public: void add_output_layer(MLPPUtilities::WeightDistributionType weight_init = MLPPUtilities::WEIGHT_DISTRIBUTION_TYPE_DEFAULT, MLPPReg::RegularizationType reg = MLPPReg::REGULARIZATION_TYPE_NONE, real_t lambda = 0.5, real_t alpha = 0.5); - MLPPWGAN(real_t k, const Ref &output_set); + MLPPWGAN(int k, const Ref &output_set); MLPPWGAN(); ~MLPPWGAN(); @@ -82,13 +82,12 @@ protected: static void _bind_methods(); Ref _output_set; - Ref _y_hat; + int _k; Vector> _network; Ref _output_layer; - int _n; - int _k; + Ref _y_hat; }; #endif /* WGAN_hpp */ \ No newline at end of file