tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
kernel_gemm.h
Go to the documentation of this file.
1
102#ifndef KERNEL_GEMM_H
103#define KERNEL_GEMM_H
104
105#include "config.h"
108
109namespace detail
110{
111
112 template <typename T, my_size_t Bits, typename Arch>
114 {
117 static constexpr my_size_t simdWidth = K::simdWidth;
118
120 static constexpr my_size_t MR = K::MR;
121
123 static constexpr my_size_t NR_VECS = K::NR_VECS;
124
126 static constexpr my_size_t NR = K::NR;
127
147 static void gemm(
148 const T *A, my_size_t M, my_size_t K_len, my_size_t strideA,
149 const T *B, my_size_t N, my_size_t strideB,
150 T *C, my_size_t strideC) noexcept
151 {
152 // Column boundaries for the three-pass tiling:
153 // [0, wide_N) → wide micro-kernel (steps of NR)
154 // [wide_N, narrow_N) → narrow micro-kernel (steps of simdWidth)
155 // [narrow_N, N) → scalar column loop (steps of 1)
156 const my_size_t wide_N = (N / NR) * NR;
157 const my_size_t narrow_N = (N / simdWidth) * simdWidth;
158
159 // ==============================================================
160 // Main body: MR rows at a time
161 // ==============================================================
162 my_size_t i = 0;
163 for (; i + MR <= M; i += MR)
164 {
165 my_size_t j = 0;
166
167 for (; j < wide_N; j += NR)
168 {
169 micro_kernel_wide(
170 A + i * strideA, strideA,
171 B + j, strideB,
172 C + i * strideC + j, strideC,
173 K_len);
174 }
175
176 for (; j < narrow_N; j += simdWidth)
177 {
178 micro_kernel_narrow(
179 A + i * strideA, strideA,
180 B + j, strideB,
181 C + i * strideC + j, strideC,
182 K_len);
183 }
184
185 for (; j < N; ++j)
186 {
187 scalar_column_MR(
188 A + i * strideA, strideA,
189 B + j, strideB,
190 C + i * strideC + j, strideC,
191 K_len);
192 }
193 }
194
195 // ==============================================================
196 // Remainder rows (< MR): same three-pass column strategy,
197 // but processing one row at a time.
198 // ==============================================================
199 for (; i < M; ++i)
200 {
201 my_size_t j = 0;
202
203 for (; j < wide_N; j += NR)
204 {
205 single_row_wide(
206 A + i * strideA,
207 B + j, strideB,
208 C + i * strideC + j,
209 K_len);
210 }
211
212 for (; j < narrow_N; j += simdWidth)
213 {
214 single_row_narrow(
215 A + i * strideA,
216 B + j, strideB,
217 C + i * strideC + j,
218 K_len);
219 }
220
221 for (; j < N; ++j)
222 {
223 T sum = T{0};
224 for (my_size_t k = 0; k < K_len; ++k)
225 sum += A[i * strideA + k] * B[k * strideB + j];
226 C[i * strideC + j] = sum;
227 }
228 }
229 }
230
231 private:
257 FORCE_INLINE static void micro_kernel_wide(
258 const T *A, my_size_t strideA,
259 const T *B, my_size_t strideB,
260 T *C, my_size_t strideC,
261 my_size_t K_len) noexcept
262 {
263 // Step 1: zero accumulators
264 typename K::VecType acc[MR][NR_VECS];
265 for (my_size_t r = 0; r < MR; ++r)
266 for (my_size_t v = 0; v < NR_VECS; ++v)
267 acc[r][v] = K::set1(T{0});
268
269 // Step 2: outer-product accumulation over k
270 for (my_size_t k = 0; k < K_len; ++k)
271 {
272 // 2a: load NR_VECS contiguous vectors from B[k, j..j+NR-1]
273 typename K::VecType b_vec[NR_VECS];
274 for (my_size_t v = 0; v < NR_VECS; ++v)
275 b_vec[v] = K::load(B + k * strideB + v * simdWidth);
276
277 // 2b: broadcast each A element and FMA into accumulators
278 for (my_size_t r = 0; r < MR; ++r)
279 {
280 auto a_bcast = K::set1(A[r * strideA + k]);
281 for (my_size_t v = 0; v < NR_VECS; ++v)
282 acc[r][v] = Helpers::fmadd_safe(a_bcast, b_vec[v], acc[r][v]);
283 }
284 }
285
286 // Step 3: store completed tile to C
287 for (my_size_t r = 0; r < MR; ++r)
288 for (my_size_t v = 0; v < NR_VECS; ++v)
289 K::store(C + r * strideC + v * simdWidth, acc[r][v]);
290 }
291
309 FORCE_INLINE static void micro_kernel_narrow(
310 const T *A, my_size_t strideA,
311 const T *B, my_size_t strideB,
312 T *C, my_size_t strideC,
313 my_size_t K_len) noexcept
314 {
315 typename K::VecType acc[MR];
316 for (my_size_t r = 0; r < MR; ++r)
317 acc[r] = K::set1(T{0});
318
319 for (my_size_t k = 0; k < K_len; ++k)
320 {
321 auto b_vec = K::load(B + k * strideB);
322
323 for (my_size_t r = 0; r < MR; ++r)
324 {
325 auto a_bcast = K::set1(A[r * strideA + k]);
326 acc[r] = Helpers::fmadd_safe(a_bcast, b_vec, acc[r]);
327 }
328 }
329
330 for (my_size_t r = 0; r < MR; ++r)
331 K::store(C + r * strideC, acc[r]);
332 }
333
348 FORCE_INLINE static void scalar_column_MR(
349 const T *A, my_size_t strideA,
350 const T *B, my_size_t strideB,
351 T *C, my_size_t strideC,
352 my_size_t K_len) noexcept
353 {
354 T acc[MR] = {};
355
356 for (my_size_t k = 0; k < K_len; ++k)
357 {
358 T b_val = B[k * strideB];
359 for (my_size_t r = 0; r < MR; ++r)
360 acc[r] += A[r * strideA + k] * b_val;
361 }
362
363 for (my_size_t r = 0; r < MR; ++r)
364 C[r * strideC] = acc[r];
365 }
366
379 FORCE_INLINE static void single_row_wide(
380 const T *A,
381 const T *B, my_size_t strideB,
382 T *C,
383 my_size_t K_len) noexcept
384 {
385 typename K::VecType acc[NR_VECS];
386 for (my_size_t v = 0; v < NR_VECS; ++v)
387 acc[v] = K::set1(T{0});
388
389 for (my_size_t k = 0; k < K_len; ++k)
390 {
391 auto a_bcast = K::set1(A[k]);
392 for (my_size_t v = 0; v < NR_VECS; ++v)
393 acc[v] = Helpers::fmadd_safe(a_bcast, K::load(B + k * strideB + v * simdWidth), acc[v]);
394 }
395
396 for (my_size_t v = 0; v < NR_VECS; ++v)
397 K::store(C + v * simdWidth, acc[v]);
398 }
399
412 FORCE_INLINE static void single_row_narrow(
413 const T *A,
414 const T *B, my_size_t strideB,
415 T *C,
416 my_size_t K_len) noexcept
417 {
418 typename K::VecType acc = K::set1(T{0});
419
420 for (my_size_t k = 0; k < K_len; ++k)
421 {
422 auto b_vec = K::load(B + k * strideB);
423 auto a_bcast = K::set1(A[k]);
424 acc = Helpers::fmadd_safe(a_bcast, b_vec, acc);
425 }
426
427 K::store(C, acc);
428 }
429 };
430
431} // namespace detail
432
433#endif // KERNEL_GEMM_H
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
Shared SIMD helper utilities for kernel operations.
Definition BaseExpr.h:4
Expr::value_type sum(const BaseExpr< Expr > &expr)
Definition reductions.h:30
Definition microkernel_base.h:16
T VecType
Definition microkernel_base.h:18
static FORCE_INLINE void store(T *ptr, VecType val) noexcept
static constexpr my_size_t simdWidth
Definition microkernel_base.h:17
static FORCE_INLINE VecType load(const T *ptr) noexcept
static FORCE_INLINE VecType set1(T scalar) noexcept
Definition kernel_gemm.h:114
static void gemm(const T *A, my_size_t M, my_size_t K_len, my_size_t strideA, const T *B, my_size_t N, my_size_t strideB, T *C, my_size_t strideC) noexcept
Register-blocked GEMM: C[M,N] = A[M,K] × B[K,N].
Definition kernel_gemm.h:147
static constexpr my_size_t NR_VECS
Number of SIMD vectors per tile column. The tile width is NR = NR_VECS × simdWidth.
Definition kernel_gemm.h:123
static constexpr my_size_t MR
Tile height: rows of C computed per micro-kernel invocation.
Definition kernel_gemm.h:120
static constexpr my_size_t simdWidth
Definition kernel_gemm.h:117
static constexpr my_size_t NR
Tile width: columns of C computed per wide micro-kernel invocation.
Definition kernel_gemm.h:126
Definition kernel_helpers.h:19
static FORCE_INLINE K::VecType fmadd_safe(typename K::VecType a, typename K::VecType b, typename K::VecType c) noexcept
Fused multiply-add with fallback for architectures without native FMA.
Definition kernel_helpers.h:27