From fe4ac8625c2aefc8f4c3a68663b0d1796797552a Mon Sep 17 00:00:00 2001 From: Relintai Date: Thu, 28 Dec 2023 11:30:56 +0100 Subject: [PATCH] Fixed typos in MLPPData::train_test_split. --- mlpp/data/data.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlpp/data/data.cpp b/mlpp/data/data.cpp index c07e717..2ccac6e 100644 --- a/mlpp/data/data.cpp +++ b/mlpp/data/data.cpp @@ -335,6 +335,9 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref data, Ref orig_input = data->get_input(); Ref orig_output = data->get_output(); + ERR_FAIL_COND_V(!orig_input.is_valid(), res); + ERR_FAIL_COND_V(!orig_output.is_valid(), res); + Size2i orig_input_size = orig_input->size(); Size2i orig_output_size = orig_output->size(); @@ -371,8 +374,8 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref data, orig_input->row_get_into_mlpp_vector(index, orig_input_row_tmp); orig_output->row_get_into_mlpp_vector(index, orig_output_row_tmp); - res_test_input->row_set_mlpp_vector(i, orig_input); - res_test_output->row_set_mlpp_vector(i, orig_output); + res_test_input->row_set_mlpp_vector(i, orig_input_row_tmp); + res_test_output->row_set_mlpp_vector(i, orig_output_row_tmp); } Ref res_train_input = res.train->get_input(); @@ -384,13 +387,13 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref data, res_train_output->resize(Size2i(orig_output_size.x, train_input_number)); for (int i = 0; i < train_input_number; ++i) { - int index = indices[train_input_number + i]; + int index = indices[test_input_number + i]; orig_input->row_get_into_mlpp_vector(index, orig_input_row_tmp); orig_output->row_get_into_mlpp_vector(index, orig_output_row_tmp); - res_train_input->row_set_mlpp_vector(i, orig_input); - res_train_output->row_set_mlpp_vector(i, orig_output); + res_train_input->row_set_mlpp_vector(i, orig_input_row_tmp); + res_train_output->row_set_mlpp_vector(i, orig_output_row_tmp); } return res;