mirror of
https://github.com/Relintai/pmlpp.git
synced 2025-02-01 17:07:02 +01:00
Ported mini batch creation methods in Utilities.
This commit is contained in:
parent
5f63aebc99
commit
7581be0e7f
@ -568,6 +568,189 @@ std::tuple<std::vector<std::vector<std::vector<real_t>>>, std::vector<std::vecto
|
||||
return { inputMiniBatches, outputMiniBatches };
|
||||
}
|
||||
|
||||
Vector<Ref<MLPPMatrix>> MLPPUtilities::create_mini_batchesm(const Ref<MLPPMatrix> &input_set, int n_mini_batch) {
|
||||
Size2i size = input_set->size();
|
||||
|
||||
int n = size.y;
|
||||
int mini_batch_element_count = n / n_mini_batch;
|
||||
|
||||
Ref<MLPPVector> row_tmp;
|
||||
row_tmp.instance();
|
||||
row_tmp->resize(size.x);
|
||||
|
||||
Vector<Ref<MLPPMatrix>> input_mini_batches;
|
||||
|
||||
// Creating the mini-batches
|
||||
for (int i = 0; i < n_mini_batch; i++) {
|
||||
int mini_batch_start_offset = n_mini_batch * i;
|
||||
Ref<MLPPMatrix> current_input_set;
|
||||
current_input_set.instance();
|
||||
current_input_set->resize(Size2i(size.x, mini_batch_element_count));
|
||||
|
||||
for (int j = 0; j < mini_batch_element_count; j++) {
|
||||
input_set->get_row_into_mlpp_vector(mini_batch_start_offset + j, row_tmp);
|
||||
current_input_set->set_row_mlpp_vector(j, row_tmp);
|
||||
}
|
||||
|
||||
input_mini_batches.push_back(current_input_set);
|
||||
}
|
||||
|
||||
/* Don't think this can ever happen, todo double check
|
||||
if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) {
|
||||
for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) {
|
||||
inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n_mini_batch * n_mini_batch + i]);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
return input_mini_batches;
|
||||
}
|
||||
MLPPUtilities::CreateMiniBatchMVBatch MLPPUtilities::create_mini_batchesmv(const Ref<MLPPMatrix> &input_set, const Ref<MLPPVector> &output_set, int n_mini_batch) {
|
||||
Size2i size = input_set->size();
|
||||
|
||||
int n = size.y;
|
||||
int mini_batch_element_count = n / n_mini_batch;
|
||||
|
||||
Ref<MLPPVector> row_tmp;
|
||||
row_tmp.instance();
|
||||
row_tmp->resize(size.x);
|
||||
|
||||
CreateMiniBatchMVBatch ret;
|
||||
|
||||
for (int i = 0; i < n_mini_batch; i++) {
|
||||
int mini_batch_start_offset = n_mini_batch * i;
|
||||
Ref<MLPPMatrix> current_input_set;
|
||||
current_input_set.instance();
|
||||
current_input_set->resize(Size2i(size.x, mini_batch_element_count));
|
||||
|
||||
Ref<MLPPVector> current_output_set;
|
||||
current_output_set.instance();
|
||||
current_output_set->resize(mini_batch_element_count);
|
||||
|
||||
for (int j = 0; j < mini_batch_element_count; j++) {
|
||||
int main_indx = mini_batch_start_offset + j;
|
||||
|
||||
input_set->get_row_into_mlpp_vector(main_indx, row_tmp);
|
||||
current_input_set->set_row_mlpp_vector(j, row_tmp);
|
||||
|
||||
current_output_set->set_element(j, output_set->get_element(j));
|
||||
}
|
||||
|
||||
ret.input_sets.push_back(current_input_set);
|
||||
ret.output_sets.push_back(current_output_set);
|
||||
}
|
||||
|
||||
/* Don't think this can ever happen, todo double check
|
||||
if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) {
|
||||
for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) {
|
||||
inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n / n_mini_batch * n_mini_batch + i]);
|
||||
outputMiniBatches[n_mini_batch - 1].push_back(outputSet[n / n_mini_batch * n_mini_batch + i]);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
return ret;
|
||||
}
|
||||
MLPPUtilities::CreateMiniBatchMMBatch MLPPUtilities::create_mini_batchesmm(const Ref<MLPPMatrix> &input_set, const Ref<MLPPMatrix> &output_set, int n_mini_batch) {
|
||||
Size2i input_set_size = input_set->size();
|
||||
Size2i output_set_size = output_set->size();
|
||||
|
||||
int n = input_set_size.y;
|
||||
int mini_batch_element_count = n / n_mini_batch;
|
||||
|
||||
Ref<MLPPVector> input_row_tmp;
|
||||
input_row_tmp.instance();
|
||||
input_row_tmp->resize(input_set_size.x);
|
||||
|
||||
Ref<MLPPVector> output_row_tmp;
|
||||
output_row_tmp.instance();
|
||||
output_row_tmp->resize(output_set_size.x);
|
||||
|
||||
CreateMiniBatchMMBatch ret;
|
||||
|
||||
for (int i = 0; i < n_mini_batch; i++) {
|
||||
int mini_batch_start_offset = n_mini_batch * i;
|
||||
Ref<MLPPMatrix> current_input_set;
|
||||
current_input_set.instance();
|
||||
current_input_set->resize(Size2i(input_set_size.x, mini_batch_element_count));
|
||||
|
||||
Ref<MLPPMatrix> current_output_set;
|
||||
current_output_set.instance();
|
||||
current_output_set->resize(Size2i(output_set_size.x, mini_batch_element_count));
|
||||
|
||||
for (int j = 0; j < mini_batch_element_count; j++) {
|
||||
int main_indx = mini_batch_start_offset + j;
|
||||
|
||||
input_set->get_row_into_mlpp_vector(main_indx, input_row_tmp);
|
||||
current_input_set->set_row_mlpp_vector(j, input_row_tmp);
|
||||
|
||||
output_set->get_row_into_mlpp_vector(main_indx, output_row_tmp);
|
||||
current_output_set->set_row_mlpp_vector(j, output_row_tmp);
|
||||
}
|
||||
|
||||
ret.input_sets.push_back(current_input_set);
|
||||
ret.output_sets.push_back(current_output_set);
|
||||
}
|
||||
|
||||
/* Don't think this can ever happen, todo double check
|
||||
if (real_t(n) / real_t(n_mini_batch) - int(n / n_mini_batch) != 0) {
|
||||
for (int i = 0; i < n - n / n_mini_batch * n_mini_batch; i++) {
|
||||
inputMiniBatches[n_mini_batch - 1].push_back(inputSet[n / n_mini_batch * n_mini_batch + i]);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
Array MLPPUtilities::create_mini_batchesm_bind(const Ref<MLPPMatrix> &input_set, int n_mini_batch) {
|
||||
Vector<Ref<MLPPMatrix>> batches = create_mini_batchesm(input_set, n_mini_batch);
|
||||
|
||||
Array ret;
|
||||
|
||||
for (int i = 0; i < batches.size(); ++i) {
|
||||
ret.push_back(batches[i].get_ref_ptr());
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
Array MLPPUtilities::create_mini_batchesmv_bind(const Ref<MLPPMatrix> &input_set, const Ref<MLPPVector> &output_set, int n_mini_batch) {
|
||||
CreateMiniBatchMVBatch batches = create_mini_batchesmv(input_set, output_set, n_mini_batch);
|
||||
|
||||
Array inputs;
|
||||
Array outputs;
|
||||
|
||||
for (int i = 0; i < batches.input_sets.size(); ++i) {
|
||||
inputs.push_back(batches.input_sets[i].get_ref_ptr());
|
||||
outputs.push_back(batches.output_sets[i].get_ref_ptr());
|
||||
}
|
||||
|
||||
Array ret;
|
||||
|
||||
ret.push_back(inputs);
|
||||
ret.push_back(outputs);
|
||||
|
||||
return ret;
|
||||
}
|
||||
Array MLPPUtilities::create_mini_batchesmm_bind(const Ref<MLPPMatrix> &input_set, const Ref<MLPPMatrix> &output_set, int n_mini_batch) {
|
||||
CreateMiniBatchMMBatch batches = create_mini_batchesmm(input_set, output_set, n_mini_batch);
|
||||
|
||||
Array inputs;
|
||||
Array outputs;
|
||||
|
||||
for (int i = 0; i < batches.input_sets.size(); ++i) {
|
||||
inputs.push_back(batches.input_sets[i].get_ref_ptr());
|
||||
outputs.push_back(batches.output_sets[i].get_ref_ptr());
|
||||
}
|
||||
|
||||
Array ret;
|
||||
|
||||
ret.push_back(inputs);
|
||||
ret.push_back(outputs);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::tuple<real_t, real_t, real_t, real_t> MLPPUtilities::TF_PN(std::vector<real_t> y_hat, std::vector<real_t> y) {
|
||||
real_t TP, FP, TN, FN = 0;
|
||||
for (int i = 0; i < y_hat.size(); i++) {
|
||||
@ -616,6 +799,10 @@ void MLPPUtilities::_bind_methods() {
|
||||
ClassDB::bind_method(D_METHOD("performance_mat", "y_hat", "y"), &MLPPUtilities::performance_mat);
|
||||
ClassDB::bind_method(D_METHOD("performance_pool_int_array_vec", "y_hat", "output_set"), &MLPPUtilities::performance_pool_int_array_vec);
|
||||
|
||||
ClassDB::bind_method(D_METHOD("create_mini_batchesm", "input_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesm_bind);
|
||||
ClassDB::bind_method(D_METHOD("create_mini_batchesmv", "input_set", "output_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesmv_bind);
|
||||
ClassDB::bind_method(D_METHOD("create_mini_batchesmm", "input_set", "output_set", "n_mini_batch"), &MLPPUtilities::create_mini_batchesmm_bind);
|
||||
|
||||
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_DEFAULT);
|
||||
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_NORMAL);
|
||||
BIND_ENUM_CONSTANT(WEIGHT_DISTRIBUTION_TYPE_XAVIER_UNIFORM);
|
||||
|
@ -73,6 +73,24 @@ public:
|
||||
static std::tuple<std::vector<std::vector<std::vector<real_t>>>, std::vector<std::vector<real_t>>> createMiniBatches(std::vector<std::vector<real_t>> inputSet, std::vector<real_t> outputSet, int n_mini_batch);
|
||||
static std::tuple<std::vector<std::vector<std::vector<real_t>>>, std::vector<std::vector<std::vector<real_t>>>> createMiniBatches(std::vector<std::vector<real_t>> inputSet, std::vector<std::vector<real_t>> outputSet, int n_mini_batch);
|
||||
|
||||
struct CreateMiniBatchMVBatch {
|
||||
Vector<Ref<MLPPMatrix>> input_sets;
|
||||
Vector<Ref<MLPPVector>> output_sets;
|
||||
};
|
||||
|
||||
struct CreateMiniBatchMMBatch {
|
||||
Vector<Ref<MLPPMatrix>> input_sets;
|
||||
Vector<Ref<MLPPMatrix>> output_sets;
|
||||
};
|
||||
|
||||
static Vector<Ref<MLPPMatrix>> create_mini_batchesm(const Ref<MLPPMatrix> &input_set, int n_mini_batch);
|
||||
static CreateMiniBatchMVBatch create_mini_batchesmv(const Ref<MLPPMatrix> &input_set, const Ref<MLPPVector> &output_set, int n_mini_batch);
|
||||
static CreateMiniBatchMMBatch create_mini_batchesmm(const Ref<MLPPMatrix> &input_set, const Ref<MLPPMatrix> &output_set, int n_mini_batch);
|
||||
|
||||
Array create_mini_batchesm_bind(const Ref<MLPPMatrix> &input_set, int n_mini_batch);
|
||||
Array create_mini_batchesmv_bind(const Ref<MLPPMatrix> &input_set, const Ref<MLPPVector> &output_set, int n_mini_batch);
|
||||
Array create_mini_batchesmm_bind(const Ref<MLPPMatrix> &input_set, const Ref<MLPPMatrix> &output_set, int n_mini_batch);
|
||||
|
||||
// F1 score, Precision/Recall, TP, FP, TN, FN, etc.
|
||||
std::tuple<real_t, real_t, real_t, real_t> TF_PN(std::vector<real_t> y_hat, std::vector<real_t> y); //TF_PN = "True", "False", "Positive", "Negative"
|
||||
real_t recall(std::vector<real_t> y_hat, std::vector<real_t> y);
|
||||
|
Loading…
Reference in New Issue
Block a user