tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
FmaExpr.h
Go to the documentation of this file.
1#pragma once
2#include "config.h"
3#include "fused/BaseExpr.h"
4#include "fused/Operations.h"
5#include "helper_traits.h"
7
8// ===============================
9// FMA Expression Template (A * B ± C)
10// ===============================
11template <
12 typename A,
13 typename B, typename C,
14 template <typename, my_size_t, typename> class Op>
15class FmaExpr : public BaseExpr<FmaExpr<A, B, C, Op>>
16{
17 static_assert(is_same_v<typename A::value_type, typename B::value_type>,
18 "FmaExpr: A and B must have the same value_type");
19 static_assert(is_same_v<typename A::value_type, typename C::value_type>,
20 "FmaExpr: A and C must have the same value_type");
21
22#ifdef COMPILETIME_CHECK_DIMENSIONS_COUNT_MISMATCH
23 static_assert(A::NumDims == B::NumDims && A::NumDims == C::NumDims,
24 "FmaExpr: number of dimensions mismatch");
25#endif
26#ifdef COMPILETIME_CHECK_DIMENSIONS_SIZE_MISMATCH
27 static_assert(dims_match<A::NumDims>(A::Dim, B::Dim),
28 "FmaExpr: dimension mismatch between A and B");
29 static_assert(dims_match<A::NumDims>(A::Dim, C::Dim),
30 "FmaExpr: dimension mismatch between A and C");
31#endif
32
33 const A &_a;
34 const B &_b;
35 const C &_c;
36
37public:
38 static constexpr my_size_t NumDims = A::NumDims;
39 static constexpr const my_size_t *Dim = A::Dim;
40 static constexpr my_size_t TotalSize = A::TotalSize;
41 using value_type = typename A::value_type;
42 using Layout = typename A::Layout;
43
44 FmaExpr(const A &a, const B &b, const C &c) : _a(a), _b(b), _c(c) {}
45
46 const A &lhs() const noexcept { return _a; }
47 const B &rhs() const noexcept { return _b; }
48 const C &addend() const noexcept { return _c; }
49
50 template <typename Output>
51 bool may_alias(const Output &output) const noexcept
52 {
53 return _a.may_alias(output) || _b.may_alias(output) || _c.may_alias(output);
54 }
55
56 template <my_size_t length>
57 inline auto operator()(my_size_t (&indices)[length]) const noexcept
58 {
59 using T = std::decay_t<decltype(_a(indices))>;
60 return Op<T, 0, GENERICARCH>::apply(
61 _a(indices), _b(indices), _c(indices));
62 }
63
64 template <typename T, my_size_t Bits, typename Arch>
65 inline typename Op<T, Bits, Arch>::type evalu(my_size_t flat) const noexcept
66 {
67 return Op<T, Bits, Arch>::apply(
68 _a.template evalu<T, Bits, Arch>(flat),
69 _b.template evalu<T, Bits, Arch>(flat),
70 _c.template evalu<T, Bits, Arch>(flat));
71 }
72
73 template <typename T, my_size_t Bits, typename Arch>
74 inline typename Op<T, Bits, Arch>::type logical_evalu(my_size_t logical_flat) const noexcept
75 {
76 return Op<T, Bits, Arch>::apply(
77 _a.template logical_evalu<T, Bits, Arch>(logical_flat),
78 _b.template logical_evalu<T, Bits, Arch>(logical_flat),
79 _c.template logical_evalu<T, Bits, Arch>(logical_flat));
80 }
81
82 inline my_size_t getNumDims() const noexcept { return _a.getNumDims(); }
83 inline my_size_t getDim(my_size_t i) const { return _a.getDim(i); }
84 my_size_t getTotalSize() const noexcept { return _a.getTotalSize(); }
85
86// protected:
87// inline auto operator()(const my_size_t *indices) const noexcept
88// {
89// using T = std::decay_t<decltype(_a(indices))>;
90// return Op<T, 0, GENERICARCH>::apply(
91// _a(indices), _b(indices), _c(indices));
92// }
93};
94
95// ===============================
96// Scalar FMA Expression Template (A * scalar ± C)
97// ===============================
98template <
99 typename EXPR,
100 typename ScalarT, typename C,
101 template <typename, my_size_t, typename> class Op>
102class ScalarFmaExpr : public BaseExpr<ScalarFmaExpr<EXPR, ScalarT, C, Op>>
103{
104 static_assert(is_same_v<typename EXPR::value_type, ScalarT>,
105 "ScalarFmaExpr: EXPR value_type and ScalarT must be the same");
106 static_assert(is_same_v<typename EXPR::value_type, typename C::value_type>,
107 "ScalarFmaExpr: EXPR and C must have the same value_type");
108
109#ifdef COMPILETIME_CHECK_DIMENSIONS_COUNT_MISMATCH
110 static_assert(EXPR::NumDims == C::NumDims,
111 "ScalarFmaExpr: number of dimensions mismatch");
112#endif
113#ifdef COMPILETIME_CHECK_DIMENSIONS_SIZE_MISMATCH
114 static_assert(dims_match<EXPR::NumDims>(EXPR::Dim, C::Dim),
115 "ScalarFmaExpr: dimension mismatch between EXPR and C");
116#endif
117
118 const EXPR &_expr;
119 ScalarT _scalar;
120 const C &_c;
121
122public:
123 static constexpr my_size_t NumDims = EXPR::NumDims;
124 static constexpr const my_size_t *Dim = EXPR::Dim;
125 static constexpr my_size_t TotalSize = EXPR::TotalSize;
126 using value_type = typename EXPR::value_type;
127 using Layout = typename EXPR::Layout;
128
129 ScalarFmaExpr(const EXPR &expr, ScalarT scalar, const C &c)
130 : _expr(expr), _scalar(scalar), _c(c) {}
131
132 const EXPR &expr() const noexcept { return _expr; }
133 ScalarT scalar() const noexcept { return _scalar; }
134 const C &addend() const noexcept { return _c; }
135
136 template <typename Output>
137 bool may_alias(const Output &output) const noexcept
138 {
139 return _expr.may_alias(output) || _c.may_alias(output);
140 }
141
142 template <my_size_t length>
143 inline auto operator()(my_size_t (&indices)[length]) const noexcept
144 {
145 using T = std::decay_t<decltype(_expr(indices))>;
146 return Op<T, 0, GENERICARCH>::apply(
147 _expr(indices), _scalar, _c(indices));
148 }
149
150 template <typename T, my_size_t Bits, typename Arch>
151 inline typename Op<T, Bits, Arch>::type evalu(my_size_t flat) const noexcept
152 {
153 return Op<T, Bits, Arch>::apply(
154 _expr.template evalu<T, Bits, Arch>(flat),
155 _scalar,
156 _c.template evalu<T, Bits, Arch>(flat));
157 }
158
159 template <typename T, my_size_t Bits, typename Arch>
160 inline typename Op<T, Bits, Arch>::type logical_evalu(my_size_t logical_flat) const noexcept
161 {
162 return Op<T, Bits, Arch>::apply(
163 _expr.template logical_evalu<T, Bits, Arch>(logical_flat),
164 _scalar,
165 _c.template logical_evalu<T, Bits, Arch>(logical_flat));
166 }
167
168 inline my_size_t getNumDims() const noexcept { return _expr.getNumDims(); }
169 inline my_size_t getDim(my_size_t i) const { return _expr.getDim(i); }
170 my_size_t getTotalSize() const noexcept { return _expr.getTotalSize(); }
171
172// protected:
173// inline auto operator()(const my_size_t *indices) const noexcept
174// {
175// using T = std::decay_t<decltype(_expr(indices))>;
176// return Op<T, 0, GENERICARCH>::apply(
177// _expr(indices), _scalar, _c(indices));
178// }
179};
Definition BaseExpr.h:15
Definition FmaExpr.h:16
static constexpr const my_size_t * Dim
Definition FmaExpr.h:39
const A & lhs() const noexcept
Definition FmaExpr.h:46
Op< T, Bits, Arch >::type evalu(my_size_t flat) const noexcept
Definition FmaExpr.h:65
my_size_t getNumDims() const noexcept
Definition FmaExpr.h:82
my_size_t getTotalSize() const noexcept
Definition FmaExpr.h:84
Op< T, Bits, Arch >::type logical_evalu(my_size_t logical_flat) const noexcept
Definition FmaExpr.h:74
auto operator()(my_size_t(&indices)[length]) const noexcept
Definition FmaExpr.h:57
static constexpr my_size_t TotalSize
Definition FmaExpr.h:40
const B & rhs() const noexcept
Definition FmaExpr.h:47
static constexpr my_size_t NumDims
Definition FmaExpr.h:38
my_size_t getDim(my_size_t i) const
Definition FmaExpr.h:83
const C & addend() const noexcept
Definition FmaExpr.h:48
typename A::value_type value_type
Definition FmaExpr.h:41
FmaExpr(const A &a, const B &b, const C &c)
Definition FmaExpr.h:44
typename A::Layout Layout
Definition FmaExpr.h:42
bool may_alias(const Output &output) const noexcept
Definition FmaExpr.h:51
Definition FmaExpr.h:103
my_size_t getDim(my_size_t i) const
Definition FmaExpr.h:169
Op< T, Bits, Arch >::type logical_evalu(my_size_t logical_flat) const noexcept
Definition FmaExpr.h:160
static constexpr my_size_t TotalSize
Definition FmaExpr.h:125
my_size_t getNumDims() const noexcept
Definition FmaExpr.h:168
ScalarFmaExpr(const EXPR &expr, ScalarT scalar, const C &c)
Definition FmaExpr.h:129
typename EXPR::Layout Layout
Definition FmaExpr.h:127
const C & addend() const noexcept
Definition FmaExpr.h:134
auto operator()(my_size_t(&indices)[length]) const noexcept
Definition FmaExpr.h:143
Op< T, Bits, Arch >::type evalu(my_size_t flat) const noexcept
Definition FmaExpr.h:151
bool may_alias(const Output &output) const noexcept
Definition FmaExpr.h:137
const EXPR & expr() const noexcept
Definition FmaExpr.h:132
static constexpr my_size_t NumDims
Definition FmaExpr.h:123
typename EXPR::value_type value_type
Definition FmaExpr.h:126
my_size_t getTotalSize() const noexcept
Definition FmaExpr.h:170
ScalarT scalar() const noexcept
Definition FmaExpr.h:133
static constexpr const my_size_t * Dim
Definition FmaExpr.h:124
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126