6#if defined(__GNUC__) || defined(__clang__)
23 const char *greeks[] = {
"α",
"β",
"γ",
"δ",
"ε",
"ζ",
"η",
"θ",
"ι",
"κ",
"λ",
"μ",
"ν",
"ξ",
"π",
"ρ",
"σ",
"τ",
"υ",
"φ",
"χ",
"ψ",
"ω"};
24 constexpr int num_greeks = 23;
30 return greeks[n % num_greeks] + std::to_string(n / num_greeks);
36 const char *subscripts[] = {
"₀",
"₁",
"₂",
"₃",
"₄",
"₅",
"₆",
"₇",
"₈",
"₉"};
43 result = subscripts[n % 10] + result;
52 if (type_name ==
"double")
54 else if (type_name ==
"float")
56 else if (type_name ==
"int")
58 else if (type_name ==
"long")
60 else if (type_name ==
"unsigned")
69 const char *superscripts[] = {
"⁰",
"¹",
"²",
"³",
"⁴",
"⁵",
"⁶",
"⁷",
"⁸",
"⁹"};
71 return superscripts[0];
76 result = superscripts[n % 10] + result;
82 inline void skip_ws(
const std::string &s,
size_t &pos)
84 while (pos < s.size() && (s[pos] ==
' ' || s[pos] ==
'\n'))
88 inline std::string
read_ident(
const std::string &s,
size_t &pos)
91 while (pos < s.size() && (isalnum(s[pos]) || s[pos] ==
'_'))
93 return s.substr(start, pos - start);
97 inline std::string
read_number(
const std::string &s,
size_t &pos)
100 while (pos < s.size() && isdigit(s[pos]))
105 while (pos < s.size() && (s[pos] ==
'u' || s[pos] ==
'l' || s[pos] ==
'U' || s[pos] ==
'L'))
112 inline std::string
parse_expr(
const std::string &s,
size_t &pos);
122 std::string type_char = (type_name ==
"double") ?
"d"
123 : (type_name ==
"float") ?
"f"
124 : (type_name ==
"int" || type_name ==
"int32_t") ?
"i32"
125 : (type_name ==
"long" || type_name ==
"int64_t" || type_name ==
"long long") ?
"i64"
129 std::vector<std::string> dims;
130 while (pos < s.size())
152 for (
size_t i = 0; i < dims.size(); ++i)
172 std::vector<int> perm;
173 while (pos < s.size())
188 perm.push_back(std::stoi(idx));
194 if (perm.size() == 2 && perm[0] == 1 && perm[1] == 0)
200 std::string result = inner +
"⁽";
201 for (
size_t i = 0; i < perm.size(); ++i)
214 std::string type_char = (scalar_type ==
"double") ?
"d"
215 : (scalar_type ==
"float") ?
"f"
216 : (scalar_type ==
"int") ?
"i"
217 : (scalar_type ==
"long") ?
"l"
222 inline std::string
parse_expr(
const std::string &s,
size_t &pos)
228 if (name ==
"FusedTensorND")
232 else if (name ==
"PermutedViewConstExpr")
236 else if (name ==
"BinaryExpr")
250 std::string sym = (op ==
"Add") ?
" + "
251 : (op ==
"Sub") ?
" − "
252 : (op ==
"Mul") ?
" · "
253 : (op ==
"Div") ?
" / "
254 : (op ==
"Min") ?
" ∧ "
255 : (op ==
"Max") ?
" ∨ "
257 return "(" + lhs + sym + rhs +
")";
259 else if (name ==
"ScalarExprRHS")
275 std::string sym = (op ==
"Add") ?
" + "
276 : (op ==
"Sub") ?
" − "
277 : (op ==
"Mul") ?
" · "
278 : (op ==
"Div") ?
" / "
279 : (op ==
"Min") ?
" ∧ "
280 : (op ==
"Max") ?
" ∨ "
282 return "(" + expr + sym + scalar +
")";
284 else if (name ==
"ScalarExprLHS")
306 std::string sym = (op ==
"Add") ?
" + "
307 : (op ==
"Sub") ?
" − "
308 : (op ==
"Mul") ?
" · "
309 : (op ==
"Div") ?
" / "
311 return "(" + scalar + sym + expr +
")";
313 else if (name ==
"FmaExpr")
331 return "⟨" + a +
" · " + b +
" + " + c +
"⟩";
332 else if (op ==
"Fms")
333 return "⟨" + a +
" · " + b +
" − " + c +
"⟩";
334 else if (op ==
"Fnma")
335 return "⟨−(" + a +
" · " + b +
") + " + c +
"⟩";
336 else if (op ==
"Fnms")
337 return "⟨−(" + a +
" · " + b +
") − " + c +
"⟩";
340 else if (name ==
"ScalarFmaExpr")
361 return "⟨" + expr +
" · " + scalar +
" + " + c +
"⟩";
362 else if (op ==
"Fms")
363 return "⟨" + expr +
" · " + scalar +
" − " + c +
"⟩";
364 else if (op ==
"Fnma")
365 return "⟨−(" + expr +
" · " + scalar +
") + " + c +
"⟩";
366 else if (op ==
"Fnms")
367 return "⟨−(" + expr +
" · " + scalar +
") − " + c +
"⟩";
374 template <
typename Expr>
379 std::string type_name =
typeid(Expr).name();
381#if defined(__GNUC__) || defined(__clang__)
383 char *demangled = abi::__cxa_demangle(type_name.c_str(),
nullptr,
nullptr, &status);
386 type_name = demangled;
395 template <
typename Expr>
398 std::cout << to_string<Expr>() <<
"\n";
403 std::cout <<
"═══════════════════════════════════════════════\n";
404 std::cout <<
" Expression Notation Legend\n";
405 std::cout <<
"═══════════════════════════════════════════════\n";
406 std::cout <<
" Tensors: Tₙtᵢ×ⱼ = Tensor #n, type t, dims i×j\n";
407 std::cout <<
" Types: d=double, f=float, i=int, l=long\n";
408 std::cout <<
" Scalars: tα, tβ, tγ... (type prefix + greek)\n";
409 std::cout <<
" dα=double, fβ=float, iγ=int, lδ=long\n";
410 std::cout <<
" Views: ᵀ = transpose (permutation 1,0)\n";
411 std::cout <<
" ⁽ⁱ'ʲ'ᵏ⁾ = general permutation\n";
412 std::cout <<
" Ops: + − · /\n";
413 std::cout <<
" FMA: ⟨a·b+c⟩ ⟨a·b−c⟩ ⟨−(a·b)+c⟩ ⟨−(a·b)−c⟩\n";
414 std::cout <<
" ⟨ ⟩ = fused (single instruction)\n";
415 std::cout <<
"═══════════════════════════════════════════════\n";
Definition expr_diag.h:15
std::string parse_permuted_view(const std::string &s, size_t &pos)
Definition expr_diag.h:162
std::string to_subscript(int n)
Definition expr_diag.h:34
void print_expr()
Definition expr_diag.h:396
std::string read_ident(const std::string &s, size_t &pos)
Definition expr_diag.h:88
std::string make_scalar(const std::string &scalar_type)
Definition expr_diag.h:212
std::string to_string()
Definition expr_diag.h:375
std::string parse_tensor(const std::string &s, size_t &pos)
Definition expr_diag.h:114
int tensor_count
Definition expr_diag.h:17
std::string parse_expr(const std::string &s, size_t &pos)
Definition expr_diag.h:222
std::string get_greek_letter(int n)
Definition expr_diag.h:21
std::string to_subscript_type(const std::string &type_name)
Definition expr_diag.h:50
std::string read_number(const std::string &s, size_t &pos)
Definition expr_diag.h:97
void print_legend()
Definition expr_diag.h:401
int scalar_count
Definition expr_diag.h:18
std::string to_superscript(int n)
Definition expr_diag.h:67
void skip_ws(const std::string &s, size_t &pos)
Definition expr_diag.h:82