tesseract++
0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
core
include
fused
kernel_ops
deprecated
einsum_old.h
Go to the documentation of this file.
1
// // contract two expression along a specific dimension (axis) and return the result
2
// template <typename LeftExpr, typename RightExpr>
3
// requires(
4
// algebra::is_tensor_v<LeftExpr> &&
5
// algebra::is_tensor_v<RightExpr>)
6
// static FusedTensorND einsum_old(const BaseExpr<LeftExpr> &_tensor1, const BaseExpr<RightExpr> &_tensor2, const my_size_t a, const my_size_t b)
7
// {
8
// static const my_size_t Dims1 = LeftExpr::NumDims;
9
// static const my_size_t Dims2 = RightExpr::NumDims;
10
11
// // static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimension");
12
// // static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimension");
13
14
// if constexpr (Dims1 < 2)
15
// {
16
// MyErrorHandler::error("Tensor 1 must have at least 2 dimension");
17
// }
18
// if constexpr (Dims2 < 2)
19
// {
20
// MyErrorHandler::error("Tensor 2 must have at least 2 dimension");
21
// }
22
23
// // check if a and b are valid dimensions at runtime
24
// if (a >= Dims1 || b >= Dims2)
25
// {
26
// MyErrorHandler::error("Invalid dimensions");
27
// }
28
29
// // check if the a axis of tensor1 is equal to the b axis of tensor2 at runtime
30
// if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
31
// {
32
// MyErrorHandler::error("Dimensions mismatch between tensors for einsum operation");
33
// }
34
35
// // ------------------------------------------------------
36
// // TODO: all this inside the ----- can be done at compile time only
37
// // calculate the new dimensions
38
// constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
39
// my_size_t newDims[n_newDims];
40
// my_size_t k = 0;
41
// for (my_size_t i = 0; i < Dims1; ++i)
42
// {
43
// if (i != a)
44
// {
45
// newDims[k++] = _tensor1.derived().getDim(i);
46
// }
47
// }
48
49
// for (my_size_t i = 0; i < Dims2; ++i)
50
// {
51
// if (i != b)
52
// {
53
// newDims[k++] = _tensor2.derived().getDim(i);
54
// }
55
// }
56
57
// // create a new tensor with the new dimensions
58
// FusedTensorND<T, Dims...> _outp;
59
60
// // check if the new dimensions one by one are the same as the dimensions of the new tensor
61
// for (my_size_t i = 0; i < n_newDims; ++i)
62
// {
63
// if (newDims[i] != _outp.getDim(i))
64
// {
65
// MyErrorHandler::error("Dimensions mismatch in output tensor");
66
// }
67
// }
68
// // ------------------------------------------------------
69
70
// // calculate the total number of combinations and create a 2D array to store them
71
// constexpr my_size_t total_combinations = (1 * ... * Dims);
72
// my_size_t combinations[total_combinations][n_newDims];
73
74
// // generate all the combinations
75
// generate_combinations(newDims, combinations);
76
77
// // calculate the contraction
78
// for (my_size_t comb = 0; comb < total_combinations; ++comb)
79
// {
80
// T sum = 0;
81
// my_size_t K = _tensor1.derived().getDim(a); // or _tensor2.derived().getDim(b) since they are equal
82
// for (my_size_t k = 0; k < K; ++k)
83
// {
84
// my_size_t indices1[Dims1] = {0};
85
// my_size_t indices2[Dims2] = {0};
86
87
// my_size_t l = 0;
88
// for (my_size_t i = 0; i < Dims1; ++i)
89
// {
90
// if (i != a)
91
// {
92
// indices1[i] = combinations[comb][l++];
93
// }
94
// else
95
// {
96
// indices1[i] = k;
97
// }
98
// }
99
100
// l = Dims1 - 1;
101
// for (my_size_t i = 0; i < Dims2; ++i)
102
// {
103
// if (i != b)
104
// {
105
// indices2[i] = combinations[comb][l++];
106
// }
107
// else
108
// {
109
// indices2[i] = k;
110
// }
111
// }
112
// sum += _tensor1.derived()(indices1) * _tensor2.derived()(indices2);
113
// }
114
// _outp(combinations[comb]) = sum;
115
// }
116
// return _outp;
117
// }
118
119
120
// /**
121
// * @brief Contract two tensors along specified axes using SIMD dot products.
122
// *
123
// * Computes C[free_indices] = sum_k A[...,k,...] * B[...,k,...]
124
// * where k runs along axis `a` of tensor1 and axis `b` of tensor2.
125
// *
126
// * For each output element, builds the fiber start (base) for each tensor
127
// * by setting the contraction index to 0, then calls dot() with the
128
// * physical stride along the contraction axis.
129
// *
130
// * ============================================================================
131
// * EXAMPLE: C[2,2] = A[2,3] * B[3,2], contract a=1, b=0
132
// * ============================================================================
133
// *
134
// * A[2,3] padded to [2,4]: B[3,2] padded to [3,4]:
135
// * [a00 a01 a02 P] [b00 b01 P P]
136
// * [a10 a11 a12 P] [b10 b11 P P]
137
// * Layout1::stride(1) = 1 [b20 b21 P P]
138
// * Layout2::stride(0) = 4
139
// *
140
// * Output C[i,j]:
141
// * C[0,0]: A[0,:] dot B[:,0]
142
// * base1=0, stride1=1 → [a00, a01, a02] contiguous
143
// * base2=0, stride2=4 → [b00, b10, b20] strided
144
// * → dot(A, 0, 1, B, 0, 4, 3)
145
// *
146
// * C[0,1]: A[0,:] dot B[:,1]
147
// * base1=0, stride1=1
148
// * base2=1, stride2=4 → [b01, b11, b21]
149
// * → dot(A, 0, 1, B, 1, 4, 3)
150
// *
151
// * C[1,0]: A[1,:] dot B[:,0]
152
// * base1=4, stride1=1 → [a10, a11, a12]
153
// * base2=0, stride2=4
154
// * → dot(A, 4, 1, B, 0, 4, 3)
155
// *
156
// * ============================================================================
157
// */
158
// template <typename LeftExpr, typename RightExpr>
159
// requires(expression::traits<LeftExpr>::IsPhysical &&
160
// expression::traits<RightExpr>::IsPhysical)
161
// static FusedTensorND einsum_new_old(
162
// const BaseExpr<LeftExpr> &_tensor1,
163
// const BaseExpr<RightExpr> &_tensor2,
164
// const my_size_t a,
165
// const my_size_t b)
166
// {
167
// static constexpr my_size_t Dims1 = LeftExpr::NumDims;
168
// static constexpr my_size_t Dims2 = RightExpr::NumDims;
169
170
// static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimensions");
171
// static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimensions");
172
173
// // Runtime validation
174
// if (a >= Dims1 || b >= Dims2)
175
// MyErrorHandler::error("Invalid contraction axis");
176
177
// if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
178
// MyErrorHandler::error("Contraction dimensions mismatch");
179
180
// // Contraction length and physical strides along contraction axes
181
// using Layout1 = typename LeftExpr::Layout;
182
// using Layout2 = typename RightExpr::Layout;
183
// using Kern = KernelOps<T, BITS, DefaultArch>;
184
185
// const my_size_t K = _tensor1.derived().getDim(a);
186
// const my_size_t stride1 = Layout1::stride(a);
187
// const my_size_t stride2 = Layout2::stride(b);
188
189
// // Build output dimensions (all dims except contracted ones)
190
// static constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
191
// my_size_t newDims[n_newDims];
192
// my_size_t d = 0;
193
// for (my_size_t i = 0; i < Dims1; ++i)
194
// if (i != a)
195
// newDims[d++] = _tensor1.derived().getDim(i);
196
// for (my_size_t i = 0; i < Dims2; ++i)
197
// if (i != b)
198
// newDims[d++] = _tensor2.derived().getDim(i);
199
200
// // Validate output dimensions match
201
// FusedTensorND _outp;
202
// for (my_size_t i = 0; i < n_newDims; ++i)
203
// {
204
// if (newDims[i] != _outp.getDim(i))
205
// MyErrorHandler::error("Output dimensions mismatch");
206
// }
207
208
// // Generate all output index combinations
209
// static constexpr my_size_t total_combinations = (1 * ... * Dims);
210
// my_size_t combinations[total_combinations][n_newDims];
211
// generate_combinations(newDims, combinations);
212
213
// // For each output element, compute dot product of two fibers
214
// for (my_size_t comb = 0; comb < total_combinations; ++comb)
215
// {
216
// // Build tensor1 indices with contraction axis = 0
217
// my_size_t indices1[Dims1] = {0};
218
// my_size_t l = 0;
219
// for (my_size_t i = 0; i < Dims1; ++i)
220
// {
221
// if (i != a)
222
// indices1[i] = combinations[comb][l++];
223
// // else indices1[i] = 0 (fiber start)
224
// }
225
226
// // Build tensor2 indices with contraction axis = 0
227
// my_size_t indices2[Dims2] = {0};
228
// for (my_size_t i = 0; i < Dims2; ++i)
229
// {
230
// if (i != b)
231
// indices2[i] = combinations[comb][l++];
232
// // else indices2[i] = 0 (fiber start)
233
// }
234
235
// // Physical base offsets where each fiber starts
236
// const my_size_t base1 = Layout1::logical_coords_to_physical_flat(indices1);
237
// const my_size_t base2 = Layout2::logical_coords_to_physical_flat(indices2);
238
239
// // Dot product replaces the inner k-loop
240
// _outp(combinations[comb]) = Kern::dot(
241
// _tensor1.derived(), base1, stride1,
242
// _tensor2.derived(), base2, stride2,
243
// K);
244
// }
245
246
// return _outp;
247
// }
248
249
// /**
250
// * @brief Contract two tensors along specified axes using SIMD dot products.
251
// *
252
// * Computes C[free_indices] = sum_k A[...,k,...] * B[...,k,...]
253
// * where k runs along axis `a` of tensor1 and axis `b` of tensor2.
254
// *
255
// * Uses pre-computed stride maps to convert output coordinates directly
256
// * into physical base offsets for each input tensor, avoiding per-element
257
// * index array construction and logical_coords_to_physical_flat calls.
258
// *
259
// * ============================================================================
260
// * STRIDE MAP EXAMPLE: C[2,2] = A[2,3] * B[3,2], contract a=1, b=0
261
// * ============================================================================
262
// *
263
// * A[2,3] padded to [2,4]: B[3,2] padded to [3,4]:
264
// * strides: [4, 1] strides: [4, 1]
265
// * contract dim 1 (stride=1) contract dim 0 (stride=4)
266
// *
267
// * Output dims: [2, 2] (A's dim 0, B's dim 1)
268
// *
269
// * Stride maps (one entry per output dim):
270
// * strides1_map = [4, 0] ← output dim 0 → A's dim 0 (stride 4)
271
// * strides2_map = [0, 1] ← output dim 1 → B's dim 1 (stride 1)
272
// *
273
// * For output coord (1, 1):
274
// * base1 = 1*4 + 1*0 = 4 (start of A[1,:])
275
// * base2 = 1*0 + 1*1 = 1 (start of B[:,1])
276
// * → dot(A, 4, 1, B, 1, 4, 3)
277
// *
278
// * ============================================================================
279
// */
280
// template <typename LeftExpr, typename RightExpr>
281
// requires(expression::traits<LeftExpr>::IsPhysical &&
282
// expression::traits<RightExpr>::IsPhysical)
283
// static FusedTensorND einsum_new(
284
// const BaseExpr<LeftExpr> &_tensor1,
285
// const BaseExpr<RightExpr> &_tensor2,
286
// const my_size_t a,
287
// const my_size_t b)
288
// {
289
// static constexpr my_size_t Dims1 = LeftExpr::NumDims;
290
// static constexpr my_size_t Dims2 = RightExpr::NumDims;
291
292
// static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimensions");
293
// static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimensions");
294
295
// // Runtime validation
296
// if (a >= Dims1 || b >= Dims2)
297
// MyErrorHandler::error("Invalid contraction axis");
298
299
// if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
300
// MyErrorHandler::error("Contraction dimensions mismatch");
301
302
// using Layout1 = typename LeftExpr::Layout;
303
// using Layout2 = typename RightExpr::Layout;
304
// using OutputLayout = Layout;
305
// using Kern = KernelOps<T, BITS, DefaultArch>;
306
307
// const my_size_t K_len = _tensor1.derived().getDim(a);
308
// const my_size_t contract_stride1 = Layout1::stride(a);
309
// const my_size_t contract_stride2 = Layout2::stride(b);
310
311
// // ====================================================================
312
// // Build stride maps
313
// // ====================================================================
314
// // For each output dimension d:
315
// // strides1_map[d] = physical stride in tensor1 (0 if d comes from tensor2)
316
// // strides2_map[d] = physical stride in tensor2 (0 if d comes from tensor1)
317
// // out_dims[d] = size of output dimension d
318
319
// static constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
320
// my_size_t strides1_map[n_newDims];
321
// my_size_t strides2_map[n_newDims];
322
// my_size_t out_dims[n_newDims];
323
324
// my_size_t d = 0;
325
// for (my_size_t i = 0; i < Dims1; ++i)
326
// {
327
// if (i != a)
328
// {
329
// out_dims[d] = _tensor1.derived().getDim(i);
330
// strides1_map[d] = Layout1::stride(i);
331
// strides2_map[d] = 0;
332
// ++d;
333
// }
334
// }
335
// for (my_size_t i = 0; i < Dims2; ++i)
336
// {
337
// if (i != b)
338
// {
339
// out_dims[d] = _tensor2.derived().getDim(i);
340
// strides1_map[d] = 0;
341
// strides2_map[d] = Layout2::stride(i);
342
// ++d;
343
// }
344
// }
345
346
// // ====================================================================
347
// // Validate output dimensions
348
// // ====================================================================
349
350
// FusedTensorND _outp;
351
// for (my_size_t i = 0; i < n_newDims; ++i)
352
// {
353
// if (out_dims[i] != _outp.getDim(i))
354
// MyErrorHandler::error("Output dimensions mismatch");
355
// }
356
357
// // ====================================================================
358
// // Pre-compute output physical strides for direct memory writes
359
// // ====================================================================
360
361
// my_size_t out_strides[n_newDims];
362
// for (my_size_t i = 0; i < n_newDims; ++i)
363
// out_strides[i] = OutputLayout::stride(i);
364
365
// T *out_ptr = _outp.data();
366
367
// // ====================================================================
368
// // Main loop: iterate all output elements
369
// // ====================================================================
370
371
// static constexpr my_size_t total_elements = (1 * ... * Dims);
372
373
// for (my_size_t flat = 0; flat < total_elements; ++flat)
374
// {
375
// // Decompose flat index into output coordinates (row-major)
376
// my_size_t coords[n_newDims];
377
// my_size_t tmp = flat;
378
// for (my_size_t i = n_newDims; i-- > 0;)
379
// {
380
// coords[i] = tmp % out_dims[i];
381
// tmp /= out_dims[i];
382
// }
383
384
// // Compute physical bases via stride maps (dot product of coords × strides)
385
// my_size_t base1 = 0;
386
// my_size_t base2 = 0;
387
// my_size_t out_phys = 0;
388
// for (my_size_t i = 0; i < n_newDims; ++i)
389
// {
390
// base1 += coords[i] * strides1_map[i];
391
// base2 += coords[i] * strides2_map[i];
392
// out_phys += coords[i] * out_strides[i];
393
// }
394
395
// // Dot product along contraction axis → write directly to output
396
// out_ptr[out_phys] = Kern::dot(
397
// _tensor1.derived(), base1, contract_stride1,
398
// _tensor2.derived(), base2, contract_stride2,
399
// K_len);
400
// }
401
402
// return _outp;
403
// }
Generated by
1.9.8