2
This file provides a number of "global" variables/handlers that are actually
3
thread local and dynamically scoped, with Inductor patching them to various
4
implementations depending on the situation.
6
These handlers are interacted with in a fairly stylized way. Typically,
7
we will import V from this module::
9
from .virtualized import V
11
Various handlers are accessible as attributes on this module; for example,
12
you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with
15
There are a few distinct usage patterns for virtualized global variables:
17
1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``.
18
Use ``V.set_current_node`` to change what the current node is while we're
19
executing some region of code, so code inside that region can query ``V.current_node``
20
to find out what it is. This is often more convenient than manually threading
21
the current node as an argument through all call stacks.
23
2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a
24
given ``compile_fx`` invocation, these typically don't change, but they are
25
associated with some internal state so they cannot just be global functions.
26
We install these objects at the beginning of compilation and then you can
27
conveniently access them without having to pass them around.
29
3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``.
30
A commonly used IR in Inductor is define-by-run: instead of maintaining
31
explicit syntax data structures, we instead represent loop bodies as
32
callable functions, which internally invoke operations defined on
33
``V.ops``. To perform semantic analysis, print or code generate these
34
operations, we dynamically patch ``V.ops`` with an alternate handler with
35
the intended semantics and then run the callable function. For example, to
36
extract out a traditional (FX) graph representation of the define-by-run
37
IR, simply install a handler that records each ``ops`` call to a graph.
39
TODO: Define a parent class / protocol that defines all of the operations
40
V.ops is expected to support.
42
It is typically an error to access a virtualized global without having installed
43
an appropriate handler (you will get a NullHandler), although in some cases we
44
provide a default implementation.
46
One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is
47
ubiquitous enough to have its own top level variable, so you will typically see
48
``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not
49
equivalent; the former interface supports arithmetic overloads like ``x + y``
50
instead of forcing ``ops.add(x, y)``, so it should be preferred.
52
Some operators are seemingly unused, but they are implicitly used by ops_wrapper.
53
In particular, we typically have an operator for every basic pointwise PyTorch operation
57
from __future__ import annotations
59
from contextlib import AbstractContextManager, contextmanager
60
from threading import local
61
from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
63
from .ops_handler import ( # noqa: F401
64
KernelFormatterHandler,
74
from torch._inductor.debug import DebugContext
75
from torch._inductor.graph import GraphLowering
76
from torch._inductor.ir import InterpreterShim
77
from torch._subclasses import FakeTensorMode
86
Sentinel indicating that a global variable is unset ala None. Typically,
87
attempting to access the global variable before it's set is an error, but with
88
NullHandler it won't fail until you try to access an attribute on it.
94
class Virtualized(Generic[T]):
96
Implements a global variable that redirects via thread local variable
97
(NB: construct this class to create the global variable; this is not
100
This allows us to swap in different op implementations in codegen.
102
NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is
103
the default value of the variable), we sometimes use these variables to
104
store other things, like booleans.
107
def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
108
self._key: str = f"__torchinductor_{vname}"
109
self._default = default
111
def _set_handler(self, value: T) -> AbstractContextManager[None]:
112
prior = self._get_handler()
113
setattr(threadlocal, self._key, value)
120
self._set_handler(prior)
124
def _get_handler(self) -> T:
126
return getattr(threadlocal, self._key)
127
except AttributeError:
128
# TODO: To be honest, I feel we probably should just error in this
129
# case, instead of making a null handler that will probably error
130
# when you getattr on it
131
return self._default() # type: ignore[return-value]
133
def __getattr__(self, name: str) -> Any:
134
return getattr(self._get_handler(), name)
137
class NullKernelHandler(NullHandler):
139
We need access `V.kernel.removed_buffers` in DeferredLine class when there
140
is no kernel in the context. This happens when codegening the wrapper.
141
Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
142
need call 'getattr' with default value which is error prone to typo in
148
self.removed_buffers = set()
149
self.inplaced_to_remove = set()
150
self.index_dtype = "tl.int64"
153
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
154
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
155
_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
156
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
157
_kernel: Virtualized[NullKernelHandler] = Virtualized(
158
"kernel", NullKernelHandler
159
) # TODO: improve type
160
_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
161
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
162
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
163
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
167
"""The return type of most ops calls.
169
This exists so we can overload magic methods, and write mathematical
170
expressions much more fluently. So instead of
172
ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
176
(_Ap2 * x - _Ap3) * x * x + _1
182
def __init__(self, value):
186
return str(self.value)
189
return f"OpsValue({self.value!r})"
191
def __add__(self, other):
192
return ops.add(self, other)
194
def __mul__(self, other):
195
return ops.mul(self, other)
197
def __sub__(self, other):
198
return ops.sub(self, other)
203
def __truediv__(self, other):
204
return ops.truediv(self, other)
206
def __floordiv__(self, other):
207
return ops.floordiv(self, other)
209
def __mod__(self, other):
210
return ops.mod(self, other)
212
def __pow__(self, other):
213
return ops.pow(self, other)
215
def __lt__(self, other):
216
return ops.lt(self, other)
218
def __le__(self, other):
219
return ops.le(self, other)
221
def __eq__(self, other):
222
return ops.eq(self, other)
224
def __ne__(self, other):
225
return ops.ne(self, other)
227
def __gt__(self, other):
228
return ops.gt(self, other)
230
def __ge__(self, other):
231
return ops.ge(self, other)
233
def __and__(self, other):
234
return ops.bitwise_and(self, other)
236
def __or__(self, other):
237
return ops.bitwise_or(self, other)
239
def __xor__(self, other):
240
return ops.bitwise_xor(self, other)
242
def __invert__(self):
243
return ops.bitwise_not(self)
245
def __rshfit__(self, n):
246
return ops.bitwise_right_shift(self, n)
248
def __lshift__(self, n):
249
return ops.bitwise_left_shift(self, n)
253
"""This wraps any returned IR values into an `OpsValue` instance, so that we
254
can overload the magic methods for writing mathematical expressions fluently.
257
def __getattr__(self, name):
258
def inner(*args, **kwargs):
259
new_args = [OpsWrapper._unwrap(a) for a in args]
260
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
261
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
267
if isinstance(x, (list, tuple)):
268
return tuple(OpsWrapper._unwrap(v) for v in x)
269
if isinstance(x, OpsValue):
275
if isinstance(x, (list, tuple)):
276
return tuple(OpsValue(v) for v in x)
280
def indirect_indexing(index, size, check=True):
281
# Returns a sympy value, not IR value
282
index = OpsWrapper._unwrap(index)
283
return _ops.indirect_indexing(index, size, check)
290
MockHandler = MockHandler
291
KernelFormatterHandler = KernelFormatterHandler
292
WrapperHandler = WrapperHandler
294
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
295
get_ops_handler: Callable[[], Any] = _ops._get_handler
296
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
297
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
298
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
299
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
300
get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
301
set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
302
set_debug_handler: Callable[[Any], Any] = _debug._set_handler
303
set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
304
set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler
305
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
306
set_current_node: Callable[[Any], Any] = _current_node._set_handler
307
get_current_node: Callable[[], Any] = _current_node._get_handler
310
def ops(self) -> OpsHandler[Any]:
311
"""The operator handler specific to the current codegen task"""
312
return _ops._get_handler()
315
def graph(self) -> GraphLowering:
316
"""The graph currently being generated"""
317
return _graph._get_handler()
320
def real_inputs(self):
321
"""non-fake example inputs"""
322
return _real_inputs._get_handler()
326
"""The graph currently being generated"""
327
return _fake_mode._get_handler()
331
"""The kernel currently being generated"""
332
return _kernel._get_handler()
336
return _debug._get_handler()
339
def interpreter(self):
340
return _interpreter._get_handler()
343
def aot_compilation(self):
344
return _aot_compilation._get_handler()
347
def current_node(self):
348
return _current_node._get_handler()