2
from typing import Any, Callable, Generic, Literal, Optional, Tuple, TypeVar, Union
3
from unittest.mock import patch
6
from typing_extensions import Protocol
9
import torch.utils._pytree as pytree
10
from torch.fx.graph import inplace_methods, magic_methods
11
from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str
14
StoreMode = Optional[Literal["atomic_add"]]
15
ReductionType = Literal[
29
def _arg_str(a) -> str:
30
if isinstance(a, sympy.Expr):
41
class OpsHandler(Protocol[T]):
43
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
44
as well as the contract for op handlers. The type T signifies the domain
45
of the abstract analysis AKA what all of the functions return / take as arguments
46
anywhere compute occurs.
48
While these operators are typically dtype polymorphic (e.g., you can use mul
49
on both integers and floats), they do NOT do promotion and usually return the
50
same dtype as the input. You are expected to have handled type promotion
51
during ATen decompositions. Most operators correspond exactly to pointwise
52
operations as defined by torch, so when in doubt about semantics, check the
53
corresponding torch documentation. These are all scalar operations (so they
54
are defined to operate on a single element at a time.)
56
For convenience, many operators take a src_dtype which indicates what the dtype
57
of the input argument is. Although in principle this can be derived by an
58
analysis, providing this for ops where it is useful helps avoid having to repeatedly
59
recompute dtype in code generation.
61
Note that this often describes a class of static methods, for stateless
64
Handlers are often defined using ``__getattr__`` metaprogramming, which means
65
that you cannot declare that a type implements a protocol by inheriting from
66
it (as the type stubs count as attribute declarations and impede the getattr
67
magic method from being called). Instead, define a function that casts an
68
argument of your type to the protocol, which is sufficient to induce mypy to
69
test that the protocol is implemented correctly. Search for ``_typecheck_``
70
in this file to see some examples. If you see an obscure error where a
71
class doesn't implement a Protocol, but mypy doesn't say why, check to see
72
that ``__getattr__`` is typed correctly (typically, it is not possible to
73
type ``__getattr__`` without typing it as ``Callable[..., Any]``)
76
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
77
"""Produces a scalar constant of type dtype."""
80
def load_seed(self, name: str, offset: T):
81
"""Computes inductor_prims.lookup_seed."""
84
def rand(self, seed: T, offset: T) -> T:
85
"""Computes inductor_prims.random with mode="rand". offset has dtype int32."""
88
def randn(self, seed: T, offset: T) -> T:
89
"""Computes inductor_prims.random with mode="randn". offset has dtype int32."""
92
def randint64(self, seed: T, offset: T, low: T, high: T) -> T:
93
"""Computes inductor_prims.randint. offset has dtype int32."""
96
def masked(self, mask: T, body: Callable[[], T], other: T) -> T:
98
Computes body, but only perform loads/stores if the boolean mask
99
evaluates to true. For example, you would use this if you needed to
100
perform an indirect load that may not be valid on some elements;
101
without masking, invalid accesses can cause IMAs. When mask is true,
102
the result is the result of body; otherwise it is other.
104
Contrast this with ops.where, which can multiplex between two values
105
that have been unconditionally computed.
109
def where(self, condition: T, input: T, other: T) -> T:
111
Computes torch.where: when condition is true, return input; otherwise return other.
115
def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T:
117
Converts a sympy expression into a scalar of type dtype. expr is typically
118
an indexing expression, thus the name; however, it can also be used in
119
non-indexing situations.
124
self, x: T, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None
127
Convert x to dtype. src_dtype can be optionally set to specify what the original
128
dtype of x was, which can improve code generation (used by torch to(dtype=dtype)).
132
def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T:
134
Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.)
135
src_dtype must be the original type of x.
148
def indirect_indexing(
149
self, x: T, size: sympy.Expr, check: bool = True
152
Convert an integral x into a sympy.Expr that can be subsequently used in
153
indexing computation. 'size' represents an upper bound on the what valid
154
indexes can be; when 'check' is True, we check that the x is in bounds.
156
NB: This is typically mandatory to implement for any analysis, because you
157
MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol).
161
def load(self, name: str, index: sympy.Expr) -> T:
163
Load from the memory location 'name', offset by some indexing expression 'index'.
172
mode: StoreMode = None,
175
Store 'value' to the memory location 'name' offset by 'expr'. If
176
specified, 'mode' can require the store to be an atomic addition.
186
src_dtype: torch.dtype,
187
reduction_type: ReductionType,
189
) -> Union[T, Tuple[T, ...]]:
191
Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype',
192
using 'dtype' as the accumulation dtype for the reduction. The result
193
is an intermediate computation which should be stored to the final
194
location using 'ops.store_reduction'.
196
Valid reduction types are . For Welford reduction types, this
197
function returns multiple outputs; consult reduction_num_outputs to
198
determine the amount in metaprogramming applications.
205
def store_reduction(self, name: str, index: sympy.Expr, value: T) -> T:
207
Store the fully accumulated result of 'reduction' to the memory
208
location 'name' offset by 'expr'.
213
self, dtype: torch.dtype, combine_fn: Callable[[T, T], T], value: T, init: int
216
Perform an associative scan on 'value'.
225
offsets_size: sympy.Expr,
226
indexing_dtype: torch.dtype,
236
def abs(self, x0: T) -> T:
239
def exp(self, x0: T) -> T:
242
def exp2(self, x0: T) -> T:
245
def expm1(self, x0: T) -> T:
248
def sqrt(self, x0: T) -> T:
251
def relu(self, x0: T) -> T:
254
def minimum(self, x0: T, x1: T) -> T:
257
def maximum(self, x0: T, x1: T) -> T:
260
def cos(self, x0: T) -> T:
263
def sin(self, x0: T) -> T:
266
def lgamma(self, x0: T) -> T:
269
def erf(self, x0: T) -> T:
272
def cosh(self, x0: T) -> T:
275
def sinh(self, x0: T) -> T:
278
def acos(self, x0: T) -> T:
281
def acosh(self, x0: T) -> T:
284
def asin(self, x0: T) -> T:
287
def asinh(self, x0: T) -> T:
290
def atan2(self, x0: T, x1: T) -> T:
293
def atan(self, x0: T) -> T:
296
def atanh(self, x0: T) -> T:
299
def copysign(self, x0: T, x1: T) -> T:
302
def erfc(self, x0: T) -> T:
305
def erfinv(self, x0: T) -> T:
308
def frexp(self, x0: T):
311
def hypot(self, x0: T, x1: T) -> T:
314
def log10(self, x0: T) -> T:
317
def nextafter(self, x0: T, x1: T) -> T:
320
def logical_and(self, x0: T, x1: T) -> T:
323
def logical_not(self, x0: T) -> T:
326
def logical_or(self, x0: T, x1: T) -> T:
329
def logical_xor(self, x0: T, x1: T) -> T:
332
def bitwise_and(self, x0: T, x1: T) -> T:
335
def bitwise_not(self, x0: T) -> T:
338
def bitwise_or(self, x0: T, x1: T) -> T:
341
def bitwise_xor(self, x0: T, x1: T) -> T:
344
def bitwise_left_shift(self, x0: T, x1: T) -> T:
347
def bitwise_right_shift(self, x0: T, x1: T) -> T:
350
def rsqrt(self, x0: T) -> T:
353
def log1p(self, x0: T) -> T:
356
def tan(self, x0: T) -> T:
359
def tanh(self, x0: T) -> T:
362
def sigmoid(self, x0: T) -> T:
365
def signbit(self, x0: T) -> T:
368
def fmod(self, x0: T, x1: T) -> T:
371
def log(self, x0: T) -> T:
374
def isinf(self, x0: T) -> T:
377
def isnan(self, x0: T) -> T:
380
def round(self, x0: T) -> T:
383
def floor(self, x0: T) -> T:
386
def sign(self, x0: T) -> T:
389
def to_int(self, x0: T) -> T:
392
def trunc(self, x0: T) -> T:
395
def truncdiv(self, x0: T, x1: T) -> T:
398
def ceil(self, x0: T) -> T:
401
def neg(self, x0: T) -> T:
404
def reciprocal(self, x0: T) -> T:
407
def eq(self, x0: T, x1: T) -> T:
410
def ne(self, x0: T, x1: T) -> T:
413
def lt(self, x0: T, x1: T) -> T:
416
def gt(self, x0: T, x1: T) -> T:
419
def le(self, x0: T, x1: T) -> T:
422
def ge(self, x0: T, x1: T) -> T:
425
def add(self, x0: T, x1: T) -> T:
428
def sub(self, x0: T, x1: T) -> T:
431
def mul(self, x0: T, x1: T) -> T:
434
def floordiv(self, x0: T, x1: T) -> T:
437
def truediv(self, x0: T, x1: T) -> T:
440
def div(self, x0: T, x1: T) -> T:
443
def mod(self, x0: T, x1: T) -> T:
446
def pow(self, x0: T, x1: T) -> T:
449
def and_(self, x0: T, x1: T) -> T:
452
def or_(self, x0: T, x1: T) -> T:
455
def xor(self, x0: T, x1: T) -> T:
471
def libdevice_abs(self, x0: T) -> T:
474
def libdevice_exp(self, x0: T) -> T:
477
def libdevice_sqrt(self, x0: T) -> T:
480
def libdevice_cos(self, x0: T) -> T:
483
def libdevice_sin(self, x0: T) -> T:
486
def libdevice_sigmoid(self, x0: T) -> T:
489
def libdevice_log(self, x0: T) -> T:
494
def __getattr__(self, name):
498
def inner(*args, **kwargs):
499
fargs = [_arg_str(a) for a in args]
500
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
501
return f"ops.{name}({', '.join(fargs)})"
506
def masked(mask, body, other) -> str:
507
return f"ops.masked({mask}, {body()}, {other})"
511
return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]")
514
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
515
return sympy_index_symbol(f"({str(index_var)})")
519
def make_handler(format_string):
522
return format_string.format(*args)
526
for name, format_string in itertools.chain(
527
magic_methods.items(), inplace_methods.items()
529
setattr(cls, name, make_handler(format_string))
532
MockHandler._init_cls()
536
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
540
class KernelFormatterHandler:
541
def __init__(self, parent_handler):
542
self.parent_handler = parent_handler
543
self.output = IndentedBuffer(1)
544
self.var_counter = itertools.count()
547
def ir_to_string(ir_fn, index, rindex=None) -> str:
548
from .ir import FlexibleLayout
549
from .virtualized import V
551
args = [index, rindex] if rindex is not None else [index]
552
names = ["index", "rindex"] if rindex is not None else ["index"]
553
formatter = KernelFormatterHandler(MockHandler())
555
with formatter.output.indent(-1):
556
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
557
for name, arg in zip(names, args):
561
str("_" if isinstance(v, (int, sympy.Integer)) else v)
565
formatter.output.writeline(f"{lhs} = {name}")
567
with V.set_ops_handler(formatter), patch.object(
568
FlexibleLayout, "allow_indexing", True
570
result = ir_fn(*args)
571
return formatter.getvalue(result)
573
def __getattr__(self, name) -> Callable[..., Any]:
574
def inner(*args, **kwargs):
575
line = getattr(self.parent_handler, name)(*args, **kwargs)
576
if name == "indirect_indexing":
581
varname = f"tmp{next(self.var_counter)}"
582
self.output.writeline(f"{varname} = {line}")
585
return pytree.tree_map(write, line)
592
src_dtype: torch.dtype,
593
reduction_type: ReductionType,
594
value: Union[str, Tuple[str, ...]],
595
) -> Union[str, Tuple[str, ...]]:
596
line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
597
num_values = reduction_num_outputs(reduction_type)
598
varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
599
self.output.writeline(f"{','.join(varnames)} = {line}")
600
return tuple(varnames) if num_values > 1 else varnames[0]
602
def getvalue(self, result):
603
self.output.writeline(f"return {result}")
604
return self.output.getvalue()
608
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
612
class WrapperHandler(Generic[T]):
613
def __init__(self, inner: OpsHandler[T]):
616
def __getattr__(self, item):
617
return getattr(self._inner, item)
621
def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
626
"""Shim to count how many ops are used"""
628
def __init__(self, inner):
630
self.parent_handler = inner
634
def __getattr__(self, name):
635
def inner(*args, **kwargs):
636
val = getattr(self.parent_handler, name)(*args, **kwargs)
637
if name == "indirect_indexing":
641
if val not in self.var_names:
642
varname = f"tmp{self.op_count}"
644
self.var_names[val] = varname
647
return self.var_names[val]
649
return pytree.tree_map(count, val)
654
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: