pytorch
1# mypy: allow-untyped-defs
2"""
3This is a simple interpreter for Sympy expressions that dispatches to
4classes following the torch._inductor.virtualized calling convention.
5For directness, the interpreter takes the handler directly rather than
6consulting the TLS. It does not use most of the methods on the full
7handler; only those with corresponding Sympy expressions. To see an example
8of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
9"""
10
11import functools12import logging13from typing import Any, Dict, Union14
15import sympy16from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom17
18import torch19
20from .functions import (21CeilToInt,22CleanDiv,23FloatPow,24FloatTrueDiv,25FloorDiv,26FloorToInt,27Identity,28IntTrueDiv,29IsNonOverlappingAndDenseIndicator,30Max,31Min,32Mod,33ModularIndexing,34PowByNatural,35PythonMod,36RoundDecimal,37RoundToInt,38ToFloat,39TruncToFloat,40TruncToInt,41Where,42)
43
44
45log = logging.getLogger(__name__)46
47
48# TODO: Dedupe this with SYMPY_INTERP
49
50
51@functools.lru_cache(None)52def handlers():53# TODO add CeilDiv (it doesn't appear in the index_expr)54
55# TODO default to some decompositions if the interpreter doesn't have them56# like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)57
58HANDLERS = {59sympy.Or: "or_",60sympy.And: "and_",61sympy.Eq: "eq",62sympy.Ne: "ne",63sympy.Lt: "lt",64sympy.Gt: "gt",65sympy.Le: "le",66sympy.Ge: "ge",67sympy.Not: "not_",68IntTrueDiv: "int_truediv",69FloatTrueDiv: "truediv",70FloorDiv: "floordiv",71CleanDiv: "floordiv", # TODO: hmm?72TruncToFloat: "trunc",73Where: "where",74sympy.Add: "add",75sympy.Mul: "mul",76FloatPow: "pow",77PowByNatural: "pow_by_natural",78# sympy simplifies x * x into Pow(x, 2), so we need to handle this.79# Do NOT use builtin Pow for floats80# TODO: There is a hazard here, if we have float * float it will81# also get turned into Pow(float, 2) but we don't want this because82# pow_by_natural is assumed to only be integers. Probably the fix is83# to add a FloatMul to impede this optimization84sympy.Pow: "pow_by_natural",85Mod: "mod",86PythonMod: "mod", # TODO: this is wrong87# TODO: Inductor can generate these, but it's ill-specified which88# semantics were intended here. Needs to be cleaned up along with89# FloorDiv in a bigger cleanup90sympy.Mod: "mod",91sympy.Abs: "abs",92sympy.log: "log",93sympy.exp: "exp",94sympy.Min: "minimum",95sympy.Max: "maximum",96Min: "minimum",97Max: "maximum",98ModularIndexing: "modular_indexing",99sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",100sympy.Piecewise: "piecewise",101Identity: "identity",102IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",103RoundDecimal: "round_decimal",104}105for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:106HANDLERS[getattr(sympy, name)] = name107
108return HANDLERS109
110
111ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}112
113
114def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):115# Special cases116if isinstance(expr, sympy.Pow) and isinstance(117expr.args[1], sympy.core.numbers.Half118):119return analysis.sqrt(args[0])120if isinstance(expr, ToFloat):121return analysis.to_dtype(args[0], torch.float64)122
123# These handlers are special because they take an extra dtype argument124# specifying what they should convert to, and we need to appropriately set125# this up when we convert from Sympy. A reasonable default when you126# are translating is to conservatively do int64, and then narrow these127# arguments later when you discover you can narrow the index range. But128# if you already know that 32-bit indexing is OK, you can directly do the129# sympy translation with index_dtype=torch.int32130INDEX_DTYPE_HANDLERS = {131TruncToInt: "trunc_to_int",132sympy.floor: "floor_to_int",133sympy.ceiling: "ceil_to_int",134FloorToInt: "floor_to_int",135CeilToInt: "ceil_to_int",136RoundToInt: "round_to_int",137}138if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:139return getattr(analysis, handler_name)(*args, index_dtype)140
141if hasattr(expr.func, "_torch_handler_name"):142handler_name = expr.func._torch_handler_name143else:144handler_name = handlers()[expr.func]145handler = getattr(analysis, handler_name)146try:147if handler_name in ASSOCIATIVE_OPS:148assert len(args) > 1149acc = handler(args[0], args[1])150for i in range(2, len(args)):151acc = handler(acc, args[i])152log.debug("%s(%s) -> %s", handler_name, args, acc)153return acc154else:155r = handler(*args)156log.debug("%s(%s) -> %s", handler_name, args, r)157return r158except Exception:159log.warning("failed while executing %s(%s)", handler_name, args)160raise161
162
163def sympy_interp(164analysis,165env: Dict[sympy.Symbol, Any],166expr: Union[sympy.Expr, SympyBoolean],167*,168index_dtype=torch.int64,169):170# Handle base cases171dtype = None172if isinstance(expr, BooleanAtom):173dtype = torch.bool174elif isinstance(expr, sympy.Integer):175dtype = torch.int64176elif isinstance(expr, sympy.Number):177dtype = torch.double178
179if dtype is not None:180return analysis.constant(expr, dtype)181elif isinstance(expr, sympy.Symbol):182return env[expr]183
184# Recursive case185return _run_sympy_handler(186analysis,187[sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type]188expr,189index_dtype=index_dtype,190) # type: ignore[arg-type]191