tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
kernel_dot.h
Go to the documentation of this file.
1
24#ifndef KERNEL_DOT_H
25#define KERNEL_DOT_H
26
27#include "config.h"
30
31namespace detail
32{
33
34 template <typename T, my_size_t Bits, typename Arch>
35 struct KernelDot
36 {
39 static constexpr my_size_t simdWidth = K::simdWidth;
40
41 // ========================================================================
42 // Public API
43 // ========================================================================
44
54 template <typename Expr1, typename Expr2>
55 FORCE_INLINE static T dot(
56 const Expr1 &expr1, my_size_t base1, my_size_t stride1,
57 const Expr2 &expr2, my_size_t base2, my_size_t stride2,
58 my_size_t len) noexcept
59 {
60 if (stride1 == 1 && stride2 == 1)
61 {
62 // std::cout << "dot: dispatching to contiguous impl" << std::endl;
63 return dot_contiguous_impl(expr1, expr2, base1, base2, len);
64 }
65 else
66 {
67 // std::cout << "dot: dispatching to strided impl" << std::endl;
68 return dot_strided_impl(expr1, expr2, base1, base2, stride1, stride2, len);
69 }
70 }
71
78 template <typename Expr1, typename Expr2>
80 const Expr1 &expr1, my_size_t base1, my_size_t stride1,
81 const Expr2 &expr2, my_size_t base2, my_size_t stride2,
82 my_size_t len) noexcept
83 {
84 T sum = T{0};
85 for (my_size_t i = 0; i < len; ++i)
86 sum += expr1.data()[base1 + i * stride1] *
87 expr2.data()[base2 + i * stride2];
88 return sum;
89 }
90
91 private:
92 // ========================================================================
93 // Contiguous dot — both fibers have stride 1
94 // ========================================================================
95
107 template <typename Expr1, typename Expr2>
108 FORCE_INLINE static T dot_contiguous_impl(
109 const Expr1 &expr1,
110 const Expr2 &expr2,
111 my_size_t base1,
112 my_size_t base2,
113 my_size_t len) noexcept
114 {
115 // std::cout << "dot_contiguous_impl" << std::endl;
116 const T *ptr1 = expr1.data() + base1;
117 const T *ptr2 = expr2.data() + base2;
118
119 const my_size_t simdSteps = len / simdWidth;
120 const my_size_t scalarStart = simdSteps * simdWidth;
121
122 T result = T{0};
123
124 if (simdSteps > 0)
125 {
126 typename K::VecType acc = K::set1(T{0});
127
128 for (my_size_t i = 0; i < simdSteps; ++i)
129 {
130 auto v1 = K::load(ptr1 + i * simdWidth);
131 auto v2 = K::load(ptr2 + i * simdWidth);
132 acc = Helpers::fmadd_safe(v1, v2, acc);
133 }
134
135 alignas(DATA_ALIGNAS) T tmp[simdWidth];
136 K::store(tmp, acc);
137
138 for (my_size_t i = 0; i < simdWidth; ++i)
139 result += tmp[i];
140 }
141
142 for (my_size_t i = scalarStart; i < len; ++i)
143 result += ptr1[i] * ptr2[i];
144
145 return result;
146 }
147
148 // ========================================================================
149 // Strided dot — one or both fibers have stride > 1
150 // ========================================================================
151
163 template <typename Expr1, typename Expr2>
164 FORCE_INLINE static T dot_strided_impl(
165 const Expr1 &expr1,
166 const Expr2 &expr2,
167 my_size_t idx1,
168 my_size_t idx2,
169 my_size_t stride1,
170 my_size_t stride2,
171 my_size_t len) noexcept
172 {
173 // std::cout << "dot_strided_impl" << std::endl;
174 const my_size_t simdSteps = len / simdWidth;
175 const my_size_t scalarStart = simdSteps * simdWidth;
176
177 T result = T{0};
178
179 if (simdSteps > 0)
180 {
181 typename K::VecType acc = K::set1(T{0});
182
183 for (my_size_t i = 0; i < simdSteps; ++i)
184 {
185 // Build gather indices for this chunk
186 my_size_t idxList1[simdWidth];
187 my_size_t idxList2[simdWidth];
188 for (my_size_t j = 0; j < simdWidth; ++j)
189 {
190 idxList1[j] = idx1 + j * stride1;
191 idxList2[j] = idx2 + j * stride2;
192 }
193
194 auto v1 = K::gather(expr1.data(), idxList1);
195 auto v2 = K::gather(expr2.data(), idxList2);
196 acc = Helpers::fmadd_safe(v1, v2, acc);
197
198 idx1 += simdWidth * stride1;
199 idx2 += simdWidth * stride2;
200 }
201
202 alignas(DATA_ALIGNAS) T tmp[simdWidth];
203 K::store(tmp, acc);
204
205 for (my_size_t i = 0; i < simdWidth; ++i)
206 result += tmp[i];
207 }
208
209 // Scalar tail
210 for (my_size_t i = scalarStart; i < len; ++i)
211 {
212 result += expr1.data()[idx1] * expr2.data()[idx2];
213 idx1 += stride1;
214 idx2 += stride2;
215 }
216
217 return result;
218 }
219 };
220
221} // namespace detail
222
223#endif // KERNEL_DOT_H
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
Shared SIMD helper utilities for kernel operations.
constexpr my_size_t DATA_ALIGNAS
Definition microkernel_base.h:145
Definition BaseExpr.h:4
Expr::value_type sum(const BaseExpr< Expr > &expr)
Definition reductions.h:30
Definition microkernel_base.h:16
T VecType
Definition microkernel_base.h:18
static FORCE_INLINE void store(T *ptr, VecType val) noexcept
static constexpr my_size_t simdWidth
Definition microkernel_base.h:17
static FORCE_INLINE VecType load(const T *ptr) noexcept
static FORCE_INLINE VecType set1(T scalar) noexcept
Definition kernel_dot.h:36
static constexpr my_size_t simdWidth
Definition kernel_dot.h:39
static FORCE_INLINE T naive_dot_physical(const Expr1 &expr1, my_size_t base1, my_size_t stride1, const Expr2 &expr2, my_size_t base2, my_size_t stride2, my_size_t len) noexcept
Naive scalar dot product for testing/validation.
Definition kernel_dot.h:79
static FORCE_INLINE T dot(const Expr1 &expr1, my_size_t base1, my_size_t stride1, const Expr2 &expr2, my_size_t base2, my_size_t stride2, my_size_t len) noexcept
Dispatch dot product based on stride values.
Definition kernel_dot.h:55
Definition kernel_helpers.h:19
static FORCE_INLINE K::VecType fmadd_safe(typename K::VecType a, typename K::VecType b, typename K::VecType c) noexcept
Fused multiply-add with fallback for architectures without native FMA.
Definition kernel_helpers.h:27