tesseract++ 0.0.1
N-dimensional tensor library for embedded systems
Loading...
Searching...
No Matches
expr_diag.h
Go to the documentation of this file.
1#pragma once
2#include <string>
3#include <sstream>
4#include <vector>
5
6#if defined(__GNUC__) || defined(__clang__)
7#include <cxxabi.h>
8#endif
9
10// ===============================
11// AI generated, use with caution
12// ===============================
13
14namespace expr_diag
15{
16
17 inline int tensor_count = 0;
18 inline int scalar_count = 0;
19
20 // Greek letters for scalars
21 inline std::string get_greek_letter(int n)
22 {
23 const char *greeks[] = {"α", "β", "γ", "δ", "ε", "ζ", "η", "θ", "ι", "κ", "λ", "μ", "ν", "ξ", "π", "ρ", "σ", "τ", "υ", "φ", "χ", "ψ", "ω"};
24 constexpr int num_greeks = 23;
25 if (n < num_greeks)
26 {
27 return greeks[n];
28 }
29 // Fallback for many scalars
30 return greeks[n % num_greeks] + std::to_string(n / num_greeks);
31 }
32
33 // Unicode subscript digits
34 inline std::string to_subscript(int n)
35 {
36 const char *subscripts[] = {"₀", "₁", "₂", "₃", "₄", "₅", "₆", "₇", "₈", "₉"};
37 if (n == 0)
38 return subscripts[0];
39
40 std::string result;
41 while (n > 0)
42 {
43 result = subscripts[n % 10] + result;
44 n /= 10;
45 }
46 return result;
47 }
48
49 // Subscript type characters
50 inline std::string to_subscript_type(const std::string &type_name)
51 {
52 if (type_name == "double")
53 return "ᵈ";
54 else if (type_name == "float")
55 return "ᶠ";
56 else if (type_name == "int")
57 return "ⁱ";
58 else if (type_name == "long")
59 return "ˡ";
60 else if (type_name == "unsigned")
61 return "ᵘ";
62 else
63 return "?";
64 }
65
66 // Unicode superscript digits (for permutation indices)
67 inline std::string to_superscript(int n)
68 {
69 const char *superscripts[] = {"⁰", "¹", "²", "³", "⁴", "⁵", "⁶", "⁷", "⁸", "⁹"};
70 if (n == 0)
71 return superscripts[0];
72
73 std::string result;
74 while (n > 0)
75 {
76 result = superscripts[n % 10] + result;
77 n /= 10;
78 }
79 return result;
80 }
81
82 inline void skip_ws(const std::string &s, size_t &pos)
83 {
84 while (pos < s.size() && (s[pos] == ' ' || s[pos] == '\n'))
85 pos++;
86 }
87
88 inline std::string read_ident(const std::string &s, size_t &pos)
89 {
90 size_t start = pos;
91 while (pos < s.size() && (isalnum(s[pos]) || s[pos] == '_'))
92 pos++;
93 return s.substr(start, pos - start);
94 }
95
96 // Read a number (possibly with 'ul' suffix)
97 inline std::string read_number(const std::string &s, size_t &pos)
98 {
99 std::string num;
100 while (pos < s.size() && isdigit(s[pos]))
101 {
102 num += s[pos++];
103 }
104 // Skip 'ul', 'lu', 'u', 'l' suffixes
105 while (pos < s.size() && (s[pos] == 'u' || s[pos] == 'l' || s[pos] == 'U' || s[pos] == 'L'))
106 {
107 pos++;
108 }
109 return num;
110 }
111
112 inline std::string parse_expr(const std::string &s, size_t &pos);
113
114 inline std::string parse_tensor(const std::string &s, size_t &pos)
115 {
116 // We're right after "FusedTensorND"
117 pos++; // skip '<'
118 skip_ws(s, pos);
119
120 // Read type
121 std::string type_name = read_ident(s, pos);
122 std::string type_char = (type_name == "double") ? "d"
123 : (type_name == "float") ? "f"
124 : (type_name == "int" || type_name == "int32_t") ? "i32"
125 : (type_name == "long" || type_name == "int64_t" || type_name == "long long") ? "i64"
126 : "?";
127
128 // Read dimensions
129 std::vector<std::string> dims;
130 while (pos < s.size())
131 {
132 skip_ws(s, pos);
133 if (s[pos] == '>')
134 {
135 pos++;
136 break;
137 }
138 if (s[pos] == ',')
139 {
140 pos++;
141 skip_ws(s, pos);
142 std::string dim = read_number(s, pos);
143 if (!dim.empty())
144 {
145 dims.push_back(dim);
146 }
147 }
148 }
149
150 // Build: T₀d₃×₃
151 std::string result = "T" + to_subscript(tensor_count++) + type_char;
152 for (size_t i = 0; i < dims.size(); ++i)
153 {
154 if (i > 0)
155 result += "×";
156 result += to_subscript(std::stoi(dims[i]));
157 }
158
159 return result;
160 }
161
162 inline std::string parse_permuted_view(const std::string &s, size_t &pos)
163 {
164 // We're right after "PermutedViewConstExpr"
165 pos++; // skip '<'
166 skip_ws(s, pos);
167
168 // Parse inner expression
169 std::string inner = parse_expr(s, pos);
170
171 // Read permutation indices
172 std::vector<int> perm;
173 while (pos < s.size())
174 {
175 skip_ws(s, pos);
176 if (s[pos] == '>')
177 {
178 pos++;
179 break;
180 }
181 if (s[pos] == ',')
182 {
183 pos++;
184 skip_ws(s, pos);
185 std::string idx = read_number(s, pos);
186 if (!idx.empty())
187 {
188 perm.push_back(std::stoi(idx));
189 }
190 }
191 }
192
193 // Check if it's a simple transpose (swap of 2 dims: 1,0)
194 if (perm.size() == 2 && perm[0] == 1 && perm[1] == 0)
195 {
196 return inner + "ᵀ";
197 }
198
199 // General permutation: show as superscript indices
200 std::string result = inner + "⁽";
201 for (size_t i = 0; i < perm.size(); ++i)
202 {
203 if (i > 0)
204 result += "'";
205 result += to_superscript(perm[i]);
206 }
207 result += "⁾";
208
209 return result;
210 }
211
212 inline std::string make_scalar(const std::string &scalar_type)
213 {
214 std::string type_char = (scalar_type == "double") ? "d"
215 : (scalar_type == "float") ? "f"
216 : (scalar_type == "int") ? "i"
217 : (scalar_type == "long") ? "l"
218 : "?";
219 return type_char + get_greek_letter(scalar_count++);
220 }
221
222 inline std::string parse_expr(const std::string &s, size_t &pos)
223 {
224 skip_ws(s, pos);
225 std::string name = read_ident(s, pos);
226 skip_ws(s, pos);
227
228 if (name == "FusedTensorND")
229 {
230 return parse_tensor(s, pos);
231 }
232 else if (name == "PermutedViewConstExpr")
233 {
234 return parse_permuted_view(s, pos);
235 }
236 else if (name == "BinaryExpr")
237 {
238 pos++; // skip '<'
239 std::string lhs = parse_expr(s, pos);
240 skip_ws(s, pos);
241 pos++; // skip ','
242 std::string rhs = parse_expr(s, pos);
243 skip_ws(s, pos);
244 pos++; // skip ','
245 skip_ws(s, pos);
246 std::string op = read_ident(s, pos);
247 skip_ws(s, pos);
248 pos++; // skip '>'
249
250 std::string sym = (op == "Add") ? " + "
251 : (op == "Sub") ? " − "
252 : (op == "Mul") ? " · "
253 : (op == "Div") ? " / "
254 : (op == "Min") ? " ∧ "
255 : (op == "Max") ? " ∨ "
256 : " ? ";
257 return "(" + lhs + sym + rhs + ")";
258 }
259 else if (name == "ScalarExprRHS")
260 {
261 pos++; // skip '<'
262 std::string expr = parse_expr(s, pos);
263 skip_ws(s, pos);
264 pos++; // skip ','
265 skip_ws(s, pos);
266 std::string scalar_type = read_ident(s, pos);
267 skip_ws(s, pos);
268 pos++; // skip ','
269 skip_ws(s, pos);
270 std::string op = read_ident(s, pos);
271 skip_ws(s, pos);
272 pos++; // skip '>'
273
274 std::string scalar = make_scalar(scalar_type);
275 std::string sym = (op == "Add") ? " + "
276 : (op == "Sub") ? " − "
277 : (op == "Mul") ? " · "
278 : (op == "Div") ? " / "
279 : (op == "Min") ? " ∧ "
280 : (op == "Max") ? " ∨ "
281 : " ? ";
282 return "(" + expr + sym + scalar + ")";
283 }
284 else if (name == "ScalarExprLHS")
285 {
286 pos++; // skip '<'
287 std::string expr = parse_expr(s, pos);
288 skip_ws(s, pos);
289 pos++; // skip ','
290 skip_ws(s, pos);
291 std::string scalar_type = read_ident(s, pos);
292 skip_ws(s, pos);
293 pos++; // skip ','
294 skip_ws(s, pos);
295 std::string op = read_ident(s, pos);
296 skip_ws(s, pos);
297 pos++; // skip '>'
298
299 // Negation special case: 0 - expr
300 if (op == "Sub")
301 {
302 return "−" + expr;
303 }
304
305 std::string scalar = make_scalar(scalar_type);
306 std::string sym = (op == "Add") ? " + "
307 : (op == "Sub") ? " − "
308 : (op == "Mul") ? " · "
309 : (op == "Div") ? " / "
310 : " ? ";
311 return "(" + scalar + sym + expr + ")";
312 }
313 else if (name == "FmaExpr")
314 {
315 pos++; // skip '<'
316 std::string a = parse_expr(s, pos);
317 skip_ws(s, pos);
318 pos++; // skip ','
319 std::string b = parse_expr(s, pos);
320 skip_ws(s, pos);
321 pos++; // skip ','
322 std::string c = parse_expr(s, pos);
323 skip_ws(s, pos);
324 pos++; // skip ','
325 skip_ws(s, pos);
326 std::string op = read_ident(s, pos);
327 skip_ws(s, pos);
328 pos++; // skip '>'
329
330 if (op == "Fma")
331 return "⟨" + a + " · " + b + " + " + c + "⟩";
332 else if (op == "Fms")
333 return "⟨" + a + " · " + b + " − " + c + "⟩";
334 else if (op == "Fnma")
335 return "⟨−(" + a + " · " + b + ") + " + c + "⟩";
336 else if (op == "Fnms")
337 return "⟨−(" + a + " · " + b + ") − " + c + "⟩";
338 return "?";
339 }
340 else if (name == "ScalarFmaExpr")
341 {
342 pos++; // skip '<'
343 std::string expr = parse_expr(s, pos);
344 skip_ws(s, pos);
345 pos++; // skip ','
346 skip_ws(s, pos);
347 std::string scalar_type = read_ident(s, pos);
348 skip_ws(s, pos);
349 pos++; // skip ','
350 std::string c = parse_expr(s, pos);
351 skip_ws(s, pos);
352 pos++; // skip ','
353 skip_ws(s, pos);
354 std::string op = read_ident(s, pos);
355 skip_ws(s, pos);
356 pos++; // skip '>'
357
358 std::string scalar = make_scalar(scalar_type);
359
360 if (op == "Fma")
361 return "⟨" + expr + " · " + scalar + " + " + c + "⟩";
362 else if (op == "Fms")
363 return "⟨" + expr + " · " + scalar + " − " + c + "⟩";
364 else if (op == "Fnma")
365 return "⟨−(" + expr + " · " + scalar + ") + " + c + "⟩";
366 else if (op == "Fnms")
367 return "⟨−(" + expr + " · " + scalar + ") − " + c + "⟩";
368 return "?";
369 }
370
371 return "?";
372 }
373
374 template <typename Expr>
375 std::string to_string()
376 {
377 tensor_count = 0;
378 scalar_count = 0;
379 std::string type_name = typeid(Expr).name();
380
381#if defined(__GNUC__) || defined(__clang__)
382 int status;
383 char *demangled = abi::__cxa_demangle(type_name.c_str(), nullptr, nullptr, &status);
384 if (status == 0)
385 {
386 type_name = demangled;
387 free(demangled);
388 }
389#endif
390
391 size_t pos = 0;
392 return parse_expr(type_name, pos);
393 }
394
395 template <typename Expr>
397 {
398 std::cout << to_string<Expr>() << "\n";
399 }
400
401 inline void print_legend()
402 {
403 std::cout << "═══════════════════════════════════════════════\n";
404 std::cout << " Expression Notation Legend\n";
405 std::cout << "═══════════════════════════════════════════════\n";
406 std::cout << " Tensors: Tₙtᵢ×ⱼ = Tensor #n, type t, dims i×j\n";
407 std::cout << " Types: d=double, f=float, i=int, l=long\n";
408 std::cout << " Scalars: tα, tβ, tγ... (type prefix + greek)\n";
409 std::cout << " dα=double, fβ=float, iγ=int, lδ=long\n";
410 std::cout << " Views: ᵀ = transpose (permutation 1,0)\n";
411 std::cout << " ⁽ⁱ'ʲ'ᵏ⁾ = general permutation\n";
412 std::cout << " Ops: + − · /\n";
413 std::cout << " FMA: ⟨a·b+c⟩ ⟨a·b−c⟩ ⟨−(a·b)+c⟩ ⟨−(a·b)−c⟩\n";
414 std::cout << " ⟨ ⟩ = fused (single instruction)\n";
415 std::cout << "═══════════════════════════════════════════════\n";
416 }
417
418} // namespace expr_diag
Definition expr_diag.h:15
std::string parse_permuted_view(const std::string &s, size_t &pos)
Definition expr_diag.h:162
std::string to_subscript(int n)
Definition expr_diag.h:34
void print_expr()
Definition expr_diag.h:396
std::string read_ident(const std::string &s, size_t &pos)
Definition expr_diag.h:88
std::string make_scalar(const std::string &scalar_type)
Definition expr_diag.h:212
std::string to_string()
Definition expr_diag.h:375
std::string parse_tensor(const std::string &s, size_t &pos)
Definition expr_diag.h:114
int tensor_count
Definition expr_diag.h:17
std::string parse_expr(const std::string &s, size_t &pos)
Definition expr_diag.h:222
std::string get_greek_letter(int n)
Definition expr_diag.h:21
std::string to_subscript_type(const std::string &type_name)
Definition expr_diag.h:50
std::string read_number(const std::string &s, size_t &pos)
Definition expr_diag.h:97
void print_legend()
Definition expr_diag.h:401
int scalar_count
Definition expr_diag.h:18
std::string to_superscript(int n)
Definition expr_diag.h:67
void skip_ws(const std::string &s, size_t &pos)
Definition expr_diag.h:82