pytorch

Форк
0
/
index_propagation.py 
277 строк · 9.6 Кб
1
"""This file implements the IndexPropagation ops handler, which wraps an
2
underlying handler to add a limited form of constant propagation, as well as
3
propagation of sympy expressions downstream of ops.index_expr calls.
4

5
For example, say we have the IR:
6

7
   tmp0 = ops.index_expr(x, torch.int32)
8
   tmp1 = ops.constant(2, torch.int32)
9
   tmp2 = ops.mul(tmp0, tmp1)
10
   tmp3 = ops.indirect_indexing(tmp2, x_size)
11
   tmp4 = ops.load("buf0", tmp3)
12

13
The underlying handler would just see:
14

15
   ops.load("buf0", x * 2)
16

17
This is limited by the set of operators handled in the sympy expression
18
printers. So simple operations like minimum and maximum cannot be translated to
19
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
20

21
"""
22
import itertools
23
from dataclasses import dataclass
24
from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
25

26
import sympy
27

28
from typing_extensions import TypeAlias
29

30
import torch
31
from torch._prims_common import is_boolean_dtype, is_integer_dtype
32
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
33

34

35
@dataclass
36
class TypedExpr:
37
    """A SymPy expression with associated type"""
38

39
    expr: sympy.Expr
40
    dtype: torch.dtype
41

42

43
class SymPyOps:
44
    """An ops handler where all IR values are SymPy expressions
45

46
    When a value cannot be represented as a SymPy expression, the method is
47
    either not defined, or returns NotImplemented
48

49
    """
50

51
    @staticmethod
52
    def identity(value: Any) -> Any:
53
        return value
54

55
    @staticmethod
56
    def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
57
        if is_boolean_dtype(dtype):
58
            expr = sympy.Integer(bool(value))
59
        elif is_integer_dtype(dtype):
60
            expr = sympy.Integer(int(value))
61
        else:
62
            expr = sympy.Float(float(value))
63
        return TypedExpr(expr, dtype)
64

65
    @staticmethod
66
    def index_expr(value: sympy.Expr, dtype: torch.dtype) -> Union[int, TypedExpr]:
67
        if isinstance(value, int):
68
            value = sympy.Integer(value)
69
        return TypedExpr(value, dtype)
70

71
    @staticmethod
72
    def to_dtype(
73
        value: Any, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
74
    ) -> Union[int, TypedExpr]:
75
        if isinstance(value.expr, (sympy.Integer, sympy.Float)):
76
            return SymPyOps.constant(value.expr, dtype)
77
        elif is_integer_dtype(dtype) and is_integer_dtype(value.dtype):
78
            return SymPyOps.index_expr(value.expr, dtype)
79
        else:
80
            # TODO: Inductor doesn't handle floating point in sympy expressions well at the moment
81
            return NotImplemented
82

83
    @staticmethod
84
    def square(x: TypedExpr) -> TypedExpr:
85
        return TypedExpr(x.expr * x.expr, x.dtype)
86

87
    @staticmethod
88
    def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
89
        result_type = torch.promote_types(x.dtype, y.dtype)
90
        return TypedExpr(x.expr + y.expr, result_type)
91

92
    @staticmethod
93
    def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
94
        result_type = torch.promote_types(x.dtype, y.dtype)
95
        return TypedExpr(x.expr - y.expr, result_type)
96

97
    @staticmethod
98
    def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
99
        result_type = torch.promote_types(x.dtype, y.dtype)
100
        return TypedExpr(x.expr * y.expr, result_type)
101

102
    @staticmethod
103
    def neg(x: TypedExpr) -> TypedExpr:
104
        return TypedExpr(-x.expr, x.dtype)
105

106
    @staticmethod
107
    def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
108
        result_type = torch.promote_types(x.dtype, y.dtype)
109
        if not is_integer_dtype(result_type):
110
            return NotImplemented
111

112
        return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
113

114
    @staticmethod
115
    def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
116
        result_type = torch.promote_types(x.dtype, y.dtype)
117
        if not is_integer_dtype(result_type):
118
            return NotImplemented
119

120
        result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
121
        return TypedExpr(result_expr, result_type)
122

123
    @staticmethod
124
    def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
125
        result_type = torch.promote_types(x.dtype, y.dtype)
126
        if not is_integer_dtype(result_type):
127
            return NotImplemented
128
        # In these cases, remainder in Python == remainder in C++, so this transformation
129
        # is sound
130
        if (
131
            x.expr.is_nonnegative is not None
132
            and x.expr.is_nonnegative == y.expr.is_positive
133
        ):
134
            result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
135
            return TypedExpr(result_expr, result_type)
136
        return NotImplemented
137

138
    @staticmethod
139
    def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
140
        result_type = torch.promote_types(x.dtype, y.dtype)
141
        return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
142

143
    @staticmethod
144
    def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
145
        result_type = torch.promote_types(x.dtype, y.dtype)
146
        return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
147

148

149
@dataclass
150
class IndexPropVar:
151
    value: Any  # Either an IR value, or TypedExpr if is_symbolic is true
152
    is_symbolic: bool = False
153

154
    @staticmethod
155
    def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
156
        return IndexPropVar(expr, is_symbolic=True)
157

158
    def __post_init__(self):
159
        assert not self.is_symbolic or isinstance(
160
            self.value, TypedExpr
161
        ), "Symbolic IndexPropVar must contain a TypedExpr"
162

163

164
IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
165

166

167
class IndexPropagation:
168
    """Ops wrapper that tries to propagate constant and index_expr values through the computation.
169

170
    This aims to maximize the compile time simplification possible, and convert
171
    indirect indexing from arange into normal static indexing.
172

173
    """
174

175
    def __init__(self, inner: Any):
176
        self._inner = inner
177

178
    def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
179
        # Construct a new constant/index_expr from the SymPy expression
180
        if isinstance(expr, sympy.Integer):
181
            return self._inner.constant(int(expr), dtype)
182
        elif expr.is_number:
183
            return self._inner.constant(float(expr), dtype)
184
        return self._inner.index_expr(expr, dtype)
185

186
    def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
187
        if isinstance(a, (list, tuple)):
188
            return tuple(self.unwrap(v) for v in a)
189

190
        if not isinstance(a, IndexPropVar):
191
            return a
192

193
        # Prefer the sympy representation if possible
194
        if a.is_symbolic:
195
            return self.materialize_expr(a.value.expr, a.value.dtype)
196

197
        return a.value
198

199
    def wrap(self, a) -> IndexPropResult:
200
        if isinstance(a, (list, tuple)):
201
            return tuple(self.wrap(v) for v in a)
202
        return IndexPropVar(a)
203

204
    @overload
205
    def fallback(
206
        self,
207
        name: Literal["indirect_indexing"],
208
        args: Tuple[Any, ...],
209
        kwargs: Dict[str, Any],
210
    ) -> IndexPropVar:
211
        ...
212

213
    @overload
214
    def fallback(
215
        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
216
    ) -> IndexPropResult:
217
        ...
218

219
    def fallback(
220
        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
221
    ) -> IndexPropResult:
222
        # Fallback to the wrapped handler
223
        new_args = [self.unwrap(a) for a in args]
224
        new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
225
        return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
226

227
    def propagate_sympy(
228
        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
229
    ) -> IndexPropResult:
230
        # Build a new SymPy expression from this ops call
231
        def unwrap(a: Union[Any, IndexPropVar]) -> Any:
232
            if not isinstance(a, IndexPropVar):
233
                return a
234
            return a.value
235

236
        new_args = [unwrap(a) for a in args]
237
        new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
238
        new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
239
        is_valid_expr = new_expr is not NotImplemented and (
240
            # Inductor doesn't expect floating point in sympy expressions, but
241
            # allow floating point constants to be propagated
242
            isinstance(new_expr.expr, sympy.Number)
243
            or new_expr.expr.is_integer
244
        )
245
        if not is_valid_expr:
246
            return self.fallback(name, args, kwargs)
247
        return IndexPropVar.new_symbolic(new_expr)
248

249
    def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
250
        def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
251
            if not hasattr(SymPyOps, name):
252
                return self.fallback(name, args, kwargs)
253

254
            var_arguments = [
255
                a
256
                for a in itertools.chain(args, kwargs.values())
257
                if isinstance(a, IndexPropVar)
258
            ]
259
            if not all(v.is_symbolic for v in var_arguments):
260
                return self.fallback(name, args, kwargs)
261

262
            return self.propagate_sympy(name, args, kwargs)
263

264
        return inner
265

266
    def indirect_indexing(
267
        self, index: Union[Any, IndexPropVar], size: Any, check: bool = True
268
    ) -> Any:
269
        # nb. We do index + Where(...) rather than Where(idx >= 0, idx, idx + sz) because we don't have CSE
270
        #     for SymPy expressions, so we don't want to repeat idx too much
271

272
        # indirect_indexing returns a sympy value, so no need to wrap in IndexPropVar here
273
        if isinstance(index, IndexPropVar) and index.is_symbolic:
274
            # If we are turning a indirect indexing into direct, we need to wrap it.
275
            index = index.value.expr
276
            return index + Where(index >= 0, 0, size)
277
        return self.fallback("indirect_indexing", (index, size, check), {}).value
278

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.