From b1d00a629f370c6151b0c8fa61edae9a6658ff15 Mon Sep 17 00:00:00 2001 From: Relintai Date: Sun, 23 Apr 2023 11:09:46 +0200 Subject: [PATCH] Added vector based feature map get/set api to MLPPTensor3. --- mlpp/lin_alg/mlpp_tensor3.h | 142 ++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) diff --git a/mlpp/lin_alg/mlpp_tensor3.h b/mlpp/lin_alg/mlpp_tensor3.h index 2bbc2f7..859a7a1 100644 --- a/mlpp/lin_alg/mlpp_tensor3.h +++ b/mlpp/lin_alg/mlpp_tensor3.h @@ -403,6 +403,148 @@ public: } } + _FORCE_INLINE_ Vector get_feature_map_vector(int p_index_z) { + ERR_FAIL_INDEX_V(p_index_z, _size.z, Vector()); + + Vector ret; + + int fmds = feature_map_data_size(); + + if (unlikely(fmds == 0)) { + return ret; + } + + ret.resize(fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + real_t *row_ptr = ret.ptrw(); + + for (int i = 0; i < fmds; ++i) { + row_ptr[i] = _data[ind_start + i]; + } + + return ret; + } + + _FORCE_INLINE_ PoolRealArray get_feature_map_pool_vector(int p_index_z) { + ERR_FAIL_INDEX_V(p_index_z, _size.z, PoolRealArray()); + + PoolRealArray ret; + + int fmds = feature_map_data_size(); + + if (unlikely(fmds == 0)) { + return ret; + } + + ret.resize(fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + PoolRealArray::Write w = ret.write(); + real_t *row_ptr = w.ptr(); + + for (int i = 0; i < fmds; ++i) { + row_ptr[i] = _data[ind_start + i]; + } + + return ret; + } + + _FORCE_INLINE_ Ref get_feature_map_mlpp_vector(int p_index_z) { + ERR_FAIL_INDEX_V(p_index_z, _size.z, Ref()); + + Ref ret; + ret.instance(); + + int fmds = feature_map_data_size(); + + if (unlikely(fmds == 0)) { + return ret; + } + + ret->resize(fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + real_t *row_ptr = ret->ptrw(); + + for (int i = 0; i < fmds; ++i) { + row_ptr[i] = _data[ind_start + i]; + } + + return ret; + } + + _FORCE_INLINE_ void get_feature_map_into_mlpp_vector(int p_index_z, Ref target) const { + ERR_FAIL_INDEX(p_index_z, _size.z); + + int fmds = feature_map_data_size(); + + if (unlikely(target->size() != fmds)) { + target->resize(fmds); + } + + int ind_start = calculate_feature_map_index(p_index_z); + + real_t *row_ptr = target->ptrw(); + + for (int i = 0; i < fmds; ++i) { + row_ptr[i] = _data[ind_start + i]; + } + } + + _FORCE_INLINE_ void set_feature_map_vector(int p_index_z, const Vector &p_row) { + ERR_FAIL_INDEX(p_index_z, _size.z); + + int fmds = feature_map_data_size(); + + ERR_FAIL_COND(p_row.size() != fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + const real_t *row_ptr = p_row.ptr(); + + for (int i = 0; i < fmds; ++i) { + _data[ind_start + i] = row_ptr[i]; + } + } + + _FORCE_INLINE_ void set_feature_map_pool_vector(int p_index_z, const PoolRealArray &p_row) { + ERR_FAIL_INDEX(p_index_z, _size.z); + + int fmds = feature_map_data_size(); + + ERR_FAIL_COND(p_row.size() != fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + PoolRealArray::Read r = p_row.read(); + const real_t *row_ptr = r.ptr(); + + for (int i = 0; i < fmds; ++i) { + _data[ind_start + i] = row_ptr[i]; + } + } + + _FORCE_INLINE_ void set_feature_map_mlpp_vector(int p_index_z, const Ref &p_row) { + ERR_FAIL_INDEX(p_index_z, _size.z); + ERR_FAIL_COND(!p_row.is_valid()); + + int fmds = feature_map_data_size(); + + ERR_FAIL_COND(p_row->size() != fmds); + + int ind_start = calculate_feature_map_index(p_index_z); + + const real_t *row_ptr = p_row->ptr(); + + for (int i = 0; i < fmds; ++i) { + _data[ind_start + i] = row_ptr[i]; + } + } + void fill(real_t p_val) { if (!_data) { return;