diff --git a/mlpp/utilities/utilities.cpp b/mlpp/utilities/utilities.cpp index a9feca5..dbd0308 100644 --- a/mlpp/utilities/utilities.cpp +++ b/mlpp/utilities/utilities.cpp @@ -568,6 +568,189 @@ std::tuple>>, std::vector> MLPPUtilities::create_mini_batchesm(const Ref &input_set, int n_mini_batch) { + Size2i size = input_set->size(); + + int n = size.y; + int mini_batch_element_count = n / n_mini_batch; + + Ref row_tmp; + row_tmp.instance(); + row_tmp->resize(size.x); + + Vector> input_mini_batches; + + // Creating the mini-batches + for (int i = 0; i < n_mini_batch; i++) { + int mini_batch_start_offset = n_mini_batch * i; + Ref current_input_set; + current_input_set.instance(); + current_input_set->resize(Size2i(size.x, mini_batch_element_count)); + + for (int j = 0; j < mini_batch_element_count; j++) { + input_set->get_row_into_mlpp_vector(mini_batch_start_offset + j, row_tmp); + current_input_set->set_row_mlpp_vector(j, row_tmp); + } + + input_mini_batches.push_back(current_input_set); + } + + /* Don't think this can ever happen, todo double check + if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) { + for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) { + inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n_mini_batch * n_mini_batch + i]); + } + } + */ + + return input_mini_batches; +} +MLPPUtilities::CreateMiniBatchMVBatch MLPPUtilities::create_mini_batchesmv(const Ref &input_set, const Ref &output_set, int n_mini_batch) { + Size2i size = input_set->size(); + + int n = size.y; + int mini_batch_element_count = n / n_mini_batch; + + Ref row_tmp; + row_tmp.instance(); + row_tmp->resize(size.x); + + CreateMiniBatchMVBatch ret; + + for (int i = 0; i < n_mini_batch; i++) { + int mini_batch_start_offset = n_mini_batch * i; + Ref current_input_set; + current_input_set.instance(); + current_input_set->resize(Size2i(size.x, mini_batch_element_count)); + + Ref current_output_set; + current_output_set.instance(); + current_output_set->resize(mini_batch_element_count); + + for (int j = 0; j < mini_batch_element_count; j++) { + int main_indx = mini_batch_start_offset + j; + + input_set->get_row_into_mlpp_vector(main_indx, row_tmp); + current_input_set->set_row_mlpp_vector(j, row_tmp); + + current_output_set->set_element(j, output_set->get_element(j)); + } + + ret.input_sets.push_back(current_input_set); + ret.output_sets.push_back(current_output_set); + } + + /* Don't think this can ever happen, todo double check + if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) { + for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) { + inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n / n_mini_batch * n_mini_batch + i]); + outputMiniBatches[n_mini_batch - 1].push_back(outputSet[n / n_mini_batch * n_mini_batch + i]); + } + } + */ + + return ret; +} +MLPPUtilities::CreateMiniBatchMMBatch MLPPUtilities::create_mini_batchesmm(const Ref &input_set, const Ref &output_set, int n_mini_batch) { + Size2i input_set_size = input_set->size(); + Size2i output_set_size = output_set->size(); + + int n = input_set_size.y; + int mini_batch_element_count = n / n_mini_batch; + + Ref input_row_tmp; + input_row_tmp.instance(); + input_row_tmp->resize(input_set_size.x); + + Ref output_row_tmp; + output_row_tmp.instance(); + output_row_tmp->resize(output_set_size.x); + + CreateMiniBatchMMBatch ret; + + for (int i = 0; i < n_mini_batch; i++) { + int mini_batch_start_offset = n_mini_batch * i; + Ref current_input_set; + current_input_set.instance(); + current_input_set->resize(Size2i(input_set_size.x, mini_batch_element_count)); + + Ref current_output_set; + current_output_set.instance(); + current_output_set->resize(Size2i(output_set_size.x, mini_batch_element_count)); + + for (int j = 0; j < mini_batch_element_count; j++) { + int main_indx = mini_batch_start_offset + j; + + input_set->get_row_into_mlpp_vector(main_indx, input_row_tmp); + current_input_set->set_row_mlpp_vector(j, input_row_tmp); + + output_set->get_row_into_mlpp_vector(main_indx, output_row_tmp); + current_output_set->set_row_mlpp_vector(j, output_row_tmp); + } + + ret.input_sets.push_back(current_input_set); + ret.output_sets.push_back(current_output_set); + } + + /* Don't think this can ever happen, todo double check + if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) { + for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) { + inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n / n_mini_batch * n_mini_batch + i]); + } + } + */ + + return ret; +} + +Array MLPPUtilities::create_mini_batchesm_bind(const Ref &input_set, int n_mini_batch) { + Vector> batches = create_mini_batchesm(input_set, n_mini_batch); + + Array ret; + + for (int i = 0; i < batches.size(); ++i) { + ret.push_back(batches[i].get_ref_ptr()); + } + + return ret; +} +Array MLPPUtilities::create_mini_batchesmv_bind(const Ref &input_set, const Ref &output_set, int n_mini_batch) { + CreateMiniBatchMVBatch batches = create_mini_batchesmv(input_set, output_set, n_mini_batch); + + Array inputs; + Array outputs; + + for (int i = 0; i < batches.input_sets.size(); ++i) { + inputs.push_back(batches.input_sets[i].get_ref_ptr()); + outputs.push_back(batches.output_sets[i].get_ref_ptr()); + } + + Array ret; + + ret.push_back(inputs); + ret.push_back(outputs); + + return ret; +} +Array MLPPUtilities::create_mini_batchesmm_bind(const Ref &input_set, const Ref &output_set, int n_mini_batch) { + CreateMiniBatchMMBatch batches = create_mini_batchesmm(input_set, output_set, n_mini_batch); + + Array inputs; + Array outputs; + + for (int i = 0; i < batches.input_sets.size(); ++i) { + inputs.push_back(batches.input_sets[i].get_ref_ptr()); + outputs.push_back(batches.output_sets[i].get_ref_ptr()); + } + + Array ret; + + ret.push_back(inputs); + ret.push_back(outputs); + + return ret; +} + std::tuple MLPPUtilities::TF_PN(std::vector y_hat, std::vector y) { real_t TP, FP, TN, FN = 0; for (int i = 0; i < y_hat.size(); i++) { @@ -616,6 +799,10 @@ void MLPPUtilities::_bind_methods() { ClassDB::bind_method(D_METHOD("performance_mat", "y_hat", "y"), &MLPPUtilities::performance_mat); ClassDB::bind_method(D_METHOD("performance_pool_int_array_vec", "y_hat", "output_set"), &MLPPUtilities::performance_pool_int_array_vec); + ClassDB::bind_method(D_METHOD("create_mini_batchesm", "input_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesm_bind); + ClassDB::bind_method(D_METHOD("create_mini_batchesmv", "input_set", "output_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesmv_bind); + ClassDB::bind_method(D_METHOD("create_mini_batchesmm", "input_set", "output_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesmm_bind); + BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_DEFAULT); BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_NORMAL); BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_UNIFORM); diff --git a/mlpp/utilities/utilities.h b/mlpp/utilities/utilities.h index 9572c51..30fd465 100644 --- a/mlpp/utilities/utilities.h +++ b/mlpp/utilities/utilities.h @@ -73,6 +73,24 @@ public: static std::tuple>>, std::vector>> createMiniBatches(std::vector> inputSet, std::vector outputSet, int n_mini_batch); static std::tuple>>, std::vector>>> createMiniBatches(std::vector> inputSet, std::vector> outputSet, int n_mini_batch); + struct CreateMiniBatchMVBatch { + Vector> input_sets; + Vector> output_sets; + }; + + struct CreateMiniBatchMMBatch { + Vector> input_sets; + Vector> output_sets; + }; + + static Vector> create_mini_batchesm(const Ref &input_set, int n_mini_batch); + static CreateMiniBatchMVBatch create_mini_batchesmv(const Ref &input_set, const Ref &output_set, int n_mini_batch); + static CreateMiniBatchMMBatch create_mini_batchesmm(const Ref &input_set, const Ref &output_set, int n_mini_batch); + + Array create_mini_batchesm_bind(const Ref &input_set, int n_mini_batch); + Array create_mini_batchesmv_bind(const Ref &input_set, const Ref &output_set, int n_mini_batch); + Array create_mini_batchesmm_bind(const Ref &input_set, const Ref &output_set, int n_mini_batch); + // F1 score, Precision/Recall, TP, FP, TN, FN, etc. std::tuple TF_PN(std::vector y_hat, std::vector y); //TF_PN = "True", "False", "Positive", "Negative" real_t recall(std::vector y_hat, std::vector y);