1
"""This file implements the IndexPropagation ops handler, which wraps an
2
underlying handler to add a limited form of constant propagation, as well as
3
propagation of sympy expressions downstream of ops.index_expr calls.
5
For example, say we have the IR:
7
tmp0 = ops.index_expr(x, torch.int32)
8
tmp1 = ops.constant(2, torch.int32)
9
tmp2 = ops.mul(tmp0, tmp1)
10
tmp3 = ops.indirect_indexing(tmp2, x_size)
11
tmp4 = ops.load("buf0", tmp3)
13
The underlying handler would just see:
15
ops.load("buf0", x * 2)
17
This is limited by the set of operators handled in the sympy expression
18
printers. So simple operations like minimum and maximum cannot be translated to
19
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
23
from dataclasses import dataclass
24
from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
28
from typing_extensions import TypeAlias
31
from torch._prims_common import is_boolean_dtype, is_integer_dtype
32
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
37
"""A SymPy expression with associated type"""
44
"""An ops handler where all IR values are SymPy expressions
46
When a value cannot be represented as a SymPy expression, the method is
47
either not defined, or returns NotImplemented
52
def identity(value: Any) -> Any:
56
def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
57
if is_boolean_dtype(dtype):
58
expr = sympy.Integer(bool(value))
59
elif is_integer_dtype(dtype):
60
expr = sympy.Integer(int(value))
62
expr = sympy.Float(float(value))
63
return TypedExpr(expr, dtype)
66
def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]:
67
if isinstance(value, int):
68
value = sympy.Integer(value)
69
return TypedExpr(value, dtype)
73
value: Any, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
74
) -> Union[int, TypedExpr]:
75
if isinstance(value.expr, (sympy.Integer, sympy.Float)):
76
return SymPyOps.constant(value.expr, dtype)
77
elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype):
78
return SymPyOps.index_expr(value.expr, dtype)
84
def square(x: TypedExpr) -> TypedExpr:
85
return TypedExpr(x.expr * x.expr, x.dtype)
88
def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
89
result_type = torch.promote_types(x.dtype, y.dtype)
90
return TypedExpr(x.expr + y.expr, result_type)
93
def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
94
result_type = torch.promote_types(x.dtype, y.dtype)
95
return TypedExpr(x.expr - y.expr, result_type)
98
def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
99
result_type = torch.promote_types(x.dtype, y.dtype)
100
return TypedExpr(x.expr * y.expr, result_type)
103
def neg(x: TypedExpr) -> TypedExpr:
104
return TypedExpr(-x.expr, x.dtype)
107
def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
108
result_type = torch.promote_types(x.dtype, y.dtype)
109
if not is_integer_dtype(result_type):
110
return NotImplemented
112
return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
115
def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
116
result_type = torch.promote_types(x.dtype, y.dtype)
117
if not is_integer_dtype(result_type):
118
return NotImplemented
120
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
121
return TypedExpr(result_expr, result_type)
124
def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
125
result_type = torch.promote_types(x.dtype, y.dtype)
126
if not is_integer_dtype(result_type):
127
return NotImplemented
131
x.expr.is_nonnegative is not None
132
and x.expr.is_nonnegative == y.expr.is_positive
134
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
135
return TypedExpr(result_expr, result_type)
136
return NotImplemented
139
def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
140
result_type = torch.promote_types(x.dtype, y.dtype)
141
return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
144
def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
145
result_type = torch.promote_types(x.dtype, y.dtype)
146
return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
152
is_symbolic: bool = False
155
def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
156
return IndexPropVar(expr, is_symbolic=True)
158
def __post_init__(self):
159
assert not self.is_symbolic or isinstance(
160
self.value, TypedExpr
161
), "Symbolic IndexPropVar must contain a TypedExpr"
164
IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
167
class IndexPropagation:
168
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
170
This aims to maximize the compile time simplification possible, and convert
171
indirect indexing from arange into normal static indexing.
175
def __init__(self, inner: Any):
178
def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
180
if isinstance(expr, sympy.Integer):
181
return self._inner.constant(int(expr), dtype)
183
return self._inner.constant(float(expr), dtype)
184
return self._inner.index_expr(expr, dtype)
186
def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
187
if isinstance(a, (list, tuple)):
188
return tuple(self.unwrap(v) for v in a)
190
if not isinstance(a, IndexPropVar):
195
return self.materialize_expr(a.value.expr, a.value.dtype)
199
def wrap(self, a) -> IndexPropResult:
200
if isinstance(a, (list, tuple)):
201
return tuple(self.wrap(v) for v in a)
202
return IndexPropVar(a)
207
name: Literal["indirect_indexing"],
208
args: Tuple[Any, ...],
209
kwargs: Dict[str, Any],
215
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
216
) -> IndexPropResult:
220
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
221
) -> IndexPropResult:
223
new_args = [self.unwrap(a) for a in args]
224
new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
225
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
228
self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
229
) -> IndexPropResult:
231
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
232
if not isinstance(a, IndexPropVar):
236
new_args = [unwrap(a) for a in args]
237
new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
238
new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
239
is_valid_expr = new_expr is not NotImplemented and (
242
isinstance(new_expr.expr, sympy.Number)
243
or new_expr.expr.is_integer
245
if not is_valid_expr:
246
return self.fallback(name, args, kwargs)
247
return IndexPropVar.new_symbolic(new_expr)
249
def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
250
def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
251
if not hasattr(SymPyOps, name):
252
return self.fallback(name, args, kwargs)
256
for a in itertools.chain(args, kwargs.values())
257
if isinstance(a, IndexPropVar)
259
if not all(v.is_symbolic for v in var_arguments):
260
return self.fallback(name, args, kwargs)
262
return self.propagate_sympy(name, args, kwargs)
266
def indirect_indexing(
267
self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
273
if isinstance(index, IndexPropVar) and index.is_symbolic:
275
index = index.value.expr
276
return index + Where(index >= 0, 0, size)
277
return self.fallback("indirect_indexing", (index, size, check), {}).value