2
This is a simple interpreter for Sympy expressions that dispatches to
3
classes following the torch._inductor.virtualized calling convention.
4
For directness, the interpreter takes the handler directly rather than
5
consulting the TLS. It does not use most of the methods on the full
6
handler; only those with corresponding Sympy expressions. To see an example
7
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
11
from typing import Any, Dict, Union
14
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
17
from .functions import (
20
IsNonOverlappingAndDenseIndicator,
34
@functools.lru_cache(None)
65
sympy.ceiling: "ceil",
68
ModularIndexing: "modular_indexing",
69
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
70
sympy.Piecewise: "piecewise",
71
IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
73
RoundDecimal: "round",
75
for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
76
HANDLERS[getattr(sympy, name)] = name
81
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
85
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
89
if isinstance(expr, BooleanAtom):
91
elif isinstance(expr, sympy.Integer):
93
elif isinstance(expr, sympy.Number):
97
return analysis.constant(expr, dtype)
98
elif isinstance(expr, sympy.Symbol):
102
if isinstance(expr, sympy.Pow) and isinstance(
103
expr.args[1], sympy.core.numbers.Half
105
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
108
args = [sympy_interp(analysis, env, arg) for arg in expr.args]
109
handler_name = handlers()[expr.func]
110
handler = getattr(analysis, handler_name)
111
if handler_name in ASSOCIATIVE_OPS:
113
acc = handler(args[0], args[1])
114
for i in range(2, len(args)):
115
acc = handler(acc, args[i])
118
return handler(*args)