tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
permuted_view.h
Go to the documentation of this file.
1#ifndef FUSED_PERMUTED_VIEW_H
2#define FUSED_PERMUTED_VIEW_H
3
4#include "config.h"
5#include "fused/BaseExpr.h"
7
8#include <algorithm> // for std::max_element, std::min_element
9
10template <typename Tensor, my_size_t N>
11class PermutedView : public BaseExpr<PermutedView<Tensor, N>>
12{
13public:
14 using value_type = typename Tensor::value_type;
15 static constexpr my_size_t NumDims = N;
16 static constexpr const my_size_t *Dim = Tensor::Dim;
17 static constexpr my_size_t TotalSize = Tensor::TotalSize;
18
19 explicit PermutedView(const Tensor &t, const my_size_t perm[NumDims])
20 : t_(t), layout_(t.layout_) // Bind the reference member t_ to the existing object t
21 // and copy the layout from the base tensor
22 {
23 // runtime checks TODO: get rid of std
24 auto max_it = std::max_element(perm, perm + NumDims);
25 auto min_it = std::min_element(perm, perm + NumDims);
26
27 if (*max_it != NumDims - 1)
28 MyErrorHandler::error("Max value of permutation array is greater than the tensor's number of dimensions");
29
30 if (*min_it != 0)
31 MyErrorHandler::error("Min value of permutation array is not equal to 0");
32
33 // TODO: check that all values in perm are unique
34 // and "Permutation pack must match tensor's number of dimensions"
35 for (std::size_t i = 0; i < NumDims; ++i)
36 {
37 // then set the permuted shape and stride
38 layout_.shape[i] = t_.getDim(perm[i]);
39 layout_.stride[i] = t_.getStride(perm[i]);
40 }
41 }
42
43 // delete copy constructor and copy assignment to avoid accidental copies
44 PermutedView(const PermutedView &) = delete;
46
47 // delete move constructor and move assignment to avoid accidental moves
50
51 template <typename Output>
52 bool may_alias(const Output &output) const noexcept
53 {
54 return t_.may_alias(output); // recurse to underlying tensor
55 }
56
57 // Const version of the access operator, because this is a view
58 template <typename... Indices>
59 requires(sizeof...(Indices) == NumDims)
60 FORCE_INLINE const value_type &operator()(Indices... indices) const noexcept
61 {
62 my_size_t idxArray[] = {static_cast<my_size_t>(indices)...};
63 return t_.data_.data()[layout_.compute_flat_index(idxArray)];
64 }
65
66 // Const version of the access operator with array of indices, because this is a view
67 FORCE_INLINE const value_type &operator()(my_size_t (&indices)[NumDims]) const noexcept
68 {
69 return t_.data_.data()[layout_.compute_flat_index(indices)];
70 }
71
72 FORCE_INLINE const value_type &operator()(const my_size_t *indices) const noexcept
73 {
74 return t_.data_.data()[layout_.compute_flat_index(indices)];
75 }
76
77 template <typename T, my_size_t Bits, typename Arch>
79 {
81 constexpr my_size_t width = K::simdWidth;
82
83 my_size_t idxList[width];
84 for (my_size_t i = 0; i < width; ++i)
85 idxList[i] = layout_.computeOffsetFromFlat(flat + i);
86
87 return K::gather(t_.data_.data(), idxList);
88 }
89
90 FORCE_INLINE constexpr my_size_t getNumDims() const noexcept { return NumDims; }
91
92 FORCE_INLINE constexpr my_size_t getDim(my_size_t i) const // TODO: conditionally noexcept
93 {
94 return layout_.getDim(i);
95 }
96
97 FORCE_INLINE constexpr my_size_t getTotalSize() const noexcept
98 {
99 // total size is the same as the base tensor, simply return it
100 return t_.getTotalSize();
101 }
102
103 // Inverse permutation — restores the base tensor
104 FORCE_INLINE const Tensor &transpose() const noexcept { return t_; }
105
106 // Utility function to retrieve the shape of the tensor as (1,5,6) for a 3D tensor use the getNumDims
107 std::string getShape() const
108 {
109 std::string shape = "(";
110 for (my_size_t i = 0; i < getNumDims(); ++i)
111 {
112 shape += std::to_string(getDim(i));
113 if (i < getNumDims() - 1)
114 shape += ",";
115 }
116 shape += ")";
117 return shape;
118 }
119
120private:
121 const Tensor &t_;
122
123 using Layout = StridedLayout<NumDims>;
124 Layout layout_;
125};
126
127#endif // FUSED_PERMUTED_VIEW_H
Definition BaseExpr.h:15
static void error(const T &msg)
Definition error_handler.h:30
Definition permuted_view.h:12
static constexpr const my_size_t * Dim
Definition permuted_view.h:16
PermutedView(PermutedView &&)=delete
PermutedView(const Tensor &t, const my_size_t perm[NumDims])
Definition permuted_view.h:19
PermutedView & operator=(const PermutedView &)=delete
FORCE_INLINE constexpr my_size_t getNumDims() const noexcept
Definition permuted_view.h:90
static constexpr my_size_t NumDims
Definition permuted_view.h:15
PermutedView(const PermutedView &)=delete
FORCE_INLINE Microkernel< T, Bits, Arch >::VecType evalu(my_size_t flat) const noexcept
Definition permuted_view.h:78
FORCE_INLINE constexpr my_size_t getTotalSize() const noexcept
Definition permuted_view.h:97
std::string getShape() const
Definition permuted_view.h:107
PermutedView & operator=(PermutedView &&)=delete
FORCE_INLINE constexpr my_size_t getDim(my_size_t i) const
Definition permuted_view.h:92
static constexpr my_size_t TotalSize
Definition permuted_view.h:17
bool may_alias(const Output &output) const noexcept
Definition permuted_view.h:52
FORCE_INLINE const value_type & operator()(my_size_t(&indices)[NumDims]) const noexcept
Definition permuted_view.h:67
typename Tensor::value_type value_type
Definition permuted_view.h:14
FORCE_INLINE const value_type & operator()(const my_size_t *indices) const noexcept
Definition permuted_view.h:72
FORCE_INLINE const Tensor & transpose() const noexcept
Definition permuted_view.h:104
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
Definition microkernel_base.h:16
T VecType
Definition microkernel_base.h:18
Definition strided_layout.h:9
FORCE_INLINE my_size_t getDim(my_size_t i) const
Definition strided_layout.h:33
FORCE_INLINE my_size_t computeOffsetFromFlat(my_size_t flat) const noexcept
Definition strided_layout.h:93
FORCE_INLINE my_size_t compute_flat_index(const my_size_t *indices) const
Definition strided_layout.h:77
my_size_t stride[NumberOfDims]
Definition strided_layout.h:11
my_size_t shape[NumberOfDims]
Definition strided_layout.h:10