13 typename B,
typename C,
14 template <
typename, my_
size_t,
typename>
class Op>
17 static_assert(is_same_v<typename A::value_type, typename B::value_type>,
18 "FmaExpr: A and B must have the same value_type");
19 static_assert(is_same_v<typename A::value_type, typename C::value_type>,
20 "FmaExpr: A and C must have the same value_type");
22#ifdef COMPILETIME_CHECK_DIMENSIONS_COUNT_MISMATCH
23 static_assert(A::NumDims == B::NumDims && A::NumDims == C::NumDims,
24 "FmaExpr: number of dimensions mismatch");
26#ifdef COMPILETIME_CHECK_DIMENSIONS_SIZE_MISMATCH
27 static_assert(dims_match<A::NumDims>(A::Dim, B::Dim),
28 "FmaExpr: dimension mismatch between A and B");
29 static_assert(dims_match<A::NumDims>(A::Dim, C::Dim),
30 "FmaExpr: dimension mismatch between A and C");
44 FmaExpr(
const A &a,
const B &b,
const C &c) : _a(a), _b(b), _c(c) {}
46 const A &
lhs() const noexcept {
return _a; }
47 const B &
rhs() const noexcept {
return _b; }
48 const C &
addend() const noexcept {
return _c; }
50 template <
typename Output>
53 return _a.may_alias(output) || _b.may_alias(output) || _c.may_alias(output);
56 template <my_
size_t length>
59 using T = std::decay_t<
decltype(_a(indices))>;
60 return Op<T, 0, GENERICARCH>::apply(
61 _a(indices), _b(indices), _c(indices));
64 template <
typename T, my_
size_t Bits,
typename Arch>
67 return Op<T, Bits, Arch>::apply(
68 _a.template evalu<T, Bits, Arch>(flat),
69 _b.template evalu<T, Bits, Arch>(flat),
70 _c.template evalu<T, Bits, Arch>(flat));
73 template <
typename T, my_
size_t Bits,
typename Arch>
76 return Op<T, Bits, Arch>::apply(
77 _a.template logical_evalu<T, Bits, Arch>(logical_flat),
78 _b.template logical_evalu<T, Bits, Arch>(logical_flat),
79 _c.template logical_evalu<T, Bits, Arch>(logical_flat));
100 typename ScalarT,
typename C,
101 template <
typename, my_
size_t,
typename>
class Op>
104 static_assert(is_same_v<typename EXPR::value_type, ScalarT>,
105 "ScalarFmaExpr: EXPR value_type and ScalarT must be the same");
106 static_assert(is_same_v<typename EXPR::value_type, typename C::value_type>,
107 "ScalarFmaExpr: EXPR and C must have the same value_type");
109#ifdef COMPILETIME_CHECK_DIMENSIONS_COUNT_MISMATCH
110 static_assert(EXPR::NumDims == C::NumDims,
111 "ScalarFmaExpr: number of dimensions mismatch");
113#ifdef COMPILETIME_CHECK_DIMENSIONS_SIZE_MISMATCH
114 static_assert(dims_match<EXPR::NumDims>(EXPR::Dim, C::Dim),
115 "ScalarFmaExpr: dimension mismatch between EXPR and C");
132 const EXPR &
expr() const noexcept {
return _expr; }
133 ScalarT
scalar() const noexcept {
return _scalar; }
134 const C &
addend() const noexcept {
return _c; }
136 template <
typename Output>
139 return _expr.may_alias(output) || _c.may_alias(output);
142 template <my_
size_t length>
145 using T = std::decay_t<
decltype(_expr(indices))>;
146 return Op<T, 0, GENERICARCH>::apply(
147 _expr(indices), _scalar, _c(indices));
150 template <
typename T, my_
size_t Bits,
typename Arch>
153 return Op<T, Bits, Arch>::apply(
154 _expr.template evalu<T, Bits, Arch>(flat),
156 _c.template evalu<T, Bits, Arch>(flat));
159 template <
typename T, my_
size_t Bits,
typename Arch>
162 return Op<T, Bits, Arch>::apply(
163 _expr.template logical_evalu<T, Bits, Arch>(logical_flat),
165 _c.template logical_evalu<T, Bits, Arch>(logical_flat));
static constexpr const my_size_t * Dim
Definition FmaExpr.h:39
const A & lhs() const noexcept
Definition FmaExpr.h:46
Op< T, Bits, Arch >::type evalu(my_size_t flat) const noexcept
Definition FmaExpr.h:65
my_size_t getNumDims() const noexcept
Definition FmaExpr.h:82
my_size_t getTotalSize() const noexcept
Definition FmaExpr.h:84
Op< T, Bits, Arch >::type logical_evalu(my_size_t logical_flat) const noexcept
Definition FmaExpr.h:74
auto operator()(my_size_t(&indices)[length]) const noexcept
Definition FmaExpr.h:57
static constexpr my_size_t TotalSize
Definition FmaExpr.h:40
const B & rhs() const noexcept
Definition FmaExpr.h:47
static constexpr my_size_t NumDims
Definition FmaExpr.h:38
my_size_t getDim(my_size_t i) const
Definition FmaExpr.h:83
const C & addend() const noexcept
Definition FmaExpr.h:48
typename A::value_type value_type
Definition FmaExpr.h:41
FmaExpr(const A &a, const B &b, const C &c)
Definition FmaExpr.h:44
typename A::Layout Layout
Definition FmaExpr.h:42
bool may_alias(const Output &output) const noexcept
Definition FmaExpr.h:51
my_size_t getDim(my_size_t i) const
Definition FmaExpr.h:169
Op< T, Bits, Arch >::type logical_evalu(my_size_t logical_flat) const noexcept
Definition FmaExpr.h:160
static constexpr my_size_t TotalSize
Definition FmaExpr.h:125
my_size_t getNumDims() const noexcept
Definition FmaExpr.h:168
ScalarFmaExpr(const EXPR &expr, ScalarT scalar, const C &c)
Definition FmaExpr.h:129
typename EXPR::Layout Layout
Definition FmaExpr.h:127
const C & addend() const noexcept
Definition FmaExpr.h:134
auto operator()(my_size_t(&indices)[length]) const noexcept
Definition FmaExpr.h:143
Op< T, Bits, Arch >::type evalu(my_size_t flat) const noexcept
Definition FmaExpr.h:151
bool may_alias(const Output &output) const noexcept
Definition FmaExpr.h:137
const EXPR & expr() const noexcept
Definition FmaExpr.h:132
static constexpr my_size_t NumDims
Definition FmaExpr.h:123
typename EXPR::value_type value_type
Definition FmaExpr.h:126
my_size_t getTotalSize() const noexcept
Definition FmaExpr.h:170
ScalarT scalar() const noexcept
Definition FmaExpr.h:133
static constexpr const my_size_t * Dim
Definition FmaExpr.h:124
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126