tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
einsum_old.h
Go to the documentation of this file.
1 // // contract two expression along a specific dimension (axis) and return the result
2 // template <typename LeftExpr, typename RightExpr>
3 // requires(
4 // algebra::is_tensor_v<LeftExpr> &&
5 // algebra::is_tensor_v<RightExpr>)
6 // static FusedTensorND einsum_old(const BaseExpr<LeftExpr> &_tensor1, const BaseExpr<RightExpr> &_tensor2, const my_size_t a, const my_size_t b)
7 // {
8 // static const my_size_t Dims1 = LeftExpr::NumDims;
9 // static const my_size_t Dims2 = RightExpr::NumDims;
10
11 // // static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimension");
12 // // static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimension");
13
14 // if constexpr (Dims1 < 2)
15 // {
16 // MyErrorHandler::error("Tensor 1 must have at least 2 dimension");
17 // }
18 // if constexpr (Dims2 < 2)
19 // {
20 // MyErrorHandler::error("Tensor 2 must have at least 2 dimension");
21 // }
22
23 // // check if a and b are valid dimensions at runtime
24 // if (a >= Dims1 || b >= Dims2)
25 // {
26 // MyErrorHandler::error("Invalid dimensions");
27 // }
28
29 // // check if the a axis of tensor1 is equal to the b axis of tensor2 at runtime
30 // if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
31 // {
32 // MyErrorHandler::error("Dimensions mismatch between tensors for einsum operation");
33 // }
34
35 // // ------------------------------------------------------
36 // // TODO: all this inside the ----- can be done at compile time only
37 // // calculate the new dimensions
38 // constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
39 // my_size_t newDims[n_newDims];
40 // my_size_t k = 0;
41 // for (my_size_t i = 0; i < Dims1; ++i)
42 // {
43 // if (i != a)
44 // {
45 // newDims[k++] = _tensor1.derived().getDim(i);
46 // }
47 // }
48
49 // for (my_size_t i = 0; i < Dims2; ++i)
50 // {
51 // if (i != b)
52 // {
53 // newDims[k++] = _tensor2.derived().getDim(i);
54 // }
55 // }
56
57 // // create a new tensor with the new dimensions
58 // FusedTensorND<T, Dims...> _outp;
59
60 // // check if the new dimensions one by one are the same as the dimensions of the new tensor
61 // for (my_size_t i = 0; i < n_newDims; ++i)
62 // {
63 // if (newDims[i] != _outp.getDim(i))
64 // {
65 // MyErrorHandler::error("Dimensions mismatch in output tensor");
66 // }
67 // }
68 // // ------------------------------------------------------
69
70 // // calculate the total number of combinations and create a 2D array to store them
71 // constexpr my_size_t total_combinations = (1 * ... * Dims);
72 // my_size_t combinations[total_combinations][n_newDims];
73
74 // // generate all the combinations
75 // generate_combinations(newDims, combinations);
76
77 // // calculate the contraction
78 // for (my_size_t comb = 0; comb < total_combinations; ++comb)
79 // {
80 // T sum = 0;
81 // my_size_t K = _tensor1.derived().getDim(a); // or _tensor2.derived().getDim(b) since they are equal
82 // for (my_size_t k = 0; k < K; ++k)
83 // {
84 // my_size_t indices1[Dims1] = {0};
85 // my_size_t indices2[Dims2] = {0};
86
87 // my_size_t l = 0;
88 // for (my_size_t i = 0; i < Dims1; ++i)
89 // {
90 // if (i != a)
91 // {
92 // indices1[i] = combinations[comb][l++];
93 // }
94 // else
95 // {
96 // indices1[i] = k;
97 // }
98 // }
99
100 // l = Dims1 - 1;
101 // for (my_size_t i = 0; i < Dims2; ++i)
102 // {
103 // if (i != b)
104 // {
105 // indices2[i] = combinations[comb][l++];
106 // }
107 // else
108 // {
109 // indices2[i] = k;
110 // }
111 // }
112 // sum += _tensor1.derived()(indices1) * _tensor2.derived()(indices2);
113 // }
114 // _outp(combinations[comb]) = sum;
115 // }
116 // return _outp;
117 // }
118
119
120// /**
121 // * @brief Contract two tensors along specified axes using SIMD dot products.
122 // *
123 // * Computes C[free_indices] = sum_k A[...,k,...] * B[...,k,...]
124 // * where k runs along axis `a` of tensor1 and axis `b` of tensor2.
125 // *
126 // * For each output element, builds the fiber start (base) for each tensor
127 // * by setting the contraction index to 0, then calls dot() with the
128 // * physical stride along the contraction axis.
129 // *
130 // * ============================================================================
131 // * EXAMPLE: C[2,2] = A[2,3] * B[3,2], contract a=1, b=0
132 // * ============================================================================
133 // *
134 // * A[2,3] padded to [2,4]: B[3,2] padded to [3,4]:
135 // * [a00 a01 a02 P] [b00 b01 P P]
136 // * [a10 a11 a12 P] [b10 b11 P P]
137 // * Layout1::stride(1) = 1 [b20 b21 P P]
138 // * Layout2::stride(0) = 4
139 // *
140 // * Output C[i,j]:
141 // * C[0,0]: A[0,:] dot B[:,0]
142 // * base1=0, stride1=1 → [a00, a01, a02] contiguous
143 // * base2=0, stride2=4 → [b00, b10, b20] strided
144 // * → dot(A, 0, 1, B, 0, 4, 3)
145 // *
146 // * C[0,1]: A[0,:] dot B[:,1]
147 // * base1=0, stride1=1
148 // * base2=1, stride2=4 → [b01, b11, b21]
149 // * → dot(A, 0, 1, B, 1, 4, 3)
150 // *
151 // * C[1,0]: A[1,:] dot B[:,0]
152 // * base1=4, stride1=1 → [a10, a11, a12]
153 // * base2=0, stride2=4
154 // * → dot(A, 4, 1, B, 0, 4, 3)
155 // *
156 // * ============================================================================
157 // */
158 // template <typename LeftExpr, typename RightExpr>
159 // requires(expression::traits<LeftExpr>::IsPhysical &&
160 // expression::traits<RightExpr>::IsPhysical)
161 // static FusedTensorND einsum_new_old(
162 // const BaseExpr<LeftExpr> &_tensor1,
163 // const BaseExpr<RightExpr> &_tensor2,
164 // const my_size_t a,
165 // const my_size_t b)
166 // {
167 // static constexpr my_size_t Dims1 = LeftExpr::NumDims;
168 // static constexpr my_size_t Dims2 = RightExpr::NumDims;
169
170 // static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimensions");
171 // static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimensions");
172
173 // // Runtime validation
174 // if (a >= Dims1 || b >= Dims2)
175 // MyErrorHandler::error("Invalid contraction axis");
176
177 // if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
178 // MyErrorHandler::error("Contraction dimensions mismatch");
179
180 // // Contraction length and physical strides along contraction axes
181 // using Layout1 = typename LeftExpr::Layout;
182 // using Layout2 = typename RightExpr::Layout;
183 // using Kern = KernelOps<T, BITS, DefaultArch>;
184
185 // const my_size_t K = _tensor1.derived().getDim(a);
186 // const my_size_t stride1 = Layout1::stride(a);
187 // const my_size_t stride2 = Layout2::stride(b);
188
189 // // Build output dimensions (all dims except contracted ones)
190 // static constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
191 // my_size_t newDims[n_newDims];
192 // my_size_t d = 0;
193 // for (my_size_t i = 0; i < Dims1; ++i)
194 // if (i != a)
195 // newDims[d++] = _tensor1.derived().getDim(i);
196 // for (my_size_t i = 0; i < Dims2; ++i)
197 // if (i != b)
198 // newDims[d++] = _tensor2.derived().getDim(i);
199
200 // // Validate output dimensions match
201 // FusedTensorND _outp;
202 // for (my_size_t i = 0; i < n_newDims; ++i)
203 // {
204 // if (newDims[i] != _outp.getDim(i))
205 // MyErrorHandler::error("Output dimensions mismatch");
206 // }
207
208 // // Generate all output index combinations
209 // static constexpr my_size_t total_combinations = (1 * ... * Dims);
210 // my_size_t combinations[total_combinations][n_newDims];
211 // generate_combinations(newDims, combinations);
212
213 // // For each output element, compute dot product of two fibers
214 // for (my_size_t comb = 0; comb < total_combinations; ++comb)
215 // {
216 // // Build tensor1 indices with contraction axis = 0
217 // my_size_t indices1[Dims1] = {0};
218 // my_size_t l = 0;
219 // for (my_size_t i = 0; i < Dims1; ++i)
220 // {
221 // if (i != a)
222 // indices1[i] = combinations[comb][l++];
223 // // else indices1[i] = 0 (fiber start)
224 // }
225
226 // // Build tensor2 indices with contraction axis = 0
227 // my_size_t indices2[Dims2] = {0};
228 // for (my_size_t i = 0; i < Dims2; ++i)
229 // {
230 // if (i != b)
231 // indices2[i] = combinations[comb][l++];
232 // // else indices2[i] = 0 (fiber start)
233 // }
234
235 // // Physical base offsets where each fiber starts
236 // const my_size_t base1 = Layout1::logical_coords_to_physical_flat(indices1);
237 // const my_size_t base2 = Layout2::logical_coords_to_physical_flat(indices2);
238
239 // // Dot product replaces the inner k-loop
240 // _outp(combinations[comb]) = Kern::dot(
241 // _tensor1.derived(), base1, stride1,
242 // _tensor2.derived(), base2, stride2,
243 // K);
244 // }
245
246 // return _outp;
247 // }
248
249 // /**
250 // * @brief Contract two tensors along specified axes using SIMD dot products.
251 // *
252 // * Computes C[free_indices] = sum_k A[...,k,...] * B[...,k,...]
253 // * where k runs along axis `a` of tensor1 and axis `b` of tensor2.
254 // *
255 // * Uses pre-computed stride maps to convert output coordinates directly
256 // * into physical base offsets for each input tensor, avoiding per-element
257 // * index array construction and logical_coords_to_physical_flat calls.
258 // *
259 // * ============================================================================
260 // * STRIDE MAP EXAMPLE: C[2,2] = A[2,3] * B[3,2], contract a=1, b=0
261 // * ============================================================================
262 // *
263 // * A[2,3] padded to [2,4]: B[3,2] padded to [3,4]:
264 // * strides: [4, 1] strides: [4, 1]
265 // * contract dim 1 (stride=1) contract dim 0 (stride=4)
266 // *
267 // * Output dims: [2, 2] (A's dim 0, B's dim 1)
268 // *
269 // * Stride maps (one entry per output dim):
270 // * strides1_map = [4, 0] ← output dim 0 → A's dim 0 (stride 4)
271 // * strides2_map = [0, 1] ← output dim 1 → B's dim 1 (stride 1)
272 // *
273 // * For output coord (1, 1):
274 // * base1 = 1*4 + 1*0 = 4 (start of A[1,:])
275 // * base2 = 1*0 + 1*1 = 1 (start of B[:,1])
276 // * → dot(A, 4, 1, B, 1, 4, 3)
277 // *
278 // * ============================================================================
279 // */
280 // template <typename LeftExpr, typename RightExpr>
281 // requires(expression::traits<LeftExpr>::IsPhysical &&
282 // expression::traits<RightExpr>::IsPhysical)
283 // static FusedTensorND einsum_new(
284 // const BaseExpr<LeftExpr> &_tensor1,
285 // const BaseExpr<RightExpr> &_tensor2,
286 // const my_size_t a,
287 // const my_size_t b)
288 // {
289 // static constexpr my_size_t Dims1 = LeftExpr::NumDims;
290 // static constexpr my_size_t Dims2 = RightExpr::NumDims;
291
292 // static_assert(Dims1 >= 2, "Tensor 1 must have at least 2 dimensions");
293 // static_assert(Dims2 >= 2, "Tensor 2 must have at least 2 dimensions");
294
295 // // Runtime validation
296 // if (a >= Dims1 || b >= Dims2)
297 // MyErrorHandler::error("Invalid contraction axis");
298
299 // if (_tensor1.derived().getDim(a) != _tensor2.derived().getDim(b))
300 // MyErrorHandler::error("Contraction dimensions mismatch");
301
302 // using Layout1 = typename LeftExpr::Layout;
303 // using Layout2 = typename RightExpr::Layout;
304 // using OutputLayout = Layout;
305 // using Kern = KernelOps<T, BITS, DefaultArch>;
306
307 // const my_size_t K_len = _tensor1.derived().getDim(a);
308 // const my_size_t contract_stride1 = Layout1::stride(a);
309 // const my_size_t contract_stride2 = Layout2::stride(b);
310
311 // // ====================================================================
312 // // Build stride maps
313 // // ====================================================================
314 // // For each output dimension d:
315 // // strides1_map[d] = physical stride in tensor1 (0 if d comes from tensor2)
316 // // strides2_map[d] = physical stride in tensor2 (0 if d comes from tensor1)
317 // // out_dims[d] = size of output dimension d
318
319 // static constexpr my_size_t n_newDims = Dims1 + Dims2 - 2;
320 // my_size_t strides1_map[n_newDims];
321 // my_size_t strides2_map[n_newDims];
322 // my_size_t out_dims[n_newDims];
323
324 // my_size_t d = 0;
325 // for (my_size_t i = 0; i < Dims1; ++i)
326 // {
327 // if (i != a)
328 // {
329 // out_dims[d] = _tensor1.derived().getDim(i);
330 // strides1_map[d] = Layout1::stride(i);
331 // strides2_map[d] = 0;
332 // ++d;
333 // }
334 // }
335 // for (my_size_t i = 0; i < Dims2; ++i)
336 // {
337 // if (i != b)
338 // {
339 // out_dims[d] = _tensor2.derived().getDim(i);
340 // strides1_map[d] = 0;
341 // strides2_map[d] = Layout2::stride(i);
342 // ++d;
343 // }
344 // }
345
346 // // ====================================================================
347 // // Validate output dimensions
348 // // ====================================================================
349
350 // FusedTensorND _outp;
351 // for (my_size_t i = 0; i < n_newDims; ++i)
352 // {
353 // if (out_dims[i] != _outp.getDim(i))
354 // MyErrorHandler::error("Output dimensions mismatch");
355 // }
356
357 // // ====================================================================
358 // // Pre-compute output physical strides for direct memory writes
359 // // ====================================================================
360
361 // my_size_t out_strides[n_newDims];
362 // for (my_size_t i = 0; i < n_newDims; ++i)
363 // out_strides[i] = OutputLayout::stride(i);
364
365 // T *out_ptr = _outp.data();
366
367 // // ====================================================================
368 // // Main loop: iterate all output elements
369 // // ====================================================================
370
371 // static constexpr my_size_t total_elements = (1 * ... * Dims);
372
373 // for (my_size_t flat = 0; flat < total_elements; ++flat)
374 // {
375 // // Decompose flat index into output coordinates (row-major)
376 // my_size_t coords[n_newDims];
377 // my_size_t tmp = flat;
378 // for (my_size_t i = n_newDims; i-- > 0;)
379 // {
380 // coords[i] = tmp % out_dims[i];
381 // tmp /= out_dims[i];
382 // }
383
384 // // Compute physical bases via stride maps (dot product of coords × strides)
385 // my_size_t base1 = 0;
386 // my_size_t base2 = 0;
387 // my_size_t out_phys = 0;
388 // for (my_size_t i = 0; i < n_newDims; ++i)
389 // {
390 // base1 += coords[i] * strides1_map[i];
391 // base2 += coords[i] * strides2_map[i];
392 // out_phys += coords[i] * out_strides[i];
393 // }
394
395 // // Dot product along contraction axis → write directly to output
396 // out_ptr[out_phys] = Kern::dot(
397 // _tensor1.derived(), base1, contract_stride1,
398 // _tensor2.derived(), base2, contract_stride2,
399 // K_len);
400 // }
401
402 // return _outp;
403 // }