pytorch

Форк
0
/
_meta_registrations.py 
6240 строк · 187.6 Кб
1
import math
2
from enum import Enum
3
from functools import partial
4
from typing import List, Optional, Sequence, Tuple, Union
5

6
import torch
7
import torch._prims_common as utils
8
from torch import SymBool, SymFloat, Tensor
9
from torch._decomp import (
10
    _add_op_to_registry,
11
    _convert_out_params,
12
    global_decomposition_table,
13
    meta_table,
14
)
15
from torch._ops import OpOverload
16
from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
17
from torch._prims_common import (
18
    corresponding_complex_dtype,
19
    corresponding_real_dtype,
20
    elementwise_dtypes,
21
    ELEMENTWISE_TYPE_PROMOTION_KIND,
22
    IntLike,
23
    make_contiguous_strides_for,
24
    TensorLike,
25
)
26

27
from torch._prims_common.wrappers import (
28
    _maybe_convert_to_dtype,
29
    _maybe_resize_out,
30
    _resize_output_check,
31
    _safe_copy_out,
32
    out_wrapper,
33
)
34
from torch._refs import _broadcast_shapes, _maybe_broadcast
35
from torch.utils import _pytree as pytree
36

37

38
aten = torch.ops.aten
39

40
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
41

42

43
def register_meta(op):
44
    def wrapper(fn):
45
        fn = _convert_out_params(fn)
46

47
        def register(op):
48
            _add_op_to_registry(meta_table, op, fn)
49

50
        pytree.tree_map_(register, op)
51
        return fn
52

53
    return wrapper
54

55

56
def elementwise_meta(
57
    *args,
58
    type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
59
):
60
    # Perform type promotion, as this is expected from prim_metafunction
61
    _, result_dtype = utils.elementwise_dtypes(
62
        *args,
63
        type_promotion_kind=type_promotion,
64
    )
65
    args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
66

67
    # Broadcast
68
    args = _maybe_broadcast(*args)
69

70
    # Perform prim checks
71
    return _prim_elementwise_meta(
72
        *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
73
    )
74

75

76
def toRealValueType(dtype):
77
    from_complex = {
78
        torch.complex32: torch.half,
79
        torch.cfloat: torch.float,
80
        torch.cdouble: torch.double,
81
    }
82
    return from_complex.get(dtype, dtype)
83

84

85
def check_inplace_broadcast(self_shape, *args_shape):
86
    broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
87
    torch._check(
88
        broadcasted_shape == self_shape,
89
        lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
90
    )
91

92

93
@register_meta([aten.linspace, aten.logspace])
94
@out_wrapper()
95
def meta_linspace_logspace(
96
    start,
97
    end,
98
    steps,
99
    base=None,
100
    dtype=None,
101
    device=None,
102
    layout=torch.strided,
103
    pin_memory=False,
104
    requires_grad=False,
105
):
106
    if isinstance(start, torch.Tensor):
107
        torch._check(
108
            start.dim() == 0,
109
            lambda: "linspace only supports 0-dimensional start and end tensors",
110
        )
111
    if isinstance(end, torch.Tensor):
112
        torch._check(
113
            end.dim() == 0,
114
            lambda: "linspace only supports 0-dimensional start and end tensors",
115
        )
116

117
    if any(isinstance(arg, complex) for arg in (start, end, steps)):
118
        default_complex_dtype = utils.corresponding_complex_dtype(
119
            torch.get_default_dtype()
120
        )
121
        if dtype is None:
122
            dtype = default_complex_dtype
123
        else:
124
            torch._check(
125
                utils.is_complex_dtype(dtype),
126
                lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
127
            )
128
    else:
129
        dtype = dtype or torch.get_default_dtype()
130
    assert isinstance(dtype, torch.dtype)
131

132
    # steps does not participate in the computation of the dtype
133
    torch._check_type(
134
        isinstance(steps, IntLike),
135
        lambda: f"received an invalid combination of arguments - got \
136
({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
137
    )
138
    assert isinstance(steps, IntLike)  # for mypy
139
    torch._check(steps >= 0, lambda: "number of steps must be non-negative")
140

141
    return torch.empty(
142
        (steps,),  # type: ignore[arg-type]
143
        dtype=dtype,
144
        layout=layout,
145
        device="meta",
146
        pin_memory=pin_memory,
147
        requires_grad=requires_grad,
148
    )
149

150

151
@register_meta([aten.take.default, aten.take.out])
152
@out_wrapper()
153
def meta_take(self, index):
154
    # Type and device checks
155
    torch._check(
156
        index.dtype == torch.long,
157
        lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
158
    )
159
    # Index checks
160
    torch._check_index(
161
        not (self.numel() == 0 and index.numel() != 0),
162
        lambda: "take(): tried to take from an empty tensor",
163
    )
164
    return self.new_empty(index.shape)
165

166

167
@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
168
@out_wrapper()
169
def linalg_cross(self, other, *, dim=-1):
170
    x_d = self.ndim
171
    y_d = other.ndim
172
    torch._check(
173
        x_d == y_d,
174
        lambda: "linalg.cross: inputs must have the same number of dimensions.",
175
    )
176
    torch._check(
177
        self.size(dim) == 3 and other.size(dim) == 3,
178
        lambda: (
179
            f"linalg.cross: inputs dimension {dim} must have length 3. "
180
            f"Got {self.size(dim)} and {other.size(dim)}"
181
        ),
182
    )
183
    out_shape = _broadcast_shapes(self.shape, other.shape)
184
    return self.new_empty(out_shape)
185

186

187
@register_meta(aten.linalg_matrix_exp)
188
@out_wrapper()
189
def linalg_matrix_exp(self):
190
    squareCheckInputs(self, "linalg.matrix_exp")
191
    checkFloatingOrComplex(self, "linalg.matrix_exp")
192
    return torch.empty_like(self, memory_format=torch.contiguous_format)
193

194

195
@register_meta(
196
    [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
197
)
198
@out_wrapper("values", "indices")
199
def cummaxmin(self, dim):
200
    values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
201
    indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
202
    if self.numel() != 0 and self.ndim != 0:
203
        # Checks that dim is within bounds
204
        maybe_wrap_dim(dim, self.ndim)
205
    return values, indices
206

207

208
@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
209
@out_wrapper()
210
def logcumsumexp(self, dim):
211
    # Checks that dim is within bounds
212
    maybe_wrap_dim(dim, self.ndim)
213
    return torch.empty_like(self).contiguous()
214

215

216
# Stride-related code from _exec_fft in aten/src/ATen/native/cuda/SpectralOps.cpp
217
def _exec_fft(out, self, out_sizes, dim, forward):
218
    ndim = self.ndim
219
    signal_ndim = len(dim)
220
    batch_dims = ndim - signal_ndim
221

222
    # Permute dimensions so batch dimensions come first, and in stride order
223
    dim_permute = list(range(ndim))
224

225
    is_transformed_dim = [False for _ in range(ndim)]
226
    for d in dim:
227
        is_transformed_dim[d] = True
228

229
    # std::partition
230
    left, right = [], []
231
    for d in dim_permute:
232
        if not is_transformed_dim[d]:
233
            left.append(d)
234
        else:
235
            right.append(d)
236
    dim_permute = left + right
237
    batch_end = len(left)
238

239
    self_strides = self.stride()
240
    tmp = dim_permute[:batch_end]
241
    tmp.sort(key=lambda x: self_strides[x], reverse=True)
242
    dim_permute = tmp + dim_permute[batch_end:]
243
    input = self.permute(dim_permute)
244

245
    # Collapse batch dimensions into a single dimension
246
    batched_sizes = [-1] + list(input.shape[batch_dims:])
247
    input = input.reshape(batched_sizes)
248

249
    batch_size = input.size(0)
250
    batched_sizes[0] = batch_size
251
    batched_out_sizes = batched_sizes
252
    for i in range(len(dim)):
253
        batched_out_sizes[i + 1] = out_sizes[dim[i]]
254
    out = out.reshape(batched_out_sizes)
255

256
    # Reshaping to original batch shape and inverting the dimension permutation
257
    out_strides = [0 for _ in range(ndim)]
258
    batch_numel = 1
259
    i = batch_dims - 1
260
    while i >= 0:
261
        out_strides[dim_permute[i]] = batch_numel * out.stride(0)
262
        batch_numel *= out_sizes[dim_permute[i]]
263
        i -= 1
264
    for i in range(batch_dims, ndim):
265
        out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
266
    return out.as_strided(out_sizes, out_strides, out.storage_offset())
267

268

269
# See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
270
# and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
271
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
272
@out_wrapper()
273
def meta_fft_c2c(self, dim, normalization, forward):
274
    assert self.dtype.is_complex
275

276
    out_sizes = self.shape
277
    output = self.new_empty(out_sizes)
278

279
    if not dim:
280
        return output
281

282
    sorted_dims = dim[:]
283
    self_strides = self.stride()
284
    sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
285
    output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
286

287
    return output
288

289

290
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
291
@out_wrapper()
292
def meta_fft_r2c(self, dim, normalization, onesided):
293
    assert self.dtype.is_floating_point
294
    output_sizes = list(self.size())
295

296
    if onesided:
297
        last_dim = dim[-1]
298
        last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
299
        output_sizes[last_dim] = last_dim_halfsize
300

301
    return self.new_empty(
302
        output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
303
    )
304

305

306
@register_meta(aten.randperm.generator_out)
307
def meta_randperm(n, *, generator=None, out):
308
    return _maybe_resize_out(out, torch.Size([n]))
309

310

311
@register_meta(aten.randperm.default)
312
def meta_randperm_default(
313
    n, *, dtype=torch.long, layout=None, device=None, pin_memory=None
314
):
315
    return torch.empty(
316
        n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
317
    )
318

319

320
@register_meta(aten.randint.default)
321
def meta_randint(
322
    high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
323
):
324
    return torch.empty(
325
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
326
    )
327

328

329
@register_meta(aten.randint.low)
330
def meta_randint_low(
331
    low,
332
    high,
333
    size,
334
    *,
335
    dtype=torch.long,
336
    layout=None,
337
    device=None,
338
    pin_memory=None,
339
):
340
    return torch.empty(
341
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
342
    )
343

344

345
@register_meta(aten.rand.default)
346
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
347
    return torch.empty(
348
        size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
349
    )
350

351

352
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
353
@out_wrapper()
354
def meta_fft_c2r(self, dim, normalization, lastdim):
355
    assert self.dtype.is_complex
356
    output_sizes = list(self.size())
357
    output_sizes[dim[-1]] = lastdim
358
    return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
359

360

361
@register_meta(aten.copy_.default)
362
def meta_copy_(self, src, non_blocking=False):
363
    # This code simulates the original decomp from inductor,
364
    # which runs most of the meta checks that we care about.
365
    # In theory, we should make this more robust by carefully
366
    # auditing our C++ copy_() kernel and copying the checks here.
367

368
    if torch._debug_has_internal_overlap(self) == 1:  # 1 == MemOverlap::Yes
369
        raise RuntimeError(
370
            "more than one element of the written-to tensor refers to a single memory location"
371
        )
372

373
    if isinstance(src, Tensor):
374
        intermediate = src.to(self, non_blocking)
375
        if self.size() != intermediate.size():
376
            aten.expand_copy.default(intermediate, self.size())
377
    return self
378

379

380
def inferUnsqueezeGeometry(tensor, dim):
381
    result_sizes = list(tensor.size())
382
    result_strides = list(tensor.stride())
383
    new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
384
    result_sizes.insert(dim, 1)
385
    result_strides.insert(dim, new_stride)
386
    return result_sizes, result_strides
387

388

389
@register_meta(aten.unsqueeze_.default)
390
def meta_unsqueeze_(self, dim):
391
    dim = maybe_wrap_dim(dim, self.dim() + 1)
392
    g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
393
    self.as_strided_(g_sizes, g_strides)
394
    return self
395

396

397
@register_meta(aten._sparse_semi_structured_linear)
398
def meta_sparse_structured_linear(
399
    input: Tensor,
400
    weight: Tensor,
401
    _meta: Tensor,
402
    bias: Optional[Tensor] = None,
403
    _activation_opt: Optional[str] = None,
404
    out_dtype: Optional[torch.dtype] = None,
405
):
406
    output_sizes = list(input.shape)
407
    if bias is not None:
408
        assert weight.size(0) == bias.size(0), "output size mismatch"
409
    assert weight.size(1) == input.size(-1) / 2
410
    output_sizes[-1] = weight.size(0)
411

412
    # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
413
    # We assume that we have already squashed the inputs into a 2-D tensor
414
    # Then, as the output is transposed, we need to propagate the transposed
415
    # stride information to the output tensor
416
    assert len(input.shape) == 2, "we can only handle the squashed input case"
417
    transposed_strides = (1, input.size(0))
418

419
    if out_dtype is not None:
420
        assert (
421
            input.dtype == torch.int8 and out_dtype == torch.int32
422
        ), "out_dtype is only supported for i8i8->i32 linear operator"
423
    output = input.new_empty(
424
        output_sizes,
425
        dtype=input.dtype if out_dtype is None else out_dtype,
426
    ).as_strided(output_sizes, transposed_strides)
427

428
    return output
429

430

431
@register_meta(aten._cslt_sparse_mm)
432
def meta__cslt_sparse_mm(
433
    compressed_A: torch.Tensor,
434
    dense_B: torch.Tensor,
435
    bias: Optional[Tensor] = None,
436
    alpha: Optional[Tensor] = None,
437
    out_dtype: Optional[torch.dtype] = None,
438
    transpose_result: bool = False,
439
):
440
    assert dense_B.dtype in {
441
        torch.float32,
442
        torch.float16,
443
        torch.bfloat16,
444
        torch.int8,
445
    }, "_cslt_sparse_mm only supports fp16, bf16, and int8"
446
    assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
447
    assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
448

449
    is_int8_input_type = compressed_A.dtype == torch.int8
450
    compression_factor = 10 if is_int8_input_type else 9
451
    k = dense_B.size(0)
452
    n = dense_B.size(1)
453
    m = (compressed_A.numel() * 16) // (compression_factor * k)
454
    if bias is not None:
455
        assert m == bias.size(0)
456

457
    if out_dtype is not None:
458
        assert is_int8_input_type and out_dtype in {
459
            torch.float16,
460
            torch.bfloat16,
461
            torch.int32,
462
        }, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
463
    output_shape = (n, m) if transpose_result else (m, n)
464
    result = dense_B.new_empty(output_shape, dtype=out_dtype)
465
    return result
466

467

468
@register_meta(aten.index_reduce.default)
469
def meta_index_reduce(
470
    self: Tensor,
471
    dim: int,
472
    index: Tensor,
473
    source: torch.Tensor,
474
    reduce: str,
475
    *,
476
    include_self: bool = True,
477
) -> Tensor:
478
    return torch.empty_like(self, memory_format=torch.contiguous_format)
479

480

481
@register_meta(aten.index_reduce_.default)
482
def meta_index_reduce_(
483
    self: Tensor,
484
    dim: int,
485
    index: Tensor,
486
    source: torch.Tensor,
487
    reduce: str,
488
    *,
489
    include_self: bool = True,
490
) -> Tensor:
491
    return self
492

493

494
# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
495
@out_wrapper()
496
@register_meta(aten.index_select.default)
497
def meta_index_select(self, dim, index):
498
    result_size = list(self.size())
499
    if self.dim() > 0:
500
        result_size[dim] = index.numel()
501
    return self.new_empty(result_size)
502

503

504
@register_meta(aten.segment_reduce.default)
505
def meta_segment_reduce(
506
    data: Tensor,
507
    reduce: str,
508
    *,
509
    lengths: Optional[Tensor] = None,
510
    indices: Optional[Tensor] = None,
511
    offsets: Optional[Tensor] = None,
512
    axis: int = 0,
513
    unsafe: bool = False,
514
    initial=None,
515
) -> Tensor:
516
    if indices is not None:
517
        raise NotImplementedError(
518
            "segment_reduce(): indices based reduction is not supported yet."
519
        )
520

521
    def segment_reduce_lengths_tensor(lengths_shape):
522
        return torch.empty(
523
            lengths_shape + data.shape[axis + 1 :],
524
            dtype=data.dtype,
525
            device="meta",
526
            memory_format=torch.contiguous_format,
527
        )
528

529
    if lengths is not None:
530
        return segment_reduce_lengths_tensor(lengths.shape)
531
    # FIXME should probably check that lengths and offset aren't both set, but
532
    # the ATen implementation neglects this too
533
    if offsets is not None:
534
        # lengths == torch.diff(offsets)
535
        lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
536
        return segment_reduce_lengths_tensor(lengths_shape)
537
    raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
538

539

540
@register_meta([aten.max.default, aten.max.unary_out])
541
@out_wrapper()
542
def meta_max(self):
543
    return self.new_empty(())
544

545

546
@register_meta(aten.max.dim)
547
def meta_max_dim(self, dim, keepdim=False):
548
    dim = utils.reduction_dims(self.shape, (dim,))
549
    output_shape = _compute_reduction_shape(self, dim, keepdim)
550
    return (
551
        self.new_empty(output_shape),
552
        self.new_empty(output_shape, dtype=torch.long),
553
    )
554

555

556
@register_meta([aten.min.default, aten.min.unary_out])
557
@out_wrapper()
558
def meta_min(self):
559
    return self.new_empty(())
560

561

562
@register_meta(aten.min.dim)
563
def meta_min_dim(self, dim, keepdim=False):
564
    dim = utils.reduction_dims(self.shape, (dim,))
565
    output_shape = _compute_reduction_shape(self, dim, keepdim)
566
    return (
567
        self.new_empty(output_shape),
568
        self.new_empty(output_shape, dtype=torch.long),
569
    )
570

571

572
@register_meta(aten.angle.default)
573
def meta_angle(self):
574
    if self.is_complex():
575
        result_dtype = corresponding_real_dtype(self.dtype)
576
    else:
577
        _, result_dtype = elementwise_dtypes(
578
            self,
579
            type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
580
        )
581
    return torch.empty_like(self, dtype=result_dtype)
582

583

584
@register_meta(aten.angle.out)
585
def meta_angle_out(self, out):
586
    torch._resize_output_(out, self.size(), self.device)
587
    return out.copy_(torch.angle(self))
588

589

590
@register_meta(aten._assert_async.default)
591
def assert_async(val):
592
    return
593

594

595
@register_meta(aten._assert_async.msg)
596
def assert_async_meta(val, assert_msg):
597
    return
598

599

600
@register_meta(aten._print.default)
601
def print_meta(s):
602
    return
603

604

605
@register_meta(aten._make_dep_token.default)
606
def make_dep_token(
607
    *,
608
    dtype=None,
609
    layout=None,
610
    device=None,
611
    pin_memory=None,
612
    memory_format=None,
613
):
614
    return torch.empty([], device="meta")
615

616

617
@register_meta(aten.sym_constrain_range.default)
618
def sym_constrain_range(size, min=None, max=None):
619
    # Avoid importing sympy at a module level
620
    from torch.fx.experimental.symbolic_shapes import constrain_range
621

622
    if isinstance(size, (SymFloat, SymBool)):
623
        raise ValueError("Constraining SymFloat or Symbool is nyi")
624
    constrain_range(size, min=min, max=max)
625

626

627
@register_meta(aten._functional_sym_constrain_range.default)
628
def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
629
    aten.sym_constrain_range(size, min=min, max=max)
630
    return dep_token
631

632

633
@register_meta(aten.sym_constrain_range_for_size.default)
634
def sym_constrain_range_for_size(size, min=None, max=None):
635
    # Avoid importing sympy at a module level
636
    from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
637

638
    if isinstance(size, (SymFloat, SymBool)):
639
        raise ValueError("Constraining SymFloat or Symbool is nyi")
640
    _constrain_range_for_size(size, min=min, max=max)
641

642

643
@register_meta(aten._functional_sym_constrain_range_for_size.default)
644
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
645
    aten.sym_constrain_range_for_size(size, min=min, max=max)
646
    return dep_token
647

648

649
@register_meta(aten._functional_assert_async.msg)
650
def functional_assert_async_meta(val, assert_msg, dep_token):
651
    return dep_token
652

653

654
# From aten/src/ATen/native/LinearAlgebraUtils.h
655
def squareCheckInputs(self: Tensor, f_name: str):
656
    assert (
657
        self.dim() >= 2
658
    ), f"{f_name}: The input tensor must have at least 2 dimensions."
659
    assert self.size(-1) == self.size(
660
        -2
661
    ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
662

663

664
# Validates input shapes and devices
665
# for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
666
# From aten/src/ATen/native/LinearAlgebraUtils.h
667
def linearSolveCheckInputs(
668
    self: Tensor,
669
    A: Tensor,
670
    name: str,
671
):
672
    torch._check(
673
        self.device == A.device,
674
        lambda: (
675
            f"Expected b and A to be on the same device, but found b on "
676
            f"{self.device} and A on {A.device} instead."
677
        ),
678
    )
679

680
    torch._check(
681
        self.dtype == A.dtype,
682
        lambda: (
683
            f"Expected b and A to have the same dtype, but found b of type "
684
            f"{self.dtype} and A of type {A.dtype} instead."
685
        ),
686
    )
687

688
    torch._check(
689
        A.size(-1) == A.size(-2),
690
        lambda: (
691
            f"A must be batches of square matrices, "
692
            f"but they are {A.size(-2)} by {A.size(-1)} matrices"
693
        ),
694
    )
695

696
    torch._check(
697
        A.size(-1) == self.size(-2),
698
        lambda: (
699
            f"Incompatible matrix sizes for {name}: each A "
700
            f"matrix is {A.size(-1)} by {A.size(-1)}"
701
            f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
702
        ),
703
    )
704

705

706
# From aten/src/ATen/native/LinearAlgebraUtils.h
707
def checkFloatingOrComplex(
708
    t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
709
):
710
    dtype = t.dtype
711
    torch._check(
712
        t.is_floating_point() or t.is_complex(),
713
        lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
714
    )
715
    if not allow_low_precision_dtypes:
716
        torch._check(
717
            dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
718
            lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
719
        )
720

721

722
# From aten/src/ATen/native/LinearAlgebraUtils.h
723
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
724
    torch._check(
725
        A.dim() >= 2,
726
        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
727
    )
728

729

730
def checkInputsSolver(
731
    A: Tensor,
732
    B: Tensor,
733
    left: bool,
734
    f_name: str,
735
):
736
    squareCheckInputs(A, f_name)
737
    checkIsMatrix(B, f_name)
738
    torch._check(
739
        A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
740
        lambda: (
741
            f"{f_name}: Incompatible shapes of A and B for the equation "
742
            f"{'AX = B' if left else 'XA = B'}"
743
            f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
744
        ),
745
    )
746

747

748
def checkSameDevice(
749
    fn_name: str, result: Tensor, input: Tensor, result_name: str = "result"
750
):
751
    torch._check(
752
        result.device == input.device,
753
        lambda: (
754
            f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
755
            f"{result_name} on {result.device} and input on {input.device}"
756
        ),
757
    )
758

759

760
def checkUplo(UPLO: str):
761
    UPLO_uppercase = UPLO.upper()
762
    torch._check(
763
        len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
764
        lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
765
    )
766

767

768
@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
769
@out_wrapper("eigenvalues", "eigenvectors")
770
def meta__linalg_eigh(
771
    A: Tensor,
772
    UPLO: str = "L",
773
    compute_v: bool = True,
774
):
775
    squareCheckInputs(A, "linalg.eigh")
776
    checkUplo(UPLO)
777

778
    shape = list(A.shape)
779
    if compute_v:
780
        vecs = A.new_empty(shape)
781
        vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
782
    else:
783
        vecs = A.new_empty([0])
784

785
    shape.pop()
786
    vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
787

788
    return vals, vecs
789

790

791
def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
792
    return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
793

794

795
@register_meta(aten._cholesky_solve_helper)
796
@out_wrapper()
797
def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
798
    return cloneBatchedColumnMajor(self)
799

800

801
@register_meta(aten.cholesky_solve)
802
@out_wrapper()
803
def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
804
    torch._check(
805
        self.ndim >= 2,
806
        lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
807
    )
808
    torch._check(
809
        A.ndim >= 2,
810
        lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
811
    )
812
    self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
813
        self, A, "cholesky_solve"
814
    )
815
    return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
816

817

818
@register_meta(aten.cholesky)
819
@out_wrapper()
820
def cholesky(self: Tensor, upper: bool = False) -> Tensor:
821
    if self.numel() == 0:
822
        return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
823
    squareCheckInputs(self, "cholesky")
824
    return cloneBatchedColumnMajor(self)
825

826

827
@register_meta(aten.cholesky_inverse)
828
@out_wrapper()
829
def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
830
    squareCheckInputs(self, "cholesky_inverse")
831
    return cloneBatchedColumnMajor(self)
832

833

834
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
835
@register_meta(aten.linalg_cholesky_ex.default)
836
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
837
    squareCheckInputs(A, "linalg.cholesky")
838
    checkFloatingOrComplex(A, "linalg.cholesky")
839

840
    A_shape = A.shape
841
    ndim = len(A_shape)
842

843
    # L
844
    L_strides = make_contiguous_strides_for(A_shape, False)
845
    L = A.new_empty(A_shape)
846
    L.as_strided_(A_shape, L_strides)
847

848
    # infos
849
    infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
850
    return L, infos
851

852

853
@register_meta(
854
    [aten.linalg_householder_product.default, aten.linalg_householder_product.out]
855
)
856
@out_wrapper()
857
def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
858
    torch._check(
859
        input.ndim >= 2,
860
        lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
861
    )
862
    torch._check(
863
        input.size(-2) >= input.size(-1),
864
        lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
865
    )
866
    torch._check(
867
        input.size(-1) >= tau.size(-1),
868
        lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
869
    )
870

871
    torch._check(
872
        input.ndim - tau.ndim == 1,
873
        lambda: (
874
            f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
875
            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
876
        ),
877
    )
878
    if input.ndim > 2:
879
        expected_batch_tau_shape = input.shape[:-2]
880
        actual_batch_tau_shape = tau.shape[:-1]
881
        torch._check(
882
            actual_batch_tau_shape == expected_batch_tau_shape,
883
            lambda: (
884
                f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
885
                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
886
            ),
887
        )
888

889
    torch._check(
890
        tau.dtype == input.dtype,
891
        lambda: (
892
            f"torch.linalg.householder_product: tau dtype {tau.dtype}"
893
            f" does not match input dtype {input.dtype}"
894
        ),
895
    )
896
    checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
897

898
    return torch.empty_strided(
899
        size=input.shape,
900
        stride=make_contiguous_strides_for(input.shape, row_major=False),
901
        dtype=input.dtype,
902
        device=input.device,
903
    )
904

905

906
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
907
@register_meta(aten.linalg_inv_ex.default)
908
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
909
    squareCheckInputs(A, "linalg.inv_ex")
910
    checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
911

912
    L = A.new_empty(A.shape)
913
    L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
914

915
    infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
916
    return L, infos
917

918

919
@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
920
@out_wrapper("LD", "pivots", "info")
921
def linalg_ldl_factor_ex_meta(
922
    self: Tensor,
923
    *,
924
    hermitian: bool = False,
925
    check_errors: bool = False,
926
) -> Tuple[Tensor, Tensor, Tensor]:
927
    squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
928
    checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
929
    LD = torch.empty_strided(
930
        size=self.shape,
931
        stride=make_contiguous_strides_for(self.shape, row_major=False),
932
        dtype=self.dtype,
933
        device=self.device,
934
    )
935
    pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
936
    info = self.new_empty(self.shape[:-2], dtype=torch.int)
937
    return LD, pivots, info
938

939

940
@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
941
@out_wrapper()
942
def linalg_ldl_solve_meta(
943
    LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False
944
) -> Tensor:
945
    squareCheckInputs(LD, "torch.linalg.ldl_solve")
946
    checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
947
    linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
948
    torch._check(
949
        B.ndim >= 2,
950
        lambda: (
951
            f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
952
            f"but it has {B.ndim} dimensions instead"
953
        ),
954
    )
955
    expected_pivots_shape = LD.shape[:-1]
956
    torch._check(
957
        expected_pivots_shape == pivots.shape,
958
        lambda: (
959
            f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
960
            f"but got pivots with shape {pivots.shape} instead"
961
        ),
962
    )
963
    torch._check(
964
        utils.is_integer_dtype(pivots.dtype),
965
        lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
966
    )
967
    torch._check(
968
        LD.dtype == B.dtype,
969
        lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
970
    )
971
    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
972
    return torch.empty_strided(
973
        size=B_broadcast_size,
974
        stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
975
        dtype=B.dtype,
976
        device=B.device,
977
    )
978

979

980
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
981
@out_wrapper("P", "L", "U")
982
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
983
    torch._check(
984
        A.ndim >= 2,
985
        lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
986
    )
987

988
    sizes = list(A.shape)
989
    m = sizes[-2]
990
    n = sizes[-1]
991
    k = min(m, n)
992

993
    sizes[-1] = m
994
    if pivot:
995
        P = A.new_empty(sizes)
996
    else:
997
        P = A.new_empty([0])
998

999
    sizes[-1] = k
1000
    L = A.new_empty(sizes)
1001

1002
    sizes[-2] = k
1003
    sizes[-1] = n
1004
    U = A.new_empty(sizes)
1005
    return P, L, U
1006

1007

1008
@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
1009
@out_wrapper("LU", "pivots", "info")
1010
def linalg_lu_factor_ex_meta(
1011
    A: Tensor, *, pivot: bool = True, check_errors: bool = False
1012
) -> Tuple[Tensor, Tensor, Tensor]:
1013
    torch._check(
1014
        A.ndim >= 2,
1015
        lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1016
    )
1017

1018
    sizes = list(A.shape)
1019
    m = sizes[-2]
1020
    n = sizes[-1]
1021

1022
    LU = torch.empty_strided(
1023
        size=sizes,
1024
        stride=make_contiguous_strides_for(sizes, row_major=False),
1025
        dtype=A.dtype,
1026
        device=A.device,
1027
    )
1028

1029
    # Sets sizes to the size of pivots
1030
    sizes.pop()
1031
    sizes[-1] = min(m, n)
1032
    pivots = A.new_empty(sizes, dtype=torch.int)
1033

1034
    # Sets sizes to the size of info
1035
    sizes.pop()
1036
    info = A.new_empty(sizes, dtype=torch.int)
1037

1038
    return LU, pivots, info
1039

1040

1041
@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
1042
@out_wrapper()
1043
def linalg_lu_solve_meta(
1044
    LU: Tensor,
1045
    pivots: Tensor,
1046
    B: Tensor,
1047
    *,
1048
    left: bool = True,
1049
    adjoint: bool = False,
1050
) -> Tensor:
1051
    # dtype
1052
    checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
1053
    torch._check(
1054
        LU.dtype == B.dtype,
1055
        lambda: (
1056
            f"linalg.lu_solve: Expected LU and B to have the same dtype, "
1057
            f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
1058
        ),
1059
    )
1060
    torch._check(
1061
        pivots.dtype == torch.int,
1062
        lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
1063
    )
1064

1065
    # matrix shapes
1066
    squareCheckInputs(LU, "torch.linalg.lu_solve")
1067
    checkInputsSolver(LU, B, left, "linalg.lu_solve")
1068
    torch._check(
1069
        LU.size(-1) == pivots.size(-1),
1070
        lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
1071
    )
1072

1073
    # batches
1074
    torch._check(
1075
        LU.shape[:-1] == pivots.shape,
1076
        lambda: (
1077
            f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
1078
            f"but got pivots with shape {pivots.shape} instead"
1079
        ),
1080
    )
1081

1082
    B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
1083

1084
    result = torch.empty_strided(
1085
        size=B_broadcast_size,
1086
        stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
1087
        dtype=B.dtype,
1088
        device=B.device,
1089
    )
1090

1091
    if result.numel() != 0 and not left:
1092
        if result.is_complex():
1093
            result = result.conj()
1094

1095
    return result
1096

1097

1098
@register_meta(aten.lu_unpack)
1099
@out_wrapper("P", "L", "U")
1100
def lu_unpack_meta(
1101
    LU: Tensor,
1102
    pivots: Tensor,
1103
    unpack_data: bool = True,
1104
    unpack_pivots: bool = True,
1105
) -> Tuple[Tensor, Tensor, Tensor]:
1106
    torch._check(
1107
        LU.ndim >= 2,
1108
        lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
1109
    )
1110
    if unpack_pivots:
1111
        torch._check(
1112
            pivots.dtype == torch.int32,
1113
            lambda: (
1114
                "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
1115
                "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
1116
            ),
1117
        )
1118
    sizes = list(LU.shape)
1119
    m = sizes[-2]
1120
    n = sizes[-1]
1121
    k = min(m, n)
1122
    sizes[-1] = m
1123
    if unpack_pivots:
1124
        P = LU.new_empty(sizes)
1125
    else:
1126
        P = LU.new_empty([0])
1127
    if unpack_data:
1128
        sizes[-1] = k
1129
        L = LU.new_empty(sizes)
1130
        sizes[-2] = k
1131
        sizes[-1] = n
1132
        U = LU.new_empty(sizes)
1133
    else:
1134
        L = LU.new_empty([0])
1135
        U = LU.new_empty([0])
1136
    return P, L, U
1137

1138

1139
# parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
1140
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
1141
    if mode == "reduced":
1142
        compute_q = True
1143
        reduced = True
1144
    elif mode == "complete":
1145
        compute_q = True
1146
        reduced = False
1147
    elif mode == "r":
1148
        compute_q = False
1149
        reduced = True  # this is actually irrelevant in this mode
1150
    else:
1151
        torch._check(
1152
            False,
1153
            lambda: (
1154
                f"qr received unrecognized mode '{mode}' "
1155
                f"but expected one of 'reduced' (default), 'r', or 'complete'"
1156
            ),
1157
        )
1158
    return compute_q, reduced  # type: ignore[possibly-undefined]
1159

1160

1161
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
1162
@out_wrapper("Q", "R")
1163
def linalg_qr_meta(
1164
    A: Tensor,
1165
    mode: str = "reduced",
1166
) -> Tuple[Tensor, Tensor]:
1167
    checkIsMatrix(A, "linalg.qr")
1168
    checkFloatingOrComplex(A, "linalg.qr")
1169

1170
    compute_q, reduced_mode = _parse_qr_mode(mode)
1171

1172
    m = A.shape[-2]
1173
    n = A.shape[-1]
1174
    k = min(m, n)
1175

1176
    if compute_q:
1177
        Q_shape = list(A.shape)
1178
        Q_shape[-1] = k if reduced_mode else m
1179
        Q = A.new_empty(Q_shape)
1180
        Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
1181
    else:
1182
        Q = A.new_empty([0])
1183

1184
    # For readability
1185
    R_shape = list(A.shape)
1186
    R_shape[-2] = k if reduced_mode or not compute_q else m
1187
    R = A.new_empty(R_shape)
1188
    R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
1189
    return Q, R
1190

1191

1192
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
1193
@out_wrapper("sign", "logabsdet", "LU", "pivots")
1194
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1195
    squareCheckInputs(A, "linalg.slogdet")
1196
    checkFloatingOrComplex(A, "linalg.slogdet", False)
1197
    shape = A.shape
1198
    sign = A.new_empty(shape[:-2])
1199
    logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
1200
    LU = torch.empty_strided(
1201
        size=shape,
1202
        stride=make_contiguous_strides_for(shape, False),
1203
        dtype=A.dtype,
1204
        device=A.device,
1205
    )
1206
    pivots = A.new_empty(shape[:-1], dtype=torch.int32)
1207
    return sign, logabsdet, LU, pivots
1208

1209

1210
# From aten/src/ATen/native/BatchLinearAlgebra.cpp
1211
# NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
1212
@register_meta(aten._linalg_svd.default)
1213
def _linalg_svd_meta(
1214
    A: Tensor,
1215
    full_matrices: bool = False,
1216
    compute_uv: bool = True,
1217
    driver: Optional[str] = None,
1218
):
1219
    checkIsMatrix(A, "linalg.svd")
1220
    checkFloatingOrComplex(A, "linalg.svd")
1221

1222
    batch_dims = list(A.shape[:-2])
1223
    m = A.shape[-2]
1224
    n = A.shape[-1]
1225
    k = min(m, n)
1226

1227
    if compute_uv:
1228
        U_shape = batch_dims + [m, m if full_matrices else k]
1229
        U = A.new_empty(U_shape)
1230
        U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
1231

1232
        V_shape = batch_dims + [n if full_matrices else k, n]
1233
        V = A.new_empty(V_shape)
1234
        # NB: This checks for CUDA since there is no way to check for cuSolver.
1235
        # Also, this might not work correctly on CPU when fake_device is not
1236
        # available as device_hint just defaults to CUDA in that case. See
1237
        # _linalg_svd meta in core.
1238
        is_cuda = device_hint(A) == "cuda"
1239
        V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
1240
    else:
1241
        # doesn't matter
1242
        U = A.new_empty([0])
1243
        V = A.new_empty([0])
1244

1245
    # S is always real, even when A is complex.
1246
    S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
1247
    return U, S, V
1248

1249

1250
def _linalg_broadcast_batch_dims(
1251
    arg1: Tensor, arg2: Tensor
1252
) -> Tuple[List[int], List[int]]:
1253
    # broadcast the batch dimensions of arg1 and arg2.
1254
    arg1_batch_sizes = arg1.shape[:-2]
1255
    arg2_batch_sizes = arg2.shape[:-2]
1256
    expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
1257

1258
    arg1_expand_size = list(expand_batch_portion)
1259
    arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
1260

1261
    arg2_expand_size = list(expand_batch_portion)
1262
    arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
1263
    return arg1_expand_size, arg2_expand_size
1264

1265

1266
def _linalg_broadcast_batch_dims_name(
1267
    arg1: Tensor, arg2: Tensor, name: Optional[str]
1268
) -> Tuple[Tensor, Tensor]:
1269
    # If there's no name we assume we don't want to check the errors
1270
    if name:
1271
        linearSolveCheckInputs(arg1, arg2, name)
1272

1273
    arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
1274

1275
    arg1_broadcasted = (
1276
        arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
1277
    )
1278
    arg2_broadcasted = (
1279
        arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
1280
    )
1281
    return arg1_broadcasted, arg2_broadcasted
1282

1283

1284
def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
1285
    expected_batched_rhs_shape = input.shape[:-1]
1286
    vector_case = other.ndim == 1 or (
1287
        input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
1288
    )
1289
    return vector_case
1290

1291

1292
@register_meta(aten._linalg_solve_ex)
1293
def _linalg_solve_ex(
1294
    A: Tensor,
1295
    B: Tensor,
1296
    *,
1297
    left: bool = True,
1298
    check_errors: bool = False,
1299
    result: Optional[Tensor] = None,
1300
    LU: Optional[Tensor] = None,
1301
    pivots: Optional[Tensor] = None,
1302
    info: Optional[Tensor] = None,
1303
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1304
    checkFloatingOrComplex(A, "linalg.solve")
1305
    torch._check(
1306
        A.dtype == B.dtype,
1307
        lambda: (
1308
            f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
1309
            f"{A.dtype} and B of type {B.dtype} instead"
1310
        ),
1311
    )
1312
    vector_case = linalg_solve_is_vector_rhs(A, B)
1313
    B_ = B.unsqueeze(-1) if vector_case else B
1314
    checkInputsSolver(A, B_, left, "linalg.solve")
1315
    B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
1316
    torch._check(
1317
        left or not vector_case,
1318
        lambda: (
1319
            "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
1320
            "In this case linalg.solve is equivalent to B / A.squeeze(-1)"
1321
        ),
1322
    )
1323
    result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
1324
    result_ = torch.empty_strided(
1325
        size=result_shape,
1326
        stride=make_contiguous_strides_for(result_shape, not left),
1327
        dtype=B.dtype,
1328
        device=B.device,
1329
    )
1330
    shape = A.shape
1331
    ndim = A.ndim
1332
    LU_ = torch.empty_strided(
1333
        size=shape,
1334
        stride=make_contiguous_strides_for(shape, False),
1335
        dtype=A.dtype,
1336
        device=A.device,
1337
    )
1338
    pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
1339
    info_ = A.new_empty(shape[:-2], dtype=torch.int32)
1340
    out = (result, LU, pivots, info)
1341
    res = (result_, LU_, pivots_, info_)
1342
    if all(x is not None for x in out):
1343
        for r, o in zip(res, out):
1344
            # resize and copy operations are done in-place
1345
            _maybe_resize_out(o, r.shape)  # type: ignore[arg-type]
1346
            # strides are not copied in out_wrapper
1347
            o.as_strided_(r.shape, r.stride())  # type: ignore[union-attr]
1348
            _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False)  # type: ignore[arg-type]
1349
    return res
1350

1351

1352
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
1353
def linalg_solve_triangular_meta(
1354
    A: Tensor,
1355
    B: Tensor,
1356
    *,
1357
    upper: bool,
1358
    left: bool = True,
1359
    unitriangular: bool = False,
1360
    out: Optional[Tensor] = None,
1361
) -> Tensor:
1362
    if out is None:
1363
        out = A.new_empty([0])
1364
    assert isinstance(out, TensorLike)
1365
    checkInputsSolver(A, B, left, "linalg.solve_triangular")
1366
    B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
1367
    avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
1368
    if avoid_copy_A:
1369
        out = _maybe_resize_out(out, B_.shape)
1370
    else:
1371
        # reimplementation of resize_output with result F-contig
1372
        if _resize_output_check(out, B_.shape):
1373
            out.resize_(B_.transpose(-2, -1).shape)
1374
            out.transpose_(-2, -1)
1375
    return out  # type: ignore[return-value]
1376

1377

1378
@register_meta(aten.triangular_solve)
1379
@out_wrapper("solution", "cloned_coefficient")
1380
def triangular_solve_meta(
1381
    self: Tensor,
1382
    A: Tensor,
1383
    upper: bool = True,
1384
    transpose: bool = False,
1385
    unitriangular: bool = False,
1386
) -> Tuple[Tensor, Tensor]:
1387
    torch._check(
1388
        self.ndim >= 2,
1389
        lambda: (
1390
            f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
1391
            f"but it has {self.ndim} dimensions instead"
1392
        ),
1393
    )
1394
    torch._check(
1395
        A.ndim >= 2,
1396
        lambda: (
1397
            f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
1398
            f"but it has {A.ndim} dimensions instead"
1399
        ),
1400
    )
1401

1402
    linearSolveCheckInputs(self, A, "triangular_solve")
1403

1404
    if A.layout == torch.strided:
1405
        self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
1406
        solution = torch.empty_strided(
1407
            size=self_broadcast_size,
1408
            stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
1409
            dtype=self.dtype,
1410
            device=self.device,
1411
        )
1412
        cloned_coefficient = torch.empty_strided(
1413
            size=A_broadcast_size,
1414
            stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
1415
            dtype=A.dtype,
1416
            device=A.device,
1417
        )
1418
    elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
1419
        solution = torch.empty_like(self)
1420
        cloned_coefficient = self.new_empty([0])
1421
    else:
1422
        torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
1423
    return solution, cloned_coefficient  # type: ignore[possibly-undefined]
1424

1425

1426
# From aten/src/ATen/native/LinearAlgebra.cpp
1427
@register_meta(aten._linalg_det.default)
1428
def _linalg_det_meta(A):
1429
    squareCheckInputs(A, "linalg.det")
1430
    checkFloatingOrComplex(A, "linalg.det")
1431

1432
    det = A.new_empty(A.shape[:-2])
1433

1434
    LU = A.new_empty(A.shape)
1435
    LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1436

1437
    pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
1438
    return det, LU, pivots
1439

1440

1441
@register_meta(aten.ormqr)
1442
@out_wrapper()
1443
def ormqr(
1444
    input: Tensor,
1445
    tau: Tensor,
1446
    other: Tensor,
1447
    left: bool = True,
1448
    transpose: bool = False,
1449
) -> Tensor:
1450
    torch._check(
1451
        input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
1452
    )
1453
    torch._check(
1454
        other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
1455
    )
1456

1457
    left_size_condition = -2 if left else -1
1458
    torch._check(
1459
        other.shape[left_size_condition] >= tau.shape[-1],
1460
        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
1461
    )
1462
    torch._check(
1463
        other.shape[left_size_condition] == input.shape[-2],
1464
        lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
1465
    )
1466

1467
    torch._check(
1468
        tau.shape[-1] <= input.shape[-1],
1469
        lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
1470
    )
1471

1472
    torch._check(
1473
        input.ndim - tau.ndim == 1,
1474
        lambda: (
1475
            f"torch.ormqr: Expected tau to have one dimension less than input, "
1476
            f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
1477
        ),
1478
    )
1479
    torch._check(
1480
        input.ndim == other.ndim,
1481
        lambda: (
1482
            f"torch.ormqr: Expected other to have the same number of dimensions as input, "
1483
            f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
1484
        ),
1485
    )
1486

1487
    if input.ndim > 2:
1488
        expected_batch_shape = input.shape[:-2]
1489
        actual_batch_tau_shape = tau.shape[:-1]
1490
        torch._check(
1491
            actual_batch_tau_shape == expected_batch_shape,
1492
            lambda: (
1493
                f"torch.ormqr: Expected batch dimensions of tau to be "
1494
                f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
1495
            ),
1496
        )
1497

1498
        actual_batch_other_shape = other.shape[:-2]
1499
        torch._check(
1500
            actual_batch_other_shape == expected_batch_shape,
1501
            lambda: (
1502
                f"torch.ormqr: Expected batch dimensions of other to be "
1503
                f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
1504
            ),
1505
        )
1506

1507
    torch._check(
1508
        tau.dtype == input.dtype,
1509
        lambda: (
1510
            f"torch.ormqr: Expected input and tau to have the same dtype, "
1511
            f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
1512
        ),
1513
    )
1514
    torch._check(
1515
        other.dtype == input.dtype,
1516
        lambda: (
1517
            f"torch.ormqr: Expected input and other to have the same dtype, "
1518
            f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
1519
        ),
1520
    )
1521

1522
    checkSameDevice("torch.ormqr", tau, input, "tau")
1523
    checkSameDevice("torch.ormqr", other, input, "other")
1524

1525
    return torch.empty_strided(
1526
        size=other.shape,
1527
        stride=make_contiguous_strides_for(other.shape, row_major=False),
1528
        dtype=other.dtype,
1529
        device=other.device,
1530
    )
1531

1532

1533
def _padding_check_valid_input(input, padding, *, dim):
1534
    torch._check(
1535
        len(padding) == 2 * dim,
1536
        lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
1537
    )
1538

1539
    input_dim = input.ndim
1540

1541
    is_batch_mode = input_dim == (dim + 2)
1542

1543
    valid_batch_mode = is_batch_mode
1544
    valid_non_batch_mode = not is_batch_mode
1545

1546
    if is_batch_mode:
1547
        # allow batch size of 0-dim.
1548
        for d in range(1, input_dim):
1549
            valid_batch_mode = valid_batch_mode and input.size(d) != 0
1550
    else:
1551
        for d in range(0, input_dim):
1552
            valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
1553

1554
    # allow empty batch size but not other dimensions.
1555
    torch._check(
1556
        valid_batch_mode or valid_non_batch_mode,
1557
        lambda: (
1558
            f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
1559
            f"and other non-zero dimensions for input, but got: {input.shape}"
1560
        ),
1561
    )
1562

1563

1564
def _pad1d_common(input, padding, *, is_reflection):
1565
    dim_plane = 0
1566
    dim_w = 1
1567
    nbatch = 1
1568

1569
    if input.ndim == 3:
1570
        nbatch = input.size(0)
1571
        dim_w += 1
1572
        dim_plane += 1
1573

1574
    _padding_check_valid_input(input, padding, dim=1)
1575

1576
    pad_l, pad_r = padding
1577

1578
    nplane = input.size(dim_plane)
1579
    input_w = input.size(dim_w)
1580
    output_w = input_w + pad_l + pad_r
1581

1582
    if is_reflection:
1583
        torch._check(
1584
            pad_l < input_w and pad_r < input_w,
1585
            lambda: (
1586
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1587
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1588
            ),
1589
        )
1590

1591
    torch._check(
1592
        output_w >= 1,
1593
        lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
1594
    )
1595

1596
    if input.ndim == 2:
1597
        return input.new_empty((nplane, output_w))
1598
    else:
1599
        return input.new_empty((nbatch, nplane, output_w))
1600

1601

1602
@register_meta(aten.reflection_pad1d)
1603
@out_wrapper()
1604
def meta_reflection_pad1d(input, padding):
1605
    return _pad1d_common(input, padding, is_reflection=True)
1606

1607

1608
@register_meta(aten.replication_pad1d)
1609
@out_wrapper()
1610
def meta_replication_pad1d(input, padding):
1611
    return _pad1d_common(input, padding, is_reflection=False)
1612

1613

1614
def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
1615
    dim_w = 1
1616
    if not is_reflection:
1617
        torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
1618

1619
    if input.ndim == 3:
1620
        dim_w += 1
1621

1622
    pad_l, pad_r = padding
1623

1624
    input_w = input.size(dim_w)
1625
    output_w = input_w + pad_l + pad_r
1626

1627
    if is_reflection:
1628
        torch._check(
1629
            pad_l < input_w and pad_r < input_w,
1630
            lambda: (
1631
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1632
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1633
            ),
1634
        )
1635

1636
    torch._check(
1637
        output_w == grad_output.size(dim_w),
1638
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1639
    )
1640

1641
    return input.new_empty(input.shape)
1642

1643

1644
@register_meta(aten.reflection_pad1d_backward)
1645
@out_wrapper("grad_input")
1646
def meta_reflection_pad1d_backward(grad_output, input, padding):
1647
    return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
1648

1649

1650
@register_meta(aten.replication_pad1d_backward)
1651
@out_wrapper("grad_input")
1652
def meta_replication_pad1d_backward(grad_output, input, padding):
1653
    return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
1654

1655

1656
def _pad2d_common(input, padding, *, is_reflection):
1657
    dim_w = 2
1658
    dim_h = 1
1659
    dim_slices = 0
1660
    nbatch = 1
1661

1662
    _padding_check_valid_input(input, padding, dim=2)
1663

1664
    ndim = input.ndim
1665
    if ndim == 4:
1666
        nbatch = input.size(0)
1667
        dim_w += 1
1668
        dim_h += 1
1669
        dim_slices += 1
1670

1671
    pad_l, pad_r, pad_t, pad_b = padding
1672

1673
    nplane = input.size(dim_slices)
1674
    input_h = input.size(dim_h)
1675
    input_w = input.size(dim_w)
1676
    output_h = input_h + pad_t + pad_b
1677
    output_w = input_w + pad_l + pad_r
1678

1679
    if is_reflection:
1680
        torch._check(
1681
            pad_l < input_w and pad_r < input_w,
1682
            lambda: (
1683
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1684
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1685
            ),
1686
        )
1687
        torch._check(
1688
            pad_t < input_h and pad_b < input_h,
1689
            lambda: (
1690
                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1691
                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1692
            ),
1693
        )
1694

1695
    torch._check(
1696
        output_w >= 1 or output_h >= 1,
1697
        lambda: (
1698
            f"input (H: {input_h} W: {input_w}) is too small. "
1699
            f"Calculated output H: {output_h} W: {output_w}"
1700
        ),
1701
    )
1702

1703
    if input.ndim == 3:
1704
        return input.new_empty((nplane, output_h, output_w))
1705
    else:
1706
        return input.new_empty((nbatch, nplane, output_h, output_w))
1707

1708

1709
@register_meta(aten.reflection_pad2d)
1710
@out_wrapper()
1711
def meta_reflection_pad2d(input, padding):
1712
    return _pad2d_common(input, padding, is_reflection=True)
1713

1714

1715
@register_meta(aten.replication_pad2d)
1716
@out_wrapper()
1717
def meta_replication_pad2d(input, padding):
1718
    return _pad2d_common(input, padding, is_reflection=False)
1719

1720

1721
@register_meta(
1722
    [
1723
        aten.reflection_pad2d_backward.default,
1724
        aten.reflection_pad2d_backward.grad_input,
1725
        aten.replication_pad2d_backward.default,
1726
        aten.replication_pad2d_backward.grad_input,
1727
    ]
1728
)
1729
@out_wrapper("grad_input")
1730
def meta_pad2d_backward(grad_output, self, padding):
1731
    dim_w = 2
1732
    dim_h = 1
1733
    dim_plane = 0
1734
    nbatch = 1
1735

1736
    self_shape = self.shape
1737
    if self.dim() == 4:
1738
        nbatch = self_shape[0]
1739
        dim_w += 1
1740
        dim_h += 1
1741
        dim_plane += 1
1742

1743
    pad_l, pad_r, pad_t, pad_b = padding
1744

1745
    nplane = self_shape[dim_plane]
1746
    input_h = self_shape[dim_h]
1747
    input_w = self_shape[dim_w]
1748
    output_h = input_h + pad_t + pad_b
1749
    output_w = input_w + pad_l + pad_r
1750

1751
    torch._check(
1752
        output_w == grad_output.size(dim_w),
1753
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1754
    )
1755
    torch._check(
1756
        output_h == grad_output.size(dim_h),
1757
        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1758
    )
1759
    return self.new_empty(self.shape)
1760

1761

1762
def _pad3d_common(input, padding, *, is_reflection):
1763
    dim_w = 3
1764
    dim_h = 2
1765
    dim_d = 1
1766
    dim_plane = 0
1767

1768
    _padding_check_valid_input(input, padding, dim=3)
1769

1770
    batch_mode = input.ndim == 5
1771
    if batch_mode:
1772
        nbatch = input.size(0)
1773
        dim_w += 1
1774
        dim_h += 1
1775
        dim_d += 1
1776
        dim_plane += 1
1777

1778
    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1779

1780
    nplane = input.size(dim_plane)
1781
    input_d = input.size(dim_d)
1782
    input_h = input.size(dim_h)
1783
    input_w = input.size(dim_w)
1784
    output_d = input_d + pad_f + pad_bk
1785
    output_h = input_h + pad_t + pad_b
1786
    output_w = input_w + pad_l + pad_r
1787

1788
    if is_reflection:
1789
        torch._check(
1790
            pad_l < input_w and pad_r < input_w,
1791
            lambda: (
1792
                f"Argument #4: Padding size should be less than the corresponding input dimension, "
1793
                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1794
            ),
1795
        )
1796
        torch._check(
1797
            pad_t < input_h and pad_b < input_h,
1798
            lambda: (
1799
                f"Argument #6: Padding size should be less than the corresponding input dimension, "
1800
                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1801
            ),
1802
        )
1803
        torch._check(
1804
            pad_f < input_d and pad_bk < input_d,
1805
            lambda: (
1806
                f"Argument #8: Padding size should be less than the corresponding input dimension, "
1807
                f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
1808
            ),
1809
        )
1810

1811
    torch._check(
1812
        output_w >= 1 or output_h >= 1 or output_d >= 1,
1813
        lambda: (
1814
            f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
1815
            f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
1816
        ),
1817
    )
1818

1819
    if batch_mode:
1820
        return input.new_empty((nbatch, nplane, output_d, output_h, output_w))  # type: ignore[possibly-undefined]
1821
    else:
1822
        return input.new_empty((nplane, output_d, output_h, output_w))
1823

1824

1825
@register_meta(aten.reflection_pad3d)
1826
@out_wrapper()
1827
def meta_reflection_pad3d(input, padding):
1828
    return _pad3d_common(input, padding, is_reflection=True)
1829

1830

1831
@register_meta(aten.replication_pad3d)
1832
@out_wrapper()
1833
def meta_replication_pad3d(input, padding):
1834
    return _pad3d_common(input, padding, is_reflection=False)
1835

1836

1837
@register_meta(
1838
    [
1839
        aten.reflection_pad3d_backward.default,
1840
        aten.reflection_pad3d_backward.grad_input,
1841
        aten.replication_pad3d_backward.default,
1842
        aten.replication_pad3d_backward.grad_input,
1843
    ]
1844
)
1845
@out_wrapper("grad_input")
1846
def meta_pad3d_backward(grad_output, input, padding):
1847
    torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
1848
    assert input.ndim > 3
1849
    assert grad_output.ndim == input.ndim
1850

1851
    dim_w = 3
1852
    dim_h = 2
1853
    dim_d = 1
1854

1855
    if input.ndim == 5:
1856
        dim_w += 1
1857
        dim_h += 1
1858
        dim_d += 1
1859

1860
    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1861

1862
    input_d = input.size(dim_d)
1863
    input_h = input.size(dim_h)
1864
    input_w = input.size(dim_w)
1865
    output_d = input_d + pad_f + pad_bk
1866
    output_h = input_h + pad_t + pad_b
1867
    output_w = input_w + pad_l + pad_r
1868

1869
    torch._check(
1870
        output_w == grad_output.size(dim_w),
1871
        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1872
    )
1873
    torch._check(
1874
        output_h == grad_output.size(dim_h),
1875
        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1876
    )
1877
    torch._check(
1878
        output_d == grad_output.size(dim_d),
1879
        lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
1880
    )
1881

1882
    return input.new_empty(input.shape)
1883

1884

1885
@register_meta(aten._pdist_forward)
1886
@out_wrapper()
1887
def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
1888
    torch._check(
1889
        self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
1890
    )
1891
    n = self.size(0)
1892
    if n <= 1:
1893
        return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format)  # type: ignore[call-overload]
1894
    else:
1895
        return self.new_empty((n * (n - 1) // 2,)).to(
1896
            memory_format=torch.legacy_contiguous_format
1897
        )  # type: ignore[call-overload]
1898

1899

1900
@register_meta(aten._pdist_backward)
1901
@out_wrapper()
1902
def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
1903
    torch._check(
1904
        self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
1905
    )
1906
    torch._check(
1907
        pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
1908
    )
1909
    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
1910

1911

1912
@register_meta([aten.baddbmm.default, aten.baddbmm.out])
1913
@out_wrapper()
1914
def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
1915
    dim1 = batch1.size(0)
1916
    dim2 = batch1.size(1)
1917
    dim3 = batch2.size(2)
1918
    self = self.expand((dim1, dim2, dim3))
1919
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
1920
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
1921
    torch._check(
1922
        self.dtype == batch1.dtype == batch2.dtype,
1923
        lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
1924
    )
1925
    batch1_sizes = batch1.shape
1926
    batch2_sizes = batch2.shape
1927
    bs = batch1_sizes[0]
1928
    contraction_size = batch1_sizes[2]
1929
    torch._check(
1930
        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
1931
        lambda: (
1932
            f"Expected size for first two dimensions of batch2 tensor to be: "
1933
            f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
1934
        ),
1935
    )
1936
    return self.new_empty(self.size())
1937

1938

1939
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
1940
@out_wrapper()
1941
def meta_bernoulli(self, *, generator=None):
1942
    # https://github.com/pytorch/pytorch/issues/88612
1943
    return torch.empty_like(self).contiguous()
1944

1945

1946
@register_meta(aten.bernoulli_.float)
1947
def meta_bernoulli_(self, p=0.5, generator=None):
1948
    return self
1949

1950

1951
@register_meta(aten.bernoulli.p)
1952
def meta_bernoulli_p(self, p=0.5, generator=None):
1953
    # https://github.com/pytorch/pytorch/issues/88612
1954
    return torch.empty_like(self).contiguous()
1955

1956

1957
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
1958
def meta__fused_moving_avg_obs_fq_helper(
1959
    self,
1960
    observer_on,
1961
    fake_quant_on,
1962
    running_min,
1963
    running_max,
1964
    scale,
1965
    zero_point,
1966
    averaging_const,
1967
    quant_min,
1968
    quant_max,
1969
    ch_axis,
1970
    per_row_fake_quant=False,
1971
    symmetric_quant=False,
1972
):
1973
    torch._check(
1974
        ch_axis < self.dim(),
1975
        lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
1976
    )
1977
    mask = torch.empty_like(self, dtype=torch.bool)
1978
    return (torch.empty_like(self), mask)
1979

1980

1981
@register_meta(aten.mm)
1982
@out_wrapper()
1983
def meta_mm(a, b):
1984
    torch._check(a.dim() == 2, lambda: "a must be 2D")
1985
    torch._check(b.dim() == 2, lambda: "b must be 2D")
1986
    N, M1 = a.shape
1987
    M2, P = b.shape
1988
    torch._check(
1989
        M1 == M2,
1990
        lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
1991
    )
1992
    return a.new_empty(N, P)
1993

1994

1995
def _compute_reduction_shape(self, dims, keepdim):
1996
    if keepdim:
1997
        return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
1998

1999
    return utils.compute_reduction_output_shape(self.shape, dims)
2000

2001

2002
# FakeTensors (meta tensors with a device) will report device as meta
2003
# when running meta kernels. Here, access the "fake device" of FakeTensor if it
2004
# exists so meta kernels which have diverge per device will be more
2005
# accurate when run with FakeTensors
2006
def device_hint(tensor) -> "str":
2007
    if isinstance(tensor, torch._subclasses.FakeTensor):
2008
        return tensor.fake_device.type
2009
    else:
2010
        return "cuda"  # default to cuda
2011

2012

2013
def calc_conv_nd_return_shape(
2014
    input_tensor: torch.Tensor,
2015
    weight: torch.Tensor,
2016
    stride: Union[List[int], int],
2017
    padding: Union[List[int], int],
2018
    dilation: Union[List[int], int],
2019
    is_transposed: bool,
2020
    groups: int,
2021
    output_padding: Optional[Union[List[int], int]] = None,
2022
):
2023
    def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
2024
        """
2025
        Formula to apply to calculate the length of some dimension of the output
2026

2027
        See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
2028

2029
        Args:
2030
            ln: length of the dimension
2031
            p: padding in that dim
2032
            d: dilation in that dim
2033
            k: kernel size in that dim
2034
            s: stride in that dim
2035
        Returns:
2036
            The output length
2037
        """
2038
        return (ln + 2 * p - d * (k - 1) - 1) // s + 1
2039

2040
    def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
2041
        """
2042
        Formula to apply to calculate the length of some dimension of the output
2043
        if transposed convolution is used.
2044
        See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
2045

2046
        Args:
2047
            ln: length of the dimension
2048
            p: padding in that dim
2049
            d: dilation in that dim
2050
            k: kernel size in that dim
2051
            s: stride in that dim
2052
            op: output padding in that dim
2053

2054
        Returns:
2055
            The output length
2056
        """
2057
        return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
2058

2059
    kernel_size = weight.shape[2:]
2060
    dims = input_tensor.shape[2:]
2061
    if is_transposed:
2062
        out_channels = groups * weight.shape[1]
2063
    else:
2064
        out_channels = weight.shape[0]
2065
        if weight.shape[1] * groups != input_tensor.shape[1]:
2066
            raise RuntimeError("Invalid channel dimensions")
2067

2068
    ret_shape = [input_tensor.shape[0], out_channels]
2069
    if isinstance(stride, IntLike):
2070
        stride = [stride] * len(dims)
2071
    elif len(stride) == 1:
2072
        stride = [stride[0]] * len(dims)
2073

2074
    if isinstance(padding, IntLike):
2075
        padding = [padding] * len(dims)
2076
    elif len(padding) == 1:
2077
        padding = [padding[0]] * len(dims)
2078

2079
    if isinstance(dilation, IntLike):
2080
        dilation = [dilation] * len(dims)
2081
    elif len(dilation) == 1:
2082
        dilation = [dilation[0]] * len(dims)
2083

2084
    output_padding_list: Optional[List[int]] = None
2085
    if output_padding:
2086
        if isinstance(output_padding, IntLike):
2087
            output_padding_list = [output_padding] * len(dims)
2088
        elif len(output_padding) == 1:
2089
            output_padding_list = [output_padding[0]] * len(dims)
2090
        else:
2091
            output_padding_list = output_padding
2092

2093
    for i in range(len(dims)):
2094
        # If output_padding is present, we are dealing with a transposed convolution
2095
        if output_padding_list:
2096
            ret_shape.append(
2097
                _formula_transposed(
2098
                    dims[i],
2099
                    padding[i],
2100
                    dilation[i],
2101
                    kernel_size[i],
2102
                    stride[i],
2103
                    output_padding_list[i],
2104
                )
2105
            )
2106
        else:
2107
            ret_shape.append(
2108
                _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
2109
            )
2110

2111
    return ret_shape
2112

2113

2114
def is_channels_last(ten):
2115
    return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
2116

2117

2118
@register_meta(aten.convolution.default)
2119
def meta_conv(
2120
    input_tensor: torch.Tensor,
2121
    weight: torch.Tensor,
2122
    bias: torch.Tensor,
2123
    stride: List[int],
2124
    padding: List[int],
2125
    dilation: List[int],
2126
    is_transposed: bool,
2127
    output_padding: List[int],
2128
    groups: int,
2129
):
2130
    def pick_memory_format():
2131
        if device_hint(input_tensor) == "cuda":
2132
            if is_channels_last(input_tensor) or is_channels_last(weight):
2133
                return torch.channels_last
2134
        else:
2135
            if is_channels_last(input_tensor):
2136
                return torch.channels_last
2137
        if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
2138
            return torch.contiguous_format
2139
        elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
2140
            return torch.preserve_format
2141

2142
    shape_out = calc_conv_nd_return_shape(
2143
        input_tensor,
2144
        weight,
2145
        stride,
2146
        padding,
2147
        dilation,
2148
        is_transposed,
2149
        groups,
2150
        output_padding if is_transposed else None,
2151
    )
2152

2153
    input_channels_dim = 1
2154
    output_channels_dim = 1
2155
    if input_tensor.size(input_channels_dim) == 0:
2156
        shape_out[output_channels_dim] = 0
2157

2158
    out = input_tensor.new_empty(shape_out)
2159
    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
2160
    return out
2161

2162

2163
if torch._C._has_mkldnn:
2164
    _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
2165
        "mkldnn", "IMPL", "Meta"
2166
    )
2167

2168
    @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
2169
    def meta_mkldnn_convolution_default(
2170
        input_tensor,
2171
        weight,
2172
        bias,
2173
        padding,
2174
        stride,
2175
        dilation,
2176
        groups,
2177
        attr,
2178
        scalars,
2179
        algorithm,
2180
    ):
2181
        shape_out = calc_conv_nd_return_shape(
2182
            input_tensor, weight, stride, padding, dilation, False, groups, []
2183
        )
2184
        out = input_tensor.new_empty(shape_out)
2185
        out_memory_format = torch.channels_last
2186
        out = out.to(memory_format=out_memory_format)  # type: ignore[call-overload]
2187
        return out
2188

2189
    @register_meta(torch.ops.mkldnn._linear_pointwise.default)
2190
    def meta_linear_pointwise_default(
2191
        input_tensor, weight, bias, attr, scalars, algorithm
2192
    ):
2193
        return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2194

2195
    if torch._C.has_mkl:
2196
        _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
2197
            "mkl", "IMPL", "Meta"
2198
        )
2199

2200
        @register_meta(torch.ops.mkl._mkl_linear)
2201
        def meta_mkl_linear(
2202
            input_tensor,
2203
            packed_weight,
2204
            orig_weight,
2205
            bias,
2206
            batch_size,
2207
        ):
2208
            return input_tensor.new_empty(
2209
                (*input_tensor.shape[:-1], orig_weight.shape[0])
2210
            )
2211

2212
    _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
2213
        "onednn", "IMPL", "Meta"
2214
    )
2215

2216
    @register_meta(torch.ops.onednn.qconv2d_pointwise.default)
2217
    def meta_qconv2d_pointwise(
2218
        x,
2219
        x_scale,
2220
        x_zp,
2221
        w,  # prepacked_weight
2222
        w_scale,
2223
        w_zp,
2224
        bias,
2225
        stride,
2226
        padding,
2227
        dilation,
2228
        groups,
2229
        output_scale,
2230
        output_zero_point,
2231
        output_dtype,
2232
        attr,
2233
        scalars,
2234
        algorithm,
2235
    ):
2236
        shape_out = calc_conv_nd_return_shape(
2237
            x,
2238
            w,
2239
            stride,
2240
            padding,
2241
            dilation,
2242
            False,
2243
            groups,
2244
            None,
2245
        )
2246
        assert output_dtype in [torch.float32, torch.bfloat16]
2247
        out = x.new_empty(shape_out, dtype=output_dtype)
2248
        out = out.to(memory_format=torch.channels_last)
2249
        return out
2250

2251
    @register_meta(torch.ops.onednn.qlinear_pointwise.default)
2252
    @register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
2253
    def meta_qlinear_pointwise(
2254
        x,
2255
        x_scale,
2256
        x_zp,
2257
        w,
2258
        w_scale,
2259
        w_zp,
2260
        bias,
2261
        output_scale,
2262
        output_zero_point,
2263
        output_dtype,
2264
        post_op_name,
2265
        post_op_args,
2266
        post_op_algorithm,
2267
    ):
2268
        output_shape = list(x.shape)
2269
        # The weight has been transposed during the qlinear weight prepack process.
2270
        output_shape[-1] = w.shape[1]
2271
        assert output_dtype in [torch.float32, torch.bfloat16]
2272
        out = x.new_empty(output_shape, dtype=output_dtype)
2273
        return out
2274

2275
    _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
2276
        "quantized", "IMPL", "Meta"
2277
    )
2278

2279
    @register_meta(torch.ops.quantized.max_pool2d)
2280
    def meta_quantized_max_pool2d(
2281
        input,
2282
        kernel_size,
2283
        stride=(),
2284
        padding=(0,),
2285
        dilation=(1,),
2286
        ceil_mode=False,
2287
    ):
2288
        (
2289
            nInputPlane,
2290
            outputHeight,
2291
            outputWidth,
2292
        ) = max_pool2d_checks_and_compute_shape(
2293
            input, kernel_size, stride, padding, dilation, ceil_mode
2294
        )
2295
        nbatch = input.size(-4) if input.dim() == 4 else 1
2296
        memory_format = torch.channels_last
2297
        if input.dim() == 3:
2298
            size = [nInputPlane, outputHeight, outputWidth]
2299
        else:
2300
            size = [nbatch, nInputPlane, outputHeight, outputWidth]
2301
        return torch.empty(
2302
            size,
2303
            dtype=input.dtype,
2304
            device=input.device,
2305
            memory_format=memory_format,
2306
        )
2307

2308

2309
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
2310
def check_dim_size(tensor, dim, dim_size, size):
2311
    torch._check(
2312
        tensor.dim() == dim and tensor.shape[dim_size] == size,
2313
        lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
2314
        + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
2315
    )
2316

2317

2318
@register_meta(aten.avg_pool2d.default)
2319
def meta_avg_pool2d(
2320
    input,
2321
    kernel_size,
2322
    stride=(),
2323
    padding=(0,),
2324
    ceil_mode=False,
2325
    count_include_pad=True,
2326
    divisor_override=None,
2327
):
2328
    def unpack(name, val):
2329
        torch._check(
2330
            len(val) in [1, 2],
2331
            lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
2332
        )
2333
        H = val[0]
2334
        W = H if len(val) == 1 else val[1]
2335
        return H, W
2336

2337
    kH, kW = unpack("kernel_size", kernel_size)
2338
    torch._check(
2339
        len(stride) in [0, 1, 2],
2340
        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2341
    )
2342
    if len(stride) == 0:
2343
        dH, dW = kH, kW
2344
    elif len(stride) == 1:
2345
        dH, dW = stride[0], stride[0]
2346
    else:
2347
        dH, dW = unpack("stride", stride)
2348

2349
    padH, padW = unpack("padding", padding)
2350

2351
    torch._check(
2352
        divisor_override is None or divisor_override != 0,
2353
        lambda: "divisor must be not zero",
2354
    )
2355

2356
    nbatch = input.size(-4) if input.dim() == 4 else 1
2357
    nInputPlane = input.size(-3)
2358
    inputHeight = input.size(-2)
2359
    inputWidth = input.size(-1)
2360

2361
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2362
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2363

2364
    memory_format = utils.suggest_memory_format(input)
2365
    pool2d_shape_check(
2366
        input,
2367
        kH,
2368
        kW,
2369
        dH,
2370
        dW,
2371
        padH,
2372
        padW,
2373
        1,
2374
        1,
2375
        nInputPlane,
2376
        inputHeight,
2377
        inputWidth,
2378
        outputHeight,
2379
        outputWidth,
2380
        memory_format,
2381
    )
2382

2383
    if input.dim() == 3:
2384
        size = [nInputPlane, outputHeight, outputWidth]
2385
    else:
2386
        size = [nbatch, nInputPlane, outputHeight, outputWidth]
2387
    return torch.empty(
2388
        size,
2389
        dtype=input.dtype,
2390
        device=input.device,
2391
        memory_format=memory_format,
2392
    )
2393

2394

2395
# from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
2396
def avg_pool2d_backward_shape_check(
2397
    input,
2398
    gradOutput,
2399
    nbatch,
2400
    kH,
2401
    kW,
2402
    dH,
2403
    dW,
2404
    padH,
2405
    padW,
2406
    nInputPlane,
2407
    inputHeight,
2408
    inputWidth,
2409
    outputHeight,
2410
    outputWidth,
2411
    mem_format,
2412
):
2413
    pool2d_shape_check(
2414
        input,
2415
        kH,
2416
        kW,
2417
        dH,
2418
        dW,
2419
        padH,
2420
        padW,
2421
        1,
2422
        1,
2423
        nInputPlane,
2424
        inputHeight,
2425
        inputWidth,
2426
        outputHeight,
2427
        outputWidth,
2428
        mem_format,
2429
    )
2430

2431
    ndim = input.dim()
2432
    nOutputPlane = nInputPlane
2433

2434
    check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
2435
    check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
2436
    check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
2437

2438

2439
# Don't override the C++ registration.
2440
@register_meta(aten.avg_pool2d_backward.default)
2441
def meta_avg_pool2d_backward(
2442
    gradOutput_,
2443
    input,
2444
    kernel_size,
2445
    stride,
2446
    padding,
2447
    ceil_mode,
2448
    count_include_pad,
2449
    divisor_override,
2450
):
2451
    # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
2452
    torch._check(
2453
        len(kernel_size) == 1 or len(kernel_size) == 2,
2454
        lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
2455
    )
2456
    kH = kernel_size[0]
2457
    kW = kH if len(kernel_size) == 1 else kernel_size[1]
2458
    torch._check(
2459
        len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
2460
        lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2461
    )
2462
    dH = kH if len(stride) == 0 else stride[0]
2463
    dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
2464
    torch._check(
2465
        len(padding) == 1 or len(padding) == 2,
2466
        lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
2467
    )
2468
    padH = padding[0]
2469
    padW = padH if len(padding) == 1 else padding[1]
2470

2471
    torch._check(
2472
        divisor_override is None or divisor_override != 0,
2473
        lambda: "divisor must be not zero",
2474
    )
2475

2476
    input_size = input.shape
2477
    nbatch = input_size[-4] if input.dim() == 4 else 1
2478
    nInputPlane = input_size[-3]
2479
    inputHeight = input_size[-2]
2480
    inputWidth = input_size[-1]
2481

2482
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2483
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2484

2485
    mem_format = utils.suggest_memory_format(input)
2486

2487
    avg_pool2d_backward_shape_check(
2488
        input,
2489
        gradOutput_,
2490
        nbatch,
2491
        kH,
2492
        kW,
2493
        dH,
2494
        dW,
2495
        padH,
2496
        padW,
2497
        nInputPlane,
2498
        inputHeight,
2499
        inputWidth,
2500
        outputHeight,
2501
        outputWidth,
2502
        mem_format,
2503
    )
2504

2505
    return torch.empty(
2506
        input_size,
2507
        dtype=input.dtype,
2508
        device=input.device,
2509
        memory_format=mem_format,
2510
    )
2511

2512

2513
@register_meta(aten.avg_pool3d)
2514
@out_wrapper()
2515
def meta_avg_pool3d(
2516
    input,
2517
    kernel_size,
2518
    stride=(),
2519
    padding=(0,),
2520
    ceil_mode=False,
2521
    count_include_pad=True,
2522
    divisor_override=None,
2523
):
2524
    torch._check(
2525
        len(kernel_size) in (1, 3),
2526
        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2527
    )
2528
    kT = kernel_size[0]
2529
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2530
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2531

2532
    torch._check(
2533
        not stride or len(stride) in (1, 3),
2534
        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2535
    )
2536
    dT = kT if not stride else stride[0]
2537
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2538
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2539

2540
    torch._check(
2541
        len(padding) in (1, 3),
2542
        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2543
    )
2544
    padT = padding[0]
2545
    padH = padT if len(padding) == 1 else padding[1]
2546
    padW = padT if len(padding) == 1 else padding[2]
2547

2548
    torch._check(
2549
        input.ndim in (4, 5),
2550
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2551
    )
2552

2553
    torch._check(
2554
        not divisor_override or divisor_override != 0,
2555
        lambda: "divisor must be not zero",
2556
    )
2557

2558
    nbatch = input.size(0)
2559
    nslices = input.size(-4)
2560
    itime = input.size(-3)
2561
    iheight = input.size(-2)
2562
    iwidth = input.size(-1)
2563

2564
    otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2565
    oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2566
    owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2567

2568
    pool3d_shape_check(
2569
        input,
2570
        nslices,
2571
        kT,
2572
        kH,
2573
        kW,
2574
        dT,
2575
        dH,
2576
        dW,
2577
        padT,
2578
        padH,
2579
        padW,
2580
        1,
2581
        1,
2582
        1,
2583
        itime,
2584
        iheight,
2585
        iwidth,
2586
        otime,
2587
        oheight,
2588
        owidth,
2589
        "avg_pool3d()",
2590
        check_input_size=True,
2591
    )
2592

2593
    if input.ndim == 4:
2594
        return input.new_empty((nslices, otime, oheight, owidth))
2595
    else:
2596
        return input.new_empty((nbatch, nslices, otime, oheight, owidth))
2597

2598

2599
@register_meta(aten.avg_pool3d_backward)
2600
@out_wrapper("grad_input")
2601
def meta_avg_pool3d_backward(
2602
    grad_output,
2603
    input,
2604
    kernel_size,
2605
    stride,
2606
    padding,
2607
    ceil_mode,
2608
    count_include_pad,
2609
    divisor_override,
2610
):
2611
    torch._check(
2612
        len(kernel_size) in (1, 3),
2613
        lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2614
    )
2615
    kT = kernel_size[0]
2616
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
2617
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
2618

2619
    torch._check(
2620
        not stride or len(stride) in (1, 3),
2621
        lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2622
    )
2623
    dT = kT if not stride else stride[0]
2624
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2625
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2626

2627
    torch._check(
2628
        len(padding) in (1, 3),
2629
        lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2630
    )
2631
    padT = padding[0]
2632
    padH = padT if len(padding) == 1 else padding[1]
2633
    padW = padT if len(padding) == 1 else padding[2]
2634

2635
    torch._check(
2636
        input.ndim in (4, 5),
2637
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2638
    )
2639

2640
    torch._check(
2641
        not divisor_override or divisor_override != 0,
2642
        lambda: "divisor must be not zero",
2643
    )
2644

2645
    nslices = input.size(-4)
2646
    itime = input.size(-3)
2647
    iheight = input.size(-2)
2648
    iwidth = input.size(-1)
2649

2650
    otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2651
    oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2652
    owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2653

2654
    avg_pool3d_backward_shape_check(
2655
        input,
2656
        grad_output,
2657
        nslices,
2658
        kT,
2659
        kH,
2660
        kW,
2661
        dT,
2662
        dH,
2663
        dW,
2664
        padT,
2665
        padH,
2666
        padW,
2667
        itime,
2668
        iheight,
2669
        iwidth,
2670
        otime_for_shape_check,
2671
        oheight_for_shape_check,
2672
        owidth_for_shape_check,
2673
        "avg_pool3d_backward()",
2674
    )
2675

2676
    return input.new_empty(input.shape)
2677

2678

2679
@register_meta(aten._adaptive_avg_pool2d.default)
2680
def meta_adaptive_avg_pool2d(self, output_size):
2681
    torch._check(
2682
        self.ndim == 3 or self.ndim == 4,
2683
        lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
2684
    )
2685
    output_shape = self.shape[:-2] + tuple(output_size)
2686
    memory_format = utils.suggest_memory_format(self)
2687
    # need to set memory_format to preserve the memory format of the input
2688
    # channel last input should have channel last output
2689
    return torch.empty(
2690
        output_shape,
2691
        dtype=self.dtype,
2692
        device=self.device,
2693
        memory_format=memory_format,
2694
    )
2695

2696

2697
@register_meta(aten._adaptive_avg_pool3d.default)
2698
def meta_adaptive_avg_pool3d(self, output_size):
2699
    torch._check(
2700
        self.ndim == 4 or self.ndim == 5,
2701
        lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
2702
    )
2703
    return self.new_empty(self.shape[:-3] + tuple(output_size))
2704

2705

2706
@register_meta(aten._adaptive_avg_pool2d_backward.default)
2707
def meta__adaptive_avg_pool2d_backward(grad_out, self):
2708
    ndim = grad_out.ndim
2709
    for i in range(1, ndim):
2710
        torch._check(
2711
            grad_out.size(i) > 0,
2712
            lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
2713
                      size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
2714
        )
2715
    torch._check(
2716
        ndim == 3 or ndim == 4,
2717
        lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
2718
    )
2719
    torch._check(
2720
        self.dtype == grad_out.dtype,
2721
        lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
2722
    )
2723
    memory_format = torch.contiguous_format
2724
    if is_channels_last(self):
2725
        memory_format = torch.channels_last
2726
    return self.new_empty(self.shape).to(memory_format=memory_format)
2727

2728

2729
@register_meta(aten._adaptive_avg_pool3d_backward)
2730
@out_wrapper("grad_input")
2731
def meta__adaptive_avg_pool3d_backward(grad_output, self):
2732
    _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
2733
    return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2734

2735

2736
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
2737
    ndim = grad_output.ndim
2738
    for i in range(1, ndim):
2739
        torch._check(
2740
            grad_output.size(i) > 0,
2741
            lambda: (
2742
                f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
2743
                f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
2744
            ),
2745
        )
2746

2747

2748
@register_meta(aten.adaptive_max_pool2d)
2749
@out_wrapper("out", "indices")
2750
def meta_adaptive_max_pool2d(input, output_size):
2751
    ndim = input.ndim
2752
    torch._check(
2753
        ndim in (3, 4),
2754
        lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
2755
    )
2756
    for i in range(1, ndim):
2757
        torch._check(
2758
            input.size(i) > 0,
2759
            lambda: (
2760
                f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
2761
                f"but input has sizes {input.shape} with dimension {i} being empty"
2762
            ),
2763
        )
2764

2765
    torch._check(
2766
        len(output_size) == 2,
2767
        lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
2768
    )
2769

2770
    dimH = 1
2771
    sizeB = 1
2772
    sizeD = 0
2773

2774
    if input.ndim == 4:
2775
        sizeB = input.size(0)
2776
        dimH += 1
2777

2778
    sizeD = input.size(dimH - 1)
2779
    osizeH, osizeW = output_size
2780

2781
    if input.ndim == 3:
2782
        out_shape = (sizeD, osizeH, osizeW)
2783
        out = input.new_empty(out_shape)
2784
        indices = input.new_empty(out_shape, dtype=torch.int64)
2785
        return out, indices
2786
    else:
2787
        out_shape = (sizeB, sizeD, osizeH, osizeW)  # type: ignore[assignment]
2788
        memory_format = utils.suggest_memory_format(input)
2789
        out = input.new_empty(out_shape).to(memory_format=memory_format)
2790
        indices = input.new_empty(out_shape, dtype=torch.int64).to(
2791
            memory_format=memory_format
2792
        )
2793
        return out, indices
2794

2795

2796
@register_meta(aten.adaptive_max_pool2d_backward)
2797
@out_wrapper("grad_input")
2798
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
2799
    ndim = grad_output.ndim
2800
    torch._check(
2801
        ndim in (3, 4),
2802
        lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
2803
    )
2804

2805
    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
2806

2807
    torch._check(
2808
        input.dtype == grad_output.dtype,
2809
        lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
2810
    )
2811

2812
    memory_format = utils.suggest_memory_format(input)
2813
    return input.new_empty(input.shape).to(memory_format=memory_format)
2814

2815

2816
@register_meta(aten.adaptive_max_pool3d)
2817
@out_wrapper("out", "indices")
2818
def meta_adaptive_max_pool3d(input, output_size):
2819
    ndim = input.ndim
2820
    torch._check(
2821
        ndim in (4, 5),
2822
        lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
2823
    )
2824
    for i in range(1, ndim):
2825
        torch._check(
2826
            input.size(i) > 0,
2827
            lambda: (
2828
                f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
2829
                f"but input has sizes {input.shape} with dimension {i} being empty"
2830
            ),
2831
        )
2832

2833
    torch._check(
2834
        len(output_size) == 3,
2835
        lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
2836
    )
2837

2838
    dimD = 0
2839
    sizeB = 1
2840
    sizeD = 0
2841

2842
    if ndim == 5:
2843
        sizeB = input.size(0)
2844
        dimD += 1
2845

2846
    sizeD = input.size(dimD)
2847
    osizeT, osizeH, osizeW = output_size
2848

2849
    if ndim == 4:
2850
        out_shape = (sizeD, osizeT, osizeH, osizeW)
2851
    else:
2852
        out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW)  # type: ignore[assignment]
2853

2854
    out = input.new_empty(out_shape)
2855
    indices = input.new_empty(out_shape, dtype=torch.int64)
2856

2857
    return out, indices
2858

2859

2860
@register_meta(aten.adaptive_max_pool3d_backward)
2861
@out_wrapper("grad_input")
2862
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
2863
    _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
2864
    return input.new_empty(input.shape)
2865

2866

2867
@register_meta(aten.repeat_interleave.Tensor)
2868
def meta_repeat_interleave_Tensor(repeats, output_size=None):
2869
    if output_size is None:
2870
        raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
2871
    return repeats.new_empty(output_size)
2872

2873

2874
@register_meta([aten.complex.default, aten.complex.out])
2875
@out_wrapper()
2876
def meta_complex(real, imag):
2877
    assert real.dtype.is_floating_point
2878
    assert imag.dtype.is_floating_point
2879
    out_shape = _broadcast_shapes(real.shape, imag.shape)
2880
    return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
2881

2882

2883
@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
2884
@out_wrapper()
2885
def nonzero_static(self, *, size: int, fill_value: int = -1):
2886
    return self.new_empty((size, self.dim()), dtype=torch.long)
2887

2888

2889
@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
2890
def meta_index_Tensor(self, indices):
2891
    torch._check(bool(indices), lambda: "at least one index must be provided")
2892
    # aten::index is the internal advanced indexing implementation
2893
    # checkIndexTensorTypes and expandTensors
2894
    result: List[Optional[Tensor]] = []
2895
    for i, index in enumerate(indices):
2896
        if index is not None:
2897
            torch._check(
2898
                index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
2899
                lambda: "tensors used as indices must be long, int, byte or bool tensors",
2900
            )
2901
            if index.dtype in [torch.int8, torch.bool]:
2902
                nonzero = index.nonzero()
2903
                k = len(result)
2904
                torch._check_index(
2905
                    k + index.ndim <= self.ndim,
2906
                    lambda: f"too many indices for tensor of dimension {self.ndim}",
2907
                )
2908
                for j in range(index.ndim):
2909
                    torch._check_index(
2910
                        index.shape[j] == self.shape[k + j],
2911
                        lambda: f"The shape of the mask {index.shape} at index {i} "
2912
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
2913
                    )
2914
                    result.append(nonzero.select(1, j))
2915
            else:
2916
                result.append(index)
2917
        else:
2918
            result.append(index)
2919
    indices = result
2920
    torch._check(
2921
        len(indices) <= self.ndim,
2922
        lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
2923
    )
2924
    # expand_outplace
2925
    import torch._refs as refs  # avoid import cycle in mypy
2926

2927
    indices = list(refs._maybe_broadcast(*indices))
2928
    # add missing null tensors
2929
    while len(indices) < self.ndim:
2930
        indices.append(None)
2931

2932
    # hasContiguousSubspace
2933
    #   true if all non-null tensors are adjacent
2934
    # See:
2935
    # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
2936
    # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
2937
    state = 0
2938
    has_contiguous_subspace = False
2939
    for index in indices:
2940
        if state == 0:
2941
            if index is not None:
2942
                state = 1
2943
        elif state == 1:
2944
            if index is None:
2945
                state = 2
2946
        else:
2947
            if index is not None:
2948
                break
2949
    else:
2950
        has_contiguous_subspace = True
2951

2952
    # transposeToFront
2953
    # This is the logic that causes the newly inserted dimensions to show up
2954
    # at the beginning of the tensor, if they're not contiguous
2955
    if not has_contiguous_subspace:
2956
        dims = []
2957
        transposed_indices = []
2958
        for i, index in enumerate(indices):
2959
            if index is not None:
2960
                dims.append(i)
2961
                transposed_indices.append(index)
2962
        for i, index in enumerate(indices):
2963
            if index is None:
2964
                dims.append(i)
2965
                transposed_indices.append(index)
2966
        self = self.permute(dims)
2967
        indices = transposed_indices
2968

2969
    # AdvancedIndex::AdvancedIndex
2970
    # Now we can assume the indices have contiguous subspace
2971
    # This is simplified from AdvancedIndex which goes to more effort
2972
    # to put the input and indices in a form so that TensorIterator can
2973
    # take them.  If we write a ref for this, probably that logic should
2974
    # get implemented
2975
    before_shape: List[int] = []
2976
    after_shape: List[int] = []
2977
    replacement_shape: List[int] = []
2978
    for dim, index in enumerate(indices):
2979
        if index is None:
2980
            if replacement_shape:
2981
                after_shape.append(self.shape[dim])
2982
            else:
2983
                before_shape.append(self.shape[dim])
2984
        else:
2985
            replacement_shape = list(index.shape)
2986
    return self.new_empty(before_shape + replacement_shape + after_shape)
2987

2988

2989
@register_meta([aten.convolution_backward.default])
2990
def meta_convolution_backward(
2991
    grad_output_,
2992
    input_,
2993
    weight_,
2994
    bias_sizes_opt,
2995
    stride,
2996
    padding,
2997
    dilation,
2998
    transposed,
2999
    output_padding,
3000
    groups,
3001
    output_mask,
3002
):
3003
    # High level logic taken from slow_conv3d_backward_cpu which should
3004
    # be representative of all convolution_backward impls
3005
    backend_grad_input = None
3006
    backend_grad_weight = None
3007
    backend_grad_bias = None
3008

3009
    if output_mask[0]:
3010
        backend_grad_input = grad_output_.new_empty(input_.size())
3011
    if output_mask[1]:
3012
        backend_grad_weight = grad_output_.new_empty(weight_.size())
3013
    if output_mask[2]:
3014
        backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3015

3016
    return (backend_grad_input, backend_grad_weight, backend_grad_bias)
3017

3018

3019
@register_meta([aten.addbmm.default, aten.addbmm.out])
3020
@out_wrapper()
3021
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
3022
    dim1 = batch1.size(1)
3023
    dim2 = batch2.size(2)
3024
    self = self.expand((dim1, dim2))
3025
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3026
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3027
    torch._check(
3028
        batch1.size(0) == batch2.size(0),
3029
        lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
3030
    )
3031
    torch._check(
3032
        batch1.size(2) == batch2.size(1),
3033
        lambda: (
3034
            f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
3035
            f"and {batch2.size(1)}x{batch2.size(2)})"
3036
        ),
3037
    )
3038
    torch._check(
3039
        self.size(0) == dim1 and self.size(1) == dim2,
3040
        lambda: "self tensor does not match matmul output shape",
3041
    )
3042
    return self.new_empty(self.size())
3043

3044

3045
def register_meta_foreach(ops):
3046
    def wrapper(fn):
3047
        def register(op):
3048
            op_name = str(op).split(".")[1]
3049
            scalar_op = getattr(aten, op_name.replace("_foreach_", ""))
3050

3051
            _add_op_to_registry(
3052
                meta_table,
3053
                op,
3054
                partial(
3055
                    fn,
3056
                    _scalar_op=scalar_op,
3057
                ),
3058
            )
3059

3060
        pytree.tree_map_(register, ops)
3061
        return fn
3062

3063
    return wrapper
3064

3065

3066
@register_meta_foreach(
3067
    [
3068
        aten._foreach_abs,
3069
        aten._foreach_acos,
3070
        aten._foreach_asin,
3071
        aten._foreach_atan,
3072
        aten._foreach_ceil,
3073
        aten._foreach_cos,
3074
        aten._foreach_cosh,
3075
        aten._foreach_erf,
3076
        aten._foreach_erfc,
3077
        aten._foreach_exp,
3078
        aten._foreach_expm1,
3079
        aten._foreach_frac,
3080
        aten._foreach_floor,
3081
        aten._foreach_lgamma,
3082
        aten._foreach_log,
3083
        aten._foreach_log10,
3084
        aten._foreach_log1p,
3085
        aten._foreach_log2,
3086
        aten._foreach_neg,
3087
        aten._foreach_norm,
3088
        aten._foreach_reciprocal,
3089
        aten._foreach_round,
3090
        aten._foreach_sigmoid,
3091
        aten._foreach_sign,
3092
        aten._foreach_sin,
3093
        aten._foreach_sinh,
3094
        aten._foreach_sqrt,
3095
        aten._foreach_tan,
3096
        aten._foreach_tanh,
3097
        aten._foreach_trunc,
3098
        aten._foreach_zero,
3099
        aten._foreach_add,
3100
        aten._foreach_sub,
3101
        aten._foreach_mul,
3102
        aten._foreach_div,
3103
        aten._foreach_clamp_min,
3104
        aten._foreach_clamp_max,
3105
        aten._foreach_lerp,
3106
    ],
3107
)
3108
def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs):
3109
    torch._check(
3110
        isinstance(args[0], list),
3111
        lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."),
3112
    )
3113

3114
    nelem = len(args[0])
3115
    torch._check(
3116
        nelem > 0,
3117
        lambda: ("Tensor list must have at least one tensor."),
3118
    )
3119

3120
    nlists = 1
3121
    for iarg, arg in enumerate(args[1:]):
3122
        if isinstance(arg, list):
3123
            nlists += 1
3124
            torch._check(
3125
                len(arg) == nelem,
3126
                lambda: (
3127
                    f"self and argument-{iarg+2} must match in length, "
3128
                    f"but got {nelem} and {len(arg)}."
3129
                ),
3130
            )
3131
        elif isinstance(arg, Tensor):
3132
            torch._check(
3133
                arg.dim() == 0 and arg.numel() == 1,
3134
                lambda: (
3135
                    "scalar tensor expected to be 0 dim but it has "
3136
                    f"{arg.dim()} dimensions and {arg.numel()} elements."
3137
                ),
3138
            )
3139
        else:
3140
            break
3141

3142
    result = []
3143
    for elem in range(nelem):
3144
        each_args = [args[i][elem] for i in range(nlists)]
3145
        result.append(_scalar_op(*each_args, *args[nlists:], **kwargs))
3146

3147
    return result
3148

3149

3150
@register_meta_foreach(
3151
    [
3152
        aten._foreach_abs_,
3153
        aten._foreach_acos_,
3154
        aten._foreach_asin_,
3155
        aten._foreach_atan_,
3156
        aten._foreach_ceil_,
3157
        aten._foreach_cos_,
3158
        aten._foreach_cosh_,
3159
        aten._foreach_erf_,
3160
        aten._foreach_erfc_,
3161
        aten._foreach_exp_,
3162
        aten._foreach_expm1_,
3163
        aten._foreach_frac_,
3164
        aten._foreach_floor_,
3165
        aten._foreach_lgamma_,
3166
        aten._foreach_log_,
3167
        aten._foreach_log10_,
3168
        aten._foreach_log1p_,
3169
        aten._foreach_log2_,
3170
        aten._foreach_neg_,
3171
        aten._foreach_reciprocal_,
3172
        aten._foreach_round_,
3173
        aten._foreach_sigmoid_,
3174
        aten._foreach_sign_,
3175
        aten._foreach_sin_,
3176
        aten._foreach_sinh_,
3177
        aten._foreach_sqrt_,
3178
        aten._foreach_tan_,
3179
        aten._foreach_tanh_,
3180
        aten._foreach_trunc_,
3181
        aten._foreach_zero_,
3182
        aten._foreach_add_,
3183
        aten._foreach_sub_,
3184
        aten._foreach_mul_,
3185
        aten._foreach_div_,
3186
        aten._foreach_clamp_min_,
3187
        aten._foreach_clamp_max_,
3188
        aten._foreach_lerp_,
3189
        aten._foreach_copy_,
3190
    ]
3191
)
3192
def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
3193
    _meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs)
3194
    return
3195

3196

3197
@register_meta([aten._foreach_pow.ScalarAndTensor])
3198
def meta__foreach_pow_scalar_and_tensor(self, exponent):
3199
    # Only foreach_pow has a ScalarAndTensor method and needs special
3200
    # handling because it does not work with _meta_foreach_out_of_place.
3201
    torch._check(
3202
        isinstance(exponent, List),
3203
        lambda: f"exponent must be a tensor list but got {type(exponent)}",
3204
    )
3205
    return [torch.empty_like(e) for e in exponent]
3206

3207

3208
def _check_foreach_binop_tensor_lists(self, other):
3209
    torch._check(
3210
        isinstance(self, List) and isinstance(other, List),
3211
        lambda: (
3212
            "The first two arguments of must be List[Tensor], "
3213
            f"but got {type(self)} and {type(other)}."
3214
        ),
3215
    )
3216
    torch._check(
3217
        len(self) > 0 and len(self) == len(other),
3218
        lambda: (
3219
            "self and other must be non-empty and match in length, "
3220
            f"but got {len(self)} and {len(other)}."
3221
        ),
3222
    )
3223

3224

3225
@register_meta(
3226
    [
3227
        aten._foreach_maximum,
3228
        aten._foreach_minimum,
3229
    ]
3230
)
3231
def meta__foreach_binop_scalar(*args):
3232
    # aten.maximum(Tensor, Scalar) does not exist.
3233
    return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min)
3234

3235

3236
@register_meta(
3237
    [
3238
        aten._foreach_maximum_,
3239
        aten._foreach_minimum_,
3240
    ]
3241
)
3242
def meta__foreach_binop__scalar(*args):
3243
    # aten.maximum(Tensor, Scalar) does not exist
3244
    _meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_)
3245
    return
3246

3247

3248
@register_meta(
3249
    [
3250
        aten._foreach_addcdiv.Scalar,
3251
        aten._foreach_addcmul.Scalar,
3252
    ]
3253
)
3254
def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
3255
    # forach_addcdiv and addcdiv have different signatures and
3256
    # cannot use _meta_foreach_out_of_place.
3257
    torch._check(
3258
        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
3259
        lambda: (
3260
            "All arguments must be List[Tensor], "
3261
            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
3262
        ),
3263
    )
3264
    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3265
    torch._check(
3266
        len(self) == len(tensor1) and len(self) == len(tensor2),
3267
        lambda: "All input tensor lists must have the same length",
3268
    )
3269

3270
    return [torch.empty_like(s) for s in self]
3271

3272

3273
@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
3274
def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
3275
    torch._check(
3276
        all(isinstance(l, List) for l in [self, tensor1, tensor2])
3277
        and isinstance(scalars, torch.Tensor),
3278
        lambda: (
3279
            "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, "
3280
            f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
3281
        ),
3282
    )
3283
    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3284
    torch._check(
3285
        len(self) == len(tensor1) and len(self) == len(tensor2),
3286
        lambda: "All input tensor lists must have the same length",
3287
    )
3288

3289

3290
@register_meta(
3291
    [
3292
        aten._foreach_addcdiv_.Scalar,
3293
        aten._foreach_addcmul_.Scalar,
3294
    ]
3295
)
3296
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
3297
    torch._check(
3298
        all(isinstance(l, List) for l in [self, tensor1, tensor2]),
3299
        lambda: (
3300
            "All arguments of _foreach_addc*_ must be List[Tensor], "
3301
            f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
3302
        ),
3303
    )
3304
    torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3305
    torch._check(
3306
        len(self) == len(tensor1) and len(self) == len(tensor2),
3307
        lambda: "All input tensor lists must have the same length",
3308
    )
3309

3310

3311
@register_meta([aten._fused_adam_.default])
3312
def meta__fused_adam_(
3313
    self,
3314
    grads,
3315
    exp_avgs,
3316
    exp_avg_sqs,
3317
    max_exp_avg_sqs,
3318
    state_steps,
3319
    *,
3320
    lr,
3321
    beta1,
3322
    beta2,
3323
    weight_decay,
3324
    eps,
3325
    amsgrad,
3326
    maximize,
3327
    grad_scale=None,
3328
    found_inf=None,
3329
):
3330
    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3331
        torch._check(
3332
            isinstance(l, List),
3333
            lambda: f"exponent must be a tensor list but got {type(l)}",
3334
        )
3335

3336

3337
@register_meta([aten._fused_adam.default])
3338
def meta__fused_adam(
3339
    self,
3340
    grads,
3341
    exp_avgs,
3342
    exp_avg_sqs,
3343
    max_exp_avg_sqs,
3344
    state_steps,
3345
    *,
3346
    lr,
3347
    beta1,
3348
    beta2,
3349
    weight_decay,
3350
    eps,
3351
    amsgrad,
3352
    maximize,
3353
    grad_scale=None,
3354
    found_inf=None,
3355
):
3356
    for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3357
        torch._check(
3358
            isinstance(l, List),
3359
            lambda: f"exponent must be a tensor list but got {type(l)}",
3360
        )
3361

3362
    def empty_like_list(tensor_list):
3363
        return [torch.empty_like(t) for t in tensor_list]
3364

3365
    return (
3366
        empty_like_list(self),
3367
        empty_like_list(grads),
3368
        empty_like_list(exp_avgs),
3369
        empty_like_list(exp_avg_sqs),
3370
        empty_like_list(max_exp_avg_sqs),
3371
    )
3372

3373

3374
@register_meta([aten._int_mm])
3375
@out_wrapper()
3376
def meta__int_mm(a, b):
3377
    torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
3378
    torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
3379
    torch._check(
3380
        a.dtype is torch.int8,
3381
        lambda: f"expected self to be int8, got {a.dtype}",
3382
    )
3383
    torch._check(
3384
        b.dtype is torch.int8,
3385
        lambda: f"expected mat2 to be int8, got {b.dtype}",
3386
    )
3387
    torch._check(
3388
        a.size(1) == b.size(0),
3389
        lambda: (
3390
            f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
3391
            f"and {b.size(0)}x{b.size(1)})"
3392
        ),
3393
    )
3394
    return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
3395

3396

3397
@register_meta([aten._convert_weight_to_int4pack])
3398
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
3399
    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3400
    torch._check(
3401
        w.dtype is torch.int32,
3402
        lambda: f"expected w to be int32, got {w.dtype}",
3403
    )
3404
    n = w.size(0)
3405
    k = w.size(1)
3406
    if device_hint(w) == "cpu":
3407
        return w.new_empty(
3408
            (n, k // 2),
3409
            dtype=torch.uint8,
3410
        )
3411
    # cuda path
3412
    return w.new_empty(
3413
        (
3414
            n // 8,
3415
            k // (inner_k_tiles * 16),
3416
            32,
3417
            inner_k_tiles // 2,
3418
        ),
3419
        dtype=torch.int32,
3420
    )
3421

3422

3423
@register_meta([aten._weight_int4pack_mm])
3424
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
3425
    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3426
    torch._check(
3427
        x.dtype is torch.bfloat16,
3428
        lambda: f"expected x to be bf16, got {x.dtype}",
3429
    )
3430
    if device_hint(w) == "cpu":
3431
        torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3432
        torch._check(
3433
            w.dtype is torch.uint8,
3434
            lambda: f"expected w to be uint8, got {w.dtype}",
3435
        )
3436
    else:
3437
        torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
3438
        torch._check(
3439
            w.dtype is torch.int32,
3440
            lambda: f"expected w to be int32, got {w.dtype}",
3441
        )
3442
    return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3443

3444

3445
@register_meta([aten._weight_int8pack_mm])
3446
def meta__weight_int8pack_mm(x, w, q_scales):
3447
    torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3448
    torch._check(
3449
        x.dtype is torch.bfloat16,
3450
        lambda: f"expected x to be bf16, got {x.dtype}",
3451
    )
3452
    torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3453
    torch._check(
3454
        w.dtype is torch.int8,
3455
        lambda: f"expected w to be int8, got {w.dtype}",
3456
    )
3457
    return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3458

3459

3460
@register_meta(aten._cdist_forward.default)
3461
def meta_cdist_forward(x1, x2, p, compute_mode):
3462
    torch._check(
3463
        x1.dim() >= 2,
3464
        lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
3465
    )
3466
    torch._check(
3467
        x2.dim() >= 2,
3468
        lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
3469
    )
3470
    torch._check(
3471
        x1.size(-1) == x2.size(-1),
3472
        lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
3473
    )
3474
    torch._check(
3475
        utils.is_float_dtype(x1.dtype),
3476
        lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
3477
    )
3478
    torch._check(
3479
        utils.is_float_dtype(x2.dtype),
3480
        lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
3481
    )
3482
    torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
3483
    torch._check(
3484
        compute_mode in (None, 1, 2),
3485
        lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
3486
    )
3487
    r1 = x1.size(-2)
3488
    r2 = x2.size(-2)
3489
    batch_tensor1 = x1.shape[:-2]
3490
    batch_tensor2 = x2.shape[:-2]
3491
    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3492
    output_shape.extend([r1, r2])
3493
    return x1.new_empty(output_shape)
3494

3495

3496
@register_meta(aten._cdist_backward)
3497
@out_wrapper()
3498
def meta_cdist_backward(grad, x1, x2, p, cdist):
3499
    c1 = x1.shape[-1]
3500
    r1 = x1.shape[-2]
3501
    r2 = x2.shape[-2]
3502
    batch_tensor1 = x1.shape[:-2]
3503
    batch_tensor2 = x2.shape[:-2]
3504
    expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3505
    tensor1_expand_size = expand_batch_portion.copy()
3506
    tensor1_expand_size.extend([r1, c1])
3507
    batch_product = math.prod(expand_batch_portion)
3508
    if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
3509
        return torch.zeros_like(x1)
3510
    if tensor1_expand_size != list(x1.shape):
3511
        x1 = x1.expand(tensor1_expand_size)
3512
    return torch.empty_like(x1, memory_format=torch.contiguous_format)
3513

3514

3515
# NB: This meta function accepts non-meta arguments!  When this behavior
3516
# was originally introduced this was accidental, but it is now load bearing
3517
# as people are using this so that they can conveniently test code involving
3518
# embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module)
3519
@register_meta(aten._embedding_bag.default)
3520
def meta_embedding_bag(
3521
    weight,
3522
    indices,
3523
    offsets,
3524
    scale_grad_by_freq=False,
3525
    mode=0,
3526
    sparse=False,
3527
    per_sample_weights=None,
3528
    include_last_offset=False,
3529
    padding_idx=-1,
3530
):
3531
    torch._check(
3532
        indices.dtype in (torch.long, torch.int),
3533
        lambda: f"expected indices to be long or int, got {indices.dtype}",
3534
    )
3535
    torch._check(
3536
        offsets.dtype in (torch.long, torch.int),
3537
        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
3538
    )
3539
    torch._check(
3540
        utils.is_float_dtype(weight.dtype),
3541
        lambda: f"expected weight to be floating point type, got {weight.dtype}",
3542
    )
3543

3544
    num_bags = offsets.size(0)
3545
    if include_last_offset:
3546
        torch._check(
3547
            num_bags >= 1,
3548
            lambda: "include_last_offset: numBags should be at least 1",
3549
        )
3550
        num_bags -= 1
3551

3552
    output = weight.new_empty(num_bags, weight.size(1))
3553
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
3554

3555
    if per_sample_weights is not None:
3556
        torch._check(
3557
            mode == MODE_SUM,
3558
            lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
3559
        )
3560
        torch._check(
3561
            per_sample_weights.dtype == weight.dtype,
3562
            lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
3563
        )
3564
        torch._check(
3565
            per_sample_weights.ndim == 1,
3566
            lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
3567
        )
3568
        torch._check(
3569
            per_sample_weights.numel() == indices.numel(),
3570
            lambda: (
3571
                f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
3572
                f"to be the same as indices.numel() ({indices.numel()})"
3573
            ),
3574
        )
3575

3576
    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
3577
        return (
3578
            is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
3579
        )
3580

3581
    def is_fast_path_index_select(src, output, padding_idx):
3582
        return (
3583
            (src.dtype == torch.float or src.dtype == torch.half)
3584
            and src.stride(1) == 1
3585
            and output.stride(1) == 1
3586
            and padding_idx < 0
3587
        )
3588

3589
    def is_fast_path(src, scale, output, padding_idx):
3590
        if scale is not None:
3591
            return is_fast_path_index_select_scale(src, scale, output, padding_idx)
3592
        else:
3593
            return is_fast_path_index_select(src, output, padding_idx)
3594

3595
    if device_hint(offsets) != "cpu":
3596
        offset2bag = indices.new_empty(indices.size(0))
3597
        bag_size = indices.new_empty(offsets.size())
3598
        if mode == MODE_MAX:
3599
            max_indices = indices.new_empty(num_bags, weight.size(1))
3600
        else:
3601
            max_indices = indices.new_empty(0)
3602
    else:
3603
        fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
3604
        if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
3605
            offset2bag = offsets.new_empty(indices.size(0))
3606
        else:
3607
            offset2bag = offsets.new_empty(0)
3608
        bag_size = offsets.new_empty(num_bags)
3609
        # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
3610
        numBags = offsets.shape[0]
3611
        if mode == MODE_MAX:
3612
            if include_last_offset:
3613
                torch._check(
3614
                    numBags >= 1,
3615
                    lambda: "include_last_offset: numBags should be at least 1",
3616
                )
3617
                numBags -= 1
3618
            max_indices = offsets.new_empty(numBags, weight.shape[1])
3619
        else:
3620
            max_indices = offsets.new_empty(bag_size.size())
3621
    return output, offset2bag, bag_size, max_indices
3622

3623

3624
@register_meta(aten._embedding_bag_forward_only.default)
3625
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
3626
    output, offset2bag, bag_size, max_indices = meta_embedding_bag(
3627
        weight, indices, offsets, *args
3628
    )
3629
    if device_hint(offsets) == "cpu":
3630
        bag_size = offsets.new_empty(offsets.size())
3631
    return output, offset2bag, bag_size, max_indices
3632

3633

3634
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
3635
    # if specified, dtype takes precedence
3636
    if dtype:
3637
        return dtype
3638

3639
    if input.dtype.is_floating_point or input.dtype.is_complex:
3640
        return input.dtype
3641
    elif promote_int_to_long:
3642
        return torch.long
3643

3644
    return input.dtype
3645

3646

3647
@register_meta([aten.nansum.default, aten.nansum.out])
3648
@out_wrapper()
3649
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
3650
    output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
3651
    dims = utils.reduction_dims(input.shape, dims)
3652
    output_shape = _compute_reduction_shape(input, dims, keepdim)
3653
    return input.new_empty(output_shape, dtype=output_dtype)
3654

3655

3656
@register_meta([aten.median.default, aten.nanmedian.default])
3657
def meta_median(input):
3658
    output_shape = utils.compute_reduction_output_shape(
3659
        input.shape, tuple(range(input.dim()))
3660
    )
3661
    return input.new_empty(output_shape)
3662

3663

3664
@register_meta(
3665
    [
3666
        aten.median.dim,
3667
        aten.median.dim_values,
3668
        aten.nanmedian.dim,
3669
        aten.nanmedian.dim_values,
3670
        aten.mode.default,
3671
        aten.mode.values,
3672
    ]
3673
)
3674
@out_wrapper("values", "indices")
3675
def meta_median_mode_dim(input, dim=-1, keepdim=False):
3676
    if device_hint(input) == "cuda":
3677
        utils.alert_not_deterministic("median CUDA with indices output")
3678
    dim = utils.reduction_dims(input.shape, (dim,))
3679
    output_shape = _compute_reduction_shape(input, dim, keepdim)
3680
    return (
3681
        input.new_empty(output_shape),
3682
        input.new_empty(output_shape, dtype=torch.long),
3683
    )
3684

3685

3686
@register_meta(aten.logical_not_.default)
3687
def meta_logical_not_(self):
3688
    return self
3689

3690

3691
@register_meta(aten.repeat.default)
3692
def meta_repeat(self, repeats):
3693
    torch._check(
3694
        len(repeats) >= self.dim(),
3695
        lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3696
    )
3697
    # Add new leading dimensions to the tensor if the
3698
    # number of target dimensions is larger than the
3699
    # number of source dimensions.
3700
    num_new_dimensions = len(repeats) - self.dim()
3701
    padded_size = (1,) * num_new_dimensions + tuple(self.shape)
3702
    target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
3703
    return self.new_empty(target_size)
3704

3705

3706
@register_meta(aten.zero_.default)
3707
def meta_zero_(self):
3708
    return self
3709

3710

3711
@register_meta(
3712
    [
3713
        aten.mul_.Scalar,
3714
        aten.div_.Scalar,
3715
        aten.mul_.Tensor,
3716
        aten.div_.Tensor,
3717
        aten.logical_and_.default,
3718
        aten.logical_or_.default,
3719
        aten.logical_xor_.default,
3720
    ],
3721
)
3722
def meta_binop_inplace(self, other):
3723
    if isinstance(other, torch.Tensor):
3724
        check_inplace_broadcast(self.shape, other.shape)
3725
    return self
3726

3727

3728
@register_meta(
3729
    [
3730
        aten.add_.Scalar,
3731
        aten.sub_.Scalar,
3732
        aten.add_.Tensor,
3733
        aten.sub_.Tensor,
3734
    ],
3735
)
3736
def meta_binop_inplace_alpha(self, other, alpha=1):
3737
    if isinstance(other, torch.Tensor):
3738
        check_inplace_broadcast(self.shape, other.shape)
3739
    return self
3740

3741

3742
@register_meta([aten.round.default, aten.round.decimals])
3743
def meta_round(self, **kwargs):
3744
    return elementwise_meta(
3745
        self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3746
    )
3747

3748

3749
def shift_dtype_check(fn_name, self, val):
3750
    torch._check(
3751
        utils.is_integer_dtype(self.dtype),
3752
        lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
3753
    )
3754
    if isinstance(val, torch.Tensor):
3755
        torch._check(
3756
            utils.is_integer_dtype(val.dtype),
3757
            lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
3758
        )
3759
    else:
3760
        torch._check(
3761
            isinstance(val, IntLike),
3762
            lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
3763
        )
3764

3765

3766
@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
3767
def meta_rshifts(self, other):
3768
    shift_dtype_check("rshift", self, other)
3769
    return elementwise_meta(
3770
        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3771
    )
3772

3773

3774
@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
3775
def meta_lshifts(self, other):
3776
    shift_dtype_check("lshift", self, other)
3777
    return elementwise_meta(
3778
        self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3779
    )
3780

3781

3782
@register_meta(aten.zero.default)
3783
def meta_zero(self):
3784
    return self.new_empty(self.shape)
3785

3786

3787
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
3788
def meta_fill_(self, val):
3789
    return self
3790

3791

3792
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
3793
def meta_fill(self, val):
3794
    return torch.empty_like(self)
3795

3796

3797
@register_meta(aten.relu_.default)
3798
def meta_relu_(self):
3799
    return self
3800

3801

3802
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
3803
def meta_index_put(self, indices, values, accumulate=False):
3804
    return torch.empty_like(self)
3805

3806

3807
@register_meta(aten.masked_fill_.Scalar)
3808
def meta_masked_fill_(self, mask, value):
3809
    check_inplace_broadcast(self.shape, mask.shape)
3810
    return self
3811

3812

3813
@register_meta(aten.masked_scatter_)
3814
def meta_masked_scatter_(self, mask, source):
3815
    torch._check(
3816
        mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
3817
    )
3818
    torch._check(
3819
        self.dtype == source.dtype,
3820
        lambda: "masked_scatter: expected self and source to have same "
3821
        "dtypes but got {self.dtype} and {source.dtype}",
3822
    )
3823
    return self
3824

3825

3826
@register_meta(aten.masked_scatter)
3827
@out_wrapper()
3828
def meta_masked_scatter(self, mask, source):
3829
    self, mask = _maybe_broadcast(self, mask)
3830
    output = torch.empty_like(self, memory_format=torch.contiguous_format)
3831
    return meta_masked_scatter_(output, mask, source)
3832

3833

3834
@register_meta(aten.masked_scatter_backward)
3835
def meta_masked_scatter_backward(self, mask, sizes):
3836
    return self.new_empty(sizes)
3837

3838

3839
@register_meta(aten.index_put_.default)
3840
def meta_index_put_(self, indices, values, accumulate=False):
3841
    return self
3842

3843

3844
@register_meta(aten.alias.default)
3845
def meta_alias(self):
3846
    return self.view(self.shape)
3847

3848

3849
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
3850
    torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3851
    torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3852

3853
    batch1_sizes = batch1.size()
3854
    batch2_sizes = batch2.size()
3855

3856
    bs = batch1_sizes[0]
3857
    contraction_size = batch1_sizes[2]
3858
    res_rows = batch1_sizes[1]
3859
    res_cols = batch2_sizes[2]
3860
    output_size = (bs, res_rows, res_cols)
3861

3862
    torch._check(
3863
        batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
3864
        lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
3865
        f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
3866
    )
3867

3868
    # TODO: handle out
3869

3870
    output = batch2.new_empty(output_size)
3871

3872
    if not is_bmm and self_baddbmm is not None:
3873
        torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
3874
        torch._check(
3875
            self_baddbmm.size() == output_size,
3876
            lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
3877
        )
3878

3879
    return output
3880

3881

3882
@register_meta(aten.bmm.default)
3883
def meta_bmm(self, mat2):
3884
    return common_meta_baddbmm_bmm(self, mat2, True)
3885

3886

3887
def div_rtn(x, y):
3888
    q = x // y
3889
    r = x % y
3890
    # WARNING: explicit bool conversion here is necessary;
3891
    # would be fixed by SymBool
3892
    if r != 0 and (bool(r < 0) != bool(y < 0)):
3893
        q -= 1
3894
    return q
3895

3896

3897
def pooling_output_shape_pad_lr(
3898
    inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
3899
):
3900
    outputSize = (
3901
        div_rtn(
3902
            inputSize
3903
            + pad_l
3904
            + pad_r
3905
            - dilation * (kernelSize - 1)
3906
            - 1
3907
            + (stride - 1 if ceil_mode else 0),
3908
            stride,
3909
        )
3910
        + 1
3911
    )
3912
    if ceil_mode:
3913
        if (outputSize - 1) * stride >= inputSize + pad_l:
3914
            outputSize -= 1
3915
    return outputSize
3916

3917

3918
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
3919
    torch._check(stride != 0, lambda: "stride should not be zero")
3920
    torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
3921
    torch._check(
3922
        pad <= ((kernelSize - 1) * dilation + 1) // 2,
3923
        lambda: (
3924
            f"pad should be at most half of effective kernel size, but got pad={pad}, "
3925
            f"kernel_size={kernelSize} and dilation={dilation}"
3926
        ),
3927
    )
3928
    return pooling_output_shape_pad_lr(
3929
        inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
3930
    )
3931

3932

3933
def pool2d_shape_check(
3934
    input,
3935
    kH,
3936
    kW,
3937
    dH,
3938
    dW,
3939
    padH,
3940
    padW,
3941
    dilationH,
3942
    dilationW,
3943
    nInputPlane,
3944
    inputHeight,
3945
    inputWidth,
3946
    outputHeight,
3947
    outputWidth,
3948
    memory_format,
3949
):
3950
    ndim = input.dim()
3951
    nOutputPlane = nInputPlane
3952

3953
    torch._check(
3954
        kW > 0 and kH > 0,
3955
        lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
3956
    )
3957
    torch._check(
3958
        dW > 0 and dH > 0,
3959
        lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
3960
    )
3961
    torch._check(
3962
        dilationH > 0 and dilationW > 0,
3963
        lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
3964
    )
3965

3966
    valid_dims = input.size(1) != 0 and input.size(2) != 0
3967

3968
    if memory_format == torch.channels_last:
3969
        torch._check(
3970
            ndim == 4 and valid_dims and input.size(3) != 0,
3971
            lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
3972
            " with optional 0 dim batch size for input, but got: {input.size()}",
3973
        )
3974
    else:
3975
        torch._check(
3976
            (ndim == 3 and input.size(0) != 0 and valid_dims)
3977
            or (ndim == 4 and valid_dims and input.size(3) != 0),
3978
            lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
3979
        )
3980

3981
    torch._check(
3982
        kW // 2 >= padW and kH // 2 >= padH,
3983
        lambda: "pad should be smaller than or equal to half of kernel size, but got "
3984
        f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
3985
    )
3986

3987
    torch._check(
3988
        outputWidth >= 1 and outputHeight >= 1,
3989
        lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
3990
        f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
3991
        "Output size is too small",
3992
    )
3993

3994

3995
def pool3d_shape_check(
3996
    input: Tensor,
3997
    nslices: int,
3998
    kT: int,
3999
    kH: int,
4000
    kW: int,
4001
    dT: int,
4002
    dH: int,
4003
    dW: int,
4004
    pT: int,
4005
    pH: int,
4006
    pW: int,
4007
    dilationT: int,
4008
    dilationH: int,
4009
    dilationW: int,
4010
    itime: int,
4011
    iheight: int,
4012
    iwidth: int,
4013
    otime: int,
4014
    oheight: int,
4015
    owidth: int,
4016
    fn_name: str,
4017
    check_input_size: bool = False,
4018
):
4019
    ndim = input.ndim
4020

4021
    torch._check(
4022
        kT > 0 and kW > 0 and kH > 0,
4023
        lambda: (
4024
            f"kernel size should be greater than zero, but got "
4025
            f"kT: {kT}, kH: {kH}, kW: {kW}"
4026
        ),
4027
    )
4028
    torch._check(
4029
        dT > 0 and dW > 0 and dH > 0,
4030
        lambda: (
4031
            f"stride should be greater than zero, but got "
4032
            f"dT: {dT}, dH: {dH}, dW: {dW}"
4033
        ),
4034
    )
4035
    torch._check(
4036
        dilationT > 0 and dilationW > 0 and dilationH > 0,
4037
        lambda: (
4038
            f"dilation should be greater than zero, but got "
4039
            f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
4040
        ),
4041
    )
4042

4043
    torch._check(
4044
        ndim in (4, 5),
4045
        lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
4046
    )
4047

4048
    for i in range(ndim):
4049
        if ndim == 5 and i == 0:
4050
            # size of batch-dim can be 0.
4051
            continue
4052
        torch._check(
4053
            input.size(i) > 0,
4054
            lambda: (
4055
                f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
4056
                f" but input has a shape of {input.shape}"
4057
                f" and non-batch dimension {input.size(i)} has length zero!"
4058
            ),
4059
        )
4060

4061
    if check_input_size:  # AveragePool3d
4062
        torch._check(
4063
            itime >= kT and iheight >= kH and iwidth >= kW,
4064
            lambda: (
4065
                f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
4066
                f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
4067
            ),
4068
        )
4069

4070
    torch._check(
4071
        kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
4072
        lambda: (
4073
            f"pad should be smaller than or equal to half of kernel size, but got "
4074
            f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
4075
        ),
4076
    )
4077

4078
    torch._check(
4079
        otime >= 1 and owidth >= 1 and oheight >= 1,
4080
        lambda: (
4081
            f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
4082
            f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
4083
            f"Output size is too small"
4084
        ),
4085
    )
4086

4087

4088
def max_pool3d_backward_shape_check(
4089
    input,
4090
    grad_output,
4091
    indices,
4092
    nslices,
4093
    kT,
4094
    kH,
4095
    kW,
4096
    dT,
4097
    dH,
4098
    dW,
4099
    pT,
4100
    pH,
4101
    pW,
4102
    dilationT,
4103
    dilationH,
4104
    dilationW,
4105
    itime,
4106
    iheight,
4107
    iwidth,
4108
    otime,
4109
    oheight,
4110
    owidth,
4111
    fn_name,
4112
):
4113
    ndim = input.ndim
4114

4115
    pool3d_shape_check(
4116
        input,
4117
        nslices,
4118
        kT,
4119
        kH,
4120
        kW,
4121
        dT,
4122
        dH,
4123
        dW,
4124
        pT,
4125
        pH,
4126
        pW,
4127
        dilationT,
4128
        dilationH,
4129
        dilationW,
4130
        itime,
4131
        iheight,
4132
        iwidth,
4133
        otime,
4134
        oheight,
4135
        owidth,
4136
        fn_name,
4137
    )
4138

4139
    check_dim_size(grad_output, ndim, ndim - 4, nslices)
4140
    check_dim_size(grad_output, ndim, ndim - 3, otime)
4141
    check_dim_size(grad_output, ndim, ndim - 2, oheight)
4142
    check_dim_size(grad_output, ndim, ndim - 1, owidth)
4143

4144
    check_dim_size(indices, ndim, ndim - 4, nslices)
4145
    check_dim_size(indices, ndim, ndim - 3, otime)
4146
    check_dim_size(indices, ndim, ndim - 2, oheight)
4147
    check_dim_size(indices, ndim, ndim - 1, owidth)
4148

4149

4150
def avg_pool3d_backward_shape_check(
4151
    input: Tensor,
4152
    grad_output: Tensor,
4153
    nslices: int,
4154
    kT: int,
4155
    kH: int,
4156
    kW: int,
4157
    dT: int,
4158
    dH: int,
4159
    dW: int,
4160
    pT: int,
4161
    pH: int,
4162
    pW: int,
4163
    itime: int,
4164
    iheight: int,
4165
    iwidth: int,
4166
    otime: int,
4167
    oheight: int,
4168
    owidth: int,
4169
    fn_name: str,
4170
):
4171
    ndim = input.ndim
4172

4173
    pool3d_shape_check(
4174
        input,
4175
        nslices,
4176
        kT,
4177
        kH,
4178
        kW,
4179
        dT,
4180
        dH,
4181
        dW,
4182
        pT,
4183
        pH,
4184
        pW,
4185
        1,
4186
        1,
4187
        1,
4188
        itime,
4189
        iheight,
4190
        iwidth,
4191
        otime,
4192
        oheight,
4193
        owidth,
4194
        fn_name,
4195
        True,
4196
    )
4197

4198
    check_dim_size(grad_output, ndim, ndim - 4, nslices)
4199
    check_dim_size(grad_output, ndim, ndim - 3, otime)
4200
    check_dim_size(grad_output, ndim, ndim - 2, oheight)
4201
    check_dim_size(grad_output, ndim, ndim - 1, owidth)
4202

4203

4204
def max_pool2d_checks_and_compute_shape(
4205
    input, kernel_size, stride, padding, dilation, ceil_mode
4206
):
4207
    # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
4208
    def unpack(name, val):
4209
        torch._check(
4210
            len(val) in [1, 2],
4211
            lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
4212
        )
4213
        H = val[0]
4214
        W = H if len(val) == 1 else val[1]
4215
        return H, W
4216

4217
    kH, kW = unpack("kernel_size", kernel_size)
4218

4219
    torch._check(
4220
        len(stride) in [0, 1, 2],
4221
        lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
4222
    )
4223
    if len(stride) == 0:
4224
        dH, dW = kH, kW
4225
    else:
4226
        dH, dW = unpack("stride", stride)
4227

4228
    padH, padW = unpack("padding", padding)
4229
    dilationH, dilationW = unpack("dilation", dilation)
4230
    nInputPlane = input.size(-3)
4231
    inputHeight = input.size(-2)
4232
    inputWidth = input.size(-1)
4233

4234
    memory_format = utils.suggest_memory_format(input)
4235
    if memory_format == torch.channels_last:
4236
        torch._check(
4237
            input.dim() == 4,
4238
            lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
4239
        )
4240
    elif memory_format == torch.contiguous_format:
4241
        torch._check(
4242
            input.dim() in [3, 4],
4243
            lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
4244
        )
4245
    else:
4246
        torch._check(
4247
            False,
4248
            lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
4249
        )
4250

4251
    outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
4252
    outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
4253

4254
    pool2d_shape_check(
4255
        input,
4256
        kH,
4257
        kW,
4258
        dH,
4259
        dW,
4260
        padH,
4261
        padW,
4262
        dilationH,
4263
        dilationW,
4264
        nInputPlane,
4265
        inputHeight,
4266
        inputWidth,
4267
        outputHeight,
4268
        outputWidth,
4269
        memory_format,
4270
    )
4271

4272
    return nInputPlane, outputHeight, outputWidth
4273

4274

4275
@register_meta(aten.max_pool2d_with_indices_backward.default)
4276
def meta_max_pool2d_with_indices_backward(
4277
    grad_output,
4278
    self,
4279
    kernel_size,
4280
    stride,
4281
    padding,
4282
    dilation,
4283
    ceil_mode,
4284
    indices,
4285
):
4286
    (
4287
        nInputPlane,
4288
        outputHeight,
4289
        outputWidth,
4290
    ) = max_pool2d_checks_and_compute_shape(
4291
        self, kernel_size, stride, padding, dilation, ceil_mode
4292
    )
4293

4294
    torch._check(
4295
        self.dtype == grad_output.dtype,
4296
        lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
4297
    )
4298

4299
    nOutputPlane = nInputPlane
4300
    ndim = self.ndim
4301

4302
    def _check_dim_size(t):
4303
        check_dim_size(t, ndim, ndim - 3, nOutputPlane)
4304
        check_dim_size(t, ndim, ndim - 2, outputHeight)
4305
        check_dim_size(t, ndim, ndim - 1, outputWidth)
4306

4307
    _check_dim_size(grad_output)
4308
    _check_dim_size(indices)
4309

4310
    memory_format = utils.suggest_memory_format(self)
4311
    return torch.empty(
4312
        self.shape,
4313
        dtype=self.dtype,
4314
        device=self.device,
4315
        memory_format=memory_format,
4316
    )
4317

4318

4319
@register_meta(aten.max_pool2d_with_indices.default)
4320
def meta_max_pool2d_with_indices(
4321
    input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
4322
):
4323
    (
4324
        nInputPlane,
4325
        outputHeight,
4326
        outputWidth,
4327
    ) = max_pool2d_checks_and_compute_shape(
4328
        input, kernel_size, stride, padding, dilation, ceil_mode
4329
    )
4330

4331
    nbatch = input.size(-4) if input.dim() == 4 else 1
4332
    memory_format = utils.suggest_memory_format(input)
4333
    if input.dim() == 3:
4334
        size = [nInputPlane, outputHeight, outputWidth]
4335
    else:
4336
        size = [nbatch, nInputPlane, outputHeight, outputWidth]
4337
    return (
4338
        torch.empty(
4339
            size,
4340
            dtype=input.dtype,
4341
            device=input.device,
4342
            memory_format=memory_format,
4343
        ),
4344
        torch.empty(
4345
            size,
4346
            dtype=torch.int64,
4347
            device=input.device,
4348
            memory_format=memory_format,
4349
        ),
4350
    )
4351

4352

4353
@register_meta(aten.fractional_max_pool2d.default)
4354
def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
4355
    torch._check(
4356
        self_.ndim in (3, 4),
4357
        lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}",
4358
    )
4359
    ndim = self_.ndim
4360

4361
    for d in range(ndim - 3, ndim):
4362
        torch._check(
4363
            self_.size(d) > 0,
4364
            f"fractional_max_pool2d: Expected input to have non-zero "
4365
            f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty",
4366
        )
4367

4368
    # the check and message are out of sync, but this matches the structured meta
4369
    torch._check(
4370
        len(kernel_size) == 2,
4371
        lambda: "fractional_max_pool2d: kernel_size must"
4372
        "either be a single int or tuple of Ints",
4373
    )
4374
    torch._check(
4375
        len(output_size) == 2,
4376
        lambda: "fractional_max_pool2d: output_size must "
4377
        "either be a single int or tuple of Ints",
4378
    )
4379

4380
    input_channels = self_.size(-3)
4381
    input_height = self_.size(-2)
4382
    input_width = self_.size(-1)
4383
    if ndim == 4:
4384
        input_batch = self_.size(0)
4385
    else:
4386
        input_batch = 1
4387

4388
    torch._check(
4389
        self_.dtype == random_samples.dtype,
4390
        lambda: "Expect _random_samples to have the same dtype as input",
4391
    )
4392
    torch._check(
4393
        random_samples.ndim == 3,
4394
        lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
4395
    )
4396

4397
    n = random_samples.size(0)
4398
    c = random_samples.size(1)
4399
    d = random_samples.size(2)
4400
    torch._check(
4401
        n >= input_batch,
4402
        "Expect _random_samples.size(0) no less then input batch size.",
4403
    )
4404
    torch._check(
4405
        c == input_channels,
4406
        lambda: "Expect _random_samples.size(1) equals to input channel size.",
4407
    )
4408
    torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
4409

4410
    torch._check(
4411
        output_size[0] + kernel_size[0] - 1 <= input_height,
4412
        lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
4413
    )
4414
    torch._check(
4415
        output_size[1] + kernel_size[1] - 1 <= input_width,
4416
        lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
4417
    )
4418

4419
    if self_.dim() == 4:
4420
        size = [input_batch, input_channels, output_size[0], output_size[1]]
4421
    else:
4422
        size = [input_channels, output_size[0], output_size[1]]
4423

4424
    return (
4425
        torch.empty(
4426
            size,
4427
            dtype=self_.dtype,
4428
            device=self_.device,
4429
        ),
4430
        torch.empty(
4431
            size,
4432
            dtype=torch.int64,
4433
            device=self_.device,
4434
        ),
4435
    )
4436

4437

4438
@register_meta(aten.max_unpool2d)
4439
@out_wrapper()
4440
def meta_max_unpool2d(self_, indices, output_size):
4441
    utils.alert_not_deterministic("max_unpooling2d_forward_out")
4442

4443
    torch._check(
4444
        indices.dtype == torch.int64,
4445
        lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4446
    )
4447
    torch._check(
4448
        len(output_size) == 2,
4449
        lambda: (
4450
            f"There should be exactly two elements (height, width) in output_size, "
4451
            f"but got {len(output_size)} elements."
4452
        ),
4453
    )
4454

4455
    oheight, owidth = output_size
4456

4457
    torch._check(
4458
        self_.ndim in (3, 4),
4459
        lambda: (
4460
            f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4461
            f"but got a tensor with {self_.ndim} dimensions."
4462
        ),
4463
    )
4464
    torch._check(
4465
        self_.shape == indices.shape,
4466
        lambda: (
4467
            f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
4468
            f"but got indices tensor with shape: {indices.shape}"
4469
        ),
4470
    )
4471

4472
    for i in range(1, self_.ndim):
4473
        torch._check(
4474
            self_.size(i) > 0,
4475
            lambda: (
4476
                f"max_unpooling2d(): "
4477
                f"Expected input to have non-zero size for non-batch dimensions, "
4478
                f"but got {self_.shape} with dimension {i} being empty."
4479
            ),
4480
        )
4481

4482
    self = self_.contiguous()
4483

4484
    if self_.ndim == 3:
4485
        nchannels = self.size(0)
4486
        result = self.new_empty((nchannels, oheight, owidth))
4487
    else:
4488
        nbatch = self.size(0)
4489
        nchannels = self.size(1)
4490
        result = self.new_empty((nbatch, nchannels, oheight, owidth))
4491

4492
    return result
4493

4494

4495
def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4496
    torch._check(
4497
        indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4498
    )
4499
    torch._check(
4500
        input.ndim in (4, 5),
4501
        lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4502
    )
4503
    torch._check(
4504
        len(output_size) == 3,
4505
        lambda: (
4506
            f"There should be exactly three elements (depth, height, width) in output_size, "
4507
            f"but got {len(output_size)} elements."
4508
        ),
4509
    )
4510
    torch._check(
4511
        len(stride) == 3,
4512
        lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4513
    )
4514
    torch._check(
4515
        len(padding) == 3,
4516
        lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4517
    )
4518
    torch._check(
4519
        input.shape == indices.shape,
4520
        lambda: (
4521
            f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4522
            f"but got indices tensor with shape: {indices.shape}"
4523
        ),
4524
    )
4525

4526
    for i in range(1, input.ndim):
4527
        torch._check(
4528
            input.size(i) > 0,
4529
            lambda: (
4530
                f"{fn_name}: "
4531
                f"Expected input to have non-zero size for non-batch dimensions, "
4532
                f"but got {input.shape} with dimension {i} being empty."
4533
            ),
4534
        )
4535

4536
    torch._check(
4537
        stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4538
        lambda: f"strides should be greater than zero, but got stride: {stride}",
4539
    )
4540

4541

4542
@register_meta(aten.max_unpool3d)
4543
@out_wrapper()
4544
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
4545
    utils.alert_not_deterministic("max_unpooling3d_forward_out")
4546

4547
    _max_unpooling3d_shape_check(
4548
        self_, indices, output_size, stride, padding, "max_unpooling3d()"
4549
    )
4550

4551
    self = self_.contiguous()
4552

4553
    odepth, oheight, owidth = output_size
4554

4555
    if self_.ndim == 4:
4556
        nchannels = self.size(0)
4557
        result = self.new_empty((nchannels, odepth, oheight, owidth))
4558
    else:
4559
        nbatch = self.size(0)
4560
        nchannels = self.size(1)
4561
        result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4562

4563
    return result
4564

4565

4566
@register_meta(aten.max_pool3d_with_indices)
4567
@out_wrapper("out", "indices")
4568
def meta_max_pool3d_with_indices(
4569
    input,
4570
    kernel_size,
4571
    stride=(),
4572
    padding=(0,),
4573
    dilation=(1,),
4574
    ceil_mode=False,
4575
):
4576
    torch._check(
4577
        len(kernel_size) in (1, 3),
4578
        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4579
    )
4580
    kT = kernel_size[0]
4581
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4582
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4583

4584
    torch._check(
4585
        not stride or len(stride) in (1, 3),
4586
        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4587
    )
4588
    dT = kT if not stride else stride[0]
4589
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4590
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4591

4592
    torch._check(
4593
        len(padding) in (1, 3),
4594
        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4595
    )
4596
    pT = padding[0]
4597
    pH = pT if len(padding) == 1 else padding[1]
4598
    pW = pT if len(padding) == 1 else padding[2]
4599

4600
    torch._check(
4601
        len(dilation) in (1, 3),
4602
        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4603
    )
4604
    dilationT = dilation[0]
4605
    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4606
    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4607

4608
    torch._check(
4609
        input.ndim in (4, 5),
4610
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4611
    )
4612

4613
    nbatch = input.size(-5) if input.ndim == 5 else 1
4614
    nslices = input.size(-4)
4615
    itime = input.size(-3)
4616
    iheight = input.size(-2)
4617
    iwidth = input.size(-1)
4618

4619
    otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
4620
    oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
4621
    owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
4622

4623
    pool3d_shape_check(
4624
        input,
4625
        nslices,
4626
        kT,
4627
        kH,
4628
        kW,
4629
        dT,
4630
        dH,
4631
        dW,
4632
        pT,
4633
        pH,
4634
        pW,
4635
        dilationT,
4636
        dilationH,
4637
        dilationW,
4638
        itime,
4639
        iheight,
4640
        iwidth,
4641
        otime,
4642
        oheight,
4643
        owidth,
4644
        "max_pool3d_with_indices()",
4645
    )
4646

4647
    channels_last = (
4648
        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4649
    )
4650
    if input.ndim == 4:
4651
        input_channels_last_check = input.unsqueeze(0)
4652
        channels_last = (
4653
            not input_channels_last_check.is_contiguous()
4654
        ) and input_channels_last_check.is_contiguous(
4655
            memory_format=torch.channels_last_3d
4656
        )
4657
        out_shape = (nslices, otime, oheight, owidth)
4658
    else:
4659
        out_shape = (nbatch, nslices, otime, oheight, owidth)  # type: ignore[assignment]
4660

4661
    out = input.new_empty(out_shape)
4662
    indices = input.new_empty(out_shape, dtype=torch.int64)
4663

4664
    if channels_last:
4665
        out = out.to(memory_format=torch.channels_last_3d)
4666
        indices = indices.to(memory_format=torch.channels_last_3d)
4667

4668
    return out, indices
4669

4670

4671
@register_meta(aten.max_pool3d_with_indices_backward)
4672
@out_wrapper("grad_input")
4673
def meta_max_pool3d_with_indices_backward(
4674
    grad_output,
4675
    input,
4676
    kernel_size,
4677
    stride,
4678
    padding,
4679
    dilation,
4680
    ceil_mode,
4681
    indices,
4682
):
4683
    torch._check(
4684
        len(kernel_size) in (1, 3),
4685
        lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4686
    )
4687
    kT = kernel_size[0]
4688
    kH = kT if len(kernel_size) == 1 else kernel_size[1]
4689
    kW = kT if len(kernel_size) == 1 else kernel_size[2]
4690

4691
    torch._check(
4692
        not stride or len(stride) in (1, 3),
4693
        lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4694
    )
4695
    dT = kT if not stride else stride[0]
4696
    dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4697
    dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4698

4699
    torch._check(
4700
        len(padding) in (1, 3),
4701
        lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4702
    )
4703
    pT = padding[0]
4704
    pH = pT if len(padding) == 1 else padding[1]
4705
    pW = pT if len(padding) == 1 else padding[2]
4706

4707
    torch._check(
4708
        len(dilation) in (1, 3),
4709
        lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4710
    )
4711
    dilationT = dilation[0]
4712
    dilationH = dilationT if len(dilation) == 1 else dilation[1]
4713
    dilationW = dilationT if len(dilation) == 1 else dilation[2]
4714

4715
    torch._check(
4716
        input.ndim in (4, 5),
4717
        lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4718
    )
4719

4720
    nslices = input.size(-4)
4721
    itime = input.size(-3)
4722
    iheight = input.size(-2)
4723
    iwidth = input.size(-1)
4724

4725
    otime = grad_output.size(-3)
4726
    oheight = grad_output.size(-2)
4727
    owidth = grad_output.size(-1)
4728

4729
    max_pool3d_backward_shape_check(
4730
        input,
4731
        grad_output,
4732
        indices,
4733
        nslices,
4734
        kT,
4735
        kH,
4736
        kW,
4737
        dT,
4738
        dH,
4739
        dW,
4740
        pT,
4741
        pH,
4742
        pW,
4743
        dilationT,
4744
        dilationH,
4745
        dilationW,
4746
        itime,
4747
        iheight,
4748
        iwidth,
4749
        otime,
4750
        oheight,
4751
        owidth,
4752
        "max_pool3d_with_indices_backward()",
4753
    )
4754

4755
    channels_last = (
4756
        input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4757
    )
4758
    if input.ndim == 4:
4759
        input_channels_last_check = input.unsqueeze(0)
4760
        channels_last = (
4761
            not input_channels_last_check.is_contiguous()
4762
        ) and input_channels_last_check.is_contiguous(
4763
            memory_format=torch.channels_last_3d
4764
        )
4765

4766
    grad_input = input.new_empty(input.shape)
4767

4768
    if channels_last:
4769
        grad_input = grad_input.to(memory_format=torch.channels_last_3d)
4770

4771
    return grad_input
4772

4773

4774
def check_grid_sampler_common(input: Tensor, grid: Tensor):
4775
    torch._check(
4776
        input.device == grid.device,
4777
        lambda: (
4778
            f"grid_sampler(): expected input and grid to be on same device, but input "
4779
            f"is on {input.device} and grid is on {grid.device}"
4780
        ),
4781
    )
4782
    torch._check(
4783
        input.layout == torch.strided and grid.layout == torch.strided,
4784
        lambda: (
4785
            f"grid_sampler(): expected input and grid to have torch.strided layout, but "
4786
            f"input has {input.layout} and grid has {grid.layout}"
4787
        ),
4788
    )
4789
    torch._check(
4790
        input.shape[0] == grid.shape[0],
4791
        lambda: (
4792
            f"grid_sampler(): expected grid and input to have same batch size, but got "
4793
            f"input with sizes {input.shape} and grid with sizes {grid.shape}"
4794
        ),
4795
    )
4796
    torch._check(
4797
        grid.shape[-1] == input.ndim - 2,
4798
        lambda: (
4799
            f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
4800
            f"dimension, but got grid with sizes {grid.shape}"
4801
        ),
4802
    )
4803

4804
    for i in range(2, input.ndim):
4805
        torch._check(
4806
            input.shape[i] > 0,
4807
            lambda: (
4808
                f"grid_sampler(): expected input to have non-empty spatial dimensions, "
4809
                f"but input has sizes {input.shape} with dimension {i} being empty"
4810
            ),
4811
        )
4812

4813

4814
class GridSamplerInterpolation(Enum):
4815
    BILINEAR = 0
4816
    NEAREST = 1
4817
    BICUBIC = 2
4818

4819

4820
def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
4821
    torch._check(
4822
        input.ndim == 5 and input.ndim == grid.ndim,
4823
        lambda: (
4824
            f"grid_sampler(): expected 5D input and grid with same number of "
4825
            f"dimensions, but got input with sizes {input.shape}"
4826
            f" and grid with sizes {grid.shape}"
4827
        ),
4828
    )
4829
    torch._check(
4830
        not (
4831
            input.ndim == 5
4832
            and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
4833
        ),
4834
        lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
4835
    )
4836

4837

4838
@register_meta(aten.grid_sampler_2d_backward.default)
4839
def grid_sampler_2d_backward_meta(
4840
    grad_output,
4841
    input,
4842
    grid,
4843
    interpolation_mode,
4844
    padding_mode,
4845
    align_corners,
4846
    output_mask,
4847
):
4848
    input_requires_grad = output_mask[0]
4849
    if input_requires_grad:
4850
        grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
4851
    else:
4852
        grad_input = None
4853
    grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
4854
    return (grad_input, grad_grid)
4855

4856

4857
@register_meta(aten.grid_sampler_3d)
4858
@out_wrapper()
4859
def grid_sampler_3d(
4860
    input,
4861
    grid,
4862
    interpolation_mode,
4863
    padding_mode,
4864
    align_corners,
4865
):
4866
    check_grid_sampler_common(input, grid)
4867
    check_grid_sampler_3d(input, grid, interpolation_mode)
4868
    N = input.shape[0]
4869
    C = input.shape[1]
4870
    out_D = grid.shape[1]
4871
    out_H = grid.shape[2]
4872
    out_W = grid.shape[3]
4873
    return input.new_empty((N, C, out_D, out_H, out_W))
4874

4875

4876
@register_meta(aten.grid_sampler_3d_backward)
4877
@out_wrapper("grad_input", "grad_grid")
4878
def grid_sampler_3d_backward(
4879
    grad_output,
4880
    input,
4881
    grid,
4882
    interpolation_mode,
4883
    padding_mode,
4884
    align_corners,
4885
    output_mask,
4886
):
4887
    check_grid_sampler_common(input, grid)
4888
    check_grid_sampler_3d(input, grid, interpolation_mode)
4889
    input_requires_grad = output_mask[0]
4890
    if input_requires_grad:
4891
        grad_input = torch.zeros_like(
4892
            input, memory_format=torch.legacy_contiguous_format
4893
        )
4894
    else:
4895
        grad_input = None
4896
    grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
4897
    return grad_input, grad_grid
4898

4899

4900
@register_meta([aten.full.default])
4901
def full(size, fill_value, *args, **kwargs):
4902
    dtype = kwargs.get("dtype", None)
4903
    if not dtype:
4904
        dtype = utils.get_dtype(fill_value)
4905
    kwargs["dtype"] = dtype
4906
    return torch.empty(size, *args, **kwargs)
4907

4908

4909
# zeros_like is special cased to work for sparse
4910
@register_meta(aten.zeros_like.default)
4911
def zeros_like(
4912
    self,
4913
    dtype=None,
4914
    layout=None,
4915
    device=None,
4916
    pin_memory=None,
4917
    memory_format=None,
4918
):
4919
    if layout == torch.sparse_coo:
4920
        torch._check(
4921
            memory_format is None,
4922
            lambda: "memory format option is only supported by strided tensors",
4923
        )
4924

4925
        res = torch.empty(
4926
            0,
4927
            dtype=self.dtype if dtype is None else dtype,
4928
            layout=layout,
4929
            device=self.device if device is None else device,
4930
            pin_memory=pin_memory,
4931
        )
4932

4933
        if self.is_sparse:
4934
            res.sparse_resize_and_clear_(
4935
                self.size(), self.sparse_dim(), self.dense_dim()
4936
            )
4937
        else:
4938
            res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
4939

4940
        res._coalesced_(True)
4941
        return res
4942
    res = aten.empty_like.default(
4943
        self,
4944
        dtype=dtype,
4945
        layout=layout,
4946
        device=device,
4947
        pin_memory=pin_memory,
4948
        memory_format=memory_format,
4949
    )
4950
    # device can be not "meta"
4951
    res.fill_(0)
4952
    return res
4953

4954

4955
@register_meta(aten.select.int)
4956
def meta_select(self, dim, index):
4957
    ndim = self.dim()
4958
    torch._check_index(
4959
        ndim != 0,
4960
        lambda: "select() cannot be applied to a 0-dim tensor.",
4961
    )
4962

4963
    dim = dim if dim >= 0 else dim + ndim
4964
    size = self.size(dim)
4965

4966
    torch._check_index(
4967
        not (-index > size or index >= size),
4968
        lambda: f"select(): index {index} out of range for tensor of size "
4969
        f"{self.size()} at dimension {dim}",
4970
    )
4971

4972
    index = index if index >= 0 else index + size
4973

4974
    new_size = list(self.size())
4975
    new_stride = list(self.stride())
4976

4977
    new_storage_offset = self.storage_offset() + index * new_stride[dim]
4978
    del new_size[dim]
4979
    del new_stride[dim]
4980

4981
    return self.as_strided(new_size, new_stride, new_storage_offset)
4982

4983

4984
@register_meta(aten.select_scatter.default)
4985
def meta_select_scatter(self, src, dim, index):
4986
    return utils.clone_preserve_strides(self)
4987

4988

4989
@register_meta(aten.slice_scatter.default)
4990
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
4991
    return utils.clone_preserve_strides(self)
4992

4993

4994
# TODO: Deduplicate this with canonicalize_dim
4995
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
4996
    if dim_post_expr <= 0:
4997
        assert wrap_scalar
4998
        dim_post_expr = 1
4999
    min = -dim_post_expr
5000
    max = dim_post_expr - 1
5001
    assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
5002
    if dim < 0:
5003
        dim += dim_post_expr
5004
    return dim
5005

5006

5007
def ensure_nonempty_size(t, dim):
5008
    return 1 if t.dim() == 0 else t.shape[dim]
5009

5010

5011
# From aten/src/ATen/native/ScatterGatherChecks.h
5012
def gather_shape_check(self, dim, index):
5013
    self_dims = max(self.dim(), 1)
5014
    index_dims = max(index.dim(), 1)
5015
    torch._check(
5016
        self_dims == index_dims,
5017
        lambda: "Index tensor must have the same number of dimensions as input tensor",
5018
    )
5019
    for i in range(self_dims):
5020
        if i != dim:
5021
            torch._check(
5022
                ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
5023
                lambda: f"Size does not match at dimension {i} expected index {index.shape}"
5024
                + f" to be smaller than self {self.shape} apart from dimension {dim}",
5025
            )
5026

5027

5028
@register_meta(aten.gather.default)
5029
def meta_gather(self, dim, index, sparse_grad=False):
5030
    wrapped_dim = maybe_wrap_dim(dim, self.dim())
5031
    is_index_empty = index.numel() == 0
5032
    if not is_index_empty:
5033
        torch._check(
5034
            index.dtype == torch.long,
5035
            lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
5036
        )
5037
        gather_shape_check(self, wrapped_dim, index)
5038
    return self.new_empty(index.shape)
5039

5040

5041
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
5042
def get_operator_enum(reduce_, use_new_options=False):
5043
    if use_new_options:
5044
        if reduce_ == "sum":
5045
            return "REDUCE_ADD"
5046
        elif reduce_ == "prod":
5047
            return "REDUCE_MULTIPLY"
5048
        elif reduce_ == "mean":
5049
            return "REDUCE_MEAN"
5050
        elif reduce_ == "amax":
5051
            return "REDUCE_MAXIMUM"
5052
        elif reduce_ == "amin":
5053
            return "REDUCE_MINIMUM"
5054
        torch._check(
5055
            False,
5056
            lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
5057
        )
5058
        return
5059
    else:
5060
        if reduce_ == "add":
5061
            return "REDUCE_ADD"
5062
        elif reduce_ == "multiply":
5063
            return "REDUCE_MULTIPLY"
5064
        torch._check(False, lambda: "reduce argument must be either add or multiply.")
5065
        return
5066

5067

5068
# From aten/src/ATen/native/ScatterGatherChecks.h
5069
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
5070
    if index.numel() != 0:
5071
        torch._check(
5072
            index.dtype == torch.long,
5073
            lambda: f"{method_name}(): Expected dtype int64 for index",
5074
        )
5075

5076
    if src_opt is not None:
5077
        torch._check(
5078
            self.dtype == src_opt.dtype,
5079
            lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
5080
        )
5081

5082

5083
def ensure_nonempty_dim(dim):
5084
    return max(dim, 1)
5085

5086

5087
# From aten/src/ATen/native/ScatterGatherChecks.h
5088
def scatter_shape_check(self, dim, index, src_opt=None):
5089
    if index.numel() == 0:
5090
        return
5091
    torch._check(
5092
        ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
5093
        lambda: "Index tensor must have the same number of dimensions as self tensor",
5094
    )
5095

5096
    is_wrong_shape = False
5097
    self_dims = ensure_nonempty_dim(self.dim())
5098

5099
    # Check: index.size(d) <= self.size(d) for all d != dim
5100
    for d in range(self_dims):
5101
        index_d_size = ensure_nonempty_size(index, d)
5102
        if d == dim:
5103
            continue
5104
        if index_d_size > ensure_nonempty_size(self, d):
5105
            is_wrong_shape = True
5106
            break
5107

5108
    # Check: index.size(d) <= src.size(d) for all d if src is Tensor
5109
    if not is_wrong_shape and src_opt is not None:
5110
        for d in range(self_dims):
5111
            index_d_size = ensure_nonempty_size(index, d)
5112
            if index_d_size > ensure_nonempty_size(src_opt, d):
5113
                is_wrong_shape = True
5114
                break
5115

5116
    if src_opt is not None:
5117
        torch._check(
5118
            ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
5119
            lambda: "Index tensor must have the same number of dimensions as self tensor",
5120
        )
5121
        torch._check(
5122
            not is_wrong_shape,
5123
            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
5124
            + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
5125
        )
5126
    else:
5127
        torch._check(
5128
            not is_wrong_shape,
5129
            lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
5130
            + f" apart from dimension {dim}",
5131
        )
5132

5133

5134
# From aten/src/ATen/native/TensorAdvancedIndexing.cpp
5135
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
5136
    wrapped_dim = maybe_wrap_dim(dim, self.dim())
5137
    scatter_gather_dtype_check("scatter", self, index, src)
5138
    scatter_shape_check(self, wrapped_dim, index, src)
5139
    if reduce_ is not None:
5140
        # Check if we have a valid reduce operator.
5141
        get_operator_enum(reduce_, use_new_options)
5142

5143

5144
@register_meta(aten.scatter_add.default)
5145
def meta_scatter_add(self, dim, index, src):
5146
    scatter_meta_impl(self, dim, index, src, "add")
5147
    return self.new_empty(self.shape)
5148

5149

5150
@register_meta(aten.scatter_add_)
5151
def meta_scatter_add_(self, dim, index, src):
5152
    scatter_meta_impl(self, dim, index, src, "add")
5153
    return self
5154

5155

5156
@register_meta(
5157
    [
5158
        aten.scatter.src,
5159
        aten.scatter.value,
5160
        aten.scatter.reduce,
5161
        aten.scatter.value_reduce,
5162
    ]
5163
)
5164
@out_wrapper()
5165
def meta_scatter(self, dim, index, src_or_value, reduce=None):
5166
    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5167
    scatter_meta_impl(self, dim, index, src, reduce)
5168
    return self.new_empty(self.shape)
5169

5170

5171
@register_meta(
5172
    [
5173
        aten.scatter_.src,
5174
        aten.scatter_.value,
5175
        aten.scatter_.reduce,
5176
        aten.scatter_.value_reduce,
5177
    ]
5178
)
5179
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
5180
    src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5181
    scatter_meta_impl(self, dim, index, src, reduce)
5182
    return self
5183

5184

5185
@register_meta(
5186
    [
5187
        aten._scaled_dot_product_flash_attention_backward,
5188
    ]
5189
)
5190
def meta__scaled_dot_product_flash_backward(
5191
    grad_out: Tensor,
5192
    query: Tensor,
5193
    key: Tensor,
5194
    value: Tensor,
5195
    out: Tensor,
5196
    logsumexp: Tensor,
5197
    cum_seq_q: Tensor,
5198
    cum_seq_k: Tensor,
5199
    max_q: int,
5200
    max_k: int,
5201
    dropout_p: float,
5202
    is_causal: bool,
5203
    philox_seed: Tensor,
5204
    philox_offset: Tensor,
5205
    scale: Optional[float] = None,
5206
):
5207
    grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
5208
    grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
5209
    grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
5210
    return grad_q, grad_k, grad_v
5211

5212

5213
@register_meta(
5214
    [
5215
        aten._scaled_dot_product_flash_attention_for_cpu,
5216
    ]
5217
)
5218
def meta__scaled_dot_product_flash_attention_for_cpu(
5219
    query: Tensor,
5220
    key: Tensor,
5221
    value: Tensor,
5222
    dropout_p: float = 0.0,
5223
    is_causal: bool = False,
5224
    attn_mask: Optional[Tensor] = None,
5225
    scale: Optional[float] = None,
5226
):
5227
    batch_size = query.size(0)
5228
    num_heads = query.size(1)
5229
    max_seqlen_batch_q = query.size(2)
5230
    head_dim = query.size(3)
5231

5232
    attention = torch.empty(
5233
        (batch_size, max_seqlen_batch_q, num_heads, head_dim),
5234
        dtype=query.dtype,
5235
        device=query.device,
5236
    ).transpose(1, 2)
5237
    logsumexp = torch.empty(
5238
        (
5239
            batch_size,
5240
            max_seqlen_batch_q,
5241
            num_heads,
5242
        ),
5243
        dtype=torch.float,
5244
        device=query.device,
5245
    ).transpose(1, 2)
5246
    return (
5247
        attention,
5248
        logsumexp,
5249
    )
5250

5251

5252
@register_meta(
5253
    [
5254
        aten._scaled_dot_product_flash_attention_for_cpu_backward,
5255
    ]
5256
)
5257
def meta__scaled_dot_product_flash_attention_for_cpu_backward(
5258
    grad_out: Tensor,
5259
    query: Tensor,
5260
    key: Tensor,
5261
    value: Tensor,
5262
    out: Tensor,
5263
    logsumexp: Tensor,
5264
    dropout_p: float,
5265
    is_causal: bool,
5266
    attn_mask: Optional[Tensor] = None,
5267
    scale: Optional[float] = None,
5268
):
5269
    # cpus's grad layout is different from cuda's,
5270
    # i.e. (batch_size, seq_len,num_heads, head_dim)
5271
    batch_size = query.size(0)
5272
    num_heads = query.size(1)
5273
    head_dim = query.size(3)
5274
    len_q = query.size(2)
5275
    len_k = key.size(2)
5276

5277
    grad_q = torch.empty_permuted(
5278
        (batch_size, num_heads, len_q, head_dim),
5279
        (0, 2, 1, 3),
5280
        dtype=query.dtype,
5281
        device=query.device,
5282
    )
5283
    grad_k = torch.empty_permuted(
5284
        (batch_size, num_heads, len_k, head_dim),
5285
        (0, 2, 1, 3),
5286
        dtype=key.dtype,
5287
        device=key.device,
5288
    )
5289
    grad_v = torch.empty_permuted(
5290
        (batch_size, num_heads, len_k, head_dim),
5291
        (0, 2, 1, 3),
5292
        dtype=value.dtype,
5293
        device=value.device,
5294
    )
5295

5296
    return grad_q, grad_k, grad_v
5297

5298

5299
@register_meta(
5300
    [
5301
        aten._scaled_dot_product_efficient_attention_backward,
5302
    ]
5303
)
5304
def meta__scaled_dot_product_efficient_backward(
5305
    grad_out: Tensor,
5306
    query: Tensor,
5307
    key: Tensor,
5308
    value: Tensor,
5309
    attn_bias: Optional[Tensor],
5310
    out: Tensor,
5311
    logsumexp: Tensor,
5312
    philox_seed: Tensor,
5313
    philox_offset: Tensor,
5314
    dropout_p: float,
5315
    grad_input_mask: List[bool],
5316
    is_causal: bool = False,
5317
    scale: Optional[float] = None,
5318
):
5319
    batch_size = query.size(0)
5320
    num_heads = query.size(1)
5321
    max_q = query.size(2)
5322
    head_dim = query.size(3)
5323
    head_dim_v = value.size(3)
5324

5325
    max_k = key.size(2)
5326

5327
    grad_q = torch.empty_permuted(
5328
        (batch_size, num_heads, max_q, head_dim),
5329
        (0, 2, 1, 3),
5330
        dtype=query.dtype,
5331
        device=query.device,
5332
    )
5333
    grad_k = torch.empty_permuted(
5334
        (batch_size, num_heads, max_k, head_dim),
5335
        (0, 2, 1, 3),
5336
        dtype=key.dtype,
5337
        device=key.device,
5338
    )
5339
    grad_v = torch.empty_permuted(
5340
        (batch_size, num_heads, max_k, head_dim_v),
5341
        (0, 2, 1, 3),
5342
        dtype=value.dtype,
5343
        device=value.device,
5344
    )
5345
    grad_bias = None
5346
    if attn_bias is not None and grad_input_mask[3]:
5347
        lastDim = attn_bias.size(-1)
5348
        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5349
        new_sizes = list(attn_bias.size())
5350
        new_sizes[-1] = lastDimAligned
5351
        grad_bias = torch.empty(
5352
            new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5353
        )
5354
        grad_bias = grad_bias[..., :lastDim]
5355

5356
    return grad_q, grad_k, grad_v, grad_bias
5357

5358

5359
@register_meta(
5360
    [
5361
        aten._flash_attention_backward,
5362
    ]
5363
)
5364
def meta__flash_attention_backward(
5365
    grad_out: Tensor,
5366
    query: Tensor,
5367
    key: Tensor,
5368
    value: Tensor,
5369
    out: Tensor,
5370
    logsumexp: Tensor,
5371
    cum_seq_q: Tensor,
5372
    cum_seq_k: Tensor,
5373
    max_q: int,
5374
    max_k: int,
5375
    dropout_p: float,
5376
    is_causal: bool,
5377
    philox_seed: Tensor,
5378
    philox_offset: Tensor,
5379
    scale: Optional[float] = None,
5380
):
5381
    grad_query = torch.empty_like(query)
5382
    grad_key = torch.empty_like(key)
5383
    grad_value = torch.empty_like(value)
5384

5385
    return grad_query, grad_key, grad_value
5386

5387

5388
@register_meta(
5389
    [
5390
        aten._efficient_attention_backward,
5391
    ]
5392
)
5393
def meta__efficient_attention_backward(
5394
    grad_out: Tensor,
5395
    query: Tensor,
5396
    key: Tensor,
5397
    value: Tensor,
5398
    bias: Optional[Tensor],
5399
    cu_seqlens_q: Optional[Tensor],
5400
    cu_seqlens_k: Optional[Tensor],
5401
    max_seqlen_q: int,
5402
    max_seqlen_k: int,
5403
    logsumexp: Tensor,
5404
    dropout_p: float,
5405
    philox_seed: Tensor,
5406
    philox_offset: Tensor,
5407
    custom_mask_type: int,
5408
    bias_requires_grad: bool,
5409
    scale: Optional[float] = None,
5410
    num_splits_key: Optional[int] = None,
5411
):
5412
    grad_query = torch.empty_like(query)
5413
    grad_key = torch.empty_like(key)
5414
    grad_value = torch.empty_like(value)
5415

5416
    if bias is not None:
5417
        lastDim = bias.size(-1)
5418
        lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5419
        new_sizes = list(bias.size())
5420
        new_sizes[-1] = lastDimAligned
5421
        grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
5422
        grad_bias = grad_bias[..., :lastDim]
5423
    else:
5424
        grad_bias = torch.empty((), device=query.device)
5425

5426
    return grad_query, grad_key, grad_value, grad_bias
5427

5428

5429
@register_meta([aten._scaled_mm.default])
5430
def meta_scaled_mm(
5431
    self: torch.Tensor,
5432
    mat2: torch.Tensor,
5433
    bias: Optional[torch.Tensor] = None,
5434
    out_dtype: Optional[torch.dtype] = None,
5435
    scale_a: Optional[torch.Tensor] = None,
5436
    scale_b: Optional[torch.Tensor] = None,
5437
    scale_result: Optional[torch.Tensor] = None,
5438
    use_fast_accum: bool = False,
5439
):
5440
    def is_row_major(stride):
5441
        return stride[0] > stride[1] and stride[1] == 1
5442

5443
    def is_col_major(shape, stride):
5444
        return stride[0] == 1 and stride[1] == shape[0]
5445

5446
    def is_fp8_type(dtype):
5447
        return dtype in (
5448
            torch.float8_e4m3fn,
5449
            torch.float8_e5m2,
5450
            torch.float8_e4m3fnuz,
5451
            torch.float8_e5m2fnuz,
5452
        )
5453

5454
    torch._check(
5455
        self.dim() == 2 and mat2.dim() == 2,
5456
        lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5457
    )
5458
    torch._check(
5459
        is_row_major(self.stride()),
5460
        lambda: "self must be row_major",
5461
    )
5462
    torch._check(
5463
        is_col_major(mat2.shape, mat2.stride()),
5464
        lambda: "mat2 must be col_major",
5465
    )
5466
    torch._check(
5467
        self.size(1) % 16 == 0,
5468
        lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
5469
    )
5470
    torch._check(
5471
        mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5472
        lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5473
    )
5474
    torch._check(
5475
        is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5476
        lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5477
    )
5478
    _out_dtype = out_dtype if out_dtype is not None else self.dtype
5479
    return torch.empty(
5480
        self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
5481
    ), torch.empty((), dtype=torch.float32, device=self.device)
5482

5483

5484
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
5485
@out_wrapper()
5486
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
5487
    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5488
    return self.new_empty(self.shape)
5489

5490

5491
@register_meta(aten.scatter_reduce_.two)
5492
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
5493
    scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5494
    return self
5495

5496

5497
@register_meta([aten.multinomial.default, aten.multinomial.out])
5498
@out_wrapper()
5499
def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
5500
    torch._check(
5501
        0 < input.dim() <= 2,
5502
        lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
5503
    )
5504
    if input.dim() == 1:
5505
        return torch.empty(num_samples, dtype=torch.long, device=input.device)
5506
    return torch.empty(
5507
        input.size(0), num_samples, dtype=torch.long, device=input.device
5508
    )
5509

5510

5511
def multiply_integers(vs):
5512
    r = 1
5513
    for v in vs:
5514
        r *= v
5515
    return r
5516

5517

5518
def upsample_common_check(input_size, output_size, num_spatial_dims):
5519
    torch._check(
5520
        len(output_size) == num_spatial_dims,
5521
        lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
5522
    )
5523
    expected_input_dims = num_spatial_dims + 2  # N, C, ...
5524
    torch._check(
5525
        len(input_size) == expected_input_dims,
5526
        lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
5527
    )
5528

5529
    torch._check(
5530
        all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
5531
        lambda: f"Input and output sizes should be greater than 0, but got "
5532
        f"input size {input_size} and output size {output_size}",
5533
    )
5534

5535
    nbatch, channels = input_size[:2]
5536
    return (nbatch, channels, *output_size)
5537

5538

5539
@register_meta(
5540
    [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
5541
)
5542
def upsample_nearest1d(input, output_size, scales=None):
5543
    torch._check(
5544
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5545
        lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
5546
    )
5547
    full_output_size = upsample_common_check(
5548
        input.size(), output_size, num_spatial_dims=1
5549
    )
5550
    return input.new_empty(full_output_size).to(
5551
        memory_format=utils.suggest_memory_format(input)
5552
    )
5553

5554

5555
@register_meta(
5556
    [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
5557
)
5558
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
5559
    torch._check(
5560
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5561
        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5562
    )
5563
    full_output_size = upsample_common_check(
5564
        input.size(), output_size, num_spatial_dims=2
5565
    )
5566
    output = input.new_empty(full_output_size)
5567

5568
    # convert output to correct memory format, if necessary
5569
    memory_format = utils.suggest_memory_format(input)
5570

5571
    # following "heuristic: only use channels_last path when it's faster than the contiguous path"
5572
    _, n_channels, _, _ = input.shape
5573
    if input.device.type == "cuda" and n_channels < 4:
5574
        memory_format = torch.contiguous_format
5575

5576
    output = output.contiguous(memory_format=memory_format)
5577

5578
    return output
5579

5580

5581
@register_meta(
5582
    [
5583
        aten.upsample_nearest2d_backward.default,
5584
        aten._upsample_nearest_exact2d_backward.default,
5585
    ]
5586
)
5587
def upsample_nearest2d_backward(
5588
    grad_output: Tensor,
5589
    output_size: Sequence[Union[int, torch.SymInt]],
5590
    input_size: Sequence[Union[int, torch.SymInt]],
5591
    scales_h: Optional[float] = None,
5592
    scales_w: Optional[float] = None,
5593
):
5594
    full_output_size = upsample_common_check(
5595
        input_size, output_size, num_spatial_dims=2
5596
    )
5597
    torch._check(
5598
        grad_output.ndim == 4,
5599
        lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
5600
    )
5601
    for i in range(4):
5602
        torch._check(
5603
            grad_output.size(i) == full_output_size[i],
5604
            lambda: (
5605
                f"Expected grad_output to have the same shape as output;"
5606
                f" output.size({i}) = {full_output_size[i]}"
5607
                f" but got grad_output.size({i}) = {grad_output.size(i)}"
5608
            ),
5609
        )
5610

5611
    return grad_output.new_empty(input_size).to(
5612
        memory_format=utils.suggest_memory_format(grad_output)
5613
    )  # type: ignore[call-overload]
5614

5615

5616
@register_meta(
5617
    [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
5618
)
5619
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
5620
    torch._check(
5621
        input.numel() != 0 or multiply_integers(input.size()[1:]),
5622
        lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
5623
    )
5624
    full_output_size = upsample_common_check(
5625
        input.size(), output_size, num_spatial_dims=3
5626
    )
5627
    return input.new_empty(full_output_size).to(
5628
        memory_format=utils.suggest_memory_format(input)
5629
    )
5630

5631

5632
@register_meta(
5633
    [
5634
        aten.sort.default,
5635
        aten.sort.stable,
5636
        aten.sort.values,
5637
        aten.sort.values_stable,
5638
    ]
5639
)
5640
def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
5641
    v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
5642
    if values is not None and indices is not None:
5643
        assert isinstance(values, TensorLike)
5644
        assert isinstance(indices, TensorLike)
5645
        # Makes sure values and indices have the same strides. For cases where
5646
        # these have different shapes, like (5, 10, 5) and (0) in msort.
5647
        out_shape = v.shape
5648
        out_stride = v.stride()
5649
        values = _maybe_resize_out(values, out_shape)
5650
        indices = _maybe_resize_out(indices, out_shape)
5651
        values.as_strided_(out_shape, out_stride)
5652
        indices.as_strided_(out_shape, out_stride)
5653
        _safe_copy_out(copy_from=v, copy_to=values)  # type: ignore[arg-type]
5654
        _safe_copy_out(copy_from=i, copy_to=indices)  # type: ignore[arg-type]
5655
        return values, indices
5656
    return v, i
5657

5658

5659
@register_meta(aten.argsort.stable)
5660
def meta_argsort(self, *, stable, dim=-1, descending=False):
5661
    return meta_sort(self, stable=stable, dim=dim, descending=descending)[1]
5662

5663

5664
def rnn_cell_checkSizes(
5665
    input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
5666
):
5667
    torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
5668
    torch._check(
5669
        input_gates.shape == hidden_gates.shape,
5670
        lambda: f"{input_gates.shape} != {hidden_gates.shape}",
5671
    )
5672
    gates_size = input_gates.size(1)
5673
    if input_bias is not None:
5674
        torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
5675
        torch._check(
5676
            input_bias.numel() == gates_size,
5677
            lambda: f"{input_bias.numel()} != {gates_size}",
5678
        )
5679
        torch._check(
5680
            input_bias.shape == hidden_bias.shape,
5681
            lambda: f"{input_bias.shape} != {hidden_bias.shape}",
5682
        )
5683
    torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
5684
    expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
5685
    torch._check(
5686
        prev_hidden.numel() == expected_prev_hidden_numel,
5687
        lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
5688
    )
5689
    torch._check(
5690
        all(
5691
            x.device == input_gates.device
5692
            for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
5693
        ),
5694
        lambda: "expected all inputs to be same device",
5695
    )
5696

5697

5698
@register_meta(aten._thnn_fused_lstm_cell.default)
5699
def _thnn_fused_lstm_cell_meta(
5700
    input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
5701
):
5702
    rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
5703
    workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
5704
    hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5705
    cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5706
    return (hy, cy, workspace)
5707

5708

5709
@register_meta(aten._cudnn_rnn.default)
5710
def _cudnn_rnn(
5711
    input,
5712
    weight,
5713
    weight_stride0,
5714
    weight_buf,
5715
    hx,
5716
    cx,
5717
    mode,
5718
    hidden_size,
5719
    proj_size,
5720
    num_layers,
5721
    batch_first,
5722
    dropout,
5723
    train,
5724
    bidirectional,
5725
    batch_sizes,
5726
    dropout_state,
5727
):
5728
    is_input_packed = len(batch_sizes) != 0
5729
    if is_input_packed:
5730
        seq_length = len(batch_sizes)
5731
        mini_batch = batch_sizes[0]
5732
        batch_sizes_sum = input.shape[0]
5733
    else:
5734
        seq_length = input.shape[1] if batch_first else input.shape[0]
5735
        mini_batch = input.shape[0] if batch_first else input.shape[1]
5736
        batch_sizes_sum = -1
5737

5738
    num_directions = 2 if bidirectional else 1
5739
    out_size = proj_size if proj_size != 0 else hidden_size
5740
    if is_input_packed:
5741
        out_shape = [batch_sizes_sum, out_size * num_directions]
5742
    else:
5743
        out_shape = (
5744
            [mini_batch, seq_length, out_size * num_directions]
5745
            if batch_first
5746
            else [seq_length, mini_batch, out_size * num_directions]
5747
        )
5748
    output = input.new_empty(out_shape)
5749

5750
    cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
5751
    if cx is None:
5752
        cy = torch.empty(0, device=input.device)
5753
    else:
5754
        cy = cx.new_empty(cell_shape)
5755

5756
    hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
5757

5758
    # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
5759
    reserve_shape = 0 if train else 0
5760
    reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
5761

5762
    return output, hy, cy, reserve, weight_buf
5763

5764

5765
@register_meta(aten.mkldnn_rnn_layer.default)
5766
def mkldnn_rnn_layer(
5767
    input,
5768
    w0,
5769
    w1,
5770
    w2,
5771
    w3,
5772
    hx_,
5773
    cx_,
5774
    reverse,
5775
    batch_sizes,
5776
    mode,
5777
    hidden_size,
5778
    num_layers,
5779
    has_biases,
5780
    bidirectional,
5781
    batch_first,
5782
    train,
5783
):
5784
    seq_length = input.shape[1] if batch_first else input.shape[0]
5785
    mini_batch = input.shape[0] if batch_first else input.shape[1]
5786
    output_chanels = hidden_size
5787
    out_shape = (
5788
        [mini_batch, seq_length, output_chanels]
5789
        if batch_first
5790
        else [seq_length, mini_batch, output_chanels]
5791
    )
5792
    output = input.new_empty(out_shape)
5793
    if hx_ is None:
5794
        hy = torch.empty(0, device=input.device)
5795
    else:
5796
        hy = hx_.new_empty(hx_.shape)
5797
    if cx_ is None:
5798
        cy = torch.empty(0, device=input.device)
5799
    else:
5800
        cy = cx_.new_empty(cx_.shape)
5801
    workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
5802
    return output, hy, cy, workspace
5803

5804

5805
def zero_numel_check_dims(self, dim, fn_name):
5806
    if self.ndim == 0:
5807
        torch._check_index(
5808
            dim == 0 or dim == -1,
5809
            lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
5810
        )
5811
    else:
5812
        torch._check_index(
5813
            self.size(dim) != 0,
5814
            lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
5815
        )
5816

5817

5818
# From aten/src/ATen/native/ReduceOps.cpp
5819
def check_argmax_argmin(name, self, dim):
5820
    if dim is not None:
5821
        dim = maybe_wrap_dim(dim, self.dim())
5822
        zero_numel_check_dims(self, dim, name)
5823
    else:
5824
        torch._check(
5825
            self.numel() != 0,
5826
            lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
5827
        )
5828

5829

5830
@register_meta([aten.argmax.default, aten.argmin.default])
5831
def argmax_argmin_meta(self, dim=None, keepdim=False):
5832
    check_argmax_argmin("argmax", self, dim)
5833
    dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
5834
    shape = _compute_reduction_shape(self, dims, keepdim)
5835
    return self.new_empty(shape, dtype=torch.int64)
5836

5837

5838
@register_meta(aten.scalar_tensor.default)
5839
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
5840
    return torch.empty(
5841
        (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
5842
    )
5843

5844

5845
@register_meta(aten.topk.default)
5846
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
5847
    # From aten/src/ATen/native/Sorting.cpp
5848
    dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
5849
    torch._check(
5850
        k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
5851
        lambda: "selected index k out of range",
5852
    )
5853
    sliceSize = 1 if self.dim() == 0 else self.size(dim)
5854
    torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
5855

5856
    topKSize = list(self.shape)
5857
    if len(topKSize) > 0:
5858
        topKSize[dim] = k
5859
    return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
5860

5861

5862
legacy_contiguous_memory_format = torch.contiguous_format
5863

5864

5865
# From aten/src/ATen/native/cuda/RNN.cu
5866
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
5867
    defined_grad = grad_hy if grad_hy is not None else grad_cy
5868
    torch._check(defined_grad.dim() == 2, lambda: "")
5869
    exp_size = defined_grad.size()
5870
    if grad_hy is not None:
5871
        torch._check(grad_hy.size() == exp_size, lambda: "")
5872
    if grad_cy is not None:
5873
        torch._check(grad_cy.size() == exp_size, lambda: "")
5874
    torch._check(cx.size() == exp_size, lambda: "")
5875
    torch._check(cy.size() == exp_size, lambda: "")
5876
    torch._check(workspace.dim() == 2, lambda: "")
5877
    torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
5878

5879

5880
# From aten/src/ATen/native/cuda/RNN.cu
5881
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
5882
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
5883
    if grad_hy is None and grad_cy is None:
5884
        return None, None, None
5885
    checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
5886
    grad_gates = torch.empty_like(
5887
        workspace, memory_format=legacy_contiguous_memory_format
5888
    )
5889
    grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
5890
    grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
5891
    return grad_gates, grad_cx, grad_bias
5892

5893

5894
# From aten/src/ATen/native/mps/operations/Linear.mm
5895
@register_meta(aten.linear_backward.default)
5896
def linear_backward(input_, grad_output_, weight_, output_mask):
5897
    grad_input = None
5898
    grad_weight = None
5899
    grad_bias = None
5900
    if output_mask[0]:
5901
        grad_input = grad_output_.new_empty(input_.size())
5902
    if output_mask[1] or output_mask[2]:
5903
        grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
5904
        grad_bias = grad_output_.new_empty(grad_output_.size(-1))
5905
    return (grad_input, grad_weight, grad_bias)
5906

5907

5908
@register_meta(aten.pixel_shuffle.default)
5909
def meta_pixel_shuffle(self, upscale_factor):
5910
    assert (
5911
        len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
5912
    ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
5913

5914
    def is_channels_last(ten):
5915
        return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
5916

5917
    def pick_memory_format():
5918
        if is_channels_last(self):
5919
            if device_hint(self) == "cuda":
5920
                return torch.contiguous_format
5921
            else:
5922
                return torch.channels_last
5923
        elif self.is_contiguous(memory_format=torch.contiguous_format):
5924
            return torch.contiguous_format
5925
        elif self.is_contiguous(memory_format=torch.preserve_format):
5926
            return torch.preserve_format
5927

5928
    C = self.shape[-3] // (upscale_factor * upscale_factor)
5929
    Hr = self.shape[-2] * upscale_factor
5930
    Wr = self.shape[-1] * upscale_factor
5931
    out_shape = (*self.shape[:-3], C, Hr, Wr)
5932

5933
    out = self.new_empty(out_shape)
5934
    out = out.to(memory_format=pick_memory_format())  # type: ignore[call-overload]
5935
    return out
5936

5937

5938
@register_meta(aten.mkldnn_rnn_layer_backward.default)
5939
def mkldnn_rnn_layer_backward(
5940
    input,
5941
    weight0,
5942
    weight1,
5943
    weight2,
5944
    weight3,
5945
    hx_,
5946
    cx_tmp,
5947
    output,
5948
    hy_,
5949
    cy_,
5950
    grad_output_r_opt,
5951
    grad_hy_r_opt,
5952
    grad_cy_r_opt,
5953
    reverse,
5954
    mode,
5955
    hidden_size,
5956
    num_layers,
5957
    has_biases,
5958
    train,
5959
    bidirectional,
5960
    batch_sizes,
5961
    batch_first,
5962
    workspace,
5963
):
5964
    diff_x = input.new_empty(input.shape)
5965
    diff_hx = hx_.new_empty(hx_.shape)
5966
    diff_cx = cx_tmp.new_empty(cx_tmp.shape)
5967
    diff_w1 = weight0.new_empty(weight0.shape)
5968
    diff_w2 = weight1.new_empty(weight1.shape)
5969
    diff_b = weight2.new_empty(weight2.shape)
5970
    return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
5971

5972

5973
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
5974
@out_wrapper()
5975
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
5976
    return torch.empty_like(
5977
        self, dtype=torch.int32 if out_int32 else torch.int64
5978
    ).contiguous()
5979

5980

5981
@register_meta(
5982
    [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
5983
)
5984
def meta_upsample_bimode2d_aa(
5985
    input, output_size, align_corners, scales_h=None, scales_w=None
5986
):
5987
    full_output_size = upsample_common_check(
5988
        input.size(), output_size, num_spatial_dims=2
5989
    )
5990
    torch._check(
5991
        input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
5992
        lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5993
    )
5994
    return input.new_empty(full_output_size).to(
5995
        memory_format=utils.suggest_memory_format(input)
5996
    )
5997

5998

5999
# From aten/src/ATen/native/cuda/AmpKernels.cu
6000
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
6001
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
6002
    torch._check(
6003
        found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
6004
    )
6005
    torch._check(
6006
        inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
6007
    )
6008
    torch._check(
6009
        found_inf.dtype.is_floating_point,
6010
        lambda: "found_inf must be a float tensor.",
6011
    )
6012
    torch._check(
6013
        inv_scale.dtype.is_floating_point,
6014
        lambda: "inv_scale must be a float tensor.",
6015
    )
6016

6017

6018
# From aten/src/ATen/native/UnaryOps.cpp
6019
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
6020
@out_wrapper()
6021
def nan_to_num(self, nan=None, posinf=None, neginf=None):
6022
    result_size = list(self.size())
6023
    return self.new_empty(result_size)
6024

6025

6026
@register_meta(torch.ops.aten.transpose_)
6027
def transpose_(self, dim0, dim1):
6028
    assert self.layout not in {
6029
        torch.sparse_csr,
6030
        torch.sparse_csc,
6031
        torch.sparse_bsr,
6032
        torch.sparse_bsc,
6033
    }, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
6034

6035
    ndims = self.ndim
6036

6037
    dim0 = maybe_wrap_dim(dim0, ndims)
6038
    dim1 = maybe_wrap_dim(dim1, ndims)
6039

6040
    if dim0 == dim1:
6041
        return self
6042

6043
    size = list(self.size())
6044
    stride = list(self.stride())
6045

6046
    stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
6047
    size[dim0], size[dim1] = size[dim1], size[dim0]
6048

6049
    self.as_strided_(size, stride)
6050
    return self
6051

6052

6053
@register_meta(torch.ops.aten.t_)
6054
def t_(self):
6055
    ndims = self.ndim
6056

6057
    if self.is_sparse:
6058
        sparse_dim = self.sparse_dim()
6059
        dense_dim = self.dense_dim()
6060
        assert (
6061
            sparse_dim <= 2 and dense_dim == 0
6062
        ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions"  # noqa: B950
6063
    else:
6064
        assert (
6065
            self.dim() <= 2
6066
        ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
6067

6068
    return transpose_(self, 0, 0 if ndims < 2 else 1)
6069

6070

6071
@register_meta(aten.searchsorted)
6072
@out_wrapper()
6073
def meta_searchsorted(
6074
    sorted_sequence, self, *, out_int32=False, right=False, side=None, sorter=None
6075
):
6076
    dtype = torch.int32 if out_int32 else torch.int64
6077
    if isinstance(self, torch.Tensor):
6078
        return torch.empty_like(self, dtype=dtype).contiguous()
6079
    else:  # Scalar
6080
        return torch.empty((), dtype=dtype, device=sorted_sequence.device)
6081

6082

6083
def _check_for_unsupported_isin_dtype(dtype):
6084
    torch._check(
6085
        dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
6086
        lambda: f"Unsupported input type encountered for isin(): {dtype}",
6087
    )
6088

6089

6090
@register_meta(aten.isin)
6091
@out_wrapper()
6092
def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
6093
    torch._check(
6094
        isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
6095
        lambda: "At least one of elements and test_elements must be a Tensor.",
6096
    )
6097
    if not isinstance(elements, Tensor):
6098
        elements = torch.tensor(elements, device=test_elements.device)
6099

6100
    if not isinstance(test_elements, Tensor):
6101
        test_elements = torch.tensor(test_elements, device=elements.device)
6102

6103
    _check_for_unsupported_isin_dtype(elements.dtype)
6104
    _check_for_unsupported_isin_dtype(test_elements.dtype)
6105
    return torch.empty_like(elements, dtype=torch.bool)
6106

6107

6108
@register_meta(aten.polygamma)
6109
@out_wrapper()
6110
def meta_polygamma(n: int, self: Tensor) -> Tensor:
6111
    torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
6112
    _, result_dtype = elementwise_dtypes(
6113
        self,
6114
        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
6115
    )
6116
    return torch.empty_like(self, dtype=result_dtype)
6117

6118

6119
def _create_unary_float_meta_func(func):
6120
    @register_meta(func)
6121
    @out_wrapper()
6122
    def _f(x):
6123
        return elementwise_meta(
6124
            x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6125
        )
6126

6127
    return _f
6128

6129

6130
def _create_binary_float_meta_func(func):
6131
    @register_meta(func)
6132
    @out_wrapper()
6133
    def _f(x, y):
6134
        return elementwise_meta(
6135
            x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6136
        )
6137

6138
    return _f
6139

6140

6141
_create_unary_float_meta_func(aten.special_airy_ai)
6142
_create_unary_float_meta_func(aten.special_bessel_y0)
6143
_create_unary_float_meta_func(aten.special_bessel_y1)
6144
_create_unary_float_meta_func(aten.special_modified_bessel_i0)
6145
_create_unary_float_meta_func(aten.special_modified_bessel_i1)
6146
_create_unary_float_meta_func(aten.special_modified_bessel_k0)
6147
_create_unary_float_meta_func(aten.special_modified_bessel_k1)
6148
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
6149
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
6150

6151

6152
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
6153
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
6154
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
6155
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
6156
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
6157
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
6158
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
6159
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
6160
_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
6161
_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
6162
_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
6163
_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
6164

6165

6166
# We must also trigger meta registrations from PrimTorch ref
6167
# decompositions
6168
import torch._refs
6169
import torch._refs.nn.functional
6170
import torch._refs.special
6171

6172

6173
def activate_meta():
6174
    activate_meta_table = {}
6175

6176
    # For a given op, we pick the most specific decomp function from
6177
    # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
6178
    for type in ["meta", "post_autograd", "pre_autograd"]:
6179
        registry = global_decomposition_table[type]
6180

6181
        for opo in registry:
6182
            if opo not in activate_meta_table:
6183
                activate_meta_table[opo] = registry[opo]
6184

6185
    for op_overload, fn in activate_meta_table.items():
6186
        # Don't register meta for HigherOrderOp's decomp.
6187
        # We can reconsider this in the future, but in general,
6188
        # the way you do a meta for a HigherOrderOp is different from
6189
        # OpOverload.
6190
        if isinstance(op_overload, torch._ops.HigherOrderOperator):
6191
            continue
6192
        assert isinstance(op_overload, OpOverload)
6193

6194
        op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6195

6196
        if torch._C._dispatch_has_kernel_for_dispatch_key(
6197
            op_overload.name(), "CompositeImplicitAutograd"
6198
        ):
6199
            # Internally, we shouldn't be registering meta kernels for any operators that
6200
            # have CompositeImplicitAutograd kernels.
6201
            # Instead, we should be letting those decompositions run, and writing meta kernels
6202
            # only for the base operators.
6203
            if op_overload in global_decomposition_table["meta"]:
6204
                raise RuntimeError(
6205
                    f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
6206
                    "register meta function for it. Instead, we should let the decomposition run and write "
6207
                    "meta kernels for the base operators."
6208
                )
6209
            pass
6210
        elif op_overload.is_view:
6211
            # Attempting to register a python meta kernel for a view operator.
6212
            # We shouldn't do this, because the output will report as not having aliased storages.
6213
            # All view ops have meta kernels in C++ today, so we should use those instead.
6214
            pass
6215
        elif op_overload.name() in {
6216
            "aten::empty_strided",  # causing infinite recursion, test_meta.py
6217
            "aten::clone",  # causing infinite recursion
6218
            "aten::_to_copy",  # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite  # noqa: B950
6219
            "aten::copy_",  # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64  # noqa: B950
6220
            "aten::constant_pad_nd",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32  # noqa: B950
6221
            "aten::rot90",  # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32  # noqa: B950
6222
            "aten::as_strided_scatter",  # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32  # noqa: B950
6223
        }:
6224
            pass
6225
        else:
6226
            if "mkldnn::" in op_overload.name():
6227
                _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
6228
            elif "mkl::" in op_overload.name():
6229
                _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
6230
            elif "onednn::" in op_overload.name():
6231
                _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
6232
            elif "quantized::" in op_overload.name():
6233
                _meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
6234
                    op_overload, fn
6235
                )
6236
            else:
6237
                _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
6238

6239

6240
activate_meta()
6241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.