Fixed typos in MLPPData::train_test_split.

This commit is contained in:
Relintai 2023-12-28 11:30:56 +01:00
parent 3f865aab1d
commit fe4ac8625c

View File

@ -335,6 +335,9 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref<MLPPDataComplex> data,
Ref<MLPPMatrix> orig_input = data->get_input();
Ref<MLPPMatrix> orig_output = data->get_output();
ERR_FAIL_COND_V(!orig_input.is_valid(), res);
ERR_FAIL_COND_V(!orig_output.is_valid(), res);
Size2i orig_input_size = orig_input->size();
Size2i orig_output_size = orig_output->size();
@ -371,8 +374,8 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref<MLPPDataComplex> data,
orig_input->row_get_into_mlpp_vector(index, orig_input_row_tmp);
orig_output->row_get_into_mlpp_vector(index, orig_output_row_tmp);
res_test_input->row_set_mlpp_vector(i, orig_input);
res_test_output->row_set_mlpp_vector(i, orig_output);
res_test_input->row_set_mlpp_vector(i, orig_input_row_tmp);
res_test_output->row_set_mlpp_vector(i, orig_output_row_tmp);
}
Ref<MLPPMatrix> res_train_input = res.train->get_input();
@ -384,13 +387,13 @@ MLPPData::SplitComplexData MLPPData::train_test_split(Ref<MLPPDataComplex> data,
res_train_output->resize(Size2i(orig_output_size.x, train_input_number));
for (int i = 0; i < train_input_number; ++i) {
int index = indices[train_input_number + i];
int index = indices[test_input_number + i];
orig_input->row_get_into_mlpp_vector(index, orig_input_row_tmp);
orig_output->row_get_into_mlpp_vector(index, orig_output_row_tmp);
res_train_input->row_set_mlpp_vector(i, orig_input);
res_train_output->row_set_mlpp_vector(i, orig_output);
res_train_input->row_set_mlpp_vector(i, orig_input_row_tmp);
res_train_output->row_set_mlpp_vector(i, orig_output_row_tmp);
}
return res;