tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
strided_layout.h
Go to the documentation of this file.
1#ifndef FUSED_STRIDED_LAYOUT_H
2#define FUSED_STRIDED_LAYOUT_H
3
4#include "config.h"
5#include "memory/mem_utils.h"
6
7template <my_size_t NumberOfDims>
9{
10 my_size_t shape[NumberOfDims];
11 my_size_t stride[NumberOfDims];
12
13 StridedLayout(const my_size_t dims[NumberOfDims]) noexcept
14 {
15 copy_n_optimized(dims, shape, NumberOfDims);
17 }
18
19 // copy constructor
20 StridedLayout(const StridedLayout &other) = default;
21
22 // move constructor
23 StridedLayout(StridedLayout &&other) = default;
24
25 // copy assignment
26 StridedLayout &operator=(const StridedLayout &other) = default;
27
28 // move assignment
30
31 FORCE_INLINE constexpr my_size_t getNumDims() const noexcept { return NumberOfDims; }
32
33 FORCE_INLINE my_size_t getDim(my_size_t i) const // TODO: conditionally noexcept
34 {
35#ifdef RUNTIME_USE_BOUNDS_CHECKING
36 if (i >= getNumDims())
37 {
38 MyErrorHandler::error("In StridedLayout, getDim(): index out of range!");
39 }
40#endif
41 return shape[i];
42 }
43
44 FORCE_INLINE my_size_t getStride(my_size_t i) const // TODO: conditionally noexcept
45 {
46#ifdef RUNTIME_USE_BOUNDS_CHECKING
47 if (i >= getNumDims())
48 {
49 MyErrorHandler::error("In StridedLayout, getStride(): index out of range!");
50 }
51#endif
52 return stride[i];
53 }
54
56 {
57 stride[getNumDims() - 1] = 1;
58 for (my_size_t i = getNumDims() - 1; i > 0; --i)
59 {
60 stride[i - 1] = stride[i] * shape[i];
61 }
62 }
63
64 FORCE_INLINE void compute_indices_from_flat(my_size_t flatIdx, my_size_t (&indices)[NumberOfDims]) const noexcept
65 {
66 // We assume: flatIdx = sum(indices[i] * stride_[i])
67 // Solve for indices[i] from highest stride to lowest stride.
68 for (my_size_t i = 0; i < getNumDims(); ++i)
69 {
70 const my_size_t s = stride[i];
71 const my_size_t idx = flatIdx / s;
72 indices[i] = idx;
73 flatIdx -= idx * s;
74 }
75 }
76
77 FORCE_INLINE my_size_t compute_flat_index(const my_size_t *indices) const // TODO: conditionally noexcept
78 {
79 my_size_t flatIndex = 0;
80 for (my_size_t i = 0; i < getNumDims(); ++i)
81 {
82#ifdef RUNTIME_USE_BOUNDS_CHECKING
83 if (indices[i] >= shape[i])
84 {
85 MyErrorHandler::error("In StridedLayout, compute_flat_index(): index out of range!");
86 }
87#endif
88 flatIndex += indices[i] * stride[i];
89 }
90 return flatIndex;
91 }
92
94 {
95 my_size_t off = 0;
96
97 for (my_size_t i = getNumDims(); i-- > 0;)
98 {
99 my_size_t idx = flat % shape[i];
100 flat /= shape[i];
101 off += idx * stride[i];
102 }
103 return off;
104 }
105
106 // FORCE_INLINE my_size_t remapFlatIndex(my_size_t flatIdx, const my_size_t (&permutations)[sizeof...(Dims)]) const noexcept
107 // {
108 // // Step 1: Unravel flat index in **view order** to multi-index
109 // my_size_t idx[numDims]; // idx[i] = index along original axis i
110 // for (my_size_t i = numDims; i-- > 0;)
111 // {
112 // const my_size_t dim = getDim(permutations[i]); // dim of permuted axis
113 // idx[i] = flatIdx % dim; // store in original axis position
114 // flatIdx /= dim;
115 // };
116
117 // // Step 2: Compute flat index in original tensor layout
118 // my_size_t remapedFlatIdx = 0;
119 // my_size_t factor = 1;
120 // for (my_size_t i = numDims; i-- > 0;)
121 // {
122 // remapedFlatIdx += idx[permutations[i]] * factor;
123 // factor *= getDim(i); // multiply by original axis size
124 // }
125
126 // return remapedFlatIdx;
127 // }
128};
129
130#endif // FUSED_STRIDED_LAYOUT_H
static void error(const T &msg)
Definition error_handler.h:30
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
STL-free memory utilities.
void copy_n_optimized(const T *src, T *dst, my_size_t count)
Copy elements from a source buffer to a destination buffer.
Definition mem_utils.h:100
Definition strided_layout.h:9
void compute_row_major_strides() noexcept
Definition strided_layout.h:55
FORCE_INLINE my_size_t getDim(my_size_t i) const
Definition strided_layout.h:33
FORCE_INLINE my_size_t computeOffsetFromFlat(my_size_t flat) const noexcept
Definition strided_layout.h:93
FORCE_INLINE my_size_t compute_flat_index(const my_size_t *indices) const
Definition strided_layout.h:77
FORCE_INLINE my_size_t getStride(my_size_t i) const
Definition strided_layout.h:44
FORCE_INLINE void compute_indices_from_flat(my_size_t flatIdx, my_size_t(&indices)[NumberOfDims]) const noexcept
Definition strided_layout.h:64
FORCE_INLINE constexpr my_size_t getNumDims() const noexcept
Definition strided_layout.h:31
StridedLayout & operator=(StridedLayout &&other)=default
StridedLayout & operator=(const StridedLayout &other)=default
my_size_t stride[NumberOfDims]
Definition strided_layout.h:11
StridedLayout(StridedLayout &&other)=default
StridedLayout(const my_size_t dims[NumberOfDims]) noexcept
Definition strided_layout.h:13
my_size_t shape[NumberOfDims]
Definition strided_layout.h:10
StridedLayout(const StridedLayout &other)=default