tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
avx2_microkernel.h
Go to the documentation of this file.
1#ifndef __AVX2_MICROKERNEL_H__
2#define __AVX2_MICROKERNEL_H__
3
4#include <immintrin.h>
5#include "config.h"
6
7// Architecture tag
8struct X86_AVX
9{
10}; // 256-bit AVX/AVX2
11
12// ============================================================================
13// AVX2 (256-bit) specializations
14// ============================================================================
15
16template <>
17struct Microkernel<float, 256, X86_AVX>
18{
19 static constexpr my_size_t simdWidth = 8; // 256 bits / 32 bits per float = 8
20 // GEMM tiling constants (register-blocked)
21 static constexpr my_size_t num_registers = 16;
22 static constexpr my_size_t MR = 4;
23 static constexpr my_size_t NR_VECS = 3;
24 static constexpr my_size_t NR = NR_VECS * simdWidth; // 24
25 using VecType = __m256;
26 using ScalarType = float;
27
28 FORCE_INLINE static VecType load(const ScalarType *ptr) noexcept { return _mm256_load_ps(ptr); }
29 FORCE_INLINE static VecType loadu(const ScalarType *ptr) noexcept { return _mm256_loadu_ps(ptr); }
30 FORCE_INLINE static void store(ScalarType *ptr, VecType val) noexcept { _mm256_store_ps(ptr, val); }
31 FORCE_INLINE static void storeu(ScalarType *ptr, VecType val) noexcept { _mm256_storeu_ps(ptr, val); }
32 FORCE_INLINE static VecType set1(ScalarType scalar) noexcept { return _mm256_set1_ps(scalar); }
33
34 FORCE_INLINE static VecType add(VecType a, VecType b) noexcept { return _mm256_add_ps(a, b); }
35 FORCE_INLINE static VecType add(VecType a, ScalarType b) noexcept { return _mm256_add_ps(a, set1(b)); }
36
37 FORCE_INLINE static VecType mul(VecType a, VecType b) noexcept { return _mm256_mul_ps(a, b); }
38 FORCE_INLINE static VecType mul(VecType a, ScalarType b) noexcept { return _mm256_mul_ps(a, set1(b)); }
39
40 FORCE_INLINE static VecType sub(VecType a, VecType b) noexcept { return _mm256_sub_ps(a, b); }
41 FORCE_INLINE static VecType sub(VecType a, ScalarType b) noexcept { return _mm256_sub_ps(a, set1(b)); }
42 FORCE_INLINE static VecType sub(ScalarType a, VecType b) noexcept { return _mm256_sub_ps(set1(a), b); }
43
44 FORCE_INLINE static VecType div(VecType a, VecType b) noexcept { return _mm256_div_ps(a, b); }
45 FORCE_INLINE static VecType div(VecType a, ScalarType b) noexcept { return _mm256_div_ps(a, set1(b)); }
46 FORCE_INLINE static VecType div(ScalarType a, VecType b) noexcept { return _mm256_div_ps(set1(a), b); }
47
48 // fmadd: a*b + c
49 FORCE_INLINE static VecType fmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_fmadd_ps(a, b, c); }
50 FORCE_INLINE static VecType fmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fmadd_ps(a, set1(b), c); }
51
52 // fmsub: a*b - c
53 FORCE_INLINE static VecType fmsub(VecType a, VecType b, VecType c) noexcept { return _mm256_fmsub_ps(a, b, c); }
54 FORCE_INLINE static VecType fmsub(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fmsub_ps(a, set1(b), c); }
55
56 // fnmadd: -(a*b) + c
57 FORCE_INLINE static VecType fnmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_fnmadd_ps(a, b, c); }
58 FORCE_INLINE static VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fnmadd_ps(a, set1(b), c); }
59
60 // fnmsub: -(a*b) - c
61 FORCE_INLINE static VecType fnmsub(VecType a, VecType b, VecType c) noexcept { return _mm256_fnmsub_ps(a, b, c); }
62 FORCE_INLINE static VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fnmsub_ps(a, set1(b), c); }
63
64 FORCE_INLINE static VecType min(VecType a, VecType b) noexcept { return _mm256_min_ps(a, b); }
65 FORCE_INLINE static VecType min(VecType a, ScalarType b) noexcept { return _mm256_min_ps(a, set1(b)); }
66
67 FORCE_INLINE static VecType max(VecType a, VecType b) noexcept { return _mm256_max_ps(a, b); }
68 FORCE_INLINE static VecType max(VecType a, ScalarType b) noexcept { return _mm256_max_ps(a, set1(b)); }
69
70 // ============================================================================
71 // Gather: non-contiguous load using index list
72 // ============================================================================
73 FORCE_INLINE static VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
74 {
75 // _mm256_i32gather_ps requires 8 × 32-bit indices.
76 // so we convert size_t → int32_t.
77 alignas(32) int32_t idx32[simdWidth];
78 for (my_size_t i = 0; i < simdWidth; ++i)
79 {
80 idx32[i] = static_cast<int32_t>(indices[i]);
81 }
82
83 // loadu (“unaligned load”) is recommended for temporary stack buffers, even when aligned, because:
84 // it's just as fast as load on aligned addresses
85 // never invokes undefined behavior
86 // does not depend on type alignment rules
87 // TODO: verify: Intel’s documentation confirms: On aligned addresses, _mm256_loadu_si256 performs identically to _mm256_load_si256.
88 __m256i vindex = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(idx32));
89 return _mm256_i32gather_ps(base, vindex, sizeof(ScalarType));
90 }
91
92 FORCE_INLINE static void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
93 {
94 alignas(32) ScalarType tmp[simdWidth];
95 _mm256_storeu_ps(tmp, val);
96 for (my_size_t i = 0; i < simdWidth; ++i)
97 base[indices[i]] = tmp[i];
98 }
99
100 FORCE_INLINE static VecType abs(VecType v) noexcept
101 {
102 // Clear sign bit: AND with 0x7FFFFFFF
103 __m256 sign_mask = _mm256_set1_ps(-0.0f);
104 return _mm256_andnot_ps(sign_mask, v);
105 }
106
108 {
109 __m256 diff = _mm256_sub_ps(a, b);
110 __m256 abs_diff = abs(diff);
111 __m256 tol_vec = _mm256_set1_ps(tol);
112 __m256 cmp = _mm256_cmp_ps(abs_diff, tol_vec, _CMP_LE_OQ); // abs_diff <= tol
113 int mask = _mm256_movemask_ps(cmp);
114 return mask == 0xFF; // all 8 lanes passed
115 }
116};
117
118template <>
119struct Microkernel<double, 256, X86_AVX>
120{
121 static constexpr my_size_t simdWidth = 4; // 256 bits / 64 bits per double = 4
122 // GEMM tiling constants (register-blocked)
123 static constexpr my_size_t num_registers = 16;
124 static constexpr my_size_t MR = 4;
125 static constexpr my_size_t NR_VECS = 3;
126 static constexpr my_size_t NR = NR_VECS * simdWidth; // 12
127 using VecType = __m256d;
128 using ScalarType = double;
129
130 FORCE_INLINE static VecType load(const ScalarType *ptr) noexcept { return _mm256_load_pd(ptr); }
131 FORCE_INLINE static VecType loadu(const ScalarType *ptr) noexcept { return _mm256_loadu_pd(ptr); }
132 FORCE_INLINE static void store(ScalarType *ptr, VecType val) noexcept { _mm256_store_pd(ptr, val); }
133 FORCE_INLINE static void storeu(ScalarType *ptr, VecType val) noexcept { _mm256_storeu_pd(ptr, val); }
134 FORCE_INLINE static VecType set1(ScalarType scalar) noexcept { return _mm256_set1_pd(scalar); }
135
136 FORCE_INLINE static VecType add(VecType a, VecType b) noexcept { return _mm256_add_pd(a, b); }
137 FORCE_INLINE static VecType add(VecType a, ScalarType b) noexcept { return _mm256_add_pd(a, set1(b)); }
138
139 FORCE_INLINE static VecType mul(VecType a, VecType b) noexcept { return _mm256_mul_pd(a, b); }
140 FORCE_INLINE static VecType mul(VecType a, ScalarType b) noexcept { return _mm256_mul_pd(a, set1(b)); }
141
142 FORCE_INLINE static VecType sub(VecType a, VecType b) noexcept { return _mm256_sub_pd(a, b); }
143 FORCE_INLINE static VecType sub(VecType a, ScalarType b) noexcept { return _mm256_sub_pd(a, set1(b)); }
144 FORCE_INLINE static VecType sub(ScalarType a, VecType b) noexcept { return _mm256_sub_pd(set1(a), b); }
145
146 FORCE_INLINE static VecType div(VecType a, VecType b) noexcept { return _mm256_div_pd(a, b); }
147 FORCE_INLINE static VecType div(VecType a, ScalarType b) noexcept { return _mm256_div_pd(a, set1(b)); }
148 FORCE_INLINE static VecType div(ScalarType a, VecType b) noexcept { return _mm256_div_pd(set1(a), b); }
149
150 // fmadd: a*b + c
151 FORCE_INLINE static VecType fmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_fmadd_pd(a, b, c); }
152 FORCE_INLINE static VecType fmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fmadd_pd(a, set1(b), c); }
153
154 // fmsub: a*b - c
155 FORCE_INLINE static VecType fmsub(VecType a, VecType b, VecType c) noexcept { return _mm256_fmsub_pd(a, b, c); }
156 FORCE_INLINE static VecType fmsub(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fmsub_pd(a, set1(b), c); }
157
158 // fnmadd: -(a*b) + c
159 FORCE_INLINE static VecType fnmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_fnmadd_pd(a, b, c); }
160 FORCE_INLINE static VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fnmadd_pd(a, set1(b), c); }
161
162 // fnmsub: -(a*b) - c
163 FORCE_INLINE static VecType fnmsub(VecType a, VecType b, VecType c) noexcept { return _mm256_fnmsub_pd(a, b, c); }
164 FORCE_INLINE static VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept { return _mm256_fnmsub_pd(a, set1(b), c); }
165
166 FORCE_INLINE static VecType min(VecType a, VecType b) noexcept { return _mm256_min_pd(a, b); }
167 FORCE_INLINE static VecType min(VecType a, ScalarType b) noexcept { return _mm256_min_pd(a, set1(b)); }
168
169 FORCE_INLINE static VecType max(VecType a, VecType b) noexcept { return _mm256_max_pd(a, b); }
170 FORCE_INLINE static VecType max(VecType a, ScalarType b) noexcept { return _mm256_max_pd(a, set1(b)); }
171
172 FORCE_INLINE static VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
173 {
174 __m256i vindex = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(indices));
175 return _mm256_i64gather_pd(base, vindex, sizeof(ScalarType));
176 }
177
178 FORCE_INLINE static void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
179 {
180 alignas(32) ScalarType tmp[simdWidth];
181 _mm256_storeu_pd(tmp, val);
182 for (my_size_t i = 0; i < simdWidth; ++i)
183 base[indices[i]] = tmp[i];
184 }
185
186 FORCE_INLINE static VecType abs(VecType v) noexcept
187 {
188 __m256d sign_mask = _mm256_set1_pd(-0.0);
189 return _mm256_andnot_pd(sign_mask, v);
190 }
191
193 {
194 __m256d diff = _mm256_sub_pd(a, b);
195 __m256d abs_diff = abs(diff);
196 __m256d tol_vec = _mm256_set1_pd(tol);
197 __m256d cmp = _mm256_cmp_pd(abs_diff, tol_vec, _CMP_LE_OQ);
198 int mask = _mm256_movemask_pd(cmp);
199 return mask == 0xF; // all 4 lanes passed
200 }
201};
202
203// ============================================================================
204// AVX2 (256-bit) int32_t specialization
205// ============================================================================
206
207template <>
208struct Microkernel<int32_t, 256, X86_AVX>
209{
210 static constexpr my_size_t simdWidth = 8; // 256 bits / 32 bits = 8
211 static constexpr my_size_t num_registers = 16;
212 static constexpr my_size_t MR = 4;
213 static constexpr my_size_t NR_VECS = 3;
214 static constexpr my_size_t NR = NR_VECS * simdWidth; // 24
215 using VecType = __m256i;
216 using ScalarType = int32_t;
217
218 FORCE_INLINE static VecType load(const ScalarType *ptr) noexcept { return _mm256_load_si256(reinterpret_cast<const __m256i *>(ptr)); }
219 FORCE_INLINE static VecType loadu(const ScalarType *ptr) noexcept { return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptr)); }
220 FORCE_INLINE static void store(ScalarType *ptr, VecType val) noexcept { _mm256_store_si256(reinterpret_cast<__m256i *>(ptr), val); }
221 FORCE_INLINE static void storeu(ScalarType *ptr, VecType val) noexcept { _mm256_storeu_si256(reinterpret_cast<__m256i *>(ptr), val); }
222 FORCE_INLINE static VecType set1(ScalarType scalar) noexcept { return _mm256_set1_epi32(scalar); }
223
224 FORCE_INLINE static VecType add(VecType a, VecType b) noexcept { return _mm256_add_epi32(a, b); }
225 FORCE_INLINE static VecType add(VecType a, ScalarType b) noexcept { return _mm256_add_epi32(a, set1(b)); }
226
227 FORCE_INLINE static VecType mul(VecType a, VecType b) noexcept { return _mm256_mullo_epi32(a, b); }
228 FORCE_INLINE static VecType mul(VecType a, ScalarType b) noexcept { return _mm256_mullo_epi32(a, set1(b)); }
229
230 FORCE_INLINE static VecType sub(VecType a, VecType b) noexcept { return _mm256_sub_epi32(a, b); }
231 FORCE_INLINE static VecType sub(VecType a, ScalarType b) noexcept { return _mm256_sub_epi32(a, set1(b)); }
232 FORCE_INLINE static VecType sub(ScalarType a, VecType b) noexcept { return _mm256_sub_epi32(set1(a), b); }
233
234 // No SIMD integer divide exists on x86; scalar fallback.
235 FORCE_INLINE static VecType div(VecType a, VecType b) noexcept
236 {
237 alignas(32) ScalarType va[simdWidth], vb[simdWidth];
238 _mm256_storeu_si256(reinterpret_cast<__m256i *>(va), a);
239 _mm256_storeu_si256(reinterpret_cast<__m256i *>(vb), b);
240 for (my_size_t i = 0; i < simdWidth; ++i)
241 va[i] /= vb[i];
242 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(va));
243 }
245 {
246 alignas(32) ScalarType va[simdWidth];
247 _mm256_storeu_si256(reinterpret_cast<__m256i *>(va), a);
248 for (my_size_t i = 0; i < simdWidth; ++i)
249 va[i] /= b;
250 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(va));
251 }
253 {
254 alignas(32) ScalarType vb[simdWidth];
255 _mm256_storeu_si256(reinterpret_cast<__m256i *>(vb), b);
256 alignas(32) ScalarType vr[simdWidth];
257 for (my_size_t i = 0; i < simdWidth; ++i)
258 vr[i] = a / vb[i];
259 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(vr));
260 }
261
262 // NOTE: No FMA for integers in AVX2. Emulate as mul + add.
263 // fmadd: a*b + c
264 FORCE_INLINE static VecType fmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_add_epi32(_mm256_mullo_epi32(a, b), c); }
265 FORCE_INLINE static VecType fmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_add_epi32(_mm256_mullo_epi32(a, set1(b)), c); }
266
267 // fmsub: a*b - c
268 FORCE_INLINE static VecType fmsub(VecType a, VecType b, VecType c) noexcept { return _mm256_sub_epi32(_mm256_mullo_epi32(a, b), c); }
269 FORCE_INLINE static VecType fmsub(VecType a, ScalarType b, VecType c) noexcept { return _mm256_sub_epi32(_mm256_mullo_epi32(a, set1(b)), c); }
270
271 // fnmadd: -(a*b) + c => c - a*b
272 FORCE_INLINE static VecType fnmadd(VecType a, VecType b, VecType c) noexcept { return _mm256_sub_epi32(c, _mm256_mullo_epi32(a, b)); }
273 FORCE_INLINE static VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept { return _mm256_sub_epi32(c, _mm256_mullo_epi32(a, set1(b))); }
274
275 // fnmsub: -(a*b) - c
277 {
278 // 0 - a*b - c
279 __m256i neg_ab = _mm256_sub_epi32(_mm256_setzero_si256(), _mm256_mullo_epi32(a, b));
280 return _mm256_sub_epi32(neg_ab, c);
281 }
283 {
284 __m256i neg_ab = _mm256_sub_epi32(_mm256_setzero_si256(), _mm256_mullo_epi32(a, set1(b)));
285 return _mm256_sub_epi32(neg_ab, c);
286 }
287
288 FORCE_INLINE static VecType min(VecType a, VecType b) noexcept { return _mm256_min_epi32(a, b); }
289 FORCE_INLINE static VecType min(VecType a, ScalarType b) noexcept { return _mm256_min_epi32(a, set1(b)); }
290
291 FORCE_INLINE static VecType max(VecType a, VecType b) noexcept { return _mm256_max_epi32(a, b); }
292 FORCE_INLINE static VecType max(VecType a, ScalarType b) noexcept { return _mm256_max_epi32(a, set1(b)); }
293
294 // ============================================================================
295 // Gather
296 // ============================================================================
297 FORCE_INLINE static VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
298 {
299 alignas(32) int32_t idx32[simdWidth];
300 for (my_size_t i = 0; i < simdWidth; ++i)
301 {
302 idx32[i] = static_cast<int32_t>(indices[i]);
303 }
304 __m256i vindex = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(idx32));
305 return _mm256_i32gather_epi32(base, vindex, sizeof(ScalarType));
306 }
307
308 FORCE_INLINE static void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
309 {
310 alignas(32) ScalarType tmp[simdWidth];
311 _mm256_storeu_si256(reinterpret_cast<__m256i *>(tmp), val);
312 for (my_size_t i = 0; i < simdWidth; ++i)
313 base[indices[i]] = tmp[i];
314 }
315
316 FORCE_INLINE static VecType abs(VecType v) noexcept
317 {
318 return _mm256_abs_epi32(v);
319 }
320
322 {
323 __m256i diff = _mm256_sub_epi32(a, b);
324 __m256i abs_diff = _mm256_abs_epi32(diff);
325 __m256i tol_vec = _mm256_set1_epi32(tol);
326 // cmpgt: abs_diff > tol → we want none to be greater
327 __m256i cmp = _mm256_cmpgt_epi32(abs_diff, tol_vec);
328 return _mm256_testz_si256(cmp, cmp); // true if cmp is all-zero
329 }
330};
331
332// ============================================================================
333// AVX2 (256-bit) int64_t specialization
334// ============================================================================
335
336template <>
337struct Microkernel<int64_t, 256, X86_AVX>
338{
339 static constexpr my_size_t simdWidth = 4; // 256 bits / 64 bits = 4
340 static constexpr my_size_t num_registers = 16;
341 static constexpr my_size_t MR = 4;
342 static constexpr my_size_t NR_VECS = 3;
343 static constexpr my_size_t NR = NR_VECS * simdWidth; // 12
344 using VecType = __m256i;
345 using ScalarType = int64_t;
346
347 FORCE_INLINE static VecType load(const ScalarType *ptr) noexcept { return _mm256_load_si256(reinterpret_cast<const __m256i *>(ptr)); }
348 FORCE_INLINE static VecType loadu(const ScalarType *ptr) noexcept { return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(ptr)); }
349 FORCE_INLINE static void store(ScalarType *ptr, VecType val) noexcept { _mm256_store_si256(reinterpret_cast<__m256i *>(ptr), val); }
350 FORCE_INLINE static void storeu(ScalarType *ptr, VecType val) noexcept { _mm256_storeu_si256(reinterpret_cast<__m256i *>(ptr), val); }
351 FORCE_INLINE static VecType set1(ScalarType scalar) noexcept { return _mm256_set1_epi64x(scalar); }
352
353 FORCE_INLINE static VecType add(VecType a, VecType b) noexcept { return _mm256_add_epi64(a, b); }
354 FORCE_INLINE static VecType add(VecType a, ScalarType b) noexcept { return _mm256_add_epi64(a, set1(b)); }
355
356 // NOTE: AVX2 has NO native 64-bit integer multiply.
357 // Emulate via 32-bit partial products.
358 FORCE_INLINE static VecType mul(VecType a, VecType b) noexcept
359 {
360 // Low 32 bits of each 64-bit lane: _mm256_mul_epu32 gives 64-bit results
361 // For full 64×64→low64, we need:
362 // result = lo(a)*lo(b) + (lo(a)*hi(b) + hi(a)*lo(b)) << 32
363 __m256i a_hi = _mm256_srli_epi64(a, 32);
364 __m256i b_hi = _mm256_srli_epi64(b, 32);
365
366 __m256i lo_lo = _mm256_mul_epu32(a, b); // lo(a) * lo(b) → 64-bit
367 __m256i lo_hi = _mm256_mul_epu32(a, b_hi); // lo(a) * hi(b) → 64-bit
368 __m256i hi_lo = _mm256_mul_epu32(a_hi, b); // hi(a) * lo(b) → 64-bit
369
370 __m256i cross = _mm256_add_epi64(lo_hi, hi_lo);
371 __m256i cross_shifted = _mm256_slli_epi64(cross, 32);
372
373 return _mm256_add_epi64(lo_lo, cross_shifted);
374 }
375 FORCE_INLINE static VecType mul(VecType a, ScalarType b) noexcept { return mul(a, set1(b)); }
376
377 FORCE_INLINE static VecType sub(VecType a, VecType b) noexcept { return _mm256_sub_epi64(a, b); }
378 FORCE_INLINE static VecType sub(VecType a, ScalarType b) noexcept { return _mm256_sub_epi64(a, set1(b)); }
379 FORCE_INLINE static VecType sub(ScalarType a, VecType b) noexcept { return _mm256_sub_epi64(set1(a), b); }
380
381 // No SIMD integer divide exists on x86; scalar fallback.
382 FORCE_INLINE static VecType div(VecType a, VecType b) noexcept
383 {
384 alignas(32) ScalarType va[simdWidth], vb[simdWidth];
385 _mm256_storeu_si256(reinterpret_cast<__m256i *>(va), a);
386 _mm256_storeu_si256(reinterpret_cast<__m256i *>(vb), b);
387 for (my_size_t i = 0; i < simdWidth; ++i)
388 va[i] /= vb[i];
389 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(va));
390 }
392 {
393 alignas(32) ScalarType va[simdWidth];
394 _mm256_storeu_si256(reinterpret_cast<__m256i *>(va), a);
395 for (my_size_t i = 0; i < simdWidth; ++i)
396 va[i] /= b;
397 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(va));
398 }
400 {
401 alignas(32) ScalarType vb[simdWidth];
402 _mm256_storeu_si256(reinterpret_cast<__m256i *>(vb), b);
403 alignas(32) ScalarType vr[simdWidth];
404 for (my_size_t i = 0; i < simdWidth; ++i)
405 vr[i] = a / vb[i];
406 return _mm256_loadu_si256(reinterpret_cast<const __m256i *>(vr));
407 }
408
409 // Emulated FMA
410 FORCE_INLINE static VecType fmadd(VecType a, VecType b, VecType c) noexcept { return add(mul(a, b), c); }
411 FORCE_INLINE static VecType fmadd(VecType a, ScalarType b, VecType c) noexcept { return add(mul(a, b), c); }
412
413 FORCE_INLINE static VecType fmsub(VecType a, VecType b, VecType c) noexcept { return sub(mul(a, b), c); }
414 FORCE_INLINE static VecType fmsub(VecType a, ScalarType b, VecType c) noexcept { return sub(mul(a, b), c); }
415
416 FORCE_INLINE static VecType fnmadd(VecType a, VecType b, VecType c) noexcept { return sub(c, mul(a, b)); }
417 FORCE_INLINE static VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept { return sub(c, mul(a, b)); }
418
420 {
421 return sub(_mm256_setzero_si256(), add(mul(a, b), c));
422 }
424 {
425 return sub(_mm256_setzero_si256(), add(mul(a, b), c));
426 }
427
428 // NOTE: AVX2 has no _mm256_min/max_epi64. Emulate via comparison.
429 FORCE_INLINE static VecType min(VecType a, VecType b) noexcept
430 {
431 // AVX-512 has _mm256_min_epi64, but for AVX2 we emulate:
432 __m256i gt = _mm256_cmpgt_epi64(a, b); // a > b ? 0xFFF... : 0
433 return _mm256_blendv_epi8(a, b, gt); // pick b where a > b
434 }
435 FORCE_INLINE static VecType min(VecType a, ScalarType b) noexcept { return min(a, set1(b)); }
436
437 FORCE_INLINE static VecType max(VecType a, VecType b) noexcept
438 {
439 __m256i gt = _mm256_cmpgt_epi64(a, b);
440 return _mm256_blendv_epi8(b, a, gt); // pick a where a > b
441 }
442 FORCE_INLINE static VecType max(VecType a, ScalarType b) noexcept { return max(a, set1(b)); }
443
444 // ============================================================================
445 // Gather
446 // ============================================================================
447 FORCE_INLINE static VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
448 {
449 // _mm256_i64gather_epi64 expects 4 × 64-bit indices — matches my_size_t on 64-bit
450 __m256i vindex = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(indices));
451 return _mm256_i64gather_epi64(reinterpret_cast<const long long *>(base), vindex, sizeof(ScalarType));
452 }
453
454 FORCE_INLINE static void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
455 {
456 alignas(32) ScalarType tmp[simdWidth];
457 _mm256_storeu_si256(reinterpret_cast<__m256i *>(tmp), val);
458 for (my_size_t i = 0; i < simdWidth; ++i)
459 base[indices[i]] = tmp[i];
460 }
461
462 FORCE_INLINE static VecType abs(VecType v) noexcept
463 {
464 // AVX2 has no _mm256_abs_epi64. Emulate:
465 __m256i sign = _mm256_cmpgt_epi64(_mm256_setzero_si256(), v); // sign = (v < 0) ? 0xFFF... : 0
466 __m256i neg_v = _mm256_sub_epi64(_mm256_setzero_si256(), v);
467 return _mm256_blendv_epi8(v, neg_v, sign); // pick neg_v where v < 0
468 }
469
471 {
472 __m256i diff = _mm256_sub_epi64(a, b);
473 __m256i abs_diff = abs(diff);
474 __m256i tol_vec = _mm256_set1_epi64x(tol);
475 // abs_diff > tol ?
476 __m256i gt = _mm256_cmpgt_epi64(abs_diff, tol_vec);
477 return _mm256_testz_si256(gt, gt);
478 }
479};
480
481#endif // __AVX2_MICROKERNEL_H__
Global configuration for the tesseract tensor library.
#define my_size_t
Size/index type used throughout the library.
Definition config.h:126
#define FORCE_INLINE
Hint the compiler to always inline a function.
Definition config.h:26
static FORCE_INLINE void storeu(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:133
static FORCE_INLINE void store(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:132
static FORCE_INLINE VecType max(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:170
static FORCE_INLINE VecType fnmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:159
static FORCE_INLINE VecType sub(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:144
static FORCE_INLINE VecType fmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:156
static FORCE_INLINE VecType abs(VecType v) noexcept
Definition avx2_microkernel.h:186
static FORCE_INLINE VecType add(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:136
__m256d VecType
Definition avx2_microkernel.h:127
static FORCE_INLINE VecType loadu(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:131
static FORCE_INLINE VecType min(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:166
static FORCE_INLINE VecType set1(ScalarType scalar) noexcept
Definition avx2_microkernel.h:134
static FORCE_INLINE VecType div(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:148
static FORCE_INLINE VecType div(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:147
static FORCE_INLINE bool all_within_tolerance(VecType a, VecType b, ScalarType tol) noexcept
Definition avx2_microkernel.h:192
static FORCE_INLINE VecType max(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:169
static FORCE_INLINE VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:164
static FORCE_INLINE VecType add(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:137
static FORCE_INLINE VecType mul(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:140
static FORCE_INLINE VecType mul(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:139
static FORCE_INLINE VecType fmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:155
static FORCE_INLINE VecType fmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:152
static FORCE_INLINE VecType load(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:130
static FORCE_INLINE VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
Definition avx2_microkernel.h:172
static FORCE_INLINE VecType sub(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:143
double ScalarType
Definition avx2_microkernel.h:128
static FORCE_INLINE void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
Definition avx2_microkernel.h:178
static FORCE_INLINE VecType div(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:146
static FORCE_INLINE VecType min(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:167
static FORCE_INLINE VecType fnmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:163
static FORCE_INLINE VecType sub(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:142
static FORCE_INLINE VecType fmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:151
static FORCE_INLINE VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:160
static FORCE_INLINE VecType add(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:34
static FORCE_INLINE void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
Definition avx2_microkernel.h:92
static FORCE_INLINE VecType sub(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:40
static FORCE_INLINE VecType fnmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:57
static FORCE_INLINE VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
Definition avx2_microkernel.h:73
static FORCE_INLINE VecType add(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:35
float ScalarType
Definition avx2_microkernel.h:26
static FORCE_INLINE VecType set1(ScalarType scalar) noexcept
Definition avx2_microkernel.h:32
static FORCE_INLINE VecType sub(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:42
static FORCE_INLINE VecType max(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:67
static FORCE_INLINE VecType sub(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:41
static FORCE_INLINE VecType load(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:28
static FORCE_INLINE VecType mul(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:38
__m256 VecType
Definition avx2_microkernel.h:25
static FORCE_INLINE VecType loadu(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:29
static FORCE_INLINE VecType max(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:68
static FORCE_INLINE VecType fmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:53
static FORCE_INLINE VecType min(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:65
static FORCE_INLINE void storeu(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:31
static FORCE_INLINE VecType div(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:44
static FORCE_INLINE VecType min(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:64
static FORCE_INLINE VecType fmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:54
static FORCE_INLINE VecType div(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:46
static FORCE_INLINE VecType fmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:50
static FORCE_INLINE VecType mul(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:37
static FORCE_INLINE VecType fmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:49
static FORCE_INLINE bool all_within_tolerance(VecType a, VecType b, ScalarType tol) noexcept
Definition avx2_microkernel.h:107
static FORCE_INLINE VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:58
static FORCE_INLINE VecType abs(VecType v) noexcept
Definition avx2_microkernel.h:100
static FORCE_INLINE VecType fnmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:61
static FORCE_INLINE void store(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:30
static FORCE_INLINE VecType div(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:45
static FORCE_INLINE VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:62
static FORCE_INLINE VecType div(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:244
static FORCE_INLINE VecType load(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:218
__m256i VecType
Definition avx2_microkernel.h:215
static FORCE_INLINE VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:273
static FORCE_INLINE VecType mul(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:227
static FORCE_INLINE VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
Definition avx2_microkernel.h:297
static FORCE_INLINE VecType add(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:225
static FORCE_INLINE bool all_within_tolerance(VecType a, VecType b, ScalarType tol) noexcept
Definition avx2_microkernel.h:321
static FORCE_INLINE VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:282
static FORCE_INLINE VecType sub(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:230
static FORCE_INLINE VecType fnmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:272
static FORCE_INLINE void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
Definition avx2_microkernel.h:308
int32_t ScalarType
Definition avx2_microkernel.h:216
static FORCE_INLINE VecType fmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:268
static FORCE_INLINE VecType abs(VecType v) noexcept
Definition avx2_microkernel.h:316
static FORCE_INLINE VecType sub(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:231
static FORCE_INLINE void storeu(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:221
static FORCE_INLINE VecType loadu(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:219
static FORCE_INLINE VecType mul(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:228
static FORCE_INLINE VecType fmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:264
static FORCE_INLINE VecType add(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:224
static FORCE_INLINE VecType fnmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:276
static FORCE_INLINE VecType max(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:291
static FORCE_INLINE VecType fmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:269
static FORCE_INLINE VecType div(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:252
static FORCE_INLINE VecType set1(ScalarType scalar) noexcept
Definition avx2_microkernel.h:222
static FORCE_INLINE VecType fmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:265
static FORCE_INLINE VecType sub(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:232
static FORCE_INLINE VecType max(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:292
static FORCE_INLINE VecType div(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:235
static FORCE_INLINE void store(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:220
static FORCE_INLINE VecType min(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:289
static FORCE_INLINE VecType min(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:288
__m256i VecType
Definition avx2_microkernel.h:344
static FORCE_INLINE VecType sub(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:377
static FORCE_INLINE VecType mul(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:358
static FORCE_INLINE VecType min(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:435
static FORCE_INLINE VecType div(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:382
static FORCE_INLINE VecType fmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:410
static FORCE_INLINE VecType fnmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:423
static FORCE_INLINE VecType fnmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:417
static FORCE_INLINE VecType max(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:437
static FORCE_INLINE VecType loadu(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:348
static FORCE_INLINE VecType max(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:442
static FORCE_INLINE VecType set1(ScalarType scalar) noexcept
Definition avx2_microkernel.h:351
static FORCE_INLINE VecType fnmadd(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:416
int64_t ScalarType
Definition avx2_microkernel.h:345
static FORCE_INLINE bool all_within_tolerance(VecType a, VecType b, ScalarType tol) noexcept
Definition avx2_microkernel.h:470
static FORCE_INLINE VecType sub(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:378
static FORCE_INLINE VecType abs(VecType v) noexcept
Definition avx2_microkernel.h:462
static FORCE_INLINE void store(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:349
static FORCE_INLINE VecType sub(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:379
static FORCE_INLINE void scatter(ScalarType *base, const my_size_t *indices, VecType val) noexcept
Definition avx2_microkernel.h:454
static FORCE_INLINE void storeu(ScalarType *ptr, VecType val) noexcept
Definition avx2_microkernel.h:350
static FORCE_INLINE VecType add(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:354
static FORCE_INLINE VecType fnmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:419
static FORCE_INLINE VecType load(const ScalarType *ptr) noexcept
Definition avx2_microkernel.h:347
static FORCE_INLINE VecType fmsub(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:414
static FORCE_INLINE VecType min(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:429
static FORCE_INLINE VecType mul(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:375
static FORCE_INLINE VecType div(ScalarType a, VecType b) noexcept
Definition avx2_microkernel.h:399
static FORCE_INLINE VecType add(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:353
static FORCE_INLINE VecType div(VecType a, ScalarType b) noexcept
Definition avx2_microkernel.h:391
static FORCE_INLINE VecType fmadd(VecType a, ScalarType b, VecType c) noexcept
Definition avx2_microkernel.h:411
static FORCE_INLINE VecType gather(const ScalarType *base, const my_size_t *indices) noexcept
Definition avx2_microkernel.h:447
static FORCE_INLINE VecType fmsub(VecType a, VecType b, VecType c) noexcept
Definition avx2_microkernel.h:413
Definition microkernel_base.h:16
static FORCE_INLINE VecType mul(VecType a, VecType b) noexcept
static constexpr my_size_t simdWidth
Definition microkernel_base.h:17
static FORCE_INLINE VecType set1(T scalar) noexcept
static FORCE_INLINE VecType add(VecType a, VecType b) noexcept
static FORCE_INLINE VecType sub(VecType a, VecType b) noexcept
Definition avx2_microkernel.h:9