pytorch

Форк
0
/
_meta_registrations.py 
6636 строк · 198.5 Кб
1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
3
import math
4
from enum import Enum
5
from typing import List, Optional, Sequence, Tuple, Union
6

7
import torch
8
import torch._prims_common as utils
9
from torch import SymBool, SymFloat, Tensor
10
from torch._decomp import (
11
    _add_op_to_registry,
12
    _convert_out_params,
13
    global_decomposition_table,
14
    meta_table,
15
)
16
from torch._ops import OpOverload
17
from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
18
from torch._prims_common import (
19
    corresponding_complex_dtype,
20
    corresponding_real_dtype,
21
    elementwise_dtypes,
22
    ELEMENTWISE_TYPE_PROMOTION_KIND,
23
    IntLike,
24
    make_contiguous_strides_for,
25
    Number,
26
    TensorLike,
27
)
28
from torch._prims_common.wrappers import (
29
    _maybe_convert_to_dtype,
30
    _maybe_resize_out,
31
    _resize_output_check,
32
    _safe_copy_out,
33
    out_wrapper,
34
)
35
from torch._refs import _broadcast_shapes, _maybe_broadcast
36
from torch.utils import _pytree as pytree
37

38

39
aten = torch.ops.aten
40

41
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
42

43

44
def register_meta(op):
45
    def wrapper(fn):
46
        fn = _convert_out_params(fn)
47

48
        def register(op):
49
            _add_op_to_registry(meta_table, op, fn)
50

51
        pytree.tree_map_(register, op)
52
        return fn
53

54
    return wrapper
55

56

57
def elementwise_meta(
58
    *args,
59
    type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
60
):
61
    # Perform type promotion, as this is expected from prim_metafunction
62
    _, result_dtype = utils.elementwise_dtypes(
63
        *args,
64
        type_promotion_kind=type_promotion,
65
    )
66
    args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
67

68
    # Broadcast
69
    args = _maybe_broadcast(*args)
70

71
    # Perform prim checks
72
    return _prim_elementwise_meta(
73
        *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
74
    )
75

76

77
def toRealValueType(dtype):
78
    from_complex = {
79
        torch.complex32: torch.half,
80
        torch.cfloat: torch.float,
81
        torch.cdouble: torch.double,
82
    }
83
    return from_complex.get(dtype, dtype)
84

85

86
def check_inplace_broadcast(self_shape, *args_shape):
87
    broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
88
    torch._check(
89
        broadcasted_shape == self_shape,
90
        lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
91
    )
92

93

94
@register_meta([aten.linspace, aten.logspace])
95
@out_wrapper()
96
def meta_linspace_logspace(
97
    start,
98
    end,
99
    steps,
100
    base=None,
101
    dtype=None,
102
    device=None,
103
    layout=torch.strided,
104
    pin_memory=False,
105
    requires_grad=False,
106
):
107
    if isinstance(start, torch.Tensor):
108
        torch._check(
109
            start.dim() == 0,
110
            lambda: "linspace only supports 0-dimensional start and end tensors",
111
        )
112
    if isinstance(end, torch.Tensor):
113
        torch._check(
114
            end.dim() == 0,
115
            lambda: "linspace only supports 0-dimensional start and end tensors",
116
        )
117

118
    if any(isinstance(arg, complex) for arg in (start, end, steps)):
119
        default_complex_dtype = utils.corresponding_complex_dtype(
120
            torch.get_default_dtype()
121
        )
122
        if dtype is None:
123
            dtype = default_complex_dtype
124
        else:
125
            torch._check(
126
                utils.is_complex_dtype(dtype),
127
                lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
128
            )
129
    else:
130
        dtype = dtype or torch.get_default_dtype()
131
    assert isinstance(dtype, torch.dtype)
132

133
    # steps does not participate in the computation of the dtype
134
    torch._check_type(
135
        isinstance(steps, IntLike),
136
        lambda: f"received an invalid combination of arguments - got \
137
({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
138
    )
139
    assert isinstance(steps, IntLike)  # for mypy
140
    torch._check(steps >= 0, lambda: "number of steps must be non-negative")
141

142
    return torch.empty(
143
        (steps,),  # type: ignore[arg-type]
144
        dtype=dtype,
145
        layout=layout,
146
        device="meta",
147
        pin_memory=pin_memory,
148
        requires_grad=requires_grad,
149
    )
150

151

152
@register_meta([aten.take.default, aten.take.out])
153
@out_wrapper()
154
def meta_take(self, index):
155
    # Type and device checks
156
    torch._check(
157
        index.dtype == torch.long,
158
        lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
159
    )
160
    # Index checks
161
    torch._check_index(
162
        not (self.numel() == 0 and index.numel() != 0),
163
        lambda: "take(): tried to take from an empty tensor",
164
    )
165
    return self.new_empty(index.shape)
166

167

168
@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
169
@out_wrapper()
170
def linalg_cross(self, other, *, dim=-1):
171
    x_d = self.ndim
172
    y_d = other.ndim
173
    torch._check(
174
        x_d == y_d,
175
        lambda: "linalg.cross: inputs must have the same number of dimensions.",
176
    )
177
    torch._check(
178
        self.size(dim) == 3 and other.size(dim) == 3,
179
        lambda: (
180
            f"linalg.cross: inputs dimension {dim} must have length 3. "
181
            f"Got {self.size(dim)} and {other.size(dim)}"
182
        ),
183
    )
184
    out_shape = _broadcast_shapes(self.shape, other.shape)
185
    return self.new_empty(out_shape)
186

187

188
@register_meta(aten.linalg_matrix_exp)
189
@out_wrapper()
190
def linalg_matrix_exp(self):
191
    squareCheckInputs(self, "linalg.matrix_exp")
192
    checkFloatingOrComplex(self, "linalg.matrix_exp")
193
    return torch.empty_like(self, memory_format=torch.contiguous_format)
194

195

196
@register_meta(
197
    [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
198
)
199
@out_wrapper("values", "indices")
200
def cummaxmin(self, dim):
201
    values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
202
    indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
203
    if self.numel() != 0 and self.ndim != 0:
204
        # Checks that dim is within bounds
205
        maybe_wrap_dim(dim, self.ndim)
206
    return values, indices
207

208

209
@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
210
@out_wrapper()
211
def logcumsumexp(self, dim):
212
    # Checks that dim is within bounds
213
    maybe_wrap_dim(dim, self.ndim)
214
    return torch.empty_like(self).contiguous()
215

216

217
# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
218
def _exec_fft(out, self, out_sizes, dim, forward):
219
    ndim = self.ndim
220
    signal_ndim = len(dim)
221
    batch_dims = ndim - signal_ndim
222

223
    # Permute dimensions so batch dimensions come first, and in stride order
224
    dim_permute = list(range(ndim))
225

226
    is_transformed_dim = [False for _ in range(ndim)]
227
    for d in dim:
228
        is_transformed_dim[d] = True
229

230
    # std::partition
231
    left, right = [], []
232
    for d in dim_permute:
233
        if not is_transformed_dim[d]:
234
            left.append(d)
235
        else:
236
            right.append(d)
237
    dim_permute = left + right
238
    batch_end = len(left)
239

240
    self_strides = self.stride()
241
    tmp = dim_permute[:batch_end]
242
    tmp.sort(key=lambda x: self_strides[x], reverse=True)
243
    dim_permute = tmp + dim_permute[batch_end:]
244
    input = self.permute(dim_permute)
245

246
    # Collapse batch dimensions into a single dimension
247
    batched_sizes = [-1] + list(input.shape[batch_dims:])
248
    input = input.reshape(batched_sizes)
249

250
    batch_size = input.size(0)
251
    batched_sizes[0] = batch_size
252
    batched_out_sizes = batched_sizes
253
    for i in range(len(dim)):
254
        batched_out_sizes[i + 1] = out_sizes[dim[i]]
255
    out = out.reshape(batched_out_sizes)
256

257
    # Reshaping to original batch shape and inverting the dimension permutation
258
    out_strides = [0 for _ in range(ndim)]
259
    batch_numel = 1
260
    i = batch_dims - 1
261
    while i >= 0:
262
        out_strides[dim_permute[i]] = batch_numel * out.stride(0)
263
        batch_numel *= out_sizes[dim_permute[i]]
264
        i -= 1
265
    for i in range(batch_dims, ndim):
266
        out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
267
    return out.as_strided(out_sizes, out_strides, out.storage_offset())
268

269

270
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
271
# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
272
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
273
@out_wrapper()
274
def meta_fft_c2c(self, dim, normalization, forward):
275
    assert self.dtype.is_complex
276

277
    out_sizes = self.shape
278
    output = self.new_empty(out_sizes)
279

280
    if not dim:
281
        return output
282

283
    sorted_dims = dim[:]
284
    self_strides = self.stride()
285
    sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
286
    output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
287

288
    return output
289

290

291
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
292
@out_wrapper()
293
def meta_fft_r2c(self, dim, normalization, onesided):
294
    assert self.dtype.is_floating_point
295
    output_sizes = list(self.size())
296

297
    if onesided:
298
        last_dim = dim[-1]
299
        last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
300
        output_sizes[last_dim] = last_dim_halfsize
301

302
    return self.new_empty(
303
        output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
304
    )
305

306

307
@register_meta(aten.randperm.generator_out)
308
def meta_randperm(n, *, generator=None, out):
309
    return _maybe_resize_out(out, torch.Size([n]))
310

311

312
@register_meta(aten.randperm.default)
313
def meta_randperm_default(
314
    n,
315
    *,
316
    dtype=torch.long,
317
    layout=None,
318
    device=None,
319
    pin_memory=None,
320
):
321
    return torch.empty(
322
        n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
323
    )
324

325

326
@register_meta([aten.randint.default, aten.randint.out])
327
@out_wrapper()
328
def meta_randint(
329
    high,
330
    size,
331
    *,
332
    dtype=torch.long,
333
    layout=None,
334
    device=None,
335
    pin_memory=None,
336
):
337
    return torch.empty(
338
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
339
    )
340

341

342
@register_meta([aten.randint.low, aten.randint.low_out])
343
@out_wrapper()
344
def meta_randint_low(
345
    low,
346
    high,
347
    size,
348
    *,
349
    dtype=torch.long,
350
    layout=None,
351
    device=None,
352
    pin_memory=None,
353
):
354
    return torch.empty(
355
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
356
    )
357

358

359
@register_meta([aten.rand.default, aten.rand.out])
360
@out_wrapper()
361
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
362
    return torch.empty(
363
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
364
    )
365

366

367
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
368
@out_wrapper()
369
def meta_fft_c2r(self, dim, normalization, lastdim):
370
    assert self.dtype.is_complex
371
    output_sizes = list(self.size())
372
    output_sizes[dim[-1]] = lastdim
373
    return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
374

375

376
@register_meta(aten.copy_.default)
377
def meta_copy_(self, src, non_blocking=False):
378
    # This code simulates the original decomp from inductor,
379
    # which runs most of the meta checks that we care about.
380
    # In theory, we should make this more robust by carefully
381
    # auditing our C++ copy_() kernel and copying the checks here.
382
    from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
383

384
    # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are
385
    # calling an actual copy_, you'll get that automatically
386
    # https://github.com/pytorch/pytorch/issues/122477
387
    if (
388
        not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1
389
    ):  # 1 == MemOverlap::Yes
390
        raise RuntimeError(
391
            "more than one element of the written-to tensor refers to a single memory location"
392
        )
393

394
    if isinstance(src, Tensor):
395
        intermediate = src.to(self, non_blocking)
396
        if self.size() != intermediate.size():
397
            aten.expand_copy.default(intermediate, self.size())
398
    return self
399

400

401
def inferUnsqueezeGeometry(tensor, dim):
402
    result_sizes = list(tensor.size())
403
    result_strides = list(tensor.stride())
404
    new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
405
    result_sizes.insert(dim, 1)
406
    result_strides.insert(dim, new_stride)
407
    return result_sizes, result_strides
408

409

410
@register_meta(aten.unsqueeze_.default)
411
def meta_unsqueeze_(self, dim):
412
    dim = maybe_wrap_dim(dim, self.dim() + 1)
413
    g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
414
    self.as_strided_(g_sizes, g_strides)
415
    return self
416

417

418
@register_meta(aten._sparse_semi_structured_linear)
419
def meta_sparse_structured_linear(
420
    input: Tensor,
421
    weight: Tensor,
422
    _meta: Tensor,
423
    bias: Optional[Tensor] = None,
424
    _activation_opt: Optional[str] = None,
425
    out_dtype: Optional[torch.dtype] = None,
426
):
427
    output_sizes = list(input.shape)
428
    if bias is not None:
429
        assert weight.size(0) == bias.size(0), "output size mismatch"
430
    assert weight.size(1) == input.size(-1) / 2
431
    output_sizes[-1] = weight.size(0)
432

433
    # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
434
    # We assume that we have already squashed the inputs into a 2-D tensor
435
    # Then, as the output is transposed, we need to propagate the transposed
436
    # stride information to the output tensor
437
    assert len(input.shape) == 2, "we can only handle the squashed input case"
438
    transposed_strides = (1, input.size(0))
439

440
    if out_dtype is not None:
441
        assert (
442
            input.dtype == torch.int8 and out_dtype == torch.int32
443
        ), "out_dtype is only supported for i8i8->i32 linear operator"
444
    output = input.new_empty(
445
        output_sizes,
446
        dtype=input.dtype if out_dtype is None else out_dtype,
447
    ).as_strided(output_sizes, transposed_strides)
448

449
    return output
450

451

452
@register_meta(aten._sparse_semi_structured_mm)
453
def meta_sparse_structured_mm(
454
    mat1: Tensor,
455
    mat1_meta: Tensor,
456
    mat2: Tensor,
457
    out_dtype: Optional[torch.dtype] = None,
458
):
459
    assert len(mat1.shape) == 2
460
    assert len(mat1_meta.shape) == 2
461
    assert len(mat2.shape) == 2
462
    assert mat1.size(1) == mat2.size(0) / 2
463
    output_sizes = [mat1.size(0), mat2.size(1)]
464

465
    if out_dtype is not None:
466
        assert (
467
            mat2.dtype == torch.int8 and out_dtype == torch.int32
468
        ), "out_dtype is only supported for i8i8->i32 linear operator"
469
    output = mat2.new_empty(
470
        output_sizes,
471
        dtype=mat2.dtype if out_dtype is None else out_dtype,
472
    )
473

474
    return output
475

476

477
@register_meta(aten._sparse_semi_structured_addmm)
478
def meta_sparse_structured_addmm(
479
    input: Tensor,
480
    mat1: Tensor,
481
    mat1_meta: Tensor,
482
    mat2: Tensor,
483
    *,
484
    alpha=1,
485
    beta=1,
486
    out_dtype: Optional[torch.dtype] = None,
487
):
488
    assert (
489
        len(input.shape) == 1
490
    ), "only input broadcasted to columns of mat1 * mat2 product is supported"
491
    assert len(mat1.shape) == 2
492
    assert len(mat1_meta.shape) == 2
493
    assert len(mat2.shape) == 2
494
    assert input.size(0) == mat1.size(
495
        0
496
    ), "only input broadcasted to columns of mat1 * mat2 product is supported"
497
    assert mat1.size(1) == mat2.size(0) / 2
498
    output_sizes = [mat1.size(0), mat2.size(1)]
499

500
    if out_dtype is not None:
501
        assert (
502
            mat2.dtype == torch.int8 and out_dtype == torch.int32
503
        ), "out_dtype is only supported for i8i8->i32 linear operator"
504
    output = mat2.new_empty(
505
        output_sizes,
506
        dtype=mat2.dtype if out_dtype is None else out_dtype,
507
    )
508

509
    return output
510

511

512
@register_meta(aten._cslt_sparse_mm)
513
def meta__cslt_sparse_mm(
514
    compressed_A: torch.Tensor,
515
    dense_B: torch.Tensor,
516
    bias: Optional[Tensor] = None,
517
    alpha: Optional[Tensor] = None,
518
    out_dtype: Optional[torch.dtype] = None,
519
    transpose_result: bool = False,
520
):
521
    assert dense_B.dtype in {
522
        torch.float32,
523
        torch.float16,
524
        torch.bfloat16,
525
        torch.int8,
526
    }, "_cslt_sparse_mm only supports fp16, bf16, and int8"
527
    assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
528
    assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
529

530
    is_int8_input_type = compressed_A.dtype == torch.int8
531
    compression_factor = 10 if is_int8_input_type else 9
532
    k = dense_B.size(0)
533
    n = dense_B.size(1)
534
    m = (compressed_A.numel() * 16) // (compression_factor * k)
535
    if bias is not None:
536
        assert m == bias.size(0)
537

538
    if out_dtype is not None:
539
        assert is_int8_input_type and out_dtype in {
540
            torch.float16,
541
            torch.bfloat16,
542
            torch.int32,
543
        }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
544
    output_shape = (n, m) if transpose_result else (m, n)
545
    result = dense_B.new_empty(output_shape, dtype=out_dtype)
546
    return result
547

548

549
@register_meta(aten.index_reduce.default)
550
def meta_index_reduce(
551
    self: Tensor,
552
    dim: int,
553
    index: Tensor,
554
    source: torch.Tensor,
555
    reduce: str,
556
    *,
557
    include_self: bool = True,
558
) -> Tensor:
559
    return torch.empty_like(self, memory_format=torch.contiguous_format)
560

561

562
@register_meta(aten.index_reduce_.default)
563
def meta_index_reduce_(
564
    self: Tensor,
565
    dim: int,
566
    index: Tensor,
567
    source: torch.Tensor,
568
    reduce: str,
569
    *,
570
    include_self: bool = True,
571
) -> Tensor:
572
    return self
573

574

575
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
576
@out_wrapper()
577
@register_meta(aten.index_select.default)
578
def meta_index_select(self, dim, index):
579
    result_size = list(self.size())
580
    if self.dim() > 0:
581
        result_size[dim] = index.numel()
582
    return self.new_empty(result_size)
583

584

585
@register_meta(aten.segment_reduce.default)
586
def meta_segment_reduce(
587
    data: Tensor,
588
    reduce: str,
589
    *,
590
    lengths: Optional[Tensor] = None,
591
    indices: Optional[Tensor] = None,
592
    offsets: Optional[Tensor] = None,
593
    axis: int = 0,
594
    unsafe: bool = False,
595
    initial=None,
596
) -> Tensor:
597
    if indices is not None:
598
        raise NotImplementedError(
599
            "segment_reduce(): indices based reduction is not supported yet."
600
        )
601

602
    def segment_reduce_lengths_tensor(lengths_shape):
603
        return torch.empty(
604
            lengths_shape + data.shape[axis + 1 :],
605
            dtype=data.dtype,
606
            device="meta",
607
            memory_format=torch.contiguous_format,
608
        )
609

610
    if lengths is not None:
611
        return segment_reduce_lengths_tensor(lengths.shape)
612
    # FIXME should probably check that lengths and offset aren't both set, but
613
    # the ATen implementation neglects this too
614
    if offsets is not None:
615
        # lengths == torch.diff(offsets)
616
        lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
617
        return segment_reduce_lengths_tensor(lengths_shape)
618
    raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
619

620

621
@register_meta([aten.max.default, aten.max.unary_out])
622
@out_wrapper()
623
def meta_max(self):
624
    return self.new_empty(())
625

626

627
@register_meta(aten.max.dim)
628
def meta_max_dim(self, dim, keepdim=False):
629
    dim = utils.reduction_dims(self.shape, (dim,))
630
    output_shape = _compute_reduction_shape(self, dim, keepdim)
631
    return (
632
        self.new_empty(output_shape),
633
        self.new_empty(output_shape, dtype=torch.long),
634
    )
635

636

637
@register_meta([aten.min.default, aten.min.unary_out])
638
@out_wrapper()
639
def meta_min(self):
640
    return self.new_empty(())
641

642

643
@register_meta(aten.min.dim)
644
def meta_min_dim(self, dim, keepdim=False):
645
    dim = utils.reduction_dims(self.shape, (dim,))
646
    output_shape = _compute_reduction_shape(self, dim, keepdim)
647
    return (
648
        self.new_empty(output_shape),
649
        self.new_empty(output_shape, dtype=torch.long),
650
    )
651

652

653
@register_meta(aten.angle.default)
654
def meta_angle(self):
655
    if self.is_complex():
656
        result_dtype = corresponding_real_dtype(self.dtype)
657
    else:
658
        _, result_dtype = elementwise_dtypes(
659
            self,
660
            type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
661
        )
662
    return torch.empty_like(self, dtype=result_dtype)
663

664

665
@register_meta(aten.angle.out)
666
def meta_angle_out(self, out):
667
    torch._resize_output_(out, self.size(), self.device)
668
    return out.copy_(torch.angle(self))
669

670

671
@register_meta(aten._assert_async.default)
672
def assert_async(val):
673
    return
674

675

676
@register_meta(aten._assert_async.msg)
677
def assert_async_meta(val, assert_msg):
678
    return
679

680

681
@register_meta(aten._print.default)
682
def print_meta(s):
683
    return
684

685

686
@register_meta(aten._make_dep_token.default)
687
def make_dep_token(
688
    *,
689
    dtype=None,
690
    layout=None,
691
    device=None,
692
    pin_memory=None,
693
    memory_format=None,
694
):
695
    return torch.empty(0, device="meta")
696

697

698
@register_meta(aten.sym_constrain_range.default)
699
def sym_constrain_range(size, min=None, max=None):
700
    # Avoid importing sympy at a module level
701
    from torch.fx.experimental.symbolic_shapes import constrain_range
702

703
    if isinstance(size, (SymFloat, SymBool)):
704
        raise ValueError("Constraining SymFloat or Symbool is nyi")
705
    constrain_range(size, min=min, max=max)
706

707

708
@register_meta(aten._functional_sym_constrain_range.default)
709
def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
710
    aten.sym_constrain_range(size, min=min, max=max)
711
    return dep_token
712

713

714
@register_meta(aten.sym_constrain_range_for_size.default)
715
def sym_constrain_range_for_size(size, min=None, max=None):
716
    # Avoid importing sympy at a module level
717
    from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
718

719
    if isinstance(size, (SymFloat, SymBool)):
720
        raise ValueError("Constraining SymFloat or Symbool is nyi")
721
    _constrain_range_for_size(size, min=min, max=max)
722

723

724
@register_meta(aten._functional_sym_constrain_range_for_size.default)
725
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
726
    aten.sym_constrain_range_for_size(size, min=min, max=max)
727
    return dep_token
728

729

730
@register_meta(aten._functional_assert_async.msg)
731
def functional_assert_async_meta(val, assert_msg, dep_token):
732
    return dep_token
733

734

735
# From aten/src/ATen/native/LinearAlgebraUtils.h
736
def squareCheckInputs(self: Tensor, f_name: str):
737
    assert (
738
        self.dim() >= 2
739
    ), f"{f_name}: The input tensor must have at least 2 dimensions."
740
    assert (
741
        self.size(-1) == self.size(-2)
742
    ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
743

744

745
# Validates input shapes and devices
746
# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
747
# From aten/src/ATen/native/LinearAlgebraUtils.h
748
def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str):
749
    torch._check(
750
        self.device == A.device,
751
        lambda: (
752
            f"Expected b and A to be on the same device, but found b on "
753
            f"{self.device} and A on {A.device} instead."
754
        ),
755
    )
756

757
    torch._check(
758
        self.dtype == A.dtype,
759
        lambda: (
760
            f"Expected b and A to have the same dtype, but found b of type "
761
            f"{self.dtype} and A of type {A.dtype} instead."
762
        ),
763
    )
764

765
    torch._check(
766
        A.size(-1) == A.size(-2),
767
        lambda: (
768
            f"A must be batches of square matrices, "
769
            f"but they are {A.size(-2)} by {A.size(-1)} matrices"
770
        ),
771
    )
772

773
    torch._check(
774
        A.size(-1) == self.size(-2),
775
        lambda: (
776
            f"Incompatible matrix sizes for {name}: each A "
777
            f"matrix is {A.size(-1)} by {A.size(-1)}"
778
            f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
779
        ),
780
    )
781

782

783
# From aten/src/ATen/native/LinearAlgebraUtils.h
784
def checkFloatingOrComplex(
785
    t: Tensor,
786
    f_name: str,
787
    allow_low_precision_dtypes: bool = True,
788
):
789
    dtype = t.dtype
790
    torch._check(
791
        t.is_floating_point() or t.is_complex(),
792
        lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
793
    )
794
    if not allow_low_precision_dtypes:
795
        torch._check(
796
            dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
797
            lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
798
        )
799

800

801
# From aten/src/ATen/native/LinearAlgebraUtils.h
802
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
803
    torch._check(
804
        A.dim() >= 2,
805
        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
806
    )
807

808

809
def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str):
810
    squareCheckInputs(A, f_name)
811
    checkIsMatrix(B, f_name)
812
    torch._check(
813
        A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
814
        lambda: (
815
            f"{f_name}: Incompatible shapes of A and B for the equation "
816
            f"{'AX = B' if left else 'XA = B'}"
817
            f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
818
        ),
819
    )
820

821

822
def checkSameDevice(
823
    fn_name: str,
824
    result: Tensor,
825
    input: Tensor,
826
    result_name: str = "result",
827
):
828
    torch._check(
829
        result.device == input.device,
830
        lambda: (
831
            f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
832
            f"{result_name} on {result.device} and input on {input.device}"
833
        ),
834
    )
835

836

837
def checkUplo(UPLO: str):
838
    UPLO_uppercase = UPLO.upper()
839
    torch._check(
840
        len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
841
        lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
842
    )
843

844

845
@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
846
@out_wrapper("eigenvalues", "eigenvectors")
847
def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True):
848
    squareCheckInputs(A, "linalg.eigh")
849
    checkUplo(UPLO)
850

851
    shape = list(A.shape)
852
    if compute_v:
853
        vecs = A.new_empty(shape)
854
        vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
855
    else:
856
        vecs = A.new_empty([0])
857

858
    shape.pop()
859
    vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
860

861
    return vals, vecs
862

863

864
@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
865
@out_wrapper()
866
def meta__linalg_eigvals(input: Tensor) -> Tensor:
867
    squareCheckInputs(input, "linalg.eigvals")
868
    complex_dtype = (
869
        input.dtype
870
        if utils.is_complex_dtype(input.dtype)
871
        else utils.corresponding_complex_dtype(input.dtype)
872
    )
873
    return input.new_empty(input.shape[:-1], dtype=complex_dtype)
874

875

876
@register_meta([aten.linalg_eig])
877
@out_wrapper("eigenvalues", "eigenvectors")
878
def meta_linalg_eig(input: Tensor):
879
    squareCheckInputs(input, "linalg.eig")
880
    complex_dtype = (
881
        input.dtype
882
        if utils.is_complex_dtype(input.dtype)
883
        else utils.corresponding_complex_dtype(input.dtype)
884
    )
885
    values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
886
    vectors = input.new_empty(input.shape, dtype=complex_dtype)
887
    return values, vectors
888

889

890
def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
891
    return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
892

893

894
@register_meta(aten._cholesky_solve_helper)
895
@out_wrapper()
896
def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
897
    return cloneBatchedColumnMajor(self)
898

899

900
@register_meta(aten.cholesky_solve)
901
@out_wrapper()
902
def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
903
    torch._check(
904
        self.ndim >= 2,
905
        lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
906
    )
907
    torch._check(
908
        A.ndim >= 2,
909
        lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
910
    )
911
    self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
912
        self, A, "cholesky_solve"
913
    )
914
    return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
915

916

917
@register_meta(aten.cholesky)
918
@out_wrapper()
919
def cholesky(self: Tensor, upper: bool = False) -> Tensor:
920
    if self.numel() == 0:
921
        return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
922
    squareCheckInputs(self, "cholesky")
923
    return cloneBatchedColumnMajor(self)
924

925

926
@register_meta(aten.cholesky_inverse)
927
@out_wrapper()
928
def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
929
    squareCheckInputs(self, "cholesky_inverse")
930
    return cloneBatchedColumnMajor(self)
931

932

933
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
934
@register_meta(aten.linalg_cholesky_ex.default)
935
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
936
    squareCheckInputs(A, "linalg.cholesky")
937
    checkFloatingOrComplex(A, "linalg.cholesky")
938

939
    A_shape = A.shape
940
    ndim = len(A_shape)
941

942
    # L
943
    L_strides = make_contiguous_strides_for(A_shape, False)
944
    L = A.new_empty(A_shape)
945
    L.as_strided_(A_shape, L_strides)
946

947
    # infos
948
    infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
949
    return L, infos
950

951

952
@register_meta(
953
    [aten.linalg_householder_product.default, aten.linalg_householder_product.out]
954
)
955
@out_wrapper()
956
def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
957
    torch._check(
958
        input.ndim >= 2,
959
        lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
960
    )
961
    torch._check(
962
        input.size(-2) >= input.size(-1),
963
        lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
964
    )
965
    torch._check(
966
        input.size(-1) >= tau.size(-1),
967
        lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
968
    )
969

970
    torch._check(
971
        input.ndim - tau.ndim == 1,
972
        lambda: (
973
            f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
974
            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
975
        ),
976
    )
977
    if input.ndim > 2:
978
        expected_batch_tau_shape = input.shape[:-2]
979
        actual_batch_tau_shape = tau.shape[:-1]
980
        torch._check(
981
            actual_batch_tau_shape == expected_batch_tau_shape,
982
            lambda: (
983
                f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
984
                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
985
            ),
986
        )
987

988
    torch._check(
989
        tau.dtype == input.dtype,
990
        lambda: (
991
            f"torch.linalg.householder_product: tau dtype {tau.dtype}"
992
            f" does not match input dtype {input.dtype}"
993
        ),
994
    )
995
    checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
996

997
    return torch.empty_strided(
998
        size=input.shape,
999
        stride=make_contiguous_strides_for(input.shape, row_major=False),
1000
        dtype=input.dtype,
1001
        device=input.device,
1002
    )
1003

1004

1005
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
1006
@register_meta(aten.linalg_inv_ex.default)
1007
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
1008
    squareCheckInputs(A, "linalg.inv_ex")
1009
    checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
1010

1011
    L = A.new_empty(A.shape)
1012
    L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1013

1014
    infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
1015
    return L, infos
1016

1017

1018
@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
1019
@out_wrapper("LD", "pivots", "info")
1020
def linalg_ldl_factor_ex_meta(
1021
    self: Tensor,
1022
    *,
1023
    hermitian: bool = False,
1024
    check_errors: bool = False,
1025
) -> Tuple[Tensor, Tensor, Tensor]:
1026
    squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
1027
    checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
1028
    LD = torch.empty_strided(
1029
        size=self.shape,
1030
        stride=make_contiguous_strides_for(self.shape, row_major=False),
1031
        dtype=self.dtype,
1032
        device=self.device,
1033
    )
1034
    pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
1035
    info = self.new_empty(self.shape[:-2], dtype=torch.int)
1036
    return LD, pivots, info
1037

1038

1039
@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
1040
@out_wrapper()
1041
def linalg_ldl_solve_meta(
1042
    LD: Tensor,
1043
    pivots: Tensor,
1044
    B: Tensor,
1045
    *,
1046
    hermitian: bool = False,
1047
) -> Tensor:
1048
    squareCheckInputs(LD, "torch.linalg.ldl_solve")
1049
    checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
1050
    linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
1051
    torch._check(
1052
        B.ndim >= 2,
1053
        lambda: (
1054
            f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
1055
            f"but it has {B.ndim} dimensions instead"
1056
        ),
1057
    )
1058
    expected_pivots_shape = LD.shape[:-1]
1059
    torch._check(
1060
        expected_pivots_shape == pivots.shape,
1061
        lambda: (
1062
            f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
1063
            f"but got pivots with shape {pivots.shape} instead"
1064
        ),
1065
    )
1066
    torch._check(
1067
        utils.is_integer_dtype(pivots.dtype),
1068
        lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
1069
    )
1070
    torch._check(
1071
        LD.dtype == B.dtype,
1072
        lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
1073
    )
1074
    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
1075
    return torch.empty_strided(
1076
        size=B_broadcast_size,
1077
        stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
1078
        dtype=B.dtype,
1079
        device=B.device,
1080
    )
1081

1082

1083
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
1084
@out_wrapper("P", "L", "U")
1085
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
1086
    torch._check(
1087
        A.ndim >= 2,
1088
        lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1089
    )
1090

1091
    sizes = list(A.shape)
1092
    m = sizes[-2]
1093
    n = sizes[-1]
1094
    k = min(m, n)
1095

1096
    sizes[-1] = m
1097
    if pivot:
1098
        P = A.new_empty(sizes)
1099
    else:
1100
        P = A.new_empty([0])
1101

1102
    sizes[-1] = k
1103
    L = A.new_empty(sizes)
1104

1105
    sizes[-2] = k
1106
    sizes[-1] = n
1107
    U = A.new_empty(sizes)
1108
    return P, L, U
1109

1110

1111
@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
1112
@out_wrapper("LU", "pivots", "info")
1113
def linalg_lu_factor_ex_meta(
1114
    A: Tensor,
1115
    *,
1116
    pivot: bool = True,
1117
    check_errors: bool = False,
1118
) -> Tuple[Tensor, Tensor, Tensor]:
1119
    torch._check(
1120
        A.ndim >= 2,
1121
        lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1122
    )
1123

1124
    sizes = list(A.shape)
1125
    m = sizes[-2]
1126
    n = sizes[-1]
1127

1128
    LU = torch.empty_strided(
1129
        size=sizes,
1130
        stride=make_contiguous_strides_for(sizes, row_major=False),
1131
        dtype=A.dtype,
1132
        device=A.device,
1133
    )
1134

1135
    # Sets sizes to the size of pivots
1136
    sizes.pop()
1137
    sizes[-1] = min(m, n)
1138
    pivots = A.new_empty(sizes, dtype=torch.int)
1139

1140
    # Sets sizes to the size of info
1141
    sizes.pop()
1142
    info = A.new_empty(sizes, dtype=torch.int)
1143

1144
    return LU, pivots, info
1145

1146

1147
@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
1148
@out_wrapper()
1149
def linalg_lu_solve_meta(
1150
    LU: Tensor,
1151
    pivots: Tensor,
1152
    B: Tensor,
1153
    *,
1154
    left: bool = True,
1155
    adjoint: bool = False,
1156
) -> Tensor:
1157
    # dtype
1158
    checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
1159
    torch._check(
1160
        LU.dtype == B.dtype,
1161
        lambda: (
1162
            f"linalg.lu_solve: Expected LU and B to have the same dtype, "
1163
            f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
1164
        ),
1165
    )
1166
    torch._check(
1167
        pivots.dtype == torch.int,
1168
        lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
1169
    )
1170

1171
    # matrix shapes
1172
    squareCheckInputs(LU, "torch.linalg.lu_solve")
1173
    checkInputsSolver(LU, B, left, "linalg.lu_solve")
1174
    torch._check(
1175
        LU.size(-1) == pivots.size(-1),
1176
        lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
1177
    )
1178

1179
    # batches
1180
    torch._check(
1181
        LU.shape[:-1] == pivots.shape,
1182
        lambda: (
1183
            f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
1184
            f"but got pivots with shape {pivots.shape} instead"
1185
        ),
1186
    )
1187

1188
    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
1189

1190
    result = torch.empty_strided(
1191
        size=B_broadcast_size,
1192
        stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
1193
        dtype=B.dtype,
1194
        device=B.device,
1195
    )
1196

1197
    if result.numel() != 0 and not left:
1198
        if result.is_complex():
1199
            result = result.conj()
1200

1201
    return result
1202

1203

1204
@register_meta(aten.lu_unpack)
1205
@out_wrapper("P", "L", "U")
1206
def lu_unpack_meta(
1207
    LU: Tensor,
1208
    pivots: Tensor,
1209
    unpack_data: bool = True,
1210
    unpack_pivots: bool = True,
1211
) -> Tuple[Tensor, Tensor, Tensor]:
1212
    torch._check(
1213
        LU.ndim >= 2,
1214
        lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
1215
    )
1216
    if unpack_pivots:
1217
        torch._check(
1218
            pivots.dtype == torch.int32,
1219
            lambda: (
1220
                "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
1221
                "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
1222
            ),
1223
        )
1224
    sizes = list(LU.shape)
1225
    m = sizes[-2]
1226
    n = sizes[-1]
1227
    k = min(m, n)
1228
    sizes[-1] = m
1229
    if unpack_pivots:
1230
        P = LU.new_empty(sizes)
1231
    else:
1232
        P = LU.new_empty([0])
1233
    if unpack_data:
1234
        sizes[-1] = k
1235
        L = LU.new_empty(sizes)
1236
        sizes[-2] = k
1237
        sizes[-1] = n
1238
        U = LU.new_empty(sizes)
1239
    else:
1240
        L = LU.new_empty([0])
1241
        U = LU.new_empty([0])
1242
    return P, L, U
1243

1244

1245
# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
1246
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
1247
    if mode == "reduced":
1248
        compute_q = True
1249
        reduced = True
1250
    elif mode == "complete":
1251
        compute_q = True
1252
        reduced = False
1253
    elif mode == "r":
1254
        compute_q = False
1255
        reduced = True  # this is actually irrelevant in this mode
1256
    else:
1257
        torch._check(
1258
            False,
1259
            lambda: (
1260
                f"qr received unrecognized mode '{mode}' "
1261
                f"but expected one of 'reduced' (default), 'r', or 'complete'"
1262
            ),
1263
        )
1264
    return compute_q, reduced  # type: ignore[possibly-undefined]
1265

1266

1267
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
1268
@out_wrapper("Q", "R")
1269
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
1270
    checkIsMatrix(A, "linalg.qr")
1271
    checkFloatingOrComplex(A, "linalg.qr")
1272

1273
    compute_q, reduced_mode = _parse_qr_mode(mode)
1274

1275
    m = A.shape[-2]
1276
    n = A.shape[-1]
1277
    k = min(m, n)
1278

1279
    if compute_q:
1280
        Q_shape = list(A.shape)
1281
        Q_shape[-1] = k if reduced_mode else m
1282
        Q = A.new_empty(Q_shape)
1283
        Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
1284
    else:
1285
        Q = A.new_empty([0])
1286

1287
    # For readability
1288
    R_shape = list(A.shape)
1289
    R_shape[-2] = k if reduced_mode or not compute_q else m
1290
    R = A.new_empty(R_shape)
1291
    R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
1292
    return Q, R
1293

1294

1295
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
1296
@out_wrapper("sign", "logabsdet", "LU", "pivots")
1297
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1298
    squareCheckInputs(A, "linalg.slogdet")
1299
    checkFloatingOrComplex(A, "linalg.slogdet", False)
1300
    shape = A.shape
1301
    sign = A.new_empty(shape[:-2])
1302
    logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
1303
    LU = torch.empty_strided(
1304
        size=shape,
1305
        stride=make_contiguous_strides_for(shape, False),
1306
        dtype=A.dtype,
1307
        device=A.device,
1308
    )
1309
    pivots = A.new_empty(shape[:-1], dtype=torch.int32)
1310
    return sign, logabsdet, LU, pivots
1311

1312

1313
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
1314
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
1315
@register_meta(aten._linalg_svd.default)
1316
def _linalg_svd_meta(
1317
    A: Tensor,
1318
    full_matrices: bool = False,
1319
    compute_uv: bool = True,
1320
    driver: Optional[str] = None,
1321
):
1322
    checkIsMatrix(A, "linalg.svd")
1323
    checkFloatingOrComplex(A, "linalg.svd")
1324

1325
    batch_dims = list(A.shape[:-2])
1326
    m = A.shape[-2]
1327
    n = A.shape[-1]
1328
    k = min(m, n)
1329

1330
    if compute_uv:
1331
        U_shape = batch_dims + [m, m if full_matrices else k]
1332
        U = A.new_empty(U_shape)
1333
        U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
1334

1335
        V_shape = batch_dims + [n if full_matrices else k, n]
1336
        V = A.new_empty(V_shape)
1337
        # NB: This checks for CUDA since there is no way to check for cuSolver.
1338
        # Also, this might not work correctly on CPU when fake_device is not
1339
        # available as device_hint just defaults to CUDA in that case. See
1340
        # _linalg_svd meta in core.
1341
        is_cuda = device_hint(A) == "cuda"
1342
        V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
1343
    else:
1344
        # doesn't matter
1345
        U = A.new_empty([0])
1346
        V = A.new_empty([0])
1347

1348
    # S is always real, even when A is complex.
1349
    S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
1350
    return U, S, V
1351

1352

1353
def _linalg_broadcast_batch_dims(
1354
    arg1: Tensor,
1355
    arg2: Tensor,
1356
) -> Tuple[List[int], List[int]]:
1357
    # broadcast the batch dimensions of arg1 and arg2.
1358
    arg1_batch_sizes = arg1.shape[:-2]
1359
    arg2_batch_sizes = arg2.shape[:-2]
1360
    expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
1361

1362
    arg1_expand_size = list(expand_batch_portion)
1363
    arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
1364

1365
    arg2_expand_size = list(expand_batch_portion)
1366
    arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
1367
    return arg1_expand_size, arg2_expand_size
1368

1369

1370
def _linalg_broadcast_batch_dims_name(
1371
    arg1: Tensor,
1372
    arg2: Tensor,
1373
    name: Optional[str],
1374
) -> Tuple[Tensor, Tensor]:
1375
    # If there's no name we assume we don't want to check the errors
1376
    if name:
1377
        linearSolveCheckInputs(arg1, arg2, name)
1378

1379
    arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
1380

1381
    arg1_broadcasted = (
1382
        arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
1383
    )
1384
    arg2_broadcasted = (
1385
        arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
1386
    )
1387
    return arg1_broadcasted, arg2_broadcasted
1388

1389

1390
def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
1391
    expected_batched_rhs_shape = input.shape[:-1]
1392
    vector_case = other.ndim == 1 or (
1393
        input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
1394
    )
1395
    return vector_case
1396

1397

1398
@register_meta(aten._linalg_solve_ex)
1399
def _linalg_solve_ex(
1400
    A: Tensor,
1401
    B: Tensor,
1402
    *,
1403
    left: bool = True,
1404
    check_errors: bool = False,
1405
    result: Optional[Tensor] = None,
1406
    LU: Optional[Tensor] = None,
1407
    pivots: Optional[Tensor] = None,
1408
    info: Optional[Tensor] = None,
1409
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1410
    checkFloatingOrComplex(A, "linalg.solve")
1411
    torch._check(
1412
        A.dtype == B.dtype,
1413
        lambda: (
1414
            f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
1415
            f"{A.dtype} and B of type {B.dtype} instead"
1416
        ),
1417
    )
1418
    vector_case = linalg_solve_is_vector_rhs(A, B)
1419
    B_ = B.unsqueeze(-1) if vector_case else B
1420
    checkInputsSolver(A, B_, left, "linalg.solve")
1421
    B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
1422
    torch._check(
1423
        left or not vector_case,
1424
        lambda: (
1425
            "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
1426
            "In this case linalg.solve is equivalent to B / A.squeeze(-1)"
1427
        ),
1428
    )
1429
    result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
1430
    result_ = torch.empty_strided(
1431
        size=result_shape,
1432
        stride=make_contiguous_strides_for(result_shape, not left),
1433
        dtype=B.dtype,
1434
        device=B.device,
1435
    )
1436
    shape = A.shape
1437
    ndim = A.ndim
1438
    LU_ = torch.empty_strided(
1439
        size=shape,
1440
        stride=make_contiguous_strides_for(shape, False),
1441
        dtype=A.dtype,
1442
        device=A.device,
1443
    )
1444
    pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
1445
    info_ = A.new_empty(shape[:-2], dtype=torch.int32)
1446
    out = (result, LU, pivots, info)
1447
    res = (result_, LU_, pivots_, info_)
1448
    if all(x is not None for x in out):
1449
        for r, o in zip(res, out):
1450
            # resize and copy operations are done in-place
1451
            _maybe_resize_out(o, r.shape)  # type: ignore[arg-type]
1452
            # strides are not copied in out_wrapper
1453
            o.as_strided_(r.shape, r.stride())  # type: ignore[union-attr]
1454
            _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False)  # type: ignore[arg-type]
1455
    return res
1456

1457

1458
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
1459
def linalg_solve_triangular_meta(
1460
    A: Tensor,
1461
    B: Tensor,
1462
    *,
1463
    upper: bool,
1464
    left: bool = True,
1465
    unitriangular: bool = False,
1466
    out: Optional[Tensor] = None,
1467
) -> Tensor:
1468
    if out is None:
1469
        out = A.new_empty([0])
1470
    assert isinstance(out, TensorLike)
1471
    checkInputsSolver(A, B, left, "linalg.solve_triangular")
1472
    B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
1473
    avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
1474
    if avoid_copy_A:
1475
        out = _maybe_resize_out(out, B_.shape)
1476
    else:
1477
        # reimplementation of resize_output with result F-contig
1478
        if _resize_output_check(out, B_.shape):
1479
            out.resize_(B_.transpose(-2, -1).shape)
1480
            out.transpose_(-2, -1)
1481
    return out  # type: ignore[return-value]
1482

1483

1484
@register_meta(aten.triangular_solve)
1485
@out_wrapper("solution", "cloned_coefficient")
1486
def triangular_solve_meta(
1487
    self: Tensor,
1488
    A: Tensor,
1489
    upper: bool = True,
1490
    transpose: bool = False,
1491
    unitriangular: bool = False,
1492
) -> Tuple[Tensor, Tensor]:
1493
    torch._check(
1494
        self.ndim >= 2,
1495
        lambda: (
1496
            f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
1497
            f"but it has {self.ndim} dimensions instead"
1498
        ),
1499
    )
1500
    torch._check(
1501
        A.ndim >= 2,
1502
        lambda: (
1503
            f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
1504
            f"but it has {A.ndim} dimensions instead"
1505
        ),
1506
    )
1507

1508
    linearSolveCheckInputs(self, A, "triangular_solve")
1509

1510
    if A.layout == torch.strided:
1511
        self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
1512
        solution = torch.empty_strided(
1513
            size=self_broadcast_size,
1514
            stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
1515
            dtype=self.dtype,
1516
            device=self.device,
1517
        )
1518
        cloned_coefficient = torch.empty_strided(
1519
            size=A_broadcast_size,
1520
            stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
1521
            dtype=A.dtype,
1522
            device=A.device,
1523
        )
1524
    elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
1525
        solution = torch.empty_like(self)
1526
        cloned_coefficient = self.new_empty([0])
1527
    else:
1528
        torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
1529
    return solution, cloned_coefficient  # type: ignore[possibly-undefined]
1530

1531

1532
# From aten/src/ATen/native/LinearAlgebra.cpp
1533
@register_meta(aten._linalg_det.default)
1534
def _linalg_det_meta(A):
1535
    squareCheckInputs(A, "linalg.det")
1536
    checkFloatingOrComplex(A, "linalg.det")
1537

1538
    det = A.new_empty(A.shape[:-2])
1539

1540
    LU = A.new_empty(A.shape)
1541
    LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1542

1543
    pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
1544
    return det, LU, pivots
1545

1546

1547
@register_meta(aten.ormqr)
1548
@out_wrapper()
1549
def ormqr(
1550
    input: Tensor,
1551
    tau: Tensor,
1552
    other: Tensor,
1553
    left: bool = True,
1554
    transpose: bool = False,
1555
) -> Tensor:
1556
    torch._check(
1557
        input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
1558
    )
1559
    torch._check(
1560
        other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
1561
    )
1562

1563
    left_size_condition = -2 if left else -1
1564
    torch._check(
1565
        other.shape[left_size_condition] >= tau.shape[-1],
1566
        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
1567
    )
1568
    torch._check(
1569
        other.shape[left_size_condition] == input.shape[-2],
1570
        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
1571
    )
1572

1573
    torch._check(
1574
        tau.shape[-1] <= input.shape[-1],
1575
        lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
1576
    )
1577

1578
    torch._check(
1579
        input.ndim - tau.ndim == 1,
1580
        lambda: (
1581
            f"torch.ormqr: Expected tau to have one dimension less than input, "
1582
            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
1583
        ),
1584
    )
1585
    torch._check(
1586
        input.ndim == other.ndim,
1587
        lambda: (
1588
            f"torch.ormqr: Expected other to have the same number of dimensions as input, "
1589
            f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
1590
        ),
1591
    )
1592

1593
    if input.ndim > 2:
1594
        expected_batch_shape = input.shape[:-2]
1595
        actual_batch_tau_shape = tau.shape[:-1]
1596
        torch._check(
1597
            actual_batch_tau_shape == expected_batch_shape,
1598
            lambda: (
1599
                f"torch.ormqr: Expected batch dimensions of tau to be "
1600
                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
1601
            ),
1602
        )
1603

1604
        actual_batch_other_shape = other.shape[:-2]
1605
        torch._check(
1606
            actual_batch_other_shape == expected_batch_shape,
1607
            lambda: (
1608
                f"torch.ormqr: Expected batch dimensions of other to be "
1609
                f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
1610
            ),
1611
        )
1612

1613
    torch._check(
1614
        tau.dtype == input.dtype,
1615
        lambda: (
1616
            f"torch.ormqr: Expected input and tau to have the same dtype, "
1617
            f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
1618
        ),
1619
    )
1620
    torch._check(
1621
        other.dtype == input.dtype,
1622
        lambda: (
1623
            f"torch.ormqr: Expected input and other to have the same dtype, "
1624
            f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
1625
        ),
1626
    )
1627

1628
    checkSameDevice("torch.ormqr", tau, input, "tau")
1629
    checkSameDevice("torch.ormqr", other, input, "other")
1630

1631
    return torch.empty_strided(
1632
        size=other.shape,
1633
        stride=make_contiguous_strides_for(other.shape, row_major=False),
1634
        dtype=other.dtype,
1635
        device=other.device,
1636
    )
1637

1638

1639
def _padding_check_valid_input(input, padding, *, dim):
1640
    torch._check(
1641
        len(padding) == 2 * dim,
1642
        lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
1643
    )
1644

1645
    input_dim = input.ndim
1646

1647
    is_batch_mode = input_dim == (dim + 2)
1648

1649
    valid_batch_mode = is_batch_mode
1650
    valid_non_batch_mode = not is_batch_mode
1651

1652
    if is_batch_mode:
1653
        # allow batch size of 0-dim.
1654
        for d in range(1, input_dim):
1655
            valid_batch_mode = valid_batch_mode and input.size(d) != 0
1656
    else:
1657
        for d in range(0, input_dim):
1658
            valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
1659

1660
    # allow empty batch size but not other dimensions.
1661
    torch._check(
1662
        valid_batch_mode or valid_non_batch_mode,
1663
        lambda: (
1664
            f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
1665
            f"and other non-zero dimensions for input, but got: {input.shape}"
1666
        ),
1667
    )
1668

1669

1670
def _pad1d_common(input, padding, *, is_reflection):
1671
    dim_plane = 0
1672
    dim_w = 1
1673
    nbatch = 1
1674

1675
    if input.ndim == 3:
1676
        nbatch = input.size(0)
1677
        dim_w += 1
1678
        dim_plane += 1
1679

1680
    _padding_check_valid_input(input, padding, dim=1)
1681

1682
    pad_l, pad_r = padding
1683

1684
    nplane = input.size(dim_plane)
1685
    input_w = input.size(dim_w)
1686
    output_w = input_w + pad_l + pad_r
1687

1688
    if is_reflection:
1689
        torch._check(
1690
            pad_l < input_w and pad_r < input_w,
1691
            lambda: (
1692
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1693
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1694
            ),
1695
        )
1696

1697
    torch._check(
1698
        output_w >= 1,
1699
        lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
1700
    )
1701

1702
    if input.ndim == 2:
1703
        return input.new_empty((nplane, output_w))
1704
    else:
1705
        return input.new_empty((nbatch, nplane, output_w))
1706

1707

1708
@register_meta(aten.reflection_pad1d)
1709
@out_wrapper()
1710
def meta_reflection_pad1d(input, padding):
1711
    return _pad1d_common(input, padding, is_reflection=True)
1712

1713

1714
@register_meta(aten.replication_pad1d)
1715
@out_wrapper()
1716
def meta_replication_pad1d(input, padding):
1717
    return _pad1d_common(input, padding, is_reflection=False)
1718

1719

1720
def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
1721
    dim_w = 1
1722
    if not is_reflection:
1723
        torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
1724

1725
    if input.ndim == 3:
1726
        dim_w += 1
1727

1728
    pad_l, pad_r = padding
1729

1730
    input_w = input.size(dim_w)
1731
    output_w = input_w + pad_l + pad_r
1732

1733
    if is_reflection:
1734
        torch._check(
1735
            pad_l < input_w and pad_r < input_w,
1736
            lambda: (
1737
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1738
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1739
            ),
1740
        )
1741

1742
    torch._check(
1743
        output_w == grad_output.size(dim_w),
1744
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1745
    )
1746

1747
    return input.new_empty(input.shape)
1748

1749

1750
@register_meta(aten.reflection_pad1d_backward)
1751
@out_wrapper("grad_input")
1752
def meta_reflection_pad1d_backward(grad_output, input, padding):
1753
    return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
1754

1755

1756
@register_meta(aten.replication_pad1d_backward)
1757
@out_wrapper("grad_input")
1758
def meta_replication_pad1d_backward(grad_output, input, padding):
1759
    return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
1760

1761

1762
def _pad2d_common(input, padding, *, is_reflection):
1763
    dim_w = 2
1764
    dim_h = 1
1765
    dim_slices = 0
1766
    nbatch = 1
1767

1768
    _padding_check_valid_input(input, padding, dim=2)
1769

1770
    ndim = input.ndim
1771
    if ndim == 4:
1772
        nbatch = input.size(0)
1773
        dim_w += 1
1774
        dim_h += 1
1775
        dim_slices += 1
1776

1777
    pad_l, pad_r, pad_t, pad_b = padding
1778

1779
    nplane = input.size(dim_slices)
1780
    input_h = input.size(dim_h)
1781
    input_w = input.size(dim_w)
1782
    output_h = input_h + pad_t + pad_b
1783
    output_w = input_w + pad_l + pad_r
1784

1785
    if is_reflection:
1786
        torch._check(
1787
            pad_l < input_w and pad_r < input_w,
1788
            lambda: (
1789
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1790
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1791
            ),
1792
        )
1793
        torch._check(
1794
            pad_t < input_h and pad_b < input_h,
1795
            lambda: (
1796
                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1797
                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1798
            ),
1799
        )
1800

1801
    torch._check(
1802
        output_w >= 1 or output_h >= 1,
1803
        lambda: (
1804
            f"input (H: {input_h} W: {input_w}) is too small. "
1805
            f"Calculated output H: {output_h} W: {output_w}"
1806
        ),
1807
    )
1808

1809
    if input.ndim == 3:
1810
        return input.new_empty((nplane, output_h, output_w))
1811
    else:
1812
        return input.new_empty((nbatch, nplane, output_h, output_w))
1813

1814

1815
@register_meta(aten.reflection_pad2d)
1816
@out_wrapper()
1817
def meta_reflection_pad2d(input, padding):
1818
    return _pad2d_common(input, padding, is_reflection=True)
1819

1820

1821
@register_meta(aten.replication_pad2d)
1822
@out_wrapper()
1823
def meta_replication_pad2d(input, padding):
1824
    return _pad2d_common(input, padding, is_reflection=False)
1825

1826

1827
@register_meta(
1828
    [
1829
        aten.reflection_pad2d_backward.default,
1830
        aten.reflection_pad2d_backward.grad_input,
1831
        aten.replication_pad2d_backward.default,
1832
        aten.replication_pad2d_backward.grad_input,
1833
    ]
1834
)
1835
@out_wrapper("grad_input")
1836
def meta_pad2d_backward(grad_output, self, padding):
1837
    dim_w = 2
1838
    dim_h = 1
1839
    dim_plane = 0
1840
    nbatch = 1
1841

1842
    self_shape = self.shape
1843
    if self.dim() == 4:
1844
        nbatch = self_shape[0]
1845
        dim_w += 1
1846
        dim_h += 1
1847
        dim_plane += 1
1848

1849
    pad_l, pad_r, pad_t, pad_b = padding
1850

1851
    nplane = self_shape[dim_plane]
1852
    input_h = self_shape[dim_h]
1853
    input_w = self_shape[dim_w]
1854
    output_h = input_h + pad_t + pad_b
1855
    output_w = input_w + pad_l + pad_r
1856

1857
    torch._check(
1858
        output_w == grad_output.size(dim_w),
1859
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1860
    )
1861
    torch._check(
1862
        output_h == grad_output.size(dim_h),
1863
        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1864
    )
1865
    return self.new_empty(self.shape)
1866

1867

1868
def _pad3d_common(input, padding, *, is_reflection):
1869
    dim_w = 3
1870
    dim_h = 2
1871
    dim_d = 1
1872
    dim_plane = 0
1873

1874
    _padding_check_valid_input(input, padding, dim=3)
1875

1876
    batch_mode = input.ndim == 5
1877
    if batch_mode:
1878
        nbatch = input.size(0)
1879
        dim_w += 1
1880
        dim_h += 1
1881
        dim_d += 1
1882
        dim_plane += 1
1883

1884
    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1885

1886
    nplane = input.size(dim_plane)
1887
    input_d = input.size(dim_d)
1888
    input_h = input.size(dim_h)
1889
    input_w = input.size(dim_w)
1890
    output_d = input_d + pad_f + pad_bk
1891
    output_h = input_h + pad_t + pad_b
1892
    output_w = input_w + pad_l + pad_r
1893

1894
    if is_reflection:
1895
        torch._check(
1896
            pad_l < input_w and pad_r < input_w,
1897
            lambda: (
1898
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1899
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1900
            ),
1901
        )
1902
        torch._check(
1903
            pad_t < input_h and pad_b < input_h,
1904
            lambda: (
1905
                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1906
                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1907
            ),
1908
        )
1909
        torch._check(
1910
            pad_f < input_d and pad_bk < input_d,
1911
            lambda: (
1912
                f"Argument #8: Padding size should be less than the corresponding input dimension, "
1913
                f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
1914
            ),
1915
        )
1916

1917
    torch._check(
1918
        output_w >= 1 or output_h >= 1 or output_d >= 1,
1919
        lambda: (
1920
            f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
1921
            f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
1922
        ),
1923
    )
1924

1925
    if batch_mode:
1926
        return input.new_empty((nbatch, nplane, output_d, output_h, output_w))  # type: ignore[possibly-undefined]
1927
    else:
1928
        return input.new_empty((nplane, output_d, output_h, output_w))
1929

1930

1931
@register_meta(aten.reflection_pad3d)
1932
@out_wrapper()
1933
def meta_reflection_pad3d(input, padding):
1934
    return _pad3d_common(input, padding, is_reflection=True)
1935

1936

1937
@register_meta(aten.replication_pad3d)
1938
@out_wrapper()
1939
def meta_replication_pad3d(input, padding):
1940
    return _pad3d_common(input, padding, is_reflection=False)
1941

1942

1943
@register_meta(
1944
    [
1945
        aten.reflection_pad3d_backward.default,
1946
        aten.reflection_pad3d_backward.grad_input,
1947
        aten.replication_pad3d_backward.default,
1948
        aten.replication_pad3d_backward.grad_input,
1949
    ]
1950
)
1951
@out_wrapper("grad_input")
1952
def meta_pad3d_backward(grad_output, input, padding):
1953
    torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
1954
    assert input.ndim > 3
1955
    assert grad_output.ndim == input.ndim
1956

1957
    dim_w = 3
1958
    dim_h = 2
1959
    dim_d = 1
1960

1961
    if input.ndim == 5:
1962
        dim_w += 1
1963
        dim_h += 1
1964
        dim_d += 1
1965

1966
    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1967

1968
    input_d = input.size(dim_d)
1969
    input_h = input.size(dim_h)
1970
    input_w = input.size(dim_w)
1971
    output_d = input_d + pad_f + pad_bk
1972
    output_h = input_h + pad_t + pad_b
1973
    output_w = input_w + pad_l + pad_r
1974

1975
    torch._check(
1976
        output_w == grad_output.size(dim_w),
1977
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1978
    )
1979
    torch._check(
1980
        output_h == grad_output.size(dim_h),
1981
        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1982
    )
1983
    torch._check(
1984
        output_d == grad_output.size(dim_d),
1985
        lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
1986
    )
1987

1988
    return input.new_empty(input.shape)
1989

1990

1991
@register_meta(aten._pdist_forward)
1992
@out_wrapper()
1993
def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
1994
    torch._check(
1995
        self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
1996
    )
1997
    n = self.size(0)
1998
    if n <= 1:
1999
        return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format)  # type: ignore[call-overload]
2000
    else:
2001
        return self.new_empty((n * (n - 1) // 2,)).to(
2002
            memory_format=torch.legacy_contiguous_format
2003
        )  # type: ignore[call-overload]
2004

2005

2006
@register_meta(aten._pdist_backward)
2007
@out_wrapper()
2008
def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
2009
    torch._check(
2010
        self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
2011
    )
2012
    torch._check(
2013
        pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
2014
    )
2015
    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2016

2017

2018
@register_meta([aten.baddbmm.default, aten.baddbmm.out])
2019
@out_wrapper()
2020
def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
2021
    dim1 = batch1.size(0)
2022
    dim2 = batch1.size(1)
2023
    dim3 = batch2.size(2)
2024
    self = self.expand((dim1, dim2, dim3))
2025
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
2026
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
2027
    torch._check(
2028
        self.dtype == batch1.dtype == batch2.dtype,
2029
        lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
2030
    )
2031
    batch1_sizes = batch1.shape
2032
    batch2_sizes = batch2.shape
2033
    bs = batch1_sizes[0]
2034
    contraction_size = batch1_sizes[2]
2035
    torch._check(
2036
        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
2037
        lambda: (
2038
            f"Expected size for first two dimensions of batch2 tensor to be: "
2039
            f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
2040
        ),
2041
    )
2042
    return self.new_empty(self.size())
2043

2044

2045
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
2046
@out_wrapper()
2047
def meta_bernoulli(self, *, generator=None):
2048
    # https://github.com/pytorch/pytorch/issues/88612
2049
    return torch.empty_like(self).contiguous()
2050

2051

2052
@register_meta(aten.bernoulli_.float)
2053
def meta_bernoulli_(self, p=0.5, generator=None):
2054
    return self
2055

2056

2057
@register_meta(aten.bernoulli.p)
2058
def meta_bernoulli_p(self, p=0.5, generator=None):
2059
    # https://github.com/pytorch/pytorch/issues/88612
2060
    return torch.empty_like(self).contiguous()
2061

2062

2063
@register_meta([aten.poisson.default, aten.poisson.out])
2064
@out_wrapper()
2065
def meta_poisson(self, generator=None):
2066
    return torch.empty_like(self)
2067

2068

2069
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
2070
def meta__fused_moving_avg_obs_fq_helper(
2071
    self,
2072
    observer_on,
2073
    fake_quant_on,
2074
    running_min,
2075
    running_max,
2076
    scale,
2077
    zero_point,
2078
    averaging_const,
2079
    quant_min,
2080
    quant_max,
2081
    ch_axis,
2082
    per_row_fake_quant=False,
2083
    symmetric_quant=False,
2084
):
2085
    torch._check(
2086
        ch_axis < self.dim(),
2087
        lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
2088
    )
2089
    mask = torch.empty_like(self, dtype=torch.bool)
2090
    return (torch.empty_like(self), mask)
2091

2092

2093
@register_meta(aten.mm)
2094
@out_wrapper()
2095
def meta_mm(a, b):
2096
    torch._check(a.dim() == 2, lambda: "a must be 2D")
2097
    torch._check(b.dim() == 2, lambda: "b must be 2D")
2098
    N, M1 = a.shape
2099
    M2, P = b.shape
2100
    torch._check(
2101
        M1 == M2,
2102
        lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
2103
    )
2104
    return a.new_empty(N, P)
2105

2106

2107
def _compute_reduction_shape(self, dims, keepdim):
2108
    if keepdim:
2109
        return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
2110

2111
    return utils.compute_reduction_output_shape(self.shape, dims)
2112

2113

2114
# FakeTensors (meta tensors with a device) will report device as meta
2115
# when running meta kernels. Here, access the "fake device" of FakeTensor if it
2116
# exists so meta kernels which have diverge per device will be more
2117
# accurate when run with FakeTensors
2118
def device_hint(tensor) -> "str":
2119
    if isinstance(tensor, torch._subclasses.FakeTensor):
2120
        return tensor.fake_device.type
2121
    else:
2122
        return "cuda"  # default to cuda
2123

2124

2125
def calc_conv_nd_return_shape(
2126
    input_tensor: torch.Tensor,
2127
    weight: torch.Tensor,
2128
    stride: Union[List[int], int],
2129
    padding: Union[List[int], int],
2130
    dilation: Union[List[int], int],
2131
    is_transposed: bool,
2132
    groups: int,
2133
    output_padding: Optional[Union[List[int], int]] = None,
2134
):
2135
    def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
2136
        """
2137
        Formula to apply to calculate the length of some dimension of the output
2138

2139
        See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
2140

2141
        Args:
2142
            ln: length of the dimension
2143
            p: padding in that dim
2144
            d: dilation in that dim
2145
            k: kernel size in that dim
2146
            s: stride in that dim
2147
        Returns:
2148
            The output length
2149
        """
2150
        return (ln + 2 * p - d * (k - 1) - 1) // s + 1
2151

2152
    def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
2153
        """
2154
        Formula to apply to calculate the length of some dimension of the output
2155
        if transposed convolution is used.
2156
        See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
2157

2158
        Args:
2159
            ln: length of the dimension
2160
            p: padding in that dim
2161
            d: dilation in that dim
2162
            k: kernel size in that dim
2163
            s: stride in that dim
2164
            op: output padding in that dim
2165

2166
        Returns:
2167
            The output length
2168
        """
2169
        return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
2170

2171
    kernel_size = weight.shape[2:]
2172
    dims = input_tensor.shape[2:]
2173
    if is_transposed:
2174
        out_channels = groups * weight.shape[1]
2175
    else:
2176
        out_channels = weight.shape[0]
2177
        if weight.shape[1] * groups != input_tensor.shape[1]:
2178
            raise RuntimeError("Invalid channel dimensions")
2179

2180
    ret_shape = [input_tensor.shape[0], out_channels]
2181
    if isinstance(stride, IntLike):
2182
        stride = [stride] * len(dims)
2183
    elif len(stride) == 1:
2184
        stride = [stride[0]] * len(dims)
2185

2186
    if isinstance(padding, IntLike):
2187
        padding = [padding] * len(dims)
2188
    elif len(padding) == 1:
2189
        padding = [padding[0]] * len(dims)
2190

2191
    if isinstance(dilation, IntLike):
2192
        dilation = [dilation] * len(dims)
2193
    elif len(dilation) == 1:
2194
        dilation = [dilation[0]] * len(dims)
2195

2196
    output_padding_list: Optional[List[int]] = None
2197
    if output_padding:
2198
        if isinstance(output_padding, IntLike):
2199
            output_padding_list = [output_padding] * len(dims)
2200
        elif len(output_padding) == 1:
2201
            output_padding_list = [output_padding[0]] * len(dims)
2202
        else:
2203
            output_padding_list = output_padding
2204

2205
    for i in range(len(dims)):
2206
        # If output_padding is present, we are dealing with a transposed convolution
2207
        if output_padding_list:
2208
            ret_shape.append(
2209
                _formula_transposed(
2210
                    dims[i],
2211
                    padding[i],
2212
                    dilation[i],
2213
                    kernel_size[i],
2214
                    stride[i],
2215
                    output_padding_list[i],
2216
                )
2217
            )
2218
        else:
2219
            ret_shape.append(
2220
                _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
2221
            )
2222

2223
    return ret_shape
2224

2225

2226
def is_channels_last(ten):
2227
    return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
2228

2229

2230
@register_meta(aten.convolution.default)
2231
def meta_conv(
2232
    input_tensor: torch.Tensor,
2233
    weight: torch.Tensor,
2234
    bias: torch.Tensor,
2235
    stride: List[int],
2236
    padding: List[int],
2237
    dilation: List[int],
2238
    is_transposed: bool,
2239
    output_padding: List[int],
2240
    groups: int,
2241
):
2242
    def pick_memory_format():
2243
        if device_hint(input_tensor) == "cuda":
2244
            if is_channels_last(input_tensor) or is_channels_last(weight):
2245
                return torch.channels_last
2246
        else:
2247
            if is_channels_last(input_tensor):
2248
                return torch.channels_last
2249
        if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
2250
            return torch.contiguous_format
2251
        elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
2252
            return torch.preserve_format
2253

2254
    shape_out = calc_conv_nd_return_shape(
2255
        input_tensor,
2256
        weight,
2257
        stride,
2258
        padding,
2259
        dilation,
2260
        is_transposed,
2261
        groups,
2262
        output_padding if is_transposed else None,
2263
    )
2264

2265
    input_channels_dim = 1
2266
    output_channels_dim = 1
2267
    if input_tensor.size(input_channels_dim) == 0:
2268
        shape_out[output_channels_dim] = 0
2269

2270
    out = input_tensor.new_empty(shape_out)
2271
    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
2272
    return out
2273

2274

2275
if torch._C._has_mkldnn:
2276
    _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
2277
        "mkldnn", "IMPL", "Meta"
2278
    )
2279

2280
    @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
2281
    def meta_mkldnn_convolution_default(
2282
        input_tensor,
2283
        weight,
2284
        bias,
2285
        padding,
2286
        stride,
2287
        dilation,
2288
        groups,
2289
        attr,
2290
        scalars,
2291
        algorithm,
2292
    ):
2293
        shape_out = calc_conv_nd_return_shape(
2294
            input_tensor, weight, stride, padding, dilation, False, groups, []
2295
        )
2296
        out = input_tensor.new_empty(shape_out)
2297
        out_memory_format = torch.channels_last
2298
        if input_tensor.dim() == 5:
2299
            out_memory_format = torch.channels_last_3d
2300
        out = out.to(memory_format=out_memory_format)  # type: ignore[call-overload]
2301
        return out
2302

2303
    @register_meta(torch.ops.mkldnn._linear_pointwise.default)
2304
    def meta_linear_pointwise_default(
2305
        input_tensor, weight, bias, attr, scalars, algorithm
2306
    ):
2307
        return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2308

2309
    if torch._C.has_mkl:
2310
        _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
2311
            "mkl", "IMPL", "Meta"
2312
        )
2313

2314
        @register_meta(torch.ops.mkl._mkl_linear)
2315
        def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size):
2316
            return input_tensor.new_empty(
2317
                (*input_tensor.shape[:-1], orig_weight.shape[0])
2318
            )
2319

2320
    _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
2321
        "onednn", "IMPL", "Meta"
2322
    )
2323

2324
    @register_meta(torch.ops.onednn.qconv2d_pointwise.default)
2325
    def meta_qconv2d_pointwise(
2326
        x,
2327
        x_scale,
2328
        x_zp,
2329
        w,  # prepacked_weight
2330
        w_scale,
2331
        w_zp,
2332
        bias,
2333
        stride,
2334
        padding,
2335
        dilation,
2336
        groups,
2337
        output_scale,
2338
        output_zero_point,
2339
        output_dtype,
2340
        attr,
2341
        scalars,
2342
        algorithm,
2343
    ):
2344
        shape_out = calc_conv_nd_return_shape(
2345
            x,
2346
            w,
2347
            stride,
2348
            padding,
2349
            dilation,
2350
            False,
2351
            groups,
2352
            None,
2353
        )
2354
        assert output_dtype in [torch.float32, torch.bfloat16]
2355
        out = x.new_empty(shape_out, dtype=output_dtype)
2356
        out = out.to(memory_format=torch.channels_last)
2357
        return out
2358

2359
    @register_meta(torch.ops.onednn.qlinear_pointwise.default)
2360
    @register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
2361
    def meta_qlinear_pointwise(
2362
        x,
2363
        x_scale,
2364
        x_zp,
2365
        w,
2366
        w_scale,
2367
        w_zp,
2368
        bias,
2369
        output_scale,
2370
        output_zero_point,
2371
        output_dtype,
2372
        post_op_name,
2373
        post_op_args,
2374
        post_op_algorithm,
2375
    ):
2376
        output_shape = list(x.shape)
2377
        # The weight has been transposed during the qlinear weight prepack process.
2378
        output_shape[-1] = w.shape[1]
2379
        assert output_dtype in [torch.float32, torch.bfloat16]
2380
        out = x.new_empty(output_shape, dtype=output_dtype)
2381
        return out
2382

2383
    _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
2384
        "quantized", "IMPL", "Meta"
2385
    )
2386

2387
    @register_meta(torch.ops.quantized.max_pool2d)
2388
    def meta_quantized_max_pool2d(
2389
        input,
2390
        kernel_size,
2391
        stride=(),
2392
        padding=(0,),
2393
        dilation=(1,),
2394
        ceil_mode=False,
2395
    ):
2396
        (
2397
            nInputPlane,
2398
            outputHeight,
2399
            outputWidth,
2400
        ) = max_pool2d_checks_and_compute_shape(
2401
            input, kernel_size, stride, padding, dilation, ceil_mode
2402
        )
2403
        nbatch = input.size(-4) if input.dim() == 4 else 1
2404
        memory_format = torch.channels_last
2405
        if input.dim() == 3:
2406
            size = [nInputPlane, outputHeight, outputWidth]
2407
        else:
2408
            size = [nbatch, nInputPlane, outputHeight, outputWidth]
2409
        return torch.empty(
2410
            size,
2411
            dtype=input.dtype,
2412
            device=input.device,
2413
            memory_format=memory_format,
2414
        )
2415

2416

2417
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
2418
def check_dim_size(tensor, dim, dim_size, size):
2419
    torch._check(
2420
        tensor.dim() == dim and tensor.shape[dim_size] == size,
2421
        lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
2422
        + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
2423
    )
2424

2425

2426
@register_meta(aten.avg_pool2d.default)
2427
def meta_avg_pool2d(
2428
    input,
2429
    kernel_size,
2430
    stride=(),
2431
    padding=(0,),
2432
    ceil_mode=False,
2433
    count_include_pad=True,
2434
    divisor_override=None,
2435
):
2436
    def unpack(name, val):
2437
        torch._check(
2438
            len(val) in [1, 2],
2439
            lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
2440
        )
2441
        H = val[0]
2442
        W = H if len(val) == 1 else val[1]
2443
        return H, W
2444

2445
    kH, kW = unpack("kernel_size", kernel_size)
2446
    torch._check(
2447
        len(stride) in [0, 1, 2],
2448
        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2449
    )
2450
    if len(stride) == 0:
2451
        dH, dW = kH, kW
2452
    elif len(stride) == 1:
2453
        dH, dW = stride[0], stride[0]
2454
    else:
2455
        dH, dW = unpack("stride", stride)
2456

2457
    padH, padW = unpack("padding", padding)
2458

2459
    torch._check(
2460
        divisor_override is None or divisor_override != 0,
2461
        lambda: "divisor must be not zero",
2462
    )
2463

2464
    nbatch = input.size(-4) if input.dim() == 4 else 1
2465
    nInputPlane = input.size(-3)
2466
    inputHeight = input.size(-2)
2467
    inputWidth = input.size(-1)
2468

2469
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2470
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2471

2472
    memory_format = utils.suggest_memory_format(input)
2473
    pool2d_shape_check(
2474
        input,
2475
        kH,
2476
        kW,
2477
        dH,
2478
        dW,
2479
        padH,
2480
        padW,
2481
        1,
2482
        1,
2483
        nInputPlane,
2484
        inputHeight,
2485
        inputWidth,
2486
        outputHeight,
2487
        outputWidth,
2488
        memory_format,
2489
    )
2490

2491
    if input.dim() == 3:
2492
        size = [nInputPlane, outputHeight, outputWidth]
2493
    else:
2494
        size = [nbatch, nInputPlane, outputHeight, outputWidth]
2495
    return torch.empty(
2496
        size,
2497
        dtype=input.dtype,
2498
        device=input.device,
2499
        memory_format=memory_format,
2500
    )
2501

2502

2503
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
2504
def avg_pool2d_backward_shape_check(
2505
    input,
2506
    gradOutput,
2507
    nbatch,
2508
    kH,
2509
    kW,
2510
    dH,
2511
    dW,
2512
    padH,
2513
    padW,
2514
    nInputPlane,
2515
    inputHeight,
2516
    inputWidth,
2517
    outputHeight,
2518
    outputWidth,
2519
    mem_format,
2520
):
2521
    pool2d_shape_check(
2522
        input,
2523
        kH,
2524
        kW,
2525
        dH,
2526
        dW,
2527
        padH,
2528
        padW,
2529
        1,
2530
        1,
2531
        nInputPlane,
2532
        inputHeight,
2533
        inputWidth,
2534
        outputHeight,
2535
        outputWidth,
2536
        mem_format,
2537
    )
2538

2539
    ndim = input.dim()
2540
    nOutputPlane = nInputPlane
2541

2542
    check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
2543
    check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
2544
    check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
2545

2546

2547
# Don't override the C++ registration.
2548
@register_meta(aten.avg_pool2d_backward.default)
2549
def meta_avg_pool2d_backward(
2550
    gradOutput_,
2551
    input,
2552
    kernel_size,
2553
    stride,
2554
    padding,
2555
    ceil_mode,
2556
    count_include_pad,
2557
    divisor_override,
2558
):
2559
    # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
2560
    torch._check(
2561
        len(kernel_size) == 1 or len(kernel_size) == 2,
2562
        lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
2563
    )
2564
    kH = kernel_size[0]
2565
    kW = kH if len(kernel_size) == 1 else kernel_size[1]
2566
    torch._check(
2567
        len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
2568
        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2569
    )
2570
    dH = kH if len(stride) == 0 else stride[0]
2571
    dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
2572
    torch._check(
2573
        len(padding) == 1 or len(padding) == 2,
2574
        lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
2575
    )
2576
    padH = padding[0]
2577
    padW = padH if len(padding) == 1 else padding[1]
2578

2579
    torch._check(
2580
        divisor_override is None or divisor_override != 0,
2581
        lambda: "divisor must be not zero",
2582
    )
2583

2584
    input_size = input.shape
2585
    nbatch = input_size[-4] if input.dim() == 4 else 1
2586
    nInputPlane = input_size[-3]
2587
    inputHeight = input_size[-2]
2588
    inputWidth = input_size[-1]
2589

2590
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2591
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2592

2593
    mem_format = utils.suggest_memory_format(input)
2594

2595
    avg_pool2d_backward_shape_check(
2596
        input,
2597
        gradOutput_,
2598
        nbatch,
2599
        kH,
2600
        kW,
2601
        dH,
2602
        dW,
2603
        padH,
2604
        padW,
2605
        nInputPlane,
2606
        inputHeight,
2607
        inputWidth,
2608
        outputHeight,
2609
        outputWidth,
2610
        mem_format,
2611
    )
2612

2613
    return torch.empty(
2614
        input_size,
2615
        dtype=input.dtype,
2616
        device=input.device,
2617
        memory_format=mem_format,
2618
    )
2619

2620

2621
@register_meta(aten.avg_pool3d)
2622
@out_wrapper()
2623
def meta_avg_pool3d(
2624
    input,
2625
    kernel_size,
2626
    stride=(),
2627
    padding=(0,),
2628
    ceil_mode=False,
2629
    count_include_pad=True,
2630
    divisor_override=None,
2631
):
2632
    torch._check(
2633
        len(kernel_size) in (1, 3),
2634
        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2635
    )
2636
    kT = kernel_size[0]
2637
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2638
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2639

2640
    torch._check(
2641
        not stride or len(stride) in (1, 3),
2642
        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2643
    )
2644
    dT = kT if not stride else stride[0]
2645
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2646
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2647

2648
    torch._check(
2649
        len(padding) in (1, 3),
2650
        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2651
    )
2652
    padT = padding[0]
2653
    padH = padT if len(padding) == 1 else padding[1]
2654
    padW = padT if len(padding) == 1 else padding[2]
2655

2656
    torch._check(
2657
        input.ndim in (4, 5),
2658
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2659
    )
2660

2661
    torch._check(
2662
        not divisor_override or divisor_override != 0,
2663
        lambda: "divisor must be not zero",
2664
    )
2665

2666
    nbatch = input.size(0)
2667
    nslices = input.size(-4)
2668
    itime = input.size(-3)
2669
    iheight = input.size(-2)
2670
    iwidth = input.size(-1)
2671

2672
    otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2673
    oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2674
    owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2675

2676
    pool3d_shape_check(
2677
        input,
2678
        nslices,
2679
        kT,
2680
        kH,
2681
        kW,
2682
        dT,
2683
        dH,
2684
        dW,
2685
        padT,
2686
        padH,
2687
        padW,
2688
        1,
2689
        1,
2690
        1,
2691
        itime,
2692
        iheight,
2693
        iwidth,
2694
        otime,
2695
        oheight,
2696
        owidth,
2697
        "avg_pool3d()",
2698
        check_input_size=True,
2699
    )
2700

2701
    if input.ndim == 4:
2702
        return input.new_empty((nslices, otime, oheight, owidth))
2703
    else:
2704
        return input.new_empty((nbatch, nslices, otime, oheight, owidth))
2705

2706

2707
@register_meta(aten.avg_pool3d_backward)
2708
@out_wrapper("grad_input")
2709
def meta_avg_pool3d_backward(
2710
    grad_output,
2711
    input,
2712
    kernel_size,
2713
    stride,
2714
    padding,
2715
    ceil_mode,
2716
    count_include_pad,
2717
    divisor_override,
2718
):
2719
    torch._check(
2720
        len(kernel_size) in (1, 3),
2721
        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2722
    )
2723
    kT = kernel_size[0]
2724
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2725
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2726

2727
    torch._check(
2728
        not stride or len(stride) in (1, 3),
2729
        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2730
    )
2731
    dT = kT if not stride else stride[0]
2732
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2733
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2734

2735
    torch._check(
2736
        len(padding) in (1, 3),
2737
        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2738
    )
2739
    padT = padding[0]
2740
    padH = padT if len(padding) == 1 else padding[1]
2741
    padW = padT if len(padding) == 1 else padding[2]
2742

2743
    torch._check(
2744
        input.ndim in (4, 5),
2745
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2746
    )
2747

2748
    torch._check(
2749
        not divisor_override or divisor_override != 0,
2750
        lambda: "divisor must be not zero",
2751
    )
2752

2753
    nslices = input.size(-4)
2754
    itime = input.size(-3)
2755
    iheight = input.size(-2)
2756
    iwidth = input.size(-1)
2757

2758
    otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2759
    oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2760
    owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2761

2762
    avg_pool3d_backward_shape_check(
2763
        input,
2764
        grad_output,
2765
        nslices,
2766
        kT,
2767
        kH,
2768
        kW,
2769
        dT,
2770
        dH,
2771
        dW,
2772
        padT,
2773
        padH,
2774
        padW,
2775
        itime,
2776
        iheight,
2777
        iwidth,
2778
        otime_for_shape_check,
2779
        oheight_for_shape_check,
2780
        owidth_for_shape_check,
2781
        "avg_pool3d_backward()",
2782
    )
2783

2784
    return input.new_empty(input.shape)
2785

2786

2787
@register_meta(aten._adaptive_avg_pool2d.default)
2788
def meta_adaptive_avg_pool2d(self, output_size):
2789
    torch._check(
2790
        self.ndim == 3 or self.ndim == 4,
2791
        lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
2792
    )
2793
    output_shape = self.shape[:-2] + tuple(output_size)
2794
    memory_format = utils.suggest_memory_format(self)
2795
    # need to set memory_format to preserve the memory format of the input
2796
    # channel last input should have channel last output
2797
    return torch.empty(
2798
        output_shape,
2799
        dtype=self.dtype,
2800
        device=self.device,
2801
        memory_format=memory_format,
2802
    )
2803

2804

2805
@register_meta(aten._adaptive_avg_pool3d.default)
2806
def meta_adaptive_avg_pool3d(self, output_size):
2807
    torch._check(
2808
        self.ndim == 4 or self.ndim == 5,
2809
        lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
2810
    )
2811
    return self.new_empty(self.shape[:-3] + tuple(output_size))
2812

2813

2814
@register_meta(aten._adaptive_avg_pool2d_backward.default)
2815
def meta__adaptive_avg_pool2d_backward(grad_out, self):
2816
    ndim = grad_out.ndim
2817
    for i in range(1, ndim):
2818
        torch._check(
2819
            grad_out.size(i) > 0,
2820
            lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
2821
                      size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
2822
        )
2823
    torch._check(
2824
        ndim == 3 or ndim == 4,
2825
        lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
2826
    )
2827
    torch._check(
2828
        self.dtype == grad_out.dtype,
2829
        lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
2830
    )
2831
    memory_format = torch.contiguous_format
2832
    if is_channels_last(self):
2833
        memory_format = torch.channels_last
2834
    return self.new_empty(self.shape).to(memory_format=memory_format)
2835

2836

2837
@register_meta(aten._adaptive_avg_pool3d_backward)
2838
@out_wrapper("grad_input")
2839
def meta__adaptive_avg_pool3d_backward(grad_output, self):
2840
    _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
2841
    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2842

2843

2844
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
2845
    ndim = grad_output.ndim
2846
    for i in range(1, ndim):
2847
        torch._check(
2848
            grad_output.size(i) > 0,
2849
            lambda: (
2850
                f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
2851
                f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
2852
            ),
2853
        )
2854

2855

2856
@register_meta(aten.adaptive_max_pool2d)
2857
@out_wrapper("out", "indices")
2858
def meta_adaptive_max_pool2d(input, output_size):
2859
    ndim = input.ndim
2860
    torch._check(
2861
        ndim in (3, 4),
2862
        lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
2863
    )
2864
    for i in range(1, ndim):
2865
        torch._check(
2866
            input.size(i) > 0,
2867
            lambda: (
2868
                f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
2869
                f"but input has sizes {input.shape} with dimension {i} being empty"
2870
            ),
2871
        )
2872

2873
    torch._check(
2874
        len(output_size) == 2,
2875
        lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
2876
    )
2877

2878
    dimH = 1
2879
    sizeB = 1
2880
    sizeD = 0
2881

2882
    if input.ndim == 4:
2883
        sizeB = input.size(0)
2884
        dimH += 1
2885

2886
    sizeD = input.size(dimH - 1)
2887
    osizeH, osizeW = output_size
2888

2889
    if input.ndim == 3:
2890
        out_shape = (sizeD, osizeH, osizeW)
2891
        out = input.new_empty(out_shape)
2892
        indices = input.new_empty(out_shape, dtype=torch.int64)
2893
        return out, indices
2894
    else:
2895
        out_shape = (sizeB, sizeD, osizeH, osizeW)  # type: ignore[assignment]
2896
        memory_format = utils.suggest_memory_format(input)
2897
        out = input.new_empty(out_shape).to(memory_format=memory_format)
2898
        indices = input.new_empty(out_shape, dtype=torch.int64).to(
2899
            memory_format=memory_format
2900
        )
2901
        return out, indices
2902

2903

2904
@register_meta(aten.adaptive_max_pool2d_backward)
2905
@out_wrapper("grad_input")
2906
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
2907
    ndim = grad_output.ndim
2908
    torch._check(
2909
        ndim in (3, 4),
2910
        lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
2911
    )
2912

2913
    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
2914

2915
    torch._check(
2916
        input.dtype == grad_output.dtype,
2917
        lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
2918
    )
2919

2920
    memory_format = utils.suggest_memory_format(input)
2921
    return input.new_empty(input.shape).to(memory_format=memory_format)
2922

2923

2924
@register_meta(aten.adaptive_max_pool3d)
2925
@out_wrapper("out", "indices")
2926
def meta_adaptive_max_pool3d(input, output_size):
2927
    ndim = input.ndim
2928
    torch._check(
2929
        ndim in (4, 5),
2930
        lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
2931
    )
2932
    for i in range(1, ndim):
2933
        torch._check(
2934
            input.size(i) > 0,
2935
            lambda: (
2936
                f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
2937
                f"but input has sizes {input.shape} with dimension {i} being empty"
2938
            ),
2939
        )
2940

2941
    torch._check(
2942
        len(output_size) == 3,
2943
        lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
2944
    )
2945

2946
    dimD = 0
2947
    sizeB = 1
2948
    sizeD = 0
2949

2950
    if ndim == 5:
2951
        sizeB = input.size(0)
2952
        dimD += 1
2953

2954
    sizeD = input.size(dimD)
2955
    osizeT, osizeH, osizeW = output_size
2956

2957
    if ndim == 4:
2958
        out_shape = (sizeD, osizeT, osizeH, osizeW)
2959
    else:
2960
        out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW)  # type: ignore[assignment]
2961

2962
    out = input.new_empty(out_shape)
2963
    indices = input.new_empty(out_shape, dtype=torch.int64)
2964

2965
    return out, indices
2966

2967

2968
@register_meta(aten.adaptive_max_pool3d_backward)
2969
@out_wrapper("grad_input")
2970
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
2971
    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
2972
    return input.new_empty(input.shape)
2973

2974

2975
@register_meta(aten.repeat_interleave.Tensor)
2976
def meta_repeat_interleave_Tensor(repeats, output_size=None):
2977
    if output_size is None:
2978
        raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
2979
    return repeats.new_empty(output_size)
2980

2981

2982
@register_meta([aten.complex.default, aten.complex.out])
2983
@out_wrapper()
2984
def meta_complex(real, imag):
2985
    assert real.dtype.is_floating_point
2986
    assert imag.dtype.is_floating_point
2987
    out_shape = _broadcast_shapes(real.shape, imag.shape)
2988
    return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
2989

2990

2991
@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
2992
@out_wrapper()
2993
def nonzero_static(self, *, size: int, fill_value: int = -1):
2994
    return self.new_empty((size, self.dim()), dtype=torch.long)
2995

2996

2997
@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
2998
def meta_index_Tensor(self, indices):
2999
    torch._check(bool(indices), lambda: "at least one index must be provided")
3000
    # aten::index is the internal advanced indexing implementation
3001
    # checkIndexTensorTypes and expandTensors
3002
    result: List[Optional[Tensor]] = []
3003
    for i, index in enumerate(indices):
3004
        if index is not None:
3005
            torch._check(
3006
                index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
3007
                lambda: "tensors used as indices must be long, int, byte or bool tensors",
3008
            )
3009
            if index.dtype in [torch.int8, torch.bool]:
3010
                nonzero = index.nonzero()
3011
                k = len(result)
3012
                torch._check_index(
3013
                    k + index.ndim <= self.ndim,
3014
                    lambda: f"too many indices for tensor of dimension {self.ndim}",
3015
                )
3016
                for j in range(index.ndim):
3017
                    torch._check_index(
3018
                        index.shape[j] == self.shape[k + j],
3019
                        lambda: f"The shape of the mask {index.shape} at index {i} "
3020
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
3021
                    )
3022
                    result.append(nonzero.select(1, j))
3023
            else:
3024
                result.append(index)
3025
        else:
3026
            result.append(index)
3027
    indices = result
3028
    torch._check(
3029
        len(indices) <= self.ndim,
3030
        lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
3031
    )
3032
    # expand_outplace
3033
    import torch._refs as refs  # avoid import cycle in mypy
3034

3035
    indices = list(refs._maybe_broadcast(*indices))
3036
    # add missing null tensors
3037
    while len(indices) < self.ndim:
3038
        indices.append(None)
3039

3040
    # hasContiguousSubspace
3041
    #   true if all non-null tensors are adjacent
3042
    # See:
3043
    # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
3044
    # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
3045
    state = 0
3046
    has_contiguous_subspace = False
3047
    for index in indices:
3048
        if state == 0:
3049
            if index is not None:
3050
                state = 1
3051
        elif state == 1:
3052
            if index is None:
3053
                state = 2
3054
        else:
3055
            if index is not None:
3056
                break
3057
    else:
3058
        has_contiguous_subspace = True
3059

3060
    # transposeToFront
3061
    # This is the logic that causes the newly inserted dimensions to show up
3062
    # at the beginning of the tensor, if they're not contiguous
3063
    if not has_contiguous_subspace:
3064
        dims = []
3065
        transposed_indices = []
3066
        for i, index in enumerate(indices):
3067
            if index is not None:
3068
                dims.append(i)
3069
                transposed_indices.append(index)
3070
        for i, index in enumerate(indices):
3071
            if index is None:
3072
                dims.append(i)
3073
                transposed_indices.append(index)
3074
        self = self.permute(dims)
3075
        indices = transposed_indices
3076

3077
    # AdvancedIndex::AdvancedIndex
3078
    # Now we can assume the indices have contiguous subspace
3079
    # This is simplified from AdvancedIndex which goes to more effort
3080
    # to put the input and indices in a form so that TensorIterator can
3081
    # take them.  If we write a ref for this, probably that logic should
3082
    # get implemented
3083
    before_shape: List[int] = []
3084
    after_shape: List[int] = []
3085
    replacement_shape: List[int] = []
3086
    for dim, index in enumerate(indices):
3087
        if index is None:
3088
            if replacement_shape:
3089
                after_shape.append(self.shape[dim])
3090
            else:
3091
                before_shape.append(self.shape[dim])
3092
        else:
3093
            replacement_shape = list(index.shape)
3094
    return self.new_empty(before_shape + replacement_shape + after_shape)
3095

3096

3097
@register_meta([aten.convolution_backward.default])
3098
def meta_convolution_backward(
3099
    grad_output_,
3100
    input_,
3101
    weight_,
3102
    bias_sizes_opt,
3103
    stride,
3104
    padding,
3105
    dilation,
3106
    transposed,
3107
    output_padding,
3108
    groups,
3109
    output_mask,
3110
):
3111
    # High level logic taken from slow_conv3d_backward_cpu which should
3112
    # be representative of all convolution_backward impls
3113
    backend_grad_input = None
3114
    backend_grad_weight = None
3115
    backend_grad_bias = None
3116

3117
    if output_mask[0]:
3118
        backend_grad_input = grad_output_.new_empty(input_.size())
3119
    if output_mask[1]:
3120
        backend_grad_weight = grad_output_.new_empty(weight_.size())
3121
    if output_mask[2]:
3122
        backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3123

3124
    return (backend_grad_input, backend_grad_weight, backend_grad_bias)
3125

3126

3127
@register_meta([aten.addbmm.default, aten.addbmm.out])
3128
@out_wrapper()
3129
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
3130
    dim1 = batch1.size(1)
3131
    dim2 = batch2.size(2)
3132
    self = self.expand((dim1, dim2))
3133
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3134
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3135
    torch._check(
3136
        batch1.size(0) == batch2.size(0),
3137
        lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
3138
    )
3139
    torch._check(
3140
        batch1.size(2) == batch2.size(1),
3141
        lambda: (
3142
            f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
3143
            f"and {batch2.size(1)}x{batch2.size(2)})"
3144
        ),
3145
    )
3146
    torch._check(
3147
        self.size(0) == dim1 and self.size(1) == dim2,
3148
        lambda: "self tensor does not match matmul output shape",
3149
    )
3150
    return self.new_empty(self.size())
3151

3152

3153
@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
3154
def meta__fused_adam_(
3155
    self,
3156
    grads,
3157
    exp_avgs,
3158
    exp_avg_sqs,
3159
    max_exp_avg_sqs,
3160
    state_steps,
3161
    *,
3162
    lr,
3163
    beta1,
3164
    beta2,
3165
    weight_decay,
3166
    eps,
3167
    amsgrad,
3168
    maximize,
3169
    grad_scale=None,
3170
    found_inf=None,
3171
):
3172
    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3173
        torch._check(
3174
            isinstance(l, List),
3175
            lambda: f"exponent must be a tensor list but got {type(l)}",
3176
        )
3177

3178

3179
@register_meta([aten._fused_adam.default])
3180
def meta__fused_adam(
3181
    self,
3182
    grads,
3183
    exp_avgs,
3184
    exp_avg_sqs,
3185
    max_exp_avg_sqs,
3186
    state_steps,
3187
    *,
3188
    lr,
3189
    beta1,
3190
    beta2,
3191
    weight_decay,
3192
    eps,
3193
    amsgrad,
3194
    maximize,
3195
    grad_scale=None,
3196
    found_inf=None,
3197
):
3198
    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3199
        torch._check(
3200
            isinstance(l, List),
3201
            lambda: f"exponent must be a tensor list but got {type(l)}",
3202
        )
3203

3204
    def empty_like_list(tensor_list):
3205
        return [torch.empty_like(t) for t in tensor_list]
3206

3207
    return (
3208
        empty_like_list(self),
3209
        empty_like_list(grads),
3210
        empty_like_list(exp_avgs),
3211
        empty_like_list(exp_avg_sqs),
3212
        empty_like_list(max_exp_avg_sqs),
3213
    )
3214

3215

3216
@register_meta([aten._int_mm])
3217
@out_wrapper()
3218
def meta__int_mm(a, b):
3219
    torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
3220
    torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
3221
    torch._check(
3222
        a.dtype is torch.int8,
3223
        lambda: f"expected self to be int8, got {a.dtype}",
3224
    )
3225
    torch._check(
3226
        b.dtype is torch.int8,
3227
        lambda: f"expected mat2 to be int8, got {b.dtype}",
3228
    )
3229
    torch._check(
3230
        a.size(1) == b.size(0),
3231
        lambda: (
3232
            f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
3233
            f"and {b.size(0)}x{b.size(1)})"
3234
        ),
3235
    )
3236
    return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
3237

3238

3239
@register_meta([aten._convert_weight_to_int4pack])
3240
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
3241
    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3242
    torch._check(
3243
        w.dtype is torch.uint8,
3244
        lambda: f"expected w to be uint8, got {w.dtype}",
3245
    )
3246
    n = w.size(0)
3247
    k = w.size(1) * 2  # w is [n][k / 2] uint8
3248
    return w.new_empty(
3249
        (
3250
            n // 8,
3251
            k // (inner_k_tiles * 16),
3252
            32,
3253
            inner_k_tiles // 2,
3254
        ),
3255
        dtype=torch.int32,
3256
    )
3257

3258

3259
@register_meta([aten._weight_int4pack_mm])
3260
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
3261
    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3262
    torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
3263
    torch._check(
3264
        x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3265
        lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3266
    )
3267
    torch._check(
3268
        w.dtype is torch.int32,
3269
        lambda: f"expected w to be int32, got {w.dtype}",
3270
    )
3271
    return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
3272

3273

3274
@register_meta([aten._weight_int8pack_mm])
3275
def meta__weight_int8pack_mm(x, w, q_scales):
3276
    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3277
    torch._check(
3278
        x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3279
        lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3280
    )
3281
    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3282
    torch._check(
3283
        w.dtype is torch.int8,
3284
        lambda: f"expected w to be int8, got {w.dtype}",
3285
    )
3286
    return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3287

3288

3289
@register_meta(aten._cdist_forward.default)
3290
def meta_cdist_forward(x1, x2, p, compute_mode):
3291
    torch._check(
3292
        x1.dim() >= 2,
3293
        lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
3294
    )
3295
    torch._check(
3296
        x2.dim() >= 2,
3297
        lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
3298
    )
3299
    torch._check(
3300
        x1.size(-1) == x2.size(-1),
3301
        lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
3302
    )
3303
    torch._check(
3304
        utils.is_float_dtype(x1.dtype),
3305
        lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
3306
    )
3307
    torch._check(
3308
        utils.is_float_dtype(x2.dtype),
3309
        lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
3310
    )
3311
    torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
3312
    torch._check(
3313
        compute_mode in (None, 1, 2),
3314
        lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
3315
    )
3316
    r1 = x1.size(-2)
3317
    r2 = x2.size(-2)
3318
    batch_tensor1 = x1.shape[:-2]
3319
    batch_tensor2 = x2.shape[:-2]
3320
    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3321
    output_shape.extend([r1, r2])
3322
    return x1.new_empty(output_shape)
3323

3324

3325
@register_meta(aten._cdist_backward)
3326
@out_wrapper()
3327
def meta_cdist_backward(grad, x1, x2, p, cdist):
3328
    c1 = x1.shape[-1]
3329
    r1 = x1.shape[-2]
3330
    r2 = x2.shape[-2]
3331
    batch_tensor1 = x1.shape[:-2]
3332
    batch_tensor2 = x2.shape[:-2]
3333
    expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3334
    tensor1_expand_size = expand_batch_portion.copy()
3335
    tensor1_expand_size.extend([r1, c1])
3336
    batch_product = math.prod(expand_batch_portion)
3337
    if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
3338
        return torch.zeros_like(x1)
3339
    if tensor1_expand_size != list(x1.shape):
3340
        x1 = x1.expand(tensor1_expand_size)
3341
    return torch.empty_like(x1, memory_format=torch.contiguous_format)
3342

3343

3344
# NB: This meta function accepts non-meta arguments!  When this behavior
3345
# was originally introduced this was accidental, but it is now load bearing
3346
# as people are using this so that they can conveniently test code involving
3347
# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module)
3348
@register_meta(aten._embedding_bag.default)
3349
def meta_embedding_bag(
3350
    weight,
3351
    indices,
3352
    offsets,
3353
    scale_grad_by_freq=False,
3354
    mode=0,
3355
    sparse=False,
3356
    per_sample_weights=None,
3357
    include_last_offset=False,
3358
    padding_idx=-1,
3359
):
3360
    torch._check(
3361
        indices.dtype in (torch.long, torch.int),
3362
        lambda: f"expected indices to be long or int, got {indices.dtype}",
3363
    )
3364
    torch._check(
3365
        offsets.dtype in (torch.long, torch.int),
3366
        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
3367
    )
3368
    torch._check(
3369
        utils.is_float_dtype(weight.dtype),
3370
        lambda: f"expected weight to be floating point type, got {weight.dtype}",
3371
    )
3372

3373
    num_bags = offsets.size(0)
3374
    if include_last_offset:
3375
        torch._check(
3376
            num_bags >= 1,
3377
            lambda: "include_last_offset: numBags should be at least 1",
3378
        )
3379
        num_bags -= 1
3380

3381
    output = weight.new_empty(num_bags, weight.size(1))
3382
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
3383

3384
    if per_sample_weights is not None:
3385
        torch._check(
3386
            mode == MODE_SUM,
3387
            lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
3388
        )
3389
        torch._check(
3390
            per_sample_weights.dtype == weight.dtype,
3391
            lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
3392
        )
3393
        torch._check(
3394
            per_sample_weights.ndim == 1,
3395
            lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
3396
        )
3397
        torch._check(
3398
            per_sample_weights.numel() == indices.numel(),
3399
            lambda: (
3400
                f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
3401
                f"to be the same as indices.numel() ({indices.numel()})"
3402
            ),
3403
        )
3404

3405
    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
3406
        return (
3407
            is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
3408
        )
3409

3410
    def is_fast_path_index_select(src, output, padding_idx):
3411
        return (
3412
            (src.dtype == torch.float or src.dtype == torch.half)
3413
            and src.stride(1) == 1
3414
            and output.stride(1) == 1
3415
            and padding_idx < 0
3416
        )
3417

3418
    def is_fast_path(src, scale, output, padding_idx):
3419
        if scale is not None:
3420
            return is_fast_path_index_select_scale(src, scale, output, padding_idx)
3421
        else:
3422
            return is_fast_path_index_select(src, output, padding_idx)
3423

3424
    if device_hint(offsets) != "cpu":
3425
        offset2bag = indices.new_empty(indices.size(0))
3426
        bag_size = indices.new_empty(offsets.size())
3427
        if mode == MODE_MAX:
3428
            max_indices = indices.new_empty(num_bags, weight.size(1))
3429
        else:
3430
            max_indices = indices.new_empty(0)
3431
    else:
3432
        fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
3433
        if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
3434
            offset2bag = offsets.new_empty(indices.size(0))
3435
        else:
3436
            offset2bag = offsets.new_empty(0)
3437
        bag_size = offsets.new_empty(num_bags)
3438
        # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
3439
        numBags = offsets.shape[0]
3440
        if mode == MODE_MAX:
3441
            if include_last_offset:
3442
                torch._check(
3443
                    numBags >= 1,
3444
                    lambda: "include_last_offset: numBags should be at least 1",
3445
                )
3446
                numBags -= 1
3447
            max_indices = offsets.new_empty(numBags, weight.shape[1])
3448
        else:
3449
            max_indices = offsets.new_empty(bag_size.size())
3450
    return output, offset2bag, bag_size, max_indices
3451

3452

3453
@register_meta(aten._embedding_bag_forward_only.default)
3454
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
3455
    output, offset2bag, bag_size, max_indices = meta_embedding_bag(
3456
        weight, indices, offsets, *args
3457
    )
3458
    if device_hint(offsets) == "cpu":
3459
        bag_size = offsets.new_empty(offsets.size())
3460
    return output, offset2bag, bag_size, max_indices
3461

3462

3463
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
3464
    # if specified, dtype takes precedence
3465
    if dtype:
3466
        return dtype
3467

3468
    if input.dtype.is_floating_point or input.dtype.is_complex:
3469
        return input.dtype
3470
    elif promote_int_to_long:
3471
        return torch.long
3472

3473
    return input.dtype
3474

3475

3476
@register_meta([aten.nansum.default, aten.nansum.out])
3477
@out_wrapper()
3478
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
3479
    output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
3480
    dims = utils.reduction_dims(input.shape, dims)
3481
    output_shape = _compute_reduction_shape(input, dims, keepdim)
3482
    return input.new_empty(output_shape, dtype=output_dtype)
3483

3484

3485
@register_meta([aten.median.default, aten.nanmedian.default])
3486
def meta_median(input):
3487
    output_shape = utils.compute_reduction_output_shape(
3488
        input.shape, tuple(range(input.dim()))
3489
    )
3490
    return input.new_empty(output_shape)
3491

3492

3493
@register_meta(
3494
    [
3495
        aten.median.dim,
3496
        aten.median.dim_values,
3497
        aten.nanmedian.dim,
3498
        aten.nanmedian.dim_values,
3499
        aten.mode.default,
3500
        aten.mode.values,
3501
    ]
3502
)
3503
@out_wrapper("values", "indices")
3504
def meta_median_mode_dim(input, dim=-1, keepdim=False):
3505
    if device_hint(input) == "cuda":
3506
        utils.alert_not_deterministic("median CUDA with indices output")
3507
    dim = utils.reduction_dims(input.shape, (dim,))
3508
    output_shape = _compute_reduction_shape(input, dim, keepdim)
3509
    return (
3510
        input.new_empty(output_shape),
3511
        input.new_empty(output_shape, dtype=torch.long),
3512
    )
3513

3514

3515
@register_meta(aten.logical_not_.default)
3516
def meta_logical_not_(self):
3517
    return self
3518

3519

3520
@register_meta(aten.repeat.default)
3521
def meta_repeat(self, repeats):
3522
    torch._check(
3523
        len(repeats) >= self.dim(),
3524
        lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3525
    )
3526
    # Add new leading dimensions to the tensor if the
3527
    # number of target dimensions is larger than the
3528
    # number of source dimensions.
3529
    num_new_dimensions = len(repeats) - self.dim()
3530
    padded_size = (1,) * num_new_dimensions + tuple(self.shape)
3531
    target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
3532
    return self.new_empty(target_size)
3533

3534

3535
@register_meta(aten.zero_.default)
3536
def meta_zero_(self):
3537
    return self
3538

3539

3540
@register_meta(
3541
    [
3542
        aten.mul_.Scalar,
3543
        aten.div_.Scalar,
3544
        aten.mul_.Tensor,
3545
        aten.div_.Tensor,
3546
        aten.logical_and_.default,
3547
        aten.logical_or_.default,
3548
        aten.logical_xor_.default,
3549
    ],
3550
)
3551
def meta_binop_inplace(self, other):
3552
    if isinstance(other, torch.Tensor):
3553
        check_inplace_broadcast(self.shape, other.shape)
3554
    return self
3555

3556

3557
@register_meta(
3558
    [
3559
        aten.add_.Scalar,
3560
        aten.sub_.Scalar,
3561
        aten.add_.Tensor,
3562
        aten.sub_.Tensor,
3563
    ],
3564
)
3565
def meta_binop_inplace_alpha(self, other, alpha=1):
3566
    if isinstance(other, torch.Tensor):
3567
        check_inplace_broadcast(self.shape, other.shape)
3568
    return self
3569

3570

3571
@register_meta([aten.round.default, aten.round.decimals])
3572
def meta_round(self, **kwargs):
3573
    return elementwise_meta(
3574
        self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3575
    )
3576

3577

3578
def shift_dtype_check(fn_name, self, val):
3579
    torch._check(
3580
        utils.is_integer_dtype(self.dtype),
3581
        lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
3582
    )
3583
    if isinstance(val, torch.Tensor):
3584
        torch._check(
3585
            utils.is_integer_dtype(val.dtype),
3586
            lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
3587
        )
3588
    else:
3589
        torch._check(
3590
            isinstance(val, IntLike),
3591
            lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
3592
        )
3593

3594

3595
@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
3596
def meta_rshifts(self, other):
3597
    shift_dtype_check("rshift", self, other)
3598
    return elementwise_meta(
3599
        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3600
    )
3601

3602

3603
@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
3604
def meta_lshifts(self, other):
3605
    shift_dtype_check("lshift", self, other)
3606
    return elementwise_meta(
3607
        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3608
    )
3609

3610

3611
@register_meta(aten.zero.default)
3612
def meta_zero(self):
3613
    return self.new_empty(self.shape)
3614

3615

3616
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
3617
def meta_fill_(self, val):
3618
    return self
3619

3620

3621
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
3622
def meta_fill(self, val):
3623
    return torch.empty_like(self)
3624

3625

3626
@register_meta(aten.relu_.default)
3627
def meta_relu_(self):
3628
    return self
3629

3630

3631
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
3632
def meta_index_put(self, indices, values, accumulate=False):
3633
    return torch.empty_like(self)
3634

3635

3636
@register_meta(aten.masked_fill_.Scalar)
3637
def meta_masked_fill_(self, mask, value):
3638
    check_inplace_broadcast(self.shape, mask.shape)
3639
    return self
3640

3641

3642
@register_meta(aten._masked_scale.default)
3643
def meta__masked_scale(self, mask, scale):
3644
    masked_scale = self.new_empty(self.size()).to(
3645
        memory_format=utils.suggest_memory_format(self)
3646
    )
3647
    return masked_scale
3648

3649

3650
@register_meta(aten.masked_scatter_)
3651
def meta_masked_scatter_(self, mask, source):
3652
    torch._check(
3653
        mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
3654
    )
3655
    torch._check(
3656
        self.dtype == source.dtype,
3657
        lambda: "masked_scatter: expected self and source to have same "
3658
        "dtypes but got {self.dtype} and {source.dtype}",
3659
    )
3660
    return self
3661

3662

3663
@register_meta(aten.masked_scatter)
3664
@out_wrapper()
3665
def meta_masked_scatter(self, mask, source):
3666
    self, mask = _maybe_broadcast(self, mask)
3667
    output = torch.empty_like(self, memory_format=torch.contiguous_format)
3668
    return meta_masked_scatter_(output, mask, source)
3669

3670

3671
@register_meta(aten.masked_scatter_backward)
3672
def meta_masked_scatter_backward(self, mask, sizes):
3673
    return self.new_empty(sizes)
3674

3675

3676
@register_meta(aten.index_put_.default)
3677
def meta_index_put_(self, indices, values, accumulate=False):
3678
    return self
3679

3680

3681
@register_meta(aten.alias.default)
3682
def meta_alias(self):
3683
    return self.view(self.shape)
3684

3685

3686
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
3687
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3688
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3689

3690
    batch1_sizes = batch1.size()
3691
    batch2_sizes = batch2.size()
3692

3693
    bs = batch1_sizes[0]
3694
    contraction_size = batch1_sizes[2]
3695
    res_rows = batch1_sizes[1]
3696
    res_cols = batch2_sizes[2]
3697
    output_size = (bs, res_rows, res_cols)
3698

3699
    torch._check(
3700
        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
3701
        lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
3702
        f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
3703
    )
3704

3705
    # TODO: handle out
3706

3707
    output = batch2.new_empty(output_size)
3708

3709
    if not is_bmm and self_baddbmm is not None:
3710
        torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
3711
        torch._check(
3712
            self_baddbmm.size() == output_size,
3713
            lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
3714
        )
3715

3716
    return output
3717

3718

3719
@register_meta(aten.bmm.default)
3720
def meta_bmm(self, mat2):
3721
    return common_meta_baddbmm_bmm(self, mat2, True)
3722

3723

3724
def div_rtn(x, y):
3725
    q = x // y
3726
    r = x % y
3727
    # WARNING: explicit bool conversion here is necessary;
3728
    # would be fixed by SymBool
3729
    if r != 0 and (bool(r < 0) != bool(y < 0)):
3730
        q -= 1
3731
    return q
3732

3733

3734
def pooling_output_shape_pad_lr(
3735
    inputSize,
3736
    kernelSize,
3737
    pad_l,
3738
    pad_r,
3739
    stride,
3740
    dilation,
3741
    ceil_mode,
3742
):
3743
    outputSize = (
3744
        div_rtn(
3745
            inputSize
3746
            + pad_l
3747
            + pad_r
3748
            - dilation * (kernelSize - 1)
3749
            - 1
3750
            + (stride - 1 if ceil_mode else 0),
3751
            stride,
3752
        )
3753
        + 1
3754
    )
3755
    if ceil_mode:
3756
        if (outputSize - 1) * stride >= inputSize + pad_l:
3757
            outputSize -= 1
3758
    return outputSize
3759

3760

3761
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
3762
    torch._check(stride != 0, lambda: "stride should not be zero")
3763
    torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
3764
    torch._check(
3765
        pad <= ((kernelSize - 1) * dilation + 1) // 2,
3766
        lambda: (
3767
            f"pad should be at most half of effective kernel size, but got pad={pad}, "
3768
            f"kernel_size={kernelSize} and dilation={dilation}"
3769
        ),
3770
    )
3771
    return pooling_output_shape_pad_lr(
3772
        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
3773
    )
3774

3775

3776
def pool2d_shape_check(
3777
    input,
3778
    kH,
3779
    kW,
3780
    dH,
3781
    dW,
3782
    padH,
3783
    padW,
3784
    dilationH,
3785
    dilationW,
3786
    nInputPlane,
3787
    inputHeight,
3788
    inputWidth,
3789
    outputHeight,
3790
    outputWidth,
3791
    memory_format,
3792
):
3793
    ndim = input.dim()
3794
    nOutputPlane = nInputPlane
3795

3796
    torch._check(
3797
        kW > 0 and kH > 0,
3798
        lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
3799
    )
3800
    torch._check(
3801
        dW > 0 and dH > 0,
3802
        lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
3803
    )
3804
    torch._check(
3805
        dilationH > 0 and dilationW > 0,
3806
        lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
3807
    )
3808

3809
    valid_dims = input.size(1) != 0 and input.size(2) != 0
3810

3811
    if memory_format == torch.channels_last:
3812
        torch._check(
3813
            ndim == 4 and valid_dims and input.size(3) != 0,
3814
            lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
3815
            " with optional 0 dim batch size for input, but got: {input.size()}",
3816
        )
3817
    else:
3818
        torch._check(
3819
            (ndim == 3 and input.size(0) != 0 and valid_dims)
3820
            or (ndim == 4 and valid_dims and input.size(3) != 0),
3821
            lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
3822
        )
3823

3824
    torch._check(
3825
        kW // 2 >= padW and kH // 2 >= padH,
3826
        lambda: "pad should be smaller than or equal to half of kernel size, but got "
3827
        f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
3828
    )
3829

3830
    torch._check(
3831
        outputWidth >= 1 and outputHeight >= 1,
3832
        lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
3833
        f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
3834
        "Output size is too small",
3835
    )
3836

3837

3838
def pool3d_shape_check(
3839
    input: Tensor,
3840
    nslices: int,
3841
    kT: int,
3842
    kH: int,
3843
    kW: int,
3844
    dT: int,
3845
    dH: int,
3846
    dW: int,
3847
    pT: int,
3848
    pH: int,
3849
    pW: int,
3850
    dilationT: int,
3851
    dilationH: int,
3852
    dilationW: int,
3853
    itime: int,
3854
    iheight: int,
3855
    iwidth: int,
3856
    otime: int,
3857
    oheight: int,
3858
    owidth: int,
3859
    fn_name: str,
3860
    check_input_size: bool = False,
3861
):
3862
    ndim = input.ndim
3863

3864
    torch._check(
3865
        kT > 0 and kW > 0 and kH > 0,
3866
        lambda: (
3867
            f"kernel size should be greater than zero, but got "
3868
            f"kT: {kT}, kH: {kH}, kW: {kW}"
3869
        ),
3870
    )
3871
    torch._check(
3872
        dT > 0 and dW > 0 and dH > 0,
3873
        lambda: (
3874
            f"stride should be greater than zero, but got "
3875
            f"dT: {dT}, dH: {dH}, dW: {dW}"
3876
        ),
3877
    )
3878
    torch._check(
3879
        dilationT > 0 and dilationW > 0 and dilationH > 0,
3880
        lambda: (
3881
            f"dilation should be greater than zero, but got "
3882
            f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
3883
        ),
3884
    )
3885

3886
    torch._check(
3887
        ndim in (4, 5),
3888
        lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
3889
    )
3890

3891
    for i in range(ndim):
3892
        if ndim == 5 and i == 0:
3893
            # size of batch-dim can be 0.
3894
            continue
3895
        torch._check(
3896
            input.size(i) > 0,
3897
            lambda: (
3898
                f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
3899
                f" but input has a shape of {input.shape}"
3900
                f" and non-batch dimension {input.size(i)} has length zero!"
3901
            ),
3902
        )
3903

3904
    if check_input_size:  # AveragePool3d
3905
        torch._check(
3906
            itime >= kT and iheight >= kH and iwidth >= kW,
3907
            lambda: (
3908
                f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
3909
                f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
3910
            ),
3911
        )
3912

3913
    torch._check(
3914
        kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
3915
        lambda: (
3916
            f"pad should be smaller than or equal to half of kernel size, but got "
3917
            f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
3918
        ),
3919
    )
3920

3921
    torch._check(
3922
        otime >= 1 and owidth >= 1 and oheight >= 1,
3923
        lambda: (
3924
            f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
3925
            f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
3926
            f"Output size is too small"
3927
        ),
3928
    )
3929

3930

3931
def max_pool3d_backward_shape_check(
3932
    input,
3933
    grad_output,
3934
    indices,
3935
    nslices,
3936
    kT,
3937
    kH,
3938
    kW,
3939
    dT,
3940
    dH,
3941
    dW,
3942
    pT,
3943
    pH,
3944
    pW,
3945
    dilationT,
3946
    dilationH,
3947
    dilationW,
3948
    itime,
3949
    iheight,
3950
    iwidth,
3951
    otime,
3952
    oheight,
3953
    owidth,
3954
    fn_name,
3955
):
3956
    ndim = input.ndim
3957

3958
    pool3d_shape_check(
3959
        input,
3960
        nslices,
3961
        kT,
3962
        kH,
3963
        kW,
3964
        dT,
3965
        dH,
3966
        dW,
3967
        pT,
3968
        pH,
3969
        pW,
3970
        dilationT,
3971
        dilationH,
3972
        dilationW,
3973
        itime,
3974
        iheight,
3975
        iwidth,
3976
        otime,
3977
        oheight,
3978
        owidth,
3979
        fn_name,
3980
    )
3981

3982
    check_dim_size(grad_output, ndim, ndim - 4, nslices)
3983
    check_dim_size(grad_output, ndim, ndim - 3, otime)
3984
    check_dim_size(grad_output, ndim, ndim - 2, oheight)
3985
    check_dim_size(grad_output, ndim, ndim - 1, owidth)
3986

3987
    check_dim_size(indices, ndim, ndim - 4, nslices)
3988
    check_dim_size(indices, ndim, ndim - 3, otime)
3989
    check_dim_size(indices, ndim, ndim - 2, oheight)
3990
    check_dim_size(indices, ndim, ndim - 1, owidth)
3991

3992

3993
def avg_pool3d_backward_shape_check(
3994
    input: Tensor,
3995
    grad_output: Tensor,
3996
    nslices: int,
3997
    kT: int,
3998
    kH: int,
3999
    kW: int,
4000
    dT: int,
4001
    dH: int,
4002
    dW: int,
4003
    pT: int,
4004
    pH: int,
4005
    pW: int,
4006
    itime: int,
4007
    iheight: int,
4008
    iwidth: int,
4009
    otime: int,
4010
    oheight: int,
4011
    owidth: int,
4012
    fn_name: str,
4013
):
4014
    ndim = input.ndim
4015

4016
    pool3d_shape_check(
4017
        input,
4018
        nslices,
4019
        kT,
4020
        kH,
4021
        kW,
4022
        dT,
4023
        dH,
4024
        dW,
4025
        pT,
4026
        pH,
4027
        pW,
4028
        1,
4029
        1,
4030
        1,
4031
        itime,
4032
        iheight,
4033
        iwidth,
4034
        otime,
4035
        oheight,
4036
        owidth,
4037
        fn_name,
4038
        True,
4039
    )
4040

4041
    check_dim_size(grad_output, ndim, ndim - 4, nslices)
4042
    check_dim_size(grad_output, ndim, ndim - 3, otime)
4043
    check_dim_size(grad_output, ndim, ndim - 2, oheight)
4044
    check_dim_size(grad_output, ndim, ndim - 1, owidth)
4045

4046

4047
def max_pool2d_checks_and_compute_shape(
4048
    input,
4049
    kernel_size,
4050
    stride,
4051
    padding,
4052
    dilation,
4053
    ceil_mode,
4054
):
4055
    # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
4056
    def unpack(name, val):
4057
        torch._check(
4058
            len(val) in [1, 2],
4059
            lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
4060
        )
4061
        H = val[0]
4062
        W = H if len(val) == 1 else val[1]
4063
        return H, W
4064

4065
    kH, kW = unpack("kernel_size", kernel_size)
4066

4067
    torch._check(
4068
        len(stride) in [0, 1, 2],
4069
        lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
4070
    )
4071
    if len(stride) == 0:
4072
        dH, dW = kH, kW
4073
    else:
4074
        dH, dW = unpack("stride", stride)
4075

4076
    padH, padW = unpack("padding", padding)
4077
    dilationH, dilationW = unpack("dilation", dilation)
4078
    nInputPlane = input.size(-3)
4079
    inputHeight = input.size(-2)
4080
    inputWidth = input.size(-1)
4081

4082
    memory_format = utils.suggest_memory_format(input)
4083
    if memory_format == torch.channels_last:
4084
        torch._check(
4085
            input.dim() == 4,
4086
            lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
4087
        )
4088
    elif memory_format == torch.contiguous_format:
4089
        torch._check(
4090
            input.dim() in [3, 4],
4091
            lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
4092
        )
4093
    else:
4094
        torch._check(
4095
            False,
4096
            lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
4097
        )
4098

4099
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
4100
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
4101

4102
    pool2d_shape_check(
4103
        input,
4104
        kH,
4105
        kW,
4106
        dH,
4107
        dW,
4108
        padH,
4109
        padW,
4110
        dilationH,
4111
        dilationW,
4112
        nInputPlane,
4113
        inputHeight,
4114
        inputWidth,
4115
        outputHeight,
4116
        outputWidth,
4117
        memory_format,
4118
    )
4119

4120
    return nInputPlane, outputHeight, outputWidth
4121

4122

4123
@register_meta(aten.max_pool2d_with_indices_backward.default)
4124
def meta_max_pool2d_with_indices_backward(
4125
    grad_output,
4126
    self,
4127
    kernel_size,
4128
    stride,
4129
    padding,
4130
    dilation,
4131
    ceil_mode,
4132
    indices,
4133
):
4134
    (
4135
        nInputPlane,
4136
        outputHeight,
4137
        outputWidth,
4138
    ) = max_pool2d_checks_and_compute_shape(
4139
        self, kernel_size, stride, padding, dilation, ceil_mode
4140
    )
4141

4142
    torch._check(
4143
        self.dtype == grad_output.dtype,
4144
        lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
4145
    )
4146

4147
    nOutputPlane = nInputPlane
4148
    ndim = self.ndim
4149

4150
    def _check_dim_size(t):
4151
        check_dim_size(t, ndim, ndim - 3, nOutputPlane)
4152
        check_dim_size(t, ndim, ndim - 2, outputHeight)
4153
        check_dim_size(t, ndim, ndim - 1, outputWidth)
4154

4155
    _check_dim_size(grad_output)
4156
    _check_dim_size(indices)
4157

4158
    memory_format = utils.suggest_memory_format(self)
4159
    return torch.empty(
4160
        self.shape,
4161
        dtype=self.dtype,
4162
        device=self.device,
4163
        memory_format=memory_format,
4164
    )
4165

4166

4167
@register_meta(aten.max_pool2d_with_indices.default)
4168
def meta_max_pool2d_with_indices(
4169
    input,
4170
    kernel_size,
4171
    stride=(),
4172
    padding=(0,),
4173
    dilation=(1,),
4174
    ceil_mode=False,
4175
):
4176
    (
4177
        nInputPlane,
4178
        outputHeight,
4179
        outputWidth,
4180
    ) = max_pool2d_checks_and_compute_shape(
4181
        input, kernel_size, stride, padding, dilation, ceil_mode
4182
    )
4183

4184
    nbatch = input.size(-4) if input.dim() == 4 else 1
4185
    memory_format = utils.suggest_memory_format(input)
4186
    if input.dim() == 3:
4187
        size = [nInputPlane, outputHeight, outputWidth]
4188
    else:
4189
        size = [nbatch, nInputPlane, outputHeight, outputWidth]
4190
    return (
4191
        torch.empty(
4192
            size,
4193
            dtype=input.dtype,
4194
            device=input.device,
4195
            memory_format=memory_format,
4196
        ),
4197
        torch.empty(
4198
            size,
4199
            dtype=torch.int64,
4200
            device=input.device,
4201
            memory_format=memory_format,
4202
        ),
4203
    )
4204

4205

4206
@register_meta(aten.fractional_max_pool2d.default)
4207
def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
4208
    torch._check(
4209
        self.ndim in (3, 4),
4210
        lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
4211
    )
4212
    ndim = self.ndim
4213

4214
    for d in range(ndim - 3, ndim):
4215
        torch._check(
4216
            self.size(d) > 0,
4217
            f"fractional_max_pool2d: Expected input to have non-zero "
4218
            f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
4219
        )
4220

4221
    # the check and message are out of sync, but this matches the structured meta
4222
    torch._check(
4223
        len(kernel_size) == 2,
4224
        lambda: "fractional_max_pool2d: kernel_size must"
4225
        "either be a single int or tuple of Ints",
4226
    )
4227
    torch._check(
4228
        len(output_size) == 2,
4229
        lambda: "fractional_max_pool2d: output_size must "
4230
        "either be a single int or tuple of Ints",
4231
    )
4232

4233
    input_channels = self.size(-3)
4234
    input_height = self.size(-2)
4235
    input_width = self.size(-1)
4236
    if ndim == 4:
4237
        input_batch = self.size(0)
4238
    else:
4239
        input_batch = 1
4240

4241
    torch._check(
4242
        self.dtype == random_samples.dtype,
4243
        lambda: "Expect _random_samples to have the same dtype as input",
4244
    )
4245
    torch._check(
4246
        random_samples.ndim == 3,
4247
        lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
4248
    )
4249

4250
    n = random_samples.size(0)
4251
    c = random_samples.size(1)
4252
    d = random_samples.size(2)
4253
    torch._check(
4254
        n >= input_batch,
4255
        "Expect _random_samples.size(0) no less then input batch size.",
4256
    )
4257
    torch._check(
4258
        c == input_channels,
4259
        lambda: "Expect _random_samples.size(1) equals to input channel size.",
4260
    )
4261
    torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
4262

4263
    torch._check(
4264
        output_size[0] + kernel_size[0] - 1 <= input_height,
4265
        lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
4266
    )
4267
    torch._check(
4268
        output_size[1] + kernel_size[1] - 1 <= input_width,
4269
        lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
4270
    )
4271

4272
    if self.dim() == 4:
4273
        size = [input_batch, input_channels, output_size[0], output_size[1]]
4274
    else:
4275
        size = [input_channels, output_size[0], output_size[1]]
4276

4277
    return (
4278
        torch.empty(
4279
            size,
4280
            dtype=self.dtype,
4281
            device=self.device,
4282
        ),
4283
        torch.empty(
4284
            size,
4285
            dtype=torch.int64,
4286
            device=self.device,
4287
        ),
4288
    )
4289

4290

4291
@register_meta(aten.max_unpool2d)
4292
@out_wrapper()
4293
def meta_max_unpool2d(self, indices, output_size):
4294
    utils.alert_not_deterministic("max_unpooling2d_forward_out")
4295

4296
    torch._check(
4297
        indices.dtype == torch.int64,
4298
        lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4299
    )
4300
    torch._check(
4301
        len(output_size) == 2,
4302
        lambda: (
4303
            f"There should be exactly two elements (height, width) in output_size, "
4304
            f"but got {len(output_size)} elements."
4305
        ),
4306
    )
4307

4308
    oheight, owidth = output_size
4309

4310
    torch._check(
4311
        self.ndim in (3, 4),
4312
        lambda: (
4313
            f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4314
            f"but got a tensor with {self.ndim} dimensions."
4315
        ),
4316
    )
4317
    torch._check(
4318
        self.shape == indices.shape,
4319
        lambda: (
4320
            f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
4321
            f"but got indices tensor with shape: {indices.shape}"
4322
        ),
4323
    )
4324

4325
    for i in range(1, self.ndim):
4326
        torch._check(
4327
            self.size(i) > 0,
4328
            lambda: (
4329
                f"max_unpooling2d(): "
4330
                f"Expected input to have non-zero size for non-batch dimensions, "
4331
                f"but got {self.shape} with dimension {i} being empty."
4332
            ),
4333
        )
4334

4335
    self = self.contiguous()
4336

4337
    if self.ndim == 3:
4338
        nchannels = self.size(0)
4339
        result = self.new_empty((nchannels, oheight, owidth))
4340
    else:
4341
        nbatch = self.size(0)
4342
        nchannels = self.size(1)
4343
        result = self.new_empty((nbatch, nchannels, oheight, owidth))
4344

4345
    return result
4346

4347

4348
def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4349
    torch._check(
4350
        indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4351
    )
4352
    torch._check(
4353
        input.ndim in (4, 5),
4354
        lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4355
    )
4356
    torch._check(
4357
        len(output_size) == 3,
4358
        lambda: (
4359
            f"There should be exactly three elements (depth, height, width) in output_size, "
4360
            f"but got {len(output_size)} elements."
4361
        ),
4362
    )
4363
    torch._check(
4364
        len(stride) == 3,
4365
        lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4366
    )
4367
    torch._check(
4368
        len(padding) == 3,
4369
        lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4370
    )
4371
    torch._check(
4372
        input.shape == indices.shape,
4373
        lambda: (
4374
            f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4375
            f"but got indices tensor with shape: {indices.shape}"
4376
        ),
4377
    )
4378

4379
    for i in range(1, input.ndim):
4380
        torch._check(
4381
            input.size(i) > 0,
4382
            lambda: (
4383
                f"{fn_name}: "
4384
                f"Expected input to have non-zero size for non-batch dimensions, "
4385
                f"but got {input.shape} with dimension {i} being empty."
4386
            ),
4387
        )
4388

4389
    torch._check(
4390
        stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4391
        lambda: f"strides should be greater than zero, but got stride: {stride}",
4392
    )
4393

4394

4395
@register_meta(aten.max_unpool3d)
4396
@out_wrapper()
4397
def meta_max_unpool3d(self, indices, output_size, stride, padding):
4398
    utils.alert_not_deterministic("max_unpooling3d_forward_out")
4399

4400
    _max_unpooling3d_shape_check(
4401
        self, indices, output_size, stride, padding, "max_unpooling3d()"
4402
    )
4403

4404
    self = self.contiguous()
4405

4406
    odepth, oheight, owidth = output_size
4407

4408
    if self.ndim == 4:
4409
        nchannels = self.size(0)
4410
        result = self.new_empty((nchannels, odepth, oheight, owidth))
4411
    else:
4412
        nbatch = self.size(0)
4413
        nchannels = self.size(1)
4414
        result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4415

4416
    return result
4417

4418

4419
@register_meta(aten.max_pool3d_with_indices)
4420
@out_wrapper("out", "indices")
4421
def meta_max_pool3d_with_indices(
4422
    input,
4423
    kernel_size,
4424
    stride=(),
4425
    padding=(0,),
4426
    dilation=(1,),
4427
    ceil_mode=False,
4428
):
4429
    torch._check(
4430
        len(kernel_size) in (1, 3),
4431
        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4432
    )
4433
    kT = kernel_size[0]
4434
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4435
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4436

4437
    torch._check(
4438
        not stride or len(stride) in (1, 3),
4439
        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4440
    )
4441
    dT = kT if not stride else stride[0]
4442
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4443
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4444

4445
    torch._check(
4446
        len(padding) in (1, 3),
4447
        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4448
    )
4449
    pT = padding[0]
4450
    pH = pT if len(padding) == 1 else padding[1]
4451
    pW = pT if len(padding) == 1 else padding[2]
4452

4453
    torch._check(
4454
        len(dilation) in (1, 3),
4455
        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4456
    )
4457
    dilationT = dilation[0]
4458
    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4459
    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4460

4461
    torch._check(
4462
        input.ndim in (4, 5),
4463
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4464
    )
4465

4466
    nbatch = input.size(-5) if input.ndim == 5 else 1
4467
    nslices = input.size(-4)
4468
    itime = input.size(-3)
4469
    iheight = input.size(-2)
4470
    iwidth = input.size(-1)
4471

4472
    otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
4473
    oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
4474
    owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
4475

4476
    pool3d_shape_check(
4477
        input,
4478
        nslices,
4479
        kT,
4480
        kH,
4481
        kW,
4482
        dT,
4483
        dH,
4484
        dW,
4485
        pT,
4486
        pH,
4487
        pW,
4488
        dilationT,
4489
        dilationH,
4490
        dilationW,
4491
        itime,
4492
        iheight,
4493
        iwidth,
4494
        otime,
4495
        oheight,
4496
        owidth,
4497
        "max_pool3d_with_indices()",
4498
    )
4499

4500
    channels_last = (
4501
        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4502
    )
4503
    if input.ndim == 4:
4504
        input_channels_last_check = input.unsqueeze(0)
4505
        channels_last = (
4506
            not input_channels_last_check.is_contiguous()
4507
        ) and input_channels_last_check.is_contiguous(
4508
            memory_format=torch.channels_last_3d
4509
        )
4510
        out_shape = (nslices, otime, oheight, owidth)
4511
    else:
4512
        out_shape = (nbatch, nslices, otime, oheight, owidth)  # type: ignore[assignment]
4513

4514
    out = input.new_empty(out_shape)
4515
    indices = input.new_empty(out_shape, dtype=torch.int64)
4516

4517
    if channels_last:
4518
        out = out.to(memory_format=torch.channels_last_3d)
4519
        indices = indices.to(memory_format=torch.channels_last_3d)
4520

4521
    return out, indices
4522

4523

4524
@register_meta(aten.max_pool3d_with_indices_backward)
4525
@out_wrapper("grad_input")
4526
def meta_max_pool3d_with_indices_backward(
4527
    grad_output,
4528
    input,
4529
    kernel_size,
4530
    stride,
4531
    padding,
4532
    dilation,
4533
    ceil_mode,
4534
    indices,
4535
):
4536
    torch._check(
4537
        len(kernel_size) in (1, 3),
4538
        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4539
    )
4540
    kT = kernel_size[0]
4541
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4542
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4543

4544
    torch._check(
4545
        not stride or len(stride) in (1, 3),
4546
        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4547
    )
4548
    dT = kT if not stride else stride[0]
4549
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4550
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4551

4552
    torch._check(
4553
        len(padding) in (1, 3),
4554
        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4555
    )
4556
    pT = padding[0]
4557
    pH = pT if len(padding) == 1 else padding[1]
4558
    pW = pT if len(padding) == 1 else padding[2]
4559

4560
    torch._check(
4561
        len(dilation) in (1, 3),
4562
        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4563
    )
4564
    dilationT = dilation[0]
4565
    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4566
    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4567

4568
    torch._check(
4569
        input.ndim in (4, 5),
4570
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4571
    )
4572

4573
    nslices = input.size(-4)
4574
    itime = input.size(-3)
4575
    iheight = input.size(-2)
4576
    iwidth = input.size(-1)
4577

4578
    otime = grad_output.size(-3)
4579
    oheight = grad_output.size(-2)
4580
    owidth = grad_output.size(-1)
4581

4582
    max_pool3d_backward_shape_check(
4583
        input,
4584
        grad_output,
4585
        indices,
4586
        nslices,
4587
        kT,
4588
        kH,
4589
        kW,
4590
        dT,
4591
        dH,
4592
        dW,
4593
        pT,
4594
        pH,
4595
        pW,
4596
        dilationT,
4597
        dilationH,
4598
        dilationW,
4599
        itime,
4600
        iheight,
4601
        iwidth,
4602
        otime,
4603
        oheight,
4604
        owidth,
4605
        "max_pool3d_with_indices_backward()",
4606
    )
4607

4608
    channels_last = (
4609
        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4610
    )
4611
    if input.ndim == 4:
4612
        input_channels_last_check = input.unsqueeze(0)
4613
        channels_last = (
4614
            not input_channels_last_check.is_contiguous()
4615
        ) and input_channels_last_check.is_contiguous(
4616
            memory_format=torch.channels_last_3d
4617
        )
4618

4619
    grad_input = input.new_empty(input.shape)
4620

4621
    if channels_last:
4622
        grad_input = grad_input.to(memory_format=torch.channels_last_3d)
4623

4624
    return grad_input
4625

4626

4627
def check_grid_sampler_common(input: Tensor, grid: Tensor):
4628
    torch._check(
4629
        input.device == grid.device,
4630
        lambda: (
4631
            f"grid_sampler(): expected input and grid to be on same device, but input "
4632
            f"is on {input.device} and grid is on {grid.device}"
4633
        ),
4634
    )
4635
    torch._check(
4636
        input.layout == torch.strided and grid.layout == torch.strided,
4637
        lambda: (
4638
            f"grid_sampler(): expected input and grid to have torch.strided layout, but "
4639
            f"input has {input.layout} and grid has {grid.layout}"
4640
        ),
4641
    )
4642
    torch._check(
4643
        input.shape[0] == grid.shape[0],
4644
        lambda: (
4645
            f"grid_sampler(): expected grid and input to have same batch size, but got "
4646
            f"input with sizes {input.shape} and grid with sizes {grid.shape}"
4647
        ),
4648
    )
4649
    torch._check(
4650
        grid.shape[-1] == input.ndim - 2,
4651
        lambda: (
4652
            f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
4653
            f"dimension, but got grid with sizes {grid.shape}"
4654
        ),
4655
    )
4656

4657
    for i in range(2, input.ndim):
4658
        torch._check(
4659
            input.shape[i] > 0,
4660
            lambda: (
4661
                f"grid_sampler(): expected input to have non-empty spatial dimensions, "
4662
                f"but input has sizes {input.shape} with dimension {i} being empty"
4663
            ),
4664
        )
4665

4666

4667
class GridSamplerInterpolation(Enum):
4668
    BILINEAR = 0
4669
    NEAREST = 1
4670
    BICUBIC = 2
4671

4672

4673
def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
4674
    torch._check(
4675
        input.ndim == 5 and input.ndim == grid.ndim,
4676
        lambda: (
4677
            f"grid_sampler(): expected 5D input and grid with same number of "
4678
            f"dimensions, but got input with sizes {input.shape}"
4679
            f" and grid with sizes {grid.shape}"
4680
        ),
4681
    )
4682
    torch._check(
4683
        not (
4684
            input.ndim == 5
4685
            and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
4686
        ),
4687
        lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
4688
    )
4689

4690

4691
@register_meta(aten.grid_sampler_2d_backward.default)
4692
def grid_sampler_2d_backward_meta(
4693
    grad_output,
4694
    input,
4695
    grid,
4696
    interpolation_mode,
4697
    padding_mode,
4698
    align_corners,
4699
    output_mask,
4700
):
4701
    input_requires_grad = output_mask[0]
4702
    if input_requires_grad:
4703
        grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
4704
    else:
4705
        grad_input = None
4706
    grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
4707
    return (grad_input, grad_grid)
4708

4709

4710
@register_meta(aten.grid_sampler_3d)
4711
@out_wrapper()
4712
def grid_sampler_3d(
4713
    input,
4714
    grid,
4715
    interpolation_mode,
4716
    padding_mode,
4717
    align_corners,
4718
):
4719
    check_grid_sampler_common(input, grid)
4720
    check_grid_sampler_3d(input, grid, interpolation_mode)
4721
    N = input.shape[0]
4722
    C = input.shape[1]
4723
    out_D = grid.shape[1]
4724
    out_H = grid.shape[2]
4725
    out_W = grid.shape[3]
4726
    return input.new_empty((N, C, out_D, out_H, out_W))
4727

4728

4729
@register_meta(aten.grid_sampler_3d_backward)
4730
@out_wrapper("grad_input", "grad_grid")
4731
def grid_sampler_3d_backward(
4732
    grad_output,
4733
    input,
4734
    grid,
4735
    interpolation_mode,
4736
    padding_mode,
4737
    align_corners,
4738
    output_mask,
4739
):
4740
    check_grid_sampler_common(input, grid)
4741
    check_grid_sampler_3d(input, grid, interpolation_mode)
4742
    input_requires_grad = output_mask[0]
4743
    if input_requires_grad:
4744
        grad_input = torch.zeros_like(
4745
            input, memory_format=torch.legacy_contiguous_format
4746
        )
4747
    else:
4748
        grad_input = None
4749
    grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
4750
    return grad_input, grad_grid
4751

4752

4753
@register_meta([aten.full.default])
4754
def full(size, fill_value, *args, **kwargs):
4755
    dtype = kwargs.get("dtype", None)
4756
    if not dtype:
4757
        dtype = utils.get_dtype(fill_value)
4758
    kwargs["dtype"] = dtype
4759
    return torch.empty(size, *args, **kwargs)
4760

4761

4762
# zeros_like is special cased to work for sparse
4763
@register_meta(aten.zeros_like.default)
4764
def zeros_like(
4765
    self,
4766
    dtype=None,
4767
    layout=None,
4768
    device=None,
4769
    pin_memory=None,
4770
    memory_format=None,
4771
):
4772
    if layout == torch.sparse_coo:
4773
        torch._check(
4774
            memory_format is None,
4775
            lambda: "memory format option is only supported by strided tensors",
4776
        )
4777

4778
        res = torch.empty(
4779
            0,
4780
            dtype=self.dtype if dtype is None else dtype,
4781
            layout=layout,
4782
            device=self.device if device is None else device,
4783
            pin_memory=pin_memory,
4784
        )
4785

4786
        if self.is_sparse:
4787
            res.sparse_resize_and_clear_(
4788
                self.size(), self.sparse_dim(), self.dense_dim()
4789
            )
4790
        else:
4791
            res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
4792

4793
        res._coalesced_(True)
4794
        return res
4795
    res = aten.empty_like.default(
4796
        self,
4797
        dtype=dtype,
4798
        layout=layout,
4799
        device=device,
4800
        pin_memory=pin_memory,
4801
        memory_format=memory_format,
4802
    )
4803
    # device can be not "meta"
4804
    res.fill_(0)
4805
    return res
4806

4807

4808
@register_meta(aten.select.int)
4809
def meta_select(self, dim, index):
4810
    ndim = self.dim()
4811
    torch._check_index(
4812
        ndim != 0,
4813
        lambda: "select() cannot be applied to a 0-dim tensor.",
4814
    )
4815

4816
    dim = dim if dim >= 0 else dim + ndim
4817
    size = self.size(dim)
4818

4819
    torch._check_index(
4820
        not (-index > size or index >= size),
4821
        lambda: f"select(): index {index} out of range for tensor of size "
4822
        f"{self.size()} at dimension {dim}",
4823
    )
4824

4825
    index = index if index >= 0 else index + size
4826

4827
    new_size = list(self.size())
4828
    new_stride = list(self.stride())
4829

4830
    new_storage_offset = self.storage_offset() + index * new_stride[dim]
4831
    del new_size[dim]
4832
    del new_stride[dim]
4833

4834
    return self.as_strided(new_size, new_stride, new_storage_offset)
4835

4836

4837
@register_meta(aten.select_scatter.default)
4838
def meta_select_scatter(self, src, dim, index):
4839
    return utils.clone_preserve_strides(self)
4840

4841

4842
@register_meta(aten.slice_scatter.default)
4843
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
4844
    return utils.clone_preserve_strides(self)
4845

4846

4847
# TODO: Deduplicate this with canonicalize_dim
4848
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
4849
    if dim_post_expr <= 0:
4850
        assert wrap_scalar
4851
        dim_post_expr = 1
4852
    min = -dim_post_expr
4853
    max = dim_post_expr - 1
4854
    assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
4855
    if dim < 0:
4856
        dim += dim_post_expr
4857
    return dim
4858

4859

4860
def ensure_nonempty_size(t, dim):
4861
    return 1 if t.dim() == 0 else t.shape[dim]
4862

4863

4864
# From aten/src/ATen/native/ScatterGatherChecks.h
4865
def gather_shape_check(self, dim, index):
4866
    self_dims = max(self.dim(), 1)
4867
    index_dims = max(index.dim(), 1)
4868
    torch._check(
4869
        self_dims == index_dims,
4870
        lambda: "Index tensor must have the same number of dimensions as input tensor",
4871
    )
4872
    for i in range(self_dims):
4873
        if i != dim:
4874
            torch._check(
4875
                ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
4876
                lambda: f"Size does not match at dimension {i} expected index {index.shape}"
4877
                + f" to be smaller than self {self.shape} apart from dimension {dim}",
4878
            )
4879

4880

4881
@register_meta(aten.gather.default)
4882
def meta_gather(self, dim, index, sparse_grad=False):
4883
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4884

4885
    wrapped_dim = maybe_wrap_dim(dim, self.dim())
4886
    is_index_empty = guard_size_oblivious(index.numel() == 0)
4887
    if not is_index_empty:
4888
        torch._check(
4889
            index.dtype == torch.long,
4890
            lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
4891
        )
4892
        gather_shape_check(self, wrapped_dim, index)
4893
    return self.new_empty(index.shape)
4894

4895

4896
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
4897
def get_operator_enum(reduce_, use_new_options=False):
4898
    if use_new_options:
4899
        if reduce_ == "sum":
4900
            return "REDUCE_ADD"
4901
        elif reduce_ == "prod":
4902
            return "REDUCE_MULTIPLY"
4903
        elif reduce_ == "mean":
4904
            return "REDUCE_MEAN"
4905
        elif reduce_ == "amax":
4906
            return "REDUCE_MAXIMUM"
4907
        elif reduce_ == "amin":
4908
            return "REDUCE_MINIMUM"
4909
        torch._check(
4910
            False,
4911
            lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
4912
        )
4913
        return
4914
    else:
4915
        if reduce_ == "add":
4916
            return "REDUCE_ADD"
4917
        elif reduce_ == "multiply":
4918
            return "REDUCE_MULTIPLY"
4919
        torch._check(False, lambda: "reduce argument must be either add or multiply.")
4920
        return
4921

4922

4923
# From aten/src/ATen/native/ScatterGatherChecks.h
4924
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
4925
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4926

4927
    if guard_size_oblivious(index.numel() != 0):
4928
        torch._check(
4929
            index.dtype == torch.long,
4930
            lambda: f"{method_name}(): Expected dtype int64 for index",
4931
        )
4932

4933
    if src_opt is not None:
4934
        torch._check(
4935
            self.dtype == src_opt.dtype,
4936
            lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
4937
        )
4938

4939

4940
def ensure_nonempty_dim(dim):
4941
    return max(dim, 1)
4942

4943

4944
# From aten/src/ATen/native/ScatterGatherChecks.h
4945
def scatter_shape_check(self, dim, index, src_opt=None):
4946
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4947

4948
    if guard_size_oblivious(index.numel() == 0):
4949
        return
4950
    torch._check(
4951
        ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4952
        lambda: "Index tensor must have the same number of dimensions as self tensor",
4953
    )
4954

4955
    is_wrong_shape = False
4956
    self_dims = ensure_nonempty_dim(self.dim())
4957

4958
    # Check: index.size(d) <= self.size(d) for all d != dim
4959
    for d in range(self_dims):
4960
        index_d_size = ensure_nonempty_size(index, d)
4961
        if d == dim:
4962
            continue
4963
        if index_d_size > ensure_nonempty_size(self, d):
4964
            is_wrong_shape = True
4965
            break
4966

4967
    # Check: index.size(d) <= src.size(d) for all d if src is Tensor
4968
    if not is_wrong_shape and src_opt is not None:
4969
        for d in range(self_dims):
4970
            index_d_size = ensure_nonempty_size(index, d)
4971
            if index_d_size > ensure_nonempty_size(src_opt, d):
4972
                is_wrong_shape = True
4973
                break
4974

4975
    if src_opt is not None:
4976
        torch._check(
4977
            ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4978
            lambda: "Index tensor must have the same number of dimensions as self tensor",
4979
        )
4980
        torch._check(
4981
            not is_wrong_shape,
4982
            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4983
            + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
4984
        )
4985
    else:
4986
        torch._check(
4987
            not is_wrong_shape,
4988
            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4989
            + f" apart from dimension {dim}",
4990
        )
4991

4992

4993
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
4994
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
4995
    wrapped_dim = maybe_wrap_dim(dim, self.dim())
4996
    scatter_gather_dtype_check("scatter", self, index, src)
4997
    scatter_shape_check(self, wrapped_dim, index, src)
4998
    if reduce_ is not None:
4999
        # Check if we have a valid reduce operator.
5000
        get_operator_enum(reduce_, use_new_options)
5001

5002

5003
@register_meta(aten.scatter_add.default)
5004
def meta_scatter_add(self, dim, index, src):
5005
    scatter_meta_impl(self, dim, index, src, "add")
5006
    return self.new_empty(self.shape)
5007

5008

5009
@register_meta(aten.scatter_add_)
5010
def meta_scatter_add_(self, dim, index, src):
5011
    scatter_meta_impl(self, dim, index, src, "add")
5012
    return self
5013

5014

5015
@register_meta(
5016
    [
5017
        aten.scatter.src,
5018
        aten.scatter.value,
5019
        aten.scatter.reduce,
5020
        aten.scatter.value_reduce,
5021
    ]
5022
)
5023
@out_wrapper()
5024
def meta_scatter(self, dim, index, src_or_value, reduce=None):
5025
    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5026
    scatter_meta_impl(self, dim, index, src, reduce)
5027
    return self.new_empty(self.shape)
5028

5029

5030
@register_meta(
5031
    [
5032
        aten.scatter_.src,
5033
        aten.scatter_.value,
5034
        aten.scatter_.reduce,
5035
        aten.scatter_.value_reduce,
5036
    ]
5037
)
5038
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
5039
    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5040
    scatter_meta_impl(self, dim, index, src, reduce)
5041
    return self
5042

5043

5044
@register_meta([aten._scaled_dot_product_flash_attention])
5045
def meta__scaled_dot_product_flash_attention(
5046
    query: Tensor,
5047
    key: Tensor,
5048
    value: Tensor,
5049
    dropout_p: float = 0.0,
5050
    is_causal: bool = False,
5051
    return_debug_mask: bool = False,
5052
    scale: Optional[float] = None,
5053
):
5054
    batch_size = query.size(0)
5055
    num_heads = query.size(1)
5056
    max_seqlen_batch_q = query.size(2)
5057
    head_dim = query.size(3)
5058
    max_seqlen_batch_k = key.size(2)
5059

5060
    query_t = query.transpose(1, 2)
5061
    attention = torch.empty_like(query_t).transpose(1, 2)
5062
    logsumexp = torch.empty(
5063
        (batch_size, num_heads, max_seqlen_batch_q),
5064
        dtype=torch.float,
5065
        device=query.device,
5066
    )
5067

5068
    if return_debug_mask:
5069
        blocksize_c = 128 if head_dim > 64 else 256
5070
        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5071
        if max_seqlen_batch_k <= 128:
5072
            max_seqlen_k = 128
5073
        elif max_seqlen_batch_k <= 256:
5074
            max_seqlen_k = 256
5075
        debug_mask = torch.empty(
5076
            (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5077
            dtype=query.dtype,
5078
            device=query.device,
5079
        )
5080
    else:
5081
        debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5082

5083
    # Note [Seed and Offset]: device for seed and offset below depends on whether we are
5084
    # capturing or not, but at the time of tracing we don't know if we
5085
    # are going to use cudagraphs or not, so we return meta tensors here
5086
    # it's possible we'll need to have some special handling in inductor for sdpa
5087

5088
    return (
5089
        attention,
5090
        logsumexp,
5091
        None,
5092
        None,
5093
        max_seqlen_batch_q,
5094
        max_seqlen_batch_k,
5095
        torch.empty((), dtype=torch.long, device="meta"),
5096
        torch.empty((), dtype=torch.long, device="meta"),
5097
        debug_mask,
5098
    )
5099

5100

5101
@register_meta([aten._scaled_dot_product_cudnn_attention])
5102
def meta__scaled_dot_product_cudnn_attention(
5103
    query: Tensor,
5104
    key: Tensor,
5105
    value: Tensor,
5106
    attn_bias: Optional[Tensor],
5107
    compute_log_sumexp: bool,
5108
    dropout_p: float = 0.0,
5109
    is_causal: bool = False,
5110
    return_debug_mask: bool = False,
5111
    scale: Optional[float] = None,
5112
):
5113
    B = query.size(0)
5114
    H = query.size(1)
5115
    S_Q = query.size(2)
5116
    S_KV = key.size(2)
5117
    D_QK = query.size(-1)
5118
    D_V = value.size(-1)
5119

5120
    res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
5121
    logsum_exp = torch.empty(
5122
        (B, H, S_Q),
5123
        dtype=torch.float,
5124
        device=query.device,
5125
    )
5126

5127
    # See Note [Seed and Offset]
5128
    seed = torch.empty((), dtype=torch.long, device="meta")
5129
    offset = torch.empty((), dtype=torch.long, device="meta")
5130

5131
    return (
5132
        res,
5133
        logsum_exp,
5134
        None,
5135
        None,
5136
        S_Q,
5137
        S_KV,
5138
        seed,
5139
        offset,
5140
        None,
5141
    )
5142

5143

5144
@register_meta(
5145
    [
5146
        aten._scaled_dot_product_flash_attention_backward,
5147
    ]
5148
)
5149
def meta__scaled_dot_product_flash_backward(
5150
    grad_out: Tensor,
5151
    query: Tensor,
5152
    key: Tensor,
5153
    value: Tensor,
5154
    out: Tensor,
5155
    logsumexp: Tensor,
5156
    cum_seq_q: Tensor,
5157
    cum_seq_k: Tensor,
5158
    max_q: int,
5159
    max_k: int,
5160
    dropout_p: float,
5161
    is_causal: bool,
5162
    philox_seed: Tensor,
5163
    philox_offset: Tensor,
5164
    scale: Optional[float] = None,
5165
):
5166
    grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
5167
    grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
5168
    grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
5169
    return grad_q, grad_k, grad_v
5170

5171

5172
@register_meta(
5173
    [
5174
        aten._scaled_dot_product_flash_attention_for_cpu,
5175
    ]
5176
)
5177
def meta__scaled_dot_product_flash_attention_for_cpu(
5178
    query: Tensor,
5179
    key: Tensor,
5180
    value: Tensor,
5181
    dropout_p: float = 0.0,
5182
    is_causal: bool = False,
5183
    attn_mask: Optional[Tensor] = None,
5184
    scale: Optional[float] = None,
5185
):
5186
    batch_size = query.size(0)
5187
    num_heads = query.size(1)
5188
    max_seqlen_batch_q = query.size(2)
5189
    head_dim = query.size(3)
5190

5191
    attention = torch.empty_like(query)
5192
    logsumexp = torch.empty(
5193
        (
5194
            batch_size,
5195
            max_seqlen_batch_q,
5196
            num_heads,
5197
        ),
5198
        dtype=torch.float,
5199
        device=query.device,
5200
    ).transpose(1, 2)
5201
    return (
5202
        attention,
5203
        logsumexp,
5204
    )
5205

5206

5207
@register_meta(
5208
    [
5209
        aten._scaled_dot_product_flash_attention_for_cpu_backward,
5210
    ]
5211
)
5212
def meta__scaled_dot_product_flash_attention_for_cpu_backward(
5213
    grad_out: Tensor,
5214
    query: Tensor,
5215
    key: Tensor,
5216
    value: Tensor,
5217
    out: Tensor,
5218
    logsumexp: Tensor,
5219
    dropout_p: float,
5220
    is_causal: bool,
5221
    attn_mask: Optional[Tensor] = None,
5222
    scale: Optional[float] = None,
5223
):
5224
    # cpus's grad layout is different from cuda's,
5225
    # i.e. (batch_size, seq_len,num_heads, head_dim)
5226
    batch_size = query.size(0)
5227
    num_heads = query.size(1)
5228
    head_dim = query.size(3)
5229
    len_q = query.size(2)
5230
    len_k = key.size(2)
5231

5232
    grad_q = torch.empty_permuted(
5233
        (batch_size, num_heads, len_q, head_dim),
5234
        (0, 2, 1, 3),
5235
        dtype=query.dtype,
5236
        device=query.device,
5237
    )
5238
    grad_k = torch.empty_permuted(
5239
        (batch_size, num_heads, len_k, head_dim),
5240
        (0, 2, 1, 3),
5241
        dtype=key.dtype,
5242
        device=key.device,
5243
    )
5244
    grad_v = torch.empty_permuted(
5245
        (batch_size, num_heads, len_k, head_dim),
5246
        (0, 2, 1, 3),
5247
        dtype=value.dtype,
5248
        device=value.device,
5249
    )
5250

5251
    return grad_q, grad_k, grad_v
5252

5253

5254
@register_meta([aten._scaled_dot_product_efficient_attention])
5255
def meta__scaled_dot_product_efficient_attention(
5256
    query: Tensor,
5257
    key: Tensor,
5258
    value: Tensor,
5259
    attn_bias: Optional[Tensor],
5260
    compute_log_sumexp: bool,
5261
    dropout_p=0.0,
5262
    is_causal: bool = False,
5263
    scale: Optional[float] = None,
5264
):
5265
    query = query.transpose(1, 2)
5266
    key = key.transpose(1, 2)
5267
    value = value.transpose(1, 2)
5268

5269
    B = query.size(0)
5270
    M = query.size(1)
5271
    N = key.size(1)
5272
    num_heads = query.size(-2)
5273
    K = query.size(-1)
5274
    Kv = value.size(-1)
5275

5276
    res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5277

5278
    logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5279
    logsum_exp = torch.empty(
5280
        (B, num_heads, logsumexp_dim),
5281
        dtype=torch.float,
5282
        device=query.device,
5283
    )
5284

5285
    res = res.transpose(1, 2)
5286

5287
    # See Note [Seed and Offset]:
5288
    seed = torch.empty((), dtype=torch.long, device="meta")
5289
    offset = torch.empty((), dtype=torch.long, device="meta")
5290

5291
    return res, logsum_exp, seed, offset
5292

5293

5294
@register_meta(
5295
    [
5296
        aten._scaled_dot_product_efficient_attention_backward,
5297
    ]
5298
)
5299
def meta__scaled_dot_product_efficient_backward(
5300
    grad_out: Tensor,
5301
    query: Tensor,
5302
    key: Tensor,
5303
    value: Tensor,
5304
    attn_bias: Optional[Tensor],
5305
    out: Tensor,
5306
    logsumexp: Tensor,
5307
    philox_seed: Tensor,
5308
    philox_offset: Tensor,
5309
    dropout_p: float,
5310
    grad_input_mask: List[bool],
5311
    is_causal: bool = False,
5312
    scale: Optional[float] = None,
5313
):
5314
    batch_size = query.size(0)
5315
    num_heads = query.size(1)
5316
    max_q = query.size(2)
5317
    head_dim = query.size(3)
5318
    head_dim_v = value.size(3)
5319

5320
    max_k = key.size(2)
5321

5322
    grad_q = torch.empty_permuted(
5323
        (batch_size, num_heads, max_q, head_dim),
5324
        (0, 2, 1, 3),
5325
        dtype=query.dtype,
5326
        device=query.device,
5327
    )
5328
    grad_k = torch.empty_permuted(
5329
        (batch_size, num_heads, max_k, head_dim),
5330
        (0, 2, 1, 3),
5331
        dtype=key.dtype,
5332
        device=key.device,
5333
    )
5334
    grad_v = torch.empty_permuted(
5335
        (batch_size, num_heads, max_k, head_dim_v),
5336
        (0, 2, 1, 3),
5337
        dtype=value.dtype,
5338
        device=value.device,
5339
    )
5340
    grad_bias = None
5341
    if attn_bias is not None and grad_input_mask[3]:
5342
        lastDim = attn_bias.size(-1)
5343
        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5344
        new_sizes = list(attn_bias.size())
5345
        new_sizes[-1] = lastDimAligned
5346
        grad_bias = torch.empty(
5347
            new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5348
        )
5349
        grad_bias = grad_bias[..., :lastDim]
5350

5351
    return grad_q, grad_k, grad_v, grad_bias
5352

5353

5354
@register_meta(
5355
    [
5356
        aten._scaled_dot_product_cudnn_attention_backward,
5357
    ]
5358
)
5359
def meta__scaled_dot_product_cudnn_backward(
5360
    grad_out: Tensor,
5361
    query: Tensor,
5362
    key: Tensor,
5363
    value: Tensor,
5364
    out: Tensor,
5365
    logsumexp: Tensor,
5366
    philox_seed: Tensor,
5367
    philox_offset: Tensor,
5368
    attn_bias: Tensor,
5369
    cum_seq_q: Tensor,
5370
    cum_seq_k: Tensor,
5371
    max_q: int,
5372
    max_k: int,
5373
    dropout_p: float,
5374
    is_causal: bool,
5375
    scale: Optional[float] = None,
5376
):
5377
    grad_q = torch.empty_like(query)
5378
    grad_k = torch.empty_like(key)
5379
    grad_v = torch.empty_like(value)
5380
    return grad_q, grad_k, grad_v
5381

5382

5383
@register_meta(
5384
    [
5385
        aten._flash_attention_forward,
5386
    ]
5387
)
5388
def meta__flash_attention_forward(
5389
    query: Tensor,
5390
    key: Tensor,
5391
    value: Tensor,
5392
    cum_seq_q: Optional[Tensor],
5393
    cum_seq_k: Optional[Tensor],
5394
    max_q: int,
5395
    max_k: int,
5396
    dropout_p: float,
5397
    is_causal: bool,
5398
    return_debug_mask: bool,
5399
    scale: Optional[float] = None,
5400
    window_size_left: Optional[int] = None,
5401
    window_size_right: Optional[int] = None,
5402
    seqused_k: Optional[Tensor] = None,
5403
    alibi_slopes: Optional[Tensor] = None,
5404
):
5405
    # NB: there are two underlying paths:
5406
    # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
5407
    # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
5408
    #    includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
5409
    batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
5410
    max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
5411
    max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
5412
    num_heads = query.size(-2)
5413
    head_dim = query.size(-1)
5414

5415
    # Cuda Path
5416
    attention = torch.empty_like(query)
5417
    logsumexp = torch.empty(
5418
        (batch_size, num_heads, max_seqlen_batch_q),
5419
        dtype=torch.float,
5420
        device=query.device,
5421
    )
5422

5423
    if return_debug_mask:
5424
        blocksize_c = 128 if head_dim > 64 else 256
5425
        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5426
        if max_seqlen_batch_k <= 128:
5427
            max_seqlen_k = 128
5428
        elif max_seqlen_batch_k <= 256:
5429
            max_seqlen_k = 256
5430
        debug_mask = torch.empty(
5431
            (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5432
            dtype=query.dtype,
5433
            device=query.device,
5434
        )
5435
    else:
5436
        debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5437

5438
    # See Note [Seed and Offset]:
5439
    return (
5440
        attention,
5441
        logsumexp,
5442
        torch.empty((), dtype=torch.long, device="meta"),
5443
        torch.empty((), dtype=torch.long, device="meta"),
5444
        debug_mask,
5445
    )
5446

5447

5448
@register_meta(
5449
    [
5450
        aten._flash_attention_backward,
5451
    ]
5452
)
5453
def meta__flash_attention_backward(
5454
    grad_out: Tensor,
5455
    query: Tensor,
5456
    key: Tensor,
5457
    value: Tensor,
5458
    out: Tensor,
5459
    logsumexp: Tensor,
5460
    cum_seq_q: Tensor,
5461
    cum_seq_k: Tensor,
5462
    max_q: int,
5463
    max_k: int,
5464
    dropout_p: float,
5465
    is_causal: bool,
5466
    philox_seed: Tensor,
5467
    philox_offset: Tensor,
5468
    scale: Optional[float] = None,
5469
    window_size_left: Optional[int] = None,
5470
    window_size_right: Optional[int] = None,
5471
):
5472
    grad_query = torch.empty_like(query)
5473
    grad_key = torch.empty_like(key)
5474
    grad_value = torch.empty_like(value)
5475

5476
    return grad_query, grad_key, grad_value
5477

5478

5479
@register_meta(
5480
    [
5481
        aten._efficient_attention_forward,
5482
    ]
5483
)
5484
def meta__efficient_attention_forward(
5485
    query: Tensor,
5486
    key: Tensor,
5487
    value: Tensor,
5488
    bias: Optional[Tensor],
5489
    cu_seqlens_q: Optional[Tensor],
5490
    cu_seqlens_k: Optional[Tensor],
5491
    max_seqlen_q: Optional[int],
5492
    max_seqlen_k: Optional[int],
5493
    dropout_p: float,
5494
    custom_mask_type: int,
5495
    compute_log_sumexp: bool = False,
5496
    scale: Optional[float] = None,
5497
    causal_diagonal: Optional[Tensor] = None,
5498
    seqlen_k: Optional[Tensor] = None,
5499
    window_size: Optional[int] = None,
5500
):
5501
    B = query.size(0)
5502
    M = query.size(1)
5503
    N = key.size(1)
5504
    num_heads = query.size(-2)
5505
    K = query.size(-1)
5506
    Kv = value.size(-1)
5507

5508
    res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5509

5510
    logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
5511
    actual_max_seqlen_q = M
5512
    if cu_seqlens_q is not None:
5513
        assert max_seqlen_q is not None
5514
        actual_max_seqlen_q = max_seqlen_q
5515
    actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
5516
    logsumexp_dim = (
5517
        math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
5518
    )
5519
    logsum_exp = torch.empty(
5520
        (logsumexp_batch_dim, num_heads, logsumexp_dim),
5521
        dtype=torch.float,
5522
        device=query.device,
5523
    )
5524

5525
    # See Note [Seed and Offset]:
5526
    seed = torch.empty((), dtype=torch.long, device="meta")
5527
    offset = torch.empty((), dtype=torch.long, device="meta")
5528

5529
    return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
5530

5531

5532
@register_meta(
5533
    [
5534
        aten._efficient_attention_backward,
5535
    ]
5536
)
5537
def meta__efficient_attention_backward(
5538
    grad_out: Tensor,
5539
    query: Tensor,
5540
    key: Tensor,
5541
    value: Tensor,
5542
    bias: Optional[Tensor],
5543
    cu_seqlens_q: Optional[Tensor],
5544
    cu_seqlens_k: Optional[Tensor],
5545
    max_seqlen_q: torch.SymInt,
5546
    max_seqlen_k: torch.SymInt,
5547
    logsumexp: Tensor,
5548
    dropout_p: float,
5549
    philox_seed: Tensor,
5550
    philox_offset: Tensor,
5551
    custom_mask_type: int,
5552
    bias_requires_grad: bool,
5553
    scale: Optional[float] = None,
5554
    num_splits_key: Optional[int] = None,
5555
    shared_storage_dqdkdv: bool = False,
5556
):
5557
    if shared_storage_dqdkdv:
5558
        torch._check(
5559
            query.shape[1] == key.shape[1],
5560
            lambda: "seqlen must match for `shared_storage_dqdkdv",
5561
        )
5562
        torch._check(
5563
            query.shape[3] == key.shape[3],
5564
            lambda: "embedding dim must match for `shared_storage_dqdkdv",
5565
        )
5566
        chunk = torch.empty(
5567
            (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
5568
            dtype=query.dtype,
5569
            device=query.device,
5570
        )
5571
        grad_query = chunk.select(-3, 0)
5572
        grad_key = chunk.select(-3, 1)
5573
        grad_value = chunk.select(-3, 2)
5574
    else:
5575
        grad_query = torch.empty_like(query)
5576
        grad_key = torch.empty_like(key)
5577
        grad_value = torch.empty_like(value)
5578

5579
    if bias is not None:
5580
        lastDim = bias.size(-1)
5581
        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5582
        new_sizes = list(bias.size())
5583
        new_sizes[-1] = lastDimAligned
5584
        grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
5585
        grad_bias = grad_bias[..., :lastDim]
5586
    else:
5587
        grad_bias = torch.empty((), device=query.device)
5588

5589
    return grad_query, grad_key, grad_value, grad_bias
5590

5591

5592
@register_meta([aten._scaled_mm.default])
5593
def meta_scaled_mm(
5594
    self: torch.Tensor,
5595
    mat2: torch.Tensor,
5596
    scale_a: torch.Tensor,
5597
    scale_b: torch.Tensor,
5598
    bias: Optional[torch.Tensor] = None,
5599
    scale_result: Optional[torch.Tensor] = None,
5600
    out_dtype: Optional[torch.dtype] = None,
5601
    use_fast_accum: bool = False,
5602
):
5603
    def is_row_major(stride):
5604
        return stride[0] > stride[1] and stride[1] == 1
5605

5606
    def is_col_major(stride):
5607
        return stride[0] == 1 and stride[1] > 1
5608

5609
    def is_fp8_type(dtype):
5610
        return dtype in (
5611
            torch.float8_e4m3fn,
5612
            torch.float8_e5m2,
5613
            torch.float8_e4m3fnuz,
5614
            torch.float8_e5m2fnuz,
5615
        )
5616

5617
    torch._check(
5618
        self.dim() == 2 and mat2.dim() == 2,
5619
        lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5620
    )
5621
    torch._check(
5622
        is_row_major(self.stride()),
5623
        lambda: "self must be row_major",
5624
    )
5625
    torch._check(
5626
        is_col_major(mat2.stride()),
5627
        lambda: "mat2 must be col_major",
5628
    )
5629
    torch._check(
5630
        self.size(1) % 16 == 0,
5631
        lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
5632
    )
5633
    torch._check(
5634
        mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5635
        lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5636
    )
5637
    torch._check(
5638
        is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5639
        lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5640
    )
5641

5642
    # determine scaling type and check input dimensions (refer to Blas.cpp op)
5643
    torch._check(
5644
        scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
5645
        lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
5646
    )
5647
    m, k = self.shape
5648
    n = mat2.size(1)
5649
    if scale_a.numel() == 1 and scale_b.numel() == 1:
5650
        # tensorwise scaling
5651
        pass
5652
    else:
5653
        # for non-tensorwise scaling, enforce 2D input tensors
5654
        torch._check(
5655
            scale_a.dim() == 2 and scale_b.dim() == 2,
5656
            lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}",
5657
        )
5658

5659
        if (
5660
            scale_a.size(0) == m
5661
            and scale_a.size(1) == 1
5662
            and scale_b.size(0) == 1
5663
            and scale_b.size(1) == n
5664
        ):
5665
            # rowwise scaling
5666
            torch._check(
5667
                scale_a.is_contiguous() and scale_b.is_contiguous(),
5668
                lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
5669
            )
5670
        else:
5671
            # does not match any valid scaling type
5672
            torch._check(
5673
                False,
5674
                lambda: (
5675
                    "Invalid scaling configuration. "
5676
                    "For tensorwise scaling, both scales should be scalar. "
5677
                    f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
5678
                    f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
5679
                    f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
5680
                ),
5681
            )
5682

5683
    _out_dtype = out_dtype if out_dtype is not None else self.dtype
5684
    return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)
5685

5686

5687
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
5688
@out_wrapper()
5689
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
5690
    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5691
    return self.new_empty(self.shape)
5692

5693

5694
@register_meta(aten.scatter_reduce_.two)
5695
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
5696
    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5697
    return self
5698

5699

5700
@register_meta([aten.multinomial.default, aten.multinomial.out])
5701
@out_wrapper()
5702
def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
5703
    torch._check(
5704
        0 < input.dim() <= 2,
5705
        lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
5706
    )
5707
    if input.dim() == 1:
5708
        return torch.empty(num_samples, dtype=torch.long, device=input.device)
5709
    return torch.empty(
5710
        input.size(0), num_samples, dtype=torch.long, device=input.device
5711
    )
5712

5713

5714
def multiply_integers(vs):
5715
    r = 1
5716
    for v in vs:
5717
        r *= v
5718
    return r
5719

5720

5721
def upsample_common_check(input_size, output_size, num_spatial_dims):
5722
    torch._check(
5723
        len(output_size) == num_spatial_dims,
5724
        lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
5725
    )
5726
    expected_input_dims = num_spatial_dims + 2  # N, C, ...
5727
    torch._check(
5728
        len(input_size) == expected_input_dims,
5729
        lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
5730
    )
5731

5732
    torch._check(
5733
        all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
5734
        lambda: f"Input and output sizes should be greater than 0, but got "
5735
        f"input size {input_size} and output size {output_size}",
5736
    )
5737

5738
    nbatch, channels = input_size[:2]
5739
    return (nbatch, channels, *output_size)
5740

5741

5742
@register_meta(
5743
    [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
5744
)
5745
def upsample_nearest1d(input, output_size, scales=None):
5746
    torch._check(
5747
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5748
        lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
5749
    )
5750
    full_output_size = upsample_common_check(
5751
        input.size(), output_size, num_spatial_dims=1
5752
    )
5753
    return input.new_empty(full_output_size).to(
5754
        memory_format=utils.suggest_memory_format(input)
5755
    )
5756

5757

5758
@register_meta(
5759
    [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
5760
)
5761
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
5762
    torch._check(
5763
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5764
        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5765
    )
5766
    full_output_size = upsample_common_check(
5767
        input.size(), output_size, num_spatial_dims=2
5768
    )
5769
    output = input.new_empty(full_output_size)
5770

5771
    # convert output to correct memory format, if necessary
5772
    memory_format = utils.suggest_memory_format(input)
5773

5774
    # following "heuristic: only use channels_last path when it's faster than the contiguous path"
5775
    _, n_channels, _, _ = input.shape
5776
    if input.device.type == "cuda" and n_channels < 4:
5777
        memory_format = torch.contiguous_format
5778

5779
    output = output.contiguous(memory_format=memory_format)
5780

5781
    return output
5782

5783

5784
@register_meta(
5785
    [
5786
        aten.upsample_nearest2d_backward.default,
5787
        aten._upsample_nearest_exact2d_backward.default,
5788
    ]
5789
)
5790
def upsample_nearest2d_backward(
5791
    grad_output: Tensor,
5792
    output_size: Sequence[Union[int, torch.SymInt]],
5793
    input_size: Sequence[Union[int, torch.SymInt]],
5794
    scales_h: Optional[float] = None,
5795
    scales_w: Optional[float] = None,
5796
):
5797
    full_output_size = upsample_common_check(
5798
        input_size, output_size, num_spatial_dims=2
5799
    )
5800
    torch._check(
5801
        grad_output.ndim == 4,
5802
        lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
5803
    )
5804
    for i in range(4):
5805
        torch._check(
5806
            grad_output.size(i) == full_output_size[i],
5807
            lambda: (
5808
                f"Expected grad_output to have the same shape as output;"
5809
                f" output.size({i}) = {full_output_size[i]}"
5810
                f" but got grad_output.size({i}) = {grad_output.size(i)}"
5811
            ),
5812
        )
5813

5814
    return grad_output.new_empty(input_size).to(
5815
        memory_format=utils.suggest_memory_format(grad_output)
5816
    )  # type: ignore[call-overload]
5817

5818

5819
@register_meta(
5820
    [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
5821
)
5822
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
5823
    torch._check(
5824
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5825
        lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
5826
    )
5827
    full_output_size = upsample_common_check(
5828
        input.size(), output_size, num_spatial_dims=3
5829
    )
5830
    return input.new_empty(full_output_size).to(
5831
        memory_format=utils.suggest_memory_format(input)
5832
    )
5833

5834

5835
@register_meta(
5836
    [
5837
        aten.sort.default,
5838
        aten.sort.stable,
5839
        aten.sort.values,
5840
        aten.sort.values_stable,
5841
    ]
5842
)
5843
def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
5844
    v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
5845
    if values is not None and indices is not None:
5846
        assert isinstance(values, TensorLike)
5847
        assert isinstance(indices, TensorLike)
5848
        # Makes sure values and indices have the same strides. For cases where
5849
        # these have different shapes, like (5, 10, 5) and (0) in msort.
5850
        out_shape = v.shape
5851
        out_stride = v.stride()
5852
        values = _maybe_resize_out(values, out_shape)
5853
        indices = _maybe_resize_out(indices, out_shape)
5854
        values.as_strided_(out_shape, out_stride)
5855
        indices.as_strided_(out_shape, out_stride)
5856
        _safe_copy_out(copy_from=v, copy_to=values)  # type: ignore[arg-type]
5857
        _safe_copy_out(copy_from=i, copy_to=indices)  # type: ignore[arg-type]
5858
        return values, indices
5859
    return v, i
5860

5861

5862
def rnn_cell_checkSizes(
5863
    input_gates,
5864
    hidden_gates,
5865
    input_bias,
5866
    hidden_bias,
5867
    factor,
5868
    prev_hidden,
5869
):
5870
    torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
5871
    torch._check(
5872
        input_gates.shape == hidden_gates.shape,
5873
        lambda: f"{input_gates.shape} != {hidden_gates.shape}",
5874
    )
5875
    gates_size = input_gates.size(1)
5876
    if input_bias is not None:
5877
        torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
5878
        torch._check(
5879
            input_bias.numel() == gates_size,
5880
            lambda: f"{input_bias.numel()} != {gates_size}",
5881
        )
5882
        torch._check(
5883
            input_bias.shape == hidden_bias.shape,
5884
            lambda: f"{input_bias.shape} != {hidden_bias.shape}",
5885
        )
5886
    torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
5887
    expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
5888
    torch._check(
5889
        prev_hidden.numel() == expected_prev_hidden_numel,
5890
        lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
5891
    )
5892
    torch._check(
5893
        all(
5894
            x.device == input_gates.device
5895
            for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
5896
        ),
5897
        lambda: "expected all inputs to be same device",
5898
    )
5899

5900

5901
@register_meta(aten._thnn_fused_lstm_cell.default)
5902
def _thnn_fused_lstm_cell_meta(
5903
    input_gates,
5904
    hidden_gates,
5905
    cx,
5906
    input_bias=None,
5907
    hidden_bias=None,
5908
):
5909
    rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
5910
    workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
5911
    hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5912
    cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5913
    return (hy, cy, workspace)
5914

5915

5916
@register_meta(aten._cudnn_rnn.default)
5917
def _cudnn_rnn(
5918
    input,
5919
    weight,
5920
    weight_stride0,
5921
    weight_buf,
5922
    hx,
5923
    cx,
5924
    mode,
5925
    hidden_size,
5926
    proj_size,
5927
    num_layers,
5928
    batch_first,
5929
    dropout,
5930
    train,
5931
    bidirectional,
5932
    batch_sizes,
5933
    dropout_state,
5934
):
5935
    is_input_packed = len(batch_sizes) != 0
5936
    if is_input_packed:
5937
        seq_length = len(batch_sizes)
5938
        mini_batch = batch_sizes[0]
5939
        batch_sizes_sum = input.shape[0]
5940
    else:
5941
        seq_length = input.shape[1] if batch_first else input.shape[0]
5942
        mini_batch = input.shape[0] if batch_first else input.shape[1]
5943
        batch_sizes_sum = -1
5944

5945
    num_directions = 2 if bidirectional else 1
5946
    out_size = proj_size if proj_size != 0 else hidden_size
5947
    if is_input_packed:
5948
        out_shape = [batch_sizes_sum, out_size * num_directions]
5949
    else:
5950
        out_shape = (
5951
            [mini_batch, seq_length, out_size * num_directions]
5952
            if batch_first
5953
            else [seq_length, mini_batch, out_size * num_directions]
5954
        )
5955
    output = input.new_empty(out_shape)
5956

5957
    cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
5958
    if cx is None:
5959
        cy = torch.empty(0, device=input.device)
5960
    else:
5961
        cy = cx.new_empty(cell_shape)
5962

5963
    hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
5964

5965
    # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
5966
    reserve_shape = 0 if train else 0
5967
    reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
5968

5969
    return output, hy, cy, reserve, weight_buf
5970

5971

5972
@register_meta(aten.mkldnn_rnn_layer.default)
5973
def mkldnn_rnn_layer(
5974
    input,
5975
    w0,
5976
    w1,
5977
    w2,
5978
    w3,
5979
    hx_,
5980
    cx_,
5981
    reverse,
5982
    batch_sizes,
5983
    mode,
5984
    hidden_size,
5985
    num_layers,
5986
    has_biases,
5987
    bidirectional,
5988
    batch_first,
5989
    train,
5990
):
5991
    seq_length = input.shape[1] if batch_first else input.shape[0]
5992
    mini_batch = input.shape[0] if batch_first else input.shape[1]
5993
    output_chanels = hidden_size
5994
    out_shape = (
5995
        [mini_batch, seq_length, output_chanels]
5996
        if batch_first
5997
        else [seq_length, mini_batch, output_chanels]
5998
    )
5999
    output = input.new_empty(out_shape)
6000
    if hx_ is None:
6001
        hy = torch.empty(0, device=input.device)
6002
    else:
6003
        hy = hx_.new_empty(hx_.shape)
6004
    if cx_ is None:
6005
        cy = torch.empty(0, device=input.device)
6006
    else:
6007
        cy = cx_.new_empty(cx_.shape)
6008
    workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
6009
    return output, hy, cy, workspace
6010

6011

6012
def zero_numel_check_dims(self, dim, fn_name):
6013
    if self.ndim == 0:
6014
        torch._check_index(
6015
            dim == 0 or dim == -1,
6016
            lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
6017
        )
6018
    else:
6019
        torch._check_index(
6020
            self.size(dim) != 0,
6021
            lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
6022
        )
6023

6024

6025
# From aten/src/ATen/native/ReduceOps.cpp
6026
def check_argmax_argmin(name, self, dim):
6027
    if dim is not None:
6028
        dim = maybe_wrap_dim(dim, self.dim())
6029
        zero_numel_check_dims(self, dim, name)
6030
    else:
6031
        torch._check(
6032
            self.numel() != 0,
6033
            lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
6034
        )
6035

6036

6037
@register_meta([aten.argmax.default, aten.argmin.default])
6038
def argmax_argmin_meta(self, dim=None, keepdim=False):
6039
    check_argmax_argmin("argmax", self, dim)
6040
    dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
6041
    shape = _compute_reduction_shape(self, dims, keepdim)
6042
    return self.new_empty(shape, dtype=torch.int64)
6043

6044

6045
@register_meta(aten.scalar_tensor.default)
6046
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
6047
    return torch.empty(
6048
        (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
6049
    )
6050

6051

6052
@register_meta(aten.topk.default)
6053
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
6054
    # From aten/src/ATen/native/Sorting.cpp
6055
    dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6056
    sliceSize = 1 if self.dim() == 0 else self.size(dim)
6057
    torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
6058

6059
    topKSize = list(self.shape)
6060
    if len(topKSize) > 0:
6061
        topKSize[dim] = k
6062
    return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
6063

6064

6065
@register_meta([aten.kthvalue.default, aten.kthvalue.values])
6066
@out_wrapper("values", "indices")
6067
def kthvalue_meta(self, k, dim=-1, keepdim=False):
6068
    dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6069
    dimSize = self.size(dim) if self.dim() > 0 else 1
6070
    torch._check(
6071
        k >= 1 and k <= dimSize,
6072
        lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
6073
    )
6074

6075
    shape = list(self.shape[:dim] + self.shape[dim + 1 :])
6076
    if keepdim and self.dim() > 0:
6077
        shape.insert(dim, 1)
6078
    return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
6079

6080

6081
legacy_contiguous_memory_format = torch.contiguous_format
6082

6083

6084
# From aten/src/ATen/native/cuda/RNN.cu
6085
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
6086
    defined_grad = grad_hy if grad_hy is not None else grad_cy
6087
    torch._check(defined_grad.dim() == 2, lambda: "")
6088
    exp_size = defined_grad.size()
6089
    if grad_hy is not None:
6090
        torch._check(grad_hy.size() == exp_size, lambda: "")
6091
    if grad_cy is not None:
6092
        torch._check(grad_cy.size() == exp_size, lambda: "")
6093
    torch._check(cx.size() == exp_size, lambda: "")
6094
    torch._check(cy.size() == exp_size, lambda: "")
6095
    torch._check(workspace.dim() == 2, lambda: "")
6096
    torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
6097

6098

6099
# From aten/src/ATen/native/cuda/RNN.cu
6100
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
6101
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
6102
    if grad_hy is None and grad_cy is None:
6103
        return None, None, None
6104
    checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
6105
    grad_gates = torch.empty_like(
6106
        workspace, memory_format=legacy_contiguous_memory_format
6107
    )
6108
    grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
6109
    grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
6110
    return grad_gates, grad_cx, grad_bias
6111

6112

6113
# From aten/src/ATen/native/mps/operations/Linear.mm
6114
@register_meta(aten.linear_backward.default)
6115
def linear_backward(input_, grad_output_, weight_, output_mask):
6116
    grad_input = None
6117
    grad_weight = None
6118
    grad_bias = None
6119
    if output_mask[0]:
6120
        grad_input = grad_output_.new_empty(input_.size())
6121
    if output_mask[1] or output_mask[2]:
6122
        grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
6123
        grad_bias = grad_output_.new_empty(grad_output_.size(-1))
6124
    return (grad_input, grad_weight, grad_bias)
6125

6126

6127
@register_meta(aten.pixel_shuffle.default)
6128
def meta_pixel_shuffle(self, upscale_factor):
6129
    assert (
6130
        len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
6131
    ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
6132

6133
    def is_channels_last(ten):
6134
        return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
6135

6136
    def pick_memory_format():
6137
        if is_channels_last(self):
6138
            if device_hint(self) == "cuda":
6139
                return torch.contiguous_format
6140
            else:
6141
                return torch.channels_last
6142
        elif self.is_contiguous(memory_format=torch.contiguous_format):
6143
            return torch.contiguous_format
6144
        elif self.is_contiguous(memory_format=torch.preserve_format):
6145
            return torch.preserve_format
6146

6147
    C = self.shape[-3] // (upscale_factor * upscale_factor)
6148
    Hr = self.shape[-2] * upscale_factor
6149
    Wr = self.shape[-1] * upscale_factor
6150
    out_shape = (*self.shape[:-3], C, Hr, Wr)
6151

6152
    out = self.new_empty(out_shape)
6153
    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
6154
    return out
6155

6156

6157
@register_meta(aten.mkldnn_rnn_layer_backward.default)
6158
def mkldnn_rnn_layer_backward(
6159
    input,
6160
    weight0,
6161
    weight1,
6162
    weight2,
6163
    weight3,
6164
    hx_,
6165
    cx_tmp,
6166
    output,
6167
    hy_,
6168
    cy_,
6169
    grad_output_r_opt,
6170
    grad_hy_r_opt,
6171
    grad_cy_r_opt,
6172
    reverse,
6173
    mode,
6174
    hidden_size,
6175
    num_layers,
6176
    has_biases,
6177
    train,
6178
    bidirectional,
6179
    batch_sizes,
6180
    batch_first,
6181
    workspace,
6182
):
6183
    diff_x = input.new_empty(input.shape)
6184
    diff_hx = hx_.new_empty(hx_.shape)
6185
    diff_cx = cx_tmp.new_empty(cx_tmp.shape)
6186
    diff_w1 = weight0.new_empty(weight0.shape)
6187
    diff_w2 = weight1.new_empty(weight1.shape)
6188
    diff_b = weight2.new_empty(weight2.shape)
6189
    return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
6190

6191

6192
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
6193
@out_wrapper()
6194
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
6195
    return torch.empty_like(
6196
        self, dtype=torch.int32 if out_int32 else torch.int64
6197
    ).contiguous()
6198

6199

6200
@register_meta([aten.histc])
6201
@out_wrapper()
6202
def meta_histc(input, bins=100, min=0, max=0):
6203
    fn_name = "histc()"
6204
    if device_hint(input) == "cpu":
6205
        torch._check(
6206
            input.is_floating_point(),
6207
            lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'",
6208
        )
6209
    torch._check(
6210
        isinstance(bins, IntLike),
6211
        lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}",
6212
    )
6213
    torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}")
6214
    torch._check(
6215
        isinstance(min, Number),
6216
        lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}",
6217
    )
6218
    torch._check(
6219
        isinstance(max, Number),
6220
        lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
6221
    )
6222
    torch._check(max >= min, lambda: "{fn_name}: max must be larger than min")
6223
    return torch.empty(bins, device=input.device, dtype=input.dtype)
6224

6225

6226
@register_meta(
6227
    [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
6228
)
6229
def meta_upsample_bimode2d_aa(
6230
    input,
6231
    output_size,
6232
    align_corners,
6233
    scales_h=None,
6234
    scales_w=None,
6235
):
6236
    full_output_size = upsample_common_check(
6237
        input.size(), output_size, num_spatial_dims=2
6238
    )
6239
    torch._check(
6240
        input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
6241
        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
6242
    )
6243
    return input.new_empty(full_output_size).to(
6244
        memory_format=utils.suggest_memory_format(input)
6245
    )
6246

6247

6248
# From aten/src/ATen/native/cuda/AmpKernels.cu
6249
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
6250
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
6251
    torch._check(
6252
        found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
6253
    )
6254
    torch._check(
6255
        inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
6256
    )
6257
    torch._check(
6258
        found_inf.dtype.is_floating_point,
6259
        lambda: "found_inf must be a float tensor.",
6260
    )
6261
    torch._check(
6262
        inv_scale.dtype.is_floating_point,
6263
        lambda: "inv_scale must be a float tensor.",
6264
    )
6265

6266

6267
# From aten/src/ATen/native/UnaryOps.cpp
6268
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
6269
@out_wrapper()
6270
def nan_to_num(self, nan=None, posinf=None, neginf=None):
6271
    result_size = list(self.size())
6272
    return self.new_empty(result_size)
6273

6274

6275
@register_meta(torch.ops.aten.transpose_)
6276
def transpose_(self, dim0, dim1):
6277
    assert (
6278
        self.layout
6279
        not in {
6280
            torch.sparse_csr,
6281
            torch.sparse_csc,
6282
            torch.sparse_bsr,
6283
            torch.sparse_bsc,
6284
        }
6285
    ), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
6286

6287
    ndims = self.ndim
6288

6289
    dim0 = maybe_wrap_dim(dim0, ndims)
6290
    dim1 = maybe_wrap_dim(dim1, ndims)
6291

6292
    if dim0 == dim1:
6293
        return self
6294

6295
    size = list(self.size())
6296
    stride = list(self.stride())
6297

6298
    stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
6299
    size[dim0], size[dim1] = size[dim1], size[dim0]
6300

6301
    self.as_strided_(size, stride)
6302
    return self
6303

6304

6305
@register_meta(torch.ops.aten.t_)
6306
def t_(self):
6307
    ndims = self.ndim
6308

6309
    if self.is_sparse:
6310
        sparse_dim = self.sparse_dim()
6311
        dense_dim = self.dense_dim()
6312
        assert (
6313
            sparse_dim <= 2 and dense_dim == 0
6314
        ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions"  # noqa: B950
6315
    else:
6316
        assert (
6317
            self.dim() <= 2
6318
        ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
6319

6320
    return transpose_(self, 0, 0 if ndims < 2 else 1)
6321

6322

6323
@register_meta(aten.searchsorted)
6324
@out_wrapper()
6325
def meta_searchsorted(
6326
    sorted_sequence,
6327
    self,
6328
    *,
6329
    out_int32=False,
6330
    right=False,
6331
    side=None,
6332
    sorter=None,
6333
):
6334
    dtype = torch.int32 if out_int32 else torch.int64
6335
    if isinstance(self, torch.Tensor):
6336
        return torch.empty_like(self, dtype=dtype).contiguous()
6337
    else:  # Scalar
6338
        return torch.empty((), dtype=dtype, device=sorted_sequence.device)
6339

6340

6341
def _check_for_unsupported_isin_dtype(dtype):
6342
    torch._check(
6343
        dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
6344
        lambda: f"Unsupported input type encountered for isin(): {dtype}",
6345
    )
6346

6347

6348
@register_meta(aten._embedding_bag_backward)
6349
def meta_embedding_bag_backward(
6350
    grad,
6351
    indices,
6352
    offsets,
6353
    offset2bag,
6354
    bag_size,
6355
    maximum_indices,
6356
    num_weights,
6357
    scale_grad_by_freq,
6358
    mode,
6359
    sparse,
6360
    per_sample_weights,
6361
    padding_idx=-1,
6362
):
6363
    if sparse:
6364
        return aten._embedding_bag_sparse_backward(
6365
            grad,
6366
            indices,
6367
            offsets,
6368
            offset2bag,
6369
            bag_size,
6370
            num_weights,
6371
            scale_grad_by_freq,
6372
            mode,
6373
            per_sample_weights,
6374
            padding_idx,
6375
        )
6376
    else:
6377
        return meta_embedding_bag_dense_backward(
6378
            grad,
6379
            indices,
6380
            offset2bag,
6381
            bag_size,
6382
            maximum_indices,
6383
            num_weights,
6384
            scale_grad_by_freq,
6385
            mode,
6386
            per_sample_weights,
6387
            padding_idx,
6388
        )
6389

6390

6391
@register_meta(aten._embedding_bag_dense_backward)
6392
def meta_embedding_bag_dense_backward(
6393
    grad,
6394
    indices,
6395
    offset2bag,
6396
    bag_size,
6397
    maximum_indices,
6398
    num_weights,
6399
    scale_grad_by_freq,
6400
    mode,
6401
    per_sample_weights,
6402
    padding_idx=-1,
6403
):
6404
    torch._check(
6405
        grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64],
6406
        lambda: f"Unsupported input type encountered: {grad.dtype}",
6407
    )
6408
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6409
    if mode == MODE_MAX:
6410
        torch._check(maximum_indices is not None)
6411
    index_grad_weight = grad.new_empty((num_weights, grad.size(1)))
6412
    return index_grad_weight
6413

6414

6415
@register_meta(aten._embedding_bag_per_sample_weights_backward)
6416
def meta_embedding_bag_per_sample_weights_backward(
6417
    grad,
6418
    weight,
6419
    indices,
6420
    offsets,
6421
    offset2bag,
6422
    mode,
6423
    padding_idx=-1,
6424
):
6425
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6426
    embedding_features = grad.size(1)
6427
    torch._check(
6428
        mode == MODE_SUM,
6429
        "embedding_bag_backward: per_sample_weights only supported for mode='sum'",
6430
    )
6431
    torch._check(grad.dim() == 2)
6432
    torch._check(indices.dim() == 1)
6433
    num_samples = indices.size(0)
6434
    torch._check(weight.dim() == 2)
6435
    torch._check(weight.size(1) == embedding_features)
6436
    output = grad.new_empty((num_samples,))
6437
    return output
6438

6439

6440
@register_meta(aten.isin)
6441
@out_wrapper()
6442
def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
6443
    torch._check(
6444
        isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
6445
        lambda: "At least one of elements and test_elements must be a Tensor.",
6446
    )
6447
    if not isinstance(elements, Tensor):
6448
        elements = torch.tensor(elements, device=test_elements.device)
6449

6450
    if not isinstance(test_elements, Tensor):
6451
        test_elements = torch.tensor(test_elements, device=elements.device)
6452

6453
    _check_for_unsupported_isin_dtype(elements.dtype)
6454
    _check_for_unsupported_isin_dtype(test_elements.dtype)
6455
    return torch.empty_like(elements, dtype=torch.bool)
6456

6457

6458
@register_meta(aten.polygamma)
6459
@out_wrapper()
6460
def meta_polygamma(n: int, self: Tensor) -> Tensor:
6461
    torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
6462
    _, result_dtype = elementwise_dtypes(
6463
        self,
6464
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
6465
    )
6466
    return torch.empty_like(self, dtype=result_dtype)
6467

6468

6469
@register_meta(aten._local_scalar_dense)
6470
def meta_local_scalar_dense(self: Tensor):
6471
    raise RuntimeError("Tensor.item() cannot be called on meta tensors")
6472

6473

6474
@register_meta(aten._jagged_to_padded_dense_forward.default)
6475
def meta__jagged_to_padded_dense_forward(
6476
    values: Tensor,
6477
    offsets: List[Tensor],
6478
    max_lengths: List[int],
6479
    padding_value: float = 0.0,
6480
):
6481
    # only one jagged dim is supported for now
6482
    assert len(offsets) == 1
6483
    assert len(max_lengths) == 1
6484

6485
    B = offsets[0].shape[0] - 1
6486
    S = max_lengths[0]
6487
    output_shape = (B, S, *values.shape[1:])
6488
    return values.new_empty(output_shape)
6489

6490

6491
@register_meta(aten._padded_dense_to_jagged_forward.default)
6492
def meta__padded_dense_to_jagged_forward(
6493
    padded: Tensor,
6494
    offsets: List[Tensor],
6495
    total_L: Optional[int] = None,
6496
):
6497
    # only one jagged dim is supported for now
6498
    assert len(offsets) == 1
6499

6500
    if not total_L:
6501
        assert isinstance(padded, torch._subclasses.FakeTensor)
6502
        shape_env = padded.fake_mode.shape_env
6503
        assert shape_env is not None
6504
        total_L = shape_env.create_unbacked_symint()
6505
        torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
6506
            total_L, min=0, max=None
6507
        )
6508

6509
    output_shape = (total_L, *padded.shape[2:])
6510
    return padded.new_empty(output_shape)
6511

6512

6513
def _create_unary_float_meta_func(func):
6514
    @register_meta(func)
6515
    @out_wrapper()
6516
    def _f(x):
6517
        return elementwise_meta(
6518
            x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6519
        )
6520

6521
    return _f
6522

6523

6524
def _create_binary_float_meta_func(func):
6525
    @register_meta(func)
6526
    @out_wrapper()
6527
    def _f(x, y):
6528
        return elementwise_meta(
6529
            x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6530
        )
6531

6532
    return _f
6533

6534

6535
_create_unary_float_meta_func(aten.special_airy_ai)
6536
_create_unary_float_meta_func(aten.special_bessel_y0)
6537
_create_unary_float_meta_func(aten.special_bessel_y1)
6538
_create_unary_float_meta_func(aten.special_modified_bessel_i0)
6539
_create_unary_float_meta_func(aten.special_modified_bessel_i1)
6540
_create_unary_float_meta_func(aten.special_modified_bessel_k0)
6541
_create_unary_float_meta_func(aten.special_modified_bessel_k1)
6542
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
6543
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
6544

6545

6546
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
6547
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
6548
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
6549
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
6550
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
6551
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
6552
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
6553
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
6554
_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
6555
_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
6556
_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
6557
_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
6558

6559

6560
# We must also trigger meta registrations from PrimTorch ref
6561
# decompositions
6562
import torch._refs
6563
import torch._refs.nn.functional
6564
import torch._refs.special
6565

6566

6567
def activate_meta():
6568
    activate_meta_table = {}
6569

6570
    # For a given op, we pick the most specific decomp function from
6571
    # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
6572
    for type in ["meta", "post_autograd", "pre_autograd"]:
6573
        registry = global_decomposition_table[type]
6574

6575
        for opo in registry:
6576
            if opo not in activate_meta_table:
6577
                activate_meta_table[opo] = registry[opo]
6578

6579
    for op_overload, fn in activate_meta_table.items():
6580
        # Don't register meta for HigherOrderOp's decomp.
6581
        # We can reconsider this in the future, but in general,
6582
        # the way you do a meta for a HigherOrderOp is different from
6583
        # OpOverload.
6584
        if isinstance(op_overload, torch._ops.HigherOrderOperator):
6585
            continue
6586
        assert isinstance(op_overload, OpOverload)
6587

6588
        op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6589

6590
        if torch._C._dispatch_has_kernel_for_dispatch_key(
6591
            op_overload.name(), "CompositeImplicitAutograd"
6592
        ):
6593
            # Internally, we shouldn't be registering meta kernels for any operators that
6594
            # have CompositeImplicitAutograd kernels.
6595
            # Instead, we should be letting those decompositions run, and writing meta kernels
6596
            # only for the base operators.
6597
            if op_overload in global_decomposition_table["meta"]:
6598
                raise RuntimeError(
6599
                    f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
6600
                    "register meta function for it. Instead, we should let the decomposition run and write "
6601
                    "meta kernels for the base operators."
6602
                )
6603
        elif op_overload.is_view:
6604
            # Attempting to register a python meta kernel for a view operator.
6605
            # We shouldn't do this, because the output will report as not having aliased storages.
6606
            # All view ops have meta kernels in C++ today, so we should use those instead.
6607
            pass
6608
        elif (
6609
            op_overload.name()
6610
            in {
6611
                "aten::empty_strided",  # causing infinite recursion, test_meta.py
6612
                "aten::clone",  # causing infinite recursion
6613
                "aten::_to_copy",  # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite  # noqa: B950
6614
                "aten::copy_",  # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64  # noqa: B950
6615
                "aten::constant_pad_nd",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32  # noqa: B950
6616
                "aten::rot90",  # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32  # noqa: B950
6617
                "aten::as_strided_scatter",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32  # noqa: B950
6618
            }
6619
        ):
6620
            pass
6621
        else:
6622
            if "mkldnn::" in op_overload.name():
6623
                _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
6624
            elif "mkl::" in op_overload.name():
6625
                _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
6626
            elif "onednn::" in op_overload.name():
6627
                _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
6628
            elif "quantized::" in op_overload.name():
6629
                _meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
6630
                    op_overload, fn
6631
                )
6632
            else:
6633
                _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
6634

6635

6636
activate_meta()
6637

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.