Improved performance() methods in MLPPUtilities.

This commit is contained in:
Relintai 2023-12-27 18:40:41 +01:00
parent dda63a53e4
commit d5294823ef

View File

@ -309,13 +309,13 @@ real_t MLPPUtilities::performance_vec(const Ref<MLPPVector> &y_hat, const Ref<ML
ERR_FAIL_COND_V(!y_hat.is_valid(), 0);
ERR_FAIL_COND_V(!output_set.is_valid(), 0);
real_t correct = 0;
int correct = 0;
for (int i = 0; i < y_hat->size(); i++) {
if (Math::is_equal_approx(y_hat->element_get(i), output_set->element_get(i))) {
if (Math::is_equal_approx(Math::round(y_hat->element_get(i)), output_set->element_get(i))) {
correct++;
}
}
return correct / y_hat->size();
return correct / (real_t)y_hat->size();
}
real_t MLPPUtilities::performance_mat(const Ref<MLPPMatrix> &y_hat, const Ref<MLPPMatrix> &y) {
ERR_FAIL_COND_V(!y_hat.is_valid(), 0);
@ -326,7 +326,7 @@ real_t MLPPUtilities::performance_mat(const Ref<MLPPMatrix> &y_hat, const Ref<ML
int sub_correct = 0;
for (int j = 0; j < y_hat->size().x; j++) {
if (Math::round(y_hat->element_get(i, j)) == y->element_get(i, j)) {
if (Math::is_equal_approx(Math::round(y_hat->element_get(i, j)), y->element_get(i, j))) {
sub_correct++;
}
@ -335,7 +335,7 @@ real_t MLPPUtilities::performance_mat(const Ref<MLPPMatrix> &y_hat, const Ref<ML
}
}
}
return correct / y_hat->size().y;
return correct / (real_t)y_hat->size().y;
}
real_t MLPPUtilities::performance_pool_int_array_vec(PoolIntArray y_hat, const Ref<MLPPVector> &output_set) {
ERR_FAIL_COND_V(!output_set.is_valid(), 0);
@ -346,7 +346,7 @@ real_t MLPPUtilities::performance_pool_int_array_vec(PoolIntArray y_hat, const R
correct++;
}
}
return correct / y_hat.size();
return correct / (real_t)y_hat.size();
}
void MLPPUtilities::saveParameters(std::string fileName, std::vector<real_t> weights, real_t bias, bool app, int layer) {