pytorch

Форк
0
/
virtualized.py 
351 строка · 11.5 Кб
1
"""
2
This file provides a number of "global" variables/handlers that are actually
3
thread local and dynamically scoped, with Inductor patching them to various
4
implementations depending on the situation.
5

6
These handlers are interacted with in a fairly stylized way.  Typically,
7
we will import V from this module::
8

9
    from .virtualized import V
10

11
Various handlers are accessible as attributes on this module; for example,
12
you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with
13
a number.
14

15
There are a few distinct usage patterns for virtualized global variables:
16

17
1. Implicit argument passing.  Examples: ``V.current_node``, ``V.aot_compilation``.
18
   Use ``V.set_current_node`` to change what the current node is while we're
19
   executing some region of code, so code inside that region can query ``V.current_node``
20
   to find out what it is.  This is often more convenient than manually threading
21
   the current node as an argument through all call stacks.
22

23
2. Per-compilation global state.  Examples: ``V.fake_mode``, ``V.graph``.  For a
24
   given ``compile_fx`` invocation, these typically don't change, but they are
25
   associated with some internal state so they cannot just be global functions.
26
   We install these objects at the beginning of compilation and then you can
27
   conveniently access them without having to pass them around.
28

29
3. Alternate define-by-run interpretations.  Examples: ``V.ops``, ``V.kernel``.
30
   A commonly used IR in Inductor is define-by-run: instead of maintaining
31
   explicit syntax data structures, we instead represent loop bodies as
32
   callable functions, which internally invoke operations defined on
33
   ``V.ops``.  To perform semantic analysis, print or code generate these
34
   operations, we dynamically patch ``V.ops`` with an alternate handler with
35
   the intended semantics and then run the callable function.  For example, to
36
   extract out a traditional (FX) graph representation of the define-by-run
37
   IR, simply install a handler that records each ``ops`` call to a graph.
38

39
   TODO: Define a parent class / protocol that defines all of the operations
40
   V.ops is expected to support.
41

42
It is typically an error to access a virtualized global without having installed
43
an appropriate handler (you will get a NullHandler), although in some cases we
44
provide a default implementation.
45

46
One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is
47
ubiquitous enough to have its own top level variable, so you will typically see
48
``ops.constant(...)`` rather than ``V.ops.constant(...)``.  In fact, these are not
49
equivalent; the former interface supports arithmetic overloads like ``x + y``
50
instead of forcing ``ops.add(x, y)``, so it should be preferred.
51

52
Some operators are seemingly unused, but they are implicitly used by ops_wrapper.
53
In particular, we typically have an operator for every basic pointwise PyTorch operation
54
supported.
55
"""
56

57
from __future__ import annotations
58

59
from contextlib import AbstractContextManager, contextmanager
60
from threading import local
61
from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
62

63
from .ops_handler import (  # noqa: F401
64
    KernelFormatterHandler,
65
    MockHandler,
66
    OpsHandler,
67
    ReductionType,
68
    StoreMode,
69
    WrapperHandler,
70
)
71

72
if TYPE_CHECKING:
73
    import torch
74
    from torch._inductor.debug import DebugContext
75
    from torch._inductor.graph import GraphLowering
76
    from torch._inductor.ir import InterpreterShim
77
    from torch._subclasses import FakeTensorMode
78

79
threadlocal = local()
80

81
T = TypeVar("T")
82

83

84
class NullHandler:
85
    """
86
    Sentinel indicating that a global variable is unset ala None.  Typically,
87
    attempting to access the global variable before it's set is an error, but with
88
    NullHandler it won't fail until you try to access an attribute on it.
89
    """
90

91
    pass
92

93

94
class Virtualized(Generic[T]):
95
    """
96
    Implements a global variable that redirects via thread local variable
97
    (NB: construct this class to create the global variable; this is not
98
    a singleton class!)
99

100
    This allows us to swap in different op implementations in codegen.
101

102
    NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is
103
    the default value of the variable), we sometimes use these variables to
104
    store other things, like booleans.
105
    """
106

107
    def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
108
        self._key: str = f"__torchinductor_{vname}"
109
        self._default = default
110

111
    def _set_handler(self, value: T) -> AbstractContextManager[None]:
112
        prior = self._get_handler()
113
        setattr(threadlocal, self._key, value)
114

115
        @contextmanager
116
        def ctx():
117
            try:
118
                yield
119
            finally:
120
                self._set_handler(prior)
121

122
        return ctx()
123

124
    def _get_handler(self) -> T:
125
        try:
126
            return getattr(threadlocal, self._key)
127
        except AttributeError:
128
            # TODO: To be honest, I feel we probably should just error in this
129
            # case, instead of making a null handler that will probably error
130
            # when you getattr on it
131
            return self._default()  # type: ignore[return-value]
132

133
    def __getattr__(self, name: str) -> Any:
134
        return getattr(self._get_handler(), name)
135

136

137
class NullKernelHandler(NullHandler):
138
    """
139
    We need access `V.kernel.removed_buffers` in DeferredLine class when there
140
    is no kernel in the context. This happens when codegening the wrapper.
141
    Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
142
    need call 'getattr' with default value which is error prone to typo in
143
    attribute name.
144
    """
145

146
    def __init__(self):
147
        super().__init__()
148
        self.removed_buffers = set()
149
        self.inplaced_to_remove = set()
150
        self.index_dtype = "tl.int64"
151

152

153
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
154
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
155
_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
156
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
157
_kernel: Virtualized[NullKernelHandler] = Virtualized(
158
    "kernel", NullKernelHandler
159
)  # TODO: improve type
160
_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
161
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
162
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
163
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
164

165

166
class OpsValue:
167
    """The return type of most ops calls.
168

169
    This exists so we can overload magic methods, and write mathematical
170
    expressions much more fluently. So instead of
171

172
        ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
173

174
    we can write
175

176
        (_Ap2 * x - _Ap3) * x * x + _1
177

178
    """
179

180
    value: Any
181

182
    def __init__(self, value):
183
        self.value = value
184

185
    def __str__(self):
186
        return str(self.value)
187

188
    def __repr__(self):
189
        return f"OpsValue({self.value!r})"
190

191
    def __add__(self, other):
192
        return ops.add(self, other)
193

194
    def __mul__(self, other):
195
        return ops.mul(self, other)
196

197
    def __sub__(self, other):
198
        return ops.sub(self, other)
199

200
    def __neg__(self):
201
        return ops.neg(self)
202

203
    def __truediv__(self, other):
204
        return ops.truediv(self, other)
205

206
    def __floordiv__(self, other):
207
        return ops.floordiv(self, other)
208

209
    def __mod__(self, other):
210
        return ops.mod(self, other)
211

212
    def __pow__(self, other):
213
        return ops.pow(self, other)
214

215
    def __lt__(self, other):
216
        return ops.lt(self, other)
217

218
    def __le__(self, other):
219
        return ops.le(self, other)
220

221
    def __eq__(self, other):
222
        return ops.eq(self, other)
223

224
    def __ne__(self, other):
225
        return ops.ne(self, other)
226

227
    def __gt__(self, other):
228
        return ops.gt(self, other)
229

230
    def __ge__(self, other):
231
        return ops.ge(self, other)
232

233
    def __and__(self, other):
234
        return ops.bitwise_and(self, other)
235

236
    def __or__(self, other):
237
        return ops.bitwise_or(self, other)
238

239
    def __xor__(self, other):
240
        return ops.bitwise_xor(self, other)
241

242
    def __invert__(self):
243
        return ops.bitwise_not(self)
244

245
    def __rshfit__(self, n):
246
        return ops.bitwise_right_shift(self, n)
247

248
    def __lshift__(self, n):
249
        return ops.bitwise_left_shift(self, n)
250

251

252
class OpsWrapper:
253
    """This wraps any returned IR values into an `OpsValue` instance, so that we
254
    can overload the magic methods for writing mathematical expressions fluently.
255
    """
256

257
    def __getattr__(self, name):
258
        def inner(*args, **kwargs):
259
            new_args = [OpsWrapper._unwrap(a) for a in args]
260
            new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
261
            return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
262

263
        return inner
264

265
    @staticmethod
266
    def _unwrap(x):
267
        if isinstance(x, (list, tuple)):
268
            return tuple(OpsWrapper._unwrap(v) for v in x)
269
        if isinstance(x, OpsValue):
270
            return x.value
271
        return x
272

273
    @staticmethod
274
    def _wrap(x):
275
        if isinstance(x, (list, tuple)):
276
            return tuple(OpsValue(v) for v in x)
277
        return OpsValue(x)
278

279
    @staticmethod
280
    def indirect_indexing(index, size, check=True):
281
        # Returns a sympy value, not IR value
282
        index = OpsWrapper._unwrap(index)
283
        return _ops.indirect_indexing(index, size, check)
284

285

286
ops = OpsWrapper()
287

288

289
class _V:
290
    MockHandler = MockHandler
291
    KernelFormatterHandler = KernelFormatterHandler
292
    WrapperHandler = WrapperHandler
293

294
    set_ops_handler: Callable[[Any], Any] = _ops._set_handler
295
    get_ops_handler: Callable[[], Any] = _ops._get_handler
296
    set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
297
    set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
298
    get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
299
    set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
300
    get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
301
    set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
302
    set_debug_handler: Callable[[Any], Any] = _debug._set_handler
303
    set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
304
    set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler
305
    get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
306
    set_current_node: Callable[[Any], Any] = _current_node._set_handler
307
    get_current_node: Callable[[], Any] = _current_node._get_handler
308

309
    @property
310
    def ops(self) -> OpsHandler[Any]:
311
        """The operator handler specific to the current codegen task"""
312
        return _ops._get_handler()
313

314
    @property
315
    def graph(self) -> GraphLowering:
316
        """The graph currently being generated"""
317
        return _graph._get_handler()
318

319
    @property
320
    def real_inputs(self):
321
        """non-fake example inputs"""
322
        return _real_inputs._get_handler()
323

324
    @property
325
    def fake_mode(self):
326
        """The graph currently being generated"""
327
        return _fake_mode._get_handler()
328

329
    @property
330
    def kernel(self):
331
        """The kernel currently being generated"""
332
        return _kernel._get_handler()
333

334
    @property
335
    def debug(self):
336
        return _debug._get_handler()
337

338
    @property
339
    def interpreter(self):
340
        return _interpreter._get_handler()
341

342
    @property
343
    def aot_compilation(self):
344
        return _aot_compilation._get_handler()
345

346
    @property
347
    def current_node(self):
348
        return _current_node._get_handler()
349

350

351
V = _V()
352

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.