From c227786c40dd63845d604d0585a49aa48bcb5534 Mon Sep 17 00:00:00 2001 From: Relintai Date: Wed, 27 Dec 2023 23:23:50 +0100 Subject: [PATCH] Fixed Wgan. --- mlpp/wgan/wgan.cpp | 46 +++++++++++++++++++++++++++------------------- mlpp/wgan/wgan.h | 9 ++++----- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/mlpp/wgan/wgan.cpp b/mlpp/wgan/wgan.cpp index dfe0741..3982a12 100644 --- a/mlpp/wgan/wgan.cpp +++ b/mlpp/wgan/wgan.cpp @@ -65,10 +65,15 @@ void MLPPWGAN::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { loutput_set->append_mlpp_vector(output_set_real); // Fake + real output scores. DiscriminatorGradientResult discriminator_gradient_results = compute_discriminator_gradients(ly_hat, loutput_set); - Ref cumulative_discriminator_hidden_layer_w_grad = discriminator_gradient_results.cumulative_hidden_layer_w_grad; + Vector> cumulative_discriminator_hidden_layer_w_grad = discriminator_gradient_results.cumulative_hidden_layer_w_grad; Ref output_discriminator_w_grad = discriminator_gradient_results.output_w_grad; - cumulative_discriminator_hidden_layer_w_grad->scalar_multiply(learning_rate / n); + real_t lrpn = learning_rate / n; + + for (int j = 0; j < cumulative_discriminator_hidden_layer_w_grad.size(); ++j) { + cumulative_discriminator_hidden_layer_w_grad.write[j]->scalar_multiply(lrpn); + } + output_discriminator_w_grad->scalar_multiply(learning_rate / n); update_discriminator_parameters(cumulative_discriminator_hidden_layer_w_grad, output_discriminator_w_grad, learning_rate); } @@ -79,8 +84,14 @@ void MLPPWGAN::gradient_descent(real_t learning_rate, int max_epoch, bool ui) { ly_hat = model_set_test_discriminator(discriminator_input_set); loutput_set = MLPPVector::create_vec_one(n); - Ref cumulative_generator_hidden_layer_w_grad = compute_generator_gradients(_y_hat, loutput_set); - cumulative_generator_hidden_layer_w_grad->scalar_multiply(learning_rate / n); + Vector> cumulative_generator_hidden_layer_w_grad = compute_generator_gradients(_y_hat, loutput_set); + + real_t lrpn = learning_rate / n; + + for (int i = 0; i < cumulative_generator_hidden_layer_w_grad.size(); ++i) { + cumulative_generator_hidden_layer_w_grad.write[i]->scalar_multiply(lrpn); + } + update_generator_parameters(cumulative_generator_hidden_layer_w_grad, learning_rate); forward_pass(); @@ -280,7 +291,7 @@ void MLPPWGAN::forward_pass() { _y_hat->set_from_mlpp_vector(_output_layer->get_a()); } -void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_updations, const Ref &output_layer_updation, real_t learning_rate) { +void MLPPWGAN::update_discriminator_parameters(const Vector> &hidden_layer_updations, const Ref &output_layer_updation, real_t learning_rate) { int n = _output_set->size().y; _output_layer->set_weights(_output_layer->get_weights()->subn(output_layer_updation)); @@ -289,10 +300,7 @@ void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_upd if (!_network.empty()) { Ref layer = _network[_network.size() - 1]; - Ref slice; - slice.instance(); - - hidden_layer_updations->z_slice_get_into_mlpp_matrix(0, slice); + Ref slice = hidden_layer_updations[0]; layer->set_weights(layer->get_weights()->subn(slice)); layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / n))); @@ -300,7 +308,7 @@ void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_upd for (int i = _network.size() - 2; i > _network.size() / 2; i--) { layer = _network[i]; - hidden_layer_updations->z_slice_get_into_mlpp_matrix((_network.size() - 2) - i + 1, slice); + slice = hidden_layer_updations[(_network.size() - 2) - i + 1]; layer->set_weights(layer->get_weights()->subn(slice)); layer->set_bias(layer->get_bias()->subtract_matrix_rowsn(layer->get_delta()->scalar_multiplyn(learning_rate / n))); @@ -308,17 +316,16 @@ void MLPPWGAN::update_discriminator_parameters(Ref hidden_layer_upd } } -void MLPPWGAN::update_generator_parameters(Ref hidden_layer_updations, real_t learning_rate) { +void MLPPWGAN::update_generator_parameters(const Vector> &hidden_layer_updations, real_t learning_rate) { if (!_network.empty()) { int n = _output_set->size().y; Ref slice; - slice.instance(); for (int i = _network.size() / 2; i >= 0; i--) { Ref layer = _network[i]; - hidden_layer_updations->z_slice_get_into_mlpp_matrix((_network.size() - 2) - i + 1, slice); + slice = hidden_layer_updations[(_network.size() - 2) - i + 1]; //std::cout << network[i].weights.size() << "x" << network[i].weights[0].size() << std::endl; //std::cout << hiddenLayerUpdations[(network.size() - 2) - i + 1].size() << "x" << hiddenLayerUpdations[(network.size() - 2) - i + 1][0].size() << std::endl; @@ -347,7 +354,7 @@ MLPPWGAN::DiscriminatorGradientResult MLPPWGAN::compute_discriminator_gradients( Ref hidden_layer_w_grad = layer->get_input()->transposen()->multn(layer->get_delta()); - data.cumulative_hidden_layer_w_grad->z_slice_add_mlpp_matrix(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. + data.cumulative_hidden_layer_w_grad.push_back(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. //std::cout << "HIDDENLAYER FIRST:" << hiddenLayerWGrad.size() << "x" << hiddenLayerWGrad[0].size() << std::endl; //std::cout << "WEIGHTS SECOND:" << layer.weights.size() << "x" << layer.weights[0].size() << std::endl; @@ -359,19 +366,19 @@ MLPPWGAN::DiscriminatorGradientResult MLPPWGAN::compute_discriminator_gradients( layer->set_delta(next_layer->get_delta()->multn(next_layer->get_weights()->transposen())->hadamard_productn(avn.run_activation_deriv_matrix(layer->get_activation(), layer->get_z()))); hidden_layer_w_grad = layer->get_input()->transposen()->multn(layer->get_delta()); - data.cumulative_hidden_layer_w_grad->z_slice_add_mlpp_matrix(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. + data.cumulative_hidden_layer_w_grad.push_back(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. } } return data; } -Ref MLPPWGAN::compute_generator_gradients(const Ref &y_hat, const Ref &output_set) { +Vector> MLPPWGAN::compute_generator_gradients(const Ref &y_hat, const Ref &output_set) { class MLPPCost cost; MLPPActivation avn; MLPPReg regularization; - Ref cumulative_hidden_layer_w_grad; // Tensor containing ALL hidden grads. + Vector> cumulative_hidden_layer_w_grad; // Tensor containing ALL hidden grads. Ref cost_deriv_vector = cost.run_cost_deriv_vector(_output_layer->get_cost(), y_hat, output_set); Ref activation_deriv_vector = avn.run_activation_deriv_vector(_output_layer->get_activation(), _output_layer->get_z()); @@ -388,7 +395,8 @@ Ref MLPPWGAN::compute_generator_gradients(const Ref &y_ layer->set_delta(_output_layer->get_delta()->outer_product(_output_layer->get_weights())->hadamard_productn(activation_deriv_matrix)); Ref hidden_layer_w_grad = layer->get_input()->transposen()->multn(layer->get_delta()); - cumulative_hidden_layer_w_grad->z_slice_add_mlpp_matrix(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. + + cumulative_hidden_layer_w_grad.push_back(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. for (int i = _network.size() - 2; i >= 0; i--) { layer = _network[i]; @@ -399,7 +407,7 @@ Ref MLPPWGAN::compute_generator_gradients(const Ref &y_ layer->set_delta(next_layer->get_delta()->multn(next_layer->get_weights()->transposen())->hadamard_productn(activation_deriv_matrix)); hidden_layer_w_grad = layer->get_input()->transposen()->multn(layer->get_delta()); - cumulative_hidden_layer_w_grad->z_slice_add_mlpp_matrix(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. + cumulative_hidden_layer_w_grad.push_back(hidden_layer_w_grad->addn(regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well. } } diff --git a/mlpp/wgan/wgan.h b/mlpp/wgan/wgan.h index 10cf596..afc6dd8 100644 --- a/mlpp/wgan/wgan.h +++ b/mlpp/wgan/wgan.h @@ -61,21 +61,20 @@ protected: real_t cost(const Ref &y_hat, const Ref &y); void forward_pass(); - void update_discriminator_parameters(Ref hidden_layer_updations, const Ref &output_layer_updation, real_t learning_rate); - void update_generator_parameters(Ref hidden_layer_updations, real_t learning_rate); + void update_discriminator_parameters(const Vector> &hidden_layer_updations, const Ref &output_layer_updation, real_t learning_rate); + void update_generator_parameters(const Vector> &hidden_layer_updations, real_t learning_rate); struct DiscriminatorGradientResult { - Ref cumulative_hidden_layer_w_grad; // Tensor containing ALL hidden grads. + Vector> cumulative_hidden_layer_w_grad; // Tensor containing ALL hidden grads. Ref output_w_grad; DiscriminatorGradientResult() { - cumulative_hidden_layer_w_grad.instance(); output_w_grad.instance(); } }; DiscriminatorGradientResult compute_discriminator_gradients(const Ref &y_hat, const Ref &output_set); - Ref compute_generator_gradients(const Ref &y_hat, const Ref &output_set); + Vector> compute_generator_gradients(const Ref &y_hat, const Ref &output_set); void handle_ui(int epoch, real_t cost_prev, const Ref &y_hat, const Ref &output_set);