diff --git a/mlpp/bernoulli_nb/bernoulli_nb.cpp b/mlpp/bernoulli_nb/bernoulli_nb.cpp index 7333fe5..6c491aa 100644 --- a/mlpp/bernoulli_nb/bernoulli_nb.cpp +++ b/mlpp/bernoulli_nb/bernoulli_nb.cpp @@ -85,7 +85,9 @@ MLPPBernoulliNB::MLPPBernoulliNB(const Ref &p_input_set, const Refresize(_output_set->size()); evaluate(); @@ -157,7 +159,7 @@ void MLPPBernoulliNB::evaluate() { Vector found_indices; - for (int j = 0; j < _input_set->size().y; j++) { + for (int j = 0; j < _input_set->size().x; j++) { for (int k = 0; k < _vocab->size(); k++) { if (_input_set->element_get(i, j) == _vocab->element_get(k)) { score_0 += Math::log(static_cast(_theta[0][_vocab->element_get(k)])); diff --git a/mlpp/multinomial_nb/multinomial_nb.cpp b/mlpp/multinomial_nb/multinomial_nb.cpp index a8c017e..58b8795 100644 --- a/mlpp/multinomial_nb/multinomial_nb.cpp +++ b/mlpp/multinomial_nb/multinomial_nb.cpp @@ -135,7 +135,10 @@ MLPPMultinomialNB::MLPPMultinomialNB(const Ref &p_input_set, const R _output_set = p_output_set; _class_num = pclass_num; + _priors.instance(); + _vocab.instance(); _y_hat.instance(); + _y_hat->resize(_output_set->size()); _initialized = true;