pytorch

Форк
0
8003 строки · 262.8 Кб
1
import collections
2
import contextlib
3
import dataclasses
4
import functools
5
import itertools
6
import logging
7
import re
8
import textwrap
9
import traceback
10
from contextlib import nullcontext
11
from enum import Enum
12
from functools import partial
13
from typing import (
14
    Any,
15
    Callable,
16
    ClassVar,
17
    Dict,
18
    Iterable,
19
    List,
20
    Optional,
21
    Sequence,
22
    Set,
23
    Tuple,
24
    TYPE_CHECKING,
25
    Union,
26
)
27
from unittest.mock import patch
28

29
import sympy
30
from sympy import Expr, Integer
31

32
import torch._export.serde.schema as export_schema
33

34
import torch._logging
35

36
import torch.fx
37
import torch.utils._pytree as pytree
38
from torch._dynamo.device_interface import get_interface_for_device
39
from torch._dynamo.utils import identity
40
from torch._export.serde.serialize import GraphModuleSerializer
41
from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
42
from torch._prims_common import (
43
    compute_required_storage_length,
44
    is_boolean_dtype,
45
    is_float_dtype,
46
    make_channels_last_strides_for,
47
    make_contiguous_strides_for,
48
    StrideType,
49
)
50
from torch._subclasses.fake_tensor import get_schema_info
51
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymTypes
52
from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
53

54
from . import config, dependencies
55
from .codegen.common import index_prevent_reordering
56
from .dependencies import (
57
    extract_free_unbacked_symbols,
58
    extract_input_node_reduction_ranges,
59
    extract_read_writes,
60
    var_builder,
61
)
62
from .ops_handler import OpCounterCSE
63
from .utils import (
64
    argsort,
65
    cache_on_self,
66
    convert_shape_to_inductor,
67
    convert_shape_to_symint,
68
    developer_warning,
69
    get_kernel_metadata,
70
    is_dynamic,
71
    pad_listlike,
72
    sympy_dot,
73
    sympy_index_symbol,
74
    sympy_product,
75
    sympy_subs,
76
)
77
from .virtualized import ops, V
78

79
if TYPE_CHECKING:
80
    from .graph import GraphLowering
81

82
log = logging.getLogger(__name__)
83
indent = functools.partial(textwrap.indent, prefix="  ")
84
aten = torch.ops.aten
85

86
""" [Note: Inductor IR]
87

88
Inductor's IR is produced by executing 'lowering' code (see lowering.py).  Each
89
lowering is registered to a particular aten operator, and expects inputs that
90
correspond to the aten schema.  However, in place of torch Tensor inputs, lowerings
91
expect Inductor TensorBox inputs.
92

93
TensorBox IR represents torch tensors.  Tensors are sometimes single objects owning
94
storage, and sometimes views of another Tensor's storage.  Mutating tensor operations
95
(such as add_()) affect the underlying storage and any associated views.  Other operations
96
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
97

98
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
99

100
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
101
output from an operation.  But just as torch.Tensors take different forms, TensorBox IR can
102
reference View IR or directly reference StorageBox IRs.
103

104
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
105
may take an existing TensorBox and point it to a new underlying View IR.
106

107
Tensors that directly own storage are represented as a chain of:
108
TensorBox -> StorageBox -> Buffer
109
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
110

111
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
112
(leaving the old buffer unmodified and functionalizing the operation).
113

114
Tensors backed by views add one more indirection to the IR.
115
TensorBox -> View -> StorageBox -> Buffer
116
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
117
"""
118

119

120
def validate_ir(node_or_nodes):
121
    def _check_tensorbox(nodes):
122
        # Could expand this to check deeper properties
123
        # (e.g. TensorBox points to View or StorageBox)
124
        if isinstance(nodes, (list, tuple)):
125
            for node in nodes:
126
                _check_tensorbox(node)
127
        elif isinstance(nodes, dict):
128
            for node in nodes.values():
129
                _check_tensorbox(node)
130
        else:
131
            assert isinstance(
132
                nodes,
133
                (
134
                    torch._inductor.ir.ExpandView,
135
                    DynamicScalar,
136
                    AssertScalar,
137
                    TensorBox,
138
                    sympy.logic.boolalg.Boolean,
139
                    Expr,
140
                ),
141
            ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
142

143
    # Be picky about the accepted data structure (don't use pytree here)
144
    _check_tensorbox(node_or_nodes)
145

146

147
def ops_wrapper(name):
148
    assert isinstance(name, str)
149

150
    def fn(*args, **kwargs):
151
        return getattr(ops, name)(*args, **kwargs)
152

153
    return fn
154

155

156
def inverse_reorder(order):
157
    inv_order = dict(zip(order, range(len(order))))
158

159
    def reindex(index):
160
        assert len(index) == len(inv_order)
161
        return [index[inv_order[i]] for i in range(len(index))]
162

163
    return reindex
164

165

166
def same_reorder(order):
167
    def reindex(index):
168
        assert len(index) == len(order)
169
        return [index[order[i]] for i in range(len(index))]
170

171
    return reindex
172

173

174
def fuse_reindexing(reindex1, reindex2):
175
    def reindex(index):
176
        return reindex1(reindex2(index))
177

178
    return reindex
179

180

181
NHWC_STRIDE_ORDER = [3, 0, 2, 1]
182

183

184
def stride_order2fill_order(order):
185
    """
186
    Convert stride order to fill order
187
    For channel last format,
188
    stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
189
    """
190
    lookup = {pos: idx for idx, pos in enumerate(order)}
191
    fill_order = [lookup[i] for i in range(len(order))]
192
    return fill_order
193

194

195
def get_stride_order(seq: Sequence[int]) -> List[int]:
196
    """
197
    Convert strides to stride order
198
    """
199
    sorted_idx: List[int] = argsort(seq)
200
    out = [0 for _ in range(len(seq))]
201
    for i, elem in enumerate(sorted_idx):
202
        out[elem] = i
203
    return out
204

205

206
def ir_node_to_tensor(x, guard_shape=True):
207
    if x is None:
208
        return None
209

210
    shape_fn: Callable[[Expr], Union[int, Expr]]
211
    if not guard_shape:
212
        shape_fn = V.graph.sizevars.size_hint
213
    else:
214
        shape_fn = identity
215
    size = [shape_fn(s) for s in x.get_size()]
216
    stride: StrideType
217
    if is_storage_and_layout(x):
218
        stride = [shape_fn(s) for s in x.get_layout().stride]  # type: ignore[misc]
219
    else:
220
        stride = make_contiguous_strides_for(size)  # type: ignore[arg-type]
221
    dtype = x.get_dtype()
222
    device = x.get_device()
223
    size = convert_shape_to_symint(size)
224
    stride = convert_shape_to_symint(stride)
225
    t = torch.empty_strided(
226
        size=size, stride=stride, dtype=dtype, device=device
227
    ).zero_()
228
    return t
229

230

231
def may_convert_to_optional(value):
232
    if isinstance(value, list) and not value:
233
        # [None] makes sure the cpp wrapper codegen will generate something like
234
        # {c10::nullopt} instead of {}
235
        return [None]
236
    return value
237

238

239
def get_device_type(x):
240
    if getattr(x, "get_device", None):
241
        return get_device_type(x.get_device())
242
    if isinstance(x, torch.device):
243
        return x.type
244
    return None
245

246

247
def is_triton(x):
248
    return get_device_type(x) == "cuda"
249

250

251
def is_cpu(x):
252
    return get_device_type(x) == "cpu"
253

254

255
class IRNode:
256
    _current_origins: ClassVar[Set[Any]] = set()
257

258
    @staticmethod
259
    @contextlib.contextmanager
260
    def current_origins(origins: Set[torch.fx.Node]):
261
        old = IRNode._current_origins
262
        IRNode._current_origins = old | origins
263
        try:
264
            yield
265
        finally:
266
            IRNode._current_origins = old
267

268
    def __post_init__(self):
269
        self.origins = set(self._current_origins)
270
        self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
271

272
    def get_traceback(self):
273
        return self.traceback
274

275
    def common_repr(self):
276
        origins = f"origins={getattr(self, 'origins', '')}"
277
        if len(origins) > 64:
278
            # this can get *very* long
279
            origins = f"{origins[:61]}..."
280
        return [origins]
281

282
    def str_helper(self, lines):
283
        lines = lines + self.common_repr()
284
        lines = indent(",\n".join(map(str, lines)))
285
        return f"{type(self).__name__}(\n{lines}\n)"
286

287
    def is_user_of(self, name):
288
        return name in self.get_read_names()
289

290
    @cache_on_self
291
    def get_read_names(self):
292
        return {dep.name for dep in self.get_reads()}
293

294
    def get_dtype(self):
295
        return self.dtype
296

297
    def get_layout(self):
298
        raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
299

300
    def get_size(self):
301
        raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
302

303
    def get_numel(self):
304
        return sympy_product(self.get_size())
305

306
    def is_zero_elements(self):
307
        return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))  # type: ignore[arg-type]
308

309
    def realize(self):
310
        """
311
        If the IRNode refers to data which has not been materialized (e.g.,
312
        it is a Pointwise/Reduction that could potentially have more
313
        compute fused into it), realize the IRNode into physical memory,
314
        ending the possibility of fusing into it, but allowing, e.g., multiple
315
        users to access the data without having to recompute.
316

317
        Check StorageBox.realize for a particularly notable implementation.
318

319
        TODO(ezyang): I think, in principle, every IRNode should have an
320
        implementation of this, and most of the time no-op is OK, but you
321
        really do have to audit each IRNode for this, so for now, raise
322
        an error if it's not implemented.  Note that some code in graph.py
323
        will catch this thrown error and suppress it with a warning.
324
        """
325
        raise NotImplementedError(f"realize NYI on {type(self)}")
326

327
    def codegen_reference(self, writer=None):
328
        raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
329

330
    # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions
331
    # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of
332
    # the code dynamically check for defined attributes.
333
    get_device: Callable[[], torch.device]
334
    dtype: torch.dtype
335
    get_name: Callable[[], str]
336
    get_reads: Callable[[], Any]
337
    get_stride: Callable[[], Any]
338
    get_storage_numel: Callable[[], Any]
339
    has_exceeded_max_reads: Callable[[], bool]
340
    make_loader: Callable[[], Callable[[Any], Any]]
341
    make_indexer: Callable[[], Callable[[Any], Any]]
342
    mark_reuse: Callable[[int], None]
343
    realize_hint: Callable[[], None]
344
    get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]]
345

346

347
@dataclasses.dataclass
348
class Loops(IRNode):
349
    device: torch.device
350
    dtype: torch.dtype
351
    inner_fn: Callable[..., Any]
352
    ranges: List[Expr]
353

354
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
355
        return set().union(
356
            *(free_unbacked_symbols(e) for e in self.ranges),
357
            self.inner_fn_free_unbacked_symbols(),
358
        )
359

360
    def __str__(self, names=("ranges",)):
361
        return self.str_helper(
362
            [
363
                f"'{self.device.type}'",
364
                str(self.dtype),
365
                self.inner_fn_str(),
366
            ]
367
            + [f"{name}={getattr(self, name)}" for name in names]
368
            + [f"origin_node={self.origin_node!r}"]
369
        )
370

371
    def __post_init__(self):
372
        super().__post_init__()
373
        self.origin_node = None
374

375
    __repr__ = __str__
376

377
    def get_device(self):
378
        return self.device
379

380
    def get_origin_node(self):
381
        return self.origin_node
382

383
    def get_size(self):
384
        return self.ranges
385

386
    def get_pointwise_size(self):
387
        return self.ranges
388

389
    def is_extern(self):
390
        return False
391

392
    @classmethod
393
    def create(cls, *args, **kwargs):
394
        origin_node = kwargs.pop("origin_node", None)
395
        tb = kwargs.pop("traceback", None)
396
        r = cls(*args, **kwargs)
397
        r.origin_node = origin_node
398
        r.traceback = (
399
            tb or traceback.format_stack() if config.debug_ir_traceback else None
400
        )
401
        return TensorBox.create(r)
402

403
    @staticmethod
404
    def _index(ranges, prefix="i"):
405
        return [
406
            sympy.Integer(0) if s == 1 else sympy_index_symbol(f"{prefix}{n}")
407
            for n, s in enumerate(ranges)
408
        ]
409

410
    @cache_on_self
411
    def inner_fn_opcount(self):
412
        from .ir import FlexibleLayout
413

414
        opcounter = OpCounterCSE(V.MockHandler())
415

416
        with V.set_ops_handler(opcounter), patch.object(
417
            FlexibleLayout, "allow_indexing", True
418
        ):
419
            result = self.inner_fn(*self.inner_fn_args())
420
            return opcounter.op_count
421

422
    def inner_fn_args(self):
423
        return (self._index(self.ranges),)
424

425
    def inner_fn_str(self):
426
        return V.KernelFormatterHandler.ir_to_string(
427
            self.inner_fn, *self.inner_fn_args()
428
        )
429

430
    def has_large_inner_fn(self):
431
        return self.inner_fn_opcount() > config.realize_opcount_threshold
432

433
    def inner_fn_free_unbacked_symbols(self):
434
        index = self._index(self.ranges)
435
        return extract_free_unbacked_symbols(self.inner_fn, index)
436

437
    def get_reads(self):
438
        with patch.object(FlexibleLayout, "allow_indexing", True):
439
            if self.get_reduction_type():
440
                return extract_read_writes(
441
                    self.make_loader(),
442
                    self.get_size(),
443
                    self.get_reduction_size(),
444
                ).reads
445
            else:
446
                return extract_read_writes(
447
                    self.make_loader(),
448
                    self.get_size(),
449
                ).reads
450

451
    def get_reduction_size(self):
452
        raise NotImplementedError(
453
            f"get_reduction_size() is not implemented by {type(self)}!"
454
        )
455

456
    def get_reduction_type(self):
457
        raise NotImplementedError(
458
            f"get_reduction_type() is not implemented by {type(self)}!"
459
        )
460

461
    def constant_to_device(self, device):
462
        raise NotImplementedError(
463
            f"constant_to_device() is not implemented by {type(self)}!"
464
        )
465

466

467
def nop_loader_fn(idx, *, dtype):
468
    if dtype.is_floating_point:
469
        return ops.constant(float("nan"), dtype)
470
    else:
471
        return ops.constant(0, dtype)
472

473

474
class Pointwise(Loops):
475
    def make_loader(self):
476
        # Make zero-element loops into a no-op
477
        if self.is_zero_elements():
478
            return partial(nop_loader_fn, dtype=self.dtype)
479

480
        return self.inner_fn
481

482
    def get_reduction_size(self):
483
        return []
484

485
    def get_reduction_type(self):
486
        return None
487

488
    def store_output(self, output_name, indexer, vars):
489
        loader = self.make_loader()
490
        return ops.store(output_name, indexer(vars), loader(vars))
491

492
    def constant_to_device(self, device):
493
        """Move this to a given device. Requires that all reads are to constants."""
494
        loader = self.make_loader()
495
        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
496
        return Pointwise(device, self.dtype, loader, self.ranges)
497

498

499
@dataclasses.dataclass
500
class Scatter(Pointwise):
501
    output_indexer: Callable[[List[Expr]], Expr]
502
    scatter_mode: Optional[str] = None
503

504
    def constant_to_device(self, device):
505
        """Move this to a given device. Requires that all reads are to constants."""
506
        loader = self.make_loader()
507
        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
508
        return Scatter(
509
            device,
510
            self.dtype,
511
            loader,
512
            self.ranges,
513
            self.output_indexer,
514
            self.scatter_mode,
515
        )
516

517
    def store_output(self, output_name, indexer, vars):
518
        loader = self.make_loader()
519
        return ops.store(
520
            output_name,
521
            indexer(self.output_indexer(vars)),
522
            loader(vars),
523
            mode=self.scatter_mode,
524
        )
525

526

527
class ReductionHint(Enum):
528
    INNER = 0
529
    OUTER = 1
530
    OUTER_TINY = 2
531
    DEFAULT = 3
532

533

534
class TileHint(Enum):
535
    SQUARE = 0
536
    DEFAULT = 1
537

538

539
REDUCTION_COMBINE_FN = {
540
    "any": ops_wrapper("logical_or"),
541
    "max": ops_wrapper("maximum"),
542
    "min": ops_wrapper("minimum"),
543
    "prod": ops_wrapper("mul"),
544
    "sum": ops_wrapper("add"),
545
    "xor_sum": ops_wrapper("bitwise_xor"),
546
}
547

548

549
def get_reduction_combine_fn(reduction_type, dtype):
550
    if reduction_type in REDUCTION_COMBINE_FN:
551
        combine_fn = REDUCTION_COMBINE_FN[reduction_type]
552
    elif reduction_type in {"argmax", "argmin"}:
553

554
        def combine_fn(a, b):
555
            a_value, a_index = a
556
            b_value, b_index = b
557

558
            if reduction_type == "argmin":
559
                mask = ops.lt(a_value, b_value)
560
            else:
561
                mask = ops.gt(a_value, b_value)
562

563
            equal = ops.eq(a_value, b_value)
564
            if is_float_dtype(dtype):
565
                a_isnan = ops.ne(a_value, a_value)
566
                b_isnan = ops.ne(b_value, b_value)
567
                mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
568
                equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
569

570
            mask = ops.logical_or(
571
                mask, ops.logical_and(equal, ops.lt(a_index, b_index))
572
            )
573
            return (
574
                ops.where(mask, a_value, b_value),
575
                ops.where(mask, a_index, b_index),
576
            )
577

578
    elif reduction_type == "welford_combine":
579

580
        def combine_fn(a, b):
581
            a_mean, a_m2, a_weight = a
582
            b_mean, b_m2, b_weight = b
583

584
            delta = b_mean - a_mean
585
            new_weight = a_weight + b_weight
586
            w2_over_w = b_weight / new_weight
587
            return (
588
                a_mean + delta * w2_over_w,
589
                a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
590
                new_weight,
591
            )
592

593
    else:
594
        raise NotImplementedError(f"unknown reduction_type={reduction_type}")
595

596
    return combine_fn
597

598

599
@dataclasses.dataclass
600
class Reduction(Loops):
601
    reduction_ranges: List[Expr]
602
    reduction_type: str
603
    # self.dtype represents the dst dtype
604
    src_dtype: torch.dtype
605
    reduction_hint: ReductionHint
606

607
    def __str__(self):
608
        return Loops.__str__(  # type: ignore[call-arg]
609
            self, names=("ranges", "reduction_ranges", "reduction_type")
610
        )
611

612
    def __repr__(self):
613
        return self.__str__()
614

615
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
616
        return super().get_unbacked_symbol_uses() | set().union(
617
            *(free_unbacked_symbols(e) for e in self.reduction_ranges)
618
        )
619

620
    def get_reduction_size(self):
621
        return self.reduction_ranges
622

623
    def get_reduction_type(self):
624
        return self.reduction_type
625

626
    def store_reduction(self, output_name, indexer, vars, reduction_vars):
627
        value = ops.reduction(
628
            self.dtype,
629
            self.src_dtype,
630
            self.reduction_type,
631
            self.inner_fn(vars, reduction_vars),
632
        )
633
        return ops.store_reduction(output_name, indexer(vars), value)
634

635
    def index_length(self):
636
        return len(self.ranges) + len(self.reduction_ranges)
637

638
    def inner_fn_args(self):
639
        index = self._index(self.ranges)
640
        rindex = self._index(self.reduction_ranges, "r")
641
        return (index, rindex)
642

643
    def inner_fn_free_unbacked_symbols(self):
644
        index = self._index(self.ranges)
645
        rindex = self._index(self.reduction_ranges, "r")
646
        return extract_free_unbacked_symbols(self.inner_fn, index, rindex)
647

648
    def constant_to_device(self, device):
649
        """Move this to a given device. Requires that all reads are to constants."""
650
        loader = self.make_loader()
651
        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
652
        return Reduction(
653
            device,
654
            self.dtype,
655
            loader,
656
            self.ranges,
657
            self.reduction_ranges,
658
            self.reduction_type,
659
            self.src_dtype,
660
            ReductionHint.DEFAULT,
661
        )
662

663
    @staticmethod
664
    def num_splits(
665
        device,
666
        dst_dtype,
667
        src_dtype,
668
        inner_fn,
669
        ranges,
670
        reduction_ranges,
671
        reduction_type,
672
        reduction_numel,
673
        input_node: Optional[IRNode] = None,
674
    ):
675
        def _is_static(x):
676
            return isinstance(x, (int, sympy.Integer))
677

678
        reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
679
        numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
680

681
        should_split = (
682
            is_triton(device)
683
            and reduction_type
684
            not in {
685
                "argmax",
686
                "argmin",
687
            }
688
            and config.split_reductions
689
            # We don't support unbacked symints
690
            and _is_static(reduction_numel_hint)
691
            and _is_static(numel_hint)
692
        )
693
        if not should_split:
694
            return ReductionHint.DEFAULT, 1
695

696
        device_interface = get_interface_for_device(get_device_type(device))
697
        num_sm = device_interface.Worker.get_device_properties(
698
            device
699
        ).multi_processor_count
700
        min_elements_per_thread = 32
701
        max_elements_per_thread = 512
702
        threads_per_sm = 2048
703
        min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
704
        max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
705

706
        def inner_reduction_splits(reduction_numel_hint, numel_hint):
707
            # do heuristics that's close to eager mode for split inner reduction
708
            # we leak reduction autotune configs here, and will need to refactor to avoid this later
709
            num_warps = 8
710
            num_threads = 32 * num_warps
711
            if numel_hint >= 2 * num_sm:  # don't split if there are enough outputs
712
                return 1
713
            if reduction_numel_hint <= 8192:
714
                return 1
715
            if reduction_numel_hint * numel_hint <= min_elements_per_device:
716
                split_size = min_elements_per_thread
717
            elif reduction_numel_hint * numel_hint < max_elements_per_device:
718
                target_blocks = num_sm * threads_per_sm // (2 * num_threads)
719
                blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
720
                tmp_split_size = (
721
                    reduction_numel_hint + num_threads * blocks_per_output - 1
722
                ) // (num_threads * blocks_per_output)
723
                divisors = sympy.divisors(reduction_numel_hint)
724
                closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
725
                if abs(closest - tmp_split_size) < 30:
726
                    # prefer even splits, but never smalle than min_elements_per_thread
727
                    split_size = max(closest, min_elements_per_thread)
728
                else:
729
                    split_size = tmp_split_size
730
            else:
731
                divisors = sympy.divisors(reduction_numel_hint)
732
                closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
733
                if abs(closest - max_elements_per_thread) < 50:
734
                    # prefer even splits
735
                    split_size = closest
736
                else:
737
                    split_size = max_elements_per_thread
738
            return (reduction_numel_hint + split_size * num_threads - 1) // (
739
                split_size * num_threads
740
            )
741

742
        def outer_reduction_splits(reduction_numel_hint, numel_hint):
743
            # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
744
            # extend to even smaller number of outputs
745
            num_warps = 8
746
            num_threads = num_warps * 32
747
            rvals_per_thread = 4  # comes from heuristics, refactor to not leak here
748
            xvals_per_block = 128
749
            xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
750
            if reduction_numel_hint * numel_hint < min_elements_per_device:
751
                split_size = min_elements_per_thread
752
            elif reduction_numel_hint * numel_hint < max_elements_per_device:
753
                target_blocks = num_sm * threads_per_sm // (num_threads)
754
                target_blocks = (target_blocks + xblocks - 1) // xblocks
755
                tmp_split_size = (
756
                    reduction_numel_hint + rvals_per_thread * target_blocks - 1
757
                ) // (rvals_per_thread * target_blocks)
758
                divisors = sympy.divisors(reduction_numel_hint)
759
                closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
760
                if abs(tmp_split_size - closest) < 20:
761
                    split_size = max(closest, min_elements_per_thread)
762
                else:
763
                    split_size = tmp_split_size
764
            else:
765
                divisors = sympy.divisors(reduction_numel_hint)
766
                closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
767
                if abs(closest - max_elements_per_thread) < 50:
768
                    # prefer even splits
769
                    split_size = closest
770
                else:
771
                    split_size = max_elements_per_thread
772

773
            return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
774
                rvals_per_thread * split_size
775
            )
776

777
        # easy cases
778
        if numel_hint == 1:
779
            split = inner_reduction_splits(reduction_numel_hint, numel_hint)
780
            if split == 1:
781
                # No need to split.
782
                return ReductionHint.INNER, split
783
            if (
784
                len(ranges) == 0
785
                and input_node is not None
786
                and isinstance(input_node, TensorBox)
787
            ):
788
                # Only handles the case where keep_dim = False.
789
                # Otherwise, we need to propagate reduction dim info to the stage where
790
                # the intermediate loader of the first Reduction is generated.
791
                new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
792
                    input_node
793
                )
794
                if new_ranges is not None and new_reduction_ranges is not None:
795
                    extracted_numel_hint = V.graph.sizevars.symbolic_hint(
796
                        sympy_product(new_ranges + new_reduction_ranges)
797
                    )
798
                    if reduction_numel_hint == extracted_numel_hint:
799
                        log.debug(
800
                            "Use previous IRNode's range and reduction_ranges instead of split. "
801
                            "current ranges: %s, current reduction ranges: %s, current split: %d, "
802
                            "new ranges: %s, new reduction ranges: %s",
803
                            ranges,
804
                            reduction_ranges,
805
                            split,
806
                            new_ranges,
807
                            new_reduction_ranges,
808
                        )
809
                        # If the input_node or its dependent nodes are also Reduction nodes,
810
                        # use reduction_sizes of this node or its dependent nodes directly.
811
                        return ReductionHint.INNER, -1
812
            return ReductionHint.INNER, split
813
        if (
814
            reduction_numel_hint <= min_elements_per_thread
815
            or numel_hint >= num_sm * 2 * 32
816
        ):
817
            return ReductionHint.DEFAULT, 1
818

819
        r = Reduction(
820
            device,
821
            dst_dtype,
822
            inner_fn,
823
            ranges,
824
            reduction_ranges,
825
            reduction_type,
826
            src_dtype,
827
            ReductionHint.DEFAULT,
828
        )
829

830
        def get_read_indices(r):
831
            cb = ComputedBuffer(
832
                name=None,
833
                layout=FlexibleLayout(
834
                    device=r.get_device(),
835
                    dtype=r.get_dtype(),
836
                    size=r.get_size(),
837
                ),
838
                data=r,
839
            )
840
            read_writes = cb.get_read_writes()
841
            # try finding the full size producer
842
            # TODO this will fail for something like ((1, N) * (N, 1)).sum()
843
            # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
844
            range_vars = [
845
                r
846
                for r in read_writes.range_vars
847
                if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
848
            ]
849
            indices = []
850
            changed = False
851
            for md in sorted(read_writes.reads, key=lambda x: x.name):
852
                if all(r in md.index.free_symbols for r in range_vars):
853
                    indices.append(md.index)
854
                    if md.name in V.graph.name_to_buffer:
855
                        buf = V.graph.name_to_buffer[md.name]
856
                        original_stride = buf.layout.stride
857
                        buf.decide_layout()
858
                        if buf.layout.stride != original_stride:
859
                            changed = True
860
            return indices, changed
861

862
        indices, changed = get_read_indices(r)
863
        if changed:
864
            indices, _ = get_read_indices(r)
865

866
        if len(indices) == 0:
867
            # TODO determine splits when all inputs are broadcast
868
            return ReductionHint.DEFAULT, 1
869

870
        (_, reduction_vars), ranges = dependencies.index_vars_squeeze(
871
            r.get_size(), r.get_reduction_size()
872
        )
873
        num_outer = 0
874
        num_inner = 0
875
        for i in indices:
876
            i = V.graph.sizevars.simplify_with_ranges(i, ranges)
877
            strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys())
878
            outer = all(s > 1 for s in strides)
879
            if outer:
880
                num_outer += 1
881
            else:
882
                num_inner += 1
883
        if num_inner > num_outer:
884
            return ReductionHint.INNER, inner_reduction_splits(
885
                reduction_numel_hint, numel_hint
886
            )
887
        else:
888
            return ReductionHint.OUTER, outer_reduction_splits(
889
                reduction_numel_hint, numel_hint
890
            )
891

892
    @staticmethod
893
    def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype):
894
        """Convert inner_fn from a reduction to an pointwise"""
895
        reduction_ranges = [
896
            V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges
897
        ]
898

899
        combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
900

901
        def fn(index):
902
            return functools.reduce(
903
                combine_fn,
904
                (
905
                    value_fn(index, rindex)
906
                    for rindex in itertools.product(
907
                        *[range(x) for x in reduction_ranges]
908
                    )
909
                ),
910
            )
911

912
        if reduction_type in ("argmin", "argmax"):
913
            flatten_index = FixedLayout(
914
                None,  # type: ignore[arg-type]
915
                None,  # type: ignore[arg-type]
916
                reduction_ranges,
917
                FlexibleLayout.contiguous_strides(reduction_ranges),
918
            ).make_indexer()
919

920
            def value_fn(index, rindex):
921
                rindex = [sympy.expand(i) for i in rindex]
922
                return (
923
                    inner_fn(index, rindex),
924
                    ops.index_expr(flatten_index(rindex), torch.int64),
925
                )
926

927
            return lambda index: fn(index)[1]
928
        else:
929
            value_fn = inner_fn
930
            return fn
931

932
    @classmethod
933
    def create(  # type: ignore[override]
934
        cls,
935
        device: torch.device,
936
        dst_dtype: torch.dtype,
937
        src_dtype: torch.dtype,
938
        inner_fn: Callable[..., Any],
939
        ranges: List[Expr],
940
        reduction_ranges: List[Expr],
941
        reduction_type: str,
942
        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
943
        input_node: Optional[IRNode] = None,
944
    ):
945
        reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
946

947
        if reduction_numel == 0:
948
            # N.B. This is a hack to generate the literal of the given type
949
            # Ideally, we should be fixing `def constant` in triton.py
950
            # but it breaks due to hardcoded dtypes in other places
951
            def py_cnst(val):
952
                return (
953
                    bool(val)
954
                    if dst_dtype == torch.bool
955
                    else float(val)
956
                    if dst_dtype.is_floating_point
957
                    else int(val)
958
                )
959

960
            rtypes_to_inits = {
961
                "sum": py_cnst(0),
962
                "xor_sum": py_cnst(0),
963
                "prod": py_cnst(1),
964
                "any": py_cnst(0),
965
                # "all" is desugared to `!any(!val)`
966
            }
967

968
            assert (
969
                reduction_type in rtypes_to_inits.keys()
970
            ), f"{reduction_type} not supported for zero-dimension tensors!"
971

972
            def const_fn(index):
973
                return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
974

975
            return Pointwise.create(
976
                device=device,
977
                dtype=src_dtype,
978
                inner_fn=const_fn,
979
                ranges=list(ranges),
980
            )
981

982
        if reduction_numel == 1:
983
            # this reduction is actually a pointwise op
984
            if reduction_type in ("argmin", "argmax"):
985

986
                def fn(index):
987
                    return ops.constant(0, dst_dtype)
988

989
            else:
990

991
                def fn(index):
992
                    reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
993
                    return inner_fn(index, reduction_index)
994

995
            return Pointwise.create(device, dst_dtype, fn, ranges)
996

997
        if (
998
            isinstance(reduction_numel, sympy.Integer)
999
            and V.graph.sizevars.size_hint(reduction_numel)
1000
            < config.unroll_reductions_threshold
1001
            and sympy_product(ranges) != 1
1002
        ):
1003
            return Pointwise.create(
1004
                device,
1005
                dst_dtype,
1006
                cls._unroll_reduction_fn(
1007
                    inner_fn, reduction_ranges, reduction_type, src_dtype
1008
                ),
1009
                ranges,
1010
            )
1011

1012
        # triton doesn't support reduce to single element well, so break it up
1013
        hint, split = cls.num_splits(
1014
            device,
1015
            dst_dtype,
1016
            src_dtype,
1017
            inner_fn,
1018
            ranges,
1019
            reduction_ranges,
1020
            reduction_type,
1021
            reduction_numel,
1022
            input_node,
1023
        )
1024
        # intermediate reduction in split can contain complex indexing,
1025
        # and num_splits will fail to correctly set the hint
1026
        # reuse the passed hint if available
1027
        if reduction_hint == ReductionHint.DEFAULT:
1028
            reduction_hint = hint
1029
        if split == -1:
1030
            assert input_node is not None
1031
            new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
1032
                input_node  # type: ignore[arg-type]
1033
            )
1034
            assert new_ranges is not None
1035
            assert new_reduction_ranges is not None
1036
            return cls.create_multilayer_existing_ranges(
1037
                device,
1038
                dst_dtype,
1039
                src_dtype,
1040
                inner_fn,
1041
                ranges,
1042
                reduction_ranges,
1043
                new_ranges,
1044
                new_reduction_ranges,
1045
                reduction_type,
1046
                reduction_hint,
1047
            )
1048
        elif split > 1:
1049
            # triton doesn't support reduce to single element well, so break it up
1050
            return cls.create_multilayer(
1051
                device,
1052
                dst_dtype,
1053
                src_dtype,
1054
                inner_fn,
1055
                ranges,
1056
                reduction_ranges,
1057
                reduction_type,
1058
                split,
1059
                reduction_hint,
1060
            )
1061

1062
        return TensorBox.create(
1063
            Reduction(
1064
                device,
1065
                dst_dtype,
1066
                inner_fn,
1067
                ranges,
1068
                reduction_ranges,
1069
                reduction_type,
1070
                src_dtype,
1071
                reduction_hint,
1072
            )
1073
        )
1074

1075
    @staticmethod
1076
    def default_accumulator(reduction_type, dtype):
1077
        if reduction_type in {"max", "argmax"}:
1078
            if is_float_dtype(dtype):
1079
                return float("-inf")
1080
            elif is_boolean_dtype(dtype):
1081
                return 0
1082
            else:
1083
                return torch.iinfo(dtype).min
1084
        if reduction_type in {"min", "argmin"}:
1085
            if is_float_dtype(dtype):
1086
                return float("inf")
1087
            elif is_boolean_dtype(dtype):
1088
                return 1
1089
            else:
1090
                return torch.iinfo(dtype).max
1091

1092
        return {
1093
            "sum": 0,
1094
            "prod": 1,
1095
            "xor_sum": 0,
1096
            "any": 0,
1097
            "welford_reduce": (0, 0, 0),
1098
            "welford_combine": (0, 0, 0),
1099
        }[reduction_type]
1100

1101
    @staticmethod
1102
    def default_value(reduction_type, dtype):
1103
        if reduction_type == "welford_reduce":
1104
            return 0
1105
        return Reduction.default_accumulator(reduction_type, dtype)
1106

1107
    @staticmethod
1108
    def _multilayer_second_step_hint(
1109
        split: int, numel_hint: int, reduction_hint: ReductionHint
1110
    ) -> ReductionHint:
1111
        if split == -1:
1112
            return reduction_hint
1113
        if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
1114
            return ReductionHint.OUTER_TINY
1115
        if (
1116
            split <= 1024
1117
            and numel_hint <= 256
1118
            and reduction_hint == ReductionHint.OUTER
1119
        ):
1120
            return ReductionHint.OUTER_TINY
1121

1122
        return reduction_hint
1123

1124
    @classmethod
1125
    def _multilayer_wrap_loader(
1126
        cls,
1127
        loader,
1128
        reduction_ranges,
1129
        reduction_numel,
1130
        split,
1131
        block_size,
1132
        default,
1133
    ):
1134
        reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
1135
        need_mask = not V.graph.sizevars.is_expr_static_and_true(
1136
            sympy.Eq(reduction_numel % split, 0)  # type: ignore[arg-type]
1137
        )
1138

1139
        def wrapper_fn(index, reduction_index):
1140
            (reduction_index,) = reduction_index
1141
            *new_index, reduction_block = index
1142
            indices = block_size * reduction_block + reduction_index
1143

1144
            def body():
1145
                return loader(new_index, reindex([indices]))
1146

1147
            if need_mask:
1148
                mask = ops.lt(
1149
                    ops.index_expr(indices, torch.int32),
1150
                    ops.index_expr(reduction_numel, torch.int32),
1151
                )
1152
                return ops.masked(mask, body, default)
1153
            else:
1154
                return body()
1155

1156
        return wrapper_fn
1157

1158
    @classmethod
1159
    def _multilayer_wrap_loader_existing_ranges(
1160
        cls,
1161
        loader,
1162
        original_ranges,
1163
        original_reduction_ranges,
1164
        new_ranges,
1165
        new_reduction_ranges,
1166
        default,
1167
    ):
1168
        assert len(original_ranges) == 0, f"{original_ranges}= is not equal to []"
1169
        reindex = View.dynamic_reshape_indexer(
1170
            original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
1171
        )
1172

1173
        def wrapper_fn(index, reduction_index):
1174
            return loader([], reindex(tuple(index) + tuple(reduction_index)))
1175

1176
        return wrapper_fn
1177

1178
    @classmethod
1179
    def create_multilayer_helper(
1180
        cls,
1181
        device: torch.device,
1182
        dst_dtype: torch.dtype,
1183
        src_dtype: torch.dtype,
1184
        wrapper_fn: Callable[..., Any],
1185
        original_ranges: List[Expr],
1186
        original_reduction_ranges: List[Expr],
1187
        new_ranges: List[Expr],
1188
        new_reduction_ranges: List[Expr],
1189
        reduction_type: str,
1190
        split: int,
1191
        reduction_hint: ReductionHint,
1192
    ):
1193
        """
1194
        Break a large reduction up into multiple smaller reductions
1195
        recursively
1196
        """
1197
        # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
1198
        # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
1199
        # in fp32 and not reduce precision by breaking up the kernel into multiple layers
1200
        intermediate_dtype = (
1201
            dst_dtype
1202
            if dst_dtype not in (torch.float16, torch.bfloat16)
1203
            else torch.float
1204
        )
1205
        intermediate = Reduction.create(
1206
            device,
1207
            intermediate_dtype,
1208
            src_dtype,
1209
            wrapper_fn,
1210
            new_ranges,
1211
            new_reduction_ranges,
1212
            reduction_type,
1213
            reduction_hint,
1214
        )
1215
        intermediate.realize()
1216
        intermediate_loader = intermediate.make_loader()
1217

1218
        def intermediate_fn(index, reduction_index):
1219
            return intermediate_loader([*index, *reduction_index])
1220

1221
        numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges))
1222
        reduction_hint = cls._multilayer_second_step_hint(
1223
            split, numel_hint, reduction_hint
1224
        )
1225

1226
        assert original_ranges == new_ranges[: len(original_ranges)]
1227
        return TensorBox.create(
1228
            Reduction(
1229
                device,
1230
                dst_dtype,
1231
                intermediate_fn,
1232
                original_ranges,
1233
                new_ranges[len(original_ranges) :],
1234
                reduction_type,
1235
                src_dtype,
1236
                reduction_hint,
1237
            )
1238
        )
1239

1240
    @classmethod
1241
    def create_multilayer(
1242
        cls,
1243
        device: torch.device,
1244
        dst_dtype: torch.dtype,
1245
        src_dtype: torch.dtype,
1246
        inner_fn: Callable[..., Any],
1247
        ranges: List[Expr],
1248
        reduction_ranges: List[Expr],
1249
        reduction_type: str,
1250
        split: int,
1251
        reduction_hint: ReductionHint,
1252
    ):
1253
        """
1254
        Break a large reduction up into multiple smaller reductions
1255
        recursively
1256
        """
1257
        # TODO(jansel): realize the reduction so we can do dynamic indexing
1258
        reduction_numel = sympy_product(reduction_ranges)
1259
        block_size = FloorDiv(reduction_numel + (split - 1), split)
1260
        default = cls.default_value(reduction_type, dst_dtype)
1261
        wrapper_fn = cls._multilayer_wrap_loader(
1262
            inner_fn, reduction_ranges, reduction_numel, split, block_size, default
1263
        )
1264

1265
        return cls.create_multilayer_helper(
1266
            device,
1267
            dst_dtype,
1268
            src_dtype,
1269
            wrapper_fn,
1270
            ranges,
1271
            reduction_ranges,
1272
            [*ranges, split],  # type: ignore[list-item]
1273
            [block_size],
1274
            reduction_type,
1275
            split,
1276
            reduction_hint,
1277
        )
1278

1279
    @classmethod
1280
    def create_multilayer_existing_ranges(
1281
        cls,
1282
        device: torch.device,
1283
        dst_dtype: torch.dtype,
1284
        src_dtype: torch.dtype,
1285
        inner_fn: Callable[..., Any],
1286
        original_ranges: List[Expr],
1287
        original_reduction_ranges: List[Expr],
1288
        new_ranges: List[Expr],
1289
        new_reduction_ranges: List[Expr],
1290
        reduction_type: str,
1291
        reduction_hint: ReductionHint,
1292
    ):
1293
        """
1294
        Break a large reduction up into multiple smaller reductions
1295
        recursively
1296
        """
1297
        default = cls.default_value(reduction_type, dst_dtype)
1298
        wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
1299
            inner_fn,
1300
            original_ranges,
1301
            original_reduction_ranges,
1302
            new_ranges,
1303
            new_reduction_ranges,
1304
            default,
1305
        )
1306
        return cls.create_multilayer_helper(
1307
            device,
1308
            dst_dtype,
1309
            src_dtype,
1310
            wrapper_fn,
1311
            original_ranges,
1312
            original_reduction_ranges,
1313
            new_ranges,
1314
            new_reduction_ranges,
1315
            reduction_type,
1316
            -1,
1317
            reduction_hint,
1318
        )
1319

1320

1321
def num_reduction_outputs(reduction_type):
1322
    return 3 if "welford" in reduction_type else 1
1323

1324

1325
class WelfordReduction(Reduction):
1326
    output_index: int
1327

1328
    def __init__(
1329
        self,
1330
        device,
1331
        dtype,
1332
        inner_fns,
1333
        ranges,
1334
        reduction_ranges,
1335
        reduction_type,
1336
        reduction_hint,
1337
        output_index,
1338
    ):
1339
        if len(inner_fns) == 1:
1340
            loader = inner_fns[0]
1341
        else:
1342

1343
            def loader(idx, reduction_idx):
1344
                return tuple(fn(idx, reduction_idx) for fn in inner_fns)
1345

1346
        super().__init__(
1347
            device,
1348
            dtype,
1349
            loader,
1350
            ranges,
1351
            reduction_ranges,
1352
            reduction_type,
1353
            dtype,
1354
            reduction_hint,
1355
        )
1356
        self.output_index = output_index
1357

1358
    def store_reduction(self, output_name, indexer, vars, reduction_vars):
1359
        values = ops.reduction(
1360
            self.dtype,
1361
            self.src_dtype,
1362
            self.reduction_type,
1363
            self.inner_fn(vars, reduction_vars),
1364
        )
1365
        value = values[self.output_index]
1366
        return ops.store_reduction(output_name, indexer(vars), value)
1367

1368
    @classmethod
1369
    def create(  # type: ignore[override]
1370
        cls,
1371
        device: torch.device,
1372
        dtype: torch.dtype,
1373
        inner_fns: Sequence[Callable[..., Any]],
1374
        ranges: List[Expr],
1375
        reduction_ranges: List[Expr],
1376
        reduction_type: str,
1377
        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1378
    ):
1379
        assert reduction_type in {"welford_reduce", "welford_combine"}
1380

1381
        reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
1382

1383
        def const(val):
1384
            def inner_fn(idx):
1385
                return ops.constant(
1386
                    val,
1387
                    dtype,
1388
                )
1389

1390
            return Pointwise.create(
1391
                device=device,
1392
                dtype=dtype,
1393
                inner_fn=inner_fn,
1394
                ranges=list(ranges),
1395
            )
1396

1397
        if reduction_numel == 0:
1398
            mean = const(0)
1399
            m2 = const(0)
1400
            weight = const(0)
1401
            return mean, m2, weight
1402

1403
        if reduction_numel == 1:
1404

1405
            def copy(loader):
1406
                def inner_fn(idx):
1407
                    reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
1408
                    return loader(idx, reduction_index)
1409

1410
                return Pointwise.create(
1411
                    device=device,
1412
                    dtype=dtype,
1413
                    inner_fn=inner_fn,
1414
                    ranges=list(ranges),
1415
                )
1416

1417
            if reduction_type == "welford_reduce":
1418
                return copy(inner_fns[0]), const(0), const(1)
1419
            else:
1420
                return tuple(copy(fn) for fn in inner_fns)
1421

1422
        # TODO: Unrolled reduction
1423
        # if (
1424
        #     isinstance(reduction_numel, sympy.Integer)
1425
        #     and V.graph.sizevars.size_hint(reduction_numel)
1426
        #     < config.unroll_reductions_threshold
1427
        #     and sympy_product(ranges) != 1
1428
        # ):
1429
        #     return Pointwise.create(
1430
        #         device,
1431
        #         dst_dtype,
1432
        #         cls._unroll_reduction_fn(
1433
        #             inner_fn, reduction_ranges, reduction_type, src_dtype
1434
        #         ),
1435
        #         ranges,
1436
        #     )
1437

1438
        # triton doesn't support reduce to single element well, so break it up
1439
        hint, split = Reduction.num_splits(
1440
            device,
1441
            dtype,
1442
            dtype,
1443
            inner_fns[0],
1444
            ranges,
1445
            reduction_ranges,
1446
            reduction_type=reduction_type,
1447
            reduction_numel=reduction_numel,
1448
        )
1449
        # intermediate reduction in split can contain complex indexing,
1450
        # and num_splits will fail to correctly set the hint
1451
        # reuse the passed hint if available
1452
        if reduction_hint == ReductionHint.DEFAULT:
1453
            reduction_hint = hint
1454
        if split > 1:
1455
            # triton doesn't support reduce to single element well, so break it up
1456
            return cls.create_multilayer(
1457
                device,
1458
                dtype,
1459
                inner_fns,
1460
                ranges,
1461
                reduction_ranges,
1462
                reduction_type,
1463
                split,
1464
                reduction_hint,
1465
            )
1466

1467
        results = [
1468
            TensorBox.create(
1469
                WelfordReduction(
1470
                    device,
1471
                    dtype,
1472
                    inner_fns,
1473
                    ranges,
1474
                    reduction_ranges,
1475
                    reduction_type,
1476
                    reduction_hint,
1477
                    output_idx,
1478
                )
1479
            )
1480
            for output_idx in range(3)
1481
        ]
1482
        for t in results:
1483
            t.realize()
1484
        return results
1485

1486
    @staticmethod
1487
    def default_value(reduction_type, dtype):
1488
        return (0, 0, 0)
1489

1490
    @classmethod
1491
    def create_multilayer(  # type: ignore[override]
1492
        cls,
1493
        device: torch.device,
1494
        dtype: torch.dtype,
1495
        inner_fns: Sequence[Callable[..., Any]],
1496
        ranges: List[Expr],
1497
        reduction_ranges: List[Expr],
1498
        reduction_type: str,
1499
        split: int,
1500
        reduction_hint: ReductionHint,
1501
    ):
1502
        """
1503
        Break a large reduction up into multiple smaller reductions
1504
        recursively
1505
        """
1506
        reduction_numel = sympy_product(reduction_ranges)
1507
        need_mask = not V.graph.sizevars.is_expr_static_and_true(
1508
            sympy.Eq(reduction_numel % split, 0)  # type: ignore[arg-type]
1509
        )
1510

1511
        if need_mask and reduction_type != "welford_combine":
1512
            # If we need mask, then "welford_reduce" doesn't work because
1513
            # masked inputs shouldn't count towards the welford weight
1514

1515
            def constant(idx, reduction_idx, value):
1516
                return ops.constant(value, dtype)
1517

1518
            return cls.create_multilayer(
1519
                device=device,
1520
                dtype=dtype,
1521
                inner_fns=(
1522
                    inner_fns[0],
1523
                    partial(constant, value=0),
1524
                    partial(constant, value=1),
1525
                ),
1526
                ranges=ranges,
1527
                reduction_ranges=reduction_ranges,
1528
                reduction_type="welford_combine",
1529
                split=split,
1530
                reduction_hint=reduction_hint,
1531
            )
1532

1533
        block_size = FloorDiv(reduction_numel + (split - 1), split)
1534
        intermediates = WelfordReduction.create(
1535
            device,
1536
            dtype,
1537
            tuple(
1538
                cls._multilayer_wrap_loader(
1539
                    loader,
1540
                    reduction_ranges,
1541
                    reduction_numel,
1542
                    split,
1543
                    block_size,
1544
                    default=0,
1545
                )
1546
                for loader in inner_fns
1547
            ),
1548
            [*ranges, split],  # type: ignore[list-item]
1549
            [block_size],
1550
            reduction_type,
1551
            reduction_hint,
1552
        )
1553
        for i in intermediates:
1554
            i.realize()
1555

1556
        i_loaders = [i.make_loader() for i in intermediates]
1557

1558
        def intermediate_loader_fn(index, reduction_index, loader):
1559
            return loader([*index, *reduction_index])
1560

1561
        numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
1562
        reduction_hint = cls._multilayer_second_step_hint(
1563
            split, numel_hint, reduction_hint
1564
        )
1565
        return WelfordReduction.create(
1566
            device,
1567
            dtype,
1568
            tuple(
1569
                partial(intermediate_loader_fn, loader=i.make_loader())
1570
                for i in intermediates
1571
            ),
1572
            ranges,
1573
            [split],  # type: ignore[list-item]
1574
            # welford_reduce turns one input into three outputs, which are combined with welford_combine
1575
            "welford_combine",
1576
            reduction_hint,
1577
        )
1578

1579

1580
@dataclasses.dataclass
1581
class Scan(Loops):
1582
    scan_ranges: List[Expr]
1583
    size: List[Expr]
1584
    combine_fn: Callable[..., Any]
1585
    reindex: Callable[[List[Expr], List[Expr]], List[Expr]]
1586
    reduction_hint: ReductionHint
1587
    init: int
1588

1589
    # HACK we mimick reduction
1590

1591
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
1592
        # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
1593
        # need to explicitly represent the closure so we can pull out unbacked
1594
        # symbols here
1595
        return (
1596
            super().get_unbacked_symbol_uses()
1597
            | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges))
1598
            | set().union(*(free_unbacked_symbols(e) for e in self.size))
1599
        )
1600

1601
    def __post_init__(self):
1602
        assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
1603
        super().__post_init__()
1604

1605
    def store_reduction(self, output_name, indexer, vars, scan_vars):
1606
        idx = self.reindex(vars, scan_vars)
1607
        value = self.inner_fn(idx)
1608
        result = ops.scan(self.dtype, self.combine_fn, value, self.init)
1609
        return ops.store(output_name, indexer(idx), result)
1610

1611
    def get_reduction_type(self):
1612
        # return self.scan_op
1613
        return "custom"
1614

1615
    def get_reduction_size(self):
1616
        return self.scan_ranges
1617

1618
    def get_size(self):
1619
        return self.size
1620

1621
    def get_pointwise_size(self):
1622
        return self.ranges
1623

1624
    def index_length(self):
1625
        return len(self.ranges) + len(self.scan_ranges)
1626

1627
    def inner_fn_args(self):
1628
        index = self._index(self.ranges)
1629
        rindex = self._index(self.scan_ranges, "r")
1630
        idx = self.reindex(index, rindex)
1631
        return (idx,)
1632

1633
    def inner_fn_free_unbacked_symbols(self):
1634
        index = self._index(self.ranges)
1635
        rindex = self._index(self.scan_ranges, "r")
1636
        idx = self.reindex(index, rindex)
1637
        return extract_free_unbacked_symbols(self.inner_fn, idx)
1638

1639
    @classmethod
1640
    def create(
1641
        cls,
1642
        device: torch.device,
1643
        dtype: torch.dtype,
1644
        inner_fn: Callable[[List[Expr]], Any],
1645
        size: List[Expr],
1646
        axis: int,
1647
        combine_fn: Callable[..., Any],
1648
        init: Any,
1649
        reduction_hint: ReductionHint = ReductionHint.DEFAULT,
1650
    ) -> Optional["TensorBox"]:
1651
        pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
1652
        scan_ranges = [size[axis]]
1653

1654
        if device.type != "cuda":
1655
            # TODO: CPU support
1656
            return None
1657

1658
        sizevars = V.graph.sizevars
1659
        scan_numel = sizevars.simplify(sympy_product(scan_ranges))
1660

1661
        # Scan with a single element is just a copy
1662
        if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):  # type: ignore[arg-type]
1663
            return Pointwise.create(
1664
                device=device,
1665
                dtype=dtype,
1666
                inner_fn=inner_fn,
1667
                ranges=size,
1668
            )
1669

1670
        reduction_hint, num_splits = cls.num_splits(
1671
            device=device,
1672
            dtype=dtype,
1673
            inner_fn=inner_fn,
1674
            axis=axis,
1675
            pointwise_ranges=pointwise_ranges,
1676
            scan_ranges=scan_ranges,
1677
            combine_fn=combine_fn,
1678
            scan_numel=scan_numel,
1679
        )
1680
        scan_type = Scan if num_splits <= 1 else SplitScan
1681

1682
        if num_splits > 1 and torch.version.hip is not None:
1683
            # Fallback for split-scan on ROCm
1684
            return None
1685

1686
        def reindex(index, scan_index):
1687
            assert len(scan_index) == len(scan_ranges)
1688
            assert len(index) == len(pointwise_ranges)
1689
            return [*index[:axis], *scan_index, *index[axis:]]
1690

1691
        result = TensorBox.create(
1692
            scan_type(
1693
                device=device,
1694
                dtype=dtype,
1695
                inner_fn=inner_fn,
1696
                size=size,
1697
                ranges=pointwise_ranges,
1698
                scan_ranges=scan_ranges,
1699
                combine_fn=combine_fn,
1700
                reindex=reindex,
1701
                init=init,
1702
                reduction_hint=reduction_hint,
1703
            )
1704
        )
1705
        result.realize()
1706
        return result
1707

1708
    @classmethod
1709
    def num_splits(
1710
        cls,
1711
        device: torch.device,
1712
        dtype: torch.dtype,
1713
        inner_fn: Callable[[List[Expr]], Any],
1714
        axis: int,
1715
        pointwise_ranges: List[Expr],
1716
        scan_ranges: List[Expr],
1717
        combine_fn: Callable[..., Any],
1718
        scan_numel: Expr,
1719
    ):
1720
        # TODO: custom splitting heuristic for scan
1721
        def wrapper_fn(idx, reduction_idx):
1722
            return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
1723

1724
        return Reduction.num_splits(
1725
            device=device,
1726
            dst_dtype=dtype,
1727
            src_dtype=dtype,
1728
            inner_fn=wrapper_fn,
1729
            ranges=pointwise_ranges,
1730
            reduction_ranges=scan_ranges,
1731
            reduction_type="sum",
1732
            reduction_numel=scan_numel,
1733
        )
1734

1735

1736
# This signifies a scan op that should go through TritonSplitScanKernel codgen on CUDA.
1737
@dataclasses.dataclass
1738
class SplitScan(Scan):
1739
    pass
1740

1741

1742
def is_storage_and_layout(x):
1743
    try:
1744
        as_storage_and_layout(x, freeze=False)
1745
        return True
1746
    except NotImplementedError:
1747
        return False
1748

1749

1750
def is_contiguous_storage_and_layout(x):
1751
    try:
1752
        buffer, layout = as_storage_and_layout(x, freeze=False)
1753
        return layout.is_contiguous()
1754
    except NotImplementedError:
1755
        return False
1756

1757

1758
def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
1759
    """Try to simplify x into a StorageBox and a Layout"""
1760
    if isinstance(x, TensorBox):
1761
        return as_storage_and_layout(
1762
            x.data,
1763
            freeze=freeze,
1764
            want_contiguous=want_contiguous,
1765
            stride_order=stride_order,
1766
        )
1767
    if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
1768
        if freeze:
1769
            if want_contiguous:
1770
                x.data.freeze_layout()
1771
                assert x.data.layout.is_contiguous()
1772
            elif stride_order is not None:
1773
                x.data.freeze_layout_with_stride_order(stride_order)
1774
            else:
1775
                x.data.decide_layout()
1776
        return x, x.data.layout
1777
    if isinstance(x, ReinterpretView):
1778
        # making the base of x contiguous or stride_ordered will not necessarily make
1779
        # the ReinterpretView either, so don't pass along those arguments
1780
        buffer, _ = as_storage_and_layout(
1781
            x.data,
1782
            freeze=freeze,
1783
        )
1784
        return buffer, x.layout
1785
    raise NotImplementedError
1786

1787

1788
as_contiguous_storage_and_layout = functools.partial(
1789
    as_storage_and_layout, want_contiguous=True
1790
)
1791

1792

1793
def is_stride_order_storage_and_layout(x, stride_order):
1794
    try:
1795
        buffer, layout = as_storage_and_layout(x, freeze=False)
1796
        return layout.is_stride_ordered(stride_order)
1797
    except NotImplementedError:
1798
        return False
1799

1800

1801
@dataclasses.dataclass
1802
class BaseView(IRNode):
1803
    data: IRNode
1804

1805
    def get_unbacked_symbol_uses(self):
1806
        return self.data.get_unbacked_symbol_uses()
1807

1808
    def make_reindexer(self):
1809
        raise NotImplementedError(f"make_reindexer NYI on {self}")
1810

1811
    def make_indexer(self):
1812
        inner = self.data.make_indexer()
1813
        reindex = self.make_reindexer()
1814

1815
        def indexer(idx):
1816
            return inner(reindex(idx))
1817

1818
        return indexer
1819

1820
    def make_loader(self):
1821
        inner = self.data.make_loader()
1822
        reindex = self.make_reindexer()
1823

1824
        def loader(idx):
1825
            return inner(reindex(idx))
1826

1827
        return loader
1828

1829
    @property
1830
    def dtype(self):
1831
        return self.data.dtype
1832

1833
    def get_layout(self):
1834
        return self.data.get_layout()
1835

1836
    def get_device(self):
1837
        return self.data.get_device()
1838

1839
    def get_origin_node(self):
1840
        return None
1841

1842
    def get_name(self):
1843
        return self.data.get_name()
1844

1845
    def get_pointwise_size(self):
1846
        return self.get_size()
1847

1848
    def mark_reuse(self, users):
1849
        return self.data.mark_reuse(users)
1850

1851
    def has_exceeded_max_reads(self):
1852
        return self.data.has_exceeded_max_reads()
1853

1854
    def realize(self):
1855
        return self.data.realize()
1856

1857
    def realize_hint(self):
1858
        return self.data.realize_hint()
1859

1860
    def get_storage_numel(self):
1861
        return self.data.get_storage_numel()
1862

1863
    def is_extern(self):
1864
        return self.data.is_extern()  # type: ignore[attr-defined]
1865

1866
    def get_reads(self):
1867
        with patch.object(FlexibleLayout, "allow_indexing", True):
1868
            return extract_read_writes(
1869
                self.make_loader(),
1870
                self.get_size(),
1871
            ).reads
1872

1873
    def unwrap_view(self):
1874
        x: IRNode = self
1875
        while isinstance(x, BaseView):
1876
            x = x.data
1877
        return x
1878

1879
    def constant_to_device(self, device):
1880
        """Move this to a given device. Requires that all reads are to constants."""
1881
        loader = self.make_loader()
1882
        loader = patch.object(ConstantBuffer, "override_device", device)(loader)
1883
        return Pointwise(device, self.get_dtype(), loader, self.get_size())
1884

1885

1886
@dataclasses.dataclass
1887
class ExpandView(BaseView):
1888
    size: List[Expr]
1889

1890
    @staticmethod
1891
    def _normalize_size(x, new_size):
1892
        """Replace `-1` with correct sizes"""
1893
        new_size = list(map(sympy.expand, new_size))
1894
        old_size = x.get_size()
1895
        old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
1896
        assert len(new_size) == len(old_size)
1897
        for i in range(len(new_size)):
1898
            if new_size[i] == -1:
1899
                assert old_size[i] is not None
1900
                new_size[i] = old_size[i]
1901
            elif old_size[i] is None or old_size[i] == 1:
1902
                pass
1903
            else:
1904
                # Expect broadcast compatibility
1905
                new_size[i] = V.graph.sizevars.expect_equals(
1906
                    new_size[i],
1907
                    old_size[i],
1908
                    msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}",
1909
                )
1910
        return new_size
1911

1912
    @classmethod
1913
    def create(cls, x, new_size):
1914
        new_size = cls._normalize_size(x, new_size)
1915

1916
        if is_storage_and_layout(x):
1917
            storage, old_layout = as_storage_and_layout(x)
1918
            skip = len(new_size) - len(old_layout.size)
1919
            assert skip >= 0
1920
            new_stride = [sympy.Integer(0)] * skip
1921
            for stride, size in zip(old_layout.stride, old_layout.size):
1922
                new_stride.append(stride if size != 1 else sympy.Integer(0))
1923
            new_layout = FixedLayout(
1924
                old_layout.device,
1925
                old_layout.dtype,
1926
                list(new_size),
1927
                new_stride,
1928
                old_layout.offset,
1929
            )
1930
            return ReinterpretView(storage, new_layout)
1931

1932
        return ExpandView(x, new_size)
1933

1934
    def get_size(self):
1935
        return self.size
1936

1937
    def make_reindexer(self):
1938
        target = self.get_size()
1939
        actual = self.data.get_size()
1940
        skip = len(target) - len(actual)
1941

1942
        def reindex(index):
1943
            index = list(index[skip:])
1944
            assert len(index) == len(actual)
1945
            for i in range(len(actual)):
1946
                if actual[i] == 1:
1947
                    # zero out broadcast dimension
1948
                    index[i] = sympy.Integer(0)
1949
            return index
1950

1951
        return reindex
1952

1953

1954
@dataclasses.dataclass
1955
class PermuteView(BaseView):
1956
    dims: List[Expr]
1957

1958
    @classmethod
1959
    def create(cls, x, dims):
1960
        dims = cls._map_neg_dims(dims)
1961
        assert set(dims) == set(range(len(dims)))
1962

1963
        if is_storage_and_layout(x):
1964
            storage, old_layout = as_storage_and_layout(x)
1965
            new_layout = FixedLayout(
1966
                old_layout.device,
1967
                old_layout.dtype,
1968
                [old_layout.size[i] for i in dims],
1969
                [old_layout.stride[i] for i in dims],
1970
                old_layout.offset,
1971
            )
1972
            return ReinterpretView(storage, new_layout)
1973

1974
        return PermuteView(x, dims)
1975

1976
    @classmethod
1977
    def _map_neg_dims(cls, dims):
1978
        return [dim if dim >= 0 else len(dims) + dim for dim in dims]
1979

1980
    def get_size(self):
1981
        assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
1982
        size = self.data.get_size()
1983
        return [size[i] for i in self.dims]
1984

1985
    def make_reindexer(self):
1986
        inv = {j: i for i, j in enumerate(self.dims)}
1987
        inv = [inv[i] for i in range(len(self.dims))]  # type: ignore[index]
1988
        assert set(inv) == set(range(len(self.dims)))
1989

1990
        def reindex(index):
1991
            return [index[i] for i in inv]
1992

1993
        return reindex
1994

1995

1996
class SqueezeView(BaseView):
1997
    @classmethod
1998
    def create(cls, x, *, dim=None):
1999
        if is_storage_and_layout(x):
2000
            storage, old_layout = as_storage_and_layout(x)
2001
            new_size = []
2002
            new_stride = []
2003
            if dim is not None:
2004
                assert isinstance(dim, int), "expected integer dim argument"
2005
                assert 0 <= dim and dim < len(old_layout.size)
2006

2007
            for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
2008
                if dim is None:
2009
                    if size != 1:
2010
                        new_size.append(size)
2011
                        new_stride.append(stride)
2012
                else:
2013
                    if i != dim:
2014
                        new_size.append(size)
2015
                        new_stride.append(stride)
2016
                    else:
2017
                        assert size == 1, "expected squeezed size to be 1"
2018

2019
            new_layout = FixedLayout(
2020
                old_layout.device,
2021
                old_layout.dtype,
2022
                new_size,
2023
                new_stride,
2024
                old_layout.offset,
2025
            )
2026
            return ReinterpretView(storage, new_layout)
2027

2028
        if dim is None:
2029
            # redirect to a generic view
2030
            return View.create(x, [s for s in x.get_size() if s != 1])
2031
        else:
2032
            assert x.get_size()[dim] == 1
2033
            return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
2034

2035
    @staticmethod
2036
    def squeezer(size: Tuple[sympy.Expr, ...]):
2037
        new_size = [s for s in size if s != 1]
2038
        not_one = [i for i, s in enumerate(size) if s != 1]
2039
        length = len(size)
2040

2041
        def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]:
2042
            assert len(index) == len(not_one), f"{index} {not_one}"
2043
            new_index = [sympy.Integer(0)] * length
2044
            for idx, s in zip(not_one, index):
2045
                new_index[idx] = s
2046
            return tuple(new_index)
2047

2048
        return new_size, reindex
2049

2050
    def __init__(self, data):
2051
        raise AssertionError("use SqueezeView.create()")
2052

2053

2054
@dataclasses.dataclass
2055
class GenericView(BaseView):
2056
    size: List[Expr]
2057
    reindex: Callable[..., Any]
2058

2059
    def make_reindexer(self):
2060
        return self.reindex
2061

2062
    def reindex_str(self):
2063
        index_old = [sympy_index_symbol(f"i{n}") for n in range(len(self.size))]
2064
        index_new = list(self.reindex(index_old))
2065
        return f"lambda {', '.join(map(str, index_old))}: {index_new}"
2066

2067
    def __str__(self):
2068
        return self.str_helper(
2069
            [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
2070
        )
2071

2072
    __repr__ = __str__
2073

2074
    @classmethod
2075
    def create(cls, x, new_size, reindex):
2076
        return cls(x, list(new_size), reindex)
2077

2078
    def get_size(self):
2079
        return self.size
2080

2081

2082
@dataclasses.dataclass
2083
class View(GenericView):
2084
    @staticmethod
2085
    def handle_negative_index(idx, size):
2086
        idx = sympy.expand(idx)
2087
        size = sympy.expand(size)
2088
        evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
2089
        if evaluate_expr(sympy.Lt(idx, 0)):
2090
            idx = idx + size
2091
        return idx
2092

2093
    @classmethod
2094
    def create(cls, x, new_size):
2095
        assert isinstance(new_size, (tuple, list))
2096
        old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
2097

2098
        # Skip pointless views
2099
        if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
2100
            return x
2101

2102
        unbacked_symbols_in_sizes = False
2103
        if (
2104
            len(free_unbacked_symbols(old_size)) > 0
2105
            or len(free_unbacked_symbols(new_size)) > 0
2106
        ):
2107
            unbacked_symbols_in_sizes = True
2108

2109
        if 0 in new_size:
2110

2111
            def fake_reindex(index):
2112
                return tuple([0] * len(old_size))
2113

2114
            return cls(x, list(new_size), fake_reindex)
2115
        # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
2116
        elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes:
2117
            if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)):
2118
                # realize x; otherwise, the dynamic_reshape_indexer below will fail
2119
                # due to the size_hint's inability to process unbacked SymInts
2120
                x = ExternKernel.realize_input(x)
2121

2122
            storage, old_layout = as_contiguous_storage_and_layout(x)
2123
            new_layout = FixedLayout(
2124
                old_layout.device,
2125
                old_layout.dtype,
2126
                new_size,
2127
                FlexibleLayout.contiguous_strides(new_size),
2128
                old_layout.offset,
2129
            )
2130
            return ReinterpretView(storage, new_layout)
2131

2132
        reindex = cls.dynamic_reshape_indexer(old_size, new_size)
2133
        return cls(x, list(new_size), reindex)
2134

2135
    @staticmethod
2136
    def resolve_negative_size(old_size, new_size):
2137
        new_size = [V.graph.sizevars.simplify(x) for x in new_size]
2138
        old_size = [V.graph.sizevars.simplify(x) for x in old_size]
2139

2140
        new_size = list(new_size)
2141
        for i in range(len(new_size)):
2142
            if new_size[i] == -1:
2143
                new_size[i] = sympy.Integer(1)
2144
                new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
2145
                break
2146

2147
        V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
2148
        return old_size, new_size
2149

2150
    @classmethod
2151
    def dynamic_reshape_indexer(cls, old_size, new_size):
2152
        try:
2153
            reindex = cls._dynamic_reshape_indexer(old_size, new_size)
2154
        except (AssertionError, IndexError):
2155
            # optimistic algorithm failed, lets do a fallback
2156
            flat = [sympy_product(old_size)]
2157
            reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
2158
            reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
2159
            reindex = fuse_reindexing(reindex1, reindex2)
2160
        return reindex
2161

2162
    @staticmethod
2163
    def _dynamic_reshape_indexer(old_size, new_size):
2164
        """
2165
        Perform a reshape entirely by modifying indexing math
2166
        """
2167
        size_hint = V.graph.sizevars.size_hint
2168
        vars = [sympy_index_symbol(f"view{i}") for i in range(len(new_size))]
2169

2170
        stack_new = list(zip(vars, new_size))
2171
        stack_old = list(old_size)
2172

2173
        view_expr = []
2174
        while stack_new and stack_old:
2175
            size_old = stack_old.pop()
2176
            var, size_new = stack_new.pop()
2177
            if size_old == 1:
2178
                view_expr.append(sympy.Integer(0))
2179
                stack_new.append((var, size_new))  # re-add
2180
            elif size_new == 1:
2181
                stack_old.append(size_old)  # re-add
2182
            elif size_hint(size_new) == size_hint(size_old):
2183
                view_expr.append(var)
2184
                V.graph.sizevars.guard_equals(size_new, size_old)
2185
            elif size_hint(size_new) < size_hint(size_old):
2186
                while size_hint(size_new) < size_hint(size_old):
2187
                    var2, size_new2 = stack_new.pop()
2188
                    var = var2 * size_new + var
2189
                    size_new = size_new * size_new2
2190
                view_expr.append(var)
2191
                V.graph.sizevars.guard_equals(size_new, size_old)
2192
            elif size_hint(size_new) > size_hint(size_old):
2193
                divisor = sympy.Integer(1)
2194
                modulus = size_old
2195
                view_expr.append(ModularIndexing(var, divisor, modulus))
2196
                divisor = divisor * modulus
2197
                while size_hint(size_new) > size_hint(size_old):
2198
                    modulus = stack_old.pop()
2199
                    view_expr.append(ModularIndexing(var, divisor, modulus))
2200
                    divisor = divisor * modulus
2201
                    size_old = size_old * modulus
2202
                V.graph.sizevars.guard_equals(size_new, size_old)
2203
            else:
2204
                raise AssertionError()
2205

2206
        while stack_old:
2207
            size_old = stack_old.pop()
2208
            V.graph.sizevars.guard_equals(size_old, 1)  # type: ignore[arg-type]
2209
            view_expr.append(sympy.Integer(0))
2210

2211
        while stack_new:
2212
            var, size_new = stack_new.pop()
2213
            V.graph.sizevars.guard_equals(size_new, 1)  # type: ignore[arg-type]
2214

2215
        view_expr = list(reversed(view_expr))
2216
        assert len(view_expr) == len(old_size)
2217

2218
        def reindex(index):
2219
            assert len(index) == len(vars), (len(index), len(vars))
2220
            replacements = dict(zip(vars, index))
2221
            return tuple(sympy_subs(x, replacements) for x in view_expr)  # type: ignore[arg-type]
2222

2223
        return reindex
2224

2225

2226
@dataclasses.dataclass
2227
class ReinterpretView(BaseView):
2228
    """Pretend our storage has a different layout"""
2229

2230
    layout: "Layout"
2231

2232
    def __post_init__(self):
2233
        super().__post_init__()
2234
        if isinstance(self.data, BaseView):
2235
            self.data = self.data.unwrap_view()
2236

2237
    def __str__(self):
2238
        return self.str_helper(
2239
            [
2240
                self.data,
2241
                self.layout,
2242
            ]
2243
        )
2244

2245
    __repr__ = __str__
2246

2247
    def get_name(self):
2248
        return self.data.get_name()
2249

2250
    def get_device(self):
2251
        return self.layout.device
2252

2253
    def get_origin_node(self):
2254
        return None
2255

2256
    @property
2257
    def dtype(self):
2258
        return self.layout.dtype
2259

2260
    def get_size(self):
2261
        return list(self.layout.size)
2262

2263
    def get_stride(self):
2264
        return list(self.layout.stride)
2265

2266
    def make_loader(self):
2267
        def loader(index):
2268
            indexer = self.layout.make_indexer()
2269
            return ops.load(self.get_name(), indexer(index))
2270

2271
        return loader
2272

2273
    def make_indexer(self):
2274
        return self.layout.make_indexer()
2275

2276
    def get_layout(self):
2277
        return self.layout
2278

2279
    def freeze_layout(self):
2280
        pass
2281

2282
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
2283
        return (
2284
            free_unbacked_symbols(self.layout.size)
2285
            | free_unbacked_symbols(self.layout.stride)
2286
            | free_unbacked_symbols(self.layout.offset)
2287
        )
2288

2289
    def codegen_reference(self, writer=None):
2290
        # reinterpret_tensor is similar to as_strided except:
2291
        # - offset is added to the existing offset (rather than replacing it)
2292
        # - view tracking is disabled similar to unsafe_view
2293
        return V.graph.wrapper_code.codegen_reinterpret_view(
2294
            self.data,
2295
            self.layout.size,
2296
            self.layout.stride,
2297
            self.layout.offset,
2298
            writer,
2299
        )
2300

2301

2302
class SliceView(View):
2303
    @classmethod
2304
    def normalize_start_end(cls, x, dim, start, end):
2305
        """
2306
        Normalize start and end such that both are in the range
2307
        [0, x.get_size()[dim]] and start <= end.
2308
        """
2309
        sizevars = V.graph.sizevars
2310
        dim_size = x.get_size()[dim]
2311

2312
        if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
2313

2314
            def clamp(x, lower, upper):
2315
                return sympy.Min(sympy.Max(x, lower), upper)
2316

2317
        else:
2318

2319
            def clamp(x, lower, upper):
2320
                return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper)
2321

2322
        def clamp_wrap(val, lower, upper, default):
2323
            if val is None:
2324
                return default
2325
            val = cls.handle_negative_index(val, dim_size)
2326
            return clamp(val, lower, upper)
2327

2328
        start = clamp_wrap(start, 0, dim_size, 0)
2329
        end = clamp_wrap(end, start, dim_size, dim_size)
2330
        return start, end
2331

2332
    @classmethod
2333
    def create(cls, x, dim, start, end, step=1):
2334
        step = sympy.expand(step)
2335
        assert step > 0
2336
        try:
2337
            if start == 0 and end >= 2**63 - 1 and step == 1:
2338
                return x
2339
        except TypeError:
2340
            pass
2341

2342
        sizevars = V.graph.sizevars
2343
        new_size = list(x.get_size())
2344

2345
        start, end = cls.normalize_start_end(x, dim, start, end)
2346

2347
        new_size[dim] = FloorDiv(end - start + (step - 1), step)
2348

2349
        if is_storage_and_layout(x):
2350
            # Fast path
2351
            storage, old_layout = as_storage_and_layout(x)
2352
            new_stride = list(old_layout.stride)
2353
            new_stride[dim] = new_stride[dim] * step
2354
            new_layout = FixedLayout(
2355
                old_layout.device,
2356
                old_layout.dtype,
2357
                new_size,
2358
                new_stride,
2359
                old_layout.offset + old_layout.stride[dim] * start,
2360
            )
2361
            return ReinterpretView(storage, new_layout)
2362

2363
        def reindex(index):
2364
            assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
2365
            index = list(index)
2366
            index[dim] = index[dim] * step + start
2367
            return index
2368

2369
        # redirect to a generic view
2370
        return SliceView(x, size=new_size, reindex=reindex)
2371

2372

2373
class BaseConstant(IRNode):
2374
    dtype: torch.dtype
2375
    device: torch.device
2376

2377
    def get_size(self):
2378
        return ()
2379

2380
    def get_device(self):
2381
        return self.device
2382

2383
    def get_origin_node(self):
2384
        return None
2385

2386
    def mark_reuse(self, users):
2387
        pass
2388

2389
    def has_exceeded_max_reads(self):
2390
        return False
2391

2392
    def get_reads(self):
2393
        return ()
2394

2395
    def is_extern(self):
2396
        return False
2397

2398

2399
@dataclasses.dataclass
2400
class Constant(BaseConstant):
2401
    value: Any
2402
    dtype: torch.dtype
2403
    device: torch.device
2404

2405
    def make_loader(self):
2406
        def loader(index):
2407
            return ops.constant(self.value, self.dtype)
2408

2409
        return loader
2410

2411
    def realize(self):
2412
        pass
2413

2414
    def constant_to_device(self, device):
2415
        return Constant(self.value, self.dtype, device)
2416

2417

2418
@dataclasses.dataclass
2419
class IndexingConstant(BaseConstant):
2420
    index: Any
2421
    dtype: torch.dtype
2422
    device: torch.device
2423

2424
    def make_loader(self):
2425
        def loader(index):
2426
            return ops.index_expr(self.index, self.dtype)
2427

2428
        return loader
2429

2430
    def constant_to_device(self, device):
2431
        return IndexingConstant(self.index, self.dtype, device)
2432

2433

2434
def is_contiguous_strides_for_shape(stride, shape):
2435
    return all(
2436
        size == 1 or left == right
2437
        for left, right, size in zip(
2438
            stride, FlexibleLayout.contiguous_strides(shape), shape
2439
        )
2440
    )
2441

2442

2443
@dataclasses.dataclass
2444
class Layout(IRNode):
2445
    def __init__(
2446
        self,
2447
        device: torch.device,
2448
        dtype: torch.dtype,
2449
        size: List[Expr],
2450
        stride: Optional[Sequence[Union[Expr, int]]],
2451
        offset: Expr = Integer(0),
2452
    ):
2453
        assert stride is None or len(size) == len(
2454
            stride
2455
        ), f"size={size}, stride={stride}"
2456
        self.device = device
2457
        self.dtype = dtype
2458
        assert all(isinstance(s, (Expr, int)) for s in size)
2459
        self.size = size
2460
        self._stride = stride
2461
        self.offset = offset
2462

2463
    @property
2464
    def stride(self):
2465
        return self._stride
2466

2467
    def __str__(self):
2468
        offset = ""
2469
        if self.offset != 0:
2470
            offset = f", offset={self.offset}"
2471
        return (
2472
            f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
2473
            f"size={self.size}, stride={self.stride}{offset})"
2474
        )
2475

2476
    __repr__ = __str__
2477

2478
    def is_contiguous(self):
2479
        return is_contiguous_strides_for_shape(self.stride, self.size)
2480

2481
    def is_channels_last_contiguous(self):
2482
        ndim = len(self.size)
2483
        if ndim not in [4, 5]:
2484
            return False
2485
        for left, right, size in zip(
2486
            self.stride, make_channels_last_strides_for(self.size), self.size  # type: ignore[arg-type]
2487
        ):
2488
            if size != 1 and left != right:
2489
                return False
2490
        return True
2491

2492
    def is_transposed(self):
2493
        for left, right, size in zip(
2494
            self.stride,
2495
            reversed(FlexibleLayout.contiguous_strides(self.size)),
2496
            self.size,
2497
        ):
2498
            if size != 1 and left != right:
2499
                return False
2500
        return True
2501

2502
    def is_stride_ordered(self, order):
2503
        assert len(self.stride) == len(order)
2504

2505
        # ignore dimensions of size 1, they dont affect layout
2506
        non_1_indices = [
2507
            i
2508
            for i, dim in enumerate(self.size)
2509
            if V.graph.sizevars.size_hint(dim, fallback=2) != 1
2510
        ]
2511

2512
        stride = [self.stride[i] for i in non_1_indices]
2513
        order = [order[i] for i in non_1_indices]
2514

2515
        def sorted_indices(arr):
2516
            sorted_arr = sorted(arr)
2517
            return [sorted_arr.index(element) for element in arr]
2518

2519
        # since we may have removed dimensions, need to re-sort & re-index order
2520
        order = sorted_indices(order)
2521

2522
        # reorder the stride given order
2523
        stride_ordered = [-1] * len(order)
2524
        for i in range(len(order)):
2525
            stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i])
2526
        # check if it is in ascending order
2527
        for i in range(len(order) - 1):
2528
            if stride_ordered[i] > stride_ordered[i + 1]:
2529
                return False
2530
        return True
2531

2532
    def is_channels_last_stride_ordered(self):
2533
        # create channels_last order(NCHW, NCDHW, the C is the first order).
2534
        order = [0] + list(reversed(range(1, len(self.stride) - 1)))
2535
        order = [len(order)] + order
2536
        return self.is_stride_ordered(order)
2537

2538
    def as_fixed(self):
2539
        return FixedLayout(
2540
            self.device,
2541
            self.dtype,
2542
            self.size,
2543
            self.stride,
2544
            self.offset,
2545
        )
2546

2547
    def make_indexer(self):
2548
        assert (
2549
            FlexibleLayout.allow_indexing
2550
        ), f"convert {type(self).__name__} to FixedLayout first"
2551
        return self.as_fixed().make_indexer()
2552

2553
    def __eq__(self, other) -> bool:
2554
        return (
2555
            self.device == other.device
2556
            and self.dtype == other.dtype
2557
            and self.size == other.size
2558
            and self.stride == other.stride
2559
            and self.offset == other.offset
2560
        )
2561

2562
    def storage_size(self) -> sympy.Expr:
2563
        return compute_required_storage_length(self.size, self.stride, self.offset)  # type: ignore[arg-type, return-value]
2564

2565

2566
class FixedLayout(Layout):
2567
    """A Tensor layout we cannot change"""
2568

2569
    def __init__(
2570
        self,
2571
        device: torch.device,
2572
        dtype: torch.dtype,
2573
        size: Union[List[Expr], List[int]],
2574
        stride: Optional[Sequence[Union[Expr, int]]] = None,
2575
        offset: Union[Expr, int] = Integer(0),
2576
    ):
2577
        if stride is None:
2578
            stride = FlexibleLayout.contiguous_strides(size)
2579
        super().__init__(
2580
            device,
2581
            dtype,
2582
            size,  # type: ignore[arg-type]
2583
            stride,
2584
            offset,  # type: ignore[arg-type]
2585
        )
2586

2587
    def make_indexer(self):
2588
        """A closure containing math to read a given element"""
2589

2590
        def indexer(index):
2591
            assert len(index) == len(self.stride) == len(self.size)
2592
            result = self.offset
2593
            for idx, stride, sz in zip(index, self.stride, self.size):
2594
                if sz != 1:
2595
                    result = result + idx * stride
2596
            return result
2597

2598
        return indexer
2599

2600

2601
class FlexibleLayout(Layout):
2602
    """A Tensor layout we are allowed to change"""
2603

2604
    allow_indexing = False
2605

2606
    @staticmethod
2607
    def contiguous_strides(sizes):
2608
        if len(sizes) == 0:
2609
            return []
2610
        reversed_strides = [sympy.Integer(1)]
2611
        for size in reversed(sizes[1:]):
2612
            reversed_strides.append(size * reversed_strides[-1])
2613
        return list(reversed(reversed_strides))
2614

2615
    @staticmethod
2616
    def fill_ordered(sizes, order):
2617
        """
2618
        Create a stride based on the order the dimensions should be filled in.
2619

2620
        In this format, channels last would be:
2621
            [1, 3, 2, 0]
2622
        """
2623
        assert set(range(len(sizes))) == set(order)
2624
        next_stride = sympy.Integer(1)
2625
        strides = [None] * len(order)
2626

2627
        for i in order:
2628
            strides[i] = next_stride
2629
            next_stride = next_stride * sizes[i]
2630
        return strides
2631

2632
    @staticmethod
2633
    def stride_ordered(sizes, order):
2634
        """
2635
        Create a stride based on the sorted order of a permuted range.
2636

2637
        In this format, channels last would be:
2638
            [3, 0, 2, 1]
2639
        """
2640
        assert set(range(len(sizes))) == set(order)
2641
        fill_order = stride_order2fill_order(order)
2642
        return FlexibleLayout.fill_ordered(sizes, fill_order)
2643

2644
    @staticmethod
2645
    def same_ordered(sizes, stride):
2646
        """
2647
        Create a stride that has the same stride order as given stride
2648

2649
        For example, if given stride is [1000, 1, 100, 10],
2650
        the fill order should be [1, 3, 2, 0]
2651
        """
2652
        assert len(sizes) == len(stride)
2653
        stride = [V.graph.sizevars.size_hint(x) for x in stride]
2654
        fill_order = sorted(range(len(stride)), key=stride.__getitem__)
2655
        return FlexibleLayout.fill_ordered(sizes, fill_order)
2656

2657
    def as_stride_order(self, order):
2658
        return FixedLayout(
2659
            self.device,
2660
            self.dtype,
2661
            self.size,
2662
            self.stride_ordered(self.size, order),
2663
            self.offset,
2664
        )
2665

2666
    def as_fill_order(self, order):
2667
        return FixedLayout(
2668
            self.device,
2669
            self.dtype,
2670
            self.size,
2671
            self.fill_ordered(self.size, order),
2672
            self.offset,
2673
        )
2674

2675
    def as_same_order(self, stride):
2676
        return FixedLayout(
2677
            self.device,
2678
            self.dtype,
2679
            self.size,
2680
            self.same_ordered(self.size, stride),
2681
            self.offset,
2682
        )
2683

2684
    def __init__(self, device, dtype, size, stride_order=None):
2685
        if stride_order:
2686
            strides = FlexibleLayout.fill_ordered(size, stride_order)
2687
        else:
2688
            strides = FlexibleLayout.contiguous_strides(size)
2689
        super().__init__(device, dtype, size, strides)
2690

2691

2692
class AliasedLayout(Layout):
2693
    """Shares the same storage as another tensor"""
2694

2695
    def __init__(self, view: Union[BaseView, "TensorBox"]):
2696
        layout = view.get_layout()
2697
        super().__init__(
2698
            layout.device,
2699
            layout.dtype,
2700
            layout.size,
2701
            layout.stride,
2702
        )
2703
        self.view = view
2704

2705
    def make_indexer(self):
2706
        return self.as_fixed().make_indexer()
2707

2708
    def maybe_guard_aligned(self):
2709
        offset = self.view.get_layout().offset
2710
        if offset == 0:
2711
            return True
2712
        from .compile_fx import ALIGNMENT
2713

2714
        return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)  # type: ignore[arg-type]
2715

2716

2717
class NoneLayout(IRNode):
2718
    # This is janky, I figured out what fields to populate by just running
2719
    # the model I was interested in and adding properties/methods as needed.
2720
    # This doesn't inherit from Layout because Layout assumes you have stuff
2721
    # like sizes, but I don't really have anything here.
2722
    #
2723
    # If you have an ir.Node with NoneLayout, you probably need to setup
2724
    # dependencies manually in scheduler
2725

2726
    def __init__(self, device):
2727
        self.device = device
2728
        self.size = [0]
2729
        self.stride = [0]
2730

2731
    def storage_size(self):
2732
        return 0
2733

2734
    def as_fixed(self):
2735
        return self
2736

2737

2738
class MutationLayout(Layout):
2739
    def __init__(self, target: IRNode):
2740
        super().__init__(
2741
            target.get_device(),
2742
            target.get_dtype(),
2743
            target.get_size(),
2744
            None,
2745
        )
2746
        self.target = target
2747
        name = self.get_buffer().get_name()
2748
        V.graph.mark_buffer_mutated(name)
2749

2750
    @Layout.stride.getter  # type: ignore[attr-defined]
2751
    def stride(self):
2752
        return self.real_layout().stride
2753

2754
    def storage_size(self) -> sympy.Expr:
2755
        return self.real_layout().storage_size()
2756

2757
    def get_buffer(self) -> "Buffer":
2758
        def unwrap_views(target):
2759
            if isinstance(target, MutationLayout):
2760
                return unwrap_views(target.target)
2761
            if isinstance(target, BaseView):
2762
                return unwrap_views(target.unwrap_view())
2763
            if isinstance(target, MutableBox):
2764
                return unwrap_views(target.data)
2765
            return target
2766

2767
        result = unwrap_views(self.target)
2768
        assert isinstance(result, Buffer), "MutationLayout must refer to a buffer"
2769
        return result
2770

2771
    def real_layout(self):
2772
        return self.get_buffer().layout
2773

2774
    @classmethod
2775
    def realize_into(cls, src, dst, unsafe_alias=False):
2776
        dst.realize()
2777
        # NOTE: We must realize users of `dst` before we realize `src`, since
2778
        # realization order determines scheduling order. Otherwise, src's
2779
        # mutation would be scheduled before the existing users of dst!
2780
        V.graph.mark_buffer_mutated(dst.get_name())
2781

2782
        if isinstance(src, TensorBox):
2783
            src = src.data
2784

2785
        # We copy the contents of src into dst. In most cases this should
2786
        # be fused into a single kernel by the scheduler.
2787
        # NOTE: We cannot change src's layout to mutate dst directly as this
2788
        # would alias src to dst, which is not correct as further mutations to
2789
        # dst would effect users of src. However if there are no more users of
2790
        # dst, we can alias src to dst.
2791
        src.realize_hint()
2792

2793
        if not unsafe_alias:
2794
            src = Pointwise.create(
2795
                device=src.get_device(),
2796
                dtype=src.get_dtype(),
2797
                inner_fn=src.make_loader(),
2798
                ranges=[
2799
                    V.graph.sizevars.guard_equals(a, b)
2800
                    for a, b in zip(src.get_size(), dst.get_size())
2801
                ],
2802
            ).data
2803

2804
        src.realize()
2805
        assert isinstance(src.data.layout, FlexibleLayout)
2806
        src.data.layout = MutationLayout(dst)
2807
        return src.data
2808

2809
    def as_fixed(self):
2810
        return self
2811

2812
    def make_indexer(self):
2813
        return self.target.make_indexer()
2814

2815

2816
@dataclasses.dataclass
2817
class Buffer(IRNode):
2818
    # Name is sometimes None; e.g., ForceInPlace, where there isn't
2819
    # a meaningful name
2820
    name: Optional[str]
2821
    layout: Layout
2822

2823
    # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
2824
    # MultiOutput does NOT define this!
2825

2826
    def __post_init__(self):
2827
        super().__post_init__()
2828
        self.origin_node = None
2829

2830
    def make_indexer(self):
2831
        return self.layout.make_indexer()
2832

2833
    def get_name(self) -> str:
2834
        assert self.name
2835
        return self.name
2836

2837
    def get_device(self):
2838
        return self.layout.device
2839

2840
    def get_origin_node(self):
2841
        return self.origin_node
2842

2843
    @property
2844
    def dtype(self):
2845
        return getattr(self.layout, "dtype", None)
2846

2847
    def get_size(self):
2848
        return list(self.layout.size)
2849

2850
    def get_stride(self):
2851
        return list(self.layout.stride)
2852

2853
    def get_offset(self):
2854
        return self.layout.offset
2855

2856
    def get_layout(self):
2857
        return self.layout
2858

2859
    def get_storage_numel(self):
2860
        return self.get_numel()
2861

2862
    def is_extern(self):
2863
        return False
2864

2865
    def freeze_layout(self):
2866
        if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)):
2867
            self.layout = self.layout.as_fixed()
2868

2869
    def freeze_layout_with_stride_order(self, order):
2870
        assert isinstance(self.layout, FlexibleLayout)
2871
        self.layout = self.layout.as_stride_order(order)
2872

2873
    def freeze_layout_with_fill_order(self, order):
2874
        assert isinstance(self.layout, FlexibleLayout)
2875
        self.layout = self.layout.as_fill_order(order)
2876

2877
    def freeze_layout_with_same_order(self, stride):
2878
        assert isinstance(self.layout, FlexibleLayout)
2879
        self.layout = self.layout.as_same_order(stride)
2880

2881
    def is_zero_elements(self):
2882
        return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))  # type: ignore[arg-type]
2883

2884
    def make_loader(self):
2885
        # Loading from a zero-element buffer is a no-op
2886
        if self.is_zero_elements():
2887
            return partial(nop_loader_fn, dtype=self.get_dtype())
2888

2889
        def loader(index):
2890
            indexer = self.layout.make_indexer()
2891
            return ops.load(self.name, indexer(index))
2892

2893
        return loader
2894

2895
    def is_no_op(self):
2896
        return False
2897

2898
    def codegen_reference(self, writer=None):
2899
        return self.get_name()
2900

2901
    def decide_layout(self):
2902
        pass
2903

2904
    def get_alias_names(self):
2905
        if isinstance(self.layout, AliasedLayout):
2906
            return [self.layout.view.get_name()]
2907
        return ()
2908

2909
    def get_mutation_names(self):
2910
        if isinstance(self.layout, MutationLayout):
2911
            return [self.layout.target.get_name()]
2912
        return ()
2913

2914
    def get_read_writes(self):
2915
        with patch.object(FlexibleLayout, "allow_indexing", True):
2916
            return extract_read_writes(
2917
                self.make_loader(),
2918
                self.get_size(),
2919
            )
2920

2921
    def get_reads(self):
2922
        return self.get_read_writes().reads
2923

2924
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
2925
        """
2926
        Returns the unbacked symbols which are defined by this IR node,
2927
        because this is a data-dependent IR node, or item()
2928
        """
2929
        # So this is a little unusual.  In principle, you could imagine
2930
        # defining a MultiOutputLayout buffer so that it DOES define
2931
        # unbacked symints.  However, we can't easily tell what symints
2932
        # such a buffer defines, because MultiOutputLayout doesn't actually
2933
        # define any useful information about what it returns.
2934
        #
2935
        # An easier and better approach is to delay the symint allocation
2936
        # to the MultiOutput IR nodes, which are when we actually extract
2937
        # out the buffers and know what their sizes are.
2938
        #
2939
        # There are two subleties here:
2940
        #
2941
        # 1. Suppose you have a kernel that produces out1: (i0,), out2: (i0,)
2942
        #    Both of these actually count as defs!  The scheduler will just
2943
        #    arbitrarily pick one of these as the canonical definer and
2944
        #    ensure it stays live.  It's not a big deal if we pick the
2945
        #    wrong one because tuple accesses are cheap, and all this means
2946
        #    is we accidentally keep a MultiOutput node live when it wasn't
2947
        #    strictly necessary.
2948
        #
2949
        # 2. Suppose you have a MultiOutput buffer whose size is (i0,), but
2950
        #    the MultiOutputLayout buffer it is projecting from isn't actually
2951
        #    dynamic; it has i0 as one of the arguments.  We cannot tell this
2952
        #    directly from MultiOutput, we have to look at the input buffer's
2953
        #    uses to work this out.  No big deal.
2954
        if isinstance(self.layout, (NoneLayout, MultiOutputLayout)):
2955
            return set()
2956

2957
        # This kernel defines all unbacked symbols... that it didn't get in as
2958
        # arguments!
2959
        defs = (
2960
            free_unbacked_symbols(self.get_size())
2961
            | free_unbacked_symbols(self.get_stride())
2962
            | free_unbacked_symbols(self.get_offset())
2963
        )
2964
        return defs - self.get_unbacked_symbol_uses()
2965

2966
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
2967
        """
2968
        Returns the unbacked symbols which are required to be in scope in
2969
        order to successfully perform codegen for this buffer.  For example,
2970
        a buffer that corresponds to an extern kernel call that takes i0 as
2971
        an argument would return {i0} here.  This is used to generate necessary
2972
        dependencies that ensure we actually bind i0 in codegen before you
2973
        try to use it.
2974

2975
        Note that this is NOT transitive; in particular, if this buffer takes
2976
        in as input another buffer with dynamic shape (e.g., (i0,)), we will
2977
        not report it here, because you will already have a dependency
2978
        on that buffer, which will eventually have a dependency on i0 if
2979
        necessary.
2980
        """
2981
        return set()
2982

2983
    def codegen_unbacked_symbol_defs(self, wrapper):
2984
        # NB: If it is possible for other ir node types to return unbacked
2985
        # symints, you need to make sure their codegen calls this method.
2986
        # Don't forget to update get_unbacked_symbol_defs too.
2987
        symbols_to_define = self.get_unbacked_symbol_defs()
2988
        for i, s in enumerate(self.get_size()):
2989
            if s in symbols_to_define:
2990
                wrapper.writeline(
2991
                    f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}"
2992
                )
2993
                symbols_to_define.remove(s)
2994
        for i, s in enumerate(self.get_stride()):
2995
            if s in symbols_to_define:
2996
                wrapper.writeline(
2997
                    f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}"
2998
                )
2999
                symbols_to_define.remove(s)
3000
        if (s := self.get_offset()) in symbols_to_define:
3001
            wrapper.writeline(
3002
                f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}"
3003
            )
3004
            symbols_to_define.remove(s)
3005
        assert (
3006
            not symbols_to_define
3007
        ), f"unbacked symint {s} not written out, check comment above"
3008

3009
    def realize(self):
3010
        pass
3011

3012
    def get_workspace_size(self):
3013
        """
3014
        Gets extra global memory size needed by this buffer.
3015
        Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
3016
        """
3017
        return 0
3018

3019
    def should_allocate(self):
3020
        # Returns False by default.
3021
        return False
3022

3023

3024
class InputBuffer(Buffer):
3025
    pass
3026

3027

3028
class ConstantBuffer(InputBuffer):
3029
    override_device: Optional[torch.device] = None
3030

3031
    def make_loader(self):
3032
        def loader(index):
3033
            indexer = self.layout.make_indexer()
3034
            return ops.load(
3035
                V.graph.constant_name(self.get_name(), self.override_device),
3036
                indexer(index),
3037
            )
3038

3039
        return loader
3040

3041
    def constant_to_device(self, device):
3042
        return ConstantBuffer(
3043
            V.graph.constant_name(self.get_name(), device), self.layout
3044
        )
3045

3046

3047
class NoneAsConstantBuffer(IRNode):
3048
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3049
        return set()
3050

3051
    def codegen_reference(self, writer=None):
3052
        return V.graph.wrapper_code.none_str
3053

3054

3055
class ShapeAsConstantBuffer(IRNode):
3056
    def __init__(self, shape):
3057
        super().__init__()
3058
        self.shape = shape
3059

3060
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3061
        return free_unbacked_symbols(self.shape)
3062

3063
    def codegen_reference(self, writer=None):
3064
        return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape))
3065

3066

3067
@dataclasses.dataclass
3068
class ComputedBuffer(Buffer):
3069
    data: Loops
3070

3071
    def get_computed_buffer_name(self):
3072
        """
3073
        Returns self.name if it exists, otherwise returns the name of the data node if that exists.
3074
        If neither exist, returns None.
3075
        """
3076
        if self.name is not None:
3077
            return self.name
3078
        if hasattr(self.data, "name"):
3079
            return self.data.name
3080
        return None
3081

3082
    @cache_on_self
3083
    def num_reads(self):
3084
        return len(self.get_read_writes().reads)
3085

3086
    def get_read_writes(self):
3087
        with patch.object(FlexibleLayout, "allow_indexing", True):
3088
            if self.data.get_reduction_type():
3089
                return extract_read_writes(
3090
                    self.get_store_function(),
3091
                    self.data.get_pointwise_size(),
3092
                    self.data.get_reduction_size(),
3093
                )
3094
            else:
3095
                return extract_read_writes(
3096
                    self.get_store_function(),
3097
                    self.data.get_size(),
3098
                )
3099

3100
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
3101
        # Ordinarily, we'd like to just peek at the arguments list,
3102
        # but ComputedBuffers have no argument list.
3103
        #
3104
        # Morally, this logic needs to be synchronized with the
3105
        # KernelArgs.size calls, which are responsible for making symbols make
3106
        # there way as kernel arguments (and it is precisely passing in one of
3107
        # those symbols that establishes a dependency).  However, we haven't
3108
        # started codegen yet so we can't directly reuse that logic.
3109
        #
3110
        # For now, I'm just yoloing with the size of the buffer.  Not sure if
3111
        # it is enough.
3112
        #
3113
        # One thing you might wonder is if this is enough for a ComputedBuffer
3114
        # denoting a reduction over i0.  Empirically, it is enough, but for an
3115
        # unusual reason: we only need accurate dependencies for item() call,
3116
        # but it's impossible to end up with a reduction over i0 from an
3117
        # item() call without a regular non-reduction buffer first.
3118
        return (
3119
            free_unbacked_symbols(self.get_size())
3120
            | free_unbacked_symbols(self.get_stride())
3121
            | free_unbacked_symbols(self.get_offset())
3122
            | self.data.get_unbacked_symbol_uses()
3123
        )
3124

3125
    def make_loader(self):
3126
        # Inline constants and index_expressions
3127
        if (
3128
            hasattr(self.data, "make_loader")
3129
            and self.name not in V.graph.mutated_buffers
3130
            and self.num_reads() == 0
3131
        ):
3132
            # can be inlined
3133
            return self.data.make_loader()
3134
        return super().make_loader()
3135

3136
    def get_store_function(self):
3137
        indexer = self.layout.as_fixed().make_indexer()
3138
        if isinstance(self.data, (Reduction, Scan)):
3139
            return partial(self.data.store_reduction, self.name, indexer)
3140
        else:
3141
            assert isinstance(self.data, Pointwise)
3142
            return partial(self.data.store_output, self.name, indexer)
3143

3144
    def get_fill_order(self):
3145
        """
3146
        If our layout is still flexible, try to determine the stride order based on stride orders of reads.
3147

3148
        TODO(jansel): A better algorithm here would look at downstream consumers of this
3149
                      value and try to do global graph-level layout optimization.
3150
                      This is also something just begging to be autotuned.
3151
        """
3152
        if isinstance(self.layout, FlexibleLayout):
3153
            (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
3154
                self.data.get_pointwise_size(), self.data.get_reduction_size()
3155
            )
3156
            reads = self.get_read_writes().reads
3157
            reads_bufs = [
3158
                V.graph.name_to_buffer[r.name]
3159
                if r.name in V.graph.name_to_buffer.keys()
3160
                else None
3161
                for r in reads
3162
            ]
3163
            # only consider reads to buffer of same size
3164
            # ignore StarDeps because they don't contribute stride information
3165
            assert all(
3166
                isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
3167
                for r in reads
3168
            )
3169
            reads = [
3170
                sympy_subs(
3171
                    r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
3172
                )
3173
                for r in reads
3174
                if isinstance(r, dependencies.MemoryDep)
3175
            ]
3176

3177
            if reads:
3178
                if isinstance(self.data, Scan):
3179
                    indices = self.data.reindex(index_vars, reduction_vars)
3180
                else:
3181
                    indices = index_vars
3182
                stride_lengths = [
3183
                    V.graph.sizevars.stride_hints(expr, indices) for expr in reads  # type: ignore[arg-type]
3184
                ]
3185
                from .scheduler import pick_loop_order
3186

3187
                return pick_loop_order(stride_lengths, self.get_size())
3188

3189
        return None
3190

3191
    def decide_layout(self):
3192
        if isinstance(self.layout, FlexibleLayout):
3193
            order = self.get_fill_order()
3194
            if order:
3195
                self.freeze_layout_with_fill_order(order)
3196
            else:
3197
                self.freeze_layout()
3198

3199
    def simplify_and_reorder(self):
3200
        """
3201
        This is a main place where we do loop transformations in a
3202
        backend-agnostic way.
3203

3204
        Here we:
3205
            1) Remove any 1 dimensions
3206
            2) Fuse contiguous dimensions together
3207
            3) Reorder dimensions based on stride orders
3208
        """
3209
        args, var_ranges = dependencies.index_vars_squeeze(
3210
            self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
3211
        )
3212
        with patch.object(ConstantBuffer, "override_device", self.get_device()):
3213
            body = LoopBody(
3214
                self.get_store_function(),
3215
                (args if self.get_reduction_type() else args[:1]),
3216
                var_ranges,
3217
            )
3218
        index_formulas = [*body.indexing_exprs.values()]
3219
        reads_bufs = [
3220
            V.graph.name_to_buffer[reads_name]
3221
            if reads_name in V.graph.name_to_buffer.keys()
3222
            else None
3223
            for reads_name in body.reads_name2expr.keys()
3224
        ]
3225
        memory_addrs = [
3226
            *body.reads_name2expr.values(),
3227
            *body.writes_name2expr.values(),
3228
        ]
3229
        index_vars = []
3230
        reduce_vars: List[Any] = []
3231
        index_size = []
3232
        reduce_size = []
3233
        for v, s in var_ranges.items():
3234
            if v in args[0]:
3235
                assert not reduce_vars
3236
                index_vars.append(v)
3237
                index_size.append(s)
3238
            else:
3239
                assert v in args[1]
3240
                reduce_vars.append(v)
3241
                reduce_size.append(s)
3242

3243
        # the reordering_reindex in reads' simplify_reorder_and_tile
3244
        reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
3245
        for i, reads_buf in enumerate(reads_bufs):
3246
            if isinstance(reads_buf, ComputedBuffer) and hasattr(
3247
                reads_buf, "iter_reordering_reindex"
3248
            ):
3249
                reordering_reindex[i] = reads_buf.iter_reordering_reindex  # type: ignore[has-type]
3250

3251
        def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None):
3252
            sizes, reindex0, reindex1 = self._apply_loop_reordering(
3253
                x_vars, support_vars, sizes, memory_addrs, reordering_reindex
3254
            )
3255
            # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
3256
            x_vars = reindex0(x_vars)
3257
            sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
3258
                x_vars,
3259
                sizes,
3260
                index_prevent_reordering(index_formulas, x_vars, sizes),
3261
            )
3262
            x_vars = prune(x_vars)
3263
            # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
3264
            # x_vars = prune(x_vars)
3265
            # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
3266
            reindex = fuse_reindexing(reindex1, reindex2)
3267
            return sizes, reindex, reindex1
3268

3269
        support_vars = index_vars + reduce_vars
3270
        iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
3271
            index_vars, support_vars, index_size, reordering_reindex
3272
        )
3273
        reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
3274
            reduce_vars, support_vars, reduce_size
3275
        )
3276

3277
        # remember the reordering if not have loop collapse.
3278
        if len(iter_ranges) == len(index_vars):
3279
            self.iter_reordering_reindex = iter_reordering_reindex
3280
        # retrace the loop body with simplification and reordering applied
3281
        (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
3282
            iter_ranges, reduce_ranges, prefix="z"
3283
        )
3284
        body = LoopBody(
3285
            body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
3286
        )
3287
        return (iter_ranges, reduce_ranges), body
3288

3289
    @staticmethod
3290
    def _apply_loop_reordering(
3291
        index_vars,
3292
        support_vars,
3293
        sizes,
3294
        memory_addrs,
3295
        reordering_reindex=None,
3296
        priority_idx=None,
3297
    ):
3298
        """
3299
        Shuffle the order of loops around to hopefully improve performance.
3300
        """
3301
        from .scheduler import pick_loop_order
3302

3303
        if priority_idx is None:
3304
            priority_idx = []
3305

3306
        try:
3307
            strides = [
3308
                V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
3309
                for expr in memory_addrs
3310
            ]
3311
            assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
3312
                index_vars
3313
            )
3314
            # consider both layout(strides) and reordering(reordering_reindex)
3315
            if reordering_reindex is not None:
3316
                for i in range(len(memory_addrs)):
3317
                    try:
3318
                        strides[i] = reordering_reindex[i](strides[i])
3319
                    # if len(order) != len(strides), do not reorder
3320
                    except AssertionError:
3321
                        pass
3322
            order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
3323
        except Exception:
3324
            if config.debug:
3325
                log.warning(
3326
                    "Did not simplify complex index:\n%s\n%s",
3327
                    dict(zip(index_vars, sizes)),
3328
                    memory_addrs,
3329
                )
3330
            order = list(range(len(sizes)))
3331
        sizes = [sizes[i] for i in order]
3332
        return sizes, same_reorder(order), inverse_reorder(order)
3333

3334
    def get_reduction_size(self):
3335
        return self.data.get_reduction_size()
3336

3337
    def get_reduction_type(self):
3338
        return self.data.get_reduction_type()
3339

3340
    def is_no_op(self):
3341
        return self.data.is_zero_elements()
3342

3343
    def should_allocate(self):
3344
        return True
3345

3346
    def constant_to_device(self, device):
3347
        """Move this to a given device. Requires that all reads are to constants."""
3348
        return self.data.constant_to_device(device)
3349

3350

3351
class TemplateBuffer(Buffer):
3352
    """
3353
    Represents a Triton (in the future other type) of template operator
3354
    that we can fuse an epilogue onto.
3355
    """
3356

3357
    def __init__(self, layout, inputs, make_kernel_render):
3358
        super().__init__(name=None, layout=layout)
3359
        self.inputs = InputsKernel.unwrap_storage(inputs)
3360
        self.make_kernel_render = make_kernel_render
3361
        self.name = V.graph.register_buffer(self)
3362

3363
    def get_read_writes(self):
3364
        return self.normalized_read_writes()
3365

3366
    def normalized_read_writes(self):
3367
        name = self.get_name()
3368
        indexer = self.layout.make_indexer()
3369

3370
        def dummy(index, rindex):
3371
            assert len(rindex) == 0
3372
            return ops.store(name, indexer(index), "fake")
3373

3374
        deps = dependencies.extract_read_writes(
3375
            dummy, self.get_size(), (), normalize=True
3376
        )
3377
        deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
3378
        return deps
3379

3380
    def get_reduction_size(self):
3381
        return 1
3382

3383
    def get_reduction_type(self):
3384
        return None
3385

3386
    def is_no_op(self):
3387
        return False
3388

3389
    def should_allocate(self):
3390
        return True
3391

3392
    def simplify_and_reorder(self):
3393
        return (
3394
            (
3395
                self.get_size(),
3396
                (),
3397
            ),
3398
            None,
3399
        )
3400

3401

3402
class TritonTemplateBuffer(TemplateBuffer):
3403
    pass
3404

3405

3406
class CUDATemplateBuffer(TemplateBuffer):
3407
    def __init__(
3408
        self,
3409
        layout,
3410
        inputs,
3411
        make_kernel_render,
3412
        workspace_size: int,
3413
        template: "CUDATemplate",  # type: ignore[name-defined]  # noqa: F821
3414
    ):
3415
        super().__init__(layout, inputs, make_kernel_render)
3416
        # Global memory (in bytes) needed for this template.
3417
        self.workspace_size = workspace_size
3418
        self.template = template
3419

3420
    def get_workspace_size(self):
3421
        return self.workspace_size if self.workspace_size is not None else 0
3422

3423

3424
@dataclasses.dataclass
3425
class InputsKernel(Buffer):
3426
    inputs: List[Buffer]
3427

3428
    def get_read_writes_input(self, x):
3429
        return dependencies.StarDep(x.get_name())
3430

3431
    def get_read_writes(self):
3432
        star_dep = []
3433
        for input in self.inputs:
3434
            if isinstance(input, list):
3435
                star_dep.extend([self.get_read_writes_input(x) for x in input])
3436
            else:
3437
                star_dep.append(self.get_read_writes_input(input))
3438

3439
        return dependencies.ReadWrites(
3440
            set(star_dep),
3441
            {dependencies.StarDep(self.get_name())},
3442
            set(),
3443
            [],
3444
            None,
3445
            op_counts=collections.Counter(),
3446
        )
3447

3448
    @classmethod
3449
    def unwrap_storage_for_input(cls, x):
3450
        if isinstance(x, TensorBox):
3451
            x = x.data
3452
        if isinstance(x, StorageBox):
3453
            x = x.data
3454
        if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
3455
            x = ExternKernel.realize_input(x)
3456
        if isinstance(x, TensorBox):
3457
            # when converting to ReinterpretView fails in the
3458
            # realize_input call above, the result will be wrapped
3459
            # into TensorBox / StorageBox pair as a result of the
3460
            # cls.copy_input call; so we should unwrap recursively
3461
            return cls.unwrap_storage_for_input(x)
3462
        assert isinstance(x, (Buffer, ReinterpretView)), x
3463
        return x
3464

3465
    @staticmethod
3466
    def unwrap_storage(inputs):
3467
        inputs_new = []
3468
        for x in inputs:
3469
            if isinstance(x, list):
3470
                x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
3471
            else:
3472
                x = InputsKernel.unwrap_storage_for_input(x)
3473
            inputs_new.append(x)
3474
        return inputs_new
3475

3476
    def is_extern(self):
3477
        return True
3478

3479

3480
class NopKernel(InputsKernel):
3481
    def is_no_op(self):
3482
        return True
3483

3484

3485
class ConcatKernel(NopKernel):
3486
    """
3487
    There isn't actually a real kernel for concat, we just change the
3488
    storage for the upstream data.
3489
    """
3490

3491
    @classmethod
3492
    def create(cls, inputs, dim):
3493
        device = inputs[0].get_device()
3494
        dtype = inputs[0].get_dtype()
3495
        new_size = list(inputs[0].get_size())
3496
        offsets_start = [0]
3497
        offsets_end = [new_size[dim]]
3498
        assert 0 <= dim < len(new_size)
3499
        for i in range(1, len(inputs)):
3500
            input_size = inputs[i].get_size()
3501
            offsets_start.append(new_size[dim])
3502
            assert len(input_size) == len(new_size)
3503
            assert inputs[i].get_dtype() == dtype
3504
            assert inputs[i].get_device() == device
3505
            for j in range(len(new_size)):
3506
                if j == dim:
3507
                    new_size[j] = new_size[j] + input_size[j]
3508
                else:
3509
                    new_size[j] = V.graph.sizevars.guard_equals(
3510
                        new_size[j], input_size[j]
3511
                    )
3512
            offsets_end.append(new_size[dim])
3513

3514
        output_stride = FlexibleLayout.contiguous_strides(new_size)
3515
        # If any of the inputs is in CL format, use CL format for the output
3516
        for i in range(len(inputs)):
3517
            x = inputs[i]
3518
            if is_storage_and_layout(x):
3519
                layout = x.get_layout()
3520
                if (
3521
                    isinstance(layout, FixedLayout)
3522
                    and layout.is_channels_last_contiguous()
3523
                ):
3524
                    # use CL stride for the output
3525
                    output_stride = make_channels_last_strides_for(new_size)
3526
                    break
3527

3528
        concat_kernel = ConcatKernel(
3529
            name=None,
3530
            layout=FixedLayout(
3531
                device=device,
3532
                dtype=dtype,
3533
                size=new_size,
3534
                stride=output_stride,
3535
            ),
3536
            inputs=[],
3537
        )
3538
        kernel = StorageBox(concat_kernel)
3539
        buffer_names = []
3540
        for i in range(len(inputs)):
3541
            input_buffer = cls.realize_into(
3542
                inputs[i],
3543
                SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
3544
            )
3545
            concat_kernel.inputs.append(input_buffer)
3546

3547
            if isinstance(inputs[i].data, BaseView):
3548
                input_unwrapped = inputs[i].data.unwrap_view()
3549
            else:
3550
                input_unwrapped = inputs[i].data
3551

3552
            if (
3553
                input_unwrapped.is_input_buffer()
3554
                and inputs[i].get_device().type == "cuda"
3555
                and not is_dynamic(input_buffer)
3556
            ):
3557
                buffer_names.append(input_buffer.get_name())
3558

3559
        if len(buffer_names) > 1:
3560
            V.graph.register_list(buffer_names)
3561

3562
        concat_kernel.name = V.graph.register_buffer(concat_kernel)
3563
        concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
3564

3565
        return kernel
3566

3567
    @classmethod
3568
    def can_realize_into_without_copy(cls, src):
3569
        if isinstance(src, TensorBox):
3570
            # unwrap a TensorBox
3571
            return cls.can_realize_into_without_copy(src.data)
3572

3573
        return isinstance(src.data.layout, FlexibleLayout) and not isinstance(
3574
            src.data, ExternKernelAlloc
3575
        )
3576

3577
    @classmethod
3578
    def realize_into(cls, src, dst):
3579
        # Attempt to turn this into a ReinterpretView rather than assert.
3580
        # This has concessions around layout, as as_storage_and_layout
3581
        # can cause us to go from flexible to fixed layout.
3582
        if not isinstance(dst, ReinterpretView):
3583
            if is_storage_and_layout(dst):
3584
                storage, layout = as_storage_and_layout(dst)
3585
                dst = ReinterpretView(storage, layout)
3586
        assert isinstance(dst, ReinterpretView), dst
3587
        if isinstance(src, TensorBox):
3588
            # unwrap a TensorBox
3589
            return cls.realize_into(src.data, dst)
3590
        if isinstance(src, StorageBox):
3591
            src.realize()
3592
            # ExternKernelAlloc has specific requirements for output layout, should create a copy
3593
            assert hasattr(src.data, "layout")
3594
            if cls.can_realize_into_without_copy(src):
3595
                src.data.layout = AliasedLayout(dst)
3596
                return src.data
3597
        # introduce a copy
3598
        pw = Pointwise.create(
3599
            device=src.get_device(),
3600
            dtype=src.get_dtype(),
3601
            inner_fn=src.make_loader(),
3602
            ranges=[
3603
                V.graph.sizevars.guard_equals(a, b)
3604
                for a, b in zip(src.get_size(), dst.get_size())
3605
            ],
3606
        )
3607
        return cls.realize_into(pw, dst)
3608

3609
    def should_allocate(self):
3610
        return True
3611

3612

3613
@dataclasses.dataclass
3614
class ExternKernel(InputsKernel):
3615
    constant_args: Tuple[Any, ...] = ()
3616
    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
3617
    output_view: Optional[ReinterpretView] = None
3618
    python_kernel_name: Optional[str] = None
3619
    cpp_kernel_name: Optional[str] = None
3620
    # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
3621
    # We shouldn't need to do this since the information can be retrieved from op_overload._schema.
3622
    ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
3623
        default_factory=list
3624
    )
3625
    op_overload: Optional[
3626
        Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
3627
    ] = None
3628
    arg_properties: Optional[List[Dict[str, Any]]] = None
3629
    kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
3630

3631
    def __init__(
3632
        self,
3633
        name,
3634
        layout,
3635
        inputs,
3636
        constant_args=(),
3637
        kwargs=None,
3638
        output_view=None,
3639
        python_kernel_name=None,
3640
        cpp_kernel_name=None,
3641
        ordered_kwargs_for_cpp_kernel=(),
3642
        op_overload=None,
3643
    ):
3644
        super().__init__(
3645
            name,
3646
            layout,
3647
            inputs,
3648
        )
3649
        self.constant_args = constant_args
3650
        self.kwargs = kwargs if kwargs else {}
3651
        self.output_view = output_view
3652
        self.python_kernel_name = python_kernel_name
3653
        self.cpp_kernel_name = cpp_kernel_name
3654
        self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
3655
        self.op_overload = op_overload
3656
        self.collect_arg_kwarg_properties()
3657

3658
    def collect_arg_kwarg_properties(self):
3659
        # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
3660
        # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
3661
        if (
3662
            isinstance(self.op_overload, torch._ops.OpOverload)
3663
            and not self.ordered_kwargs_for_cpp_kernel
3664
        ):
3665
            self.ordered_kwargs_for_cpp_kernel = [
3666
                x.name for x in self.op_overload._schema.arguments if x.kwarg_only
3667
            ]
3668
        self.arg_properties = (
3669
            [
3670
                {
3671
                    "name": x.name,
3672
                    "type": x.real_type,
3673
                    "default_value": x.default_value,
3674
                }
3675
                for x in self.op_overload._schema.arguments
3676
                if not x.kwarg_only
3677
            ]
3678
            if isinstance(self.op_overload, torch._ops.OpOverload)
3679
            else [{} for i in range(len(self.inputs))]
3680
        )
3681
        self.kwarg_properties = (
3682
            {
3683
                x.name: {"type": x.real_type, "default_value": x.default_value}
3684
                for x in self.op_overload._schema.arguments
3685
                if x.kwarg_only
3686
            }
3687
            if isinstance(self.op_overload, torch._ops.OpOverload)
3688
            else {}
3689
        )
3690

3691
    def decide_layout(self):
3692
        if isinstance(self.layout, FlexibleLayout):
3693
            self.apply_constraint()
3694
            self.freeze_layout()
3695

3696
    def codegen_comment(self, wrapper):
3697
        origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper)
3698
        if origin_str:
3699
            wrapper.writeline(origin_str)
3700

3701
    def codegen(self, wrapper):
3702
        raise NotImplementedError()
3703

3704
    def get_kernel_name(self):
3705
        return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
3706

3707
    @staticmethod
3708
    def copy_input(x):
3709
        pw = Pointwise.create(
3710
            device=x.get_device(),
3711
            dtype=x.get_dtype(),
3712
            inner_fn=x.make_loader(),
3713
            ranges=x.get_size(),
3714
            origin_node=x.get_origin_node(),
3715
            traceback=x.get_traceback(),
3716
        )
3717
        pw.realize()
3718
        return pw
3719

3720
    @classmethod
3721
    def process_kernel(cls, kernel, *args, **kwargs):
3722
        binded_args = {"args": args, "kwargs": kwargs}
3723

3724
        args_flat, args_spec = pytree.tree_flatten(binded_args)
3725

3726
        is_arg_tensor = []
3727
        tensor_args = []
3728
        non_tensor_args: List[Any] = []
3729
        for arg in args_flat:
3730
            is_arg_tensor.append(isinstance(arg, IRNode))
3731
            if is_arg_tensor[-1]:
3732
                tensor_args.append(arg)
3733
            else:
3734
                if isinstance(arg, sympy.Expr):
3735
                    arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
3736
                non_tensor_args.append(arg)
3737

3738
        def unflatten_args(new_tensor_args, new_non_tensor_args):
3739
            result = []
3740
            it_tensors = iter(new_tensor_args)
3741
            it_non_tensors = iter(new_non_tensor_args)
3742
            for is_tensor in is_arg_tensor:
3743
                if is_tensor:
3744
                    result.append(next(it_tensors))
3745
                else:
3746
                    result.append(next(it_non_tensors))
3747
            r = pytree.tree_unflatten(result, args_spec)
3748
            return r.get("args", []), r.get("kwargs", {})
3749

3750
        tensor_args = [cls.realize_input(x) for x in tensor_args]
3751

3752
        # freeze layout otherwise our output stride calculation might
3753
        # become incorrect
3754
        for x in tensor_args:
3755
            if is_storage_and_layout(x):
3756
                as_storage_and_layout(x, freeze=True)
3757

3758
        # We don't have generic shape formulas, so just burn in the
3759
        # shapes and run an example input.
3760
        # TODO(jansel): replace this with dynamic shape formulas
3761
        example_args = []
3762

3763
        # We need to retain the constant values of fake tensors that we originally
3764
        # propagated the graph with, because for some operators running without a
3765
        # constant would trigger an error / DataDependentException
3766
        for x in tensor_args:
3767
            if x.get_name() in V.graph.constants:
3768
                example_args.append(V.graph.constants[x.get_name()])
3769
            else:
3770
                example_args.append(ir_node_to_tensor(x, guard_shape=True))
3771

3772
        new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
3773
        example_output = kernel(*new_args, **new_kwargs)
3774

3775
        example_out_li = (
3776
            [example_output]
3777
            if not isinstance(example_output, (list, tuple))
3778
            else example_output
3779
        )
3780
        for t in example_out_li:
3781
            if isinstance(t, torch.Tensor) and t.is_sparse:
3782
                msg = "sparsity not handled. Please file issue for sparse inference weights."
3783
                if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
3784
                    msg = f"{msg} Found from : \n {stack_trace}"
3785
                V.graph.disable_cudagraphs_reason = msg
3786

3787
        # TODO: Unconditionally do this, not just when example_output has
3788
        # unbacked symbols
3789
        if maybe_free_unbacked_symbols(example_output):
3790
            example_output = V.graph.current_node.meta["val"]
3791

3792
        return example_output, tensor_args, non_tensor_args, unflatten_args
3793

3794
    @classmethod
3795
    def convert_to_reinterpret_view(cls, x):
3796
        """
3797
        In order to pass this to an extern kernel we need a
3798
        ReinterpretView not a View.  This allows us to avoid some
3799
        unneeded copies.
3800
        """
3801
        assert isinstance(x, BaseView)
3802
        if isinstance(x, ReinterpretView):
3803
            return x
3804

3805
        # NOTE: Don't use extract_read_writes here as it fails when
3806
        # make_loader() inlines the computation
3807
        x.unwrap_view().freeze_layout()
3808
        index_args, var_ranges = dependencies.index_vars_squeeze(
3809
            x.get_size(), prefix="r"
3810
        )
3811
        range_vars = index_args[0]
3812
        index = x.make_indexer()(range_vars)
3813

3814
        index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
3815
        strides = V.graph.sizevars.stride_vars(index, range_vars)
3816
        offset = V.graph.sizevars.offset_var(index, range_vars)
3817
        expected = sympy_dot(range_vars, strides) + offset
3818

3819
        if index != expected:
3820
            log.debug(
3821
                "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
3822
                strides,
3823
                offset,
3824
                index,
3825
            )
3826
            raise NotImplementedError()
3827

3828
        return ReinterpretView(
3829
            data=x.data,
3830
            layout=FixedLayout(
3831
                device=x.get_device(),
3832
                dtype=x.get_dtype(),
3833
                size=x.get_size(),
3834
                stride=strides,
3835
                offset=offset,
3836
            ),
3837
        )
3838

3839
    @classmethod
3840
    def realize_input(cls, x):
3841
        if x is None:
3842
            return NoneAsConstantBuffer()
3843
        if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)):
3844
            return ShapeAsConstantBuffer(x)
3845
        if isinstance(x, Constant):
3846
            return V.graph.add_tensor_constant(
3847
                torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
3848
            )
3849
        if isinstance(x, ConstantBuffer):
3850
            return x
3851
        if isinstance(x, TensorBox):
3852
            return cls.realize_input(x.data)
3853
        if isinstance(x, ReinterpretView):
3854
            return ReinterpretView(cls.realize_input(x.data), x.get_layout())
3855
        if isinstance(x, BaseView):
3856
            x.realize()
3857
            if is_storage_and_layout(x.unwrap_view()):
3858
                try:
3859
                    return cls.convert_to_reinterpret_view(x)
3860
                except NotImplementedError:
3861
                    pass
3862
        if isinstance(x, StorageBox):
3863
            # TODO(jansel): impose layout preference on realized buffer
3864
            x.realize()
3865
            return x
3866
        return cls.copy_input(x)
3867

3868
    @classmethod
3869
    def require_stride1(cls, x):
3870
        if is_storage_and_layout(x):
3871
            if len(x.get_stride()) == 0:
3872
                return x
3873
            for stride in x.get_stride():
3874
                if stride == 1:
3875
                    return x
3876
        return cls.copy_input(x)
3877

3878
    @classmethod
3879
    def require_stride_order(cls, x, order):
3880
        if x.get_numel() == 0:  # Layout doesn't matter
3881
            return x
3882

3883
        # require x to have the layout as strided_ordered as order
3884
        if is_storage_and_layout(x):
3885
            while isinstance(x.get_layout(), AliasedLayout):
3886
                x = x.get_layout().view
3887
            if isinstance(x.get_layout(), FlexibleLayout):
3888
                # fix flexiblelayout to be FixedLayout with stride_order
3889
                as_storage_and_layout(
3890
                    x, freeze=True, want_contiguous=False, stride_order=order
3891
                )
3892
                return x
3893
            elif isinstance(
3894
                x.get_layout(), FixedLayout
3895
            ) and x.get_layout().is_stride_ordered(order):
3896
                return x
3897
            elif isinstance(x.get_layout(), MutationLayout):
3898
                if isinstance(x.get_layout().real_layout(), FlexibleLayout):
3899
                    raise AssertionError(
3900
                        "the MutationLayout's real layout shouldn't be FlexibleLayout"
3901
                    )
3902
                elif isinstance(
3903
                    x.get_layout().real_layout(), FixedLayout
3904
                ) and x.get_layout().real_layout().is_stride_ordered(order):
3905
                    return x
3906

3907
        # TODO - Storage to InputBuffer
3908
        if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
3909
            return x
3910
        if (
3911
            isinstance(x, TensorBox)
3912
            and isinstance(x.data, BaseView)
3913
            and not isinstance(x.data, ReinterpretView)
3914
            and is_storage_and_layout(x.unwrap_view())
3915
            and not isinstance(x.unwrap_view().data, ExternKernelAlloc)
3916
        ):
3917
            try:
3918
                x.data = cls.convert_to_reinterpret_view(x.data)
3919
                return cls.require_stride_order(x, order)
3920
            except NotImplementedError:
3921
                pass
3922
        x = cls.copy_input(x)
3923
        as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
3924
        assert is_stride_order_storage_and_layout(x, order)
3925
        return x
3926

3927
    @classmethod
3928
    def require_channels_last(cls, x):
3929
        return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
3930

3931
    @classmethod
3932
    def require_contiguous(cls, x):
3933
        return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
3934

3935
    def apply_constraint(self):
3936
        pass
3937

3938
    def codegen_const_args(self):
3939
        return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args)
3940

3941
    def codegen_args(self):
3942
        args = []
3943
        for i, x in enumerate(self.inputs):
3944
            if isinstance(x, list):
3945
                names = [i.codegen_reference() for i in x]
3946
                codegen_reference = f'[{", ".join(names)}]'
3947
                args.append(codegen_reference)
3948
            else:
3949
                if V.graph.cpp_wrapper:
3950
                    assert self.arg_properties and i < len(
3951
                        self.arg_properties
3952
                    ), "Invalid arg_properties accessing"
3953
                    type_ = self.arg_properties[i].get("type")
3954
                    args.append(
3955
                        V.graph.wrapper_code.val_to_cpp_arg_str(  # type: ignore[arg-type]
3956
                            type_, x, self.is_legacy_abi_kernel()
3957
                        )
3958
                    )
3959
                else:
3960
                    args.append(x.codegen_reference())
3961
        args.extend(self.codegen_const_args())
3962
        return args
3963

3964
    def get_kwargs_value(self, arg_name):
3965
        if arg_name in self.kwargs:
3966
            return self.kwargs.get(arg_name)
3967
        if self.kwarg_properties and self.kwarg_properties.get(arg_name):
3968
            return self.kwarg_properties.get(arg_name).get("default_value")  # type: ignore[union-attr]
3969
        else:
3970
            raise AssertionError(f"{arg_name} not in self.kwarg_properties")
3971

3972
    def is_legacy_abi_kernel(self):
3973
        return False
3974

3975
    def codegen_kwargs(self):
3976
        if V.graph.cpp_wrapper:
3977
            kwargs = []
3978
            for arg_name in self.ordered_kwargs_for_cpp_kernel:
3979
                v = self.get_kwargs_value(arg_name)
3980
                if isinstance(v, sympy.Expr):
3981
                    kwargs.append(v)
3982
                else:
3983
                    type_ = (
3984
                        self.kwarg_properties.get(arg_name).get("type")  # type: ignore[union-attr]
3985
                        if self.kwarg_properties and arg_name in self.kwarg_properties
3986
                        else None
3987
                    )
3988
                    kwargs.append(
3989
                        V.graph.wrapper_code.val_to_cpp_arg_str(  # type: ignore[arg-type]
3990
                            type_, v, self.is_legacy_abi_kernel()
3991
                        )
3992
                    )
3993
        else:
3994
            kwargs = [
3995
                f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"  # type: ignore[misc]
3996
                for k, v in self.kwargs.items()
3997
            ]
3998
        return kwargs
3999

4000
    def codegen_size_asserts(self, wrapper):
4001
        if config.size_asserts and not V.graph.cpp_wrapper:
4002
            size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
4003
            stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
4004
            wrapper.writeline(
4005
                f"assert_size_stride({self.get_name()}, {size}, {stride})"
4006
            )
4007

4008
    def get_group_stride(self):
4009
        """
4010
        get output sizes and strides, for template_codegen
4011
        """
4012
        _size = self.get_size()
4013
        _stride = self.get_stride()
4014
        # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
4015
        return [_size, []], _stride
4016

4017
    def canonicalize(self):
4018
        """
4019
        Manually get canonicalization of the output index
4020
        """
4021
        # manually generate index formula for conv
4022
        sizevars = V.graph.sizevars
4023
        sizes = self.get_size()
4024
        strides = self.get_stride()
4025
        strides = [sizevars.size_hint(x) for x in strides]
4026
        index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
4027
        # reorder index vars according to stride
4028
        index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
4029
        lookup = {pos: idx for idx, pos in enumerate(index_order)}
4030
        order = [lookup[i] for i in range(len(lookup))]
4031
        index_vars = [index_vars[i] for i in order]
4032
        indexer = self.make_indexer()
4033
        index = indexer(index_vars)
4034

4035
        new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
4036
            index_vars, sizes, [index]
4037
        )
4038

4039
        # assign new variables each dimension to deal with numbering mismatches
4040
        # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
4041
        _, add_var = var_builder("c")
4042
        replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
4043

4044
        index = sympy_subs(sympy.expand(index), replacement)  # type: ignore[arg-type]
4045
        return index, tuple(new_sizes)
4046

4047
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
4048
        # NB: It's not necessary to check regular inputs as we automatically
4049
        # have dependencies on them
4050
        r = set()
4051
        for arg in self.constant_args:
4052
            r |= maybe_free_unbacked_symbols(arg)
4053
        for arg in self.kwargs.values():
4054
            r |= maybe_free_unbacked_symbols(arg)
4055
        return r
4056

4057
    def __str__(self):
4058
        kernel_name = getattr(self, "python_kernel_name", None)
4059
        lines = [
4060
            f"python_kernel_name={kernel_name!r}",
4061
        ]
4062
        lines += [
4063
            f"{field.name}={getattr(self, field.name)}"
4064
            for field in dataclasses.fields(self)
4065
        ]
4066
        lines.append(f"origin_node={self.origin_node!r}")
4067
        return self.str_helper(lines)
4068

4069
    __repr__ = __str__
4070

4071

4072
@dataclasses.dataclass
4073
class ExternKernelOut(ExternKernel):
4074
    def codegen(self, wrapper):
4075
        self.codegen_comment(wrapper)
4076
        args = [*self.codegen_args(), *self.codegen_kwargs()]
4077
        wrapper.generate_extern_kernel_out(
4078
            self.output_view,
4079
            self.codegen_reference(),
4080
            args,
4081
            self.get_kernel_name(),
4082
        )
4083

4084
    def __init__(
4085
        self,
4086
        layout,
4087
        inputs,
4088
        constant_args=(),
4089
        kwargs=None,
4090
        output_view=None,
4091
        python_kernel_name=None,
4092
        cpp_kernel_name=None,
4093
        ordered_kwargs_for_cpp_kernel=(),
4094
        op_overload=None,
4095
    ):
4096
        super().__init__(
4097
            None,
4098
            layout,
4099
            self.unwrap_storage(inputs),
4100
            constant_args,
4101
            kwargs or {},
4102
            None,
4103
            python_kernel_name,
4104
            cpp_kernel_name,
4105
            ordered_kwargs_for_cpp_kernel,
4106
            op_overload,
4107
        )
4108
        self.name = V.graph.register_buffer(self)
4109

4110
    def should_allocate(self):
4111
        return True
4112

4113

4114
class RandomSeeds(ExternKernelOut):
4115
    def __init__(self, count: int, device: torch.device):
4116
        limits = torch.iinfo(torch.int64)
4117
        super().__init__(
4118
            layout=FixedLayout(
4119
                device=device,
4120
                dtype=torch.int64,
4121
                size=[count],
4122
            ),
4123
            inputs=[],
4124
            constant_args=[limits.min, limits.max, [count]],
4125
            python_kernel_name="aten.randint.low_out",
4126
            cpp_kernel_name="at::randint_out",
4127
        )
4128

4129

4130
class ExternKernelAlloc(ExternKernel):
4131
    def codegen(self, wrapper):
4132
        self.codegen_comment(wrapper)
4133
        args = [*self.codegen_args(), *self.codegen_kwargs()]
4134
        V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
4135
        if isinstance(self.layout, Layout):
4136
            self.codegen_size_asserts(wrapper)
4137

4138
    def __init__(
4139
        self,
4140
        layout,
4141
        inputs,
4142
        constant_args=(),
4143
        kwargs=None,
4144
        python_kernel_name=None,
4145
        cpp_kernel_name=None,
4146
        ordered_kwargs_for_cpp_kernel=(),
4147
        op_overload=None,
4148
    ):
4149
        super().__init__(
4150
            None,
4151
            layout,
4152
            self.unwrap_storage(inputs),
4153
            constant_args,
4154
            kwargs or {},
4155
            None,
4156
            python_kernel_name,
4157
            cpp_kernel_name,
4158
            ordered_kwargs_for_cpp_kernel,
4159
            op_overload,
4160
        )
4161
        self.name = V.graph.register_buffer(self)
4162

4163
    def should_allocate(self):
4164
        return False
4165

4166
    def apply_constraint(self):
4167
        raise NotImplementedError
4168

4169

4170
class UserDefinedTritonKernel(ExternKernel):
4171
    def get_kernel_and_configs(self):
4172
        from triton.runtime.autotuner import Autotuner
4173

4174
        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
4175

4176
        kernel = kernel_side_table.get_kernel(self.kernel_idx)
4177
        configs = []
4178
        if isinstance(kernel, Autotuner):
4179
            configs = kernel.configs
4180
            kernel = kernel.fn
4181
        return kernel, configs
4182

4183
    def codegen(self, wrapper):
4184
        kernel, configs = self.get_kernel_and_configs()
4185

4186
        # Definition of kernel
4187
        new_name, triton_meta = wrapper.define_user_defined_triton_kernel(
4188
            kernel, configs, self.kwargs
4189
        )
4190

4191
        args = self.codegen_kwargs()
4192
        if V.graph.cpp_wrapper:
4193
            # in C++ wrapper, we don't pass constexpr args, as they don't
4194
            # get added as parameters to the PTX code compiled from the
4195
            # user-defined Triton kernel (only non-constexpr args do)
4196
            args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs]
4197

4198
        # Call to kernel
4199
        self.codegen_comment(wrapper)
4200
        wrapper.generate_user_defined_triton_kernel(
4201
            new_name,
4202
            self.grid,
4203
            configs,
4204
            args,
4205
            triton_meta,
4206
        )
4207

4208
    def should_allocate(self):
4209
        return False
4210

4211
    def has_side_effects(self):
4212
        # UserDefinedTritonKernel does not return anything, but rather
4213
        # modifies input in place, do not let it get DCEd
4214
        return True
4215

4216
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4217
        return set()
4218

4219
    def get_mutation_names(self):
4220
        return []
4221

4222
    def __init__(self, *, kernel_idx, grid, kernel_args):
4223
        inputs = []
4224
        kwargs = dict()
4225
        constant_args = []
4226
        for k, v in kernel_args.items():
4227
            if isinstance(v, TensorBox):
4228
                t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
4229
                inputs.append(t)
4230
                kwargs[k] = t
4231
            else:
4232
                constant_args.append(v)
4233
                kwargs[k] = v
4234

4235
        assert len(inputs) != 0
4236
        device = inputs[0].get_device()
4237

4238
        super().__init__(
4239
            None,
4240
            NoneLayout(device),  # type: ignore[arg-type]
4241
            inputs,
4242
            tuple(constant_args),
4243
            kwargs,
4244
        )
4245
        self.name = V.graph.register_buffer(self)
4246
        self.kernel_idx = kernel_idx
4247
        self.grid = grid
4248

4249
        kernel, _ = self.get_kernel_and_configs()
4250
        # If we are autotuning, not all arguments will be passed
4251
        self.ordered_kwargs_for_cpp_kernel = [
4252
            arg for arg in kernel.arg_names if arg in kernel_args
4253
        ]
4254

4255
        mark_node_as_mutating(
4256
            self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)]
4257
        )
4258

4259
    def get_alias_names(self):
4260
        return [i.get_name() for i in self.inputs]
4261

4262

4263
def mark_node_as_mutating(cur_buffer, *mutated_ops):
4264
    """
4265
    Allows ops in mutated_ops to be marked as being mutated as well as
4266
    indicates to the scheduler that these ops depend on cur_buffer.
4267
    """
4268
    for op in mutated_ops:
4269
        assert isinstance(op, IRNode), op
4270
        V.graph.mark_buffer_mutated(op.get_name())
4271
        assert hasattr(op, "layout")
4272
        MutationOutput(op.layout, op, cur_buffer)
4273

4274

4275
class MutationOutput(ExternKernel):
4276
    def get_mutation_names(self):
4277
        return [self.inputs[0].get_name()]
4278

4279
    def __init__(self, layout, input, parent):
4280
        super().__init__(None, layout, [input, parent], ())
4281
        self.name = V.graph.register_buffer(self)
4282

4283
    def should_allocate(self):
4284
        return False
4285

4286
    def is_no_op(self):
4287
        return True
4288

4289
    def has_side_effects(self):
4290
        return True
4291

4292
    def get_alias_names(self):
4293
        return [self.inputs[0].get_name()]
4294

4295

4296
class InplaceBernoulliFallback(ExternKernel):
4297
    """
4298
    This needs to be a custom class to handle mutation properly
4299
    """
4300

4301
    def codegen(self, wrapper):
4302
        (x,) = (t.codegen_reference() for t in self.inputs)
4303
        wrapper.writeline(
4304
            f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
4305
        )
4306

4307
    def should_allocate(self):
4308
        return False
4309

4310
    def get_mutation_names(self):
4311
        return [self.inputs[0].get_name()]
4312

4313
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4314
        return set()
4315

4316
    def __init__(self, x, *constant_args):
4317
        super().__init__(
4318
            None,
4319
            NoneLayout(x.get_device()),  # type: ignore[arg-type]
4320
            self.unwrap_storage([x]),
4321
            constant_args,
4322
        )
4323
        self.name = V.graph.register_buffer(self)
4324
        self.python_kernel_name = "aten.bernoulli_"
4325
        self.cpp_kernel_name = "at::native::bernoulli_"
4326
        mark_node_as_mutating(self, x)
4327

4328

4329
# Used to deal with torch.complex types
4330
class InplaceCopyFallback(ExternKernel):
4331
    """
4332
    This needs to be a custom class to handle mutation properly
4333
    """
4334

4335
    def codegen(self, wrapper):
4336
        (dst, src, non_blocking) = self.codegen_args()
4337
        wrapper.writeline(
4338
            f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}"
4339
        )
4340

4341
    def should_allocate(self):
4342
        return False
4343

4344
    def get_mutation_names(self):
4345
        return [self.inputs[0].get_name()]
4346

4347
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4348
        return set()
4349

4350
    def __init__(
4351
        self,
4352
        layout,
4353
        inputs,
4354
        constant_args,
4355
    ):
4356
        super().__init__(
4357
            None,
4358
            layout,
4359
            inputs,
4360
            constant_args,
4361
            python_kernel_name="aten.copy_",
4362
            cpp_kernel_name=(
4363
                "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call"
4364
            ),
4365
        )
4366
        self.name = V.graph.register_buffer(self)
4367

4368
    @classmethod
4369
    def create(cls, dst, src, non_blocking: bool = False):
4370
        inputs = [cls.realize_input(t) for t in [dst, src]]
4371
        constant_args = (non_blocking,)
4372
        result = InplaceCopyFallback(
4373
            NoneLayout(dst.get_device()),  # type: ignore[arg-type]
4374
            inputs,
4375
            constant_args,
4376
        )
4377
        mark_node_as_mutating(result, dst)
4378
        return result
4379

4380

4381
class MutatingFirstArgExternKernel(ExternKernel):
4382
    """
4383
    This needs to be a custom class to handle mutation properly
4384
    """
4385

4386
    def codegen(self, wrapper):
4387
        argrefs = [
4388
            *(t.codegen_reference() for t in self.inputs),
4389
            *map(repr, self.constant_args),
4390
        ]
4391
        wrapper.writeline(
4392
            f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
4393
        )
4394

4395
    def should_allocate(self):
4396
        return False
4397

4398
    def get_mutation_names(self):
4399
        return [self.inputs[0].get_name()]
4400

4401
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4402
        return set()
4403

4404
    def has_side_effects(self):
4405
        return True
4406

4407

4408
class ResizeStorageBytes(MutatingFirstArgExternKernel):
4409
    def __init__(self, variable, new_size):
4410
        assert isinstance(new_size, int), "TODO: dynamic shapes"
4411
        super().__init__(
4412
            None,
4413
            NoneLayout(variable.get_device()),  # type: ignore[arg-type]
4414
            self.unwrap_storage([variable]),
4415
            constant_args=(new_size,),
4416
        )
4417
        V.graph.mark_buffer_mutated(variable.get_name())
4418
        self.name = V.graph.register_buffer(self)
4419
        self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
4420
        self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
4421
        V.graph.never_reuse_buffers.add(variable.data.get_name())
4422
        mark_node_as_mutating(self, variable)
4423

4424

4425
class ScatterFallback(ExternKernel):
4426
    """
4427
    This needs to be a custom class to handle mutation properly.
4428
    This class handles both aten.scatter_ and aten.scatter_reduce_.
4429
    It also handle the case `src` being a scalar properly.
4430
    """
4431

4432
    def codegen(self, wrapper):
4433
        reduce = self.kwargs["reduce"]
4434
        if V.graph.cpp_wrapper:
4435
            # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
4436
            get_operator_enum = {"add": "sum", "multiply": "prod"}
4437
            if reduce in get_operator_enum:
4438
                reduce = get_operator_enum[reduce]
4439

4440
        if self.src_is_tensor:
4441
            (x, index, src) = (t.codegen_reference() for t in self.inputs)
4442
        else:
4443
            (x, index) = (t.codegen_reference() for t in self.inputs)
4444
            src = self.constant_args[1]
4445
        wrapper.generate_scatter_fallback(
4446
            x,
4447
            [x, self.constant_args[0], index, src],
4448
            self.get_kernel_name(),
4449
            self.python_kernel_name,
4450
            self.src_is_tensor,
4451
            reduce,
4452
            self.codegen_kwargs(),
4453
        )
4454

4455
    def should_allocate(self):
4456
        return False
4457

4458
    def get_cpp_kernel(self):
4459
        reduce = self.kwargs["reduce"]
4460
        if self.python_kernel_name == "aten.scatter_":
4461
            if self.src_is_tensor:
4462
                kernel = (
4463
                    "at::scatter_out" if reduce is None else "at::scatter_reduce_out"
4464
                )
4465
            else:
4466
                assert (
4467
                    reduce is None
4468
                ), "Expect reduce to be None for aten.scatter_ with scalar src"
4469
                kernel = "at::scatter_out"
4470
        else:
4471
            assert (
4472
                reduce is not None
4473
            ), "Expect reduce to be not None for aten.scatter_reduce_"
4474
            kernel = "at::scatter_reduce_out"
4475
        return kernel
4476

4477
    def get_mutation_names(self):
4478
        return [self.inputs[0].get_name()]
4479

4480
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4481
        return set()
4482

4483
    def __init__(
4484
        self,
4485
        op_overload,
4486
        python_kernel_name,
4487
        x,
4488
        dim: int,
4489
        index,
4490
        src,
4491
        *,
4492
        reduce: Optional[str] = None,
4493
        include_self: bool = True,
4494
    ):
4495
        assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"}
4496
        self.src_is_tensor = isinstance(src, TensorBox)
4497

4498
        constant_args: Tuple[Any, ...]
4499
        if self.src_is_tensor:
4500
            tensors = [self.realize_input(t) for t in [x, index, src]]
4501
            constant_args = (dim,)
4502
        else:
4503
            tensors = [self.realize_input(t) for t in [x, index]]
4504
            constant_args = (dim, src)
4505

4506
        super().__init__(
4507
            None,
4508
            NoneLayout(x.get_device()),  # type: ignore[arg-type]
4509
            self.unwrap_storage(tensors),
4510
            constant_args,
4511
            {"reduce": reduce, "include_self": include_self},
4512
            python_kernel_name=python_kernel_name,
4513
            ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
4514
            op_overload=op_overload,
4515
        )
4516
        self.cpp_kernel_name = self.get_cpp_kernel()
4517
        self.name = V.graph.register_buffer(self)
4518
        mark_node_as_mutating(self, x)
4519

4520

4521
class IndexPutFallback(ExternKernel):
4522
    """
4523
    This needs to be a custom class to handle mutation and indices properly
4524
    """
4525

4526
    def codegen(self, wrapper):
4527
        (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
4528
        indices = []
4529
        iter_valid_indices = iter(valid_indices)
4530
        for i, _ in enumerate(self.indices):
4531
            if self.indices[i] is not None:
4532
                indices.append(next(iter_valid_indices))
4533
            else:
4534
                indices.append(V.graph.wrapper_code.none_str)
4535

4536
        wrapper.generate_index_put_fallback(
4537
            self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
4538
        )
4539

4540
    def should_allocate(self):
4541
        return False
4542

4543
    def get_mutation_names(self):
4544
        return [self.inputs[0].get_name()]
4545

4546
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4547
        return set()
4548

4549
    def __init__(self, op_overload, x, indices, values, accumulate):
4550
        self.indices = indices
4551
        valid_indices = [i for i in indices if i is not None]
4552
        tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
4553
        cpp_kernel_name = (
4554
            "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out"
4555
        )
4556
        super().__init__(
4557
            None,
4558
            NoneLayout(x.get_device()),  # type: ignore[arg-type]
4559
            self.unwrap_storage(tensors),
4560
            (accumulate,),
4561
            python_kernel_name="aten.index_put_",
4562
            cpp_kernel_name=cpp_kernel_name,
4563
            op_overload=op_overload,
4564
        )
4565
        self.name = V.graph.register_buffer(self)
4566
        mark_node_as_mutating(self, x)
4567

4568

4569
class DeviceCopy(ExternKernelOut):
4570
    @classmethod
4571
    def create(cls, x, device):
4572
        if (
4573
            not x.is_extern()
4574
            and all(
4575
                (r.name in V.graph.constants and isinstance(r, dependencies.MemoryDep))
4576
                for r in x.get_reads()
4577
            )
4578
            and not config.aot_inductor.use_runtime_constant_folding
4579
        ):
4580
            return x.constant_to_device(device)
4581

4582
        V.graph.add_device_info(device)
4583
        V.graph.add_device_info(x.get_device())
4584

4585
        developer_warning("DeviceCopy in input program")
4586
        return DeviceCopy(
4587
            FlexibleLayout(
4588
                device=device,
4589
                dtype=x.get_dtype(),
4590
                size=x.get_size(),
4591
            ),
4592
            [cls.realize_input(x)],
4593
        )
4594

4595
    def codegen(self, wrapper):
4596
        args = self.codegen_args()
4597
        assert len(args) == 1
4598
        if self.output_view:
4599
            wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference())
4600
        else:
4601
            wrapper.codegen_device_copy(args[0], self.codegen_reference())
4602

4603

4604
class DynamicScalar(ExternKernel):
4605
    """
4606
    The result of a call to aten._local_scalar_dense.
4607
    """
4608

4609
    def get_reads(self):
4610
        return ()
4611

4612
    def should_allocate(self):
4613
        return False
4614

4615
    # TODO: handle bools carefully
4616
    def __init__(self, sym, data):
4617
        data.realize()
4618
        super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data]))  # type: ignore[arg-type]
4619
        if isinstance(sym, sympy.Symbol):
4620
            self.sym = sym
4621
            self.is_bool = False
4622
        else:
4623
            # Special case for boolean.  For Reasons(TM), we don't represent
4624
            # boolean variables directly in sympy; instead, we generate an
4625
            # indicator integer variable which we then convert to a boolean by
4626
            # testing i0 == 1.  We have to identify the underlying indicator
4627
            # variable, and then bind i0 to the appropriate integer value
4628
            # based on the runtime boolean.
4629
            assert isinstance(sym, sympy.Eq), sym
4630
            assert isinstance(sym.args[0], sympy.Symbol), sym
4631
            assert sym.args[1] == 1, sym
4632
            self.sym = sym.args[0]
4633
            self.is_bool = True
4634

4635
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
4636
        return {self.sym}
4637

4638
    def codegen(self, wrapper):
4639
        wrapper.codegen_dynamic_scalar(self)
4640

4641

4642
class AssertScalar(ExternKernel):
4643
    """
4644
    The result of a call to aten._assert_scalar
4645
    """
4646

4647
    def get_reads(self):
4648
        return ()
4649

4650
    def should_allocate(self):
4651
        return False
4652

4653
    def __init__(self, scalar, msg):
4654
        super().__init__(
4655
            # Buffer(name, layotu)
4656
            None,
4657
            NoneLayout(torch.device("cpu")),  # type: ignore[arg-type]
4658
            # InputsKernel(inputs)
4659
            [],
4660
        )  # type: ignore[arg-type]
4661
        self.scalar = scalar
4662
        self.msg = msg
4663

4664
    def has_side_effects(self):
4665
        return True
4666

4667
    def get_unbacked_symbol_uses(self):
4668
        return free_unbacked_symbols(self.scalar)
4669

4670
    def codegen(self, wrapper):
4671
        if V.graph.cpp_wrapper:
4672
            pass
4673
        else:
4674
            wrapper.writeline(
4675
                f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:"
4676
            )
4677
            wrapper.writeline(f"    raise RuntimeError({repr(self.msg)})")
4678
            # No one should ever use this buffer, but for uniformity
4679
            # define the variable and assign it None
4680
            wrapper.writeline(f"{self.get_name()} = None")
4681

4682

4683
@dataclasses.dataclass
4684
class ExternKernelNode:
4685
    name: str
4686
    node: export_schema.Node
4687

4688

4689
has_c_shim = {
4690
    aten._embedding_bag.default,
4691
    aten._fft_c2c.default,
4692
    aten._scaled_dot_product_efficient_attention.default,
4693
    aten._scaled_dot_product_flash_attention.default,
4694
    aten._scaled_mm.default,
4695
    aten.addmm.out,
4696
    aten.bmm.out,
4697
    aten.copy_.default,
4698
    aten.mm.out,
4699
    aten.repeat_interleave.Tensor,
4700
    aten.nonzero.default,
4701
    aten.view.dtype,
4702
    aten.view_as_real.default,
4703
}
4704

4705

4706
def get_aten_cpp_kernel_name(kernel):
4707
    # Calling with the default kernel name can lead to ambiguous behavior like the following example.
4708
    # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
4709
    # repeat_interleave(const at::Tensor & self, int64_t repeats,
4710
    #       c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
4711
    assert (
4712
        isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten"
4713
    ), "Invalid aten kernel"
4714
    opname = (
4715
        kernel.__name__.split(".")[0]
4716
        if kernel._overloadname == "default"
4717
        else kernel.__name__.replace(".", "_")
4718
    )
4719
    return f"at::_ops::{opname}::call"
4720

4721

4722
class FallbackKernel(ExternKernelAlloc):
4723
    args_default_value: List[Dict[str, Any]]
4724

4725
    def __init__(
4726
        self,
4727
        layout,
4728
        kernel,
4729
        tensor_args,
4730
        nontensor_args,
4731
        unflatten_args,
4732
        kwargs=None,
4733
    ):
4734
        super().__init__(
4735
            layout,
4736
            tuple(tensor_args),
4737
            tuple(nontensor_args),
4738
            op_overload=kernel,
4739
        )
4740
        # We need output buffers for generating kernel arguments in the
4741
        # abi-compatible mode, where we retrieve outputs by pass each individual
4742
        # output through the abi-compatible interface.
4743
        self.outputs: Sequence[Any] = []
4744
        self.use_runtime_dispatch = False
4745
        self.abi_compatible_kernel = None
4746

4747
        assert isinstance(
4748
            kernel,
4749
            (
4750
                torch._ops.OpOverload,
4751
                torch._ops.HigherOrderOperator,
4752
            ),
4753
        ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
4754
        self.op_overload = kernel
4755

4756
        self.unflatten_args = unflatten_args
4757
        self.kwargs = {} if kwargs is None else kwargs
4758
        V.graph.warn_fallback(self.python_kernel_name)
4759

4760
        # args that are aliased
4761
        self.alias_names: List[str] = []
4762
        # args that are mutated AND returned from the op
4763
        self.mutation_names: List[str] = []
4764

4765
        if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
4766
            # We assume here that HOPs with FallbackKernel are functional.
4767
            # This may not always be true! HOPs must individually opt-in to
4768
            # FallbackKernel, so please check this if you opt-in.
4769
            return
4770

4771
        if "_c10d_functional" in self.op_overload.name():
4772
            # _c10d_functional kernels are lowered into _CollectiveKernel which
4773
            # derives from FallbackKernel for the cpp codegen. The kernels
4774
            # don't pass the can_auto_functionalize check, but their mutation
4775
            # is handled properly by _CollectiveKernel.
4776
            return
4777

4778
        schema = self.op_overload._schema
4779

4780
        # NOTE: [FallbackKernel supported operators]
4781
        # We only support three types of operators:
4782
        # - functional ops
4783
        # - view ops
4784
        # - inplace aten ops
4785
        # - mutating ops that are auto-functionalizable. That is,
4786
        # the operator may mutate any number of inputs, but its outputs
4787
        # may not alias any of the inputs.
4788
        #
4789
        # The unsupported cases usually do not show up here (because
4790
        # AOTAutograd functionalized them away); the only way for an in-place
4791
        # op to show up here is if a lowering or pass introduced it.
4792
        if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
4793
            self.mutation_names.append(tensor_args[0].get_name())
4794
            return
4795

4796
        if schema.is_mutable and not can_auto_functionalize(kernel):
4797
            raise NotImplementedError(
4798
                f"NYI: Can't generate FallbackKernel for {kernel}"
4799
            )
4800

4801
        schema_args = schema.arguments
4802
        args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
4803

4804
        def handle_aliasing_and_mutation(info, arg):
4805
            # Assertions to make sure we didn't mismatch args
4806
            if isinstance(info.type, torch.ListType):
4807
                assert isinstance(arg, (list, tuple))
4808
            is_optional_tensor = isinstance(
4809
                info.type, torch.OptionalType
4810
            ) and isinstance(info.type.getElementType(), torch.TensorType)
4811
            if is_optional_tensor or isinstance(info.type, torch.TensorType):
4812
                # PyTorch also accepts None and scalar types for args marked as "Tensor".
4813
                # We're not going to check all of them here.
4814
                assert not isinstance(arg, (tuple, list))
4815

4816
            if arg is None:
4817
                return
4818
            if info.alias_info is None:
4819
                return
4820
            # can_auto_functionalize already filters out mutable List[Tensor].
4821
            # We can support this in the future, but this is very uncommon.
4822
            assert isinstance(info.type, torch.TensorType) or is_optional_tensor
4823
            self.alias_names.append(arg.get_name())
4824
            if info.alias_info.is_write:
4825
                mark_node_as_mutating(self, arg)
4826

4827
        for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
4828
            handle_aliasing_and_mutation(info, arg)
4829

4830
    def set_cpp_kernel(self, kernel):
4831
        from .codegen.wrapper import get_cpp_op_schema
4832

4833
        assert (
4834
            not kernel._schema.is_mutable
4835
        ), f"mutable {kernel.__name__} is not supported with cpp_wrapper"
4836

4837
        # These checks are here because ops that return aliasing tensors will
4838
        # return type Tensor& instead of Tensor, but codegen will always write
4839
        # type Tensor on the LHS.
4840
        def is_not_write(arg):
4841
            return arg.alias_info is None or not arg.alias_info.is_write
4842

4843
        assert all(
4844
            is_not_write(x) for x in kernel._schema.arguments
4845
        ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper"
4846
        assert all(
4847
            is_not_write(x) for x in kernel._schema.returns
4848
        ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper"
4849

4850
        self.cpp_kernel_name = kernel._schema.name
4851
        self.cpp_kernel_overload_name = kernel._schema.overload_name
4852
        self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"  # type: ignore[union-attr]
4853

4854
        self.cpp_op_schema = get_cpp_op_schema(kernel)
4855
        self.init_args_default_value(kernel._schema)
4856

4857
    def is_legacy_abi_kernel(self):
4858
        return "_scaled_dot_product_flash_attention" in str(self.python_kernel_name)
4859

4860
    def init_args_default_value(self, schema):
4861
        self.args_default_value = [
4862
            {
4863
                "name": x.name,
4864
                "type": x.real_type,
4865
                "value": x.default_value,
4866
            }
4867
            for x in schema.arguments
4868
            if not x.kwarg_only
4869
        ]
4870

4871
    def get_pos_arg_value(self, pos, kwargs):
4872
        # positional args may be provided in kwargs
4873
        pos_arg_name = self.args_default_value[pos]["name"]
4874
        if pos_arg_name in kwargs:
4875
            log.debug(
4876
                "Found argument %s with value %s from kwargs",
4877
                pos_arg_name,
4878
                kwargs[pos_arg_name],
4879
            )
4880
            return kwargs[pos_arg_name]
4881

4882
        assert hasattr(
4883
            self, "args_default_value"
4884
        ), "self.args_default_value has to be provided"
4885
        assert pos < len(
4886
            self.args_default_value
4887
        ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}"
4888
        arg_default_value = self.args_default_value[pos]["value"]
4889
        log.debug(
4890
            "Use default value %s for argument %s", arg_default_value, pos_arg_name
4891
        )
4892
        return arg_default_value
4893

4894
    def codegen_args(self):
4895
        @dataclasses.dataclass
4896
        class Shim:
4897
            ref: Any
4898

4899
            def __repr__(self):
4900
                return self.ref
4901

4902
        tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
4903
        args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
4904
        # Now we setup abi_compatible_kernel after self.python_kernel_name
4905
        # and kwargs are adjusted appropriately.
4906
        # For sdpa, we need the v2 version since v1 didn't consider optional arg
4907
        # FIXME: no need to do this after we switch to the torchgen-ed C shim
4908
        self.abi_compatible_kernel = (
4909
            f"{self.cpp_kernel_name}_v2"
4910
            if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"}
4911
            else self.cpp_kernel_name
4912
        )
4913

4914
        if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
4915
            args = [
4916
                V.graph.wrapper_code.val_to_cpp_arg_str(
4917
                    param.real_type, x, self.is_legacy_abi_kernel()
4918
                )
4919
                for param, x in zip(self.op_overload._schema.arguments, args)
4920
            ]
4921
        else:
4922
            args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
4923

4924
        # Previously, we want to maintain forward-compatibility by skipping
4925
        # default args in the serialized artifacts in fbcode. However,
4926
        # some of our shim interfaces require default values being set.
4927
        # Discussed with Sherlock offline and we decided to allow serializing
4928
        # default args into the C++ wrapper code for now. We will refine this
4929
        # part if we see real FC requirement. More details related to FC
4930
        # can be found at:
4931
        # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
4932
        if V.graph.cpp_wrapper and hasattr(self, "args_default_value"):
4933
            self.fill_non_provided_args(args, kwargs, convert_val_to_str=True)
4934

4935
        # let self.codegen_kwargs handle kwargs
4936
        self.kwargs.update(kwargs)
4937
        return args
4938

4939
    @staticmethod
4940
    def find_device(tensor_args, example_output):
4941
        if tensor_args:
4942
            return tensor_args[0].get_device()
4943
        if isinstance(example_output, torch.Tensor):
4944
            return example_output.device
4945
        if isinstance(example_output, (list, tuple)):
4946
            devices = {FallbackKernel.find_device(None, x) for x in example_output}
4947
            # Remove None
4948
            devices = [device for device in devices if device]
4949
            if len(devices) == 1:
4950
                return devices[0]
4951
            for device in devices:
4952
                if device.type == "cuda":
4953
                    return device
4954
            return devices[0]
4955
        return None
4956

4957
    def has_side_effects(self):
4958
        if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
4959
            return False
4960
        return get_schema_info(self.op_overload).is_mutable()
4961

4962
    def get_alias_names(self):
4963
        return self.alias_names
4964

4965
    def get_mutation_names(self):
4966
        assert len(self.mutation_names) <= 1
4967
        return self.mutation_names
4968

4969
    def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
4970
        assert isinstance(args, (list, tuple))
4971
        if isinstance(args, tuple):
4972
            args = list(args)
4973
        assert hasattr(self, "args_default_value")
4974
        n_args = len(args)
4975
        n_pos_args = len(self.args_default_value)
4976
        # For cpp wrapper, if some positional args are not provided, we need to check
4977
        # if they're in the kwargs or use their default value
4978
        if n_args < n_pos_args:
4979
            log.debug(
4980
                "%s has %d unprovided positional arguments. "
4981
                "Will check if they are in the keyword arguments or will use default values.",
4982
                self.op_overload,
4983
                n_pos_args - n_args,
4984
            )
4985
            pos_args = [
4986
                self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args)
4987
            ]
4988
            if convert_val_to_str:
4989
                pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args]
4990
            args.extend(pos_args)
4991
        return args
4992

4993
    # ProxyExecutor Design Note
4994
    # We export the ExternFallbackNodes (for custom ops) into a serialized file
4995
    # and run it with a host side proxy executor to address the ABI problem
4996
    # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
4997
    # Detailed design doc can be found at
4998
    # https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
4999
    def export_extern_kernel_node(self):
5000
        assert isinstance(self, FallbackKernel)
5001
        args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
5002
        args = self.fill_non_provided_args(args, kwargs)
5003
        ordered_kwargs = [
5004
            kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel
5005
        ]
5006

5007
        serializer = GraphModuleSerializer(None, None)  # type: ignore[arg-type]
5008
        named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)  # type: ignore[arg-type]
5009

5010
        # serialize_outputs
5011
        def handle_single_output(return_type, output):
5012
            if isinstance(return_type, torch.TensorType):
5013
                # For single Tensor
5014
                out = output
5015
                if isinstance(output, (list, tuple)):
5016
                    assert len(output) == 1
5017
                    out = output[0]
5018
                return export_schema.Argument.create(
5019
                    as_tensor=export_schema.TensorArgument(name=out.get_name())
5020
                )
5021
            elif isinstance(return_type, torch.ListType) and isinstance(
5022
                return_type.getElementType(), torch.TensorType
5023
            ):
5024
                # For single TensorList
5025
                return export_schema.Argument.create(
5026
                    as_tensors=[
5027
                        export_schema.TensorArgument(name=out.get_name())
5028
                        for out in output
5029
                    ]
5030
                )
5031
            else:
5032
                raise RuntimeError(f"Unsupported return type {type(return_type)}")
5033

5034
        target = self.op_overload
5035
        returns = target._schema.returns  # type: ignore[union-attr]
5036
        if len(returns) == 1:
5037
            return_type = returns[0].real_type
5038
            output_arguments = [handle_single_output(return_type, self.outputs)]
5039
        else:
5040
            # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
5041
            assert isinstance(self.outputs, tuple)
5042
            assert len(returns) == len(self.outputs)
5043
            output_arguments = [
5044
                handle_single_output(return_schema.real_type, output)
5045
                for return_schema, output in zip(returns, self.outputs)
5046
            ]
5047

5048
        node = ExternKernelNode(
5049
            name=self.get_name(),
5050
            node=export_schema.Node(
5051
                target=self.op_overload.name(),  # type: ignore[union-attr]
5052
                inputs=named_arguments,
5053
                outputs=output_arguments,
5054
                metadata={},
5055
            ),
5056
        )
5057

5058
        V.graph.extern_kernel_nodes.append(node)
5059

5060
        return [*args, *ordered_kwargs]
5061

5062
    def codegen(self, wrapper):
5063
        kernel = self.op_overload
5064
        if kernel.namespace == "aten":  # type: ignore[union-attr]
5065
            # Aten Fallback Ops
5066
            assert isinstance(kernel, torch._ops.OpOverload)
5067
            if V.graph.cpp_wrapper:
5068
                if config.is_fbcode() and kernel not in has_c_shim:
5069
                    log.warning(
5070
                        "%s is missing a c-shim implementation, using proxy executor as fallback",
5071
                        kernel,
5072
                    )
5073
                    self.use_runtime_dispatch = True
5074
                    self.set_cpp_kernel(kernel)
5075
                else:
5076
                    self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel)
5077
                    schema = kernel._schema
5078
                    self.init_args_default_value(schema)
5079
            else:
5080
                self.python_kernel_name = str(kernel)
5081

5082
        elif isinstance(kernel, torch._ops.HigherOrderOperator):
5083
            self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
5084
        else:
5085
            # For non-aten OpOverload, i.e. custom ops
5086
            if V.graph.cpp_wrapper:
5087
                self.use_runtime_dispatch = True
5088
                self.set_cpp_kernel(kernel)
5089
            else:
5090
                self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"  # type: ignore[union-attr]
5091

5092
        if self.use_runtime_dispatch:
5093
            self.codegen_comment(wrapper)
5094

5095
            exported_args = None
5096
            args = None
5097
            if config.is_fbcode() and V.graph.cpp_wrapper:
5098
                exported_args = self.export_extern_kernel_node()
5099
            else:
5100
                args = [*self.codegen_args(), *self.codegen_kwargs()]
5101

5102
            wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5103
                self.get_name(),
5104
                self.get_kernel_name(),
5105
                args,
5106
                self.cpp_op_schema,
5107
                self.cpp_kernel_key,
5108
                self.cpp_kernel_overload_name,
5109
                self.op_overload,
5110
                exported_args,
5111
                self.outputs,
5112
            )
5113
        else:
5114
            self.codegen_comment(wrapper)
5115
            args = [*self.codegen_args(), *self.codegen_kwargs()]
5116
            V.graph.wrapper_code.generate_fallback_kernel(self, args)
5117
            if isinstance(self.layout, Layout):
5118
                self.codegen_size_asserts(wrapper)
5119

5120
    @staticmethod
5121
    def tensor_to_layout(output: torch.Tensor):
5122
        return FixedLayout(
5123
            output.device,
5124
            output.dtype,
5125
            convert_shape_to_inductor(output.size()),
5126
            convert_shape_to_inductor(output.stride()),
5127
        )
5128

5129
    @classmethod
5130
    def create(cls, kernel, *args, **kwargs):
5131
        fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
5132
        context = (
5133
            V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
5134
        )
5135
        with context:
5136
            (
5137
                example_output,
5138
                tensor_args,
5139
                non_tensor_args,
5140
                unflatten_args,
5141
            ) = cls.process_kernel(kernel, *args, **kwargs)
5142

5143
        device = cls.find_device(tensor_args, example_output)
5144
        assert device, "Not sure where to find device info"
5145

5146
        packed = cls(
5147
            MultiOutputLayout(device),
5148
            kernel,
5149
            tensor_args,
5150
            non_tensor_args,
5151
            unflatten_args,
5152
        )
5153

5154
        def generate_output(output, indices):
5155
            if isinstance(output, (list, tuple)):
5156
                return type(output)(
5157
                    generate_output(output[i], indices + [(type(output), i)])
5158
                    for i in range(len(output))
5159
                )
5160
            elif isinstance(output, dict):
5161
                return {
5162
                    key: generate_output(val, indices + [(type(output), key)])
5163
                    for key, val in output.items()
5164
                }
5165
            elif isinstance(output, torch.Tensor):
5166
                return MultiOutput(
5167
                    cls.tensor_to_layout(output),
5168
                    packed,
5169
                    indices,
5170
                )
5171
            elif isinstance(output, int):
5172
                return output
5173
            elif isinstance(output, torch.SymInt):
5174
                return output.node.expr
5175
            else:
5176
                assert (
5177
                    output is None
5178
                ), f"FallbackKernel output type {type(output)} is not supported"
5179
                return None
5180

5181
        outputs = generate_output(example_output, [])
5182
        if isinstance(outputs, (list, tuple, dict)):
5183
            packed.outputs = outputs  # type: ignore[assignment]
5184
        else:
5185
            packed.outputs = [outputs]
5186
        return outputs
5187

5188
    def apply_constraint(self):
5189
        return super().apply_constraint()
5190

5191

5192
@dataclasses.dataclass
5193
class ComplexView(FallbackKernel):
5194
    """View a complex number as two dtyped numbers or vice versa"""
5195

5196
    def should_allocate(self):
5197
        return False
5198

5199
    def get_alias_names(self):
5200
        # Signal to codegen that our output buffer isn't safe to reuse
5201
        return [self.inputs[0].get_name()]
5202

5203
    def __init__(
5204
        self,
5205
        layout,
5206
        kernel,
5207
        tensor_args,
5208
        nontensor_args,
5209
        unflatten_args,
5210
    ):
5211
        super().__init__(
5212
            layout,
5213
            kernel,
5214
            tensor_args,
5215
            nontensor_args,
5216
            unflatten_args,
5217
        )
5218

5219

5220
@dataclasses.dataclass
5221
class MultiOutputLayout(IRNode):
5222
    device: torch.device
5223

5224

5225
class MultiOutput(ExternKernel):
5226
    # Given an input MultiOutputLayout buffer, indexes out an actual buffer
5227
    # from that result.  This doesn't actually produce multiple outputs,
5228
    # that's MultiOutputLayout!
5229
    def codegen_list_tuple_access(self, basename, indices):
5230
        if len(indices) > 0:
5231
            itype, i = indices[0]
5232
            if itype == list:
5233
                return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
5234
            elif itype == tuple:
5235
                # cpp wrapper code needs to use std::get<> to access a tuple
5236
                tuple_access = V.graph.wrapper_code.codegen_tuple_access(
5237
                    basename, self.get_name(), str(i)
5238
                )
5239
                return self.codegen_list_tuple_access(tuple_access, indices[1:])
5240
            elif itype == dict:
5241
                return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:])
5242
            else:
5243
                raise AssertionError("non supported index type")
5244
        else:
5245
            return basename
5246

5247
    def codegen(self, wrapper):
5248
        wrapper.codegen_multi_output(
5249
            self.get_name(),
5250
            self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
5251
        )
5252
        self.codegen_unbacked_symbol_defs(wrapper)
5253

5254
    def __init__(self, layout, input, indices: List[Tuple[Any, ...]]):
5255
        super().__init__(None, layout, [input], ())
5256
        self.name = V.graph.register_buffer(self)
5257
        self.indices = indices
5258

5259
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
5260
        return self.inputs[0].get_unbacked_symbol_uses()
5261

5262
    def should_allocate(self):
5263
        return False
5264

5265
    def get_alias_names(self):
5266
        return [
5267
            inp.get_name()
5268
            for inp in self.inputs
5269
            if isinstance(inp, FallbackKernel) and len(inp.get_alias_names()) > 0
5270
        ]
5271

5272

5273
def _prepare_convolution_fusion_create(
5274
    cls,
5275
    x: "TensorBox",
5276
    weight: "TensorBox",
5277
    bias: "TensorBox",
5278
    padding: List[int],
5279
    stride: List[int],
5280
    dilation: List[int],
5281
    groups: int,
5282
    transposed: bool = False,
5283
    output_padding: Optional[List[int]] = None,
5284
):
5285
    """
5286
    This function is a helper function to prepare inputs, layout and constant args
5287
    for convolution post-op fusion's create function, including deciding the output
5288
    layout (channels first or channels last), realizing inputs and make them etc. The
5289
    function only supports the CPU device since conv post-op fusion kernel is only
5290
    supported on CPU right now.
5291
    """
5292

5293
    # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
5294
    def _conv_input_size(
5295
        output_size, weight_size, padding, output_padding, stride, dilation, groups
5296
    ):
5297
        assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
5298
        dim = len(output_size)
5299
        assert dim > 2, "Expect input dim > 2"
5300

5301
        BATCH_DIM = 0
5302
        WEIGHT_INPUT_CHANNELS_DIM = 1
5303
        input_size = []
5304
        input_size.append(output_size[BATCH_DIM])
5305
        input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
5306
        for d in range(2, dim):
5307
            kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
5308
            input_size_d = (
5309
                (output_size[d] - 1) * stride[d - 2]
5310
                - (padding[d - 2] * 2)
5311
                + kernel
5312
                + output_padding[d - 2]
5313
            )
5314
            input_size.append(input_size_d)
5315
        return list(map(int, input_size))
5316

5317
    # The size of prepacked_weight is the prepacked weight size of deconv:
5318
    #   Groups > 1:  [g*o, i/g, ...]
5319
    #   Groups == 1: [o, i, ...]
5320
    # Returns original weight size in [i, o, ...]
5321
    def _original_deconv_weight_size(
5322
        prepacked_weight,
5323
        groups,
5324
    ):
5325
        prepacked_weight_size = prepacked_weight.size()
5326
        dim = len(prepacked_weight_size)
5327
        assert dim > 2, "Expect weight dim > 2"
5328
        if groups > 1:
5329
            weight_size = []
5330
            weight_size.append(prepacked_weight_size[1] * groups)
5331
            weight_size.append(prepacked_weight_size[0] / groups)
5332
            for d in range(2, dim):
5333
                weight_size.append(prepacked_weight_size[d])
5334
        else:
5335
            weight_size = prepacked_weight.transpose(0, 1).size()
5336
        return weight_size
5337

5338
    x.realize()
5339
    weight.realize()
5340
    if bias is not None:
5341
        bias.realize()
5342
    with V.graph.fake_mode:
5343
        # TODO <Leslie> cleaned up the fake_tensor trace as Linear implementation
5344
        x_fake = ir_node_to_tensor(x, guard_shape=True)
5345
        weight_fake = ir_node_to_tensor(weight, guard_shape=True)
5346
        dims = len(x_fake.size()) - 2
5347
        assert 0 < len(padding) <= dims
5348
        assert 0 < len(dilation) <= dims
5349
        assert 0 < len(stride) <= dims
5350
        padding = pad_listlike(padding, dims)
5351
        dilation = pad_listlike(dilation, dims)
5352
        stride = pad_listlike(stride, dims)
5353
        if output_padding is None:
5354
            output_padding = pad_listlike([0], dims)
5355
        else:
5356
            assert 0 < len(output_padding) <= dims
5357
            output_padding = pad_listlike(output_padding, dims)
5358
        assert isinstance(groups, int)
5359
        if transposed:
5360
            # When transposed, the size of the prepacked oneDNN weight is different
5361
            # from the PyTorch weight. We're not able to run aten conv with such
5362
            # size. We infer the output size from the input params here:
5363
            weight_size = _original_deconv_weight_size(weight_fake, groups)
5364
            input_size = x_fake.size()
5365
            output_size = _conv_input_size(
5366
                input_size,
5367
                weight_size,
5368
                padding,
5369
                output_padding,
5370
                stride,
5371
                dilation,
5372
                groups,
5373
            )
5374
        else:
5375
            bias_fake = (
5376
                ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
5377
            )
5378
            output = torch.ops.aten.convolution(
5379
                x_fake,
5380
                weight_fake,
5381
                bias_fake,
5382
                stride,
5383
                padding,
5384
                dilation,
5385
                transposed,
5386
                output_padding,
5387
                groups,
5388
            )
5389
            output_size = output.size()
5390

5391
        req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
5392
        req_stride_order = [len(req_stride_order)] + req_stride_order
5393
        output_stride = make_channels_last_strides_for(output_size)
5394

5395
    x = cls.require_stride_order(x, req_stride_order)
5396
    assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5397
    inputs = [x, weight]
5398

5399
    kernel_layout = FixedLayout(
5400
        x.get_device(),
5401
        x.get_dtype(),
5402
        convert_shape_to_inductor(output_size),
5403
        convert_shape_to_inductor(output_stride),
5404
    )
5405
    constant_args = [padding, stride, dilation, groups]
5406
    if transposed:
5407
        constant_args.insert(1, output_padding)
5408

5409
    if bias is not None:
5410
        inputs.append(bias)
5411
    else:
5412
        constant_args.insert(0, bias)
5413
    return inputs, constant_args, kernel_layout, req_stride_order
5414

5415

5416
def _prepare_linear_fusion_create(
5417
    cls,
5418
    x: "TensorBox",
5419
    weight: "TensorBox",
5420
    bias: "TensorBox",
5421
):
5422
    """
5423
    This function is a helper function to prepare inputs, layout and constant args
5424
    for linear post-op fusion's create function. The function only supports the CPU device
5425
    since linear post-op fusion kernel is only supported on CPU right now.
5426
    """
5427
    x.realize()
5428
    weight.realize()
5429
    if bias is not None:
5430
        bias.realize()
5431

5432
    *m, _ = x.get_size()
5433
    # The weight has been transposed during the qlinear weight prepack process.
5434
    # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/
5435
    # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291
5436
    _, oc = weight.get_size()
5437
    output_size = list(m) + [oc]
5438
    req_stride_order = list(reversed(range(len(x.get_size()))))
5439

5440
    x = cls.require_stride_order(x, req_stride_order)
5441
    assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
5442
    inputs = [x, weight]
5443

5444
    output_stride = make_contiguous_strides_for(output_size)
5445
    kernel_layout = FixedLayout(
5446
        x.get_device(),
5447
        x.get_dtype(),
5448
        output_size,
5449
        output_stride,
5450
    )
5451
    constant_args: List[Any] = []
5452

5453
    if bias is not None:
5454
        inputs.append(bias)
5455
    else:
5456
        constant_args.insert(0, bias)
5457
    return inputs, constant_args, kernel_layout, req_stride_order
5458

5459

5460
class ConvolutionUnary(ExternKernelAlloc):
5461
    def __init__(
5462
        self,
5463
        layout,
5464
        inputs,
5465
        constant_args=(),
5466
    ):
5467
        super().__init__(
5468
            layout,
5469
            inputs,
5470
            constant_args,
5471
            None,
5472
            python_kernel_name="torch.ops.mkldnn._convolution_pointwise",
5473
            cpp_kernel_name="mkldnn::_convolution_pointwise",
5474
        )
5475
        self.cpp_kernel_key = "convolution_pointwise"
5476
        self.cpp_op_schema = """
5477
            at::Tensor(
5478
                const at::Tensor& input_t,
5479
                const at::Tensor& weight_t,
5480
                const c10::optional<at::Tensor>& bias_opt,
5481
                at::IntArrayRef padding,
5482
                at::IntArrayRef stride,
5483
                at::IntArrayRef dilation,
5484
                int64_t groups,
5485
                c10::string_view attr,
5486
                torch::List<c10::optional<at::Scalar>> scalars,
5487
                c10::optional<c10::string_view> algorithm)"""
5488

5489
    def codegen(self, wrapper):
5490
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5491
            self.get_name(),
5492
            self.get_kernel_name(),
5493
            self.codegen_args(),
5494
            self.cpp_op_schema,
5495
            self.cpp_kernel_key,
5496
        )
5497
        if isinstance(self.layout, Layout):
5498
            self.codegen_size_asserts(wrapper)
5499

5500
    @classmethod
5501
    def create(
5502
        cls,
5503
        x: "TensorBox",
5504
        weight: "TensorBox",
5505
        bias: "TensorBox",
5506
        padding_: List[int],
5507
        stride_: List[int],
5508
        dilation_: List[int],
5509
        groups: int,
5510
        attr,
5511
        scalars: Optional[List[Any]],
5512
        algorithm,
5513
    ):
5514
        (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
5515
            cls, x, weight, bias, padding_, stride_, dilation_, groups
5516
        )
5517
        constant_args = constant_args + [
5518
            attr,
5519
            may_convert_to_optional(scalars),
5520
            algorithm,
5521
        ]
5522
        return ConvolutionUnary(
5523
            layout=kernel_layout,
5524
            inputs=inputs,
5525
            constant_args=constant_args,
5526
        )
5527

5528

5529
class ConvolutionBinary(ExternKernelAlloc):
5530
    def __init__(
5531
        self,
5532
        layout,
5533
        inputs,
5534
        constant_args=(),
5535
        cpp_constant_args=(),
5536
    ):
5537
        super().__init__(
5538
            layout,
5539
            inputs,
5540
            constant_args,
5541
            None,
5542
            python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary",
5543
            cpp_kernel_name="mkldnn::_convolution_pointwise",
5544
        )
5545
        self.cpp_kernel_overload_name = "binary"
5546
        self.cpp_kernel_key = "convolution_pointwise_binary"
5547
        self.cpp_op_schema = """
5548
            at::Tensor(
5549
                const at::Tensor& input_t,
5550
                const at::Tensor& other_t,
5551
                const at::Tensor& weight_t,
5552
                const c10::optional<at::Tensor>& bias_opt,
5553
                at::IntArrayRef padding,
5554
                at::IntArrayRef stride,
5555
                at::IntArrayRef dilation,
5556
                int64_t groups,
5557
                c10::string_view binary_attr,
5558
                c10::optional<at::Scalar> alpha,
5559
                c10::optional<c10::string_view> unary_attr,
5560
                torch::List<c10::optional<at::Scalar>> unary_scalars,
5561
                c10::optional<c10::string_view> unary_algorithm)"""
5562
        self.cpp_constant_args = cpp_constant_args
5563

5564
    def codegen(self, wrapper):
5565
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5566
            self.get_name(),
5567
            self.get_kernel_name(),
5568
            self.codegen_args(),
5569
            self.cpp_op_schema,
5570
            self.cpp_kernel_key,
5571
            self.cpp_kernel_overload_name,
5572
        )
5573
        if isinstance(self.layout, Layout):
5574
            self.codegen_size_asserts(wrapper)
5575

5576
    @classmethod
5577
    def create(
5578
        cls,
5579
        x: "TensorBox",
5580
        other: "TensorBox",
5581
        weight: "TensorBox",
5582
        bias: "TensorBox",
5583
        padding_: List[int],
5584
        stride_: List[int],
5585
        dilation_: List[int],
5586
        groups: int,
5587
        binary_attr: str,
5588
        binary_alpha: Optional[float],
5589
        unary_attr: Optional[str],
5590
        unary_scalars: Optional[List[Any]],
5591
        unary_algorithm: Optional[str],
5592
    ):
5593
        (
5594
            inputs,
5595
            constant_args,
5596
            kernel_layout,
5597
            req_stride_order,
5598
        ) = _prepare_convolution_fusion_create(
5599
            cls, x, weight, bias, padding_, stride_, dilation_, groups
5600
        )
5601
        other = cls.require_stride_order(other, req_stride_order)
5602
        inputs.insert(1, other)
5603
        constant_args = constant_args + [
5604
            binary_attr,
5605
            binary_alpha,
5606
            unary_attr,
5607
            may_convert_to_optional(unary_scalars),
5608
            unary_algorithm,
5609
        ]
5610
        return ConvolutionBinary(
5611
            layout=kernel_layout,
5612
            inputs=inputs,
5613
            constant_args=constant_args,
5614
        )
5615

5616

5617
class ConvolutionBinaryInplace(ExternKernelAlloc):
5618
    def __init__(
5619
        self,
5620
        kernel_layout,
5621
        inputs,
5622
        constant_args=(),
5623
    ):
5624
        # Due to constrain of op.call, other (Tensor&) should be at input[0]
5625
        reordered_inputs = [inputs[1], inputs[0]] + inputs[2:]
5626

5627
        super().__init__(
5628
            kernel_layout,
5629
            reordered_inputs,
5630
            constant_args,
5631
            None,
5632
            python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary",
5633
            cpp_kernel_name="mkldnn::_convolution_pointwise_",
5634
        )
5635
        self.cpp_kernel_overload_name = "binary"
5636
        self.cpp_kernel_key = "convolution_pointwise_binary_"
5637
        # TODO: op.call: input[0] should be at::Tensor&
5638
        self.cpp_op_schema = """
5639
            at::Tensor&(
5640
                at::Tensor& other_t,
5641
                const at::Tensor& input_t,
5642
                const at::Tensor& weight_t,
5643
                const c10::optional<at::Tensor>& bias_opt,
5644
                at::IntArrayRef padding,
5645
                at::IntArrayRef stride,
5646
                at::IntArrayRef dilation,
5647
                int64_t groups,
5648
                c10::string_view binary_attr,
5649
                c10::optional<at::Scalar> alpha,
5650
                c10::optional<c10::string_view> unary_attr,
5651
                torch::List<c10::optional<at::Scalar>> unary_scalars,
5652
                c10::optional<c10::string_view> unary_algorithm)"""
5653

5654
    def codegen(self, wrapper):
5655
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5656
            self.get_name(),
5657
            self.get_kernel_name(),
5658
            self.codegen_args(),
5659
            self.cpp_op_schema,
5660
            self.cpp_kernel_key,
5661
            self.cpp_kernel_overload_name,
5662
        )
5663

5664
    def get_mutation_names(self):
5665
        return [self.inputs[0].get_name()]
5666

5667
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
5668
        return set()
5669

5670
    @classmethod
5671
    def create(
5672
        cls,
5673
        x: "TensorBox",
5674
        other: "TensorBox",
5675
        weight: "TensorBox",
5676
        bias: "TensorBox",
5677
        padding_: List[int],
5678
        stride_: List[int],
5679
        dilation_: List[int],
5680
        groups: int,
5681
        binary_attr: str,
5682
        binary_alpha: Optional[float],
5683
        unary_attr: Optional[str],
5684
        unary_scalars: Optional[List[Any]],
5685
        unary_algorithm: Optional[str],
5686
    ):
5687
        (
5688
            inputs,
5689
            constant_args,
5690
            _,
5691
            req_stride_order,
5692
        ) = _prepare_convolution_fusion_create(
5693
            cls, x, weight, bias, padding_, stride_, dilation_, groups
5694
        )
5695
        other = cls.require_stride_order(other, req_stride_order)
5696
        inputs.insert(1, other)
5697
        constant_args = constant_args + [
5698
            binary_attr,
5699
            binary_alpha,
5700
            unary_attr,
5701
            may_convert_to_optional(unary_scalars),
5702
            unary_algorithm,
5703
        ]
5704
        packed = ConvolutionBinaryInplace(
5705
            kernel_layout=NoneLayout(inputs[1].get_device()),  # type: ignore[arg-type]
5706
            inputs=inputs,
5707
            constant_args=constant_args,
5708
        )
5709
        mark_node_as_mutating(packed, inputs[1])
5710
        # This op mutates in place which means that the result is not the
5711
        # target but rather the input that is being mutated
5712
        # init reorders the inputs, so inputs[1] becomes packed.inputs[0]
5713
        return packed.inputs[0]
5714

5715

5716
class MKLPackedLinear(ExternKernelAlloc):
5717
    def __init__(
5718
        self,
5719
        layout,
5720
        inputs,
5721
        constant_args=(),
5722
    ):
5723
        super().__init__(
5724
            layout,
5725
            inputs,
5726
            constant_args,
5727
            None,
5728
            python_kernel_name="torch.ops.mkl._mkl_linear",
5729
            cpp_kernel_name="mkl::_mkl_linear",
5730
        )
5731
        self.cpp_kernel_key = "mkl_linear"
5732
        self.cpp_op_schema = """
5733
            at::Tensor(
5734
                const at::Tensor& self,
5735
                const at::Tensor& mkl_weight_t,
5736
                const at::Tensor& origin_weight_t,
5737
                const c10::optional<at::Tensor>& bias_opt,
5738
                const int64_t prepack_batch_size)"""
5739

5740
    def codegen(self, wrapper):
5741
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5742
            self.get_name(),
5743
            self.get_kernel_name(),
5744
            self.codegen_args(),
5745
            self.cpp_op_schema,
5746
            self.cpp_kernel_key,
5747
        )
5748

5749
    @classmethod
5750
    def create(cls, x, packed_w, orig_w, batch_size):
5751
        x = cls.require_stride1(cls.realize_input(x))
5752
        orig_w = cls.require_stride1(cls.realize_input(orig_w))
5753
        *m, _ = x.get_size()
5754
        oc, _ = orig_w.get_size()
5755
        output_size = list(m) + [oc]
5756
        output_stride = make_contiguous_strides_for(output_size)
5757
        inputs = [x, packed_w, orig_w]
5758
        constant_args = [None, batch_size]
5759

5760
        return MKLPackedLinear(
5761
            layout=FixedLayout(
5762
                x.get_device(), x.get_dtype(), output_size, output_stride
5763
            ),
5764
            inputs=inputs,
5765
            constant_args=constant_args,
5766
        )
5767

5768

5769
class LinearUnary(ExternKernelAlloc):
5770
    def __init__(
5771
        self,
5772
        layout,
5773
        inputs,
5774
        constant_args=(),
5775
    ):
5776
        super().__init__(
5777
            layout,
5778
            inputs,
5779
            constant_args,
5780
            None,
5781
            python_kernel_name="torch.ops.mkldnn._linear_pointwise",
5782
            cpp_kernel_name="mkldnn::_linear_pointwise",
5783
        )
5784
        self.cpp_kernel_key = "linear_pointwise"
5785
        self.cpp_op_schema = """
5786
            at::Tensor(
5787
                const at::Tensor& input_t,
5788
                const at::Tensor& weight_t,
5789
                const c10::optional<at::Tensor>& bias_opt,
5790
                c10::string_view attr,
5791
                torch::List<c10::optional<at::Scalar>> scalars,
5792
                c10::optional<c10::string_view> algorithm)"""
5793

5794
    def codegen(self, wrapper):
5795
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5796
            self.get_name(),
5797
            self.get_kernel_name(),
5798
            self.codegen_args(),
5799
            self.cpp_op_schema,
5800
            self.cpp_kernel_key,
5801
        )
5802

5803
    @classmethod
5804
    def create(cls, x, w, b, attr, scalars, algorithm):
5805
        x = cls.require_contiguous(cls.realize_input(x))
5806
        w = cls.require_contiguous(cls.realize_input(w))
5807

5808
        *m, ic = x.get_size()
5809
        oc, ic = w.get_size()
5810
        inputs = [x, w]
5811
        constant_args = [attr, scalars if scalars else [-1], algorithm]
5812
        if b is not None:
5813
            b = cls.require_contiguous(cls.realize_input(b))
5814
            inputs.append(b)
5815
        else:
5816
            constant_args.insert(0, None)
5817

5818
        return LinearUnary(
5819
            layout=FlexibleLayout(
5820
                device=x.get_device(),
5821
                dtype=x.get_dtype(),
5822
                size=list(m) + [oc],
5823
            ),
5824
            inputs=inputs,
5825
            constant_args=constant_args,
5826
        )
5827

5828
    def apply_constraint(self):
5829
        pass
5830

5831

5832
class LinearBinary(ExternKernelAlloc):
5833
    kernel = "torch.ops.mkldnn._linear_pointwise.binary"
5834

5835
    def __init__(
5836
        self,
5837
        layout,
5838
        inputs,
5839
        constant_args=(),
5840
    ):
5841
        super().__init__(
5842
            layout,
5843
            inputs,
5844
            constant_args,
5845
            None,
5846
            python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary",
5847
            cpp_kernel_name="mkldnn::_linear_pointwise",
5848
        )
5849
        self.cpp_kernel_overload_name = "binary"
5850
        self.cpp_kernel_key = "linear_pointwise_binary"
5851
        self.cpp_op_schema = """
5852
            at::Tensor(
5853
                const at::Tensor& input_t,
5854
                const at::Tensor& other_t,
5855
                const at::Tensor& weight_t,
5856
                const c10::optional<at::Tensor>& bias_opt,
5857
                c10::string_view attr)
5858
        """
5859

5860
    def codegen(self, wrapper):
5861
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5862
            self.get_name(),
5863
            self.get_kernel_name(),
5864
            self.codegen_args(),
5865
            self.cpp_op_schema,
5866
            self.cpp_kernel_key,
5867
            self.cpp_kernel_overload_name,
5868
        )
5869

5870
    @classmethod
5871
    def create(cls, x, y, w, b, attr):
5872
        x = cls.require_contiguous(cls.realize_input(x))
5873
        y = cls.require_contiguous(cls.realize_input(y))
5874
        w = cls.require_contiguous(cls.realize_input(w))
5875

5876
        *m, ic = x.get_size()
5877
        oc, ic = w.get_size()
5878

5879
        inputs = [x, y, w]
5880
        constant_args = [attr]
5881
        if b is not None:
5882
            b = cls.require_contiguous(cls.realize_input(b))
5883
            inputs.append(b)
5884
        else:
5885
            constant_args.insert(0, b)
5886

5887
        return LinearBinary(
5888
            layout=FlexibleLayout(
5889
                device=x.get_device(),
5890
                dtype=x.get_dtype(),
5891
                size=list(m) + [oc],
5892
            ),
5893
            inputs=inputs,
5894
            constant_args=constant_args,
5895
        )
5896

5897
    def apply_constraint(self):
5898
        pass
5899

5900

5901
class ConvolutionTransposeUnary(ExternKernelAlloc):
5902
    def __init__(
5903
        self,
5904
        layout,
5905
        inputs,
5906
        constant_args=(),
5907
    ):
5908
        super().__init__(
5909
            layout,
5910
            inputs,
5911
            constant_args,
5912
            None,
5913
            python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise",
5914
            cpp_kernel_name="mkldnn::_convolution_transpose_pointwise",
5915
        )
5916
        self.cpp_kernel_key = "convolution_transpose_pointwise"
5917
        self.cpp_op_schema = """
5918
            at::Tensor(
5919
                const at::Tensor& input_t,
5920
                const at::Tensor& weight_t,
5921
                const c10::optional<at::Tensor>& bias_opt,
5922
                at::IntArrayRef padding,
5923
                at::IntArrayRef output_padding,
5924
                at::IntArrayRef stride,
5925
                at::IntArrayRef dilation,
5926
                int64_t groups,
5927
                c10::string_view attr,
5928
                torch::List<c10::optional<at::Scalar>> scalars,
5929
                c10::optional<c10::string_view> algorithm)"""
5930

5931
    def codegen(self, wrapper):
5932
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
5933
            self.get_name(),
5934
            self.get_kernel_name(),
5935
            self.codegen_args(),
5936
            self.cpp_op_schema,
5937
            self.cpp_kernel_key,
5938
        )
5939

5940
    @classmethod
5941
    def create(
5942
        cls,
5943
        x: "TensorBox",
5944
        weight: "TensorBox",
5945
        bias: "TensorBox",
5946
        padding_: List[int],
5947
        output_padding_: List[int],
5948
        stride_: List[int],
5949
        dilation_: List[int],
5950
        groups_: int,
5951
        attr,
5952
        scalars: Optional[List[Any]],
5953
        algorithm,
5954
    ):
5955
        transposed = True
5956
        (
5957
            inputs,
5958
            constant_args,
5959
            kernel_layout,
5960
            _,
5961
        ) = _prepare_convolution_fusion_create(
5962
            cls,
5963
            x,
5964
            weight,
5965
            bias,
5966
            padding_,
5967
            stride_,
5968
            dilation_,
5969
            groups_,
5970
            transposed,
5971
            output_padding_,
5972
        )
5973
        constant_args = constant_args + [
5974
            attr,
5975
            may_convert_to_optional(scalars),
5976
            algorithm,
5977
        ]
5978
        return ConvolutionTransposeUnary(
5979
            layout=kernel_layout,
5980
            inputs=inputs,
5981
            constant_args=constant_args,
5982
        )
5983

5984

5985
class MkldnnRnnLayer(ExternKernelAlloc):
5986
    def __init__(
5987
        self,
5988
        layout,
5989
        inputs,
5990
        constant_args=(),
5991
    ):
5992
        super().__init__(
5993
            layout,
5994
            inputs,
5995
            constant_args,
5996
            None,
5997
            python_kernel_name="aten.mkldnn_rnn_layer",
5998
            cpp_kernel_name="at::mkldnn_rnn_layer",
5999
        )
6000

6001
    @classmethod
6002
    def create(
6003
        cls,
6004
        x: "TensorBox",
6005
        w0: "TensorBox",
6006
        w1: "TensorBox",
6007
        w2: "TensorBox",
6008
        w3: "TensorBox",
6009
        hx: "TensorBox",
6010
        cx: "TensorBox",
6011
        reverse: bool,
6012
        batch_sizes: List[int],
6013
        mode: int,
6014
        hidden_size: int,
6015
        num_layers: int,
6016
        has_biases: bool,
6017
        bidirectional: bool,
6018
        batch_first: bool,
6019
        train: bool,
6020
    ):
6021
        x = cls.require_stride1(cls.realize_input(x))
6022
        # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer.
6023
        # Make sure x is contiguous in batch_first case.
6024
        x.freeze_layout()
6025
        w0 = cls.require_stride1(cls.realize_input(w0))
6026
        w1 = cls.require_stride1(cls.realize_input(w1))
6027
        w2 = cls.require_stride1(cls.realize_input(w2))
6028
        w3 = cls.require_stride1(cls.realize_input(w3))
6029
        hx = cls.require_stride1(cls.realize_input(hx))
6030
        hx.freeze_layout()
6031
        cx = cls.require_stride1(cls.realize_input(cx))
6032
        cx.freeze_layout()
6033

6034
        input_size = x.get_size()
6035
        assert len(input_size) == 3, "Expect lstm input to be 3D"
6036
        # batch_first is handled in the lstm OP. When entering
6037
        # rnn_layer here, we'll always have batch_first = False
6038
        seq_length, mini_batch, input_size = input_size
6039
        output_shape = [seq_length, mini_batch, hidden_size]
6040

6041
        hy_shape = hx.get_size()
6042
        cy_shape = cx.get_size()
6043

6044
        res: List[IRNode] = []
6045

6046
        inputs = [x, w0, w1, w2, w3, hx, cx]
6047
        constant_args = [
6048
            reverse,
6049
            batch_sizes,
6050
            mode,
6051
            hidden_size,
6052
            num_layers,
6053
            has_biases,
6054
            bidirectional,
6055
            batch_first,
6056
            train,
6057
        ]
6058

6059
        packed = MkldnnRnnLayer(
6060
            MultiOutputLayout(x.get_device()),
6061
            inputs=inputs,
6062
            constant_args=constant_args,
6063
        )
6064

6065
        def get_strides_of_lstm_output(output_shape, batch_first):
6066
            assert len(output_shape) == 3, "Expect output_shape to be 3D"
6067
            return make_contiguous_strides_for(output_shape)
6068

6069
        output_sizes = [output_shape, hy_shape, cy_shape]
6070
        output_strides = [
6071
            get_strides_of_lstm_output(output_shape, batch_first),
6072
            make_contiguous_strides_for(hy_shape),
6073
            make_contiguous_strides_for(cy_shape),
6074
        ]
6075
        output_ir = [
6076
            MultiOutput(
6077
                FixedLayout(
6078
                    x.get_device(),
6079
                    x.get_dtype(),
6080
                    output_size,
6081
                    output_stride,
6082
                ),
6083
                packed,
6084
                [(tuple, i)],
6085
            )
6086
            for i, (output_size, output_stride) in enumerate(
6087
                zip(output_sizes, output_strides)
6088
            )
6089
        ]
6090

6091
        return output_ir
6092

6093

6094
class QConvPointWisePT2E(ExternKernelAlloc):
6095
    def __init__(
6096
        self,
6097
        layout,
6098
        inputs,
6099
        constant_args=(),
6100
    ):
6101
        """
6102
        if bias is not None
6103
            - inputs = [x, w, b, weight_scale, weight_zp]
6104
            - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6105
              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6106
        else
6107
            - inputs = [x, w, weight_scale, weight_zp]
6108
            - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp,
6109
              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6110
        """
6111
        self.has_bias = len(inputs) == 5
6112
        super().__init__(
6113
            layout,
6114
            inputs,
6115
            constant_args,
6116
            None,
6117
            python_kernel_name="torch.ops.onednn.qconv2d_pointwise",
6118
            cpp_kernel_name="onednn::qconv2d_pointwise",
6119
        )
6120
        self.cpp_kernel_key = "qconv2d_pointwise"
6121
        self.cpp_op_schema = """
6122
            at::Tensor(
6123
                at::Tensor act,
6124
                double act_scale,
6125
                int64_t act_zero_point,
6126
                at::Tensor weight,
6127
                at::Tensor weight_scales,
6128
                at::Tensor weight_zero_points,
6129
                c10::optional<at::Tensor> bias,
6130
                torch::List<int64_t> stride,
6131
                torch::List<int64_t> padding,
6132
                torch::List<int64_t> dilation,
6133
                int64_t groups,
6134
                double inv_output_scale,
6135
                int64_t output_zero_point,
6136
                c10::optional<c10::ScalarType> output_dtype,
6137
                c10::string_view attr,
6138
                torch::List<c10::optional<at::Scalar>> scalars,
6139
                c10::optional<c10::string_view> algorithm)"""
6140

6141
    def codegen(self, wrapper):
6142
        # Parser the inputs and constant
6143
        args = [x.codegen_reference() for x in self.inputs]
6144
        const_args = []
6145
        const_args.extend(self.codegen_const_args())
6146

6147
        x = args[0]
6148
        packed_weight = args[1]
6149
        bias = args[2] if self.has_bias else const_args[0]
6150
        w_scale, w_zp = args[-2], args[-1]
6151
        (
6152
            stride,
6153
            padding,
6154
            dilation,
6155
            groups,
6156
            x_scale,
6157
            x_zp,
6158
            o_inv_scale,
6159
            o_zp,
6160
            output_dtype,
6161
            unary_attr,
6162
            unary_scalars,
6163
            unary_algorithm,
6164
        ) = const_args[-12:]
6165

6166
        codegen_args = (
6167
            x,
6168
            x_scale,
6169
            x_zp,
6170
            packed_weight,
6171
            w_scale,
6172
            w_zp,
6173
            bias,
6174
            stride,
6175
            padding,
6176
            dilation,
6177
            groups,
6178
            o_inv_scale,
6179
            o_zp,
6180
            output_dtype,
6181
            unary_attr,
6182
            unary_scalars,
6183
            unary_algorithm,
6184
        )
6185
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6186
            self.get_name(),
6187
            self.get_kernel_name(),
6188
            codegen_args,
6189
            self.cpp_op_schema,
6190
            self.cpp_kernel_key,
6191
        )
6192
        if isinstance(self.layout, Layout):
6193
            self.codegen_size_asserts(wrapper)
6194

6195
    @classmethod
6196
    def create(
6197
        cls,
6198
        x: "TensorBox",
6199
        x_scale: float,
6200
        x_zp: int,
6201
        weight: "TensorBox",  # packed_weight
6202
        w_scale: "TensorBox",
6203
        w_zp: "TensorBox",
6204
        bias: "TensorBox",
6205
        stride_: List[int],
6206
        padding_: List[int],
6207
        dilation_: List[int],
6208
        groups: int,
6209
        o_inv_scale: float,
6210
        output_zero_point: int,
6211
        output_dtype,
6212
        unary_attr,
6213
        unary_scalars,
6214
        unary_algorithm,
6215
    ):
6216
        transposed = False
6217
        output_padding = None
6218
        (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
6219
            cls,
6220
            x,
6221
            weight,
6222
            bias,
6223
            padding_,
6224
            stride_,
6225
            dilation_,
6226
            groups,
6227
            transposed,
6228
            output_padding,
6229
        )
6230
        # swap padding and stride to align with functional conv arg order
6231
        if bias is None:
6232
            constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6233
        else:
6234
            constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6235

6236
        w_scale.realize()
6237
        w_zp.realize()
6238
        inputs = inputs + [w_scale, w_zp]
6239
        constant_args = constant_args + [
6240
            x_scale,
6241
            x_zp,
6242
            o_inv_scale,
6243
            output_zero_point,
6244
            output_dtype,
6245
            unary_attr,
6246
            may_convert_to_optional(unary_scalars),
6247
            unary_algorithm,
6248
        ]
6249

6250
        if output_dtype is not None:
6251
            assert output_dtype in [torch.float32, torch.bfloat16]
6252
            # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout
6253
            # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8.
6254
            kernel_layout.dtype = output_dtype
6255

6256
        return QConvPointWisePT2E(
6257
            layout=kernel_layout,
6258
            inputs=inputs,
6259
            constant_args=constant_args,
6260
        )
6261

6262

6263
class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
6264
    def __init__(
6265
        self,
6266
        layout,
6267
        inputs,
6268
        constant_args=(),
6269
    ):
6270
        """
6271
        Needs input/weight/output qparams
6272
        if bias is not None
6273
            - inputs = [x, w, b, accum, w_scale, w_zp]
6274
            - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp,
6275
            fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6276
        else
6277
            - inputs = [x, w, accum, w_scale, w_zp]
6278
            - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale,
6279
            accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm]
6280
        """
6281
        self.has_bias = len(inputs) == 6
6282
        self.idx_for_inplace_sum = 3 if self.has_bias else 2
6283
        super().__init__(
6284
            layout,
6285
            inputs,
6286
            constant_args,
6287
            None,
6288
            python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary",
6289
            cpp_kernel_name="onednn::qconv2d_pointwise",
6290
        )
6291
        self.cpp_kernel_overload_name = "binary"
6292
        self.cpp_kernel_key = "qconv2d_pointwise_binary"
6293
        self.cpp_op_schema = """
6294
            at::Tensor(
6295
                at::Tensor act,
6296
                double act_scale,
6297
                int64_t act_zero_point,
6298
                at::Tensor accum,
6299
                double accum_scale,
6300
                int64_t accum_zero_point,
6301
                at::Tensor weight,
6302
                at::Tensor weight_scales,
6303
                at::Tensor weight_zero_points,
6304
                c10::optional<at::Tensor> bias,
6305
                torch::List<int64_t> stride,
6306
                torch::List<int64_t> padding,
6307
                torch::List<int64_t> dilation,
6308
                int64_t groups,
6309
                double inv_output_scale,
6310
                int64_t output_zero_point,
6311
                c10::optional<c10::ScalarType> output_dtype,
6312
                c10::string_view binary_attr,
6313
                c10::optional<at::Scalar> alpha,
6314
                c10::optional<c10::string_view> attr,
6315
                torch::List<c10::optional<at::Scalar>> scalars,
6316
                c10::optional<c10::string_view> algorithm)"""
6317

6318
    def codegen(self, wrapper):
6319
        # Parser the inputs and constant
6320
        args = [x.codegen_reference() for x in self.inputs]
6321
        const_args = []
6322
        const_args.extend(self.codegen_const_args())
6323

6324
        x = args[0]
6325
        packed_weight = args[1]
6326
        bias = args[2] if self.has_bias else const_args[0]
6327
        accum, w_scale, w_zp = args[-3], args[-2], args[-1]
6328
        (
6329
            stride,
6330
            padding,
6331
            dilation,
6332
            groups,
6333
            x_scale,
6334
            x_zp,
6335
            accum_scale,
6336
            accum_zp,
6337
            o_inv_scale,
6338
            o_zp,
6339
            output_dtype,
6340
            binary_attr,
6341
            alpha,
6342
            unary_attr,
6343
            unary_scalars,
6344
            unary_algorithm,
6345
        ) = const_args[-16:]
6346
        conv_args = (
6347
            x,
6348
            x_scale,
6349
            x_zp,
6350
            accum,
6351
            accum_scale,
6352
            accum_zp,
6353
            packed_weight,
6354
            w_scale,
6355
            w_zp,
6356
            bias,
6357
            stride,
6358
            padding,
6359
            dilation,
6360
            groups,
6361
            o_inv_scale,
6362
            o_zp,
6363
            output_dtype,
6364
            binary_attr,
6365
            alpha,
6366
            unary_attr,
6367
            unary_scalars,
6368
            unary_algorithm,
6369
        )
6370
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6371
            self.get_name(),
6372
            self.get_kernel_name(),
6373
            conv_args,
6374
            self.cpp_op_schema,
6375
            self.cpp_kernel_key,
6376
            self.cpp_kernel_overload_name,
6377
        )
6378
        if isinstance(self.layout, Layout):
6379
            self.codegen_size_asserts(wrapper)
6380

6381
    def get_mutation_names(self):
6382
        return [self.inputs[self.idx_for_inplace_sum].get_name()]
6383

6384
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
6385
        return set()
6386

6387
    @classmethod
6388
    def create(
6389
        cls,
6390
        x: "TensorBox",
6391
        x_scale,
6392
        x_zp,
6393
        accum: "TensorBox",
6394
        accum_scale,
6395
        accum_zp,
6396
        weight: "TensorBox",  # packed_weight
6397
        w_scale,
6398
        w_zp,
6399
        bias: "TensorBox",
6400
        stride_: List[int],
6401
        padding_: List[int],
6402
        dilation_: List[int],
6403
        groups: int,
6404
        o_inv_scale: "TensorBox",
6405
        output_zero_point: "TensorBox",
6406
        output_dtype,
6407
        binary_attr,
6408
        alpha,
6409
        unary_attr,
6410
        unary_scalars,
6411
        unary_algorithm,
6412
    ):
6413
        transposed = False
6414
        output_padding = None
6415
        (
6416
            inputs,
6417
            constant_args,
6418
            kernel_layout,
6419
            req_stride_order,
6420
        ) = _prepare_convolution_fusion_create(
6421
            cls,
6422
            x,
6423
            weight,
6424
            bias,
6425
            padding_,
6426
            stride_,
6427
            dilation_,
6428
            groups,
6429
            transposed,
6430
            output_padding,
6431
        )
6432

6433
        accum = cls.require_stride_order(accum, req_stride_order)
6434
        inputs.append(accum)
6435

6436
        # swap padding and stride to align with functional conv arg order
6437
        if bias is None:
6438
            constant_args[1], constant_args[2] = constant_args[2], constant_args[1]
6439
        else:
6440
            constant_args[0], constant_args[1] = constant_args[1], constant_args[0]
6441

6442
        w_scale.realize()
6443
        w_zp.realize()
6444
        inputs = inputs + [w_scale, w_zp]
6445
        constant_args = constant_args + [
6446
            x_scale,
6447
            x_zp,
6448
            accum_scale,
6449
            accum_zp,
6450
            o_inv_scale,
6451
            output_zero_point,
6452
            output_dtype,
6453
            binary_attr,
6454
            alpha,
6455
            unary_attr,
6456
            may_convert_to_optional(unary_scalars),
6457
            unary_algorithm,
6458
        ]
6459

6460
        assert (
6461
            binary_attr == "sum"
6462
        ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
6463

6464
        packed = QConvPointWiseBinaryPT2E(
6465
            layout=NoneLayout(accum.get_device()),
6466
            inputs=inputs,
6467
            constant_args=constant_args,
6468
        )
6469
        mark_node_as_mutating(packed, accum)
6470

6471
        # Return accum since it has been inplace changed.
6472
        return packed.inputs[packed.idx_for_inplace_sum]
6473

6474

6475
class QLinearPointwisePT2E(ExternKernelAlloc):
6476
    def __init__(
6477
        self,
6478
        layout,
6479
        inputs,
6480
        constant_args=(),
6481
        has_bias=True,
6482
        x_scale_zp_are_tensors=False,
6483
    ):
6484
        """
6485
        if bias is not None
6486
            - inputs = [x, w, b, weight_scale, weight_zp]
6487
            - const_args is: [x_scale, x_zp, o_inv_scale, o_zp,
6488
              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6489
        else
6490
            - inputs = [x, w, weight_scale, weight_zp]
6491
            - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp,
6492
              fp32_output, unary_attr, unary_scalars, unary_algorithm]
6493
        """
6494
        self.has_bias = has_bias
6495
        self.x_scale_zp_are_tensors = x_scale_zp_are_tensors
6496
        super().__init__(
6497
            layout,
6498
            inputs,
6499
            constant_args,
6500
            None,
6501
            python_kernel_name=(
6502
                "torch.ops.onednn.qlinear_pointwise.tensor"
6503
                if x_scale_zp_are_tensors
6504
                else "torch.ops.onednn.qlinear_pointwise.default"
6505
            ),
6506
            cpp_kernel_name="onednn::qlinear_pointwise",
6507
        )
6508
        self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else ""
6509
        self.cpp_kernel_key = "qlinear_pointwise"
6510
        x_scale_type_str, x_zp_type_str = (
6511
            ("at::Tensor", "at::Tensor")
6512
            if x_scale_zp_are_tensors
6513
            else ("double", "int64_t")
6514
        )
6515
        self.cpp_op_schema = f"""
6516
            at::Tensor(
6517
                at::Tensor act,
6518
                {x_scale_type_str} act_scale,
6519
                {x_zp_type_str} act_zero_point,
6520
                at::Tensor weight,
6521
                at::Tensor weight_scales,
6522
                at::Tensor weight_zero_points,
6523
                c10::optional<at::Tensor> bias,
6524
                double inv_output_scale,
6525
                int64_t output_zero_point,
6526
                c10::optional<c10::ScalarType> output_dtype,
6527
                std::string post_op_name,
6528
                torch::List<c10::optional<at::Scalar>> post_op_args,
6529
                std::string post_op_algorithm)"""
6530

6531
    def codegen(self, wrapper):
6532
        # Parser the inputs and constant
6533
        args = [x.codegen_reference() for x in self.inputs]
6534
        const_args = []
6535
        const_args.extend(self.codegen_const_args())
6536

6537
        x = args[0]
6538
        packed_weight = args[1]
6539
        bias = args[2] if self.has_bias else const_args[0]
6540
        w_scale, w_zp = args[-2], args[-1]
6541
        if self.x_scale_zp_are_tensors:
6542
            assert len(args) >= 4
6543
            x_scale, x_zp = args[-4], args[-3]
6544
            (
6545
                o_inv_scale,
6546
                o_zp,
6547
                output_dtype,
6548
                unary_attr,
6549
                unary_scalars,
6550
                unary_algorithm,
6551
            ) = const_args[-6:]
6552
        else:
6553
            assert len(const_args) >= 8
6554
            (
6555
                x_scale,
6556
                x_zp,
6557
                o_inv_scale,
6558
                o_zp,
6559
                output_dtype,
6560
                unary_attr,
6561
                unary_scalars,
6562
                unary_algorithm,
6563
            ) = const_args[-8:]
6564

6565
        codegen_args = (
6566
            x,
6567
            x_scale,
6568
            x_zp,
6569
            packed_weight,
6570
            w_scale,
6571
            w_zp,
6572
            bias,
6573
            o_inv_scale,
6574
            o_zp,
6575
            output_dtype,
6576
            unary_attr,
6577
            unary_scalars,
6578
            unary_algorithm,
6579
        )
6580
        wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
6581
            self.get_name(),
6582
            self.get_kernel_name(),
6583
            codegen_args,
6584
            self.cpp_op_schema,
6585
            self.cpp_kernel_key,
6586
            self.cpp_kernel_overload_name,
6587
        )
6588
        if isinstance(self.layout, Layout):
6589
            self.codegen_size_asserts(wrapper)
6590

6591
    @classmethod
6592
    def create(
6593
        cls,
6594
        x: "TensorBox",
6595
        x_scale: float,
6596
        x_zp: int,
6597
        weight: "TensorBox",  # packed_weight
6598
        w_scale: "TensorBox",
6599
        w_zp: "TensorBox",
6600
        bias: "TensorBox",
6601
        o_inv_scale: float,
6602
        output_zero_point: int,
6603
        output_dtype,
6604
        unary_attr,
6605
        unary_scalars,
6606
        unary_algorithm,
6607
    ):
6608
        (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create(
6609
            cls,
6610
            x,
6611
            weight,
6612
            bias,
6613
        )
6614

6615
        if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox):
6616
            x_scale.realize()
6617
            x_zp.realize()
6618
            inputs = inputs + [x_scale, x_zp]
6619
            x_scale_zp_are_tensors = True
6620
        else:
6621
            assert isinstance(x_scale, float) and isinstance(x_zp, int)
6622
            constant_args = constant_args + [x_scale, x_zp]
6623
            x_scale_zp_are_tensors = False
6624
        w_scale.realize()
6625
        w_zp.realize()
6626
        inputs = inputs + [w_scale, w_zp]
6627
        constant_args = constant_args + [
6628
            o_inv_scale,
6629
            output_zero_point,
6630
            output_dtype,
6631
            unary_attr,
6632
            may_convert_to_optional(unary_scalars),
6633
            unary_algorithm,
6634
        ]
6635

6636
        if output_dtype is not None:
6637
            assert output_dtype in [torch.float32, torch.bfloat16]
6638
            # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout
6639
            # if we set fp32_output, the output buf should be dtype float32 instead of uint8.
6640
            kernel_layout.dtype = output_dtype
6641

6642
        return QLinearPointwisePT2E(
6643
            layout=kernel_layout,
6644
            inputs=inputs,
6645
            constant_args=constant_args,
6646
            has_bias=(bias is not None),
6647
            x_scale_zp_are_tensors=x_scale_zp_are_tensors,
6648
        )
6649

6650

6651
@dataclasses.dataclass
6652
class MutableBox(IRNode):
6653
    """
6654
    TensorBox / StorageBox allow in-place mutation of Tensors
6655
    """
6656

6657
    data: IRNode
6658

6659
    def __getattr__(self, name):
6660
        fn = getattr(self.data, name)
6661
        if callable(fn):
6662
            return fn
6663
        raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
6664

6665
    def realize(self):
6666
        return self.data.realize()
6667

6668
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
6669
        return self.data.get_unbacked_symbol_uses()
6670

6671
    def codegen_reference(self, writer=None):
6672
        return self.data.codegen_reference(writer)
6673

6674
    @property
6675
    def layout(self):
6676
        return self.data.layout  # type: ignore[attr-defined]
6677

6678
    def get_layout(self):
6679
        return self.layout
6680

6681
    def get_size(self):
6682
        return self.data.get_size()
6683

6684
    @property
6685
    def dtype(self):
6686
        return self.data.dtype
6687

6688
    def __str__(self):
6689
        if isinstance(self.data, MutableBox):
6690
            line0 = f"{type(self).__name__}({type(self.data).__name__}("
6691
            endl = "))"
6692
            inner = self.data.data
6693
        else:
6694
            line0 = f"{type(self).__name__}("
6695
            inner = self.data
6696
            endl = ")"
6697

6698
        lines = [
6699
            line0,
6700
            indent(str(inner)),
6701
            endl,
6702
        ]
6703
        return "\n".join(lines)
6704

6705
    __repr__ = __str__
6706

6707

6708
class TensorBox(MutableBox):
6709
    @staticmethod
6710
    def create(data):
6711
        return TensorBox(StorageBox(data))
6712

6713

6714
class StorageBox(MutableBox):
6715
    def is_input_buffer(self):
6716
        if isinstance(self.data, (InputBuffer, ReinterpretView)):
6717
            return self.data.get_name() in V.graph.graph_inputs
6718
        return False
6719

6720
    def realize(self):
6721
        if isinstance(
6722
            self.data,
6723
            (
6724
                ComputedBuffer,
6725
                InputsKernel,
6726
                InputBuffer,
6727
                ReinterpretView,
6728
                TemplateBuffer,
6729
            ),
6730
        ):
6731
            return self.data.get_name()
6732
        assert isinstance(self.data, (Pointwise, Reduction, Scan)), type(self.data)
6733
        origin_node = self.data.get_origin_node()
6734
        traceback = self.data.get_traceback()
6735
        self.data = ComputedBuffer(
6736
            name=None,
6737
            layout=FlexibleLayout(
6738
                device=self.data.get_device(),
6739
                dtype=self.data.get_dtype(),
6740
                size=self.data.get_size(),
6741
            ),
6742
            data=self.data,
6743
        )
6744
        self.data.name = V.graph.register_buffer(self.data)
6745
        self.data.origins = self.origins
6746
        self.data.origin_node = origin_node
6747
        self.data.traceback = traceback
6748
        return self.data.name
6749

6750
    def realize_hint(self):
6751
        """
6752
        Called on buffers we expect to be forced to realize later.
6753
        """
6754
        if (
6755
            isinstance(self.data, (Pointwise, Reduction))
6756
            and self.num_reads() > 1
6757
            and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one()
6758
        ):
6759
            self.realize()
6760

6761
    def has_exceeded_max_reads(self):
6762
        return isinstance(self.data, Pointwise) and (
6763
            self.num_reads() > config.realize_acc_reads_threshold
6764
            or self.has_large_inner_fn()
6765
        )
6766

6767
    def mark_reuse(self, users):
6768
        """
6769
        A heuristic to decide if we should realize a tensor
6770
        that is used multiple times.
6771
        """
6772

6773
        def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
6774
            """
6775
            The heuristic for realizing reused result of heavy ops on cpu
6776
            """
6777
            heavy_ops = ["exp"]  # a list of heavy ops
6778
            fn_str = loops.inner_fn_str()
6779
            return any((op + "(") in fn_str for op in heavy_ops)
6780

6781
        if (
6782
            users > 1
6783
            and isinstance(self.data, (Pointwise, Reduction))
6784
            and (
6785
                self.num_reads() > config.realize_reads_threshold
6786
                or self.has_large_inner_fn()
6787
                or (is_cpu(self.data) and should_realize_on_cpu(self.data))
6788
            )
6789
        ):
6790
            self.realize()
6791

6792
    @cache_on_self
6793
    def num_reads(self):
6794
        data = self.data
6795
        if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
6796
            return 1
6797
        if isinstance(data, ComputedBuffer):
6798
            read_writes = data.get_read_writes()
6799
        else:
6800
            assert isinstance(data, (Pointwise, Reduction)), type(data)
6801
            read_writes = ComputedBuffer(
6802
                name=None,
6803
                layout=FlexibleLayout(
6804
                    device=data.get_device(),
6805
                    dtype=data.get_dtype(),
6806
                    size=data.get_size(),
6807
                ),
6808
                data=data,
6809
            ).get_read_writes()
6810
        return len(read_writes.reads)
6811

6812
    @cache_on_self
6813
    def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self):
6814
        # Skip the check for non Pointwise instances
6815
        return (
6816
            (sum(read.index != 0 for read in self.data.get_reads()) > 1)
6817
            if isinstance(self.data, Pointwise)
6818
            and all(
6819
                not isinstance(read, dependencies.StarDep)
6820
                for read in self.data.get_reads()
6821
            )
6822
            else True
6823
        )
6824

6825

6826
@dataclasses.dataclass
6827
class Subgraph(IRNode):
6828
    name: str
6829
    graph_module: torch.fx.GraphModule
6830
    graph: Optional["GraphLowering"] = None
6831

6832

6833
@dataclasses.dataclass
6834
class Conditional(ExternKernel):
6835
    predicate: Optional[DynamicScalar] = None
6836
    operands: Optional[List[TensorBox]] = None
6837
    true_subgraph: Optional[Subgraph] = None
6838
    false_subgraph: Optional[Subgraph] = None
6839
    outputs: Optional[List[MultiOutput]] = None
6840

6841
    def __init__(
6842
        self,
6843
        predicate: DynamicScalar,
6844
        operands: List[TensorBox],
6845
        true_subgraph: Subgraph,
6846
        false_subgraph: Subgraph,
6847
        layout: MultiOutputLayout,
6848
    ):
6849
        self.predicate = predicate
6850
        self.operands = operands
6851
        self.true_subgraph = true_subgraph
6852
        self.false_subgraph = false_subgraph
6853

6854
        super().__init__(
6855
            name=None,
6856
            layout=layout,  # type: ignore[arg-type]
6857
            inputs=[predicate, *operands],  # type: ignore[list-item]
6858
        )
6859

6860
        self.name = V.graph.register_buffer(self)
6861

6862
    @classmethod
6863
    def create(
6864
        cls,
6865
        predicate: TensorBox,
6866
        true_fn: Subgraph,
6867
        false_fn: Subgraph,
6868
        operands: List[TensorBox],
6869
    ):
6870
        predicate = cls.realize_input(predicate)
6871
        operands = [cls.realize_input(x) for x in operands]
6872

6873
        fx_operands = V.graph.current_node.args[-1]
6874
        fake_operands = [x.meta["val"] for x in fx_operands]  # type: ignore[union-attr]
6875

6876
        for subgraph in (true_fn, false_fn):
6877
            if subgraph.graph is None:
6878
                # create and lower subgraphs
6879
                subgraph.graph = V.graph.make_subgraph(
6880
                    gm=subgraph.graph_module,
6881
                    example_inputs=fake_operands,
6882
                    subgraph_name=subgraph.name,
6883
                )
6884
                with V.set_graph_handler(subgraph.graph):
6885
                    subgraph.graph.run(*fake_operands)
6886

6887
        true_outputs = true_fn.graph.graph_outputs  # type: ignore[union-attr]
6888
        false_outputs = true_fn.graph.graph_outputs  # type: ignore[union-attr]
6889

6890
        def _aliased_buffers(outputs):
6891
            buffers = [
6892
                output.unwrap_view() if isinstance(output, ReinterpretView) else output
6893
                for output in outputs
6894
            ]
6895
            # assuming the same buffer is represented by the same IRNode object
6896
            return len({id(buffer) for buffer in buffers}) < len(outputs)
6897

6898
        for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
6899
            if _aliased_buffers(true_outputs):
6900
                raise AssertionError(
6901
                    "Output aliasing is currently not supported in compiled torch.cond. "
6902
                    f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
6903
                )
6904

6905
        # make sure true and false outputs are structurally equivalent
6906
        assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
6907
        for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)):
6908
            assert to.get_size() == fo.get_size(), (i, to, fo)
6909
            assert to.get_stride() == fo.get_stride(), (i, to, fo)
6910
            assert to.get_device() == fo.get_device(), (i, to, fo)
6911
            assert to.get_dtype() == fo.get_dtype(), (i, to, fo)
6912
            assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo)
6913

6914
        conditional = Conditional(
6915
            predicate=predicate,
6916
            operands=operands,
6917
            true_subgraph=true_fn,
6918
            false_subgraph=false_fn,
6919
            # use predicate device for consistent codegen-ing
6920
            layout=MultiOutputLayout(predicate.get_device()),
6921
        )
6922

6923
        outputs = [
6924
            MultiOutput(
6925
                FixedLayout(
6926
                    device=output.get_device(),
6927
                    dtype=output.get_dtype(),
6928
                    size=output.get_size(),
6929
                    stride=output.get_stride(),
6930
                    offset=output.get_layout().offset,
6931
                ),
6932
                conditional,
6933
                [(list, i)],
6934
            )
6935
            # as the true and false outputs are equivalent,
6936
            # we can use either of them here as a "template"
6937
            for i, output in enumerate(true_outputs)
6938
        ]
6939

6940
        conditional.outputs = outputs
6941
        return outputs
6942

6943
    def codegen(self, wrapper):
6944
        wrapper.codegen_conditional(self)
6945

6946

6947
class InterpreterShim(torch.fx.Interpreter):
6948
    @staticmethod
6949
    @functools.lru_cache(None)
6950
    def _dummy_gm():
6951
        return torch.fx.symbolic_trace(identity)
6952

6953
    def __init__(self, graph, submodules):
6954
        # call super() with a placeholder to avoid constructing a
6955
        # GraphModule which is very expensive (it does codegen).
6956
        super().__init__(self._dummy_gm(), garbage_collect_values=False)
6957
        self.module = self  # type: ignore[assignment]
6958
        self.graph = graph
6959
        self.submodules = submodules
6960
        self.extra_traceback = False
6961
        self.fetch_attr = submodules.__getitem__
6962
        self.current_node = None
6963

6964
    def run_node(self, n: torch.fx.Node) -> Any:
6965
        self.current_node = n
6966
        return super().run_node(n)
6967

6968
    def run(self, *args, **kwargs):
6969
        with V.set_interpreter_handler(self):
6970
            return super().run(*args, **kwargs)
6971

6972

6973
class LoopBody:
6974
    """
6975
    Captures the body of a Loops subclass into an FX graph.  Persists any
6976
    indexing simplifications and makes it easier to analyze loop bodies.
6977
    """
6978

6979
    def __init__(self, fn, args, var_ranges):
6980
        super().__init__()
6981
        self.var_ranges = var_ranges
6982
        self.indexing_exprs = {}
6983
        self.indexing_exprs_name = {}
6984
        self.reads = []
6985
        self.writes = []
6986
        self.reads_name2expr = {}
6987
        self.writes_name2expr = {}
6988
        self.other = []
6989
        self.submodules = {"get_index": self.get_index}
6990
        self.subblocks = {}
6991
        self.indirect_vars = []
6992
        self.root_block = LoopBodyBlock(self, fn, args)
6993
        self.indexing = None
6994

6995
    @cache_on_self
6996
    def get_nodes(self):
6997
        all_graphs = itertools.chain(
6998
            (self.root_block.graph,),
6999
            (block.graph for block in self.subblocks.values()),
7000
        )
7001
        return [node for graph in all_graphs for node in graph.nodes]
7002

7003
    @cache_on_self
7004
    def bounds(self):
7005
        # Doing a local import to avoid dumping all the code here
7006
        from .bounds import BoundVars
7007

7008
        return BoundVars(self)
7009

7010
    def debug_str(self):
7011
        lines = [f"var_ranges = {dict(self.var_ranges)}"]
7012
        lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
7013
        lines.extend(
7014
            [
7015
                block.debug_str(name)
7016
                for name, block in itertools.chain(
7017
                    [("body", self.root_block)], self.subblocks.items()
7018
                )
7019
            ]
7020
        )
7021
        return "\n".join(lines)
7022

7023
    def add_index_expr(self, expr: sympy.Expr, category, buf_name):
7024
        getattr(self, category).append(expr)
7025
        if buf_name is not None:
7026
            getattr(self, f"{category}_name2expr")[buf_name] = expr
7027
        if expr not in self.indexing_exprs_name:
7028
            name = f"index{len(self.indexing_exprs)}"
7029
            self.indexing_exprs_name[expr] = name
7030
            self.indexing_exprs[name] = expr
7031
        return self.indexing_exprs_name[expr]
7032

7033
    def add_submodule(self, block, prefix):
7034
        """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
7035
        if prefix[-1].isnumeric() and prefix not in self.submodules:
7036
            name = prefix
7037
        else:
7038
            name = f"{prefix}{len(self.submodules)}"
7039
        self.submodules[name] = block
7040
        return name
7041

7042
    def add_indirect(self, size):
7043
        name = f"indirect{len(self.indirect_vars)}"
7044
        var = sympy_index_symbol(name)
7045
        self.indirect_vars.append(var)
7046
        return var
7047

7048
    def replace_indirect(self, old, new):
7049
        """Swap in a variable used in indirect indexing"""
7050
        if str(old) == str(new):
7051
            return
7052
        assert self.indexing is not None
7053
        self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
7054

7055
    def get_index(self, name):
7056
        assert self.indexing is not None
7057
        return self.indexing[name]
7058

7059
    def __call__(self, *indices):
7060
        index = list(itertools.chain.from_iterable(indices))
7061
        assert len(index) == len(self.var_ranges), (index, self.var_ranges)
7062
        assert all(v not in self.var_ranges for v in index)
7063
        replacements = dict(zip(self.var_ranges.keys(), index))
7064
        self.indexing = {
7065
            name: sympy_subs(expr, replacements)
7066
            for name, expr in self.indexing_exprs.items()
7067
        }
7068
        result = self.root_block()
7069
        self.indexing = None
7070
        return result
7071

7072

7073
class LoopBodyBlock:
7074
    """
7075
    Captures the body of a Loops subclass into an FX graph.
7076
    In normal cases there will be a 1:1 mapping between LoopBody and
7077
    LoopBodyBlock, hower in the case of ops.masked() the masked out
7078
    operations will manifest as an extra LoopBodyBlock.
7079
    """
7080

7081
    def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
7082
        self.body = body
7083

7084
        def add_index(expr, category, buf_name=None):
7085
            return tracer.create_proxy(
7086
                "call_module",
7087
                "get_index",
7088
                (self.body.add_index_expr(expr, category, buf_name),),
7089
                {},
7090
            )
7091

7092
        class CaptureIndexing(V.WrapperHandler):  # type: ignore[name-defined]
7093
            self.name = "CaptureIndexing"
7094

7095
            def load(self, name: str, index: sympy.Expr):
7096
                index = add_index(index, "reads", name)
7097
                return self._inner.load(name, index)
7098

7099
            def store(self, name, index, value, mode=None):
7100
                index = add_index(index, "writes", name)
7101
                return self._inner.store(name, index, value, mode)
7102

7103
            def store_reduction(self, name, index, value):
7104
                index = add_index(index, "writes", name)
7105
                return self._inner.store_reduction(name, index, value)
7106

7107
            def reduction(self, dtype, src_dtype, reduction_type, value):
7108
                result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
7109
                if "welford" in reduction_type:
7110
                    return tuple(result[i] for i in range(3))
7111
                return result
7112

7113
            def index_expr(self, index, dtype):
7114
                if isinstance(index, (int, sympy.Integer)):
7115
                    return self._inner.constant(int(index), dtype)
7116
                index = add_index(index, "other")
7117
                return self._inner.index_expr(index, dtype)
7118

7119
            def bucketize(
7120
                self,
7121
                values,
7122
                offsets_name: str,
7123
                offsets_size: sympy.Expr,
7124
                indexing_dtype: torch.dtype,
7125
                right: bool,
7126
            ):
7127
                offsets_size = add_index(offsets_size, "other")
7128
                return self._inner.bucketize(
7129
                    values, offsets_name, offsets_size, indexing_dtype, right
7130
                )
7131

7132
            @staticmethod
7133
            def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy):
7134
                """
7135
                Recursively capture the masked out body in another LoopBodyBlock
7136
                """
7137

7138
                subblock: LoopBodyBlock
7139

7140
                def shim(mask, other):
7141
                    return V.ops.masked(mask, subblock, other)
7142

7143
                name = self.body.add_submodule(shim, "masked_subblock")
7144
                subblock = LoopBodyBlock(self.body, masked_body, [])
7145
                self.body.subblocks[name] = subblock
7146
                return tracer.create_proxy(
7147
                    "call_module", name, (mask_proxy, other_proxy), {}
7148
                )
7149

7150
            @staticmethod
7151
            def scan(
7152
                dtype_proxy, combine_fn: Callable[..., Any], value_proxy, init_proxy
7153
            ):
7154
                def shim(dtype, value, init):
7155
                    return V.ops.scan(dtype, combine_fn, value, init)
7156

7157
                name = self.body.add_submodule(shim, "scan")
7158
                return tracer.create_proxy(
7159
                    "call_module", name, (dtype_proxy, value_proxy, init_proxy), {}
7160
                )
7161

7162
            def frexp(self, value_proxy):
7163
                result = self._inner.frexp(value_proxy)
7164
                # Proxies are iterable, but some methods expect tuples/lists
7165
                return (result[0], result[1])
7166

7167
            @staticmethod
7168
            def indirect_indexing(index_proxy, size, check=True):
7169
                """
7170
                Flow data from tensors into indexing formulas.
7171
                Introduce a call_module to update the indexing.
7172
                """
7173

7174
                var = self.body.add_indirect(size)
7175

7176
                def set_indirect(new_var):
7177
                    self.body.replace_indirect(
7178
                        var, V.ops.indirect_indexing(new_var, size, check)
7179
                    )
7180

7181
                tracer.create_proxy(
7182
                    "call_module",
7183
                    self.body.add_submodule(set_indirect, f"set_{var}"),
7184
                    (index_proxy,),
7185
                    {},
7186
                )
7187
                return var
7188

7189
            @staticmethod
7190
            def output(result):
7191
                tracer.create_proxy("output", "output", (result,), {})
7192

7193
        tracer = torch.fx.Tracer()
7194
        tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
7195
        proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
7196

7197
        from .index_propagation import IndexPropagation
7198
        from .sizevars import SimplifyIndexing
7199

7200
        handler: Any = SimplifyIndexing(
7201
            CaptureIndexing(proxy_ops), self.body.var_ranges
7202
        )
7203
        if config.constant_and_index_propagation:
7204
            handler = IndexPropagation(handler)
7205

7206
        with V.set_ops_handler(handler):
7207
            # This indirection is just a cute way to get IndexPropagation to
7208
            # unwrap the return value.
7209
            ops.output(fn(*args))
7210
        self.graph = tracer.graph
7211

7212
    def __call__(self):
7213
        graph = self.graph
7214
        submodules = self.body.submodules
7215

7216
        return InterpreterShim(graph, submodules).run(V.get_ops_handler())
7217

7218
    def debug_str(self, name="block"):
7219
        code = torch.fx.GraphModule(self.body.submodules, self.graph).code
7220
        return re.sub(
7221
            # strip `; del var0` suffixes to make output prettier
7222
            r";[^\n]*",
7223
            "",
7224
            code.strip().replace("def forward(", f"def {name}("),
7225
        )
7226

7227

7228
class Wait(ExternKernelAlloc):
7229
    """
7230
    Wait should not be used by itself.  It should always be constructed in tandem
7231
    with a collective op that produces a work to wait on.
7232
    """
7233

7234
    def __init__(
7235
        self,
7236
        layout,
7237
        inputs,
7238
        constant_args=(),
7239
    ):
7240
        super().__init__(layout, inputs, constant_args)
7241

7242
    def should_allocate(self):
7243
        return False
7244

7245
    def codegen(self, wrapper):
7246
        from .codegen.wrapper import ReuseLine
7247

7248
        wrapper.add_import_once(
7249
            "from torch.distributed._functional_collectives_impl import _wait_tensor"
7250
        )
7251
        (input_collective,) = (t.codegen_reference() for t in self.inputs)
7252
        wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})")
7253

7254
        # wait op still needs to produce a 'buffer' that represents the tensor output.
7255
        # this is a symbolic gesture, and it gets handled by WrapperCodegen.
7256
        # codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
7257
        # to a new name (`self.get_name()`) and `del`s the old name.
7258
        wrapper.writeline(ReuseLine(wrapper, self.inputs[0], self, delete_old=False))
7259

7260
    @classmethod
7261
    def create(cls, collective_op: "TensorBox"):
7262
        # TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream
7263
        collective_op.decide_layout()
7264
        return Wait(
7265
            layout=AliasedLayout(collective_op),
7266
            inputs=[collective_op],
7267
        )
7268

7269
    def get_alias_names(self):
7270
        # Signal to codegen that our output buffer isn't safe to reuse
7271
        return [self.inputs[0].codegen_reference()]
7272

7273
    def get_mutation_names(self):
7274
        # The generated `_wait_tensor` op mutates the input tensor
7275
        return [self.inputs[0].codegen_reference()]
7276

7277

7278
class CollectiveKernel(ExternKernel):
7279
    """
7280
    Each collective should follow the pattern:
7281
    - extend InPlaceCollectiveKernel or OutOfPlaceCollectiveKernel.
7282
    - the kernel delegates into c10d processgroup, which returns a 'work' obj
7283
    - the work obj is registered via _register_tensor_work so it can be waited on later
7284
    """
7285

7286
    def __init__(self, layout, inputs, constant_args):
7287
        super().__init__(None, layout, inputs, constant_args)
7288
        self.name = V.graph.register_buffer(self)
7289

7290
    def should_emit_register_tensor_work(self):
7291
        return True
7292

7293
    def should_emit_find_or_create_pg(self):
7294
        return True
7295

7296
    def codegen_collective(self, wrapper, output_name, input_names):
7297
        # factor so the boilerplate can be handled in CollectiveKernel.codegen
7298
        raise NotImplementedError("Must implement")
7299

7300
    def codegen_output(self, wrapper, output_name, input_names):
7301
        # factor so the boilerplate can be handled in CollectiveKernel.codegen
7302
        raise NotImplementedError("Must implement")
7303

7304
    @classmethod
7305
    def wrap_inputs_as_inplace(cls, inputs):
7306
        def wrap_input(var):
7307
            op = InPlaceHint(
7308
                FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var
7309
            )
7310
            return TensorBox.create(op)
7311

7312
        return list(map(wrap_input, inputs))
7313

7314
    def codegen(self, wrapper):
7315
        wrapper.add_import_once("import torch.distributed as dist")
7316
        wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d")
7317
        wrapper.add_import_once(
7318
            "import torch.distributed._functional_collectives_impl as fun_col_impl"
7319
        )
7320
        # extract references to our args in string form for codegen output
7321
        input_names = [t.codegen_reference() for t in self.inputs]
7322
        output_name = self.get_name()
7323
        tag, ranks, group_size = self.constant_args
7324

7325
        if self.should_emit_find_or_create_pg():
7326
            # TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
7327
            wrapper.writeline(
7328
                f"{output_name}_pg = c10d._find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
7329
            )
7330

7331
        self.codegen_output(wrapper, output_name, input_names)
7332
        self.codegen_collective(wrapper, output_name, input_names)
7333
        if self.should_emit_register_tensor_work():
7334
            wrapper.writeline(
7335
                f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)"
7336
            )
7337

7338

7339
class InPlaceCollectiveKernel(CollectiveKernel):
7340
    """
7341
    InPlaceCollectiveKernel are those with in-out arguments such as all_reduce.
7342
    Extend this kernel if your collective needs to modify its inputs in-place.
7343
    """
7344

7345
    def __init__(self, layout, inputs, constant_args):
7346
        super().__init__(layout, inputs, constant_args)
7347

7348
    def should_allocate(self):
7349
        return False
7350

7351
    def has_side_effects(self):
7352
        return True
7353

7354
    def codegen_output(self, wrapper, output_name, input_names):
7355
        if len(input_names) > 1:
7356
            wrapper.writeline(f"{output_name} = [{','.join(input_names)}] ")
7357
        else:
7358
            wrapper.writeline(f"{output_name} = {input_names[0]}")
7359

7360

7361
class OutOfPlaceCollectiveKernel(CollectiveKernel):
7362
    """
7363
    OutOfPlaceCollectiveKernel are those that allocate their
7364
    outputs and leave their inputs inplace, such as all_gather.
7365
    """
7366

7367
    def __init__(self, layout, inputs, outputs, constant_args):
7368
        super().__init__(layout, inputs + outputs, constant_args)
7369
        self.outputs = outputs
7370
        self.original_inputs = inputs
7371
        # NOTE: As seen in issue #108780, output buffers of out-of-place collectives
7372
        # could be incorrectly reused. As a safety measure, here we just ban the reuse of them.
7373
        # TODO: A better fix is to figure out how to propagate the aliases properly,
7374
        # so that the buffer is only reused after all its users have consumed it.
7375
        for x in self.outputs:
7376
            V.graph.never_reuse_buffers.add(x.name)
7377

7378
    def should_allocate(self):
7379
        return False
7380

7381
    def has_side_effects(self):
7382
        return True
7383

7384
    def codegen_output(self, wrapper, output_name, input_names):
7385
        input_names = [t.codegen_reference() for t in self.original_inputs]
7386
        wrapper.writeline(f"{output_name}_inputs = [{','.join(input_names)}]")
7387
        wrapper.writeline(f"{output_name} = [{','.join(x.name for x in self.outputs)}]")
7388

7389
    @classmethod
7390
    def create_output_buffers(cls, inputs, size_cb=None):
7391
        outputs = []
7392
        for input in inputs:
7393
            new_size = input.get_size()
7394
            if size_cb is not None:
7395
                size_cb(new_size)
7396
            # new_size[0] *= group_size
7397

7398
            buff = OutputBuffer(
7399
                layout=FlexibleLayout(
7400
                    device=input.get_device(),
7401
                    dtype=input.get_dtype(),
7402
                    size=new_size,
7403
                ),
7404
            )
7405
            outputs.append(buff)
7406
        return outputs
7407

7408
    @classmethod
7409
    def create_output_nodes(cls, coll, output_buffers):
7410
        return [
7411
            MultiOutputNoSizeAssert(
7412
                out_t.layout,
7413
                coll,
7414
                f"[{i}]",
7415
            )
7416
            for i, out_t in enumerate(output_buffers)
7417
        ]
7418

7419

7420
class InPlaceHint(ExternKernel):
7421
    """
7422
    Helper OP to encode an in/out argument that tries to make it inplace whenever possible.
7423
    Wrap the input of your inplace op to enable this behavior.
7424

7425
    The design is based on two key decisions:
7426
    - this node is responsible for allocating the in/out buffer used by the collective.
7427
        This is controlled by the ``should_allocate`` method that returns True here and
7428
        False for the collective node
7429
    - The scheduler special-case this node and enable it to reuse its input.
7430
    """
7431

7432
    def codegen(self, wrapper):
7433
        input_name = self.inputs[0].codegen_reference()
7434
        output_name = self.get_name()
7435
        if not wrapper.did_reuse(self, self.inputs[0]):
7436
            wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse")
7437

7438
    def __init__(self, layout, input):
7439
        input = self.realize_input(input)
7440
        super().__init__(None, layout, self.unwrap_storage([input]), ())
7441
        self.name = V.graph.register_buffer(self)
7442

7443
    def should_allocate(self):
7444
        return True
7445

7446

7447
class OutputBuffer(ExternKernel):
7448
    """
7449
    Represent the output buffer used by ops that require multiple of them
7450
    """
7451

7452
    def __init__(self, layout):
7453
        super().__init__(name=None, layout=layout, inputs=[])
7454
        self.name = V.graph.register_buffer(self)
7455

7456
    def should_allocate(self):
7457
        return True
7458

7459
    def codegen(self, wrapper):
7460
        wrapper.writeline(f"# collective out buffer {self.name}")
7461

7462

7463
class MultiOutputNoSizeAssert(MultiOutput):
7464
    """
7465
    Extract partial output from a multi-output OP.
7466
    Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emitting this.
7467
    """
7468

7469
    def __init__(self, layout, input, index):
7470
        super().__init__(layout, input, [])
7471
        self.index = index
7472

7473
    def codegen(self, wrapper):
7474
        wrapper.writeline(
7475
            f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
7476
        )
7477

7478

7479
class Broadcast(InPlaceCollectiveKernel):
7480
    def __init__(self, layout, inputs, constant_args, src):
7481
        super().__init__(layout, inputs, constant_args)
7482
        self.src = src
7483

7484
    def get_mutation_names(self):
7485
        return [self.inputs[0].get_name()]
7486

7487
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7488
        return set()
7489

7490
    @classmethod
7491
    def create(
7492
        cls, x: "TensorBox", src: int, tag: str, ranks: List[int], group_size: int
7493
    ):
7494
        inplace_inputs = cls.wrap_inputs_as_inplace([x])
7495
        packed = Broadcast(
7496
            layout=NoneLayout(inplace_inputs[0].get_device()),  # type: ignore[arg-type]
7497
            inputs=inplace_inputs,
7498
            constant_args=[tag, ranks, group_size],
7499
            src=src,
7500
        )
7501
        mark_node_as_mutating(packed, inplace_inputs[0])
7502
        return inplace_inputs[0]
7503

7504
    def codegen_collective(self, wrapper, output_name, input_names):
7505
        wrapper.writeline(
7506
            f"{output_name}_work = dist.broadcast("
7507
            f"{output_name}, async_op=True, group={output_name}_pg, src={self.src})"
7508
        )
7509

7510

7511
class AllReduceCoalesced(InPlaceCollectiveKernel):
7512
    def __init__(self, layout, inputs, constant_args, reduce_op):
7513
        super().__init__(layout, inputs, constant_args)
7514
        self.reduce_op = reduce_op
7515

7516
    def should_allocate(self):
7517
        return False
7518

7519
    def get_mutation_names(self):
7520
        return [self.inputs[0].get_name()]
7521

7522
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7523
        return set()
7524

7525
    @classmethod
7526
    def create(
7527
        cls,
7528
        inputs: List["TensorBox"],
7529
        reduce_op: str,
7530
        tag: str,
7531
        ranks: List[int],
7532
        group_size: int,
7533
    ):
7534
        inplace_inputs = cls.wrap_inputs_as_inplace(inputs)
7535
        packed = AllReduceCoalesced(
7536
            layout=NoneLayout(inplace_inputs[0].get_device()),  # type: ignore[arg-type]
7537
            inputs=inplace_inputs,
7538
            constant_args=[tag, ranks, group_size],
7539
            reduce_op=reduce_op,
7540
        )
7541
        mark_node_as_mutating(packed, inplace_inputs[0])
7542
        return inplace_inputs
7543

7544
    def codegen_collective(self, wrapper, output_name, input_names):
7545
        wrapper.writeline(
7546
            f"{output_name}_work = dist.all_reduce_coalesced("
7547
            f"{output_name}, "
7548
            f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
7549
            f"group={output_name}_pg, "
7550
            "async_op=True)"
7551
        )
7552

7553

7554
class AllReduce(InPlaceCollectiveKernel):
7555
    def __init__(self, layout, inputs, constant_args, reduce_op):
7556
        super().__init__(layout, inputs, constant_args)
7557
        self.reduce_op = reduce_op
7558

7559
    def get_mutation_names(self):
7560
        return [self.inputs[0].get_name()]
7561

7562
    def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]:
7563
        return set()
7564

7565
    @classmethod
7566
    def create(
7567
        cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
7568
    ):
7569
        inplace_inputs = cls.wrap_inputs_as_inplace([x])
7570

7571
        packed = AllReduce(
7572
            layout=NoneLayout(inplace_inputs[0].get_device()),  # type: ignore[arg-type]
7573
            inputs=inplace_inputs,
7574
            constant_args=[tag, ranks, group_size],
7575
            reduce_op=reduce_op,
7576
        )
7577
        mark_node_as_mutating(packed, inplace_inputs[0])
7578
        return inplace_inputs[0]
7579

7580
    def codegen_collective(self, wrapper, output_name, input_names):
7581
        wrapper.writeline(
7582
            f"{output_name}_work = dist.all_reduce("
7583
            f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
7584
        )
7585

7586

7587
class AllGatherIntoTensor(OutOfPlaceCollectiveKernel):
7588
    def __init__(self, layout, inputs, outputs, constant_args):
7589
        super().__init__(layout, inputs, outputs, constant_args)
7590

7591
    @classmethod
7592
    def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int):
7593
        inputs = [cls.realize_input(x)]
7594

7595
        def compute_size(new_size):
7596
            new_size[0] *= group_size
7597

7598
        outputs = cls.create_output_buffers(inputs, compute_size)
7599

7600
        layout = MultiOutputLayout(inputs[0].get_device())
7601

7602
        packed = AllGatherIntoTensor(
7603
            layout=layout,
7604
            inputs=inputs,
7605
            outputs=outputs,
7606
            constant_args=[tag, ranks, group_size],
7607
        )
7608
        return cls.create_output_nodes(packed, outputs)[0]
7609

7610
    def codegen_collective(self, wrapper, output_name, input_names):
7611
        wrapper.writeline(
7612
            f"{output_name}_work = dist.all_gather_into_tensor("
7613
            f"{output_name}[0], {output_name}_inputs[0], async_op=True, group={output_name}_pg)"
7614
        )
7615

7616

7617
class ReduceScatterTensor(OutOfPlaceCollectiveKernel):
7618
    def __init__(self, layout, inputs, outputs, constant_args, reduce_op):
7619
        super().__init__(layout, inputs, outputs, constant_args)
7620
        self.reduce_op = reduce_op
7621

7622
    @classmethod
7623
    def create(
7624
        cls,
7625
        x: "TensorBox",
7626
        reduce_op: str,
7627
        tag: str,
7628
        ranks: List[int],
7629
        group_size: int,
7630
    ):
7631
        inputs = [cls.realize_input(x)]
7632

7633
        def compute_size(new_size):
7634
            new_size[0] //= group_size
7635

7636
        outputs = cls.create_output_buffers(inputs, compute_size)
7637

7638
        layout = MultiOutputLayout(inputs[0].get_device())
7639

7640
        packed = ReduceScatterTensor(
7641
            layout=layout,
7642
            inputs=inputs,
7643
            outputs=outputs,
7644
            constant_args=[tag, ranks, group_size],
7645
            reduce_op=reduce_op,
7646
        )
7647
        return cls.create_output_nodes(packed, outputs)[0]
7648

7649
    def codegen_collective(self, wrapper, output_name, input_names):
7650
        wrapper.writeline(
7651
            f"{output_name}_work = dist.reduce_scatter_tensor("
7652
            f"{output_name}[0], {output_name}_inputs[0], "
7653
            f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
7654
        )
7655

7656

7657
class AllGatherIntoTensorCoalesced(OutOfPlaceCollectiveKernel):
7658
    def __init__(self, layout, inputs, outputs, constant_args):
7659
        super().__init__(layout, inputs, outputs, constant_args)
7660

7661
    @classmethod
7662
    def create(
7663
        cls,
7664
        inputs: List["TensorBox"],
7665
        tag: str,
7666
        ranks: List[int],
7667
        group_size: int,
7668
    ):
7669
        inputs = [cls.realize_input(x) for x in inputs]
7670

7671
        def compute_size(new_size):
7672
            new_size[0] *= group_size
7673

7674
        outputs = cls.create_output_buffers(inputs, compute_size)
7675

7676
        layout = MultiOutputLayout(inputs[0].get_device())
7677

7678
        packed = AllGatherIntoTensorCoalesced(
7679
            layout=layout,
7680
            inputs=inputs,
7681
            outputs=outputs,
7682
            constant_args=[tag, ranks, group_size],
7683
        )
7684

7685
        return outputs
7686
        # return cls.create_output_nodes(packed, outputs)
7687

7688
    def codegen_collective(self, wrapper, output_name, input_names):
7689
        wrapper.writeline(
7690
            f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
7691
            f"output_tensors={output_name}, "
7692
            f"input_tensors={output_name}_inputs, "
7693
            f"group={output_name}_pg, "
7694
            "async_op=True)"
7695
        )
7696

7697

7698
class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel):
7699
    def __init__(self, layout, inputs, outputs, constant_args, reduce_op):
7700
        super().__init__(layout, inputs, outputs, constant_args)
7701
        self.reduce_op = reduce_op
7702

7703
    @classmethod
7704
    def create(
7705
        cls,
7706
        inputs: List["TensorBox"],
7707
        reduce_op: str,
7708
        tag: str,
7709
        ranks: List[int],
7710
        group_size: int,
7711
    ):
7712
        inputs = [cls.realize_input(x) for x in inputs]
7713

7714
        def compute_size(new_size):
7715
            new_size[0] //= group_size
7716

7717
        outputs = cls.create_output_buffers(inputs, compute_size)
7718

7719
        layout = MultiOutputLayout(inputs[0].get_device())
7720

7721
        _ = ReduceScatterTensorCoalesced(
7722
            layout=layout,
7723
            inputs=inputs,
7724
            outputs=outputs,
7725
            constant_args=[tag, ranks, group_size],
7726
            reduce_op=reduce_op,
7727
        )
7728

7729
        return outputs
7730

7731
    def codegen_collective(self, wrapper, output_name, input_names):
7732
        wrapper.writeline(
7733
            f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
7734
            f"output_tensors={output_name}, "
7735
            f"input_tensors={output_name}_inputs, "
7736
            f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
7737
            f"group={output_name}_pg, "
7738
            "async_op=True)"
7739
        )
7740

7741

7742
# TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel.
7743
class _CollectiveKernel(FallbackKernel):
7744
    def should_allocate(self):
7745
        return False
7746

7747
    def has_side_effects(self):
7748
        return True
7749

7750
    # This is identical to FallbackKernel.set_cpp_kernel(), minus the
7751
    # part that checks against input aliasing and mutation.
7752
    def set_cpp_kernel(self, kernel):
7753
        from .codegen.wrapper import get_cpp_op_schema
7754

7755
        self.cpp_kernel_name = kernel._schema.name
7756
        self.cpp_kernel_overload_name = kernel._schema.overload_name
7757
        self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"  # type: ignore[union-attr]
7758

7759
        self.cpp_op_schema = get_cpp_op_schema(kernel)
7760
        self.ordered_kwargs_for_cpp_kernel = [
7761
            x.name for x in kernel._schema.arguments if x.kwarg_only
7762
        ]
7763

7764
    # NOTE: [In-Place Collective Safety]
7765
    # Between the initiation and completion of an in-place collective, the
7766
    # input buffers are subject to both volatile reads and volatile writes.
7767
    # They must not be read, written to or reused by another kernel. To ensure
7768
    # the constraints, we model collective -> wait_tensor as as two-step
7769
    # mutation of the input buffers.
7770
    @classmethod
7771
    def create_inplace(
7772
        cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
7773
    ) -> None:
7774
        cpp_kernel_name = kernel._name
7775
        python_kernel_name = cpp_kernel_name.replace("::", ".")
7776
        with V.graph.fake_mode:
7777
            (
7778
                example_output,
7779
                tensor_args,
7780
                non_tensor_args,
7781
                unflatten_args,
7782
            ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
7783
        for tensor_arg in tensor_args:
7784
            tensor_arg.realize()
7785

7786
        packed = cls(
7787
            NoneLayout(tensor_args[0].get_device()),
7788
            kernel,
7789
            tensor_args,
7790
            non_tensor_args,
7791
            unflatten_args,
7792
        )
7793
        packed.cpp_kernel_name = cpp_kernel_name
7794
        packed.python_kernel_name = python_kernel_name
7795

7796
        def mark_mutation(x):
7797
            if isinstance(x.data, BaseView):
7798
                x = x.data.unwrap_view()
7799
            MutationOutput(x.layout, x, packed)
7800

7801
        pytree.tree_map(lambda inp: mark_mutation(inp), inputs)
7802

7803
    # NOTE: [Out-of-Place Collective Safety]
7804
    # Between the initiation and completion of an out-of-place collective:
7805
    #
7806
    # Input buffers:
7807
    # - Are subject to volatile reads
7808
    # - Can be read by another kernel
7809
    # - Must not be written to or reused by another kernel
7810
    #
7811
    # Output buffers:
7812
    # - Are subject to volatile writes
7813
    # - Must not be read, written to or reused by another kernel
7814
    #
7815
    # To ensure the safety of input buffers without sacrificing read
7816
    # availability, we add input buffers as read deps of wait_tensor kernels.
7817
    #
7818
    # To ensure the safety of output buffers, we model wait_tensor as a
7819
    # mutation to the output buffer. Note we also assumes the user program being
7820
    # correct and the output buffer is not consumed by kernels other than
7821
    # wait_tensor.
7822
    #
7823
    # TODO(yifu): add a pre-grad pass to validate the correctness of collective
7824
    # usage in the user program.
7825
    @classmethod
7826
    def create_out_of_place(
7827
        cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
7828
    ):
7829
        cpp_kernel_name = kernel._name
7830
        python_kernel_name = cpp_kernel_name.replace("::", ".")
7831
        with V.graph.fake_mode:
7832
            (
7833
                example_output,
7834
                tensor_args,
7835
                non_tensor_args,
7836
                unflatten_args,
7837
            ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
7838
        for tensor_arg in tensor_args:
7839
            tensor_arg.realize()
7840

7841
        if isinstance(example_output, list):
7842
            device = cls.find_device(tensor_args, example_output)
7843
            packed = cls(
7844
                MultiOutputLayout(device),
7845
                kernel,
7846
                tensor_args,
7847
                non_tensor_args,
7848
                unflatten_args,
7849
            )
7850
            packed.cpp_kernel_name = cpp_kernel_name
7851
            packed.python_kernel_name = python_kernel_name
7852
            packed.outputs = [
7853
                MultiOutput(
7854
                    cls.tensor_to_layout(tensor),
7855
                    packed,
7856
                    [(list, i)],
7857
                )
7858
                for i, tensor in enumerate(example_output)
7859
            ]
7860
            return packed.outputs
7861
        else:
7862
            packed = cls(
7863
                cls.tensor_to_layout(example_output),
7864
                kernel,
7865
                tensor_args,
7866
                non_tensor_args,
7867
                unflatten_args,
7868
            )
7869
            packed.cpp_kernel_name = cpp_kernel_name
7870
            packed.python_kernel_name = python_kernel_name
7871
            packed.outputs = [packed]
7872
            return packed
7873

7874

7875
class _WaitKernel(_CollectiveKernel):
7876
    def get_volatile_reads(self):
7877
        inp = self.inputs[0]
7878
        if isinstance(inp, _CollectiveKernel):
7879
            # Out-of-place single-output
7880
            return [inp.inputs[0]]
7881
        elif isinstance(inp, MultiOutput):
7882
            # Out-of-place multi-output
7883
            coll = inp.inputs[0]
7884
            assert isinstance(coll, _CollectiveKernel)
7885
            _, idx = inp.indices[0]
7886
            return [coll.inputs[idx]]
7887
        else:
7888
            # In-place requires no additional deps handling for volatile
7889
            # reads since the inputs are mutated.
7890
            return []
7891

7892
    @classmethod
7893
    def create_wait(cls, kernel, inp: TensorBox) -> None:
7894
        with V.graph.fake_mode:
7895
            (
7896
                example_output,
7897
                tensor_args,
7898
                non_tensor_args,
7899
                unflatten_args,
7900
            ) = cls.process_kernel(kernel, inp)
7901
        packed = cls(
7902
            NoneLayout(inp.get_device()),
7903
            kernel,
7904
            tensor_args,
7905
            non_tensor_args,
7906
            unflatten_args,
7907
        )
7908
        if isinstance(inp.data, BaseView):
7909
            inp = inp.data.unwrap_view()
7910
        MutationOutput(inp.layout, inp, packed)
7911

7912
    def get_read_writes(self):
7913
        read_writes = super().get_read_writes()
7914
        # See [Out-of-Place Collective Safety].
7915
        volatile_reads = self.get_volatile_reads()
7916
        for vr in volatile_reads:
7917
            read_writes.reads.add(dependencies.StarDep(vr.get_name()))
7918
        return read_writes
7919

7920

7921
# NB: recursive structure here reflects val_to_arg_str, avoid
7922
# calling free_unbacked_symbols on "exotic" types that don't get pexpr
7923
# treatment
7924
def maybe_free_unbacked_symbols(s):
7925
    if isinstance(s, (SymTypes, sympy.Expr)):
7926
        # This branch should be impossible in return position
7927
        return free_unbacked_symbols(s)
7928
    elif isinstance(s, (tuple, list)):
7929
        r = set()
7930
        for t in s:
7931
            r |= maybe_free_unbacked_symbols(t)
7932
        return r
7933
    elif isinstance(s, torch.Tensor):
7934
        # This branch is impossible in constant-args position
7935
        return free_unbacked_symbols(s)
7936
    else:
7937
        return set()
7938

7939

7940
class AllToAllSingle(OutOfPlaceCollectiveKernel):
7941
    def __init__(
7942
        self,
7943
        layout,
7944
        inputs,
7945
        outputs,
7946
        constant_args,
7947
        output_split_sizes,
7948
        input_split_sizes,
7949
    ):
7950
        super().__init__(layout, inputs, outputs, constant_args)
7951
        self.output_split_sizes = output_split_sizes
7952
        self.input_split_sizes = input_split_sizes
7953

7954
    def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]:
7955
        r = set()
7956
        if self.output_split_sizes is not None:
7957
            r |= free_unbacked_symbols(self.output_split_sizes)
7958
        if self.input_split_sizes is not None:
7959
            r |= free_unbacked_symbols(self.input_split_sizes)
7960
        return r
7961

7962
    @classmethod
7963
    def create(
7964
        cls,
7965
        x: "TensorBox",
7966
        output_split_sizes: Optional[List[Expr]],
7967
        input_split_sizes: Optional[List[Expr]],
7968
        tag: str,
7969
        ranks: List[int],
7970
        group_size: int,
7971
    ):
7972
        inputs = [cls.realize_input(x)]
7973

7974
        def compute_size(new_size):
7975
            if output_split_sizes is not None:
7976
                new_size[0] = sum(output_split_sizes)
7977

7978
        outputs = cls.create_output_buffers(inputs, compute_size)
7979

7980
        layout = MultiOutputLayout(inputs[0].get_device())
7981

7982
        packed = AllToAllSingle(
7983
            layout=layout,
7984
            inputs=inputs,
7985
            outputs=outputs,
7986
            constant_args=[tag, ranks, group_size],
7987
            output_split_sizes=output_split_sizes,
7988
            input_split_sizes=input_split_sizes,
7989
        )
7990
        return cls.create_output_nodes(packed, outputs)[0]
7991

7992
    def codegen_collective(self, wrapper, output_name, input_names):
7993
        tag, ranks, group_size = self.constant_args
7994

7995
        # TODO: might be necessary to do some pretty printing on
7996
        # split sizes
7997
        wrapper.writeline(
7998
            f"{output_name}_work = dist.all_to_all_single("
7999
            f"{output_name}[0], {output_name}_inputs[0], "
8000
            f"output_split_sizes={self.output_split_sizes}, "
8001
            f"input_split_sizes={self.input_split_sizes}, "
8002
            f"group={output_name}_pg, async_op=True)"
8003
        )
8004

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.