mirror of
https://github.com/Relintai/pmlpp.git
synced 2024-12-21 14:56:47 +01:00
Added matrix based getters and setters to Tensor3.
This commit is contained in:
parent
83f4c22e74
commit
711009b02d
@ -504,6 +504,50 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
_FORCE_INLINE_ Ref<MLPPMatrix> get_feature_map_mlpp_matrix(int p_index_z) {
|
||||
ERR_FAIL_INDEX_V(p_index_z, _size.z, Ref<MLPPMatrix>());
|
||||
|
||||
Ref<MLPPMatrix> ret;
|
||||
ret.instance();
|
||||
|
||||
int fmds = feature_map_data_size();
|
||||
|
||||
if (unlikely(fmds == 0)) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret->resize(feature_map_size());
|
||||
|
||||
int ind_start = calculate_feature_map_index(p_index_z);
|
||||
|
||||
real_t *row_ptr = ret->ptrw();
|
||||
|
||||
for (int i = 0; i < fmds; ++i) {
|
||||
row_ptr[i] = _data[ind_start + i];
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
_FORCE_INLINE_ void get_feature_map_into_mlpp_vector(int p_index_z, Ref<MLPPMatrix> target) const {
|
||||
ERR_FAIL_INDEX(p_index_z, _size.z);
|
||||
|
||||
int fmds = feature_map_data_size();
|
||||
Size2i fms = feature_map_size();
|
||||
|
||||
if (unlikely(target->size() != fms)) {
|
||||
target->resize(fms);
|
||||
}
|
||||
|
||||
int ind_start = calculate_feature_map_index(p_index_z);
|
||||
|
||||
real_t *row_ptr = target->ptrw();
|
||||
|
||||
for (int i = 0; i < fmds; ++i) {
|
||||
row_ptr[i] = _data[ind_start + i];
|
||||
}
|
||||
}
|
||||
|
||||
_FORCE_INLINE_ void set_feature_map_vector(int p_index_z, const Vector<real_t> &p_row) {
|
||||
ERR_FAIL_INDEX(p_index_z, _size.z);
|
||||
|
||||
@ -554,6 +598,23 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
_FORCE_INLINE_ void set_feature_map_mlpp_matrix(int p_index_z, const Ref<MLPPMatrix> &p_mat) {
|
||||
ERR_FAIL_INDEX(p_index_z, _size.z);
|
||||
ERR_FAIL_COND(!p_mat.is_valid());
|
||||
|
||||
int fmds = feature_map_data_size();
|
||||
|
||||
ERR_FAIL_COND(p_mat->size() != feature_map_size());
|
||||
|
||||
int ind_start = calculate_feature_map_index(p_index_z);
|
||||
|
||||
const real_t *row_ptr = p_mat->ptr();
|
||||
|
||||
for (int i = 0; i < fmds; ++i) {
|
||||
_data[ind_start + i] = row_ptr[i];
|
||||
}
|
||||
}
|
||||
|
||||
void fill(real_t p_val) {
|
||||
if (!_data) {
|
||||
return;
|
||||
|
Loading…
Reference in New Issue
Block a user