Fix logic in MLPPSoftmaxReg.

This commit is contained in:
Relintai 2023-12-27 18:15:34 +01:00
parent 44893d7aae
commit 681118c9d2

View File

@ -273,7 +273,7 @@ bool MLPPSoftmaxReg::needs_init() const {
int k = _input_set->size().x; int k = _input_set->size().x;
int n_class = _output_set->size().x; int n_class = _output_set->size().x;
if (_y_hat->size().x != n) { if (_y_hat->size().y != n) {
return true; return true;
} }
@ -294,7 +294,7 @@ void MLPPSoftmaxReg::initialize() {
int k = _input_set->size().x; int k = _input_set->size().x;
int n_class = _output_set->size().x; int n_class = _output_set->size().x;
_y_hat->resize(Size2i(n, 0)); _y_hat->resize(Size2i(0, n));
MLPPUtilities util; MLPPUtilities util;
@ -315,6 +315,8 @@ MLPPSoftmaxReg::MLPPSoftmaxReg(const Ref<MLPPMatrix> &p_input_set, const Ref<MLP
_y_hat.instance(); _y_hat.instance();
_weights.instance(); _weights.instance();
_bias.instance(); _bias.instance();
initialize();
} }
MLPPSoftmaxReg::MLPPSoftmaxReg() { MLPPSoftmaxReg::MLPPSoftmaxReg() {