mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-22 15:06:47 +01:00
Fixed MLPPWGAN's handle_ui.
This commit is contained in:
parent
7e71ea49d9
commit
f8710e6432
@ -5,6 +5,9 @@
|
||||
//
|
||||
|
||||
#include "wgan.h"
|
||||
|
||||
#include "core/log/logger.h"
|
||||
|
||||
#include "../activation/activation.h"
|
||||
#include "../cost/cost.h"
|
||||
#include "../lin_alg/lin_alg.h"
|
||||
@ -314,6 +317,7 @@ MLPPWGAN::DiscriminatorGradientResult MLPPWGAN::compute_discriminator_gradients(
|
||||
Ref<MLPPHiddenLayer> layer = network[network.size() - 1];
|
||||
|
||||
layer->set_delta(alg.hadamard_productm(alg.outer_product(output_layer->get_delta(), output_layer->get_weights()), avn.run_activation_deriv_matrix(layer->get_activation(), layer->get_z())));
|
||||
|
||||
Ref<MLPPMatrix> hidden_layer_w_grad = alg.matmultm(alg.transposem(layer->get_input()), layer->get_delta());
|
||||
|
||||
data.cumulative_hidden_layer_w_grad.push_back(alg.additionm(hidden_layer_w_grad, regularization.reg_deriv_termm(layer->get_weights(), layer->get_lambda(), layer->get_alpha(), layer->get_reg()))); // Adding to our cumulative hidden layer grads. Maintain reg terms as well.
|
||||
@ -378,7 +382,8 @@ Vector<Ref<MLPPMatrix>> MLPPWGAN::compute_generator_gradients(const Ref<MLPPVect
|
||||
|
||||
void MLPPWGAN::handle_ui(int epoch, real_t cost_prev, const Ref<MLPPVector> &y_hat, const Ref<MLPPVector> &output_set) {
|
||||
MLPPUtilities::cost_info(epoch, cost_prev, cost(y_hat, output_set));
|
||||
std::cout << "Layer " << network.size() + 1 << ": " << std::endl;
|
||||
|
||||
PLOG_MSG("Layer " + itos(network.size() + 1) + ":");
|
||||
|
||||
MLPPUtilities::print_ui_vb(output_layer->get_weights(), output_layer->get_bias());
|
||||
|
||||
@ -386,9 +391,9 @@ void MLPPWGAN::handle_ui(int epoch, real_t cost_prev, const Ref<MLPPVector> &y_h
|
||||
for (int i = network.size() - 1; i >= 0; i--) {
|
||||
Ref<MLPPHiddenLayer> layer = network[i];
|
||||
|
||||
std::cout << "Layer " << i + 1 << ": " << std::endl;
|
||||
PLOG_MSG("Layer " + itos(i + 1) + ":");
|
||||
|
||||
MLPPUtilities::print_ui_vib(layer->get_weights(), layer->get_bias(), 0);
|
||||
MLPPUtilities::print_ui_mb(layer->get_weights(), layer->get_bias());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user