tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
arithmetic.h
Go to the documentation of this file.
1#pragma once
2#include "config.h"
3#include "fused/BinaryExpr.h"
4#include "fused/ScalarExpr.h"
5#include "fused/FmaExpr.h"
6#include "fused/Operations.h"
10
11// ===============================
12// Operator Overloads
13// ===============================
14
15// ===============================
16// FMA detection: operator+
17// ===============================
18
19#ifdef TESSERACT_USE_FMAD
20// (A * B) + C → Fma
21template <typename L, typename R, typename C>
25operator+(const BaseExpr<BinaryExpr<L, R, Mul>> &lhs,
27{
28#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
29 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (A*B + C)");
30#endif
31 const auto &mul = lhs.derived();
32 return FmaExpr<L, R, C, Fma>(mul.lhs(), mul.rhs(), rhs.derived());
33}
34
35// C + (A * B) → Fma
36template <typename C, typename L, typename R>
40operator+(const BaseExpr<C> &lhs,
42{
43#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
44 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (C + A*B)");
45#endif
46 const auto &mul = rhs.derived();
47 return FmaExpr<L, R, C, Fma>(mul.lhs(), mul.rhs(), lhs.derived());
48}
49
50// (A * scalar) + C → Fma
51template <typename L, typename T, typename C>
55operator+(const BaseExpr<ScalarExprRHS<L, T, Mul>> &lhs,
57{
58#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
59 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (A*scalar + C)");
60#endif
61 const auto &mul = lhs.derived();
62 return ScalarFmaExpr<L, T, C, Fma>(mul.expr(), mul.scalar(), rhs.derived());
63}
64
65// C + (A * scalar) → Fma
66template <typename C, typename L, typename T>
70operator+(const BaseExpr<C> &lhs,
72{
73#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
74 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (C + A*scalar)");
75#endif
76 const auto &mul = rhs.derived();
77 return ScalarFmaExpr<L, T, C, Fma>(mul.expr(), mul.scalar(), lhs.derived());
78}
79
80// -(A * B) + C → Fnma
81template <typename L, typename R, typename T, typename C>
85operator+(const BaseExpr<ScalarExprLHS<BinaryExpr<L, R, Mul>, T, Sub>> &lhs,
87{
88#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
89 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (-(A*B) + C)");
90#endif
91 const auto &neg = lhs.derived();
92 const auto &mul = neg.expr(); // the BinaryExpr<L, R, Mul>
93 return FmaExpr<L, R, C, Fnma>(mul.lhs(), mul.rhs(), rhs.derived());
94}
95
96// -(A * scalar) + C → Fnma
97template <typename L, typename T1, typename T2, typename C>
101operator+(const BaseExpr<ScalarExprLHS<ScalarExprRHS<L, T1, Mul>, T2, Sub>> &lhs,
103{
104#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
105 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+ for FMA pattern (-(A*scalar) + C)");
106#endif
107 const auto &neg = lhs.derived();
108 const auto &mul = neg.expr();
109 return ScalarFmaExpr<L, T1, C, Fnma>(mul.expr(), mul.scalar(), rhs.derived());
110}
111
112// ===============================
113// FMA detection: operator-
114// ===============================
115
116// (A * B) - C → Fms
117template <typename L, typename R, typename C>
121operator-(const BaseExpr<BinaryExpr<L, R, Mul>> &lhs,
123{
124#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
125 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (A*B - C)");
126#endif
127 const auto &mul = lhs.derived();
128 return FmaExpr<L, R, C, Fms>(mul.lhs(), mul.rhs(), rhs.derived());
129}
130
131// C - (A * B) → Fnma
132template <typename C, typename L, typename R>
136operator-(const BaseExpr<C> &lhs,
138{
139#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
140 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (C - A*B)");
141#endif
142 const auto &mul = rhs.derived();
143 return FmaExpr<L, R, C, Fnma>(mul.lhs(), mul.rhs(), lhs.derived());
144}
145
146// -(A * B) - C → Fnms
147template <typename L, typename R, typename T, typename C>
151operator-(const BaseExpr<ScalarExprLHS<BinaryExpr<L, R, Mul>, T, Sub>> &lhs,
153{
154#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
155 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (-(A*B) - C)");
156#endif
157 const auto &neg = lhs.derived();
158 const auto &mul = neg.expr();
159 return FmaExpr<L, R, C, Fnms>(mul.lhs(), mul.rhs(), rhs.derived());
160}
161
162// -(A * scalar) - C → Fnms
163template <typename L, typename T1, typename T2, typename C>
167operator-(const BaseExpr<ScalarExprLHS<ScalarExprRHS<L, T1, Mul>, T2, Sub>> &lhs,
169{
170#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
171 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (-(A*scalar) - C)");
172#endif
173 const auto &neg = lhs.derived();
174 const auto &mul = neg.expr();
175 return ScalarFmaExpr<L, T1, C, Fnms>(mul.expr(), mul.scalar(), rhs.derived());
176}
177
178// (A * scalar) - C → Fms
179template <typename L, typename T, typename C>
183operator-(const BaseExpr<ScalarExprRHS<L, T, Mul>> &lhs,
185{
186#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
187 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (A*scalar - C)");
188#endif
189 const auto &mul = lhs.derived();
190 return ScalarFmaExpr<L, T, C, Fms>(mul.expr(), mul.scalar(), rhs.derived());
191}
192
193// C - (A * scalar) → Fnma
194template <typename C, typename L, typename T>
198operator-(const BaseExpr<C> &lhs,
200{
201#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
202 checkDimsMatch(lhs.derived(), rhs.derived(), "operator- for FMA pattern (C - A*scalar)");
203#endif
204 const auto &mul = rhs.derived();
205 return ScalarFmaExpr<L, T, C, Fnma>(mul.expr(), mul.scalar(), lhs.derived());
206}
207#endif // TESSERACT_USE_FMAD
208
209// ===============================
210// binary detection: operator+
211// ===============================
212
213template <typename LHS, typename RHS>
217{
218#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
219 checkDimsMatch(lhs.derived(), rhs.derived(), "operator+");
220#endif
221 return BinaryExpr<LHS, RHS, Add>(lhs.derived(), rhs.derived());
222}
223
224template <typename LHS, typename RHS>
228{
229#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
230 checkDimsMatch(lhs.derived(), rhs.derived(), "operator-");
231#endif
232 return BinaryExpr<LHS, RHS, Sub>(lhs.derived(), rhs.derived());
233}
234
235template <typename LHS, typename RHS>
236 requires( // for Hadamard product only it must be tensors, not general algebras
243{
244#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
245 checkDimsMatch(lhs.derived(), rhs.derived(), "operator*");
246#endif
247 return BinaryExpr<LHS, RHS, Mul>(lhs.derived(), rhs.derived());
248}
249
250template <typename LHS, typename RHS>
251 requires( // for Hadamard product (element-wise division) only it must be tensors, not general algebras
258{
259#if defined(RUNTIME_CHECK_DIMENSIONS_COUNT_MISMATCH) || defined(RUNTIME_CHECK_DIMENSIONS_SIZE_MISMATCH)
260 checkDimsMatch(lhs.derived(), rhs.derived(), "operator/");
261#endif
262 return BinaryExpr<LHS, RHS, Div>(lhs.derived(), rhs.derived());
263}
264
265// matrix + scalar (scalar on RHS)
266template <typename LHS, typename T>
268 !is_base_of_v<detail::BaseExprTag, T>)
270operator+(const BaseExpr<LHS> &lhs, T scalar) noexcept
271{
272 return ScalarExprRHS<LHS, T, Add>(lhs.derived(), scalar);
273}
274
275// scalar + matrix (scalar on LHS)
276template <typename RHS, typename T>
278 !is_base_of_v<detail::BaseExprTag, T>)
280operator+(T scalar, const BaseExpr<RHS> &rhs) noexcept
281{
282 return ScalarExprRHS<RHS, T, Add>(rhs.derived(), scalar);
283}
284
285// Override operator- to get the negative
286template <typename RHS>
289operator-(const BaseExpr<RHS> &expr) noexcept
290{
291 using T = typename RHS::value_type;
292 return ScalarExprLHS<RHS, T, Sub>(expr.derived(), T(0)); // Negation is like subtracting from zero
293}
294
295// matrix - scalar (scalar on RHS)
296template <typename LHS, typename T>
298 !is_base_of_v<detail::BaseExprTag, T>)
300operator-(const BaseExpr<LHS> &lhs, T scalar) noexcept
301{
302 return ScalarExprRHS<LHS, T, Sub>(lhs.derived(), scalar);
303}
304
305// scalar - matrix (scalar on LHS)
306template <typename RHS, typename T>
308 !is_base_of_v<detail::BaseExprTag, T>)
310operator-(T scalar, const BaseExpr<RHS> &rhs) noexcept
311{
312 return ScalarExprLHS<RHS, T, Sub>(rhs.derived(), scalar);
313}
314
315// matrix * scalar (scalar on RHS)
316template <typename LHS, typename T>
318 !is_base_of_v<detail::BaseExprTag, T>)
320operator*(const BaseExpr<LHS> &lhs, T scalar) noexcept
321{
322 return ScalarExprRHS<LHS, T, Mul>(lhs.derived(), scalar);
323}
324
325// scalar * matrix (scalar on LHS)
326template <typename RHS, typename T>
328 !is_base_of_v<detail::BaseExprTag, T>)
330operator*(T scalar, const BaseExpr<RHS> &rhs) noexcept
331{
332 return ScalarExprRHS<RHS, T, Mul>(rhs.derived(), scalar);
333}
334
335// matrix / scalar (scalar on RHS)
336template <typename LHS, typename T>
338 !is_base_of_v<detail::BaseExprTag, T>)
340operator/(const BaseExpr<LHS> &lhs, T scalar) noexcept
341{
342 return ScalarExprRHS<LHS, T, Div>(lhs.derived(), scalar);
343}
344
345// scalar / matrix (scalar on LHS)
346template <typename RHS, typename T>
348 !is_base_of_v<detail::BaseExprTag, T>)
350operator/(T scalar, const BaseExpr<RHS> &rhs) noexcept
351{
352 return ScalarExprLHS<RHS, T, Div>(rhs.derived(), scalar);
353}
Definition BaseExpr.h:15
const Derived & derived() const
Definition BaseExpr.h:17
Definition BinaryExpr.h:15
Definition FmaExpr.h:16
Definition ScalarExpr.h:95
Definition ScalarExpr.h:14
Definition FmaExpr.h:103
Global configuration for the tesseract tensor library.
#define TESSERACT_CONDITIONAL_NOEXCEPT
Definition config.h:56
constexpr bool is_vector_space_v
Definition basic_algebraic_traits.h:123
void checkDimsMatch(const Expr1 &lhs, const Expr2 &rhs, const char *opName) TESSERACT_CONDITIONAL_NOEXCEPT
Definition operators_common.h:5
Definition Operations.h:109