2023-04-16 15:08:11 +02:00
# ifndef MLPP_TENSOR3_H
# define MLPP_TENSOR3_H
2023-10-24 21:38:45 +02:00
# ifndef GDNATIVE
2023-04-16 15:08:11 +02:00
# include "core/math/math_defs.h"
# include "core/containers/pool_vector.h"
# include "core/containers/sort_array.h"
# include "core/containers/vector.h"
# include "core/error/error_macros.h"
# include "core/math/vector2i.h"
# include "core/os/memory.h"
2023-04-29 12:31:57 +02:00
# include "core/object/resource.h"
2023-04-16 15:08:11 +02:00
2023-12-29 17:32:31 +01:00
# else
2023-10-24 21:38:45 +02:00
2023-12-29 17:32:31 +01:00
# include "core/containers/vector.h"
2023-10-24 21:38:45 +02:00
# include "core/defs.h"
# include "core/math_funcs.h"
# include "core/os/memory.h"
2023-12-29 17:32:31 +01:00
# include "core/pool_arrays.h"
2023-10-24 21:38:45 +02:00
# include "gen/resource.h"
# endif
2023-04-16 15:08:11 +02:00
# include "mlpp_matrix.h"
# include "mlpp_vector.h"
2023-04-23 13:16:04 +02:00
class Image ;
2023-04-29 12:31:57 +02:00
class MLPPTensor3 : public Resource {
GDCLASS ( MLPPTensor3 , Resource ) ;
2023-04-16 15:08:11 +02:00
public :
2023-04-29 12:47:45 +02:00
Array get_data ( ) ;
void set_data ( const Array & p_from ) ;
2023-04-29 13:44:18 +02:00
2023-04-25 18:04:34 +02:00
_FORCE_INLINE_ real_t * ptrw ( ) {
2023-04-16 15:08:11 +02:00
return _data ;
}
2023-04-25 18:04:34 +02:00
_FORCE_INLINE_ const real_t * ptr ( ) const {
2023-04-16 15:08:11 +02:00
return _data ;
}
2023-04-29 13:44:18 +02:00
void z_slice_add ( const Vector < real_t > & p_row ) ;
void z_slice_add_pool_vector ( const PoolRealArray & p_row ) ;
void z_slice_add_mlpp_vector ( const Ref < MLPPVector > & p_row ) ;
void z_slice_add_mlpp_matrix ( const Ref < MLPPMatrix > & p_matrix ) ;
void z_slice_remove ( int p_index ) ;
2023-04-16 15:08:11 +02:00
// Removes the item copying the last value into the position of the one to
// remove. It's generally faster than `remove`.
2023-04-29 13:44:18 +02:00
void z_slice_remove_unordered ( int p_index ) ;
2023-04-16 15:08:11 +02:00
2023-04-29 13:44:18 +02:00
void z_slice_swap ( int p_index_1 , int p_index_2 ) ;
2023-04-16 15:08:11 +02:00
2023-04-16 20:28:50 +02:00
_FORCE_INLINE_ void clear ( ) { resize ( Size3i ( ) ) ; }
2023-04-16 15:08:11 +02:00
_FORCE_INLINE_ void reset ( ) {
if ( _data ) {
memfree ( _data ) ;
_data = NULL ;
2023-04-16 20:28:50 +02:00
_size = Size3i ( ) ;
2023-04-16 15:08:11 +02:00
}
}
2023-04-16 20:28:50 +02:00
_FORCE_INLINE_ bool empty ( ) const { return _size = = Size3i ( ) ; }
2023-04-25 17:49:08 +02:00
_FORCE_INLINE_ int z_slice_data_size ( ) const { return _size . x * _size . y ; }
_FORCE_INLINE_ Size2i z_slice_size ( ) const { return Size2i ( _size . x , _size . y ) ; }
2023-04-16 20:28:50 +02:00
_FORCE_INLINE_ int data_size ( ) const { return _size . x * _size . y * _size . z ; }
_FORCE_INLINE_ Size3i size ( ) const { return _size ; }
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
void resize ( const Size3i & p_size ) ;
2023-04-29 13:50:35 +02:00
void shape_set ( const Size3i & p_size ) ;
2023-04-23 11:46:35 +02:00
2023-12-29 18:57:31 +01:00
_FORCE_INLINE_ int calculate_index ( int p_index_z , int p_index_y , int p_index_x ) const {
2023-04-16 20:28:50 +02:00
return p_index_y * _size . x + p_index_x + _size . x * _size . y * p_index_z ;
2023-04-16 15:08:11 +02:00
}
2023-04-25 17:49:08 +02:00
_FORCE_INLINE_ int calculate_z_slice_index ( int p_index_z ) const {
2023-04-23 10:59:50 +02:00
return _size . x * _size . y * p_index_z ;
}
2023-04-16 15:08:11 +02:00
_FORCE_INLINE_ const real_t & operator [ ] ( int p_index ) const {
CRASH_BAD_INDEX ( p_index , data_size ( ) ) ;
return _data [ p_index ] ;
}
_FORCE_INLINE_ real_t & operator [ ] ( int p_index ) {
CRASH_BAD_INDEX ( p_index , data_size ( ) ) ;
return _data [ p_index ] ;
}
2023-04-29 13:44:18 +02:00
_FORCE_INLINE_ real_t element_get_index ( int p_index ) const {
2023-04-16 20:38:50 +02:00
ERR_FAIL_INDEX_V ( p_index , data_size ( ) , 0 ) ;
return _data [ p_index ] ;
}
2023-04-29 13:44:18 +02:00
_FORCE_INLINE_ void element_set_index ( int p_index , real_t p_val ) {
2023-04-16 20:38:50 +02:00
ERR_FAIL_INDEX ( p_index , data_size ( ) ) ;
_data [ p_index ] = p_val ;
}
2023-12-29 18:57:31 +01:00
_FORCE_INLINE_ real_t element_get ( int p_index_z , int p_index_y , int p_index_x ) const {
2023-04-16 15:08:11 +02:00
ERR_FAIL_INDEX_V ( p_index_x , _size . x , 0 ) ;
ERR_FAIL_INDEX_V ( p_index_y , _size . y , 0 ) ;
2023-04-16 20:28:50 +02:00
ERR_FAIL_INDEX_V ( p_index_z , _size . z , 0 ) ;
2023-04-16 15:08:11 +02:00
2023-04-16 20:28:50 +02:00
return _data [ p_index_y * _size . x + p_index_x + _size . x * _size . y * p_index_z ] ;
2023-04-16 15:08:11 +02:00
}
2023-12-29 18:57:31 +01:00
_FORCE_INLINE_ void element_set ( int p_index_z , int p_index_y , int p_index_x , real_t p_val ) {
2023-04-16 15:08:11 +02:00
ERR_FAIL_INDEX ( p_index_x , _size . x ) ;
ERR_FAIL_INDEX ( p_index_y , _size . y ) ;
2023-04-16 20:28:50 +02:00
ERR_FAIL_INDEX ( p_index_z , _size . z ) ;
2023-04-16 15:08:11 +02:00
2023-04-16 20:28:50 +02:00
_data [ p_index_y * _size . x + p_index_x + _size . x * _size . y * p_index_z ] = p_val ;
2023-04-16 15:08:11 +02:00
}
2023-12-29 17:32:31 +01:00
Vector < real_t > row_get_vector ( int p_index_z , int p_index_y ) const ;
PoolRealArray row_get_pool_vector ( int p_index_z , int p_index_y ) const ;
Ref < MLPPVector > row_get_mlpp_vector ( int p_index_z , int p_index_y ) const ;
void row_get_into_mlpp_vector ( int p_index_z , int p_index_y , Ref < MLPPVector > target ) const ;
2023-04-23 11:53:57 +02:00
2023-12-29 18:57:31 +01:00
void row_set_vector ( int p_index_z , int p_index_y , const Vector < real_t > & p_row ) ;
void row_set_pool_vector ( int p_index_z , int p_index_y , const PoolRealArray & p_row ) ;
void row_set_mlpp_vector ( int p_index_z , int p_index_y , const Ref < MLPPVector > & p_row ) ;
2023-04-23 11:53:57 +02:00
2023-04-29 13:44:18 +02:00
Vector < real_t > z_slice_get_vector ( int p_index_z ) const ;
PoolRealArray z_slice_get_pool_vector ( int p_index_z ) const ;
Ref < MLPPVector > z_slice_get_mlpp_vector ( int p_index_z ) const ;
void z_slice_get_into_mlpp_vector ( int p_index_z , Ref < MLPPVector > target ) const ;
Ref < MLPPMatrix > z_slice_get_mlpp_matrix ( int p_index_z ) const ;
void z_slice_get_into_mlpp_matrix ( int p_index_z , Ref < MLPPMatrix > target ) const ;
2023-04-23 11:09:46 +02:00
2023-04-29 13:44:18 +02:00
void z_slice_set_vector ( int p_index_z , const Vector < real_t > & p_row ) ;
void z_slice_set_pool_vector ( int p_index_z , const PoolRealArray & p_row ) ;
void z_slice_set_mlpp_vector ( int p_index_z , const Ref < MLPPVector > & p_row ) ;
void z_slice_set_mlpp_matrix ( int p_index_z , const Ref < MLPPMatrix > & p_mat ) ;
2023-04-23 11:53:57 +02:00
2023-04-25 20:21:12 +02:00
//TODO resize() need to be reworked for add and remove to work, in any other direction than z
2023-04-29 13:44:18 +02:00
//void x_slice_add(const Ref<MLPPMatrix> &p_matrix);
//void x_slice_remove(int p_index);
void x_slice_get_into ( int p_index_x , Ref < MLPPMatrix > target ) const ;
Ref < MLPPMatrix > x_slice_get ( int p_index_x ) const ;
void x_slice_set ( int p_index_x , const Ref < MLPPMatrix > & p_mat ) ;
//void y_slice_add(const Ref<MLPPMatrix> &p_matrix);
//void y_slice_remove(int p_index);
void y_slice_get_into ( int p_index_y , Ref < MLPPMatrix > target ) const ;
Ref < MLPPMatrix > y_slice_get ( int p_index_y ) const ;
void y_slice_set ( int p_index_y , const Ref < MLPPMatrix > & p_mat ) ;
2023-04-25 20:21:12 +02:00
2023-04-23 13:16:04 +02:00
public :
//Image api
2023-04-23 15:42:38 +02:00
enum ImageChannelFlags {
IMAGE_CHANNEL_FLAG_R = 1 < < 0 ,
IMAGE_CHANNEL_FLAG_G = 1 < < 1 ,
IMAGE_CHANNEL_FLAG_B = 1 < < 2 ,
IMAGE_CHANNEL_FLAG_A = 1 < < 3 ,
IMAGE_CHANNEL_FLAG_NONE = 0 ,
IMAGE_CHANNEL_FLAG_RG = IMAGE_CHANNEL_FLAG_R | IMAGE_CHANNEL_FLAG_G ,
IMAGE_CHANNEL_FLAG_RGB = IMAGE_CHANNEL_FLAG_R | IMAGE_CHANNEL_FLAG_G | IMAGE_CHANNEL_FLAG_B ,
IMAGE_CHANNEL_FLAG_GB = IMAGE_CHANNEL_FLAG_G | IMAGE_CHANNEL_FLAG_B ,
IMAGE_CHANNEL_FLAG_GBA = IMAGE_CHANNEL_FLAG_G | IMAGE_CHANNEL_FLAG_B | IMAGE_CHANNEL_FLAG_A ,
IMAGE_CHANNEL_FLAG_BA = IMAGE_CHANNEL_FLAG_B | IMAGE_CHANNEL_FLAG_A ,
IMAGE_CHANNEL_FLAG_RGBA = IMAGE_CHANNEL_FLAG_R | IMAGE_CHANNEL_FLAG_G | IMAGE_CHANNEL_FLAG_B | IMAGE_CHANNEL_FLAG_A ,
2023-04-23 13:16:04 +02:00
} ;
2023-04-29 13:44:18 +02:00
void z_slices_add_image ( const Ref < Image > & p_img , const int p_channels = IMAGE_CHANNEL_FLAG_RGBA ) ;
2023-04-23 13:16:04 +02:00
2023-04-29 13:44:18 +02:00
Ref < Image > z_slice_get_image ( const int p_index_z ) const ;
Ref < Image > z_slices_get_image ( const int p_index_r = - 1 , const int p_index_g = - 1 , const int p_index_b = - 1 , const int p_index_a = - 1 ) const ;
2023-04-23 13:16:04 +02:00
2023-04-29 13:44:18 +02:00
void z_slice_get_into_image ( Ref < Image > p_target , const int p_index_z , const int p_target_channels = IMAGE_CHANNEL_FLAG_RGB ) const ;
void z_slices_get_into_image ( Ref < Image > p_target , const int p_index_r = - 1 , const int p_index_g = - 1 , const int p_index_b = - 1 , const int p_index_a = - 1 ) const ;
2023-04-23 13:16:04 +02:00
2023-04-29 13:44:18 +02:00
void z_slice_set_image ( const Ref < Image > & p_img , const int p_index_z , const int p_image_channel_flag = IMAGE_CHANNEL_FLAG_R ) ;
void z_slices_set_image ( const Ref < Image > & p_img , const int p_index_r = - 1 , const int p_index_g = - 1 , const int p_index_b = - 1 , const int p_index_a = - 1 ) ;
2023-04-23 13:16:04 +02:00
2023-04-23 15:42:38 +02:00
void set_from_image ( const Ref < Image > & p_img , const int p_channels = IMAGE_CHANNEL_FLAG_RGBA ) ;
2023-04-23 13:16:04 +02:00
2023-04-29 13:44:18 +02:00
//void x_slices_add_image(const Ref<Image> &p_img, const int p_channels = IMAGE_CHANNEL_FLAG_RGBA);
Ref < Image > x_slice_get_image ( const int p_index_x ) const ;
void x_slice_get_into_image ( Ref < Image > p_target , const int p_index_x , const int p_target_channels = IMAGE_CHANNEL_FLAG_RGB ) const ;
void x_slice_set_image ( const Ref < Image > & p_img , const int p_index_x , const int p_image_channel_flag = IMAGE_CHANNEL_FLAG_R ) ;
2023-04-25 20:21:12 +02:00
2023-04-29 13:44:18 +02:00
//void y_slices_add_image(const Ref<Image> &p_img, const int p_channels = IMAGE_CHANNEL_FLAG_RGBA);
Ref < Image > y_slice_get_image ( const int p_index_y ) const ;
void y_slice_get_into_image ( Ref < Image > p_target , const int p_index_y , const int p_target_channels = IMAGE_CHANNEL_FLAG_RGB ) const ;
void y_slice_set_image ( const Ref < Image > & p_img , const int p_index_y , const int p_image_channel_flag = IMAGE_CHANNEL_FLAG_R ) ;
2023-04-25 20:21:12 +02:00
2023-04-24 11:40:46 +02:00
public :
//math api
2023-04-25 14:06:12 +02:00
void add ( const Ref < MLPPTensor3 > & B ) ;
2023-04-25 17:46:42 +02:00
Ref < MLPPTensor3 > addn ( const Ref < MLPPTensor3 > & B ) const ;
2023-04-25 14:06:12 +02:00
void addb ( const Ref < MLPPTensor3 > & A , const Ref < MLPPTensor3 > & B ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void sub ( const Ref < MLPPTensor3 > & B ) ;
2023-04-25 17:46:42 +02:00
Ref < MLPPTensor3 > subn ( const Ref < MLPPTensor3 > & B ) const ;
2023-04-25 14:06:12 +02:00
void subb ( const Ref < MLPPTensor3 > & A , const Ref < MLPPTensor3 > & B ) ;
2023-04-24 11:40:46 +02:00
2023-04-29 13:50:35 +02:00
void division_element_wise ( const Ref < MLPPTensor3 > & B ) ;
Ref < MLPPTensor3 > division_element_wisen ( const Ref < MLPPTensor3 > & B ) const ;
void division_element_wiseb ( const Ref < MLPPTensor3 > & A , const Ref < MLPPTensor3 > & B ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void sqrt ( ) ;
Ref < MLPPTensor3 > sqrtn ( ) const ;
void sqrtb ( const Ref < MLPPTensor3 > & A ) ;
void exponentiate ( real_t p ) ;
Ref < MLPPTensor3 > exponentiaten ( real_t p ) const ;
void exponentiateb ( const Ref < MLPPTensor3 > & A , real_t p ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void scalar_multiply ( const real_t scalar ) ;
Ref < MLPPTensor3 > scalar_multiplyn ( const real_t scalar ) const ;
void scalar_multiplyb ( const real_t scalar , const Ref < MLPPTensor3 > & A ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void scalar_add ( const real_t scalar ) ;
Ref < MLPPTensor3 > scalar_addn ( const real_t scalar ) const ;
void scalar_addb ( const real_t scalar , const Ref < MLPPTensor3 > & A ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void hadamard_product ( const Ref < MLPPTensor3 > & B ) ;
Ref < MLPPTensor3 > hadamard_productn ( const Ref < MLPPTensor3 > & B ) const ;
void hadamard_productb ( const Ref < MLPPTensor3 > & A , const Ref < MLPPTensor3 > & B ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void max ( const Ref < MLPPTensor3 > & B ) ;
2023-04-25 17:46:42 +02:00
Ref < MLPPTensor3 > maxn ( const Ref < MLPPTensor3 > & B ) const ;
2023-04-25 14:06:12 +02:00
void maxb ( const Ref < MLPPTensor3 > & A , const Ref < MLPPTensor3 > & B ) ;
2023-04-24 11:40:46 +02:00
2023-04-25 14:06:12 +02:00
void abs ( ) ;
Ref < MLPPTensor3 > absn ( ) const ;
void absb ( const Ref < MLPPTensor3 > & A ) ;
Ref < MLPPVector > flatten ( ) const ;
void flatteno ( Ref < MLPPVector > out ) const ;
2023-04-24 11:40:46 +02:00
//real_t norm_2(std::vector<std::vector<std::vector<real_t>>> A);
2023-12-29 17:32:31 +01:00
Ref < MLPPMatrix > tensor_vec_mult ( const Ref < MLPPVector > & b ) ;
2023-04-24 11:40:46 +02:00
//std::vector<std::vector<std::vector<real_t>>> vector_wise_tensor_product(std::vector<std::vector<std::vector<real_t>>> A, std::vector<std::vector<real_t>> B);
2023-04-23 13:16:04 +02:00
public :
2023-04-25 18:03:22 +02:00
void fill ( real_t p_val ) ;
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
Vector < real_t > to_flat_vector ( ) const ;
PoolRealArray to_flat_pool_vector ( ) const ;
Vector < uint8_t > to_flat_byte_array ( ) const ;
2023-04-16 15:08:11 +02:00
2023-04-29 12:31:57 +02:00
Ref < MLPPTensor3 > duplicate_fast ( ) const ;
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
void set_from_mlpp_tensor3 ( const Ref < MLPPTensor3 > & p_from ) ;
void set_from_mlpp_tensor3r ( const MLPPTensor3 & p_from ) ;
2023-04-23 12:15:56 +02:00
2023-04-25 18:03:22 +02:00
void set_from_mlpp_matrix ( const Ref < MLPPMatrix > & p_from ) ;
void set_from_mlpp_matrixr ( const MLPPMatrix & p_from ) ;
void set_from_mlpp_vectors ( const Vector < Ref < MLPPVector > > & p_from ) ;
void set_from_mlpp_matricess ( const Vector < Ref < MLPPMatrix > > & p_from ) ;
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
void set_from_mlpp_vectors_array ( const Array & p_from ) ;
void set_from_mlpp_matrices_array ( const Array & p_from ) ;
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
bool is_equal_approx ( const Ref < MLPPTensor3 > & p_with , real_t tolerance = static_cast < real_t > ( CMP_EPSILON ) ) const ;
2023-04-16 15:08:11 +02:00
2023-04-16 20:28:50 +02:00
String to_string ( ) ;
2023-04-16 15:08:11 +02:00
2023-04-25 18:03:22 +02:00
MLPPTensor3 ( ) ;
MLPPTensor3 ( const MLPPMatrix & p_from ) ;
MLPPTensor3 ( const Array & p_from ) ;
~ MLPPTensor3 ( ) ;
2023-04-16 15:08:11 +02:00
// TODO: These are temporary
std : : vector < real_t > to_flat_std_vector ( ) const ;
2023-04-16 20:28:50 +02:00
void set_from_std_vectors ( const std : : vector < std : : vector < std : : vector < real_t > > > & p_from ) ;
std : : vector < std : : vector < std : : vector < real_t > > > to_std_vector ( ) ;
MLPPTensor3 ( const std : : vector < std : : vector < std : : vector < real_t > > > & p_from ) ;
2023-04-16 15:08:11 +02:00
protected :
static void _bind_methods ( ) ;
protected :
2023-04-16 20:28:50 +02:00
Size3i _size ;
2023-04-16 15:08:11 +02:00
real_t * _data ;
} ;
2023-04-23 15:42:38 +02:00
VARIANT_ENUM_CAST ( MLPPTensor3 : : ImageChannelFlags ) ;
2023-04-23 13:16:04 +02:00
2023-04-16 15:08:11 +02:00
# endif