10
from contextlib import nullcontext
12
from functools import partial
27
from unittest.mock import patch
30
from sympy import Expr, Integer
32
import torch._export.serde.schema as export_schema
37
import torch.utils._pytree as pytree
38
from torch._dynamo.device_interface import get_interface_for_device
39
from torch._dynamo.utils import identity
40
from torch._export.serde.serialize import GraphModuleSerializer
41
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
42
from torch._prims_common import (
43
compute_required_storage_length,
46
make_channels_last_strides_for,
47
make_contiguous_strides_for,
50
from torch._subclasses.fake_tensor import get_schema_info
51
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymTypes
52
from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
54
from . import config, dependencies
55
from .codegen.common import index_prevent_reordering
56
from .dependencies import (
57
extract_free_unbacked_symbols,
58
extract_input_node_reduction_ranges,
62
from .ops_handler import OpCounterCSE
66
convert_shape_to_inductor,
67
convert_shape_to_symint,
77
from .virtualized import ops, V
80
from .graph import GraphLowering
82
log = logging.getLogger(__name__)
83
indent = functools.partial(textwrap.indent, prefix=" ")
86
""" [Note: Inductor IR]
88
Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
89
lowering is registered to a particular aten operator, and expects inputs that
90
correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
91
expect Inductor TensorBox inputs.
93
TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
94
storage, and sometimes views of another Tensor's storage. Mutating tensor operations
95
(such as add_()) affect the underlying storage and any associated views. Other operations
96
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
98
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
100
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
101
output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
102
reference View IR or directly reference StorageBox IRs.
104
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
105
may take an existing TensorBox and point it to a new underlying View IR.
107
Tensors that directly own storage are represented as a chain of:
108
TensorBox -> StorageBox -> Buffer
109
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
111
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
112
(leaving the old buffer unmodified and functionalizing the operation).
114
Tensors backed by views add one more indirection to the IR.
115
TensorBox -> View -> StorageBox -> Buffer
116
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
120
def validate_ir(node_or_nodes):
121
def _check_tensorbox(nodes):
122
# Could expand this to check deeper properties
123
# (e.g. TensorBox points to View or StorageBox)
124
if isinstance(nodes, (list, tuple)):
126
_check_tensorbox(node)
127
elif isinstance(nodes, dict):
128
for node in nodes.values():
129
_check_tensorbox(node)
134
torch._inductor.ir.ExpandView,
138
sympy.logic.boolalg.Boolean,
141
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
143
# Be picky about the accepted data structure (don't use pytree here)
144
_check_tensorbox(node_or_nodes)
147
def ops_wrapper(name):
148
assert isinstance(name, str)
150
def fn(*args, **kwargs):
151
return getattr(ops, name)(*args, **kwargs)
156
def inverse_reorder(order):
157
inv_order = dict(zip(order, range(len(order))))
160
assert len(index) == len(inv_order)
161
return [index[inv_order[i]] for i in range(len(index))]
166
def same_reorder(order):
168
assert len(index) == len(order)
169
return [index[order[i]] for i in range(len(index))]
174
def fuse_reindexing(reindex1, reindex2):
176
return reindex1(reindex2(index))
181
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
184
def stride_order2fill_order(order):
186
Convert stride order to fill order
187
For channel last format,
188
stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
190
lookup = {pos: idx for idx, pos in enumerate(order)}
191
fill_order = [lookup[i] for i in range(len(order))]
195
def get_stride_order(seq: Sequence[int]) -> List[int]:
197
Convert strides to stride order
199
sorted_idx: List[int] = argsort(seq)
200
out = [0 for _ in range(len(seq))]
201
for i, elem in enumerate(sorted_idx):
206
def ir_node_to_tensor(x, guard_shape=True):
210
shape_fn: Callable[[Expr], Union[int, Expr]]
212
shape_fn = V.graph.sizevars.size_hint
215
size = [shape_fn(s) for s in x.get_size()]
217
if is_storage_and_layout(x):
218
stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc]
220
stride = make_contiguous_strides_for(size) # type: ignore[arg-type]
221
dtype = x.get_dtype()
222
device = x.get_device()
223
size = convert_shape_to_symint(size)
224
stride = convert_shape_to_symint(stride)
225
t = torch.empty_strided(
226
size=size, stride=stride, dtype=dtype, device=device
231
def may_convert_to_optional(value):
232
if isinstance(value, list) and not value:
233
# [None] makes sure the cpp wrapper codegen will generate something like
234
# {c10::nullopt} instead of {}
239
def get_device_type(x):
240
if getattr(x, "get_device", None):
241
return get_device_type(x.get_device())
242
if isinstance(x, torch.device):
248
return get_device_type(x) == "cuda"
252
return get_device_type(x) == "cpu"
256
_current_origins: ClassVar[Set[Any]] = set()
259
@contextlib.contextmanager
260
def current_origins(origins: Set[torch.fx.Node]):
261
old = IRNode._current_origins
262
IRNode._current_origins = old | origins
266
IRNode._current_origins = old
268
def __post_init__(self):
269
self.origins = set(self._current_origins)
270
self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
272
def get_traceback(self):
273
return self.traceback
275
def common_repr(self):
276
origins = f"origins={getattr(self, 'origins', '')}"
277
if len(origins) > 64:
278
# this can get *very* long
279
origins = f"{origins[:61]}..."
282
def str_helper(self, lines):
283
lines = lines + self.common_repr()
284
lines = indent(",\n".join(map(str, lines)))
285
return f"{type(self).__name__}(\n{lines}\n)"
287
def is_user_of(self, name):
288
return name in self.get_read_names()
291
def get_read_names(self):
292
return {dep.name for dep in self.get_reads()}
297
def get_layout(self):
298
raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
301
raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
304
return sympy_product(self.get_size())
306
def is_zero_elements(self):
307
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
311
If the IRNode refers to data which has not been materialized (e.g.,
312
it is a Pointwise/Reduction that could potentially have more
313
compute fused into it), realize the IRNode into physical memory,
314
ending the possibility of fusing into it, but allowing, e.g., multiple
315
users to access the data without having to recompute.
317
Check StorageBox.realize for a particularly notable implementation.
319
TODO(ezyang): I think, in principle, every IRNode should have an
320
implementation of this, and most of the time no-op is OK, but you
321
really do have to audit each IRNode for this, so for now, raise
322
an error if it's not implemented. Note that some code in graph.py
323
will catch this thrown error and suppress it with a warning.
325
raise NotImplementedError(f"realize NYI on {type(self)}")
327
def codegen_reference(self, writer=None):
328
raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
330
# The abstract method declarations below serve to convince mypy that all IRNode instances have these functions
331
# defined, while having no effect at runtime. We cannot create stub implementations here because other parts of
332
# the code dynamically check for defined attributes.
333
get_device: Callable[[], torch.device]
335
get_name: Callable[[], str]
336
get_reads: Callable[[], Any]
337
get_stride: Callable[[], Any]
338
get_storage_numel: Callable[[], Any]
339
has_exceeded_max_reads: Callable[[], bool]
340
make_loader: Callable[[], Callable[[Any], Any]]
341
make_indexer: Callable[[], Callable[[Any], Any]]
342
mark_reuse: Callable[[int], None]
343
realize_hint: Callable[[], None]
344
get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]]
347
@dataclasses.dataclass
351
inner_fn: Callable[..., Any]
354
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
356
*(free_unbacked_symbols(e) for e in self.ranges),
357
self.inner_fn_free_unbacked_symbols(),
360
def __str__(self, names=("ranges",)):
361
return self.str_helper(
363
f"'{self.device.type}'",
367
+ [f"{name}={getattr(self, name)}" for name in names]
368
+ [f"origin_node={self.origin_node!r}"]
371
def __post_init__(self):
372
super().__post_init__()
373
self.origin_node = None
377
def get_device(self):
380
def get_origin_node(self):
381
return self.origin_node
386
def get_pointwise_size(self):
393
def create(cls, *args, **kwargs):
394
origin_node = kwargs.pop("origin_node", None)
395
tb = kwargs.pop("traceback", None)
396
r = cls(*args, **kwargs)
397
r.origin_node = origin_node
399
tb or traceback.format_stack() if config.debug_ir_traceback else None
401
return TensorBox.create(r)
404
def _index(ranges, prefix="i"):
406
sympy.Integer(0) if s == 1 else sympy_index_symbol(f"{prefix}{n}")
407
for n, s in enumerate(ranges)
411
def inner_fn_opcount(self):
412
from .ir import FlexibleLayout
414
opcounter = OpCounterCSE(V.MockHandler())
416
with V.set_ops_handler(opcounter), patch.object(
417
FlexibleLayout, "allow_indexing", True
419
result = self.inner_fn(*self.inner_fn_args())
420
return opcounter.op_count
422
def inner_fn_args(self):
423
return (self._index(self.ranges),)
425
def inner_fn_str(self):
426
return V.KernelFormatterHandler.ir_to_string(
427
self.inner_fn, *self.inner_fn_args()
430
def has_large_inner_fn(self):
431
return self.inner_fn_opcount() > config.realize_opcount_threshold
433
def inner_fn_free_unbacked_symbols(self):
434
index = self._index(self.ranges)
435
return extract_free_unbacked_symbols(self.inner_fn, index)
438
with patch.object(FlexibleLayout, "allow_indexing", True):
439
if self.get_reduction_type():
440
return extract_read_writes(
443
self.get_reduction_size(),
446
return extract_read_writes(
451
def get_reduction_size(self):
452
raise NotImplementedError(
453
f"get_reduction_size() is not implemented by {type(self)}!"
456
def get_reduction_type(self):
457
raise NotImplementedError(
458
f"get_reduction_type() is not implemented by {type(self)}!"
461
def constant_to_device(self, device):
462
raise NotImplementedError(
463
f"constant_to_device() is not implemented by {type(self)}!"
467
def nop_loader_fn(idx, *, dtype):
468
if dtype.is_floating_point:
469
return ops.constant(float("nan"), dtype)
471
return ops.constant(0, dtype)
474
class Pointwise(Loops):
475
def make_loader(self):
476
# Make zero-element loops into a no-op
477
if self.is_zero_elements():
478
return partial(nop_loader_fn, dtype=self.dtype)
482
def get_reduction_size(self):
485
def get_reduction_type(self):
488
def store_output(self, output_name, indexer, vars):
489
loader = self.make_loader()
490
return ops.store(output_name, indexer(vars), loader(vars))
492
def constant_to_device(self, device):
493
"""Move this to a given device. Requires that all reads are to constants."""
494
loader = self.make_loader()
495
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
496
return Pointwise(device, self.dtype, loader, self.ranges)
499
@dataclasses.dataclass
500
class Scatter(Pointwise):
501
output_indexer: Callable[[List[Expr]], Expr]
502
scatter_mode: Optional[str] = None
504
def constant_to_device(self, device):
505
"""Move this to a given device. Requires that all reads are to constants."""
506
loader = self.make_loader()
507
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
517
def store_output(self, output_name, indexer, vars):
518
loader = self.make_loader()
521
indexer(self.output_indexer(vars)),
523
mode=self.scatter_mode,
527
class ReductionHint(Enum):
539
REDUCTION_COMBINE_FN = {
540
"any": ops_wrapper("logical_or"),
541
"max": ops_wrapper("maximum"),
542
"min": ops_wrapper("minimum"),
543
"prod": ops_wrapper("mul"),
544
"sum": ops_wrapper("add"),
545
"xor_sum": ops_wrapper("bitwise_xor"),
549
def get_reduction_combine_fn(reduction_type, dtype):
550
if reduction_type in REDUCTION_COMBINE_FN:
551
combine_fn = REDUCTION_COMBINE_FN[reduction_type]
552
elif reduction_type in {"argmax", "argmin"}:
554
def combine_fn(a, b):
558
if reduction_type == "argmin":
559
mask = ops.lt(a_value, b_value)
561
mask = ops.gt(a_value, b_value)
563
equal = ops.eq(a_value, b_value)
564
if is_float_dtype(dtype):
565
a_isnan = ops.ne(a_value, a_value)
566
b_isnan = ops.ne(b_value, b_value)
567
mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
568
equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
570
mask = ops.logical_or(
571
mask, ops.logical_and(equal, ops.lt(a_index, b_index))
574
ops.where(mask, a_value, b_value),
575
ops.where(mask, a_index, b_index),
578
elif reduction_type == "welford_combine":
580
def combine_fn(a, b):
581
a_mean, a_m2, a_weight = a
582
b_mean, b_m2, b_weight = b
584
delta = b_mean - a_mean
585
new_weight = a_weight + b_weight
586
w2_over_w = b_weight / new_weight
588
a_mean + delta * w2_over_w,
589
a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
594
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
599
@dataclasses.dataclass
600
class Reduction(Loops):
601
reduction_ranges: List[Expr]
603
# self.dtype represents the dst dtype
604
src_dtype: torch.dtype
605
reduction_hint: ReductionHint
608
return Loops.__str__( # type: ignore[call-arg]
609
self, names=("ranges", "reduction_ranges", "reduction_type")
613
return self.__str__()
615
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
616
return super().get_unbacked_symbol_uses() | set().union(
617
*(free_unbacked_symbols(e) for e in self.reduction_ranges)
620
def get_reduction_size(self):
621
return self.reduction_ranges
623
def get_reduction_type(self):
624
return self.reduction_type
626
def store_reduction(self, output_name, indexer, vars, reduction_vars):
627
value = ops.reduction(
631
self.inner_fn(vars, reduction_vars),
633
return ops.store_reduction(output_name, indexer(vars), value)
635
def index_length(self):
636
return len(self.ranges) + len(self.reduction_ranges)
638
def inner_fn_args(self):
639
index = self._index(self.ranges)
640
rindex = self._index(self.reduction_ranges, "r")
641
return (index, rindex)
643
def inner_fn_free_unbacked_symbols(self):
644
index = self._index(self.ranges)
645
rindex = self._index(self.reduction_ranges, "r")
646
return extract_free_unbacked_symbols(self.inner_fn, index, rindex)
648
def constant_to_device(self, device):
649
"""Move this to a given device. Requires that all reads are to constants."""
650
loader = self.make_loader()
651
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
657
self.reduction_ranges,
660
ReductionHint.DEFAULT,
673
input_node: Optional[IRNode] = None,
676
return isinstance(x, (int, sympy.Integer))
678
reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
679
numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
688
and config.split_reductions
689
# We don't support unbacked symints
690
and _is_static(reduction_numel_hint)
691
and _is_static(numel_hint)
694
return ReductionHint.DEFAULT, 1
696
device_interface = get_interface_for_device(get_device_type(device))
697
num_sm = device_interface.Worker.get_device_properties(
699
).multi_processor_count
700
min_elements_per_thread = 32
701
max_elements_per_thread = 512
702
threads_per_sm = 2048
703
min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
704
max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
706
def inner_reduction_splits(reduction_numel_hint, numel_hint):
707
# do heuristics that's close to eager mode for split inner reduction
708
# we leak reduction autotune configs here, and will need to refactor to avoid this later
710
num_threads = 32 * num_warps
711
if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
713
if reduction_numel_hint <= 8192:
715
if reduction_numel_hint * numel_hint <= min_elements_per_device:
716
split_size = min_elements_per_thread
717
elif reduction_numel_hint * numel_hint < max_elements_per_device:
718
target_blocks = num_sm * threads_per_sm // (2 * num_threads)
719
blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
721
reduction_numel_hint + num_threads * blocks_per_output - 1
722
) // (num_threads * blocks_per_output)
723
divisors = sympy.divisors(reduction_numel_hint)
724
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
725
if abs(closest - tmp_split_size) < 30:
726
# prefer even splits, but never smalle than min_elements_per_thread
727
split_size = max(closest, min_elements_per_thread)
729
split_size = tmp_split_size
731
divisors = sympy.divisors(reduction_numel_hint)
732
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
733
if abs(closest - max_elements_per_thread) < 50:
737
split_size = max_elements_per_thread
738
return (reduction_numel_hint + split_size * num_threads - 1) // (
739
split_size * num_threads
742
def outer_reduction_splits(reduction_numel_hint, numel_hint):
743
# TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
744
# extend to even smaller number of outputs
746
num_threads = num_warps * 32
747
rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
748
xvals_per_block = 128
749
xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
750
if reduction_numel_hint * numel_hint < min_elements_per_device:
751
split_size = min_elements_per_thread
752
elif reduction_numel_hint * numel_hint < max_elements_per_device:
753
target_blocks = num_sm * threads_per_sm // (num_threads)
754
target_blocks = (target_blocks + xblocks - 1) // xblocks
756
reduction_numel_hint + rvals_per_thread * target_blocks - 1
757
) // (rvals_per_thread * target_blocks)
758
divisors = sympy.divisors(reduction_numel_hint)
759
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
760
if abs(tmp_split_size - closest) < 20:
761
split_size = max(closest, min_elements_per_thread)
763
split_size = tmp_split_size
765
divisors = sympy.divisors(reduction_numel_hint)
766
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
767
if abs(closest - max_elements_per_thread) < 50:
771
split_size = max_elements_per_thread
773
return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
774
rvals_per_thread * split_size
779
split = inner_reduction_splits(reduction_numel_hint, numel_hint)
782
return ReductionHint.INNER, split
785
and input_node is not None
786
and isinstance(input_node, TensorBox)
788
# Only handles the case where keep_dim = False.
789
# Otherwise, we need to propagate reduction dim info to the stage where
790
# the intermediate loader of the first Reduction is generated.
791
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
794
if new_ranges is not None and new_reduction_ranges is not None:
795
extracted_numel_hint = V.graph.sizevars.symbolic_hint(
796
sympy_product(new_ranges + new_reduction_ranges)
798
if reduction_numel_hint == extracted_numel_hint:
800
"Use previous IRNode's range and reduction_ranges instead of split. "
801
"current ranges: %s, current reduction ranges: %s, current split: %d, "
802
"new ranges: %s, new reduction ranges: %s",
807
new_reduction_ranges,
809
# If the input_node or its dependent nodes are also Reduction nodes,
810
# use reduction_sizes of this node or its dependent nodes directly.
811
return ReductionHint.INNER, -1
812
return ReductionHint.INNER, split
814
reduction_numel_hint <= min_elements_per_thread
815
or numel_hint >= num_sm * 2 * 32
817
return ReductionHint.DEFAULT, 1
827
ReductionHint.DEFAULT,
830
def get_read_indices(r):
833
layout=FlexibleLayout(
834
device=r.get_device(),
840
read_writes = cb.get_read_writes()
841
# try finding the full size producer
842
# TODO this will fail for something like ((1, N) * (N, 1)).sum()
843
# this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
846
for r in read_writes.range_vars
847
if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
851
for md in sorted(read_writes.reads, key=lambda x: x.name):
852
if all(r in md.index.free_symbols for r in range_vars):
853
indices.append(md.index)
854
if md.name in V.graph.name_to_buffer:
855
buf = V.graph.name_to_buffer[md.name]
856
original_stride = buf.layout.stride
858
if buf.layout.stride != original_stride:
860
return indices, changed
862
indices, changed = get_read_indices(r)
864
indices, _ = get_read_indices(r)
866
if len(indices) == 0:
867
# TODO determine splits when all inputs are broadcast
868
return ReductionHint.DEFAULT, 1
870
(_, reduction_vars), ranges = dependencies.index_vars_squeeze(
871
r.get_size(), r.get_reduction_size()
876
i = V.graph.sizevars.simplify_with_ranges(i, ranges)
877
strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys())
878
outer = all(s > 1 for s in strides)
883
if num_inner > num_outer:
884
return ReductionHint.INNER, inner_reduction_splits(
885
reduction_numel_hint, numel_hint
888
return ReductionHint.OUTER, outer_reduction_splits(
889
reduction_numel_hint, numel_hint
893
def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype):
894
"""Convert inner_fn from a reduction to an pointwise"""
896
V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges
899
combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
902
return functools.reduce(
905
value_fn(index, rindex)
906
for rindex in itertools.product(
907
*[range(x) for x in reduction_ranges]
912
if reduction_type in ("argmin", "argmax"):
913
flatten_index = FixedLayout(
914
None, # type: ignore[arg-type]
915
None, # type: ignore[arg-type]
917
FlexibleLayout.contiguous_strides(reduction_ranges),
920
def value_fn(index, rindex):
921
rindex = [sympy.expand(i) for i in rindex]
923
inner_fn(index, rindex),
924
ops.index_expr(flatten_index(rindex), torch.int64),
927
return lambda index: fn(index)[1]
933
def create( # type: ignore[override]
935
device: torch.device,
936
dst_dtype: torch.dtype,
937
src_dtype: torch.dtype,
938
inner_fn: Callable[..., Any],
940
reduction_ranges: List[Expr],
942
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
943
input_node: Optional[IRNode] = None,
945
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
947
if reduction_numel == 0:
948
# N.B. This is a hack to generate the literal of the given type
949
# Ideally, we should be fixing `def constant` in triton.py
950
# but it breaks due to hardcoded dtypes in other places
954
if dst_dtype == torch.bool
956
if dst_dtype.is_floating_point
962
"xor_sum": py_cnst(0),
965
# "all" is desugared to `!any(!val)`
969
reduction_type in rtypes_to_inits.keys()
970
), f"{reduction_type} not supported for zero-dimension tensors!"
973
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
975
return Pointwise.create(
982
if reduction_numel == 1:
983
# this reduction is actually a pointwise op
984
if reduction_type in ("argmin", "argmax"):
987
return ops.constant(0, dst_dtype)
992
reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
993
return inner_fn(index, reduction_index)
995
return Pointwise.create(device, dst_dtype, fn, ranges)
998
isinstance(reduction_numel, sympy.Integer)
999
and V.graph.sizevars.size_hint(reduction_numel)
1000
< config.unroll_reductions_threshold
1001
and sympy_product(ranges) != 1
1003
return Pointwise.create(
1006
cls._unroll_reduction_fn(
1007
inner_fn, reduction_ranges, reduction_type, src_dtype
1012
# triton doesn't support reduce to single element well, so break it up
1013
hint, split = cls.num_splits(
1024
# intermediate reduction in split can contain complex indexing,
1025
# and num_splits will fail to correctly set the hint
1026
# reuse the passed hint if available
1027
if reduction_hint == ReductionHint.DEFAULT:
1028
reduction_hint = hint
1030
assert input_node is not None
1031
new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
1032
input_node # type: ignore[arg-type]
1034
assert new_ranges is not None
1035
assert new_reduction_ranges is not None
1036
return cls.create_multilayer_existing_ranges(
1044
new_reduction_ranges,
1049
# triton doesn't support reduce to single element well, so break it up
1050
return cls.create_multilayer(
1062
return TensorBox.create(
1076
def default_accumulator(reduction_type, dtype):
1077
if reduction_type in {"max", "argmax"}:
1078
if is_float_dtype(dtype):
1079
return float("-inf")
1080
elif is_boolean_dtype(dtype):
1083
return torch.iinfo(dtype).min
1084
if reduction_type in {"min", "argmin"}:
1085
if is_float_dtype(dtype):
1087
elif is_boolean_dtype(dtype):
1090
return torch.iinfo(dtype).max
1097
"welford_reduce": (0, 0, 0),
1098
"welford_combine": (0, 0, 0),
1102
def default_value(reduction_type, dtype):
1103
if reduction_type == "welford_reduce":
1105
return Reduction.default_accumulator(reduction_type, dtype)
1108
def _multilayer_second_step_hint(
1109
split: int, numel_hint: int, reduction_hint: ReductionHint
1112
return reduction_hint
1113
if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
1114
return ReductionHint.OUTER_TINY
1117
and numel_hint <= 256
1118
and reduction_hint == ReductionHint.OUTER
1120
return ReductionHint.OUTER_TINY
1122
return reduction_hint
1125
def _multilayer_wrap_loader(
1134
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
1135
need_mask = not V.graph.sizevars.is_expr_static_and_true(
1136
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
1139
def wrapper_fn(index, reduction_index):
1140
(reduction_index,) = reduction_index
1141
*new_index, reduction_block = index
1142
indices = block_size * reduction_block + reduction_index
1145
return loader(new_index, reindex([indices]))
1149
ops.index_expr(indices, torch.int32),
1150
ops.index_expr(reduction_numel, torch.int32),
1152
return ops.masked(mask, body, default)
1159
def _multilayer_wrap_loader_existing_ranges(
1163
original_reduction_ranges,
1165
new_reduction_ranges,
1168
assert len(original_ranges) == 0, f"{original_ranges}= is not equal to []"
1169
reindex = View.dynamic_reshape_indexer(
1170
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
1173
def wrapper_fn(index, reduction_index):
1174
return loader([], reindex(tuple(index) + tuple(reduction_index)))
1179
def create_multilayer_helper(
1181
device: torch.device,
1182
dst_dtype: torch.dtype,
1183
src_dtype: torch.dtype,
1184
wrapper_fn: Callable[..., Any],
1185
original_ranges: List[Expr],
1186
original_reduction_ranges: List[Expr],
1187
new_ranges: List[Expr],
1188
new_reduction_ranges: List[Expr],
1189
reduction_type: str,
1191
reduction_hint: ReductionHint,
1194
Break a large reduction up into multiple smaller reductions
1197
# triton will automatically compute reductions in fp32 if reducing over fp16/bf16
1198
# within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
1199
# in fp32 and not reduce precision by breaking up the kernel into multiple layers
1200
intermediate_dtype = (
1202
if dst_dtype not in (torch.float16, torch.bfloat16)
1205
intermediate = Reduction.create(
1211
new_reduction_ranges,
1215
intermediate.realize()
1216
intermediate_loader = intermediate.make_loader()
1218
def intermediate_fn(index, reduction_index):
1219
return intermediate_loader([*index, *reduction_index])
1221
numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges))
1222
reduction_hint = cls._multilayer_second_step_hint(
1223
split, numel_hint, reduction_hint
1226
assert original_ranges == new_ranges[: len(original_ranges)]
1227
return TensorBox.create(
1233
new_ranges[len(original_ranges) :],
1241
def create_multilayer(
1243
device: torch.device,
1244
dst_dtype: torch.dtype,
1245
src_dtype: torch.dtype,
1246
inner_fn: Callable[..., Any],
1248
reduction_ranges: List[Expr],
1249
reduction_type: str,
1251
reduction_hint: ReductionHint,
1254
Break a large reduction up into multiple smaller reductions
1257
# TODO(jansel): realize the reduction so we can do dynamic indexing
1258
reduction_numel = sympy_product(reduction_ranges)
1259
block_size = FloorDiv(reduction_numel + (split - 1), split)
1260
default = cls.default_value(reduction_type, dst_dtype)
1261
wrapper_fn = cls._multilayer_wrap_loader(
1262
inner_fn, reduction_ranges, reduction_numel, split, block_size, default
1265
return cls.create_multilayer_helper(
1272
[*ranges, split], # type: ignore[list-item]
1280
def create_multilayer_existing_ranges(
1282
device: torch.device,
1283
dst_dtype: torch.dtype,
1284
src_dtype: torch.dtype,
1285
inner_fn: Callable[..., Any],
1286
original_ranges: List[Expr],
1287
original_reduction_ranges: List[Expr],
1288
new_ranges: List[Expr],
1289
new_reduction_ranges: List[Expr],
1290
reduction_type: str,
1291
reduction_hint: ReductionHint,
1294
Break a large reduction up into multiple smaller reductions
1297
default = cls.default_value(reduction_type, dst_dtype)
1298
wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
1301
original_reduction_ranges,
1303
new_reduction_ranges,
1306
return cls.create_multilayer_helper(
1312
original_reduction_ranges,
1314
new_reduction_ranges,
1321
def num_reduction_outputs(reduction_type):
1322
return 3 if "welford" in reduction_type else 1
1325
class WelfordReduction(Reduction):
1339
if len(inner_fns) == 1:
1340
loader = inner_fns[0]
1343
def loader(idx, reduction_idx):
1344
return tuple(fn(idx, reduction_idx) for fn in inner_fns)
1356
self.output_index = output_index
1358
def store_reduction(self, output_name, indexer, vars, reduction_vars):
1359
values = ops.reduction(
1362
self.reduction_type,
1363
self.inner_fn(vars, reduction_vars),
1365
value = values[self.output_index]
1366
return ops.store_reduction(output_name, indexer(vars), value)
1369
def create( # type: ignore[override]
1371
device: torch.device,
1373
inner_fns: Sequence[Callable[..., Any]],
1375
reduction_ranges: List[Expr],
1376
reduction_type: str,
1377
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1379
assert reduction_type in {"welford_reduce", "welford_combine"}
1381
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
1385
return ops.constant(
1390
return Pointwise.create(
1394
ranges=list(ranges),
1397
if reduction_numel == 0:
1401
return mean, m2, weight
1403
if reduction_numel == 1:
1407
reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
1408
return loader(idx, reduction_index)
1410
return Pointwise.create(
1414
ranges=list(ranges),
1417
if reduction_type == "welford_reduce":
1418
return copy(inner_fns[0]), const(0), const(1)
1420
return tuple(copy(fn) for fn in inner_fns)
1422
# TODO: Unrolled reduction
1424
# isinstance(reduction_numel, sympy.Integer)
1425
# and V.graph.sizevars.size_hint(reduction_numel)
1426
# < config.unroll_reductions_threshold
1427
# and sympy_product(ranges) != 1
1429
# return Pointwise.create(
1432
# cls._unroll_reduction_fn(
1433
# inner_fn, reduction_ranges, reduction_type, src_dtype
1438
# triton doesn't support reduce to single element well, so break it up
1439
hint, split = Reduction.num_splits(
1446
reduction_type=reduction_type,
1447
reduction_numel=reduction_numel,
1449
# intermediate reduction in split can contain complex indexing,
1450
# and num_splits will fail to correctly set the hint
1451
# reuse the passed hint if available
1452
if reduction_hint == ReductionHint.DEFAULT:
1453
reduction_hint = hint
1455
# triton doesn't support reduce to single element well, so break it up
1456
return cls.create_multilayer(
1480
for output_idx in range(3)
1487
def default_value(reduction_type, dtype):
1491
def create_multilayer( # type: ignore[override]
1493
device: torch.device,
1495
inner_fns: Sequence[Callable[..., Any]],
1497
reduction_ranges: List[Expr],
1498
reduction_type: str,
1500
reduction_hint: ReductionHint,
1503
Break a large reduction up into multiple smaller reductions
1506
reduction_numel = sympy_product(reduction_ranges)
1507
need_mask = not V.graph.sizevars.is_expr_static_and_true(
1508
sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type]
1511
if need_mask and reduction_type != "welford_combine":
1512
# If we need mask, then "welford_reduce" doesn't work because
1513
# masked inputs shouldn't count towards the welford weight
1515
def constant(idx, reduction_idx, value):
1516
return ops.constant(value, dtype)
1518
return cls.create_multilayer(
1523
partial(constant, value=0),
1524
partial(constant, value=1),
1527
reduction_ranges=reduction_ranges,
1528
reduction_type="welford_combine",
1530
reduction_hint=reduction_hint,
1533
block_size = FloorDiv(reduction_numel + (split - 1), split)
1534
intermediates = WelfordReduction.create(
1538
cls._multilayer_wrap_loader(
1546
for loader in inner_fns
1548
[*ranges, split], # type: ignore[list-item]
1553
for i in intermediates:
1556
i_loaders = [i.make_loader() for i in intermediates]
1558
def intermediate_loader_fn(index, reduction_index, loader):
1559
return loader([*index, *reduction_index])
1561
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
1562
reduction_hint = cls._multilayer_second_step_hint(
1563
split, numel_hint, reduction_hint
1565
return WelfordReduction.create(
1569
partial(intermediate_loader_fn, loader=i.make_loader())
1570
for i in intermediates
1573
[split], # type: ignore[list-item]
1574
# welford_reduce turns one input into three outputs, which are combined with welford_combine
1580
@dataclasses.dataclass
1582
scan_ranges: List[Expr]
1584
combine_fn: Callable[..., Any]
1585
reindex: Callable[[List[Expr], List[Expr]], List[Expr]]
1586
reduction_hint: ReductionHint
1589
# HACK we mimick reduction
1591
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
1592
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
1593
# need to explicitly represent the closure so we can pull out unbacked
1596
super().get_unbacked_symbol_uses()
1597
| set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
1598
| set().union(*(free_unbacked_symbols(e) for e in self.size))
1601
def __post_init__(self):
1602
assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
1603
super().__post_init__()
1605
def store_reduction(self, output_name, indexer, vars, scan_vars):
1606
idx = self.reindex(vars, scan_vars)
1607
value = self.inner_fn(idx)
1608
result = ops.scan(self.dtype, self.combine_fn, value, self.init)
1609
return ops.store(output_name, indexer(idx), result)
1611
def get_reduction_type(self):
1612
# return self.scan_op
1615
def get_reduction_size(self):
1616
return self.scan_ranges
1621
def get_pointwise_size(self):
1624
def index_length(self):
1625
return len(self.ranges) + len(self.scan_ranges)
1627
def inner_fn_args(self):
1628
index = self._index(self.ranges)
1629
rindex = self._index(self.scan_ranges, "r")
1630
idx = self.reindex(index, rindex)
1633
def inner_fn_free_unbacked_symbols(self):
1634
index = self._index(self.ranges)
1635
rindex = self._index(self.scan_ranges, "r")
1636
idx = self.reindex(index, rindex)
1637
return extract_free_unbacked_symbols(self.inner_fn, idx)
1642
device: torch.device,
1644
inner_fn: Callable[[List[Expr]], Any],
1647
combine_fn: Callable[..., Any],
1649
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1650
) -> Optional["TensorBox"]:
1651
pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
1652
scan_ranges = [size[axis]]
1654
if device.type != "cuda":
1658
sizevars = V.graph.sizevars
1659
scan_numel = sizevars.simplify(sympy_product(scan_ranges))
1661
# Scan with a single element is just a copy
1662
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type]
1663
return Pointwise.create(
1670
reduction_hint, num_splits = cls.num_splits(
1675
pointwise_ranges=pointwise_ranges,
1676
scan_ranges=scan_ranges,
1677
combine_fn=combine_fn,
1678
scan_numel=scan_numel,
1680
scan_type = Scan if num_splits <= 1 else SplitScan
1682
if num_splits > 1 and torch.version.hip is not None:
1683
# Fallback for split-scan on ROCm
1686
def reindex(index, scan_index):
1687
assert len(scan_index) == len(scan_ranges)
1688
assert len(index) == len(pointwise_ranges)
1689
return [*index[:axis], *scan_index, *index[axis:]]
1691
result = TensorBox.create(
1697
ranges=pointwise_ranges,
1698
scan_ranges=scan_ranges,
1699
combine_fn=combine_fn,
1702
reduction_hint=reduction_hint,
1711
device: torch.device,
1713
inner_fn: Callable[[List[Expr]], Any],
1715
pointwise_ranges: List[Expr],
1716
scan_ranges: List[Expr],
1717
combine_fn: Callable[..., Any],
1720
# TODO: custom splitting heuristic for scan
1721
def wrapper_fn(idx, reduction_idx):
1722
return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
1724
return Reduction.num_splits(
1728
inner_fn=wrapper_fn,
1729
ranges=pointwise_ranges,
1730
reduction_ranges=scan_ranges,
1731
reduction_type="sum",
1732
reduction_numel=scan_numel,
1736
# This signifies a scan op that should go through TritonSplitScanKernel codgen on CUDA.
1737
@dataclasses.dataclass
1738
class SplitScan(Scan):
1742
def is_storage_and_layout(x):
1744
as_storage_and_layout(x, freeze=False)
1746
except NotImplementedError:
1750
def is_contiguous_storage_and_layout(x):
1752
buffer, layout = as_storage_and_layout(x, freeze=False)
1753
return layout.is_contiguous()
1754
except NotImplementedError:
1758
def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
1759
"""Try to simplify x into a StorageBox and a Layout"""
1760
if isinstance(x, TensorBox):
1761
return as_storage_and_layout(
1764
want_contiguous=want_contiguous,
1765
stride_order=stride_order,
1767
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
1770
x.data.freeze_layout()
1771
assert x.data.layout.is_contiguous()
1772
elif stride_order is not None:
1773
x.data.freeze_layout_with_stride_order(stride_order)
1775
x.data.decide_layout()
1776
return x, x.data.layout
1777
if isinstance(x, ReinterpretView):
1778
# making the base of x contiguous or stride_ordered will not necessarily make
1779
# the ReinterpretView either, so don't pass along those arguments
1780
buffer, _ = as_storage_and_layout(
1784
return buffer, x.layout
1785
raise NotImplementedError
1788
as_contiguous_storage_and_layout = functools.partial(
1789
as_storage_and_layout, want_contiguous=True
1793
def is_stride_order_storage_and_layout(x, stride_order):
1795
buffer, layout = as_storage_and_layout(x, freeze=False)
1796
return layout.is_stride_ordered(stride_order)
1797
except NotImplementedError:
1801
@dataclasses.dataclass
1802
class BaseView(IRNode):
1805
def get_unbacked_symbol_uses(self):
1806
return self.data.get_unbacked_symbol_uses()
1808
def make_reindexer(self):
1809
raise NotImplementedError(f"make_reindexer NYI on {self}")
1811
def make_indexer(self):
1812
inner = self.data.make_indexer()
1813
reindex = self.make_reindexer()
1816
return inner(reindex(idx))
1820
def make_loader(self):
1821
inner = self.data.make_loader()
1822
reindex = self.make_reindexer()
1825
return inner(reindex(idx))
1831
return self.data.dtype
1833
def get_layout(self):
1834
return self.data.get_layout()
1836
def get_device(self):
1837
return self.data.get_device()
1839
def get_origin_node(self):
1843
return self.data.get_name()
1845
def get_pointwise_size(self):
1846
return self.get_size()
1848
def mark_reuse(self, users):
1849
return self.data.mark_reuse(users)
1851
def has_exceeded_max_reads(self):
1852
return self.data.has_exceeded_max_reads()
1855
return self.data.realize()
1857
def realize_hint(self):
1858
return self.data.realize_hint()
1860
def get_storage_numel(self):
1861
return self.data.get_storage_numel()
1863
def is_extern(self):
1864
return self.data.is_extern() # type: ignore[attr-defined]
1866
def get_reads(self):
1867
with patch.object(FlexibleLayout, "allow_indexing", True):
1868
return extract_read_writes(
1873
def unwrap_view(self):
1875
while isinstance(x, BaseView):
1879
def constant_to_device(self, device):
1880
"""Move this to a given device. Requires that all reads are to constants."""
1881
loader = self.make_loader()
1882
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
1883
return Pointwise(device, self.get_dtype(), loader, self.get_size())
1886
@dataclasses.dataclass
1887
class ExpandView(BaseView):
1891
def _normalize_size(x, new_size):
1892
"""Replace `-1` with correct sizes"""
1893
new_size = list(map(sympy.expand, new_size))
1894
old_size = x.get_size()
1895
old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
1896
assert len(new_size) == len(old_size)
1897
for i in range(len(new_size)):
1898
if new_size[i] == -1:
1899
assert old_size[i] is not None
1900
new_size[i] = old_size[i]
1901
elif old_size[i] is None or old_size[i] == 1:
1904
# Expect broadcast compatibility
1905
new_size[i] = V.graph.sizevars.expect_equals(
1908
msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}",
1913
def create(cls, x, new_size):
1914
new_size = cls._normalize_size(x, new_size)
1916
if is_storage_and_layout(x):
1917
storage, old_layout = as_storage_and_layout(x)
1918
skip = len(new_size) - len(old_layout.size)
1920
new_stride = [sympy.Integer(0)] * skip
1921
for stride, size in zip(old_layout.stride, old_layout.size):
1922
new_stride.append(stride if size != 1 else sympy.Integer(0))
1923
new_layout = FixedLayout(
1930
return ReinterpretView(storage, new_layout)
1932
return ExpandView(x, new_size)
1937
def make_reindexer(self):
1938
target = self.get_size()
1939
actual = self.data.get_size()
1940
skip = len(target) - len(actual)
1943
index = list(index[skip:])
1944
assert len(index) == len(actual)
1945
for i in range(len(actual)):
1947
# zero out broadcast dimension
1948
index[i] = sympy.Integer(0)
1954
@dataclasses.dataclass
1955
class PermuteView(BaseView):
1959
def create(cls, x, dims):
1960
dims = cls._map_neg_dims(dims)
1961
assert set(dims) == set(range(len(dims)))
1963
if is_storage_and_layout(x):
1964
storage, old_layout = as_storage_and_layout(x)
1965
new_layout = FixedLayout(
1968
[old_layout.size[i] for i in dims],
1969
[old_layout.stride[i] for i in dims],
1972
return ReinterpretView(storage, new_layout)
1974
return PermuteView(x, dims)
1977
def _map_neg_dims(cls, dims):
1978
return [dim if dim >= 0 else len(dims) + dim for dim in dims]
1981
assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
1982
size = self.data.get_size()
1983
return [size[i] for i in self.dims]
1985
def make_reindexer(self):
1986
inv = {j: i for i, j in enumerate(self.dims)}
1987
inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index]
1988
assert set(inv) == set(range(len(self.dims)))
1991
return [index[i] for i in inv]
1996
class SqueezeView(BaseView):
1998
def create(cls, x, *, dim=None):
1999
if is_storage_and_layout(x):
2000
storage, old_layout = as_storage_and_layout(x)
2004
assert isinstance(dim, int), "expected integer dim argument"
2005
assert 0 <= dim and dim < len(old_layout.size)
2007
for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
2010
new_size.append(size)
2011
new_stride.append(stride)
2014
new_size.append(size)
2015
new_stride.append(stride)
2017
assert size == 1, "expected squeezed size to be 1"
2019
new_layout = FixedLayout(
2026
return ReinterpretView(storage, new_layout)
2029
# redirect to a generic view
2030
return View.create(x, [s for s in x.get_size() if s != 1])
2032
assert x.get_size()[dim] == 1
2033
return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
2036
def squeezer(size: Tuple[sympy.Expr, ...]):
2037
new_size = [s for s in size if s != 1]
2038
not_one = [i for i, s in enumerate(size) if s != 1]
2041
def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]:
2042
assert len(index) == len(not_one), f"{index} {not_one}"
2043
new_index = [sympy.Integer(0)] * length
2044
for idx, s in zip(not_one, index):
2046
return tuple(new_index)
2048
return new_size, reindex
2050
def __init__(self, data):
2051
raise AssertionError("use SqueezeView.create()")
2054
@dataclasses.dataclass
2055
class GenericView(BaseView):
2057
reindex: Callable[..., Any]
2059
def make_reindexer(self):
2062
def reindex_str(self):
2063
index_old = [sympy_index_symbol(f"i{n}") for n in range(len(self.size))]
2064
index_new = list(self.reindex(index_old))
2065
return f"lambda {', '.join(map(str, index_old))}: {index_new}"
2068
return self.str_helper(
2069
[self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
2075
def create(cls, x, new_size, reindex):
2076
return cls(x, list(new_size), reindex)
2082
@dataclasses.dataclass
2083
class View(GenericView):
2085
def handle_negative_index(idx, size):
2086
idx = sympy.expand(idx)
2087
size = sympy.expand(size)
2088
evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
2089
if evaluate_expr(sympy.Lt(idx, 0)):
2094
def create(cls, x, new_size):
2095
assert isinstance(new_size, (tuple, list))
2096
old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
2098
# Skip pointless views
2099
if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
2102
unbacked_symbols_in_sizes = False
2104
len(free_unbacked_symbols(old_size)) > 0
2105
or len(free_unbacked_symbols(new_size)) > 0
2107
unbacked_symbols_in_sizes = True
2111
def fake_reindex(index):
2112
return tuple([0] * len(old_size))
2114
return cls(x, list(new_size), fake_reindex)
2115
# TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
2116
elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes:
2117
if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)):
2118
# realize x; otherwise, the dynamic_reshape_indexer below will fail
2119
# due to the size_hint's inability to process unbacked SymInts
2120
x = ExternKernel.realize_input(x)
2122
storage, old_layout = as_contiguous_storage_and_layout(x)
2123
new_layout = FixedLayout(
2127
FlexibleLayout.contiguous_strides(new_size),
2130
return ReinterpretView(storage, new_layout)
2132
reindex = cls.dynamic_reshape_indexer(old_size, new_size)
2133
return cls(x, list(new_size), reindex)
2136
def resolve_negative_size(old_size, new_size):
2137
new_size = [V.graph.sizevars.simplify(x) for x in new_size]
2138
old_size = [V.graph.sizevars.simplify(x) for x in old_size]
2140
new_size = list(new_size)
2141
for i in range(len(new_size)):
2142
if new_size[i] == -1:
2143
new_size[i] = sympy.Integer(1)
2144
new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
2147
V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
2148
return old_size, new_size
2151
def dynamic_reshape_indexer(cls, old_size, new_size):
2153
reindex = cls._dynamic_reshape_indexer(old_size, new_size)
2154
except (AssertionError, IndexError):
2155
# optimistic algorithm failed, lets do a fallback
2156
flat = [sympy_product(old_size)]
2157
reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
2158
reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
2159
reindex = fuse_reindexing(reindex1, reindex2)
2163
def _dynamic_reshape_indexer(old_size, new_size):
2165
Perform a reshape entirely by modifying indexing math
2167
size_hint = V.graph.sizevars.size_hint
2168
vars = [sympy_index_symbol(f"view{i}") for i in range(len(new_size))]
2170
stack_new = list(zip(vars, new_size))
2171
stack_old = list(old_size)
2174
while stack_new and stack_old:
2175
size_old = stack_old.pop()
2176
var, size_new = stack_new.pop()
2178
view_expr.append(sympy.Integer(0))
2179
stack_new.append((var, size_new)) # re-add
2181
stack_old.append(size_old) # re-add
2182
elif size_hint(size_new) == size_hint(size_old):
2183
view_expr.append(var)
2184
V.graph.sizevars.guard_equals(size_new, size_old)
2185
elif size_hint(size_new) < size_hint(size_old):
2186
while size_hint(size_new) < size_hint(size_old):
2187
var2, size_new2 = stack_new.pop()
2188
var = var2 * size_new + var
2189
size_new = size_new * size_new2
2190
view_expr.append(var)
2191
V.graph.sizevars.guard_equals(size_new, size_old)
2192
elif size_hint(size_new) > size_hint(size_old):
2193
divisor = sympy.Integer(1)
2195
view_expr.append(ModularIndexing(var, divisor, modulus))
2196
divisor = divisor * modulus
2197
while size_hint(size_new) > size_hint(size_old):
2198
modulus = stack_old.pop()
2199
view_expr.append(ModularIndexing(var, divisor, modulus))
2200
divisor = divisor * modulus
2201
size_old = size_old * modulus
2202
V.graph.sizevars.guard_equals(size_new, size_old)
2204
raise AssertionError()
2207
size_old = stack_old.pop()
2208
V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type]
2209
view_expr.append(sympy.Integer(0))
2212
var, size_new = stack_new.pop()
2213
V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type]
2215
view_expr = list(reversed(view_expr))
2216
assert len(view_expr) == len(old_size)
2219
assert len(index) == len(vars), (len(index), len(vars))
2220
replacements = dict(zip(vars, index))
2221
return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type]
2226
@dataclasses.dataclass
2227
class ReinterpretView(BaseView):
2228
"""Pretend our storage has a different layout"""
2232
def __post_init__(self):
2233
super().__post_init__()
2234
if isinstance(self.data, BaseView):
2235
self.data = self.data.unwrap_view()
2238
return self.str_helper(
2248
return self.data.get_name()
2250
def get_device(self):
2251
return self.layout.device
2253
def get_origin_node(self):
2258
return self.layout.dtype
2261
return list(self.layout.size)
2263
def get_stride(self):
2264
return list(self.layout.stride)
2266
def make_loader(self):
2268
indexer = self.layout.make_indexer()
2269
return ops.load(self.get_name(), indexer(index))
2273
def make_indexer(self):
2274
return self.layout.make_indexer()
2276
def get_layout(self):
2279
def freeze_layout(self):
2282
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
2284
free_unbacked_symbols(self.layout.size)
2285
| free_unbacked_symbols(self.layout.stride)
2286
| free_unbacked_symbols(self.layout.offset)
2289
def codegen_reference(self, writer=None):
2290
# reinterpret_tensor is similar to as_strided except:
2291
# - offset is added to the existing offset (rather than replacing it)
2292
# - view tracking is disabled similar to unsafe_view
2293
return V.graph.wrapper_code.codegen_reinterpret_view(
2302
class SliceView(View):
2304
def normalize_start_end(cls, x, dim, start, end):
2306
Normalize start and end such that both are in the range
2307
[0, x.get_size()[dim]] and start <= end.
2309
sizevars = V.graph.sizevars
2310
dim_size = x.get_size()[dim]
2312
if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
2314
def clamp(x, lower, upper):
2315
return sympy.Min(sympy.Max(x, lower), upper)
2319
def clamp(x, lower, upper):
2320
return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
2322
def clamp_wrap(val, lower, upper, default):
2325
val = cls.handle_negative_index(val, dim_size)
2326
return clamp(val, lower, upper)
2328
start = clamp_wrap(start, 0, dim_size, 0)
2329
end = clamp_wrap(end, start, dim_size, dim_size)
2333
def create(cls, x, dim, start, end, step=1):
2334
step = sympy.expand(step)
2337
if start == 0 and end >= 2**63 - 1 and step == 1:
2342
sizevars = V.graph.sizevars
2343
new_size = list(x.get_size())
2345
start, end = cls.normalize_start_end(x, dim, start, end)
2347
new_size[dim] = FloorDiv(end - start + (step - 1), step)
2349
if is_storage_and_layout(x):
2351
storage, old_layout = as_storage_and_layout(x)
2352
new_stride = list(old_layout.stride)
2353
new_stride[dim] = new_stride[dim] * step
2354
new_layout = FixedLayout(
2359
old_layout.offset + old_layout.stride[dim] * start,
2361
return ReinterpretView(storage, new_layout)
2364
assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
2366
index[dim] = index[dim] * step + start
2369
# redirect to a generic view
2370
return SliceView(x, size=new_size, reindex=reindex)
2373
class BaseConstant(IRNode):
2375
device: torch.device
2380
def get_device(self):
2383
def get_origin_node(self):
2386
def mark_reuse(self, users):
2389
def has_exceeded_max_reads(self):
2392
def get_reads(self):
2395
def is_extern(self):
2399
@dataclasses.dataclass
2400
class Constant(BaseConstant):
2403
device: torch.device
2405
def make_loader(self):
2407
return ops.constant(self.value, self.dtype)
2414
def constant_to_device(self, device):
2415
return Constant(self.value, self.dtype, device)
2418
@dataclasses.dataclass
2419
class IndexingConstant(BaseConstant):
2422
device: torch.device
2424
def make_loader(self):
2426
return ops.index_expr(self.index, self.dtype)
2430
def constant_to_device(self, device):
2431
return IndexingConstant(self.index, self.dtype, device)
2434
def is_contiguous_strides_for_shape(stride, shape):
2436
size == 1 or left == right
2437
for left, right, size in zip(
2438
stride, FlexibleLayout.contiguous_strides(shape), shape
2443
@dataclasses.dataclass
2444
class Layout(IRNode):
2447
device: torch.device,
2450
stride: Optional[Sequence[Union[Expr, int]]],
2451
offset: Expr = Integer(0),
2453
assert stride is None or len(size) == len(
2455
), f"size={size}, stride={stride}"
2456
self.device = device
2458
assert all(isinstance(s, (Expr, int)) for s in size)
2460
self._stride = stride
2461
self.offset = offset
2469
if self.offset != 0:
2470
offset = f", offset={self.offset}"
2472
f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
2473
f"size={self.size}, stride={self.stride}{offset})"
2478
def is_contiguous(self):
2479
return is_contiguous_strides_for_shape(self.stride, self.size)
2481
def is_channels_last_contiguous(self):
2482
ndim = len(self.size)
2483
if ndim not in [4, 5]:
2485
for left, right, size in zip(
2486
self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type]
2488
if size != 1 and left != right:
2492
def is_transposed(self):
2493
for left, right, size in zip(
2495
reversed(FlexibleLayout.contiguous_strides(self.size)),
2498
if size != 1 and left != right:
2502
def is_stride_ordered(self, order):
2503
assert len(self.stride) == len(order)
2505
# ignore dimensions of size 1, they dont affect layout
2508
for i, dim in enumerate(self.size)
2509
if V.graph.sizevars.size_hint(dim, fallback=2) != 1
2512
stride = [self.stride[i] for i in non_1_indices]
2513
order = [order[i] for i in non_1_indices]
2515
def sorted_indices(arr):
2516
sorted_arr = sorted(arr)
2517
return [sorted_arr.index(element) for element in arr]
2519
# since we may have removed dimensions, need to re-sort & re-index order
2520
order = sorted_indices(order)
2522
# reorder the stride given order
2523
stride_ordered = [-1] * len(order)
2524
for i in range(len(order)):
2525
stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i])
2526
# check if it is in ascending order
2527
for i in range(len(order) - 1):
2528
if stride_ordered[i] > stride_ordered[i + 1]:
2532
def is_channels_last_stride_ordered(self):
2533
# create channels_last order(NCHW, NCDHW, the C is the first order).
2534
order = [0] + list(reversed(range(1, len(self.stride) - 1)))
2535
order = [len(order)] + order
2536
return self.is_stride_ordered(order)
2547
def make_indexer(self):
2549
FlexibleLayout.allow_indexing
2550
), f"convert {type(self).__name__} to FixedLayout first"
2551
return self.as_fixed().make_indexer()
2553
def __eq__(self, other) -> bool:
2555
self.device == other.device
2556
and self.dtype == other.dtype
2557
and self.size == other.size
2558
and self.stride == other.stride
2559
and self.offset == other.offset
2562
def storage_size(self) -> sympy.Expr:
2563
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value]
2566
class FixedLayout(Layout):
2567
"""A Tensor layout we cannot change"""
2571
device: torch.device,
2573
size: Union[List[Expr], List[int]],
2574
stride: Optional[Sequence[Union[Expr, int]]] = None,
2575
offset: Union[Expr, int] = Integer(0),
2578
stride = FlexibleLayout.contiguous_strides(size)
2582
size, # type: ignore[arg-type]
2584
offset, # type: ignore[arg-type]
2587
def make_indexer(self):
2588
"""A closure containing math to read a given element"""
2591
assert len(index) == len(self.stride) == len(self.size)
2592
result = self.offset
2593
for idx, stride, sz in zip(index, self.stride, self.size):
2595
result = result + idx * stride
2601
class FlexibleLayout(Layout):
2602
"""A Tensor layout we are allowed to change"""
2604
allow_indexing = False
2607
def contiguous_strides(sizes):
2610
reversed_strides = [sympy.Integer(1)]
2611
for size in reversed(sizes[1:]):
2612
reversed_strides.append(size * reversed_strides[-1])
2613
return list(reversed(reversed_strides))
2616
def fill_ordered(sizes, order):
2618
Create a stride based on the order the dimensions should be filled in.
2620
In this format, channels last would be:
2623
assert set(range(len(sizes))) == set(order)
2624
next_stride = sympy.Integer(1)
2625
strides = [None] * len(order)
2628
strides[i] = next_stride
2629
next_stride = next_stride * sizes[i]
2633
def stride_ordered(sizes, order):
2635
Create a stride based on the sorted order of a permuted range.
2637
In this format, channels last would be:
2640
assert set(range(len(sizes))) == set(order)
2641
fill_order = stride_order2fill_order(order)
2642
return FlexibleLayout.fill_ordered(sizes, fill_order)
2645
def same_ordered(sizes, stride):
2647
Create a stride that has the same stride order as given stride
2649
For example, if given stride is [1000, 1, 100, 10],
2650
the fill order should be [1, 3, 2, 0]
2652
assert len(sizes) == len(stride)
2653
stride = [V.graph.sizevars.size_hint(x) for x in stride]
2654
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
2655
return FlexibleLayout.fill_ordered(sizes, fill_order)
2657
def as_stride_order(self, order):
2662
self.stride_ordered(self.size, order),
2666
def as_fill_order(self, order):
2671
self.fill_ordered(self.size, order),
2675
def as_same_order(self, stride):
2680
self.same_ordered(self.size, stride),
2684
def __init__(self, device, dtype, size, stride_order=None):
2686
strides = FlexibleLayout.fill_ordered(size, stride_order)
2688
strides = FlexibleLayout.contiguous_strides(size)
2689
super().__init__(device, dtype, size, strides)
2692
class AliasedLayout(Layout):
2693
"""Shares the same storage as another tensor"""
2695
def __init__(self, view: Union[BaseView, "TensorBox"]):
2696
layout = view.get_layout()
2705
def make_indexer(self):
2706
return self.as_fixed().make_indexer()
2708
def maybe_guard_aligned(self):
2709
offset = self.view.get_layout().offset
2712
from .compile_fx import ALIGNMENT
2714
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type]
2717
class NoneLayout(IRNode):
2718
# This is janky, I figured out what fields to populate by just running
2719
# the model I was interested in and adding properties/methods as needed.
2720
# This doesn't inherit from Layout because Layout assumes you have stuff
2721
# like sizes, but I don't really have anything here.
2723
# If you have an ir.Node with NoneLayout, you probably need to setup
2724
# dependencies manually in scheduler
2726
def __init__(self, device):
2727
self.device = device
2731
def storage_size(self):
2738
class MutationLayout(Layout):
2739
def __init__(self, target: IRNode):
2741
target.get_device(),
2746
self.target = target
2747
name = self.get_buffer().get_name()
2748
V.graph.mark_buffer_mutated(name)
2750
@Layout.stride.getter # type: ignore[attr-defined]
2752
return self.real_layout().stride
2754
def storage_size(self) -> sympy.Expr:
2755
return self.real_layout().storage_size()
2757
def get_buffer(self) -> "Buffer":
2758
def unwrap_views(target):
2759
if isinstance(target, MutationLayout):
2760
return unwrap_views(target.target)
2761
if isinstance(target, BaseView):
2762
return unwrap_views(target.unwrap_view())
2763
if isinstance(target, MutableBox):
2764
return unwrap_views(target.data)
2767
result = unwrap_views(self.target)
2768
assert isinstance(result, Buffer), "MutationLayout must refer to a buffer"
2771
def real_layout(self):
2772
return self.get_buffer().layout
2775
def realize_into(cls, src, dst, unsafe_alias=False):
2777
# NOTE: We must realize users of `dst` before we realize `src`, since
2778
# realization order determines scheduling order. Otherwise, src's
2779
# mutation would be scheduled before the existing users of dst!
2780
V.graph.mark_buffer_mutated(dst.get_name())
2782
if isinstance(src, TensorBox):
2785
# We copy the contents of src into dst. In most cases this should
2786
# be fused into a single kernel by the scheduler.
2787
# NOTE: We cannot change src's layout to mutate dst directly as this
2788
# would alias src to dst, which is not correct as further mutations to
2789
# dst would effect users of src. However if there are no more users of
2790
# dst, we can alias src to dst.
2793
if not unsafe_alias:
2794
src = Pointwise.create(
2795
device=src.get_device(),
2796
dtype=src.get_dtype(),
2797
inner_fn=src.make_loader(),
2799
V.graph.sizevars.guard_equals(a, b)
2800
for a, b in zip(src.get_size(), dst.get_size())
2805
assert isinstance(src.data.layout, FlexibleLayout)
2806
src.data.layout = MutationLayout(dst)
2812
def make_indexer(self):
2813
return self.target.make_indexer()
2816
@dataclasses.dataclass
2817
class Buffer(IRNode):
2818
# Name is sometimes None; e.g., ForceInPlace, where there isn't
2823
# Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
2824
# MultiOutput does NOT define this!
2826
def __post_init__(self):
2827
super().__post_init__()
2828
self.origin_node = None
2830
def make_indexer(self):
2831
return self.layout.make_indexer()
2833
def get_name(self) -> str:
2837
def get_device(self):
2838
return self.layout.device
2840
def get_origin_node(self):
2841
return self.origin_node
2845
return getattr(self.layout, "dtype", None)
2848
return list(self.layout.size)
2850
def get_stride(self):
2851
return list(self.layout.stride)
2853
def get_offset(self):
2854
return self.layout.offset
2856
def get_layout(self):
2859
def get_storage_numel(self):
2860
return self.get_numel()
2862
def is_extern(self):
2865
def freeze_layout(self):
2866
if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)):
2867
self.layout = self.layout.as_fixed()
2869
def freeze_layout_with_stride_order(self, order):
2870
assert isinstance(self.layout, FlexibleLayout)
2871
self.layout = self.layout.as_stride_order(order)
2873
def freeze_layout_with_fill_order(self, order):
2874
assert isinstance(self.layout, FlexibleLayout)
2875
self.layout = self.layout.as_fill_order(order)
2877
def freeze_layout_with_same_order(self, stride):
2878
assert isinstance(self.layout, FlexibleLayout)
2879
self.layout = self.layout.as_same_order(stride)
2881
def is_zero_elements(self):
2882
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type]
2884
def make_loader(self):
2885
# Loading from a zero-element buffer is a no-op
2886
if self.is_zero_elements():
2887
return partial(nop_loader_fn, dtype=self.get_dtype())
2890
indexer = self.layout.make_indexer()
2891
return ops.load(self.name, indexer(index))
2898
def codegen_reference(self, writer=None):
2899
return self.get_name()
2901
def decide_layout(self):
2904
def get_alias_names(self):
2905
if isinstance(self.layout, AliasedLayout):
2906
return [self.layout.view.get_name()]
2909
def get_mutation_names(self):
2910
if isinstance(self.layout, MutationLayout):
2911
return [self.layout.target.get_name()]
2914
def get_read_writes(self):
2915
with patch.object(FlexibleLayout, "allow_indexing", True):
2916
return extract_read_writes(
2921
def get_reads(self):
2922
return self.get_read_writes().reads
2924
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
2926
Returns the unbacked symbols which are defined by this IR node,
2927
because this is a data-dependent IR node, or item()
2929
# So this is a little unusual. In principle, you could imagine
2930
# defining a MultiOutputLayout buffer so that it DOES define
2931
# unbacked symints. However, we can't easily tell what symints
2932
# such a buffer defines, because MultiOutputLayout doesn't actually
2933
# define any useful information about what it returns.
2935
# An easier and better approach is to delay the symint allocation
2936
# to the MultiOutput IR nodes, which are when we actually extract
2937
# out the buffers and know what their sizes are.
2939
# There are two subleties here:
2941
# 1. Suppose you have a kernel that produces out1: (i0,), out2: (i0,)
2942
# Both of these actually count as defs! The scheduler will just
2943
# arbitrarily pick one of these as the canonical definer and
2944
# ensure it stays live. It's not a big deal if we pick the
2945
# wrong one because tuple accesses are cheap, and all this means
2946
# is we accidentally keep a MultiOutput node live when it wasn't
2947
# strictly necessary.
2949
# 2. Suppose you have a MultiOutput buffer whose size is (i0,), but
2950
# the MultiOutputLayout buffer it is projecting from isn't actually
2951
# dynamic; it has i0 as one of the arguments. We cannot tell this
2952
# directly from MultiOutput, we have to look at the input buffer's
2953
# uses to work this out. No big deal.
2954
if isinstance(self.layout, (NoneLayout, MultiOutputLayout)):
2957
# This kernel defines all unbacked symbols... that it didn't get in as
2960
free_unbacked_symbols(self.get_size())
2961
| free_unbacked_symbols(self.get_stride())
2962
| free_unbacked_symbols(self.get_offset())
2964
return defs - self.get_unbacked_symbol_uses()
2966
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
2968
Returns the unbacked symbols which are required to be in scope in
2969
order to successfully perform codegen for this buffer. For example,
2970
a buffer that corresponds to an extern kernel call that takes i0 as
2971
an argument would return {i0} here. This is used to generate necessary
2972
dependencies that ensure we actually bind i0 in codegen before you
2975
Note that this is NOT transitive; in particular, if this buffer takes
2976
in as input another buffer with dynamic shape (e.g., (i0,)), we will
2977
not report it here, because you will already have a dependency
2978
on that buffer, which will eventually have a dependency on i0 if
2983
def codegen_unbacked_symbol_defs(self, wrapper):
2984
# NB: If it is possible for other ir node types to return unbacked
2985
# symints, you need to make sure their codegen calls this method.
2986
# Don't forget to update get_unbacked_symbol_defs too.
2987
symbols_to_define = self.get_unbacked_symbol_defs()
2988
for i, s in enumerate(self.get_size()):
2989
if s in symbols_to_define:
2991
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}"
2993
symbols_to_define.remove(s)
2994
for i, s in enumerate(self.get_stride()):
2995
if s in symbols_to_define:
2997
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}"
2999
symbols_to_define.remove(s)
3000
if (s := self.get_offset()) in symbols_to_define:
3002
f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}"
3004
symbols_to_define.remove(s)
3006
not symbols_to_define
3007
), f"unbacked symint {s} not written out, check comment above"
3012
def get_workspace_size(self):
3014
Gets extra global memory size needed by this buffer.
3015
Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
3019
def should_allocate(self):
3020
# Returns False by default.
3024
class InputBuffer(Buffer):
3028
class ConstantBuffer(InputBuffer):
3029
override_device: Optional[torch.device] = None
3031
def make_loader(self):
3033
indexer = self.layout.make_indexer()
3035
V.graph.constant_name(self.get_name(), self.override_device),
3041
def constant_to_device(self, device):
3042
return ConstantBuffer(
3043
V.graph.constant_name(self.get_name(), device), self.layout
3047
class NoneAsConstantBuffer(IRNode):
3048
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3051
def codegen_reference(self, writer=None):
3052
return V.graph.wrapper_code.none_str
3055
class ShapeAsConstantBuffer(IRNode):
3056
def __init__(self, shape):
3060
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3061
return free_unbacked_symbols(self.shape)
3063
def codegen_reference(self, writer=None):
3064
return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape))
3067
@dataclasses.dataclass
3068
class ComputedBuffer(Buffer):
3071
def get_computed_buffer_name(self):
3073
Returns self.name if it exists, otherwise returns the name of the data node if that exists.
3074
If neither exist, returns None.
3076
if self.name is not None:
3078
if hasattr(self.data, "name"):
3079
return self.data.name
3083
def num_reads(self):
3084
return len(self.get_read_writes().reads)
3086
def get_read_writes(self):
3087
with patch.object(FlexibleLayout, "allow_indexing", True):
3088
if self.data.get_reduction_type():
3089
return extract_read_writes(
3090
self.get_store_function(),
3091
self.data.get_pointwise_size(),
3092
self.data.get_reduction_size(),
3095
return extract_read_writes(
3096
self.get_store_function(),
3097
self.data.get_size(),
3100
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3101
# Ordinarily, we'd like to just peek at the arguments list,
3102
# but ComputedBuffers have no argument list.
3104
# Morally, this logic needs to be synchronized with the
3105
# KernelArgs.size calls, which are responsible for making symbols make
3106
# there way as kernel arguments (and it is precisely passing in one of
3107
# those symbols that establishes a dependency). However, we haven't
3108
# started codegen yet so we can't directly reuse that logic.
3110
# For now, I'm just yoloing with the size of the buffer. Not sure if
3113
# One thing you might wonder is if this is enough for a ComputedBuffer
3114
# denoting a reduction over i0. Empirically, it is enough, but for an
3115
# unusual reason: we only need accurate dependencies for item() call,
3116
# but it's impossible to end up with a reduction over i0 from an
3117
# item() call without a regular non-reduction buffer first.
3119
free_unbacked_symbols(self.get_size())
3120
| free_unbacked_symbols(self.get_stride())
3121
| free_unbacked_symbols(self.get_offset())
3122
| self.data.get_unbacked_symbol_uses()
3125
def make_loader(self):
3126
# Inline constants and index_expressions
3128
hasattr(self.data, "make_loader")
3129
and self.name not in V.graph.mutated_buffers
3130
and self.num_reads() == 0
3133
return self.data.make_loader()
3134
return super().make_loader()
3136
def get_store_function(self):
3137
indexer = self.layout.as_fixed().make_indexer()
3138
if isinstance(self.data, (Reduction, Scan)):
3139
return partial(self.data.store_reduction, self.name, indexer)
3141
assert isinstance(self.data, Pointwise)
3142
return partial(self.data.store_output, self.name, indexer)
3144
def get_fill_order(self):
3146
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
3148
TODO(jansel): A better algorithm here would look at downstream consumers of this
3149
value and try to do global graph-level layout optimization.
3150
This is also something just begging to be autotuned.
3152
if isinstance(self.layout, FlexibleLayout):
3153
(index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
3154
self.data.get_pointwise_size(), self.data.get_reduction_size()
3156
reads = self.get_read_writes().reads
3158
V.graph.name_to_buffer[r.name]
3159
if r.name in V.graph.name_to_buffer.keys()
3163
# only consider reads to buffer of same size
3164
# ignore StarDeps because they don't contribute stride information
3166
isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
3171
r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
3174
if isinstance(r, dependencies.MemoryDep)
3178
if isinstance(self.data, Scan):
3179
indices = self.data.reindex(index_vars, reduction_vars)
3181
indices = index_vars
3183
V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type]
3185
from .scheduler import pick_loop_order
3187
return pick_loop_order(stride_lengths, self.get_size())
3191
def decide_layout(self):
3192
if isinstance(self.layout, FlexibleLayout):
3193
order = self.get_fill_order()
3195
self.freeze_layout_with_fill_order(order)
3197
self.freeze_layout()
3199
def simplify_and_reorder(self):
3201
This is a main place where we do loop transformations in a
3202
backend-agnostic way.
3205
1) Remove any 1 dimensions
3206
2) Fuse contiguous dimensions together
3207
3) Reorder dimensions based on stride orders
3209
args, var_ranges = dependencies.index_vars_squeeze(
3210
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
3212
with patch.object(ConstantBuffer, "override_device", self.get_device()):
3214
self.get_store_function(),
3215
(args if self.get_reduction_type() else args[:1]),
3218
index_formulas = [*body.indexing_exprs.values()]
3220
V.graph.name_to_buffer[reads_name]
3221
if reads_name in V.graph.name_to_buffer.keys()
3223
for reads_name in body.reads_name2expr.keys()
3226
*body.reads_name2expr.values(),
3227
*body.writes_name2expr.values(),
3230
reduce_vars: List[Any] = []
3233
for v, s in var_ranges.items():
3235
assert not reduce_vars
3236
index_vars.append(v)
3237
index_size.append(s)
3240
reduce_vars.append(v)
3241
reduce_size.append(s)
3243
# the reordering_reindex in reads' simplify_reorder_and_tile
3244
reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
3245
for i, reads_buf in enumerate(reads_bufs):
3246
if isinstance(reads_buf, ComputedBuffer) and hasattr(
3247
reads_buf, "iter_reordering_reindex"
3249
reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type]
3251
def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None):
3252
sizes, reindex0, reindex1 = self._apply_loop_reordering(
3253
x_vars, support_vars, sizes, memory_addrs, reordering_reindex
3255
# for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
3256
x_vars = reindex0(x_vars)
3257
sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
3260
index_prevent_reordering(index_formulas, x_vars, sizes),
3262
x_vars = prune(x_vars)
3263
# sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
3264
# x_vars = prune(x_vars)
3265
# sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
3266
reindex = fuse_reindexing(reindex1, reindex2)
3267
return sizes, reindex, reindex1
3269
support_vars = index_vars + reduce_vars
3270
iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
3271
index_vars, support_vars, index_size, reordering_reindex
3273
reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
3274
reduce_vars, support_vars, reduce_size
3277
# remember the reordering if not have loop collapse.
3278
if len(iter_ranges) == len(index_vars):
3279
self.iter_reordering_reindex = iter_reordering_reindex
3280
# retrace the loop body with simplification and reordering applied
3281
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
3282
iter_ranges, reduce_ranges, prefix="z"
3285
body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
3287
return (iter_ranges, reduce_ranges), body
3290
def _apply_loop_reordering(
3295
reordering_reindex=None,
3299
Shuffle the order of loops around to hopefully improve performance.
3301
from .scheduler import pick_loop_order
3303
if priority_idx is None:
3308
V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
3309
for expr in memory_addrs
3311
assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
3314
# consider both layout(strides) and reordering(reordering_reindex)
3315
if reordering_reindex is not None:
3316
for i in range(len(memory_addrs)):
3318
strides[i] = reordering_reindex[i](strides[i])
3319
# if len(order) != len(strides), do not reorder
3320
except AssertionError:
3322
order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
3326
"Did not simplify complex index:\n%s\n%s",
3327
dict(zip(index_vars, sizes)),
3330
order = list(range(len(sizes)))
3331
sizes = [sizes[i] for i in order]
3332
return sizes, same_reorder(order), inverse_reorder(order)
3334
def get_reduction_size(self):
3335
return self.data.get_reduction_size()
3337
def get_reduction_type(self):
3338
return self.data.get_reduction_type()
3341
return self.data.is_zero_elements()
3343
def should_allocate(self):
3346
def constant_to_device(self, device):
3347
"""Move this to a given device. Requires that all reads are to constants."""
3348
return self.data.constant_to_device(device)
3351
class TemplateBuffer(Buffer):
3353
Represents a Triton (in the future other type) of template operator
3354
that we can fuse an epilogue onto.
3357
def __init__(self, layout, inputs, make_kernel_render):
3358
super().__init__(name=None, layout=layout)
3359
self.inputs = InputsKernel.unwrap_storage(inputs)
3360
self.make_kernel_render = make_kernel_render
3361
self.name = V.graph.register_buffer(self)
3363
def get_read_writes(self):
3364
return self.normalized_read_writes()
3366
def normalized_read_writes(self):
3367
name = self.get_name()
3368
indexer = self.layout.make_indexer()
3370
def dummy(index, rindex):
3371
assert len(rindex) == 0
3372
return ops.store(name, indexer(index), "fake")
3374
deps = dependencies.extract_read_writes(
3375
dummy, self.get_size(), (), normalize=True
3377
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
3380
def get_reduction_size(self):
3383
def get_reduction_type(self):
3389
def should_allocate(self):
3392
def simplify_and_reorder(self):
3402
class TritonTemplateBuffer(TemplateBuffer):
3406
class CUDATemplateBuffer(TemplateBuffer):
3412
workspace_size: int,
3413
template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821
3415
super().__init__(layout, inputs, make_kernel_render)
3416
# Global memory (in bytes) needed for this template.
3417
self.workspace_size = workspace_size
3418
self.template = template
3420
def get_workspace_size(self):
3421
return self.workspace_size if self.workspace_size is not None else 0
3424
@dataclasses.dataclass
3425
class InputsKernel(Buffer):
3426
inputs: List[Buffer]
3428
def get_read_writes_input(self, x):
3429
return dependencies.StarDep(x.get_name())
3431
def get_read_writes(self):
3433
for input in self.inputs:
3434
if isinstance(input, list):
3435
star_dep.extend([self.get_read_writes_input(x) for x in input])
3437
star_dep.append(self.get_read_writes_input(input))
3439
return dependencies.ReadWrites(
3441
{dependencies.StarDep(self.get_name())},
3445
op_counts=collections.Counter(),
3449
def unwrap_storage_for_input(cls, x):
3450
if isinstance(x, TensorBox):
3452
if isinstance(x, StorageBox):
3454
if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
3455
x = ExternKernel.realize_input(x)
3456
if isinstance(x, TensorBox):
3457
# when converting to ReinterpretView fails in the
3458
# realize_input call above, the result will be wrapped
3459
# into TensorBox / StorageBox pair as a result of the
3460
# cls.copy_input call; so we should unwrap recursively
3461
return cls.unwrap_storage_for_input(x)
3462
assert isinstance(x, (Buffer, ReinterpretView)), x
3466
def unwrap_storage(inputs):
3469
if isinstance(x, list):
3470
x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
3472
x = InputsKernel.unwrap_storage_for_input(x)
3473
inputs_new.append(x)
3476
def is_extern(self):
3480
class NopKernel(InputsKernel):
3485
class ConcatKernel(NopKernel):
3487
There isn't actually a real kernel for concat, we just change the
3488
storage for the upstream data.
3492
def create(cls, inputs, dim):
3493
device = inputs[0].get_device()
3494
dtype = inputs[0].get_dtype()
3495
new_size = list(inputs[0].get_size())
3497
offsets_end = [new_size[dim]]
3498
assert 0 <= dim < len(new_size)
3499
for i in range(1, len(inputs)):
3500
input_size = inputs[i].get_size()
3501
offsets_start.append(new_size[dim])
3502
assert len(input_size) == len(new_size)
3503
assert inputs[i].get_dtype() == dtype
3504
assert inputs[i].get_device() == device
3505
for j in range(len(new_size)):
3507
new_size[j] = new_size[j] + input_size[j]
3509
new_size[j] = V.graph.sizevars.guard_equals(
3510
new_size[j], input_size[j]
3512
offsets_end.append(new_size[dim])
3514
output_stride = FlexibleLayout.contiguous_strides(new_size)
3515
# If any of the inputs is in CL format, use CL format for the output
3516
for i in range(len(inputs)):
3518
if is_storage_and_layout(x):
3519
layout = x.get_layout()
3521
isinstance(layout, FixedLayout)
3522
and layout.is_channels_last_contiguous()
3524
# use CL stride for the output
3525
output_stride = make_channels_last_strides_for(new_size)
3528
concat_kernel = ConcatKernel(
3534
stride=output_stride,
3538
kernel = StorageBox(concat_kernel)
3540
for i in range(len(inputs)):
3541
input_buffer = cls.realize_into(
3543
SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
3545
concat_kernel.inputs.append(input_buffer)
3547
if isinstance(inputs[i].data, BaseView):
3548
input_unwrapped = inputs[i].data.unwrap_view()
3550
input_unwrapped = inputs[i].data
3553
input_unwrapped.is_input_buffer()
3554
and inputs[i].get_device().type == "cuda"
3555
and not is_dynamic(input_buffer)
3557
buffer_names.append(input_buffer.get_name())
3559
if len(buffer_names) > 1:
3560
V.graph.register_list(buffer_names)
3562
concat_kernel.name = V.graph.register_buffer(concat_kernel)
3563
concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
3568
def can_realize_into_without_copy(cls, src):
3569
if isinstance(src, TensorBox):
3570
# unwrap a TensorBox
3571
return cls.can_realize_into_without_copy(src.data)
3573
return isinstance(src.data.layout, FlexibleLayout) and not isinstance(
3574
src.data, ExternKernelAlloc
3578
def realize_into(cls, src, dst):
3579
# Attempt to turn this into a ReinterpretView rather than assert.
3580
# This has concessions around layout, as as_storage_and_layout
3581
# can cause us to go from flexible to fixed layout.
3582
if not isinstance(dst, ReinterpretView):
3583
if is_storage_and_layout(dst):
3584
storage, layout = as_storage_and_layout(dst)
3585
dst = ReinterpretView(storage, layout)
3586
assert isinstance(dst, ReinterpretView), dst
3587
if isinstance(src, TensorBox):
3588
# unwrap a TensorBox
3589
return cls.realize_into(src.data, dst)
3590
if isinstance(src, StorageBox):
3592
# ExternKernelAlloc has specific requirements for output layout, should create a copy
3593
assert hasattr(src.data, "layout")
3594
if cls.can_realize_into_without_copy(src):
3595
src.data.layout = AliasedLayout(dst)
3598
pw = Pointwise.create(
3599
device=src.get_device(),
3600
dtype=src.get_dtype(),
3601
inner_fn=src.make_loader(),
3603
V.graph.sizevars.guard_equals(a, b)
3604
for a, b in zip(src.get_size(), dst.get_size())
3607
return cls.realize_into(pw, dst)
3609
def should_allocate(self):
3613
@dataclasses.dataclass
3614
class ExternKernel(InputsKernel):
3615
constant_args: Tuple[Any, ...] = ()
3616
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
3617
output_view: Optional[ReinterpretView] = None
3618
python_kernel_name: Optional[str] = None
3619
cpp_kernel_name: Optional[str] = None
3620
# FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
3621
# We shouldn't need to do this since the information can be retrieved from op_overload._schema.
3622
ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
3623
default_factory=list
3625
op_overload: Optional[
3626
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
3628
arg_properties: Optional[List[Dict[str, Any]]] = None
3629
kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
3639
python_kernel_name=None,
3640
cpp_kernel_name=None,
3641
ordered_kwargs_for_cpp_kernel=(),
3649
self.constant_args = constant_args
3650
self.kwargs = kwargs if kwargs else {}
3651
self.output_view = output_view
3652
self.python_kernel_name = python_kernel_name
3653
self.cpp_kernel_name = cpp_kernel_name
3654
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
3655
self.op_overload = op_overload
3656
self.collect_arg_kwarg_properties()
3658
def collect_arg_kwarg_properties(self):
3659
# if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
3660
# information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
3662
isinstance(self.op_overload, torch._ops.OpOverload)
3663
and not self.ordered_kwargs_for_cpp_kernel
3665
self.ordered_kwargs_for_cpp_kernel = [
3666
x.name for x in self.op_overload._schema.arguments if x.kwarg_only
3668
self.arg_properties = (
3672
"type": x.real_type,
3673
"default_value": x.default_value,
3675
for x in self.op_overload._schema.arguments
3678
if isinstance(self.op_overload, torch._ops.OpOverload)
3679
else [{} for i in range(len(self.inputs))]
3681
self.kwarg_properties = (
3683
x.name: {"type": x.real_type, "default_value": x.default_value}
3684
for x in self.op_overload._schema.arguments
3687
if isinstance(self.op_overload, torch._ops.OpOverload)
3691
def decide_layout(self):
3692
if isinstance(self.layout, FlexibleLayout):
3693
self.apply_constraint()
3694
self.freeze_layout()
3696
def codegen_comment(self, wrapper):
3697
origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper)
3699
wrapper.writeline(origin_str)
3701
def codegen(self, wrapper):
3702
raise NotImplementedError()
3704
def get_kernel_name(self):
3705
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
3709
pw = Pointwise.create(
3710
device=x.get_device(),
3711
dtype=x.get_dtype(),
3712
inner_fn=x.make_loader(),
3713
ranges=x.get_size(),
3714
origin_node=x.get_origin_node(),
3715
traceback=x.get_traceback(),
3721
def process_kernel(cls, kernel, *args, **kwargs):
3722
binded_args = {"args": args, "kwargs": kwargs}
3724
args_flat, args_spec = pytree.tree_flatten(binded_args)
3728
non_tensor_args: List[Any] = []
3729
for arg in args_flat:
3730
is_arg_tensor.append(isinstance(arg, IRNode))
3731
if is_arg_tensor[-1]:
3732
tensor_args.append(arg)
3734
if isinstance(arg, sympy.Expr):
3735
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
3736
non_tensor_args.append(arg)
3738
def unflatten_args(new_tensor_args, new_non_tensor_args):
3740
it_tensors = iter(new_tensor_args)
3741
it_non_tensors = iter(new_non_tensor_args)
3742
for is_tensor in is_arg_tensor:
3744
result.append(next(it_tensors))
3746
result.append(next(it_non_tensors))
3747
r = pytree.tree_unflatten(result, args_spec)
3748
return r.get("args", []), r.get("kwargs", {})
3750
tensor_args = [cls.realize_input(x) for x in tensor_args]
3752
# freeze layout otherwise our output stride calculation might
3754
for x in tensor_args:
3755
if is_storage_and_layout(x):
3756
as_storage_and_layout(x, freeze=True)
3758
# We don't have generic shape formulas, so just burn in the
3759
# shapes and run an example input.
3760
# TODO(jansel): replace this with dynamic shape formulas
3763
# We need to retain the constant values of fake tensors that we originally
3764
# propagated the graph with, because for some operators running without a
3765
# constant would trigger an error / DataDependentException
3766
for x in tensor_args:
3767
if x.get_name() in V.graph.constants:
3768
example_args.append(V.graph.constants[x.get_name()])
3770
example_args.append(ir_node_to_tensor(x, guard_shape=True))
3772
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
3773
example_output = kernel(*new_args, **new_kwargs)
3777
if not isinstance(example_output, (list, tuple))
3780
for t in example_out_li:
3781
if isinstance(t, torch.Tensor) and t.is_sparse:
3782
msg = "sparsity not handled. Please file issue for sparse inference weights."
3783
if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
3784
msg = f"{msg} Found from : \n {stack_trace}"
3785
V.graph.disable_cudagraphs_reason = msg
3787
# TODO: Unconditionally do this, not just when example_output has
3789
if maybe_free_unbacked_symbols(example_output):
3790
example_output = V.graph.current_node.meta["val"]
3792
return example_output, tensor_args, non_tensor_args, unflatten_args
3795
def convert_to_reinterpret_view(cls, x):
3797
In order to pass this to an extern kernel we need a
3798
ReinterpretView not a View. This allows us to avoid some
3801
assert isinstance(x, BaseView)
3802
if isinstance(x, ReinterpretView):
3805
# NOTE: Don't use extract_read_writes here as it fails when
3806
# make_loader() inlines the computation
3807
x.unwrap_view().freeze_layout()
3808
index_args, var_ranges = dependencies.index_vars_squeeze(
3809
x.get_size(), prefix="r"
3811
range_vars = index_args[0]
3812
index = x.make_indexer()(range_vars)
3814
index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
3815
strides = V.graph.sizevars.stride_vars(index, range_vars)
3816
offset = V.graph.sizevars.offset_var(index, range_vars)
3817
expected = sympy_dot(range_vars, strides) + offset
3819
if index != expected:
3821
"convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
3826
raise NotImplementedError()
3828
return ReinterpretView(
3831
device=x.get_device(),
3832
dtype=x.get_dtype(),
3840
def realize_input(cls, x):
3842
return NoneAsConstantBuffer()
3843
if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)):
3844
return ShapeAsConstantBuffer(x)
3845
if isinstance(x, Constant):
3846
return V.graph.add_tensor_constant(
3847
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
3849
if isinstance(x, ConstantBuffer):
3851
if isinstance(x, TensorBox):
3852
return cls.realize_input(x.data)
3853
if isinstance(x, ReinterpretView):
3854
return ReinterpretView(cls.realize_input(x.data), x.get_layout())
3855
if isinstance(x, BaseView):
3857
if is_storage_and_layout(x.unwrap_view()):
3859
return cls.convert_to_reinterpret_view(x)
3860
except NotImplementedError:
3862
if isinstance(x, StorageBox):
3863
# TODO(jansel): impose layout preference on realized buffer
3866
return cls.copy_input(x)
3869
def require_stride1(cls, x):
3870
if is_storage_and_layout(x):
3871
if len(x.get_stride()) == 0:
3873
for stride in x.get_stride():
3876
return cls.copy_input(x)
3879
def require_stride_order(cls, x, order):
3880
if x.get_numel() == 0: # Layout doesn't matter
3883
# require x to have the layout as strided_ordered as order
3884
if is_storage_and_layout(x):
3885
while isinstance(x.get_layout(), AliasedLayout):
3886
x = x.get_layout().view
3887
if isinstance(x.get_layout(), FlexibleLayout):
3888
# fix flexiblelayout to be FixedLayout with stride_order
3889
as_storage_and_layout(
3890
x, freeze=True, want_contiguous=False, stride_order=order
3894
x.get_layout(), FixedLayout
3895
) and x.get_layout().is_stride_ordered(order):
3897
elif isinstance(x.get_layout(), MutationLayout):
3898
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
3899
raise AssertionError(
3900
"the MutationLayout's real layout shouldn't be FlexibleLayout"
3903
x.get_layout().real_layout(), FixedLayout
3904
) and x.get_layout().real_layout().is_stride_ordered(order):
3907
# TODO - Storage to InputBuffer
3908
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
3911
isinstance(x, TensorBox)
3912
and isinstance(x.data, BaseView)
3913
and not isinstance(x.data, ReinterpretView)
3914
and is_storage_and_layout(x.unwrap_view())
3915
and not isinstance(x.unwrap_view().data, ExternKernelAlloc)
3918
x.data = cls.convert_to_reinterpret_view(x.data)
3919
return cls.require_stride_order(x, order)
3920
except NotImplementedError:
3922
x = cls.copy_input(x)
3923
as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
3924
assert is_stride_order_storage_and_layout(x, order)
3928
def require_channels_last(cls, x):
3929
return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
3932
def require_contiguous(cls, x):
3933
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
3935
def apply_constraint(self):
3938
def codegen_const_args(self):
3939
return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args)
3941
def codegen_args(self):
3943
for i, x in enumerate(self.inputs):
3944
if isinstance(x, list):
3945
names = [i.codegen_reference() for i in x]
3946
codegen_reference = f'[{", ".join(names)}]'
3947
args.append(codegen_reference)
3949
if V.graph.cpp_wrapper:
3950
assert self.arg_properties and i < len(
3952
), "Invalid arg_properties accessing"
3953
type_ = self.arg_properties[i].get("type")
3955
V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
3956
type_, x, self.is_legacy_abi_kernel()
3960
args.append(x.codegen_reference())
3961
args.extend(self.codegen_const_args())
3964
def get_kwargs_value(self, arg_name):
3965
if arg_name in self.kwargs:
3966
return self.kwargs.get(arg_name)
3967
if self.kwarg_properties and self.kwarg_properties.get(arg_name):
3968
return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr]
3970
raise AssertionError(f"{arg_name} not in self.kwarg_properties")
3972
def is_legacy_abi_kernel(self):
3975
def codegen_kwargs(self):
3976
if V.graph.cpp_wrapper:
3978
for arg_name in self.ordered_kwargs_for_cpp_kernel:
3979
v = self.get_kwargs_value(arg_name)
3980
if isinstance(v, sympy.Expr):
3984
self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr]
3985
if self.kwarg_properties and arg_name in self.kwarg_properties
3989
V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type]
3990
type_, v, self.is_legacy_abi_kernel()
3995
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc]
3996
for k, v in self.kwargs.items()
4000
def codegen_size_asserts(self, wrapper):
4001
if config.size_asserts and not V.graph.cpp_wrapper:
4002
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
4003
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
4005
f"assert_size_stride({self.get_name()}, {size}, {stride})"
4008
def get_group_stride(self):
4010
get output sizes and strides, for template_codegen
4012
_size = self.get_size()
4013
_stride = self.get_stride()
4014
# iter_ranges = _size of output tensor, reduce_range = [] because no reduction
4015
return [_size, []], _stride
4017
def canonicalize(self):
4019
Manually get canonicalization of the output index
4021
# manually generate index formula for conv
4022
sizevars = V.graph.sizevars
4023
sizes = self.get_size()
4024
strides = self.get_stride()
4025
strides = [sizevars.size_hint(x) for x in strides]
4026
index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
4027
# reorder index vars according to stride
4028
index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
4029
lookup = {pos: idx for idx, pos in enumerate(index_order)}
4030
order = [lookup[i] for i in range(len(lookup))]
4031
index_vars = [index_vars[i] for i in order]
4032
indexer = self.make_indexer()
4033
index = indexer(index_vars)
4035
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
4036
index_vars, sizes, [index]
4039
# assign new variables each dimension to deal with numbering mismatches
4040
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
4041
_, add_var = var_builder("c")
4042
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
4044
index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type]
4045
return index, tuple(new_sizes)
4047
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
4048
# NB: It's not necessary to check regular inputs as we automatically
4049
# have dependencies on them
4051
for arg in self.constant_args:
4052
r |= maybe_free_unbacked_symbols(arg)
4053
for arg in self.kwargs.values():
4054
r |= maybe_free_unbacked_symbols(arg)
4058
kernel_name = getattr(self, "python_kernel_name", None)
4060
f"python_kernel_name={kernel_name!r}",
4063
f"{field.name}={getattr(self, field.name)}"
4064
for field in dataclasses.fields(self)
4066
lines.append(f"origin_node={self.origin_node!r}")
4067
return self.str_helper(lines)
4072
@dataclasses.dataclass
4073
class ExternKernelOut(ExternKernel):
4074
def codegen(self, wrapper):
4075
self.codegen_comment(wrapper)
4076
args = [*self.codegen_args(), *self.codegen_kwargs()]
4077
wrapper.generate_extern_kernel_out(
4079
self.codegen_reference(),
4081
self.get_kernel_name(),
4091
python_kernel_name=None,
4092
cpp_kernel_name=None,
4093
ordered_kwargs_for_cpp_kernel=(),
4099
self.unwrap_storage(inputs),
4105
ordered_kwargs_for_cpp_kernel,
4108
self.name = V.graph.register_buffer(self)
4110
def should_allocate(self):
4114
class RandomSeeds(ExternKernelOut):
4115
def __init__(self, count: int, device: torch.device):
4116
limits = torch.iinfo(torch.int64)
4124
constant_args=[limits.min, limits.max, [count]],
4125
python_kernel_name="aten.randint.low_out",
4126
cpp_kernel_name="at::randint_out",
4130
class ExternKernelAlloc(ExternKernel):
4131
def codegen(self, wrapper):
4132
self.codegen_comment(wrapper)
4133
args = [*self.codegen_args(), *self.codegen_kwargs()]
4134
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
4135
if isinstance(self.layout, Layout):
4136
self.codegen_size_asserts(wrapper)
4144
python_kernel_name=None,
4145
cpp_kernel_name=None,
4146
ordered_kwargs_for_cpp_kernel=(),
4152
self.unwrap_storage(inputs),
4158
ordered_kwargs_for_cpp_kernel,
4161
self.name = V.graph.register_buffer(self)
4163
def should_allocate(self):
4166
def apply_constraint(self):
4167
raise NotImplementedError
4170
class UserDefinedTritonKernel(ExternKernel):
4171
def get_kernel_and_configs(self):
4172
from triton.runtime.autotuner import Autotuner
4174
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
4176
kernel = kernel_side_table.get_kernel(self.kernel_idx)
4178
if isinstance(kernel, Autotuner):
4179
configs = kernel.configs
4181
return kernel, configs
4183
def codegen(self, wrapper):
4184
kernel, configs = self.get_kernel_and_configs()
4186
# Definition of kernel
4187
new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
4188
kernel, configs, self.kwargs
4191
args = self.codegen_kwargs()
4192
if V.graph.cpp_wrapper:
4193
# in C++ wrapper, we don't pass constexpr args, as they don't
4194
# get added as parameters to the PTX code compiled from the
4195
# user-defined Triton kernel (only non-constexpr args do)
4196
args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs]
4199
self.codegen_comment(wrapper)
4200
wrapper.generate_user_defined_triton_kernel(
4208
def should_allocate(self):
4211
def has_side_effects(self):
4212
# UserDefinedTritonKernel does not return anything, but rather
4213
# modifies input in place, do not let it get DCEd
4216
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4219
def get_mutation_names(self):
4222
def __init__(self, *, kernel_idx, grid, kernel_args):
4226
for k, v in kernel_args.items():
4227
if isinstance(v, TensorBox):
4228
t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
4232
constant_args.append(v)
4235
assert len(inputs) != 0
4236
device = inputs[0].get_device()
4240
NoneLayout(device), # type: ignore[arg-type]
4242
tuple(constant_args),
4245
self.name = V.graph.register_buffer(self)
4246
self.kernel_idx = kernel_idx
4249
kernel, _ = self.get_kernel_and_configs()
4250
# If we are autotuning, not all arguments will be passed
4251
self.ordered_kwargs_for_cpp_kernel = [
4252
arg for arg in kernel.arg_names if arg in kernel_args
4255
mark_node_as_mutating(
4256
self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)]
4259
def get_alias_names(self):
4260
return [i.get_name() for i in self.inputs]
4263
def mark_node_as_mutating(cur_buffer, *mutated_ops):
4265
Allows ops in mutated_ops to be marked as being mutated as well as
4266
indicates to the scheduler that these ops depend on cur_buffer.
4268
for op in mutated_ops:
4269
assert isinstance(op, IRNode), op
4270
V.graph.mark_buffer_mutated(op.get_name())
4271
assert hasattr(op, "layout")
4272
MutationOutput(op.layout, op, cur_buffer)
4275
class MutationOutput(ExternKernel):
4276
def get_mutation_names(self):
4277
return [self.inputs[0].get_name()]
4279
def __init__(self, layout, input, parent):
4280
super().__init__(None, layout, [input, parent], ())
4281
self.name = V.graph.register_buffer(self)
4283
def should_allocate(self):
4289
def has_side_effects(self):
4292
def get_alias_names(self):
4293
return [self.inputs[0].get_name()]
4296
class InplaceBernoulliFallback(ExternKernel):
4298
This needs to be a custom class to handle mutation properly
4301
def codegen(self, wrapper):
4302
(x,) = (t.codegen_reference() for t in self.inputs)
4304
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
4307
def should_allocate(self):
4310
def get_mutation_names(self):
4311
return [self.inputs[0].get_name()]
4313
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4316
def __init__(self, x, *constant_args):
4319
NoneLayout(x.get_device()), # type: ignore[arg-type]
4320
self.unwrap_storage([x]),
4323
self.name = V.graph.register_buffer(self)
4324
self.python_kernel_name = "aten.bernoulli_"
4325
self.cpp_kernel_name = "at::native::bernoulli_"
4326
mark_node_as_mutating(self, x)
4329
# Used to deal with torch.complex types
4330
class InplaceCopyFallback(ExternKernel):
4332
This needs to be a custom class to handle mutation properly
4335
def codegen(self, wrapper):
4336
(dst, src, non_blocking) = self.codegen_args()
4338
f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}"
4341
def should_allocate(self):
4344
def get_mutation_names(self):
4345
return [self.inputs[0].get_name()]
4347
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4361
python_kernel_name="aten.copy_",
4363
"aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call"
4366
self.name = V.graph.register_buffer(self)
4369
def create(cls, dst, src, non_blocking: bool = False):
4370
inputs = [cls.realize_input(t) for t in [dst, src]]
4371
constant_args = (non_blocking,)
4372
result = InplaceCopyFallback(
4373
NoneLayout(dst.get_device()), # type: ignore[arg-type]
4377
mark_node_as_mutating(result, dst)
4381
class MutatingFirstArgExternKernel(ExternKernel):
4383
This needs to be a custom class to handle mutation properly
4386
def codegen(self, wrapper):
4388
*(t.codegen_reference() for t in self.inputs),
4389
*map(repr, self.constant_args),
4392
f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
4395
def should_allocate(self):
4398
def get_mutation_names(self):
4399
return [self.inputs[0].get_name()]
4401
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4404
def has_side_effects(self):
4408
class ResizeStorageBytes(MutatingFirstArgExternKernel):
4409
def __init__(self, variable, new_size):
4410
assert isinstance(new_size, int), "TODO: dynamic shapes"
4413
NoneLayout(variable.get_device()), # type: ignore[arg-type]
4414
self.unwrap_storage([variable]),
4415
constant_args=(new_size,),
4417
V.graph.mark_buffer_mutated(variable.get_name())
4418
self.name = V.graph.register_buffer(self)
4419
self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
4420
self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
4421
V.graph.never_reuse_buffers.add(variable.data.get_name())
4422
mark_node_as_mutating(self, variable)
4425
class ScatterFallback(ExternKernel):
4427
This needs to be a custom class to handle mutation properly.
4428
This class handles both aten.scatter_ and aten.scatter_reduce_.
4429
It also handle the case `src` being a scalar properly.
4432
def codegen(self, wrapper):
4433
reduce = self.kwargs["reduce"]
4434
if V.graph.cpp_wrapper:
4435
# Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
4436
get_operator_enum = {"add": "sum", "multiply": "prod"}
4437
if reduce in get_operator_enum:
4438
reduce = get_operator_enum[reduce]
4440
if self.src_is_tensor:
4441
(x, index, src) = (t.codegen_reference() for t in self.inputs)
4443
(x, index) = (t.codegen_reference() for t in self.inputs)
4444
src = self.constant_args[1]
4445
wrapper.generate_scatter_fallback(
4447
[x, self.constant_args[0], index, src],
4448
self.get_kernel_name(),
4449
self.python_kernel_name,
4452
self.codegen_kwargs(),
4455
def should_allocate(self):
4458
def get_cpp_kernel(self):
4459
reduce = self.kwargs["reduce"]
4460
if self.python_kernel_name == "aten.scatter_":
4461
if self.src_is_tensor:
4463
"at::scatter_out" if reduce is None else "at::scatter_reduce_out"
4468
), "Expect reduce to be None for aten.scatter_ with scalar src"
4469
kernel = "at::scatter_out"
4473
), "Expect reduce to be not None for aten.scatter_reduce_"
4474
kernel = "at::scatter_reduce_out"
4477
def get_mutation_names(self):
4478
return [self.inputs[0].get_name()]
4480
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4492
reduce: Optional[str] = None,
4493
include_self: bool = True,
4495
assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"}
4496
self.src_is_tensor = isinstance(src, TensorBox)
4498
constant_args: Tuple[Any, ...]
4499
if self.src_is_tensor:
4500
tensors = [self.realize_input(t) for t in [x, index, src]]
4501
constant_args = (dim,)
4503
tensors = [self.realize_input(t) for t in [x, index]]
4504
constant_args = (dim, src)
4508
NoneLayout(x.get_device()), # type: ignore[arg-type]
4509
self.unwrap_storage(tensors),
4511
{"reduce": reduce, "include_self": include_self},
4512
python_kernel_name=python_kernel_name,
4513
ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
4514
op_overload=op_overload,
4516
self.cpp_kernel_name = self.get_cpp_kernel()
4517
self.name = V.graph.register_buffer(self)
4518
mark_node_as_mutating(self, x)
4521
class IndexPutFallback(ExternKernel):
4523
This needs to be a custom class to handle mutation and indices properly
4526
def codegen(self, wrapper):
4527
(x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
4529
iter_valid_indices = iter(valid_indices)
4530
for i, _ in enumerate(self.indices):
4531
if self.indices[i] is not None:
4532
indices.append(next(iter_valid_indices))
4534
indices.append(V.graph.wrapper_code.none_str)
4536
wrapper.generate_index_put_fallback(
4537
self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
4540
def should_allocate(self):
4543
def get_mutation_names(self):
4544
return [self.inputs[0].get_name()]
4546
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4549
def __init__(self, op_overload, x, indices, values, accumulate):
4550
self.indices = indices
4551
valid_indices = [i for i in indices if i is not None]
4552
tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
4554
"aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out"
4558
NoneLayout(x.get_device()), # type: ignore[arg-type]
4559
self.unwrap_storage(tensors),
4561
python_kernel_name="aten.index_put_",
4562
cpp_kernel_name=cpp_kernel_name,
4563
op_overload=op_overload,
4565
self.name = V.graph.register_buffer(self)
4566
mark_node_as_mutating(self, x)
4569
class DeviceCopy(ExternKernelOut):
4571
def create(cls, x, device):
4575
(r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep))
4576
for r in x.get_reads()
4578
and not config.aot_inductor.use_runtime_constant_folding
4580
return x.constant_to_device(device)
4582
V.graph.add_device_info(device)
4583
V.graph.add_device_info(x.get_device())
4585
developer_warning("DeviceCopy in input program")
4589
dtype=x.get_dtype(),
4592
[cls.realize_input(x)],
4595
def codegen(self, wrapper):
4596
args = self.codegen_args()
4597
assert len(args) == 1
4598
if self.output_view:
4599
wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference())
4601
wrapper.codegen_device_copy(args[0], self.codegen_reference())
4604
class DynamicScalar(ExternKernel):
4606
The result of a call to aten._local_scalar_dense.
4609
def get_reads(self):
4612
def should_allocate(self):
4615
# TODO: handle bools carefully
4616
def __init__(self, sym, data):
4618
super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type]
4619
if isinstance(sym, sympy.Symbol):
4621
self.is_bool = False
4623
# Special case for boolean. For Reasons(TM), we don't represent
4624
# boolean variables directly in sympy; instead, we generate an
4625
# indicator integer variable which we then convert to a boolean by
4626
# testing i0 == 1. We have to identify the underlying indicator
4627
# variable, and then bind i0 to the appropriate integer value
4628
# based on the runtime boolean.
4629
assert isinstance(sym, sympy.Eq), sym
4630
assert isinstance(sym.args[0], sympy.Symbol), sym
4631
assert sym.args[1] == 1, sym
4632
self.sym = sym.args[0]
4635
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4638
def codegen(self, wrapper):
4639
wrapper.codegen_dynamic_scalar(self)
4642
class AssertScalar(ExternKernel):
4644
The result of a call to aten._assert_scalar
4647
def get_reads(self):
4650
def should_allocate(self):
4653
def __init__(self, scalar, msg):
4655
# Buffer(name, layotu)
4657
NoneLayout(torch.device("cpu")), # type: ignore[arg-type]
4658
# InputsKernel(inputs)
4660
) # type: ignore[arg-type]
4661
self.scalar = scalar
4664
def has_side_effects(self):
4667
def get_unbacked_symbol_uses(self):
4668
return free_unbacked_symbols(self.scalar)
4670
def codegen(self, wrapper):
4671
if V.graph.cpp_wrapper:
4675
f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:"
4677
wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
4678
# No one should ever use this buffer, but for uniformity
4679
# define the variable and assign it None
4680
wrapper.writeline(f"{self.get_name()} = None")
4683
@dataclasses.dataclass
4684
class ExternKernelNode:
4686
node: export_schema.Node
4690
aten._embedding_bag.default,
4691
aten._fft_c2c.default,
4692
aten._scaled_dot_product_efficient_attention.default,
4693
aten._scaled_dot_product_flash_attention.default,
4694
aten._scaled_mm.default,
4699
aten.repeat_interleave.Tensor,
4700
aten.nonzero.default,
4702
aten.view_as_real.default,
4706
def get_aten_cpp_kernel_name(kernel):
4707
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
4708
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4709
# repeat_interleave(const at::Tensor & self, int64_t repeats,
4710
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4712
isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten"
4713
), "Invalid aten kernel"
4715
kernel.__name__.split(".")[0]
4716
if kernel._overloadname == "default"
4717
else kernel.__name__.replace(".", "_")
4719
return f"at::_ops::{opname}::call"
4722
class FallbackKernel(ExternKernelAlloc):
4723
args_default_value: List[Dict[str, Any]]
4737
tuple(nontensor_args),
4740
# We need output buffers for generating kernel arguments in the
4741
# abi-compatible mode, where we retrieve outputs by pass each individual
4742
# output through the abi-compatible interface.
4743
self.outputs: Sequence[Any] = []
4744
self.use_runtime_dispatch = False
4745
self.abi_compatible_kernel = None
4750
torch._ops.OpOverload,
4751
torch._ops.HigherOrderOperator,
4753
), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
4754
self.op_overload = kernel
4756
self.unflatten_args = unflatten_args
4757
self.kwargs = {} if kwargs is None else kwargs
4758
V.graph.warn_fallback(self.python_kernel_name)
4760
# args that are aliased
4761
self.alias_names: List[str] = []
4762
# args that are mutated AND returned from the op
4763
self.mutation_names: List[str] = []
4765
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
4766
# We assume here that HOPs with FallbackKernel are functional.
4767
# This may not always be true! HOPs must individually opt-in to
4768
# FallbackKernel, so please check this if you opt-in.
4771
if "_c10d_functional" in self.op_overload.name():
4772
# _c10d_functional kernels are lowered into _CollectiveKernel which
4773
# derives from FallbackKernel for the cpp codegen. The kernels
4774
# don't pass the can_auto_functionalize check, but their mutation
4775
# is handled properly by _CollectiveKernel.
4778
schema = self.op_overload._schema
4780
# NOTE: [FallbackKernel supported operators]
4781
# We only support three types of operators:
4784
# - inplace aten ops
4785
# - mutating ops that are auto-functionalizable. That is,
4786
# the operator may mutate any number of inputs, but its outputs
4787
# may not alias any of the inputs.
4789
# The unsupported cases usually do not show up here (because
4790
# AOTAutograd functionalized them away); the only way for an in-place
4791
# op to show up here is if a lowering or pass introduced it.
4792
if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
4793
self.mutation_names.append(tensor_args[0].get_name())
4796
if schema.is_mutable and not can_auto_functionalize(kernel):
4797
raise NotImplementedError(
4798
f"NYI: Can't generate FallbackKernel for {kernel}"
4801
schema_args = schema.arguments
4802
args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
4804
def handle_aliasing_and_mutation(info, arg):
4805
# Assertions to make sure we didn't mismatch args
4806
if isinstance(info.type, torch.ListType):
4807
assert isinstance(arg, (list, tuple))
4808
is_optional_tensor = isinstance(
4809
info.type, torch.OptionalType
4810
) and isinstance(info.type.getElementType(), torch.TensorType)
4811
if is_optional_tensor or isinstance(info.type, torch.TensorType):
4812
# PyTorch also accepts None and scalar types for args marked as "Tensor".
4813
# We're not going to check all of them here.
4814
assert not isinstance(arg, (tuple, list))
4818
if info.alias_info is None:
4820
# can_auto_functionalize already filters out mutable List[Tensor].
4821
# We can support this in the future, but this is very uncommon.
4822
assert isinstance(info.type, torch.TensorType) or is_optional_tensor
4823
self.alias_names.append(arg.get_name())
4824
if info.alias_info.is_write:
4825
mark_node_as_mutating(self, arg)
4827
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
4828
handle_aliasing_and_mutation(info, arg)
4830
def set_cpp_kernel(self, kernel):
4831
from .codegen.wrapper import get_cpp_op_schema
4834
not kernel._schema.is_mutable
4835
), f"mutable {kernel.__name__} is not supported with cpp_wrapper"
4837
# These checks are here because ops that return aliasing tensors will
4838
# return type Tensor& instead of Tensor, but codegen will always write
4839
# type Tensor on the LHS.
4840
def is_not_write(arg):
4841
return arg.alias_info is None or not arg.alias_info.is_write
4844
is_not_write(x) for x in kernel._schema.arguments
4845
), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper"
4847
is_not_write(x) for x in kernel._schema.returns
4848
), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper"
4850
self.cpp_kernel_name = kernel._schema.name
4851
self.cpp_kernel_overload_name = kernel._schema.overload_name
4852
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]
4854
self.cpp_op_schema = get_cpp_op_schema(kernel)
4855
self.init_args_default_value(kernel._schema)
4857
def is_legacy_abi_kernel(self):
4858
return "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4860
def init_args_default_value(self, schema):
4861
self.args_default_value = [
4864
"type": x.real_type,
4865
"value": x.default_value,
4867
for x in schema.arguments
4871
def get_pos_arg_value(self, pos, kwargs):
4872
# positional args may be provided in kwargs
4873
pos_arg_name = self.args_default_value[pos]["name"]
4874
if pos_arg_name in kwargs:
4876
"Found argument %s with value %s from kwargs",
4878
kwargs[pos_arg_name],
4880
return kwargs[pos_arg_name]
4883
self, "args_default_value"
4884
), "self.args_default_value has to be provided"
4886
self.args_default_value
4887
), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}"
4888
arg_default_value = self.args_default_value[pos]["value"]
4890
"Use default value %s for argument %s", arg_default_value, pos_arg_name
4892
return arg_default_value
4894
def codegen_args(self):
4895
@dataclasses.dataclass
4902
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
4903
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
4904
# Now we setup abi_compatible_kernel after self.python_kernel_name
4905
# and kwargs are adjusted appropriately.
4906
# For sdpa, we need the v2 version since v1 didn't consider optional arg
4907
# FIXME: no need to do this after we switch to the torchgen-ed C shim
4908
self.abi_compatible_kernel = (
4909
f"{self.cpp_kernel_name}_v2"
4910
if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"}
4911
else self.cpp_kernel_name
4914
if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
4916
V.graph.wrapper_code.val_to_cpp_arg_str(
4917
param.real_type, x, self.is_legacy_abi_kernel()
4919
for param, x in zip(self.op_overload._schema.arguments, args)
4922
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
4924
# Previously, we want to maintain forward-compatibility by skipping
4925
# default args in the serialized artifacts in fbcode. However,
4926
# some of our shim interfaces require default values being set.
4927
# Discussed with Sherlock offline and we decided to allow serializing
4928
# default args into the C++ wrapper code for now. We will refine this
4929
# part if we see real FC requirement. More details related to FC
4931
# https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
4932
if V.graph.cpp_wrapper and hasattr(self, "args_default_value"):
4933
self.fill_non_provided_args(args, kwargs, convert_val_to_str=True)
4935
# let self.codegen_kwargs handle kwargs
4936
self.kwargs.update(kwargs)
4940
def find_device(tensor_args, example_output):
4942
return tensor_args[0].get_device()
4943
if isinstance(example_output, torch.Tensor):
4944
return example_output.device
4945
if isinstance(example_output, (list, tuple)):
4946
devices = {FallbackKernel.find_device(None, x) for x in example_output}
4948
devices = [device for device in devices if device]
4949
if len(devices) == 1:
4951
for device in devices:
4952
if device.type == "cuda":
4957
def has_side_effects(self):
4958
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
4960
return get_schema_info(self.op_overload).is_mutable()
4962
def get_alias_names(self):
4963
return self.alias_names
4965
def get_mutation_names(self):
4966
assert len(self.mutation_names) <= 1
4967
return self.mutation_names
4969
def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
4970
assert isinstance(args, (list, tuple))
4971
if isinstance(args, tuple):
4973
assert hasattr(self, "args_default_value")
4975
n_pos_args = len(self.args_default_value)
4976
# For cpp wrapper, if some positional args are not provided, we need to check
4977
# if they're in the kwargs or use their default value
4978
if n_args < n_pos_args:
4980
"%s has %d unprovided positional arguments. "
4981
"Will check if they are in the keyword arguments or will use default values.",
4983
n_pos_args - n_args,
4986
self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args)
4988
if convert_val_to_str:
4989
pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args]
4990
args.extend(pos_args)
4993
# ProxyExecutor Design Note
4994
# We export the ExternFallbackNodes (for custom ops) into a serialized file
4995
# and run it with a host side proxy executor to address the ABI problem
4996
# This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
4997
# Detailed design doc can be found at
4998
# https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
4999
def export_extern_kernel_node(self):
5000
assert isinstance(self, FallbackKernel)
5001
args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
5002
args = self.fill_non_provided_args(args, kwargs)
5004
kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
5007
serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type]
5008
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type]
5011
def handle_single_output(return_type, output):
5012
if isinstance(return_type, torch.TensorType):
5015
if isinstance(output, (list, tuple)):
5016
assert len(output) == 1
5018
return export_schema.Argument.create(
5019
as_tensor=export_schema.TensorArgument(name=out.get_name())
5021
elif isinstance(return_type, torch.ListType) and isinstance(
5022
return_type.getElementType(), torch.TensorType
5024
# For single TensorList
5025
return export_schema.Argument.create(
5027
export_schema.TensorArgument(name=out.get_name())
5032
raise RuntimeError(f"Unsupported return type {type(return_type)}")
5034
target = self.op_overload
5035
returns = target._schema.returns # type: ignore[union-attr]
5036
if len(returns) == 1:
5037
return_type = returns[0].real_type
5038
output_arguments = [handle_single_output(return_type, self.outputs)]
5040
# For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
5041
assert isinstance(self.outputs, tuple)
5042
assert len(returns) == len(self.outputs)
5043
output_arguments = [
5044
handle_single_output(return_schema.real_type, output)
5045
for return_schema, output in zip(returns, self.outputs)
5048
node = ExternKernelNode(
5049
name=self.get_name(),
5050
node=export_schema.Node(
5051
target=self.op_overload.name(), # type: ignore[union-attr]
5052
inputs=named_arguments,
5053
outputs=output_arguments,
5058
V.graph.extern_kernel_nodes.append(node)
5060
return [*args, *ordered_kwargs]
5062
def codegen(self, wrapper):
5063
kernel = self.op_overload
5064
if kernel.namespace == "aten": # type: ignore[union-attr]
5066
assert isinstance(kernel, torch._ops.OpOverload)
5067
if V.graph.cpp_wrapper:
5068
if config.is_fbcode() and kernel not in has_c_shim:
5070
"%s is missing a c-shim implementation, using proxy executor as fallback",
5073
self.use_runtime_dispatch = True
5074
self.set_cpp_kernel(kernel)
5076
self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel)
5077
schema = kernel._schema
5078
self.init_args_default_value(schema)
5080
self.python_kernel_name = str(kernel)
5082
elif isinstance(kernel, torch._ops.HigherOrderOperator):
5083
self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
5085
# For non-aten OpOverload, i.e. custom ops
5086
if V.graph.cpp_wrapper:
5087
self.use_runtime_dispatch = True
5088
self.set_cpp_kernel(kernel)
5090
self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" # type: ignore[union-attr]
5092
if self.use_runtime_dispatch:
5093
self.codegen_comment(wrapper)
5095
exported_args = None
5097
if config.is_fbcode() and V.graph.cpp_wrapper:
5098
exported_args = self.export_extern_kernel_node()
5100
args = [*self.codegen_args(), *self.codegen_kwargs()]
5102
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5104
self.get_kernel_name(),
5107
self.cpp_kernel_key,
5108
self.cpp_kernel_overload_name,
5114
self.codegen_comment(wrapper)
5115
args = [*self.codegen_args(), *self.codegen_kwargs()]
5116
V.graph.wrapper_code.generate_fallback_kernel(self, args)
5117
if isinstance(self.layout, Layout):
5118
self.codegen_size_asserts(wrapper)
5121
def tensor_to_layout(output: torch.Tensor):
5125
convert_shape_to_inductor(output.size()),
5126
convert_shape_to_inductor(output.stride()),
5130
def create(cls, kernel, *args, **kwargs):
5131
fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
5133
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
5141
) = cls.process_kernel(kernel, *args, **kwargs)
5143
device = cls.find_device(tensor_args, example_output)
5144
assert device, "Not sure where to find device info"
5147
MultiOutputLayout(device),
5154
def generate_output(output, indices):
5155
if isinstance(output, (list, tuple)):
5156
return type(output)(
5157
generate_output(output[i], indices + [(type(output), i)])
5158
for i in range(len(output))
5160
elif isinstance(output, dict):
5162
key: generate_output(val, indices + [(type(output), key)])
5163
for key, val in output.items()
5165
elif isinstance(output, torch.Tensor):
5167
cls.tensor_to_layout(output),
5171
elif isinstance(output, int):
5173
elif isinstance(output, torch.SymInt):
5174
return output.node.expr
5178
), f"FallbackKernel output type {type(output)} is not supported"
5181
outputs = generate_output(example_output, [])
5182
if isinstance(outputs, (list, tuple, dict)):
5183
packed.outputs = outputs # type: ignore[assignment]
5185
packed.outputs = [outputs]
5188
def apply_constraint(self):
5189
return super().apply_constraint()
5192
@dataclasses.dataclass
5193
class ComplexView(FallbackKernel):
5194
"""View a complex number as two dtyped numbers or vice versa"""
5196
def should_allocate(self):
5199
def get_alias_names(self):
5200
# Signal to codegen that our output buffer isn't safe to reuse
5201
return [self.inputs[0].get_name()]
5220
@dataclasses.dataclass
5221
class MultiOutputLayout(IRNode):
5222
device: torch.device
5225
class MultiOutput(ExternKernel):
5226
# Given an input MultiOutputLayout buffer, indexes out an actual buffer
5227
# from that result. This doesn't actually produce multiple outputs,
5228
# that's MultiOutputLayout!
5229
def codegen_list_tuple_access(self, basename, indices):
5230
if len(indices) > 0:
5231
itype, i = indices[0]
5233
return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
5234
elif itype == tuple:
5235
# cpp wrapper code needs to use std::get<> to access a tuple
5236
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
5237
basename, self.get_name(), str(i)
5239
return self.codegen_list_tuple_access(tuple_access, indices[1:])
5241
return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
5243
raise AssertionError("non supported index type")
5247
def codegen(self, wrapper):
5248
wrapper.codegen_multi_output(
5250
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
5252
self.codegen_unbacked_symbol_defs(wrapper)
5254
def __init__(self, layout, input, indices: List[Tuple[Any, ...]]):
5255
super().__init__(None, layout, [input], ())
5256
self.name = V.graph.register_buffer(self)
5257
self.indices = indices
5259
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
5260
return self.inputs[0].get_unbacked_symbol_uses()
5262
def should_allocate(self):
5265
def get_alias_names(self):
5268
for inp in self.inputs
5269
if isinstance(inp, FallbackKernel) and len(inp.get_alias_names()) > 0
5273
def _prepare_convolution_fusion_create(
5276
weight: "TensorBox",
5280
dilation: List[int],
5282
transposed: bool = False,
5283
output_padding: Optional[List[int]] = None,
5286
This function is a helper function to prepare inputs, layout and constant args
5287
for convolution post-op fusion's create function, including deciding the output
5288
layout (channels first or channels last), realizing inputs and make them etc. The
5289
function only supports the CPU device since conv post-op fusion kernel is only
5290
supported on CPU right now.
5293
# Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
5294
def _conv_input_size(
5295
output_size, weight_size, padding, output_padding, stride, dilation, groups
5297
assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
5298
dim = len(output_size)
5299
assert dim > 2, "Expect input dim > 2"
5302
WEIGHT_INPUT_CHANNELS_DIM = 1
5304
input_size.append(output_size[BATCH_DIM])
5305
input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
5306
for d in range(2, dim):
5307
kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
5309
(output_size[d] - 1) * stride[d - 2]
5310
- (padding[d - 2] * 2)
5312
+ output_padding[d - 2]
5314
input_size.append(input_size_d)
5315
return list(map(int, input_size))
5317
# The size of prepacked_weight is the prepacked weight size of deconv:
5318
# Groups > 1: [g*o, i/g, ...]
5319
# Groups == 1: [o, i, ...]
5320
# Returns original weight size in [i, o, ...]
5321
def _original_deconv_weight_size(
5325
prepacked_weight_size = prepacked_weight.size()
5326
dim = len(prepacked_weight_size)
5327
assert dim > 2, "Expect weight dim > 2"
5330
weight_size.append(prepacked_weight_size[1] * groups)
5331
weight_size.append(prepacked_weight_size[0] / groups)
5332
for d in range(2, dim):
5333
weight_size.append(prepacked_weight_size[d])
5335
weight_size = prepacked_weight.transpose(0, 1).size()
5340
if bias is not None:
5342
with V.graph.fake_mode:
5343
# TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
5344
x_fake = ir_node_to_tensor(x, guard_shape=True)
5345
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
5346
dims = len(x_fake.size()) - 2
5347
assert 0 < len(padding) <= dims
5348
assert 0 < len(dilation) <= dims
5349
assert 0 < len(stride) <= dims
5350
padding = pad_listlike(padding, dims)
5351
dilation = pad_listlike(dilation, dims)
5352
stride = pad_listlike(stride, dims)
5353
if output_padding is None:
5354
output_padding = pad_listlike([0], dims)
5356
assert 0 < len(output_padding) <= dims
5357
output_padding = pad_listlike(output_padding, dims)
5358
assert isinstance(groups, int)
5360
# When transposed, the size of the prepacked oneDNN weight is different
5361
# from the PyTorch weight. We're not able to run aten conv with such
5362
# size. We infer the output size from the input params here:
5363
weight_size = _original_deconv_weight_size(weight_fake, groups)
5364
input_size = x_fake.size()
5365
output_size = _conv_input_size(
5376
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
5378
output = torch.ops.aten.convolution(
5389
output_size = output.size()
5391
req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
5392
req_stride_order = [len(req_stride_order)] + req_stride_order
5393
output_stride = make_channels_last_strides_for(output_size)
5395
x = cls.require_stride_order(x, req_stride_order)
5396
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5397
inputs = [x, weight]
5399
kernel_layout = FixedLayout(
5402
convert_shape_to_inductor(output_size),
5403
convert_shape_to_inductor(output_stride),
5405
constant_args = [padding, stride, dilation, groups]
5407
constant_args.insert(1, output_padding)
5409
if bias is not None:
5412
constant_args.insert(0, bias)
5413
return inputs, constant_args, kernel_layout, req_stride_order
5416
def _prepare_linear_fusion_create(
5419
weight: "TensorBox",
5423
This function is a helper function to prepare inputs, layout and constant args
5424
for linear post-op fusion's create function. The function only supports the CPU device
5425
since linear post-op fusion kernel is only supported on CPU right now.
5429
if bias is not None:
5432
*m, _ = x.get_size()
5433
# The weight has been transposed during the qlinear weight prepack process.
5434
# https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
5435
# aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
5436
_, oc = weight.get_size()
5437
output_size = list(m) + [oc]
5438
req_stride_order = list(reversed(range(len(x.get_size()))))
5440
x = cls.require_stride_order(x, req_stride_order)
5441
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5442
inputs = [x, weight]
5444
output_stride = make_contiguous_strides_for(output_size)
5445
kernel_layout = FixedLayout(
5451
constant_args: List[Any] = []
5453
if bias is not None:
5456
constant_args.insert(0, bias)
5457
return inputs, constant_args, kernel_layout, req_stride_order
5460
class ConvolutionUnary(ExternKernelAlloc):
5472
python_kernel_name="torch.ops.mkldnn._convolution_pointwise",
5473
cpp_kernel_name="mkldnn::_convolution_pointwise",
5475
self.cpp_kernel_key = "convolution_pointwise"
5476
self.cpp_op_schema = """
5478
const at::Tensor& input_t,
5479
const at::Tensor& weight_t,
5480
const c10::optional<at::Tensor>& bias_opt,
5481
at::IntArrayRef padding,
5482
at::IntArrayRef stride,
5483
at::IntArrayRef dilation,
5485
c10::string_view attr,
5486
torch::List<c10::optional<at::Scalar>> scalars,
5487
c10::optional<c10::string_view> algorithm)"""
5489
def codegen(self, wrapper):
5490
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5492
self.get_kernel_name(),
5493
self.codegen_args(),
5495
self.cpp_kernel_key,
5497
if isinstance(self.layout, Layout):
5498
self.codegen_size_asserts(wrapper)
5504
weight: "TensorBox",
5506
padding_: List[int],
5508
dilation_: List[int],
5511
scalars: Optional[List[Any]],
5514
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
5515
cls, x, weight, bias, padding_, stride_, dilation_, groups
5517
constant_args = constant_args + [
5519
may_convert_to_optional(scalars),
5522
return ConvolutionUnary(
5523
layout=kernel_layout,
5525
constant_args=constant_args,
5529
class ConvolutionBinary(ExternKernelAlloc):
5535
cpp_constant_args=(),
5542
python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary",
5543
cpp_kernel_name="mkldnn::_convolution_pointwise",
5545
self.cpp_kernel_overload_name = "binary"
5546
self.cpp_kernel_key = "convolution_pointwise_binary"
5547
self.cpp_op_schema = """
5549
const at::Tensor& input_t,
5550
const at::Tensor& other_t,
5551
const at::Tensor& weight_t,
5552
const c10::optional<at::Tensor>& bias_opt,
5553
at::IntArrayRef padding,
5554
at::IntArrayRef stride,
5555
at::IntArrayRef dilation,
5557
c10::string_view binary_attr,
5558
c10::optional<at::Scalar> alpha,
5559
c10::optional<c10::string_view> unary_attr,
5560
torch::List<c10::optional<at::Scalar>> unary_scalars,
5561
c10::optional<c10::string_view> unary_algorithm)"""
5562
self.cpp_constant_args = cpp_constant_args
5564
def codegen(self, wrapper):
5565
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5567
self.get_kernel_name(),
5568
self.codegen_args(),
5570
self.cpp_kernel_key,
5571
self.cpp_kernel_overload_name,
5573
if isinstance(self.layout, Layout):
5574
self.codegen_size_asserts(wrapper)
5581
weight: "TensorBox",
5583
padding_: List[int],
5585
dilation_: List[int],
5588
binary_alpha: Optional[float],
5589
unary_attr: Optional[str],
5590
unary_scalars: Optional[List[Any]],
5591
unary_algorithm: Optional[str],
5598
) = _prepare_convolution_fusion_create(
5599
cls, x, weight, bias, padding_, stride_, dilation_, groups
5601
other = cls.require_stride_order(other, req_stride_order)
5602
inputs.insert(1, other)
5603
constant_args = constant_args + [
5607
may_convert_to_optional(unary_scalars),
5610
return ConvolutionBinary(
5611
layout=kernel_layout,
5613
constant_args=constant_args,
5617
class ConvolutionBinaryInplace(ExternKernelAlloc):
5624
# Due to constrain of op.call, other (Tensor&) should be at input[0]
5625
reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
5632
python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary",
5633
cpp_kernel_name="mkldnn::_convolution_pointwise_",
5635
self.cpp_kernel_overload_name = "binary"
5636
self.cpp_kernel_key = "convolution_pointwise_binary_"
5637
# TODO: op.call: input[0] should be at::Tensor&
5638
self.cpp_op_schema = """
5640
at::Tensor& other_t,
5641
const at::Tensor& input_t,
5642
const at::Tensor& weight_t,
5643
const c10::optional<at::Tensor>& bias_opt,
5644
at::IntArrayRef padding,
5645
at::IntArrayRef stride,
5646
at::IntArrayRef dilation,
5648
c10::string_view binary_attr,
5649
c10::optional<at::Scalar> alpha,
5650
c10::optional<c10::string_view> unary_attr,
5651
torch::List<c10::optional<at::Scalar>> unary_scalars,
5652
c10::optional<c10::string_view> unary_algorithm)"""
5654
def codegen(self, wrapper):
5655
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5657
self.get_kernel_name(),
5658
self.codegen_args(),
5660
self.cpp_kernel_key,
5661
self.cpp_kernel_overload_name,
5664
def get_mutation_names(self):
5665
return [self.inputs[0].get_name()]
5667
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
5675
weight: "TensorBox",
5677
padding_: List[int],
5679
dilation_: List[int],
5682
binary_alpha: Optional[float],
5683
unary_attr: Optional[str],
5684
unary_scalars: Optional[List[Any]],
5685
unary_algorithm: Optional[str],
5692
) = _prepare_convolution_fusion_create(
5693
cls, x, weight, bias, padding_, stride_, dilation_, groups
5695
other = cls.require_stride_order(other, req_stride_order)
5696
inputs.insert(1, other)
5697
constant_args = constant_args + [
5701
may_convert_to_optional(unary_scalars),
5704
packed = ConvolutionBinaryInplace(
5705
kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type]
5707
constant_args=constant_args,
5709
mark_node_as_mutating(packed, inputs[1])
5710
# This op mutates in place which means that the result is not the
5711
# target but rather the input that is being mutated
5712
# init reorders the inputs, so inputs[1] becomes packed.inputs[0]
5713
return packed.inputs[0]
5716
class MKLPackedLinear(ExternKernelAlloc):
5728
python_kernel_name="torch.ops.mkl._mkl_linear",
5729
cpp_kernel_name="mkl::_mkl_linear",
5731
self.cpp_kernel_key = "mkl_linear"
5732
self.cpp_op_schema = """
5734
const at::Tensor& self,
5735
const at::Tensor& mkl_weight_t,
5736
const at::Tensor& origin_weight_t,
5737
const c10::optional<at::Tensor>& bias_opt,
5738
const int64_t prepack_batch_size)"""
5740
def codegen(self, wrapper):
5741
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5743
self.get_kernel_name(),
5744
self.codegen_args(),
5746
self.cpp_kernel_key,
5750
def create(cls, x, packed_w, orig_w, batch_size):
5751
x = cls.require_stride1(cls.realize_input(x))
5752
orig_w = cls.require_stride1(cls.realize_input(orig_w))
5753
*m, _ = x.get_size()
5754
oc, _ = orig_w.get_size()
5755
output_size = list(m) + [oc]
5756
output_stride = make_contiguous_strides_for(output_size)
5757
inputs = [x, packed_w, orig_w]
5758
constant_args = [None, batch_size]
5760
return MKLPackedLinear(
5762
x.get_device(), x.get_dtype(), output_size, output_stride
5765
constant_args=constant_args,
5769
class LinearUnary(ExternKernelAlloc):
5781
python_kernel_name="torch.ops.mkldnn._linear_pointwise",
5782
cpp_kernel_name="mkldnn::_linear_pointwise",
5784
self.cpp_kernel_key = "linear_pointwise"
5785
self.cpp_op_schema = """
5787
const at::Tensor& input_t,
5788
const at::Tensor& weight_t,
5789
const c10::optional<at::Tensor>& bias_opt,
5790
c10::string_view attr,
5791
torch::List<c10::optional<at::Scalar>> scalars,
5792
c10::optional<c10::string_view> algorithm)"""
5794
def codegen(self, wrapper):
5795
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5797
self.get_kernel_name(),
5798
self.codegen_args(),
5800
self.cpp_kernel_key,
5804
def create(cls, x, w, b, attr, scalars, algorithm):
5805
x = cls.require_contiguous(cls.realize_input(x))
5806
w = cls.require_contiguous(cls.realize_input(w))
5808
*m, ic = x.get_size()
5809
oc, ic = w.get_size()
5811
constant_args = [attr, scalars if scalars else [-1], algorithm]
5813
b = cls.require_contiguous(cls.realize_input(b))
5816
constant_args.insert(0, None)
5819
layout=FlexibleLayout(
5820
device=x.get_device(),
5821
dtype=x.get_dtype(),
5822
size=list(m) + [oc],
5825
constant_args=constant_args,
5828
def apply_constraint(self):
5832
class LinearBinary(ExternKernelAlloc):
5833
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
5846
python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary",
5847
cpp_kernel_name="mkldnn::_linear_pointwise",
5849
self.cpp_kernel_overload_name = "binary"
5850
self.cpp_kernel_key = "linear_pointwise_binary"
5851
self.cpp_op_schema = """
5853
const at::Tensor& input_t,
5854
const at::Tensor& other_t,
5855
const at::Tensor& weight_t,
5856
const c10::optional<at::Tensor>& bias_opt,
5857
c10::string_view attr)
5860
def codegen(self, wrapper):
5861
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5863
self.get_kernel_name(),
5864
self.codegen_args(),
5866
self.cpp_kernel_key,
5867
self.cpp_kernel_overload_name,
5871
def create(cls, x, y, w, b, attr):
5872
x = cls.require_contiguous(cls.realize_input(x))
5873
y = cls.require_contiguous(cls.realize_input(y))
5874
w = cls.require_contiguous(cls.realize_input(w))
5876
*m, ic = x.get_size()
5877
oc, ic = w.get_size()
5880
constant_args = [attr]
5882
b = cls.require_contiguous(cls.realize_input(b))
5885
constant_args.insert(0, b)
5887
return LinearBinary(
5888
layout=FlexibleLayout(
5889
device=x.get_device(),
5890
dtype=x.get_dtype(),
5891
size=list(m) + [oc],
5894
constant_args=constant_args,
5897
def apply_constraint(self):
5901
class ConvolutionTransposeUnary(ExternKernelAlloc):
5913
python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise",
5914
cpp_kernel_name="mkldnn::_convolution_transpose_pointwise",
5916
self.cpp_kernel_key = "convolution_transpose_pointwise"
5917
self.cpp_op_schema = """
5919
const at::Tensor& input_t,
5920
const at::Tensor& weight_t,
5921
const c10::optional<at::Tensor>& bias_opt,
5922
at::IntArrayRef padding,
5923
at::IntArrayRef output_padding,
5924
at::IntArrayRef stride,
5925
at::IntArrayRef dilation,
5927
c10::string_view attr,
5928
torch::List<c10::optional<at::Scalar>> scalars,
5929
c10::optional<c10::string_view> algorithm)"""
5931
def codegen(self, wrapper):
5932
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5934
self.get_kernel_name(),
5935
self.codegen_args(),
5937
self.cpp_kernel_key,
5944
weight: "TensorBox",
5946
padding_: List[int],
5947
output_padding_: List[int],
5949
dilation_: List[int],
5952
scalars: Optional[List[Any]],
5961
) = _prepare_convolution_fusion_create(
5973
constant_args = constant_args + [
5975
may_convert_to_optional(scalars),
5978
return ConvolutionTransposeUnary(
5979
layout=kernel_layout,
5981
constant_args=constant_args,
5985
class MkldnnRnnLayer(ExternKernelAlloc):
5997
python_kernel_name="aten.mkldnn_rnn_layer",
5998
cpp_kernel_name="at::mkldnn_rnn_layer",
6012
batch_sizes: List[int],
6017
bidirectional: bool,
6021
x = cls.require_stride1(cls.realize_input(x))
6022
# If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
6023
# Make sure x is contiguous in batch_first case.
6025
w0 = cls.require_stride1(cls.realize_input(w0))
6026
w1 = cls.require_stride1(cls.realize_input(w1))
6027
w2 = cls.require_stride1(cls.realize_input(w2))
6028
w3 = cls.require_stride1(cls.realize_input(w3))
6029
hx = cls.require_stride1(cls.realize_input(hx))
6031
cx = cls.require_stride1(cls.realize_input(cx))
6034
input_size = x.get_size()
6035
assert len(input_size) == 3, "Expect lstm input to be 3D"
6036
# batch_first is handled in the lstm OP. When entering
6037
# rnn_layer here, we'll always have batch_first = False
6038
seq_length, mini_batch, input_size = input_size
6039
output_shape = [seq_length, mini_batch, hidden_size]
6041
hy_shape = hx.get_size()
6042
cy_shape = cx.get_size()
6044
res: List[IRNode] = []
6046
inputs = [x, w0, w1, w2, w3, hx, cx]
6059
packed = MkldnnRnnLayer(
6060
MultiOutputLayout(x.get_device()),
6062
constant_args=constant_args,
6065
def get_strides_of_lstm_output(output_shape, batch_first):
6066
assert len(output_shape) == 3, "Expect output_shape to be 3D"
6067
return make_contiguous_strides_for(output_shape)
6069
output_sizes = [output_shape, hy_shape, cy_shape]
6071
get_strides_of_lstm_output(output_shape, batch_first),
6072
make_contiguous_strides_for(hy_shape),
6073
make_contiguous_strides_for(cy_shape),
6086
for i, (output_size, output_stride) in enumerate(
6087
zip(output_sizes, output_strides)
6094
class QConvPointWisePT2E(ExternKernelAlloc):
6103
- inputs = [x, w, b, weight_scale, weight_zp]
6104
- const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6105
fp32_output, unary_attr, unary_scalars, unary_algorithm]
6107
- inputs = [x, w, weight_scale, weight_zp]
6108
- const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6109
fp32_output, unary_attr, unary_scalars, unary_algorithm]
6111
self.has_bias = len(inputs) == 5
6117
python_kernel_name="torch.ops.onednn.qconv2d_pointwise",
6118
cpp_kernel_name="onednn::qconv2d_pointwise",
6120
self.cpp_kernel_key = "qconv2d_pointwise"
6121
self.cpp_op_schema = """
6125
int64_t act_zero_point,
6127
at::Tensor weight_scales,
6128
at::Tensor weight_zero_points,
6129
c10::optional<at::Tensor> bias,
6130
torch::List<int64_t> stride,
6131
torch::List<int64_t> padding,
6132
torch::List<int64_t> dilation,
6134
double inv_output_scale,
6135
int64_t output_zero_point,
6136
c10::optional<c10::ScalarType> output_dtype,
6137
c10::string_view attr,
6138
torch::List<c10::optional<at::Scalar>> scalars,
6139
c10::optional<c10::string_view> algorithm)"""
6141
def codegen(self, wrapper):
6142
# Parser the inputs and constant
6143
args = [x.codegen_reference() for x in self.inputs]
6145
const_args.extend(self.codegen_const_args())
6148
packed_weight = args[1]
6149
bias = args[2] if self.has_bias else const_args[0]
6150
w_scale, w_zp = args[-2], args[-1]
6164
) = const_args[-12:]
6185
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6187
self.get_kernel_name(),
6190
self.cpp_kernel_key,
6192
if isinstance(self.layout, Layout):
6193
self.codegen_size_asserts(wrapper)
6201
weight: "TensorBox", # packed_weight
6202
w_scale: "TensorBox",
6206
padding_: List[int],
6207
dilation_: List[int],
6210
output_zero_point: int,
6217
output_padding = None
6218
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
6230
# swap padding and stride to align with functional conv arg order
6232
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6234
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6238
inputs = inputs + [w_scale, w_zp]
6239
constant_args = constant_args + [
6246
may_convert_to_optional(unary_scalars),
6250
if output_dtype is not None:
6251
assert output_dtype in [torch.float32, torch.bfloat16]
6252
# in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
6253
# if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
6254
kernel_layout.dtype = output_dtype
6256
return QConvPointWisePT2E(
6257
layout=kernel_layout,
6259
constant_args=constant_args,
6263
class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
6271
Needs input/weight/output qparams
6273
- inputs = [x, w, b, accum, w_scale, w_zp]
6274
- const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp,
6275
fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6277
- inputs = [x, w, accum, w_scale, w_zp]
6278
- const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale,
6279
accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6281
self.has_bias = len(inputs) == 6
6282
self.idx_for_inplace_sum = 3 if self.has_bias else 2
6288
python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary",
6289
cpp_kernel_name="onednn::qconv2d_pointwise",
6291
self.cpp_kernel_overload_name = "binary"
6292
self.cpp_kernel_key = "qconv2d_pointwise_binary"
6293
self.cpp_op_schema = """
6297
int64_t act_zero_point,
6300
int64_t accum_zero_point,
6302
at::Tensor weight_scales,
6303
at::Tensor weight_zero_points,
6304
c10::optional<at::Tensor> bias,
6305
torch::List<int64_t> stride,
6306
torch::List<int64_t> padding,
6307
torch::List<int64_t> dilation,
6309
double inv_output_scale,
6310
int64_t output_zero_point,
6311
c10::optional<c10::ScalarType> output_dtype,
6312
c10::string_view binary_attr,
6313
c10::optional<at::Scalar> alpha,
6314
c10::optional<c10::string_view> attr,
6315
torch::List<c10::optional<at::Scalar>> scalars,
6316
c10::optional<c10::string_view> algorithm)"""
6318
def codegen(self, wrapper):
6319
# Parser the inputs and constant
6320
args = [x.codegen_reference() for x in self.inputs]
6322
const_args.extend(self.codegen_const_args())
6325
packed_weight = args[1]
6326
bias = args[2] if self.has_bias else const_args[0]
6327
accum, w_scale, w_zp = args[-3], args[-2], args[-1]
6345
) = const_args[-16:]
6370
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6372
self.get_kernel_name(),
6375
self.cpp_kernel_key,
6376
self.cpp_kernel_overload_name,
6378
if isinstance(self.layout, Layout):
6379
self.codegen_size_asserts(wrapper)
6381
def get_mutation_names(self):
6382
return [self.inputs[self.idx_for_inplace_sum].get_name()]
6384
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
6396
weight: "TensorBox", # packed_weight
6401
padding_: List[int],
6402
dilation_: List[int],
6404
o_inv_scale: "TensorBox",
6405
output_zero_point: "TensorBox",
6414
output_padding = None
6420
) = _prepare_convolution_fusion_create(
6433
accum = cls.require_stride_order(accum, req_stride_order)
6434
inputs.append(accum)
6436
# swap padding and stride to align with functional conv arg order
6438
constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6440
constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6444
inputs = inputs + [w_scale, w_zp]
6445
constant_args = constant_args + [
6456
may_convert_to_optional(unary_scalars),
6461
binary_attr == "sum"
6462
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
6464
packed = QConvPointWiseBinaryPT2E(
6465
layout=NoneLayout(accum.get_device()),
6467
constant_args=constant_args,
6469
mark_node_as_mutating(packed, accum)
6471
# Return accum since it has been inplace changed.
6472
return packed.inputs[packed.idx_for_inplace_sum]
6475
class QLinearPointwisePT2E(ExternKernelAlloc):
6482
x_scale_zp_are_tensors=False,
6486
- inputs = [x, w, b, weight_scale, weight_zp]
6487
- const_args is: [x_scale, x_zp, o_inv_scale, o_zp,
6488
fp32_output, unary_attr, unary_scalars, unary_algorithm]
6490
- inputs = [x, w, weight_scale, weight_zp]
6491
- const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
6492
fp32_output, unary_attr, unary_scalars, unary_algorithm]
6494
self.has_bias = has_bias
6495
self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
6501
python_kernel_name=(
6502
"torch.ops.onednn.qlinear_pointwise.tensor"
6503
if x_scale_zp_are_tensors
6504
else "torch.ops.onednn.qlinear_pointwise.default"
6506
cpp_kernel_name="onednn::qlinear_pointwise",
6508
self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else ""
6509
self.cpp_kernel_key = "qlinear_pointwise"
6510
x_scale_type_str, x_zp_type_str = (
6511
("at::Tensor", "at::Tensor")
6512
if x_scale_zp_are_tensors
6513
else ("double", "int64_t")
6515
self.cpp_op_schema = f"""
6518
{x_scale_type_str} act_scale,
6519
{x_zp_type_str} act_zero_point,
6521
at::Tensor weight_scales,
6522
at::Tensor weight_zero_points,
6523
c10::optional<at::Tensor> bias,
6524
double inv_output_scale,
6525
int64_t output_zero_point,
6526
c10::optional<c10::ScalarType> output_dtype,
6527
std::string post_op_name,
6528
torch::List<c10::optional<at::Scalar>> post_op_args,
6529
std::string post_op_algorithm)"""
6531
def codegen(self, wrapper):
6532
# Parser the inputs and constant
6533
args = [x.codegen_reference() for x in self.inputs]
6535
const_args.extend(self.codegen_const_args())
6538
packed_weight = args[1]
6539
bias = args[2] if self.has_bias else const_args[0]
6540
w_scale, w_zp = args[-2], args[-1]
6541
if self.x_scale_zp_are_tensors:
6542
assert len(args) >= 4
6543
x_scale, x_zp = args[-4], args[-3]
6553
assert len(const_args) >= 8
6580
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6582
self.get_kernel_name(),
6585
self.cpp_kernel_key,
6586
self.cpp_kernel_overload_name,
6588
if isinstance(self.layout, Layout):
6589
self.codegen_size_asserts(wrapper)
6597
weight: "TensorBox", # packed_weight
6598
w_scale: "TensorBox",
6602
output_zero_point: int,
6608
(inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create(
6615
if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
6618
inputs = inputs + [x_scale, x_zp]
6619
x_scale_zp_are_tensors = True
6621
assert isinstance(x_scale, float) and isinstance(x_zp, int)
6622
constant_args = constant_args + [x_scale, x_zp]
6623
x_scale_zp_are_tensors = False
6626
inputs = inputs + [w_scale, w_zp]
6627
constant_args = constant_args + [
6632
may_convert_to_optional(unary_scalars),
6636
if output_dtype is not None:
6637
assert output_dtype in [torch.float32, torch.bfloat16]
6638
# in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
6639
# if we set fp32_output, the output buf should be dtype float32 instead of uint8.
6640
kernel_layout.dtype = output_dtype
6642
return QLinearPointwisePT2E(
6643
layout=kernel_layout,
6645
constant_args=constant_args,
6646
has_bias=(bias is not None),
6647
x_scale_zp_are_tensors=x_scale_zp_are_tensors,
6651
@dataclasses.dataclass
6652
class MutableBox(IRNode):
6654
TensorBox / StorageBox allow in-place mutation of Tensors
6659
def __getattr__(self, name):
6660
fn = getattr(self.data, name)
6663
raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
6666
return self.data.realize()
6668
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
6669
return self.data.get_unbacked_symbol_uses()
6671
def codegen_reference(self, writer=None):
6672
return self.data.codegen_reference(writer)
6676
return self.data.layout # type: ignore[attr-defined]
6678
def get_layout(self):
6682
return self.data.get_size()
6686
return self.data.dtype
6689
if isinstance(self.data, MutableBox):
6690
line0 = f"{type(self).__name__}({type(self.data).__name__}("
6692
inner = self.data.data
6694
line0 = f"{type(self).__name__}("
6703
return "\n".join(lines)
6708
class TensorBox(MutableBox):
6711
return TensorBox(StorageBox(data))
6714
class StorageBox(MutableBox):
6715
def is_input_buffer(self):
6716
if isinstance(self.data, (InputBuffer, ReinterpretView)):
6717
return self.data.get_name() in V.graph.graph_inputs
6731
return self.data.get_name()
6732
assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data)
6733
origin_node = self.data.get_origin_node()
6734
traceback = self.data.get_traceback()
6735
self.data = ComputedBuffer(
6737
layout=FlexibleLayout(
6738
device=self.data.get_device(),
6739
dtype=self.data.get_dtype(),
6740
size=self.data.get_size(),
6744
self.data.name = V.graph.register_buffer(self.data)
6745
self.data.origins = self.origins
6746
self.data.origin_node = origin_node
6747
self.data.traceback = traceback
6748
return self.data.name
6750
def realize_hint(self):
6752
Called on buffers we expect to be forced to realize later.
6755
isinstance(self.data, (Pointwise, Reduction))
6756
and self.num_reads() > 1
6757
and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one()
6761
def has_exceeded_max_reads(self):
6762
return isinstance(self.data, Pointwise) and (
6763
self.num_reads() > config.realize_acc_reads_threshold
6764
or self.has_large_inner_fn()
6767
def mark_reuse(self, users):
6769
A heuristic to decide if we should realize a tensor
6770
that is used multiple times.
6773
def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
6775
The heuristic for realizing reused result of heavy ops on cpu
6777
heavy_ops = ["exp"] # a list of heavy ops
6778
fn_str = loops.inner_fn_str()
6779
return any((op + "(") in fn_str for op in heavy_ops)
6783
and isinstance(self.data, (Pointwise, Reduction))
6785
self.num_reads() > config.realize_reads_threshold
6786
or self.has_large_inner_fn()
6787
or (is_cpu(self.data) and should_realize_on_cpu(self.data))
6793
def num_reads(self):
6795
if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
6797
if isinstance(data, ComputedBuffer):
6798
read_writes = data.get_read_writes()
6800
assert isinstance(data, (Pointwise, Reduction)), type(data)
6801
read_writes = ComputedBuffer(
6803
layout=FlexibleLayout(
6804
device=data.get_device(),
6805
dtype=data.get_dtype(),
6806
size=data.get_size(),
6810
return len(read_writes.reads)
6813
def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self):
6814
# Skip the check for non Pointwise instances
6816
(sum(read.index != 0 for read in self.data.get_reads()) > 1)
6817
if isinstance(self.data, Pointwise)
6819
not isinstance(read, dependencies.StarDep)
6820
for read in self.data.get_reads()
6826
@dataclasses.dataclass
6827
class Subgraph(IRNode):
6829
graph_module: torch.fx.GraphModule
6830
graph: Optional["GraphLowering"] = None
6833
@dataclasses.dataclass
6834
class Conditional(ExternKernel):
6835
predicate: Optional[DynamicScalar] = None
6836
operands: Optional[List[TensorBox]] = None
6837
true_subgraph: Optional[Subgraph] = None
6838
false_subgraph: Optional[Subgraph] = None
6839
outputs: Optional[List[MultiOutput]] = None
6843
predicate: DynamicScalar,
6844
operands: List[TensorBox],
6845
true_subgraph: Subgraph,
6846
false_subgraph: Subgraph,
6847
layout: MultiOutputLayout,
6849
self.predicate = predicate
6850
self.operands = operands
6851
self.true_subgraph = true_subgraph
6852
self.false_subgraph = false_subgraph
6856
layout=layout, # type: ignore[arg-type]
6857
inputs=[predicate, *operands], # type: ignore[list-item]
6860
self.name = V.graph.register_buffer(self)
6865
predicate: TensorBox,
6868
operands: List[TensorBox],
6870
predicate = cls.realize_input(predicate)
6871
operands = [cls.realize_input(x) for x in operands]
6873
fx_operands = V.graph.current_node.args[-1]
6874
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
6876
for subgraph in (true_fn, false_fn):
6877
if subgraph.graph is None:
6878
# create and lower subgraphs
6879
subgraph.graph = V.graph.make_subgraph(
6880
gm=subgraph.graph_module,
6881
example_inputs=fake_operands,
6882
subgraph_name=subgraph.name,
6884
with V.set_graph_handler(subgraph.graph):
6885
subgraph.graph.run(*fake_operands)
6887
true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr]
6888
false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr]
6890
def _aliased_buffers(outputs):
6892
output.unwrap_view() if isinstance(output, ReinterpretView) else output
6893
for output in outputs
6895
# assuming the same buffer is represented by the same IRNode object
6896
return len({id(buffer) for buffer in buffers}) < len(outputs)
6898
for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
6899
if _aliased_buffers(true_outputs):
6900
raise AssertionError(
6901
"Output aliasing is currently not supported in compiled torch.cond. "
6902
f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
6905
# make sure true and false outputs are structurally equivalent
6906
assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
6907
for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)):
6908
assert to.get_size() == fo.get_size(), (i, to, fo)
6909
assert to.get_stride() == fo.get_stride(), (i, to, fo)
6910
assert to.get_device() == fo.get_device(), (i, to, fo)
6911
assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
6912
assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
6914
conditional = Conditional(
6915
predicate=predicate,
6917
true_subgraph=true_fn,
6918
false_subgraph=false_fn,
6919
# use predicate device for consistent codegen-ing
6920
layout=MultiOutputLayout(predicate.get_device()),
6926
device=output.get_device(),
6927
dtype=output.get_dtype(),
6928
size=output.get_size(),
6929
stride=output.get_stride(),
6930
offset=output.get_layout().offset,
6935
# as the true and false outputs are equivalent,
6936
# we can use either of them here as a "template"
6937
for i, output in enumerate(true_outputs)
6940
conditional.outputs = outputs
6943
def codegen(self, wrapper):
6944
wrapper.codegen_conditional(self)
6947
class InterpreterShim(torch.fx.Interpreter):
6949
@functools.lru_cache(None)
6951
return torch.fx.symbolic_trace(identity)
6953
def __init__(self, graph, submodules):
6954
# call super() with a placeholder to avoid constructing a
6955
# GraphModule which is very expensive (it does codegen).
6956
super().__init__(self._dummy_gm(), garbage_collect_values=False)
6957
self.module = self # type: ignore[assignment]
6959
self.submodules = submodules
6960
self.extra_traceback = False
6961
self.fetch_attr = submodules.__getitem__
6962
self.current_node = None
6964
def run_node(self, n: torch.fx.Node) -> Any:
6965
self.current_node = n
6966
return super().run_node(n)
6968
def run(self, *args, **kwargs):
6969
with V.set_interpreter_handler(self):
6970
return super().run(*args, **kwargs)
6975
Captures the body of a Loops subclass into an FX graph. Persists any
6976
indexing simplifications and makes it easier to analyze loop bodies.
6979
def __init__(self, fn, args, var_ranges):
6981
self.var_ranges = var_ranges
6982
self.indexing_exprs = {}
6983
self.indexing_exprs_name = {}
6986
self.reads_name2expr = {}
6987
self.writes_name2expr = {}
6989
self.submodules = {"get_index": self.get_index}
6991
self.indirect_vars = []
6992
self.root_block = LoopBodyBlock(self, fn, args)
6993
self.indexing = None
6996
def get_nodes(self):
6997
all_graphs = itertools.chain(
6998
(self.root_block.graph,),
6999
(block.graph for block in self.subblocks.values()),
7001
return [node for graph in all_graphs for node in graph.nodes]
7005
# Doing a local import to avoid dumping all the code here
7006
from .bounds import BoundVars
7008
return BoundVars(self)
7010
def debug_str(self):
7011
lines = [f"var_ranges = {dict(self.var_ranges)}"]
7012
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
7015
block.debug_str(name)
7016
for name, block in itertools.chain(
7017
[("body", self.root_block)], self.subblocks.items()
7021
return "\n".join(lines)
7023
def add_index_expr(self, expr: sympy.Expr, category, buf_name):
7024
getattr(self, category).append(expr)
7025
if buf_name is not None:
7026
getattr(self, f"{category}_name2expr")[buf_name] = expr
7027
if expr not in self.indexing_exprs_name:
7028
name = f"index{len(self.indexing_exprs)}"
7029
self.indexing_exprs_name[expr] = name
7030
self.indexing_exprs[name] = expr
7031
return self.indexing_exprs_name[expr]
7033
def add_submodule(self, block, prefix):
7034
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
7035
if prefix[-1].isnumeric() and prefix not in self.submodules:
7038
name = f"{prefix}{len(self.submodules)}"
7039
self.submodules[name] = block
7042
def add_indirect(self, size):
7043
name = f"indirect{len(self.indirect_vars)}"
7044
var = sympy_index_symbol(name)
7045
self.indirect_vars.append(var)
7048
def replace_indirect(self, old, new):
7049
"""Swap in a variable used in indirect indexing"""
7050
if str(old) == str(new):
7052
assert self.indexing is not None
7053
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
7055
def get_index(self, name):
7056
assert self.indexing is not None
7057
return self.indexing[name]
7059
def __call__(self, *indices):
7060
index = list(itertools.chain.from_iterable(indices))
7061
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
7062
assert all(v not in self.var_ranges for v in index)
7063
replacements = dict(zip(self.var_ranges.keys(), index))
7065
name: sympy_subs(expr, replacements)
7066
for name, expr in self.indexing_exprs.items()
7068
result = self.root_block()
7069
self.indexing = None
7075
Captures the body of a Loops subclass into an FX graph.
7076
In normal cases there will be a 1:1 mapping between LoopBody and
7077
LoopBodyBlock, hower in the case of ops.masked() the masked out
7078
operations will manifest as an extra LoopBodyBlock.
7081
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
7084
def add_index(expr, category, buf_name=None):
7085
return tracer.create_proxy(
7088
(self.body.add_index_expr(expr, category, buf_name),),
7092
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
7093
self.name = "CaptureIndexing"
7095
def load(self, name: str, index: sympy.Expr):
7096
index = add_index(index, "reads", name)
7097
return self._inner.load(name, index)
7099
def store(self, name, index, value, mode=None):
7100
index = add_index(index, "writes", name)
7101
return self._inner.store(name, index, value, mode)
7103
def store_reduction(self, name, index, value):
7104
index = add_index(index, "writes", name)
7105
return self._inner.store_reduction(name, index, value)
7107
def reduction(self, dtype, src_dtype, reduction_type, value):
7108
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
7109
if "welford" in reduction_type:
7110
return tuple(result[i] for i in range(3))
7113
def index_expr(self, index, dtype):
7114
if isinstance(index, (int, sympy.Integer)):
7115
return self._inner.constant(int(index), dtype)
7116
index = add_index(index, "other")
7117
return self._inner.index_expr(index, dtype)
7123
offsets_size: sympy.Expr,
7124
indexing_dtype: torch.dtype,
7127
offsets_size = add_index(offsets_size, "other")
7128
return self._inner.bucketize(
7129
values, offsets_name, offsets_size, indexing_dtype, right
7133
def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
7135
Recursively capture the masked out body in another LoopBodyBlock
7138
subblock: LoopBodyBlock
7140
def shim(mask, other):
7141
return V.ops.masked(mask, subblock, other)
7143
name = self.body.add_submodule(shim, "masked_subblock")
7144
subblock = LoopBodyBlock(self.body, masked_body, [])
7145
self.body.subblocks[name] = subblock
7146
return tracer.create_proxy(
7147
"call_module", name, (mask_proxy, other_proxy), {}
7152
dtype_proxy, combine_fn: Callable[..., Any], value_proxy, init_proxy
7154
def shim(dtype, value, init):
7155
return V.ops.scan(dtype, combine_fn, value, init)
7157
name = self.body.add_submodule(shim, "scan")
7158
return tracer.create_proxy(
7159
"call_module", name, (dtype_proxy, value_proxy, init_proxy), {}
7162
def frexp(self, value_proxy):
7163
result = self._inner.frexp(value_proxy)
7164
# Proxies are iterable, but some methods expect tuples/lists
7165
return (result[0], result[1])
7168
def indirect_indexing(index_proxy, size, check=True):
7170
Flow data from tensors into indexing formulas.
7171
Introduce a call_module to update the indexing.
7174
var = self.body.add_indirect(size)
7176
def set_indirect(new_var):
7177
self.body.replace_indirect(
7178
var, V.ops.indirect_indexing(new_var, size, check)
7181
tracer.create_proxy(
7183
self.body.add_submodule(set_indirect, f"set_{var}"),
7191
tracer.create_proxy("output", "output", (result,), {})
7193
tracer = torch.fx.Tracer()
7194
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
7195
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
7197
from .index_propagation import IndexPropagation
7198
from .sizevars import SimplifyIndexing
7200
handler: Any = SimplifyIndexing(
7201
CaptureIndexing(proxy_ops), self.body.var_ranges
7203
if config.constant_and_index_propagation:
7204
handler = IndexPropagation(handler)
7206
with V.set_ops_handler(handler):
7207
# This indirection is just a cute way to get IndexPropagation to
7208
# unwrap the return value.
7209
ops.output(fn(*args))
7210
self.graph = tracer.graph
7214
submodules = self.body.submodules
7216
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
7218
def debug_str(self, name="block"):
7219
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
7221
# strip `; del var0` suffixes to make output prettier
7224
code.strip().replace("def forward(", f"def {name}("),
7228
class Wait(ExternKernelAlloc):
7230
Wait should not be used by itself. It should always be constructed in tandem
7231
with a collective op that produces a work to wait on.
7240
super().__init__(layout, inputs, constant_args)
7242
def should_allocate(self):
7245
def codegen(self, wrapper):
7246
from .codegen.wrapper import ReuseLine
7248
wrapper.add_import_once(
7249
"from torch.distributed._functional_collectives_impl import _wait_tensor"
7251
(input_collective,) = (t.codegen_reference() for t in self.inputs)
7252
wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})")
7254
# wait op still needs to produce a 'buffer' that represents the tensor output.
7255
# this is a symbolic gesture, and it gets handled by WrapperCodegen.
7256
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
7257
# to a new name (`self.get_name()`) and `del`s the old name.
7258
wrapper.writeline(ReuseLine(wrapper, self.inputs[0], self, delete_old=False))
7261
def create(cls, collective_op: "TensorBox"):
7262
# TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream
7263
collective_op.decide_layout()
7265
layout=AliasedLayout(collective_op),
7266
inputs=[collective_op],
7269
def get_alias_names(self):
7270
# Signal to codegen that our output buffer isn't safe to reuse
7271
return [self.inputs[0].codegen_reference()]
7273
def get_mutation_names(self):
7274
# The generated `_wait_tensor` op mutates the input tensor
7275
return [self.inputs[0].codegen_reference()]
7278
class CollectiveKernel(ExternKernel):
7280
Each collective should follow the pattern:
7281
- extend InPlaceCollectiveKernel or OutOfPlaceCollectiveKernel.
7282
- the kernel delegates into c10d processgroup, which returns a 'work' obj
7283
- the work obj is registered via _register_tensor_work so it can be waited on later
7286
def __init__(self, layout, inputs, constant_args):
7287
super().__init__(None, layout, inputs, constant_args)
7288
self.name = V.graph.register_buffer(self)
7290
def should_emit_register_tensor_work(self):
7293
def should_emit_find_or_create_pg(self):
7296
def codegen_collective(self, wrapper, output_name, input_names):
7297
# factor so the boilerplate can be handled in CollectiveKernel.codegen
7298
raise NotImplementedError("Must implement")
7300
def codegen_output(self, wrapper, output_name, input_names):
7301
# factor so the boilerplate can be handled in CollectiveKernel.codegen
7302
raise NotImplementedError("Must implement")
7305
def wrap_inputs_as_inplace(cls, inputs):
7306
def wrap_input(var):
7308
FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var
7310
return TensorBox.create(op)
7312
return list(map(wrap_input, inputs))
7314
def codegen(self, wrapper):
7315
wrapper.add_import_once("import torch.distributed as dist")
7316
wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d")
7317
wrapper.add_import_once(
7318
"import torch.distributed._functional_collectives_impl as fun_col_impl"
7320
# extract references to our args in string form for codegen output
7321
input_names = [t.codegen_reference() for t in self.inputs]
7322
output_name = self.get_name()
7323
tag, ranks, group_size = self.constant_args
7325
if self.should_emit_find_or_create_pg():
7326
# TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
7328
f"{output_name}_pg = c10d._find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
7331
self.codegen_output(wrapper, output_name, input_names)
7332
self.codegen_collective(wrapper, output_name, input_names)
7333
if self.should_emit_register_tensor_work():
7335
f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)"
7339
class InPlaceCollectiveKernel(CollectiveKernel):
7341
InPlaceCollectiveKernel are those with in-out arguments such as all_reduce.
7342
Extend this kernel if your collective needs to modify its inputs in-place.
7345
def __init__(self, layout, inputs, constant_args):
7346
super().__init__(layout, inputs, constant_args)
7348
def should_allocate(self):
7351
def has_side_effects(self):
7354
def codegen_output(self, wrapper, output_name, input_names):
7355
if len(input_names) > 1:
7356
wrapper.writeline(f"{output_name} = [{','.join(input_names)}] ")
7358
wrapper.writeline(f"{output_name} = {input_names[0]}")
7361
class OutOfPlaceCollectiveKernel(CollectiveKernel):
7363
OutOfPlaceCollectiveKernel are those that allocate their
7364
outputs and leave their inputs inplace, such as all_gather.
7367
def __init__(self, layout, inputs, outputs, constant_args):
7368
super().__init__(layout, inputs + outputs, constant_args)
7369
self.outputs = outputs
7370
self.original_inputs = inputs
7371
# NOTE: As seen in issue #108780, output buffers of out-of-place collectives
7372
# could be incorrectly reused. As a safety measure, here we just ban the reuse of them.
7373
# TODO: A better fix is to figure out how to propagate the aliases properly,
7374
# so that the buffer is only reused after all its users have consumed it.
7375
for x in self.outputs:
7376
V.graph.never_reuse_buffers.add(x.name)
7378
def should_allocate(self):
7381
def has_side_effects(self):
7384
def codegen_output(self, wrapper, output_name, input_names):
7385
input_names = [t.codegen_reference() for t in self.original_inputs]
7386
wrapper.writeline(f"{output_name}_inputs = [{','.join(input_names)}]")
7387
wrapper.writeline(f"{output_name} = [{','.join(x.name for x in self.outputs)}]")
7390
def create_output_buffers(cls, inputs, size_cb=None):
7392
for input in inputs:
7393
new_size = input.get_size()
7394
if size_cb is not None:
7396
# new_size[0] *= group_size
7398
buff = OutputBuffer(
7399
layout=FlexibleLayout(
7400
device=input.get_device(),
7401
dtype=input.get_dtype(),
7405
outputs.append(buff)
7409
def create_output_nodes(cls, coll, output_buffers):
7411
MultiOutputNoSizeAssert(
7416
for i, out_t in enumerate(output_buffers)
7420
class InPlaceHint(ExternKernel):
7422
Helper OP to encode an in/out argument that tries to make it inplace whenever possible.
7423
Wrap the input of your inplace op to enable this behavior.
7425
The design is based on two key decisions:
7426
- this node is responsible for allocating the in/out buffer used by the collective.
7427
This is controlled by the ``should_allocate`` method that returns True here and
7428
False for the collective node
7429
- The scheduler special-case this node and enable it to reuse its input.
7432
def codegen(self, wrapper):
7433
input_name = self.inputs[0].codegen_reference()
7434
output_name = self.get_name()
7435
if not wrapper.did_reuse(self, self.inputs[0]):
7436
wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse")
7438
def __init__(self, layout, input):
7439
input = self.realize_input(input)
7440
super().__init__(None, layout, self.unwrap_storage([input]), ())
7441
self.name = V.graph.register_buffer(self)
7443
def should_allocate(self):
7447
class OutputBuffer(ExternKernel):
7449
Represent the output buffer used by ops that require multiple of them
7452
def __init__(self, layout):
7453
super().__init__(name=None, layout=layout, inputs=[])
7454
self.name = V.graph.register_buffer(self)
7456
def should_allocate(self):
7459
def codegen(self, wrapper):
7460
wrapper.writeline(f"# collective out buffer {self.name}")
7463
class MultiOutputNoSizeAssert(MultiOutput):
7465
Extract partial output from a multi-output OP.
7466
Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emitting this.
7469
def __init__(self, layout, input, index):
7470
super().__init__(layout, input, [])
7473
def codegen(self, wrapper):
7475
f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
7479
class Broadcast(InPlaceCollectiveKernel):
7480
def __init__(self, layout, inputs, constant_args, src):
7481
super().__init__(layout, inputs, constant_args)
7484
def get_mutation_names(self):
7485
return [self.inputs[0].get_name()]
7487
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7492
cls, x: "TensorBox", src: int, tag: str, ranks: List[int], group_size: int
7494
inplace_inputs = cls.wrap_inputs_as_inplace([x])
7496
layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type]
7497
inputs=inplace_inputs,
7498
constant_args=[tag, ranks, group_size],
7501
mark_node_as_mutating(packed, inplace_inputs[0])
7502
return inplace_inputs[0]
7504
def codegen_collective(self, wrapper, output_name, input_names):
7506
f"{output_name}_work = dist.broadcast("
7507
f"{output_name}, async_op=True, group={output_name}_pg, src={self.src})"
7511
class AllReduceCoalesced(InPlaceCollectiveKernel):
7512
def __init__(self, layout, inputs, constant_args, reduce_op):
7513
super().__init__(layout, inputs, constant_args)
7514
self.reduce_op = reduce_op
7516
def should_allocate(self):
7519
def get_mutation_names(self):
7520
return [self.inputs[0].get_name()]
7522
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7528
inputs: List["TensorBox"],
7534
inplace_inputs = cls.wrap_inputs_as_inplace(inputs)
7535
packed = AllReduceCoalesced(
7536
layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type]
7537
inputs=inplace_inputs,
7538
constant_args=[tag, ranks, group_size],
7539
reduce_op=reduce_op,
7541
mark_node_as_mutating(packed, inplace_inputs[0])
7542
return inplace_inputs
7544
def codegen_collective(self, wrapper, output_name, input_names):
7546
f"{output_name}_work = dist.all_reduce_coalesced("
7548
f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
7549
f"group={output_name}_pg, "
7554
class AllReduce(InPlaceCollectiveKernel):
7555
def __init__(self, layout, inputs, constant_args, reduce_op):
7556
super().__init__(layout, inputs, constant_args)
7557
self.reduce_op = reduce_op
7559
def get_mutation_names(self):
7560
return [self.inputs[0].get_name()]
7562
def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7567
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
7569
inplace_inputs = cls.wrap_inputs_as_inplace([x])
7572
layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type]
7573
inputs=inplace_inputs,
7574
constant_args=[tag, ranks, group_size],
7575
reduce_op=reduce_op,
7577
mark_node_as_mutating(packed, inplace_inputs[0])
7578
return inplace_inputs[0]
7580
def codegen_collective(self, wrapper, output_name, input_names):
7582
f"{output_name}_work = dist.all_reduce("
7583
f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
7587
class AllGatherIntoTensor(OutOfPlaceCollectiveKernel):
7588
def __init__(self, layout, inputs, outputs, constant_args):
7589
super().__init__(layout, inputs, outputs, constant_args)
7592
def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int):
7593
inputs = [cls.realize_input(x)]
7595
def compute_size(new_size):
7596
new_size[0] *= group_size
7598
outputs = cls.create_output_buffers(inputs, compute_size)
7600
layout = MultiOutputLayout(inputs[0].get_device())
7602
packed = AllGatherIntoTensor(
7606
constant_args=[tag, ranks, group_size],
7608
return cls.create_output_nodes(packed, outputs)[0]
7610
def codegen_collective(self, wrapper, output_name, input_names):
7612
f"{output_name}_work = dist.all_gather_into_tensor("
7613
f"{output_name}[0], {output_name}_inputs[0], async_op=True, group={output_name}_pg)"
7617
class ReduceScatterTensor(OutOfPlaceCollectiveKernel):
7618
def __init__(self, layout, inputs, outputs, constant_args, reduce_op):
7619
super().__init__(layout, inputs, outputs, constant_args)
7620
self.reduce_op = reduce_op
7631
inputs = [cls.realize_input(x)]
7633
def compute_size(new_size):
7634
new_size[0] //= group_size
7636
outputs = cls.create_output_buffers(inputs, compute_size)
7638
layout = MultiOutputLayout(inputs[0].get_device())
7640
packed = ReduceScatterTensor(
7644
constant_args=[tag, ranks, group_size],
7645
reduce_op=reduce_op,
7647
return cls.create_output_nodes(packed, outputs)[0]
7649
def codegen_collective(self, wrapper, output_name, input_names):
7651
f"{output_name}_work = dist.reduce_scatter_tensor("
7652
f"{output_name}[0], {output_name}_inputs[0], "
7653
f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
7657
class AllGatherIntoTensorCoalesced(OutOfPlaceCollectiveKernel):
7658
def __init__(self, layout, inputs, outputs, constant_args):
7659
super().__init__(layout, inputs, outputs, constant_args)
7664
inputs: List["TensorBox"],
7669
inputs = [cls.realize_input(x) for x in inputs]
7671
def compute_size(new_size):
7672
new_size[0] *= group_size
7674
outputs = cls.create_output_buffers(inputs, compute_size)
7676
layout = MultiOutputLayout(inputs[0].get_device())
7678
packed = AllGatherIntoTensorCoalesced(
7682
constant_args=[tag, ranks, group_size],
7686
# return cls.create_output_nodes(packed, outputs)
7688
def codegen_collective(self, wrapper, output_name, input_names):
7690
f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
7691
f"output_tensors={output_name}, "
7692
f"input_tensors={output_name}_inputs, "
7693
f"group={output_name}_pg, "
7698
class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel):
7699
def __init__(self, layout, inputs, outputs, constant_args, reduce_op):
7700
super().__init__(layout, inputs, outputs, constant_args)
7701
self.reduce_op = reduce_op
7706
inputs: List["TensorBox"],
7712
inputs = [cls.realize_input(x) for x in inputs]
7714
def compute_size(new_size):
7715
new_size[0] //= group_size
7717
outputs = cls.create_output_buffers(inputs, compute_size)
7719
layout = MultiOutputLayout(inputs[0].get_device())
7721
_ = ReduceScatterTensorCoalesced(
7725
constant_args=[tag, ranks, group_size],
7726
reduce_op=reduce_op,
7731
def codegen_collective(self, wrapper, output_name, input_names):
7733
f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
7734
f"output_tensors={output_name}, "
7735
f"input_tensors={output_name}_inputs, "
7736
f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
7737
f"group={output_name}_pg, "
7742
# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel.
7743
class _CollectiveKernel(FallbackKernel):
7744
def should_allocate(self):
7747
def has_side_effects(self):
7750
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
7751
# part that checks against input aliasing and mutation.
7752
def set_cpp_kernel(self, kernel):
7753
from .codegen.wrapper import get_cpp_op_schema
7755
self.cpp_kernel_name = kernel._schema.name
7756
self.cpp_kernel_overload_name = kernel._schema.overload_name
7757
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]
7759
self.cpp_op_schema = get_cpp_op_schema(kernel)
7760
self.ordered_kwargs_for_cpp_kernel = [
7761
x.name for x in kernel._schema.arguments if x.kwarg_only
7764
# NOTE: [In-Place Collective Safety]
7765
# Between the initiation and completion of an in-place collective, the
7766
# input buffers are subject to both volatile reads and volatile writes.
7767
# They must not be read, written to or reused by another kernel. To ensure
7768
# the constraints, we model collective -> wait_tensor as as two-step
7769
# mutation of the input buffers.
7772
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
7774
cpp_kernel_name = kernel._name
7775
python_kernel_name = cpp_kernel_name.replace("::", ".")
7776
with V.graph.fake_mode:
7782
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
7783
for tensor_arg in tensor_args:
7784
tensor_arg.realize()
7787
NoneLayout(tensor_args[0].get_device()),
7793
packed.cpp_kernel_name = cpp_kernel_name
7794
packed.python_kernel_name = python_kernel_name
7796
def mark_mutation(x):
7797
if isinstance(x.data, BaseView):
7798
x = x.data.unwrap_view()
7799
MutationOutput(x.layout, x, packed)
7801
pytree.tree_map(lambda inp: mark_mutation(inp), inputs)
7803
# NOTE: [Out-of-Place Collective Safety]
7804
# Between the initiation and completion of an out-of-place collective:
7807
# - Are subject to volatile reads
7808
# - Can be read by another kernel
7809
# - Must not be written to or reused by another kernel
7812
# - Are subject to volatile writes
7813
# - Must not be read, written to or reused by another kernel
7815
# To ensure the safety of input buffers without sacrificing read
7816
# availability, we add input buffers as read deps of wait_tensor kernels.
7818
# To ensure the safety of output buffers, we model wait_tensor as a
7819
# mutation to the output buffer. Note we also assumes the user program being
7820
# correct and the output buffer is not consumed by kernels other than
7823
# TODO(yifu): add a pre-grad pass to validate the correctness of collective
7824
# usage in the user program.
7826
def create_out_of_place(
7827
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
7829
cpp_kernel_name = kernel._name
7830
python_kernel_name = cpp_kernel_name.replace("::", ".")
7831
with V.graph.fake_mode:
7837
) = cls.process_kernel(kernel, inputs, *args, **kwargs)
7838
for tensor_arg in tensor_args:
7839
tensor_arg.realize()
7841
if isinstance(example_output, list):
7842
device = cls.find_device(tensor_args, example_output)
7844
MultiOutputLayout(device),
7850
packed.cpp_kernel_name = cpp_kernel_name
7851
packed.python_kernel_name = python_kernel_name
7854
cls.tensor_to_layout(tensor),
7858
for i, tensor in enumerate(example_output)
7860
return packed.outputs
7863
cls.tensor_to_layout(example_output),
7869
packed.cpp_kernel_name = cpp_kernel_name
7870
packed.python_kernel_name = python_kernel_name
7871
packed.outputs = [packed]
7875
class _WaitKernel(_CollectiveKernel):
7876
def get_volatile_reads(self):
7877
inp = self.inputs[0]
7878
if isinstance(inp, _CollectiveKernel):
7879
# Out-of-place single-output
7880
return [inp.inputs[0]]
7881
elif isinstance(inp, MultiOutput):
7882
# Out-of-place multi-output
7883
coll = inp.inputs[0]
7884
assert isinstance(coll, _CollectiveKernel)
7885
_, idx = inp.indices[0]
7886
return [coll.inputs[idx]]
7888
# In-place requires no additional deps handling for volatile
7889
# reads since the inputs are mutated.
7893
def create_wait(cls, kernel, inp: TensorBox) -> None:
7894
with V.graph.fake_mode:
7900
) = cls.process_kernel(kernel, inp)
7902
NoneLayout(inp.get_device()),
7908
if isinstance(inp.data, BaseView):
7909
inp = inp.data.unwrap_view()
7910
MutationOutput(inp.layout, inp, packed)
7912
def get_read_writes(self):
7913
read_writes = super().get_read_writes()
7914
# See [Out-of-Place Collective Safety].
7915
volatile_reads = self.get_volatile_reads()
7916
for vr in volatile_reads:
7917
read_writes.reads.add(dependencies.StarDep(vr.get_name()))
7921
# NB: recursive structure here reflects val_to_arg_str, avoid
7922
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
7924
def maybe_free_unbacked_symbols(s):
7925
if isinstance(s, (SymTypes, sympy.Expr)):
7926
# This branch should be impossible in return position
7927
return free_unbacked_symbols(s)
7928
elif isinstance(s, (tuple, list)):
7931
r |= maybe_free_unbacked_symbols(t)
7933
elif isinstance(s, torch.Tensor):
7934
# This branch is impossible in constant-args position
7935
return free_unbacked_symbols(s)
7940
class AllToAllSingle(OutOfPlaceCollectiveKernel):
7950
super().__init__(layout, inputs, outputs, constant_args)
7951
self.output_split_sizes = output_split_sizes
7952
self.input_split_sizes = input_split_sizes
7954
def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
7956
if self.output_split_sizes is not None:
7957
r |= free_unbacked_symbols(self.output_split_sizes)
7958
if self.input_split_sizes is not None:
7959
r |= free_unbacked_symbols(self.input_split_sizes)
7966
output_split_sizes: Optional[List[Expr]],
7967
input_split_sizes: Optional[List[Expr]],
7972
inputs = [cls.realize_input(x)]
7974
def compute_size(new_size):
7975
if output_split_sizes is not None:
7976
new_size[0] = sum(output_split_sizes)
7978
outputs = cls.create_output_buffers(inputs, compute_size)
7980
layout = MultiOutputLayout(inputs[0].get_device())
7982
packed = AllToAllSingle(
7986
constant_args=[tag, ranks, group_size],
7987
output_split_sizes=output_split_sizes,
7988
input_split_sizes=input_split_sizes,
7990
return cls.create_output_nodes(packed, outputs)[0]
7992
def codegen_collective(self, wrapper, output_name, input_names):
7993
tag, ranks, group_size = self.constant_args
7995
# TODO: might be necessary to do some pretty printing on
7998
f"{output_name}_work = dist.all_to_all_single("
7999
f"{output_name}[0], {output_name}_inputs[0], "
8000
f"output_split_sizes={self.output_split_sizes}, "
8001
f"input_split_sizes={self.input_split_sizes}, "
8002
f"group={output_name}_pg, async_op=True)"