5
from typing import List, Optional, Sequence, Tuple, Union
8
import torch._prims_common as utils
9
from torch import SymBool, SymFloat, Tensor
10
from torch._decomp import (
13
global_decomposition_table,
16
from torch._ops import OpOverload
17
from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
18
from torch._prims_common import (
19
corresponding_complex_dtype,
20
corresponding_real_dtype,
22
ELEMENTWISE_TYPE_PROMOTION_KIND,
24
make_contiguous_strides_for,
28
from torch._prims_common.wrappers import (
29
_maybe_convert_to_dtype,
35
from torch._refs import _broadcast_shapes, _maybe_broadcast
36
from torch.utils import _pytree as pytree
41
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
46
fn = _convert_out_params(fn)
49
_add_op_to_registry(meta_table, op, fn)
51
pytree.tree_map_(register, op)
59
type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
62
_, result_dtype = utils.elementwise_dtypes(
64
type_promotion_kind=type_promotion,
66
args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
69
args = _maybe_broadcast(*args)
72
return _prim_elementwise_meta(
73
*args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
77
def toRealValueType(dtype):
79
torch.complex32: torch.half,
80
torch.cfloat: torch.float,
81
torch.cdouble: torch.double,
83
return from_complex.get(dtype, dtype)
86
def check_inplace_broadcast(self_shape, *args_shape):
87
broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
89
broadcasted_shape == self_shape,
90
lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
94
@register_meta([aten.linspace, aten.logspace])
96
def meta_linspace_logspace(
103
layout=torch.strided,
107
if isinstance(start, torch.Tensor):
110
lambda: "linspace only supports 0-dimensional start and end tensors",
112
if isinstance(end, torch.Tensor):
115
lambda: "linspace only supports 0-dimensional start and end tensors",
118
if any(isinstance(arg, complex) for arg in (start, end, steps)):
119
default_complex_dtype = utils.corresponding_complex_dtype(
120
torch.get_default_dtype()
123
dtype = default_complex_dtype
126
utils.is_complex_dtype(dtype),
127
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
130
dtype = dtype or torch.get_default_dtype()
131
assert isinstance(dtype, torch.dtype)
135
isinstance(steps, IntLike),
136
lambda: f"received an invalid combination of arguments - got \
137
({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
139
assert isinstance(steps, IntLike)
140
torch._check(steps >= 0, lambda: "number of steps must be non-negative")
147
pin_memory=pin_memory,
148
requires_grad=requires_grad,
152
@register_meta([aten.take.default, aten.take.out])
154
def meta_take(self, index):
157
index.dtype == torch.long,
158
lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
162
not (self.numel() == 0 and index.numel() != 0),
163
lambda: "take(): tried to take from an empty tensor",
165
return self.new_empty(index.shape)
168
@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
170
def linalg_cross(self, other, *, dim=-1):
175
lambda: "linalg.cross: inputs must have the same number of dimensions.",
178
self.size(dim) == 3 and other.size(dim) == 3,
180
f"linalg.cross: inputs dimension {dim} must have length 3. "
181
f"Got {self.size(dim)} and {other.size(dim)}"
184
out_shape = _broadcast_shapes(self.shape, other.shape)
185
return self.new_empty(out_shape)
188
@register_meta(aten.linalg_matrix_exp)
190
def linalg_matrix_exp(self):
191
squareCheckInputs(self, "linalg.matrix_exp")
192
checkFloatingOrComplex(self, "linalg.matrix_exp")
193
return torch.empty_like(self, memory_format=torch.contiguous_format)
197
[aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
199
@out_wrapper("values", "indices")
200
def cummaxmin(self, dim):
201
values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
202
indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
203
if self.numel() != 0 and self.ndim != 0:
205
maybe_wrap_dim(dim, self.ndim)
206
return values, indices
209
@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
211
def logcumsumexp(self, dim):
213
maybe_wrap_dim(dim, self.ndim)
214
return torch.empty_like(self).contiguous()
218
def _exec_fft(out, self, out_sizes, dim, forward):
220
signal_ndim = len(dim)
221
batch_dims = ndim - signal_ndim
224
dim_permute = list(range(ndim))
226
is_transformed_dim = [False for _ in range(ndim)]
228
is_transformed_dim[d] = True
232
for d in dim_permute:
233
if not is_transformed_dim[d]:
237
dim_permute = left + right
238
batch_end = len(left)
240
self_strides = self.stride()
241
tmp = dim_permute[:batch_end]
242
tmp.sort(key=lambda x: self_strides[x], reverse=True)
243
dim_permute = tmp + dim_permute[batch_end:]
244
input = self.permute(dim_permute)
247
batched_sizes = [-1] + list(input.shape[batch_dims:])
248
input = input.reshape(batched_sizes)
250
batch_size = input.size(0)
251
batched_sizes[0] = batch_size
252
batched_out_sizes = batched_sizes
253
for i in range(len(dim)):
254
batched_out_sizes[i + 1] = out_sizes[dim[i]]
255
out = out.reshape(batched_out_sizes)
258
out_strides = [0 for _ in range(ndim)]
262
out_strides[dim_permute[i]] = batch_numel * out.stride(0)
263
batch_numel *= out_sizes[dim_permute[i]]
265
for i in range(batch_dims, ndim):
266
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
267
return out.as_strided(out_sizes, out_strides, out.storage_offset())
272
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
274
def meta_fft_c2c(self, dim, normalization, forward):
275
assert self.dtype.is_complex
277
out_sizes = self.shape
278
output = self.new_empty(out_sizes)
284
self_strides = self.stride()
285
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
286
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
291
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
293
def meta_fft_r2c(self, dim, normalization, onesided):
294
assert self.dtype.is_floating_point
295
output_sizes = list(self.size())
299
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
300
output_sizes[last_dim] = last_dim_halfsize
302
return self.new_empty(
303
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
307
@register_meta(aten.randperm.generator_out)
308
def meta_randperm(n, *, generator=None, out):
309
return _maybe_resize_out(out, torch.Size([n]))
312
@register_meta(aten.randperm.default)
313
def meta_randperm_default(
322
n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
326
@register_meta([aten.randint.default, aten.randint.out])
338
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
342
@register_meta([aten.randint.low, aten.randint.low_out])
355
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
359
@register_meta([aten.rand.default, aten.rand.out])
361
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
363
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
367
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
369
def meta_fft_c2r(self, dim, normalization, lastdim):
370
assert self.dtype.is_complex
371
output_sizes = list(self.size())
372
output_sizes[dim[-1]] = lastdim
373
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
376
@register_meta(aten.copy_.default)
377
def meta_copy_(self, src, non_blocking=False):
382
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
388
not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1
391
"more than one element of the written-to tensor refers to a single memory location"
394
if isinstance(src, Tensor):
395
intermediate = src.to(self, non_blocking)
396
if self.size() != intermediate.size():
397
aten.expand_copy.default(intermediate, self.size())
401
def inferUnsqueezeGeometry(tensor, dim):
402
result_sizes = list(tensor.size())
403
result_strides = list(tensor.stride())
404
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
405
result_sizes.insert(dim, 1)
406
result_strides.insert(dim, new_stride)
407
return result_sizes, result_strides
410
@register_meta(aten.unsqueeze_.default)
411
def meta_unsqueeze_(self, dim):
412
dim = maybe_wrap_dim(dim, self.dim() + 1)
413
g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
414
self.as_strided_(g_sizes, g_strides)
418
@register_meta(aten._sparse_semi_structured_linear)
419
def meta_sparse_structured_linear(
423
bias: Optional[Tensor] = None,
424
_activation_opt: Optional[str] = None,
425
out_dtype: Optional[torch.dtype] = None,
427
output_sizes = list(input.shape)
429
assert weight.size(0) == bias.size(0), "output size mismatch"
430
assert weight.size(1) == input.size(-1) / 2
431
output_sizes[-1] = weight.size(0)
437
assert len(input.shape) == 2, "we can only handle the squashed input case"
438
transposed_strides = (1, input.size(0))
440
if out_dtype is not None:
442
input.dtype == torch.int8 and out_dtype == torch.int32
443
), "out_dtype is only supported for i8i8->i32 linear operator"
444
output = input.new_empty(
446
dtype=input.dtype if out_dtype is None else out_dtype,
447
).as_strided(output_sizes, transposed_strides)
452
@register_meta(aten._sparse_semi_structured_mm)
453
def meta_sparse_structured_mm(
457
out_dtype: Optional[torch.dtype] = None,
459
assert len(mat1.shape) == 2
460
assert len(mat1_meta.shape) == 2
461
assert len(mat2.shape) == 2
462
assert mat1.size(1) == mat2.size(0) / 2
463
output_sizes = [mat1.size(0), mat2.size(1)]
465
if out_dtype is not None:
467
mat2.dtype == torch.int8 and out_dtype == torch.int32
468
), "out_dtype is only supported for i8i8->i32 linear operator"
469
output = mat2.new_empty(
471
dtype=mat2.dtype if out_dtype is None else out_dtype,
477
@register_meta(aten._sparse_semi_structured_addmm)
478
def meta_sparse_structured_addmm(
486
out_dtype: Optional[torch.dtype] = None,
489
len(input.shape) == 1
490
), "only input broadcasted to columns of mat1 * mat2 product is supported"
491
assert len(mat1.shape) == 2
492
assert len(mat1_meta.shape) == 2
493
assert len(mat2.shape) == 2
494
assert input.size(0) == mat1.size(
496
), "only input broadcasted to columns of mat1 * mat2 product is supported"
497
assert mat1.size(1) == mat2.size(0) / 2
498
output_sizes = [mat1.size(0), mat2.size(1)]
500
if out_dtype is not None:
502
mat2.dtype == torch.int8 and out_dtype == torch.int32
503
), "out_dtype is only supported for i8i8->i32 linear operator"
504
output = mat2.new_empty(
506
dtype=mat2.dtype if out_dtype is None else out_dtype,
512
@register_meta(aten._cslt_sparse_mm)
513
def meta__cslt_sparse_mm(
514
compressed_A: torch.Tensor,
515
dense_B: torch.Tensor,
516
bias: Optional[Tensor] = None,
517
alpha: Optional[Tensor] = None,
518
out_dtype: Optional[torch.dtype] = None,
519
transpose_result: bool = False,
521
assert dense_B.dtype in {
526
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
527
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
528
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
530
is_int8_input_type = compressed_A.dtype == torch.int8
531
compression_factor = 10 if is_int8_input_type else 9
534
m = (compressed_A.numel() * 16) // (compression_factor * k)
536
assert m == bias.size(0)
538
if out_dtype is not None:
539
assert is_int8_input_type and out_dtype in {
543
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
544
output_shape = (n, m) if transpose_result else (m, n)
545
result = dense_B.new_empty(output_shape, dtype=out_dtype)
549
@register_meta(aten.index_reduce.default)
550
def meta_index_reduce(
554
source: torch.Tensor,
557
include_self: bool = True,
559
return torch.empty_like(self, memory_format=torch.contiguous_format)
562
@register_meta(aten.index_reduce_.default)
563
def meta_index_reduce_(
567
source: torch.Tensor,
570
include_self: bool = True,
577
@register_meta(aten.index_select.default)
578
def meta_index_select(self, dim, index):
579
result_size = list(self.size())
581
result_size[dim] = index.numel()
582
return self.new_empty(result_size)
585
@register_meta(aten.segment_reduce.default)
586
def meta_segment_reduce(
590
lengths: Optional[Tensor] = None,
591
indices: Optional[Tensor] = None,
592
offsets: Optional[Tensor] = None,
594
unsafe: bool = False,
597
if indices is not None:
598
raise NotImplementedError(
599
"segment_reduce(): indices based reduction is not supported yet."
602
def segment_reduce_lengths_tensor(lengths_shape):
604
lengths_shape + data.shape[axis + 1 :],
607
memory_format=torch.contiguous_format,
610
if lengths is not None:
611
return segment_reduce_lengths_tensor(lengths.shape)
614
if offsets is not None:
616
lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
617
return segment_reduce_lengths_tensor(lengths_shape)
618
raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
621
@register_meta([aten.max.default, aten.max.unary_out])
624
return self.new_empty(())
627
@register_meta(aten.max.dim)
628
def meta_max_dim(self, dim, keepdim=False):
629
dim = utils.reduction_dims(self.shape, (dim,))
630
output_shape = _compute_reduction_shape(self, dim, keepdim)
632
self.new_empty(output_shape),
633
self.new_empty(output_shape, dtype=torch.long),
637
@register_meta([aten.min.default, aten.min.unary_out])
640
return self.new_empty(())
643
@register_meta(aten.min.dim)
644
def meta_min_dim(self, dim, keepdim=False):
645
dim = utils.reduction_dims(self.shape, (dim,))
646
output_shape = _compute_reduction_shape(self, dim, keepdim)
648
self.new_empty(output_shape),
649
self.new_empty(output_shape, dtype=torch.long),
653
@register_meta(aten.angle.default)
655
if self.is_complex():
656
result_dtype = corresponding_real_dtype(self.dtype)
658
_, result_dtype = elementwise_dtypes(
660
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
662
return torch.empty_like(self, dtype=result_dtype)
665
@register_meta(aten.angle.out)
666
def meta_angle_out(self, out):
667
torch._resize_output_(out, self.size(), self.device)
668
return out.copy_(torch.angle(self))
671
@register_meta(aten._assert_async.default)
672
def assert_async(val):
676
@register_meta(aten._assert_async.msg)
677
def assert_async_meta(val, assert_msg):
681
@register_meta(aten._print.default)
686
@register_meta(aten._make_dep_token.default)
695
return torch.empty(0, device="meta")
698
@register_meta(aten.sym_constrain_range.default)
699
def sym_constrain_range(size, min=None, max=None):
701
from torch.fx.experimental.symbolic_shapes import constrain_range
703
if isinstance(size, (SymFloat, SymBool)):
704
raise ValueError("Constraining SymFloat or Symbool is nyi")
705
constrain_range(size, min=min, max=max)
708
@register_meta(aten._functional_sym_constrain_range.default)
709
def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
710
aten.sym_constrain_range(size, min=min, max=max)
714
@register_meta(aten.sym_constrain_range_for_size.default)
715
def sym_constrain_range_for_size(size, min=None, max=None):
717
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
719
if isinstance(size, (SymFloat, SymBool)):
720
raise ValueError("Constraining SymFloat or Symbool is nyi")
721
_constrain_range_for_size(size, min=min, max=max)
724
@register_meta(aten._functional_sym_constrain_range_for_size.default)
725
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
726
aten.sym_constrain_range_for_size(size, min=min, max=max)
730
@register_meta(aten._functional_assert_async.msg)
731
def functional_assert_async_meta(val, assert_msg, dep_token):
736
def squareCheckInputs(self: Tensor, f_name: str):
739
), f"{f_name}: The input tensor must have at least 2 dimensions."
741
self.size(-1) == self.size(-2)
742
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
748
def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str):
750
self.device == A.device,
752
f"Expected b and A to be on the same device, but found b on "
753
f"{self.device} and A on {A.device} instead."
758
self.dtype == A.dtype,
760
f"Expected b and A to have the same dtype, but found b of type "
761
f"{self.dtype} and A of type {A.dtype} instead."
766
A.size(-1) == A.size(-2),
768
f"A must be batches of square matrices, "
769
f"but they are {A.size(-2)} by {A.size(-1)} matrices"
774
A.size(-1) == self.size(-2),
776
f"Incompatible matrix sizes for {name}: each A "
777
f"matrix is {A.size(-1)} by {A.size(-1)}"
778
f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
784
def checkFloatingOrComplex(
787
allow_low_precision_dtypes: bool = True,
791
t.is_floating_point() or t.is_complex(),
792
lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
794
if not allow_low_precision_dtypes:
796
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
797
lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
802
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
805
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
809
def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str):
810
squareCheckInputs(A, f_name)
811
checkIsMatrix(B, f_name)
813
A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
815
f"{f_name}: Incompatible shapes of A and B for the equation "
816
f"{'AX = B' if left else 'XA = B'}"
817
f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
826
result_name: str = "result",
829
result.device == input.device,
831
f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
832
f"{result_name} on {result.device} and input on {input.device}"
837
def checkUplo(UPLO: str):
838
UPLO_uppercase = UPLO.upper()
840
len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
841
lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
845
@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
846
@out_wrapper("eigenvalues", "eigenvectors")
847
def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True):
848
squareCheckInputs(A, "linalg.eigh")
851
shape = list(A.shape)
853
vecs = A.new_empty(shape)
854
vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
856
vecs = A.new_empty([0])
859
vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
864
@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
866
def meta__linalg_eigvals(input: Tensor) -> Tensor:
867
squareCheckInputs(input, "linalg.eigvals")
870
if utils.is_complex_dtype(input.dtype)
871
else utils.corresponding_complex_dtype(input.dtype)
873
return input.new_empty(input.shape[:-1], dtype=complex_dtype)
876
@register_meta([aten.linalg_eig])
877
@out_wrapper("eigenvalues", "eigenvectors")
878
def meta_linalg_eig(input: Tensor):
879
squareCheckInputs(input, "linalg.eig")
882
if utils.is_complex_dtype(input.dtype)
883
else utils.corresponding_complex_dtype(input.dtype)
885
values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
886
vectors = input.new_empty(input.shape, dtype=complex_dtype)
887
return values, vectors
890
def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
891
return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
894
@register_meta(aten._cholesky_solve_helper)
896
def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
897
return cloneBatchedColumnMajor(self)
900
@register_meta(aten.cholesky_solve)
902
def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
905
lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
909
lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
911
self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
912
self, A, "cholesky_solve"
914
return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
917
@register_meta(aten.cholesky)
919
def cholesky(self: Tensor, upper: bool = False) -> Tensor:
920
if self.numel() == 0:
921
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
922
squareCheckInputs(self, "cholesky")
923
return cloneBatchedColumnMajor(self)
926
@register_meta(aten.cholesky_inverse)
928
def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
929
squareCheckInputs(self, "cholesky_inverse")
930
return cloneBatchedColumnMajor(self)
934
@register_meta(aten.linalg_cholesky_ex.default)
935
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
936
squareCheckInputs(A, "linalg.cholesky")
937
checkFloatingOrComplex(A, "linalg.cholesky")
943
L_strides = make_contiguous_strides_for(A_shape, False)
944
L = A.new_empty(A_shape)
945
L.as_strided_(A_shape, L_strides)
948
infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
953
[aten.linalg_householder_product.default, aten.linalg_householder_product.out]
956
def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
959
lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
962
input.size(-2) >= input.size(-1),
963
lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
966
input.size(-1) >= tau.size(-1),
967
lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
971
input.ndim - tau.ndim == 1,
973
f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
974
f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
978
expected_batch_tau_shape = input.shape[:-2]
979
actual_batch_tau_shape = tau.shape[:-1]
981
actual_batch_tau_shape == expected_batch_tau_shape,
983
f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
984
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
989
tau.dtype == input.dtype,
991
f"torch.linalg.householder_product: tau dtype {tau.dtype}"
992
f" does not match input dtype {input.dtype}"
995
checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
997
return torch.empty_strided(
999
stride=make_contiguous_strides_for(input.shape, row_major=False),
1001
device=input.device,
1006
@register_meta(aten.linalg_inv_ex.default)
1007
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
1008
squareCheckInputs(A, "linalg.inv_ex")
1009
checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
1011
L = A.new_empty(A.shape)
1012
L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1014
infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
1018
@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
1019
@out_wrapper("LD", "pivots", "info")
1020
def linalg_ldl_factor_ex_meta(
1023
hermitian: bool = False,
1024
check_errors: bool = False,
1025
) -> Tuple[Tensor, Tensor, Tensor]:
1026
squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
1027
checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
1028
LD = torch.empty_strided(
1030
stride=make_contiguous_strides_for(self.shape, row_major=False),
1034
pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
1035
info = self.new_empty(self.shape[:-2], dtype=torch.int)
1036
return LD, pivots, info
1039
@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
1041
def linalg_ldl_solve_meta(
1046
hermitian: bool = False,
1048
squareCheckInputs(LD, "torch.linalg.ldl_solve")
1049
checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
1050
linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
1054
f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
1055
f"but it has {B.ndim} dimensions instead"
1058
expected_pivots_shape = LD.shape[:-1]
1060
expected_pivots_shape == pivots.shape,
1062
f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
1063
f"but got pivots with shape {pivots.shape} instead"
1067
utils.is_integer_dtype(pivots.dtype),
1068
lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
1071
LD.dtype == B.dtype,
1072
lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
1074
B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
1075
return torch.empty_strided(
1076
size=B_broadcast_size,
1077
stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
1083
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
1084
@out_wrapper("P", "L", "U")
1085
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
1088
lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1091
sizes = list(A.shape)
1098
P = A.new_empty(sizes)
1100
P = A.new_empty([0])
1103
L = A.new_empty(sizes)
1107
U = A.new_empty(sizes)
1111
@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
1112
@out_wrapper("LU", "pivots", "info")
1113
def linalg_lu_factor_ex_meta(
1117
check_errors: bool = False,
1118
) -> Tuple[Tensor, Tensor, Tensor]:
1121
lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1124
sizes = list(A.shape)
1128
LU = torch.empty_strided(
1130
stride=make_contiguous_strides_for(sizes, row_major=False),
1137
sizes[-1] = min(m, n)
1138
pivots = A.new_empty(sizes, dtype=torch.int)
1142
info = A.new_empty(sizes, dtype=torch.int)
1144
return LU, pivots, info
1147
@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
1149
def linalg_lu_solve_meta(
1155
adjoint: bool = False,
1158
checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
1160
LU.dtype == B.dtype,
1162
f"linalg.lu_solve: Expected LU and B to have the same dtype, "
1163
f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
1167
pivots.dtype == torch.int,
1168
lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
1172
squareCheckInputs(LU, "torch.linalg.lu_solve")
1173
checkInputsSolver(LU, B, left, "linalg.lu_solve")
1175
LU.size(-1) == pivots.size(-1),
1176
lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
1181
LU.shape[:-1] == pivots.shape,
1183
f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
1184
f"but got pivots with shape {pivots.shape} instead"
1188
B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
1190
result = torch.empty_strided(
1191
size=B_broadcast_size,
1192
stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
1197
if result.numel() != 0 and not left:
1198
if result.is_complex():
1199
result = result.conj()
1204
@register_meta(aten.lu_unpack)
1205
@out_wrapper("P", "L", "U")
1209
unpack_data: bool = True,
1210
unpack_pivots: bool = True,
1211
) -> Tuple[Tensor, Tensor, Tensor]:
1214
lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
1218
pivots.dtype == torch.int32,
1220
"torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
1221
"Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
1224
sizes = list(LU.shape)
1230
P = LU.new_empty(sizes)
1232
P = LU.new_empty([0])
1235
L = LU.new_empty(sizes)
1238
U = LU.new_empty(sizes)
1240
L = LU.new_empty([0])
1241
U = LU.new_empty([0])
1246
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
1247
if mode == "reduced":
1250
elif mode == "complete":
1260
f"qr received unrecognized mode '{mode}' "
1261
f"but expected one of 'reduced' (default), 'r', or 'complete'"
1264
return compute_q, reduced
1267
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
1268
@out_wrapper("Q", "R")
1269
def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> Tuple[Tensor, Tensor]:
1270
checkIsMatrix(A, "linalg.qr")
1271
checkFloatingOrComplex(A, "linalg.qr")
1273
compute_q, reduced_mode = _parse_qr_mode(mode)
1280
Q_shape = list(A.shape)
1281
Q_shape[-1] = k if reduced_mode else m
1282
Q = A.new_empty(Q_shape)
1283
Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
1285
Q = A.new_empty([0])
1288
R_shape = list(A.shape)
1289
R_shape[-2] = k if reduced_mode or not compute_q else m
1290
R = A.new_empty(R_shape)
1291
R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
1295
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
1296
@out_wrapper("sign", "logabsdet", "LU", "pivots")
1297
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1298
squareCheckInputs(A, "linalg.slogdet")
1299
checkFloatingOrComplex(A, "linalg.slogdet", False)
1301
sign = A.new_empty(shape[:-2])
1302
logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
1303
LU = torch.empty_strided(
1305
stride=make_contiguous_strides_for(shape, False),
1309
pivots = A.new_empty(shape[:-1], dtype=torch.int32)
1310
return sign, logabsdet, LU, pivots
1315
@register_meta(aten._linalg_svd.default)
1316
def _linalg_svd_meta(
1318
full_matrices: bool = False,
1319
compute_uv: bool = True,
1320
driver: Optional[str] = None,
1322
checkIsMatrix(A, "linalg.svd")
1323
checkFloatingOrComplex(A, "linalg.svd")
1325
batch_dims = list(A.shape[:-2])
1331
U_shape = batch_dims + [m, m if full_matrices else k]
1332
U = A.new_empty(U_shape)
1333
U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
1335
V_shape = batch_dims + [n if full_matrices else k, n]
1336
V = A.new_empty(V_shape)
1341
is_cuda = device_hint(A) == "cuda"
1342
V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
1345
U = A.new_empty([0])
1346
V = A.new_empty([0])
1349
S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
1353
def _linalg_broadcast_batch_dims(
1356
) -> Tuple[List[int], List[int]]:
1358
arg1_batch_sizes = arg1.shape[:-2]
1359
arg2_batch_sizes = arg2.shape[:-2]
1360
expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
1362
arg1_expand_size = list(expand_batch_portion)
1363
arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
1365
arg2_expand_size = list(expand_batch_portion)
1366
arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
1367
return arg1_expand_size, arg2_expand_size
1370
def _linalg_broadcast_batch_dims_name(
1373
name: Optional[str],
1374
) -> Tuple[Tensor, Tensor]:
1377
linearSolveCheckInputs(arg1, arg2, name)
1379
arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
1381
arg1_broadcasted = (
1382
arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
1384
arg2_broadcasted = (
1385
arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
1387
return arg1_broadcasted, arg2_broadcasted
1390
def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
1391
expected_batched_rhs_shape = input.shape[:-1]
1392
vector_case = other.ndim == 1 or (
1393
input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
1398
@register_meta(aten._linalg_solve_ex)
1399
def _linalg_solve_ex(
1404
check_errors: bool = False,
1405
result: Optional[Tensor] = None,
1406
LU: Optional[Tensor] = None,
1407
pivots: Optional[Tensor] = None,
1408
info: Optional[Tensor] = None,
1409
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1410
checkFloatingOrComplex(A, "linalg.solve")
1414
f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
1415
f"{A.dtype} and B of type {B.dtype} instead"
1418
vector_case = linalg_solve_is_vector_rhs(A, B)
1419
B_ = B.unsqueeze(-1) if vector_case else B
1420
checkInputsSolver(A, B_, left, "linalg.solve")
1421
B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
1423
left or not vector_case,
1425
"linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
1426
"In this case linalg.solve is equivalent to B / A.squeeze(-1)"
1429
result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
1430
result_ = torch.empty_strided(
1432
stride=make_contiguous_strides_for(result_shape, not left),
1438
LU_ = torch.empty_strided(
1440
stride=make_contiguous_strides_for(shape, False),
1444
pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
1445
info_ = A.new_empty(shape[:-2], dtype=torch.int32)
1446
out = (result, LU, pivots, info)
1447
res = (result_, LU_, pivots_, info_)
1448
if all(x is not None for x in out):
1449
for r, o in zip(res, out):
1451
_maybe_resize_out(o, r.shape)
1453
o.as_strided_(r.shape, r.stride())
1454
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False)
1458
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
1459
def linalg_solve_triangular_meta(
1465
unitriangular: bool = False,
1466
out: Optional[Tensor] = None,
1469
out = A.new_empty([0])
1470
assert isinstance(out, TensorLike)
1471
checkInputsSolver(A, B, left, "linalg.solve_triangular")
1472
B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
1473
avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
1475
out = _maybe_resize_out(out, B_.shape)
1478
if _resize_output_check(out, B_.shape):
1479
out.resize_(B_.transpose(-2, -1).shape)
1480
out.transpose_(-2, -1)
1484
@register_meta(aten.triangular_solve)
1485
@out_wrapper("solution", "cloned_coefficient")
1486
def triangular_solve_meta(
1490
transpose: bool = False,
1491
unitriangular: bool = False,
1492
) -> Tuple[Tensor, Tensor]:
1496
f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
1497
f"but it has {self.ndim} dimensions instead"
1503
f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
1504
f"but it has {A.ndim} dimensions instead"
1508
linearSolveCheckInputs(self, A, "triangular_solve")
1510
if A.layout == torch.strided:
1511
self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
1512
solution = torch.empty_strided(
1513
size=self_broadcast_size,
1514
stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
1518
cloned_coefficient = torch.empty_strided(
1519
size=A_broadcast_size,
1520
stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
1524
elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
1525
solution = torch.empty_like(self)
1526
cloned_coefficient = self.new_empty([0])
1528
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
1529
return solution, cloned_coefficient
1533
@register_meta(aten._linalg_det.default)
1534
def _linalg_det_meta(A):
1535
squareCheckInputs(A, "linalg.det")
1536
checkFloatingOrComplex(A, "linalg.det")
1538
det = A.new_empty(A.shape[:-2])
1540
LU = A.new_empty(A.shape)
1541
LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1543
pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
1544
return det, LU, pivots
1547
@register_meta(aten.ormqr)
1554
transpose: bool = False,
1557
input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
1560
other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
1563
left_size_condition = -2 if left else -1
1565
other.shape[left_size_condition] >= tau.shape[-1],
1566
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
1569
other.shape[left_size_condition] == input.shape[-2],
1570
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
1574
tau.shape[-1] <= input.shape[-1],
1575
lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
1579
input.ndim - tau.ndim == 1,
1581
f"torch.ormqr: Expected tau to have one dimension less than input, "
1582
f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
1586
input.ndim == other.ndim,
1588
f"torch.ormqr: Expected other to have the same number of dimensions as input, "
1589
f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
1594
expected_batch_shape = input.shape[:-2]
1595
actual_batch_tau_shape = tau.shape[:-1]
1597
actual_batch_tau_shape == expected_batch_shape,
1599
f"torch.ormqr: Expected batch dimensions of tau to be "
1600
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
1604
actual_batch_other_shape = other.shape[:-2]
1606
actual_batch_other_shape == expected_batch_shape,
1608
f"torch.ormqr: Expected batch dimensions of other to be "
1609
f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
1614
tau.dtype == input.dtype,
1616
f"torch.ormqr: Expected input and tau to have the same dtype, "
1617
f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
1621
other.dtype == input.dtype,
1623
f"torch.ormqr: Expected input and other to have the same dtype, "
1624
f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
1628
checkSameDevice("torch.ormqr", tau, input, "tau")
1629
checkSameDevice("torch.ormqr", other, input, "other")
1631
return torch.empty_strided(
1633
stride=make_contiguous_strides_for(other.shape, row_major=False),
1635
device=other.device,
1639
def _padding_check_valid_input(input, padding, *, dim):
1641
len(padding) == 2 * dim,
1642
lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
1645
input_dim = input.ndim
1647
is_batch_mode = input_dim == (dim + 2)
1649
valid_batch_mode = is_batch_mode
1650
valid_non_batch_mode = not is_batch_mode
1654
for d in range(1, input_dim):
1655
valid_batch_mode = valid_batch_mode and input.size(d) != 0
1657
for d in range(0, input_dim):
1658
valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
1662
valid_batch_mode or valid_non_batch_mode,
1664
f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
1665
f"and other non-zero dimensions for input, but got: {input.shape}"
1670
def _pad1d_common(input, padding, *, is_reflection):
1676
nbatch = input.size(0)
1680
_padding_check_valid_input(input, padding, dim=1)
1682
pad_l, pad_r = padding
1684
nplane = input.size(dim_plane)
1685
input_w = input.size(dim_w)
1686
output_w = input_w + pad_l + pad_r
1690
pad_l < input_w and pad_r < input_w,
1692
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1693
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1699
lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
1703
return input.new_empty((nplane, output_w))
1705
return input.new_empty((nbatch, nplane, output_w))
1708
@register_meta(aten.reflection_pad1d)
1710
def meta_reflection_pad1d(input, padding):
1711
return _pad1d_common(input, padding, is_reflection=True)
1714
@register_meta(aten.replication_pad1d)
1716
def meta_replication_pad1d(input, padding):
1717
return _pad1d_common(input, padding, is_reflection=False)
1720
def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
1722
if not is_reflection:
1723
torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
1728
pad_l, pad_r = padding
1730
input_w = input.size(dim_w)
1731
output_w = input_w + pad_l + pad_r
1735
pad_l < input_w and pad_r < input_w,
1737
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1738
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1743
output_w == grad_output.size(dim_w),
1744
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1747
return input.new_empty(input.shape)
1750
@register_meta(aten.reflection_pad1d_backward)
1751
@out_wrapper("grad_input")
1752
def meta_reflection_pad1d_backward(grad_output, input, padding):
1753
return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
1756
@register_meta(aten.replication_pad1d_backward)
1757
@out_wrapper("grad_input")
1758
def meta_replication_pad1d_backward(grad_output, input, padding):
1759
return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
1762
def _pad2d_common(input, padding, *, is_reflection):
1768
_padding_check_valid_input(input, padding, dim=2)
1772
nbatch = input.size(0)
1777
pad_l, pad_r, pad_t, pad_b = padding
1779
nplane = input.size(dim_slices)
1780
input_h = input.size(dim_h)
1781
input_w = input.size(dim_w)
1782
output_h = input_h + pad_t + pad_b
1783
output_w = input_w + pad_l + pad_r
1787
pad_l < input_w and pad_r < input_w,
1789
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1790
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1794
pad_t < input_h and pad_b < input_h,
1796
f"Argument #6: Padding size should be less than the corresponding input dimension, "
1797
f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1802
output_w >= 1 or output_h >= 1,
1804
f"input (H: {input_h} W: {input_w}) is too small. "
1805
f"Calculated output H: {output_h} W: {output_w}"
1810
return input.new_empty((nplane, output_h, output_w))
1812
return input.new_empty((nbatch, nplane, output_h, output_w))
1815
@register_meta(aten.reflection_pad2d)
1817
def meta_reflection_pad2d(input, padding):
1818
return _pad2d_common(input, padding, is_reflection=True)
1821
@register_meta(aten.replication_pad2d)
1823
def meta_replication_pad2d(input, padding):
1824
return _pad2d_common(input, padding, is_reflection=False)
1829
aten.reflection_pad2d_backward.default,
1830
aten.reflection_pad2d_backward.grad_input,
1831
aten.replication_pad2d_backward.default,
1832
aten.replication_pad2d_backward.grad_input,
1835
@out_wrapper("grad_input")
1836
def meta_pad2d_backward(grad_output, self, padding):
1842
self_shape = self.shape
1844
nbatch = self_shape[0]
1849
pad_l, pad_r, pad_t, pad_b = padding
1851
nplane = self_shape[dim_plane]
1852
input_h = self_shape[dim_h]
1853
input_w = self_shape[dim_w]
1854
output_h = input_h + pad_t + pad_b
1855
output_w = input_w + pad_l + pad_r
1858
output_w == grad_output.size(dim_w),
1859
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1862
output_h == grad_output.size(dim_h),
1863
lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1865
return self.new_empty(self.shape)
1868
def _pad3d_common(input, padding, *, is_reflection):
1874
_padding_check_valid_input(input, padding, dim=3)
1876
batch_mode = input.ndim == 5
1878
nbatch = input.size(0)
1884
pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1886
nplane = input.size(dim_plane)
1887
input_d = input.size(dim_d)
1888
input_h = input.size(dim_h)
1889
input_w = input.size(dim_w)
1890
output_d = input_d + pad_f + pad_bk
1891
output_h = input_h + pad_t + pad_b
1892
output_w = input_w + pad_l + pad_r
1896
pad_l < input_w and pad_r < input_w,
1898
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1899
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1903
pad_t < input_h and pad_b < input_h,
1905
f"Argument #6: Padding size should be less than the corresponding input dimension, "
1906
f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1910
pad_f < input_d and pad_bk < input_d,
1912
f"Argument #8: Padding size should be less than the corresponding input dimension, "
1913
f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
1918
output_w >= 1 or output_h >= 1 or output_d >= 1,
1920
f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
1921
f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
1926
return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
1928
return input.new_empty((nplane, output_d, output_h, output_w))
1931
@register_meta(aten.reflection_pad3d)
1933
def meta_reflection_pad3d(input, padding):
1934
return _pad3d_common(input, padding, is_reflection=True)
1937
@register_meta(aten.replication_pad3d)
1939
def meta_replication_pad3d(input, padding):
1940
return _pad3d_common(input, padding, is_reflection=False)
1945
aten.reflection_pad3d_backward.default,
1946
aten.reflection_pad3d_backward.grad_input,
1947
aten.replication_pad3d_backward.default,
1948
aten.replication_pad3d_backward.grad_input,
1951
@out_wrapper("grad_input")
1952
def meta_pad3d_backward(grad_output, input, padding):
1953
torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
1954
assert input.ndim > 3
1955
assert grad_output.ndim == input.ndim
1966
pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1968
input_d = input.size(dim_d)
1969
input_h = input.size(dim_h)
1970
input_w = input.size(dim_w)
1971
output_d = input_d + pad_f + pad_bk
1972
output_h = input_h + pad_t + pad_b
1973
output_w = input_w + pad_l + pad_r
1976
output_w == grad_output.size(dim_w),
1977
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1980
output_h == grad_output.size(dim_h),
1981
lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1984
output_d == grad_output.size(dim_d),
1985
lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
1988
return input.new_empty(input.shape)
1991
@register_meta(aten._pdist_forward)
1993
def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
1995
self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
1999
return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format)
2001
return self.new_empty((n * (n - 1) // 2,)).to(
2002
memory_format=torch.legacy_contiguous_format
2006
@register_meta(aten._pdist_backward)
2008
def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
2010
self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
2013
pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
2015
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2018
@register_meta([aten.baddbmm.default, aten.baddbmm.out])
2020
def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
2021
dim1 = batch1.size(0)
2022
dim2 = batch1.size(1)
2023
dim3 = batch2.size(2)
2024
self = self.expand((dim1, dim2, dim3))
2025
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
2026
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
2028
self.dtype == batch1.dtype == batch2.dtype,
2029
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
2031
batch1_sizes = batch1.shape
2032
batch2_sizes = batch2.shape
2033
bs = batch1_sizes[0]
2034
contraction_size = batch1_sizes[2]
2036
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
2038
f"Expected size for first two dimensions of batch2 tensor to be: "
2039
f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
2042
return self.new_empty(self.size())
2045
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
2047
def meta_bernoulli(self, *, generator=None):
2049
return torch.empty_like(self).contiguous()
2052
@register_meta(aten.bernoulli_.float)
2053
def meta_bernoulli_(self, p=0.5, generator=None):
2057
@register_meta(aten.bernoulli.p)
2058
def meta_bernoulli_p(self, p=0.5, generator=None):
2060
return torch.empty_like(self).contiguous()
2063
@register_meta([aten.poisson.default, aten.poisson.out])
2065
def meta_poisson(self, generator=None):
2066
return torch.empty_like(self)
2069
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
2070
def meta__fused_moving_avg_obs_fq_helper(
2082
per_row_fake_quant=False,
2083
symmetric_quant=False,
2086
ch_axis < self.dim(),
2087
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
2089
mask = torch.empty_like(self, dtype=torch.bool)
2090
return (torch.empty_like(self), mask)
2093
@register_meta(aten.mm)
2096
torch._check(a.dim() == 2, lambda: "a must be 2D")
2097
torch._check(b.dim() == 2, lambda: "b must be 2D")
2102
lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
2104
return a.new_empty(N, P)
2107
def _compute_reduction_shape(self, dims, keepdim):
2109
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
2111
return utils.compute_reduction_output_shape(self.shape, dims)
2118
def device_hint(tensor) -> "str":
2119
if isinstance(tensor, torch._subclasses.FakeTensor):
2120
return tensor.fake_device.type
2125
def calc_conv_nd_return_shape(
2126
input_tensor: torch.Tensor,
2127
weight: torch.Tensor,
2128
stride: Union[List[int], int],
2129
padding: Union[List[int], int],
2130
dilation: Union[List[int], int],
2131
is_transposed: bool,
2133
output_padding: Optional[Union[List[int], int]] = None,
2135
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
2137
Formula to apply to calculate the length of some dimension of the output
2139
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
2142
ln: length of the dimension
2143
p: padding in that dim
2144
d: dilation in that dim
2145
k: kernel size in that dim
2146
s: stride in that dim
2150
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
2152
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
2154
Formula to apply to calculate the length of some dimension of the output
2155
if transposed convolution is used.
2156
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
2159
ln: length of the dimension
2160
p: padding in that dim
2161
d: dilation in that dim
2162
k: kernel size in that dim
2163
s: stride in that dim
2164
op: output padding in that dim
2169
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
2171
kernel_size = weight.shape[2:]
2172
dims = input_tensor.shape[2:]
2174
out_channels = groups * weight.shape[1]
2176
out_channels = weight.shape[0]
2177
if weight.shape[1] * groups != input_tensor.shape[1]:
2178
raise RuntimeError("Invalid channel dimensions")
2180
ret_shape = [input_tensor.shape[0], out_channels]
2181
if isinstance(stride, IntLike):
2182
stride = [stride] * len(dims)
2183
elif len(stride) == 1:
2184
stride = [stride[0]] * len(dims)
2186
if isinstance(padding, IntLike):
2187
padding = [padding] * len(dims)
2188
elif len(padding) == 1:
2189
padding = [padding[0]] * len(dims)
2191
if isinstance(dilation, IntLike):
2192
dilation = [dilation] * len(dims)
2193
elif len(dilation) == 1:
2194
dilation = [dilation[0]] * len(dims)
2196
output_padding_list: Optional[List[int]] = None
2198
if isinstance(output_padding, IntLike):
2199
output_padding_list = [output_padding] * len(dims)
2200
elif len(output_padding) == 1:
2201
output_padding_list = [output_padding[0]] * len(dims)
2203
output_padding_list = output_padding
2205
for i in range(len(dims)):
2207
if output_padding_list:
2209
_formula_transposed(
2215
output_padding_list[i],
2220
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
2226
def is_channels_last(ten):
2227
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
2230
@register_meta(aten.convolution.default)
2232
input_tensor: torch.Tensor,
2233
weight: torch.Tensor,
2237
dilation: List[int],
2238
is_transposed: bool,
2239
output_padding: List[int],
2242
def pick_memory_format():
2243
if device_hint(input_tensor) == "cuda":
2244
if is_channels_last(input_tensor) or is_channels_last(weight):
2245
return torch.channels_last
2247
if is_channels_last(input_tensor):
2248
return torch.channels_last
2249
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
2250
return torch.contiguous_format
2251
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
2252
return torch.preserve_format
2254
shape_out = calc_conv_nd_return_shape(
2262
output_padding if is_transposed else None,
2265
input_channels_dim = 1
2266
output_channels_dim = 1
2267
if input_tensor.size(input_channels_dim) == 0:
2268
shape_out[output_channels_dim] = 0
2270
out = input_tensor.new_empty(shape_out)
2271
out = out.to(memory_format=pick_memory_format())
2275
if torch._C._has_mkldnn:
2276
_meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
2277
"mkldnn", "IMPL", "Meta"
2280
@register_meta(torch.ops.mkldnn._convolution_pointwise.default)
2281
def meta_mkldnn_convolution_default(
2293
shape_out = calc_conv_nd_return_shape(
2294
input_tensor, weight, stride, padding, dilation, False, groups, []
2296
out = input_tensor.new_empty(shape_out)
2297
out_memory_format = torch.channels_last
2298
if input_tensor.dim() == 5:
2299
out_memory_format = torch.channels_last_3d
2300
out = out.to(memory_format=out_memory_format)
2303
@register_meta(torch.ops.mkldnn._linear_pointwise.default)
2304
def meta_linear_pointwise_default(
2305
input_tensor, weight, bias, attr, scalars, algorithm
2307
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2309
if torch._C.has_mkl:
2310
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
2311
"mkl", "IMPL", "Meta"
2314
@register_meta(torch.ops.mkl._mkl_linear)
2315
def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size):
2316
return input_tensor.new_empty(
2317
(*input_tensor.shape[:-1], orig_weight.shape[0])
2320
_meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
2321
"onednn", "IMPL", "Meta"
2324
@register_meta(torch.ops.onednn.qconv2d_pointwise.default)
2325
def meta_qconv2d_pointwise(
2344
shape_out = calc_conv_nd_return_shape(
2354
assert output_dtype in [torch.float32, torch.bfloat16]
2355
out = x.new_empty(shape_out, dtype=output_dtype)
2356
out = out.to(memory_format=torch.channels_last)
2359
@register_meta(torch.ops.onednn.qlinear_pointwise.default)
2360
@register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
2361
def meta_qlinear_pointwise(
2376
output_shape = list(x.shape)
2378
output_shape[-1] = w.shape[1]
2379
assert output_dtype in [torch.float32, torch.bfloat16]
2380
out = x.new_empty(output_shape, dtype=output_dtype)
2383
_meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
2384
"quantized", "IMPL", "Meta"
2387
@register_meta(torch.ops.quantized.max_pool2d)
2388
def meta_quantized_max_pool2d(
2400
) = max_pool2d_checks_and_compute_shape(
2401
input, kernel_size, stride, padding, dilation, ceil_mode
2403
nbatch = input.size(-4) if input.dim() == 4 else 1
2404
memory_format = torch.channels_last
2405
if input.dim() == 3:
2406
size = [nInputPlane, outputHeight, outputWidth]
2408
size = [nbatch, nInputPlane, outputHeight, outputWidth]
2412
device=input.device,
2413
memory_format=memory_format,
2418
def check_dim_size(tensor, dim, dim_size, size):
2420
tensor.dim() == dim and tensor.shape[dim_size] == size,
2421
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
2422
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
2426
@register_meta(aten.avg_pool2d.default)
2433
count_include_pad=True,
2434
divisor_override=None,
2436
def unpack(name, val):
2439
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
2442
W = H if len(val) == 1 else val[1]
2445
kH, kW = unpack("kernel_size", kernel_size)
2447
len(stride) in [0, 1, 2],
2448
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2450
if len(stride) == 0:
2452
elif len(stride) == 1:
2453
dH, dW = stride[0], stride[0]
2455
dH, dW = unpack("stride", stride)
2457
padH, padW = unpack("padding", padding)
2460
divisor_override is None or divisor_override != 0,
2461
lambda: "divisor must be not zero",
2464
nbatch = input.size(-4) if input.dim() == 4 else 1
2465
nInputPlane = input.size(-3)
2466
inputHeight = input.size(-2)
2467
inputWidth = input.size(-1)
2469
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2470
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2472
memory_format = utils.suggest_memory_format(input)
2491
if input.dim() == 3:
2492
size = [nInputPlane, outputHeight, outputWidth]
2494
size = [nbatch, nInputPlane, outputHeight, outputWidth]
2498
device=input.device,
2499
memory_format=memory_format,
2504
def avg_pool2d_backward_shape_check(
2540
nOutputPlane = nInputPlane
2542
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
2543
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
2544
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
2548
@register_meta(aten.avg_pool2d_backward.default)
2549
def meta_avg_pool2d_backward(
2561
len(kernel_size) == 1 or len(kernel_size) == 2,
2562
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
2565
kW = kH if len(kernel_size) == 1 else kernel_size[1]
2567
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
2568
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2570
dH = kH if len(stride) == 0 else stride[0]
2571
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
2573
len(padding) == 1 or len(padding) == 2,
2574
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
2577
padW = padH if len(padding) == 1 else padding[1]
2580
divisor_override is None or divisor_override != 0,
2581
lambda: "divisor must be not zero",
2584
input_size = input.shape
2585
nbatch = input_size[-4] if input.dim() == 4 else 1
2586
nInputPlane = input_size[-3]
2587
inputHeight = input_size[-2]
2588
inputWidth = input_size[-1]
2590
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2591
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2593
mem_format = utils.suggest_memory_format(input)
2595
avg_pool2d_backward_shape_check(
2616
device=input.device,
2617
memory_format=mem_format,
2621
@register_meta(aten.avg_pool3d)
2629
count_include_pad=True,
2630
divisor_override=None,
2633
len(kernel_size) in (1, 3),
2634
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2637
kH = kT if len(kernel_size) == 1 else kernel_size[1]
2638
kW = kT if len(kernel_size) == 1 else kernel_size[2]
2641
not stride or len(stride) in (1, 3),
2642
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2644
dT = kT if not stride else stride[0]
2645
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2646
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2649
len(padding) in (1, 3),
2650
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2653
padH = padT if len(padding) == 1 else padding[1]
2654
padW = padT if len(padding) == 1 else padding[2]
2657
input.ndim in (4, 5),
2658
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2662
not divisor_override or divisor_override != 0,
2663
lambda: "divisor must be not zero",
2666
nbatch = input.size(0)
2667
nslices = input.size(-4)
2668
itime = input.size(-3)
2669
iheight = input.size(-2)
2670
iwidth = input.size(-1)
2672
otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2673
oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2674
owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2698
check_input_size=True,
2702
return input.new_empty((nslices, otime, oheight, owidth))
2704
return input.new_empty((nbatch, nslices, otime, oheight, owidth))
2707
@register_meta(aten.avg_pool3d_backward)
2708
@out_wrapper("grad_input")
2709
def meta_avg_pool3d_backward(
2720
len(kernel_size) in (1, 3),
2721
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2724
kH = kT if len(kernel_size) == 1 else kernel_size[1]
2725
kW = kT if len(kernel_size) == 1 else kernel_size[2]
2728
not stride or len(stride) in (1, 3),
2729
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2731
dT = kT if not stride else stride[0]
2732
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2733
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2736
len(padding) in (1, 3),
2737
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2740
padH = padT if len(padding) == 1 else padding[1]
2741
padW = padT if len(padding) == 1 else padding[2]
2744
input.ndim in (4, 5),
2745
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2749
not divisor_override or divisor_override != 0,
2750
lambda: "divisor must be not zero",
2753
nslices = input.size(-4)
2754
itime = input.size(-3)
2755
iheight = input.size(-2)
2756
iwidth = input.size(-1)
2758
otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2759
oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2760
owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2762
avg_pool3d_backward_shape_check(
2778
otime_for_shape_check,
2779
oheight_for_shape_check,
2780
owidth_for_shape_check,
2781
"avg_pool3d_backward()",
2784
return input.new_empty(input.shape)
2787
@register_meta(aten._adaptive_avg_pool2d.default)
2788
def meta_adaptive_avg_pool2d(self, output_size):
2790
self.ndim == 3 or self.ndim == 4,
2791
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
2793
output_shape = self.shape[:-2] + tuple(output_size)
2794
memory_format = utils.suggest_memory_format(self)
2801
memory_format=memory_format,
2805
@register_meta(aten._adaptive_avg_pool3d.default)
2806
def meta_adaptive_avg_pool3d(self, output_size):
2808
self.ndim == 4 or self.ndim == 5,
2809
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
2811
return self.new_empty(self.shape[:-3] + tuple(output_size))
2814
@register_meta(aten._adaptive_avg_pool2d_backward.default)
2815
def meta__adaptive_avg_pool2d_backward(grad_out, self):
2816
ndim = grad_out.ndim
2817
for i in range(1, ndim):
2819
grad_out.size(i) > 0,
2820
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
2821
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
2824
ndim == 3 or ndim == 4,
2825
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
2828
self.dtype == grad_out.dtype,
2829
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
2831
memory_format = torch.contiguous_format
2832
if is_channels_last(self):
2833
memory_format = torch.channels_last
2834
return self.new_empty(self.shape).to(memory_format=memory_format)
2837
@register_meta(aten._adaptive_avg_pool3d_backward)
2838
@out_wrapper("grad_input")
2839
def meta__adaptive_avg_pool3d_backward(grad_output, self):
2840
_adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
2841
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2844
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
2845
ndim = grad_output.ndim
2846
for i in range(1, ndim):
2848
grad_output.size(i) > 0,
2850
f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
2851
f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
2856
@register_meta(aten.adaptive_max_pool2d)
2857
@out_wrapper("out", "indices")
2858
def meta_adaptive_max_pool2d(input, output_size):
2862
lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
2864
for i in range(1, ndim):
2868
f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
2869
f"but input has sizes {input.shape} with dimension {i} being empty"
2874
len(output_size) == 2,
2875
lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
2883
sizeB = input.size(0)
2886
sizeD = input.size(dimH - 1)
2887
osizeH, osizeW = output_size
2890
out_shape = (sizeD, osizeH, osizeW)
2891
out = input.new_empty(out_shape)
2892
indices = input.new_empty(out_shape, dtype=torch.int64)
2895
out_shape = (sizeB, sizeD, osizeH, osizeW)
2896
memory_format = utils.suggest_memory_format(input)
2897
out = input.new_empty(out_shape).to(memory_format=memory_format)
2898
indices = input.new_empty(out_shape, dtype=torch.int64).to(
2899
memory_format=memory_format
2904
@register_meta(aten.adaptive_max_pool2d_backward)
2905
@out_wrapper("grad_input")
2906
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
2907
ndim = grad_output.ndim
2910
lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
2913
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
2916
input.dtype == grad_output.dtype,
2917
lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
2920
memory_format = utils.suggest_memory_format(input)
2921
return input.new_empty(input.shape).to(memory_format=memory_format)
2924
@register_meta(aten.adaptive_max_pool3d)
2925
@out_wrapper("out", "indices")
2926
def meta_adaptive_max_pool3d(input, output_size):
2930
lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
2932
for i in range(1, ndim):
2936
f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
2937
f"but input has sizes {input.shape} with dimension {i} being empty"
2942
len(output_size) == 3,
2943
lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
2951
sizeB = input.size(0)
2954
sizeD = input.size(dimD)
2955
osizeT, osizeH, osizeW = output_size
2958
out_shape = (sizeD, osizeT, osizeH, osizeW)
2960
out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW)
2962
out = input.new_empty(out_shape)
2963
indices = input.new_empty(out_shape, dtype=torch.int64)
2968
@register_meta(aten.adaptive_max_pool3d_backward)
2969
@out_wrapper("grad_input")
2970
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
2971
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
2972
return input.new_empty(input.shape)
2975
@register_meta(aten.repeat_interleave.Tensor)
2976
def meta_repeat_interleave_Tensor(repeats, output_size=None):
2977
if output_size is None:
2978
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
2979
return repeats.new_empty(output_size)
2982
@register_meta([aten.complex.default, aten.complex.out])
2984
def meta_complex(real, imag):
2985
assert real.dtype.is_floating_point
2986
assert imag.dtype.is_floating_point
2987
out_shape = _broadcast_shapes(real.shape, imag.shape)
2988
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
2991
@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
2993
def nonzero_static(self, *, size: int, fill_value: int = -1):
2994
return self.new_empty((size, self.dim()), dtype=torch.long)
2997
@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
2998
def meta_index_Tensor(self, indices):
2999
torch._check(bool(indices), lambda: "at least one index must be provided")
3002
result: List[Optional[Tensor]] = []
3003
for i, index in enumerate(indices):
3004
if index is not None:
3006
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
3007
lambda: "tensors used as indices must be long, int, byte or bool tensors",
3009
if index.dtype in [torch.int8, torch.bool]:
3010
nonzero = index.nonzero()
3013
k + index.ndim <= self.ndim,
3014
lambda: f"too many indices for tensor of dimension {self.ndim}",
3016
for j in range(index.ndim):
3018
index.shape[j] == self.shape[k + j],
3019
lambda: f"The shape of the mask {index.shape} at index {i} "
3020
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
3022
result.append(nonzero.select(1, j))
3024
result.append(index)
3026
result.append(index)
3029
len(indices) <= self.ndim,
3030
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
3033
import torch._refs as refs
3035
indices = list(refs._maybe_broadcast(*indices))
3037
while len(indices) < self.ndim:
3038
indices.append(None)
3046
has_contiguous_subspace = False
3047
for index in indices:
3049
if index is not None:
3055
if index is not None:
3058
has_contiguous_subspace = True
3063
if not has_contiguous_subspace:
3065
transposed_indices = []
3066
for i, index in enumerate(indices):
3067
if index is not None:
3069
transposed_indices.append(index)
3070
for i, index in enumerate(indices):
3073
transposed_indices.append(index)
3074
self = self.permute(dims)
3075
indices = transposed_indices
3083
before_shape: List[int] = []
3084
after_shape: List[int] = []
3085
replacement_shape: List[int] = []
3086
for dim, index in enumerate(indices):
3088
if replacement_shape:
3089
after_shape.append(self.shape[dim])
3091
before_shape.append(self.shape[dim])
3093
replacement_shape = list(index.shape)
3094
return self.new_empty(before_shape + replacement_shape + after_shape)
3097
@register_meta([aten.convolution_backward.default])
3098
def meta_convolution_backward(
3113
backend_grad_input = None
3114
backend_grad_weight = None
3115
backend_grad_bias = None
3118
backend_grad_input = grad_output_.new_empty(input_.size())
3120
backend_grad_weight = grad_output_.new_empty(weight_.size())
3122
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3124
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
3127
@register_meta([aten.addbmm.default, aten.addbmm.out])
3129
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
3130
dim1 = batch1.size(1)
3131
dim2 = batch2.size(2)
3132
self = self.expand((dim1, dim2))
3133
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3134
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3136
batch1.size(0) == batch2.size(0),
3137
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
3140
batch1.size(2) == batch2.size(1),
3142
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
3143
f"and {batch2.size(1)}x{batch2.size(2)})"
3147
self.size(0) == dim1 and self.size(1) == dim2,
3148
lambda: "self tensor does not match matmul output shape",
3150
return self.new_empty(self.size())
3153
@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
3154
def meta__fused_adam_(
3172
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3174
isinstance(l, List),
3175
lambda: f"exponent must be a tensor list but got {type(l)}",
3179
@register_meta([aten._fused_adam.default])
3180
def meta__fused_adam(
3198
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3200
isinstance(l, List),
3201
lambda: f"exponent must be a tensor list but got {type(l)}",
3204
def empty_like_list(tensor_list):
3205
return [torch.empty_like(t) for t in tensor_list]
3208
empty_like_list(self),
3209
empty_like_list(grads),
3210
empty_like_list(exp_avgs),
3211
empty_like_list(exp_avg_sqs),
3212
empty_like_list(max_exp_avg_sqs),
3216
@register_meta([aten._int_mm])
3218
def meta__int_mm(a, b):
3219
torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
3220
torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
3222
a.dtype is torch.int8,
3223
lambda: f"expected self to be int8, got {a.dtype}",
3226
b.dtype is torch.int8,
3227
lambda: f"expected mat2 to be int8, got {b.dtype}",
3230
a.size(1) == b.size(0),
3232
f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
3233
f"and {b.size(0)}x{b.size(1)})"
3236
return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
3239
@register_meta([aten._convert_weight_to_int4pack])
3240
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
3241
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3243
w.dtype is torch.uint8,
3244
lambda: f"expected w to be uint8, got {w.dtype}",
3251
k // (inner_k_tiles * 16),
3259
@register_meta([aten._weight_int4pack_mm])
3260
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
3261
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3262
torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
3264
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3265
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3268
w.dtype is torch.int32,
3269
lambda: f"expected w to be int32, got {w.dtype}",
3271
return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
3274
@register_meta([aten._weight_int8pack_mm])
3275
def meta__weight_int8pack_mm(x, w, q_scales):
3276
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3278
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
3279
lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
3281
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3283
w.dtype is torch.int8,
3284
lambda: f"expected w to be int8, got {w.dtype}",
3286
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3289
@register_meta(aten._cdist_forward.default)
3290
def meta_cdist_forward(x1, x2, p, compute_mode):
3293
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
3297
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
3300
x1.size(-1) == x2.size(-1),
3301
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
3304
utils.is_float_dtype(x1.dtype),
3305
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
3308
utils.is_float_dtype(x2.dtype),
3309
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
3311
torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
3313
compute_mode in (None, 1, 2),
3314
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
3318
batch_tensor1 = x1.shape[:-2]
3319
batch_tensor2 = x2.shape[:-2]
3320
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3321
output_shape.extend([r1, r2])
3322
return x1.new_empty(output_shape)
3325
@register_meta(aten._cdist_backward)
3327
def meta_cdist_backward(grad, x1, x2, p, cdist):
3331
batch_tensor1 = x1.shape[:-2]
3332
batch_tensor2 = x2.shape[:-2]
3333
expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3334
tensor1_expand_size = expand_batch_portion.copy()
3335
tensor1_expand_size.extend([r1, c1])
3336
batch_product = math.prod(expand_batch_portion)
3337
if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
3338
return torch.zeros_like(x1)
3339
if tensor1_expand_size != list(x1.shape):
3340
x1 = x1.expand(tensor1_expand_size)
3341
return torch.empty_like(x1, memory_format=torch.contiguous_format)
3348
@register_meta(aten._embedding_bag.default)
3349
def meta_embedding_bag(
3353
scale_grad_by_freq=False,
3356
per_sample_weights=None,
3357
include_last_offset=False,
3361
indices.dtype in (torch.long, torch.int),
3362
lambda: f"expected indices to be long or int, got {indices.dtype}",
3365
offsets.dtype in (torch.long, torch.int),
3366
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
3369
utils.is_float_dtype(weight.dtype),
3370
lambda: f"expected weight to be floating point type, got {weight.dtype}",
3373
num_bags = offsets.size(0)
3374
if include_last_offset:
3377
lambda: "include_last_offset: numBags should be at least 1",
3381
output = weight.new_empty(num_bags, weight.size(1))
3382
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
3384
if per_sample_weights is not None:
3387
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
3390
per_sample_weights.dtype == weight.dtype,
3391
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
3394
per_sample_weights.ndim == 1,
3395
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
3398
per_sample_weights.numel() == indices.numel(),
3400
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
3401
f"to be the same as indices.numel() ({indices.numel()})"
3405
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
3407
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
3410
def is_fast_path_index_select(src, output, padding_idx):
3412
(src.dtype == torch.float or src.dtype == torch.half)
3413
and src.stride(1) == 1
3414
and output.stride(1) == 1
3418
def is_fast_path(src, scale, output, padding_idx):
3419
if scale is not None:
3420
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
3422
return is_fast_path_index_select(src, output, padding_idx)
3424
if device_hint(offsets) != "cpu":
3425
offset2bag = indices.new_empty(indices.size(0))
3426
bag_size = indices.new_empty(offsets.size())
3427
if mode == MODE_MAX:
3428
max_indices = indices.new_empty(num_bags, weight.size(1))
3430
max_indices = indices.new_empty(0)
3432
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
3433
if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
3434
offset2bag = offsets.new_empty(indices.size(0))
3436
offset2bag = offsets.new_empty(0)
3437
bag_size = offsets.new_empty(num_bags)
3439
numBags = offsets.shape[0]
3440
if mode == MODE_MAX:
3441
if include_last_offset:
3444
lambda: "include_last_offset: numBags should be at least 1",
3447
max_indices = offsets.new_empty(numBags, weight.shape[1])
3449
max_indices = offsets.new_empty(bag_size.size())
3450
return output, offset2bag, bag_size, max_indices
3453
@register_meta(aten._embedding_bag_forward_only.default)
3454
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
3455
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
3456
weight, indices, offsets, *args
3458
if device_hint(offsets) == "cpu":
3459
bag_size = offsets.new_empty(offsets.size())
3460
return output, offset2bag, bag_size, max_indices
3463
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
3468
if input.dtype.is_floating_point or input.dtype.is_complex:
3470
elif promote_int_to_long:
3476
@register_meta([aten.nansum.default, aten.nansum.out])
3478
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
3479
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
3480
dims = utils.reduction_dims(input.shape, dims)
3481
output_shape = _compute_reduction_shape(input, dims, keepdim)
3482
return input.new_empty(output_shape, dtype=output_dtype)
3485
@register_meta([aten.median.default, aten.nanmedian.default])
3486
def meta_median(input):
3487
output_shape = utils.compute_reduction_output_shape(
3488
input.shape, tuple(range(input.dim()))
3490
return input.new_empty(output_shape)
3496
aten.median.dim_values,
3498
aten.nanmedian.dim_values,
3503
@out_wrapper("values", "indices")
3504
def meta_median_mode_dim(input, dim=-1, keepdim=False):
3505
if device_hint(input) == "cuda":
3506
utils.alert_not_deterministic("median CUDA with indices output")
3507
dim = utils.reduction_dims(input.shape, (dim,))
3508
output_shape = _compute_reduction_shape(input, dim, keepdim)
3510
input.new_empty(output_shape),
3511
input.new_empty(output_shape, dtype=torch.long),
3515
@register_meta(aten.logical_not_.default)
3516
def meta_logical_not_(self):
3520
@register_meta(aten.repeat.default)
3521
def meta_repeat(self, repeats):
3523
len(repeats) >= self.dim(),
3524
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3529
num_new_dimensions = len(repeats) - self.dim()
3530
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
3531
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
3532
return self.new_empty(target_size)
3535
@register_meta(aten.zero_.default)
3536
def meta_zero_(self):
3546
aten.logical_and_.default,
3547
aten.logical_or_.default,
3548
aten.logical_xor_.default,
3551
def meta_binop_inplace(self, other):
3552
if isinstance(other, torch.Tensor):
3553
check_inplace_broadcast(self.shape, other.shape)
3565
def meta_binop_inplace_alpha(self, other, alpha=1):
3566
if isinstance(other, torch.Tensor):
3567
check_inplace_broadcast(self.shape, other.shape)
3571
@register_meta([aten.round.default, aten.round.decimals])
3572
def meta_round(self, **kwargs):
3573
return elementwise_meta(
3574
self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3578
def shift_dtype_check(fn_name, self, val):
3580
utils.is_integer_dtype(self.dtype),
3581
lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
3583
if isinstance(val, torch.Tensor):
3585
utils.is_integer_dtype(val.dtype),
3586
lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
3590
isinstance(val, IntLike),
3591
lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
3595
@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
3596
def meta_rshifts(self, other):
3597
shift_dtype_check("rshift", self, other)
3598
return elementwise_meta(
3599
self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3603
@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
3604
def meta_lshifts(self, other):
3605
shift_dtype_check("lshift", self, other)
3606
return elementwise_meta(
3607
self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3611
@register_meta(aten.zero.default)
3613
return self.new_empty(self.shape)
3616
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
3617
def meta_fill_(self, val):
3621
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
3622
def meta_fill(self, val):
3623
return torch.empty_like(self)
3626
@register_meta(aten.relu_.default)
3627
def meta_relu_(self):
3631
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
3632
def meta_index_put(self, indices, values, accumulate=False):
3633
return torch.empty_like(self)
3636
@register_meta(aten.masked_fill_.Scalar)
3637
def meta_masked_fill_(self, mask, value):
3638
check_inplace_broadcast(self.shape, mask.shape)
3642
@register_meta(aten._masked_scale.default)
3643
def meta__masked_scale(self, mask, scale):
3644
masked_scale = self.new_empty(self.size()).to(
3645
memory_format=utils.suggest_memory_format(self)
3650
@register_meta(aten.masked_scatter_)
3651
def meta_masked_scatter_(self, mask, source):
3653
mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
3656
self.dtype == source.dtype,
3657
lambda: "masked_scatter: expected self and source to have same "
3658
"dtypes but got {self.dtype} and {source.dtype}",
3663
@register_meta(aten.masked_scatter)
3665
def meta_masked_scatter(self, mask, source):
3666
self, mask = _maybe_broadcast(self, mask)
3667
output = torch.empty_like(self, memory_format=torch.contiguous_format)
3668
return meta_masked_scatter_(output, mask, source)
3671
@register_meta(aten.masked_scatter_backward)
3672
def meta_masked_scatter_backward(self, mask, sizes):
3673
return self.new_empty(sizes)
3676
@register_meta(aten.index_put_.default)
3677
def meta_index_put_(self, indices, values, accumulate=False):
3681
@register_meta(aten.alias.default)
3682
def meta_alias(self):
3683
return self.view(self.shape)
3686
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
3687
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3688
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3690
batch1_sizes = batch1.size()
3691
batch2_sizes = batch2.size()
3693
bs = batch1_sizes[0]
3694
contraction_size = batch1_sizes[2]
3695
res_rows = batch1_sizes[1]
3696
res_cols = batch2_sizes[2]
3697
output_size = (bs, res_rows, res_cols)
3700
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
3701
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
3702
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
3707
output = batch2.new_empty(output_size)
3709
if not is_bmm and self_baddbmm is not None:
3710
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
3712
self_baddbmm.size() == output_size,
3713
lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
3719
@register_meta(aten.bmm.default)
3720
def meta_bmm(self, mat2):
3721
return common_meta_baddbmm_bmm(self, mat2, True)
3729
if r != 0 and (bool(r < 0) != bool(y < 0)):
3734
def pooling_output_shape_pad_lr(
3748
- dilation * (kernelSize - 1)
3750
+ (stride - 1 if ceil_mode else 0),
3756
if (outputSize - 1) * stride >= inputSize + pad_l:
3761
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
3762
torch._check(stride != 0, lambda: "stride should not be zero")
3763
torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
3765
pad <= ((kernelSize - 1) * dilation + 1) // 2,
3767
f"pad should be at most half of effective kernel size, but got pad={pad}, "
3768
f"kernel_size={kernelSize} and dilation={dilation}"
3771
return pooling_output_shape_pad_lr(
3772
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
3776
def pool2d_shape_check(
3794
nOutputPlane = nInputPlane
3798
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
3802
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
3805
dilationH > 0 and dilationW > 0,
3806
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
3809
valid_dims = input.size(1) != 0 and input.size(2) != 0
3811
if memory_format == torch.channels_last:
3813
ndim == 4 and valid_dims and input.size(3) != 0,
3814
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
3815
" with optional 0 dim batch size for input, but got: {input.size()}",
3819
(ndim == 3 and input.size(0) != 0 and valid_dims)
3820
or (ndim == 4 and valid_dims and input.size(3) != 0),
3821
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
3825
kW // 2 >= padW and kH // 2 >= padH,
3826
lambda: "pad should be smaller than or equal to half of kernel size, but got "
3827
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
3831
outputWidth >= 1 and outputHeight >= 1,
3832
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
3833
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
3834
"Output size is too small",
3838
def pool3d_shape_check(
3860
check_input_size: bool = False,
3865
kT > 0 and kW > 0 and kH > 0,
3867
f"kernel size should be greater than zero, but got "
3868
f"kT: {kT}, kH: {kH}, kW: {kW}"
3872
dT > 0 and dW > 0 and dH > 0,
3874
f"stride should be greater than zero, but got "
3875
f"dT: {dT}, dH: {dH}, dW: {dW}"
3879
dilationT > 0 and dilationW > 0 and dilationH > 0,
3881
f"dilation should be greater than zero, but got "
3882
f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
3888
lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
3891
for i in range(ndim):
3892
if ndim == 5 and i == 0:
3898
f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
3899
f" but input has a shape of {input.shape}"
3900
f" and non-batch dimension {input.size(i)} has length zero!"
3904
if check_input_size:
3906
itime >= kT and iheight >= kH and iwidth >= kW,
3908
f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
3909
f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
3914
kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
3916
f"pad should be smaller than or equal to half of kernel size, but got "
3917
f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
3922
otime >= 1 and owidth >= 1 and oheight >= 1,
3924
f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
3925
f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
3926
f"Output size is too small"
3931
def max_pool3d_backward_shape_check(
3982
check_dim_size(grad_output, ndim, ndim - 4, nslices)
3983
check_dim_size(grad_output, ndim, ndim - 3, otime)
3984
check_dim_size(grad_output, ndim, ndim - 2, oheight)
3985
check_dim_size(grad_output, ndim, ndim - 1, owidth)
3987
check_dim_size(indices, ndim, ndim - 4, nslices)
3988
check_dim_size(indices, ndim, ndim - 3, otime)
3989
check_dim_size(indices, ndim, ndim - 2, oheight)
3990
check_dim_size(indices, ndim, ndim - 1, owidth)
3993
def avg_pool3d_backward_shape_check(
3995
grad_output: Tensor,
4041
check_dim_size(grad_output, ndim, ndim - 4, nslices)
4042
check_dim_size(grad_output, ndim, ndim - 3, otime)
4043
check_dim_size(grad_output, ndim, ndim - 2, oheight)
4044
check_dim_size(grad_output, ndim, ndim - 1, owidth)
4047
def max_pool2d_checks_and_compute_shape(
4056
def unpack(name, val):
4059
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
4062
W = H if len(val) == 1 else val[1]
4065
kH, kW = unpack("kernel_size", kernel_size)
4068
len(stride) in [0, 1, 2],
4069
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
4071
if len(stride) == 0:
4074
dH, dW = unpack("stride", stride)
4076
padH, padW = unpack("padding", padding)
4077
dilationH, dilationW = unpack("dilation", dilation)
4078
nInputPlane = input.size(-3)
4079
inputHeight = input.size(-2)
4080
inputWidth = input.size(-1)
4082
memory_format = utils.suggest_memory_format(input)
4083
if memory_format == torch.channels_last:
4086
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
4088
elif memory_format == torch.contiguous_format:
4090
input.dim() in [3, 4],
4091
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
4096
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
4099
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
4100
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
4120
return nInputPlane, outputHeight, outputWidth
4123
@register_meta(aten.max_pool2d_with_indices_backward.default)
4124
def meta_max_pool2d_with_indices_backward(
4138
) = max_pool2d_checks_and_compute_shape(
4139
self, kernel_size, stride, padding, dilation, ceil_mode
4143
self.dtype == grad_output.dtype,
4144
lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
4147
nOutputPlane = nInputPlane
4150
def _check_dim_size(t):
4151
check_dim_size(t, ndim, ndim - 3, nOutputPlane)
4152
check_dim_size(t, ndim, ndim - 2, outputHeight)
4153
check_dim_size(t, ndim, ndim - 1, outputWidth)
4155
_check_dim_size(grad_output)
4156
_check_dim_size(indices)
4158
memory_format = utils.suggest_memory_format(self)
4163
memory_format=memory_format,
4167
@register_meta(aten.max_pool2d_with_indices.default)
4168
def meta_max_pool2d_with_indices(
4180
) = max_pool2d_checks_and_compute_shape(
4181
input, kernel_size, stride, padding, dilation, ceil_mode
4184
nbatch = input.size(-4) if input.dim() == 4 else 1
4185
memory_format = utils.suggest_memory_format(input)
4186
if input.dim() == 3:
4187
size = [nInputPlane, outputHeight, outputWidth]
4189
size = [nbatch, nInputPlane, outputHeight, outputWidth]
4194
device=input.device,
4195
memory_format=memory_format,
4200
device=input.device,
4201
memory_format=memory_format,
4206
@register_meta(aten.fractional_max_pool2d.default)
4207
def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
4209
self.ndim in (3, 4),
4210
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
4214
for d in range(ndim - 3, ndim):
4217
f"fractional_max_pool2d: Expected input to have non-zero "
4218
f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
4223
len(kernel_size) == 2,
4224
lambda: "fractional_max_pool2d: kernel_size must"
4225
"either be a single int or tuple of Ints",
4228
len(output_size) == 2,
4229
lambda: "fractional_max_pool2d: output_size must "
4230
"either be a single int or tuple of Ints",
4233
input_channels = self.size(-3)
4234
input_height = self.size(-2)
4235
input_width = self.size(-1)
4237
input_batch = self.size(0)
4242
self.dtype == random_samples.dtype,
4243
lambda: "Expect _random_samples to have the same dtype as input",
4246
random_samples.ndim == 3,
4247
lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
4250
n = random_samples.size(0)
4251
c = random_samples.size(1)
4252
d = random_samples.size(2)
4255
"Expect _random_samples.size(0) no less then input batch size.",
4258
c == input_channels,
4259
lambda: "Expect _random_samples.size(1) equals to input channel size.",
4261
torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
4264
output_size[0] + kernel_size[0] - 1 <= input_height,
4265
lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
4268
output_size[1] + kernel_size[1] - 1 <= input_width,
4269
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
4273
size = [input_batch, input_channels, output_size[0], output_size[1]]
4275
size = [input_channels, output_size[0], output_size[1]]
4291
@register_meta(aten.max_unpool2d)
4293
def meta_max_unpool2d(self, indices, output_size):
4294
utils.alert_not_deterministic("max_unpooling2d_forward_out")
4297
indices.dtype == torch.int64,
4298
lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4301
len(output_size) == 2,
4303
f"There should be exactly two elements (height, width) in output_size, "
4304
f"but got {len(output_size)} elements."
4308
oheight, owidth = output_size
4311
self.ndim in (3, 4),
4313
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4314
f"but got a tensor with {self.ndim} dimensions."
4318
self.shape == indices.shape,
4320
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
4321
f"but got indices tensor with shape: {indices.shape}"
4325
for i in range(1, self.ndim):
4329
f"max_unpooling2d(): "
4330
f"Expected input to have non-zero size for non-batch dimensions, "
4331
f"but got {self.shape} with dimension {i} being empty."
4335
self = self.contiguous()
4338
nchannels = self.size(0)
4339
result = self.new_empty((nchannels, oheight, owidth))
4341
nbatch = self.size(0)
4342
nchannels = self.size(1)
4343
result = self.new_empty((nbatch, nchannels, oheight, owidth))
4348
def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4350
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4353
input.ndim in (4, 5),
4354
lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4357
len(output_size) == 3,
4359
f"There should be exactly three elements (depth, height, width) in output_size, "
4360
f"but got {len(output_size)} elements."
4365
lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4369
lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4372
input.shape == indices.shape,
4374
f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4375
f"but got indices tensor with shape: {indices.shape}"
4379
for i in range(1, input.ndim):
4384
f"Expected input to have non-zero size for non-batch dimensions, "
4385
f"but got {input.shape} with dimension {i} being empty."
4390
stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4391
lambda: f"strides should be greater than zero, but got stride: {stride}",
4395
@register_meta(aten.max_unpool3d)
4397
def meta_max_unpool3d(self, indices, output_size, stride, padding):
4398
utils.alert_not_deterministic("max_unpooling3d_forward_out")
4400
_max_unpooling3d_shape_check(
4401
self, indices, output_size, stride, padding, "max_unpooling3d()"
4404
self = self.contiguous()
4406
odepth, oheight, owidth = output_size
4409
nchannels = self.size(0)
4410
result = self.new_empty((nchannels, odepth, oheight, owidth))
4412
nbatch = self.size(0)
4413
nchannels = self.size(1)
4414
result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4419
@register_meta(aten.max_pool3d_with_indices)
4420
@out_wrapper("out", "indices")
4421
def meta_max_pool3d_with_indices(
4430
len(kernel_size) in (1, 3),
4431
lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4434
kH = kT if len(kernel_size) == 1 else kernel_size[1]
4435
kW = kT if len(kernel_size) == 1 else kernel_size[2]
4438
not stride or len(stride) in (1, 3),
4439
lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4441
dT = kT if not stride else stride[0]
4442
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4443
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4446
len(padding) in (1, 3),
4447
lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4450
pH = pT if len(padding) == 1 else padding[1]
4451
pW = pT if len(padding) == 1 else padding[2]
4454
len(dilation) in (1, 3),
4455
lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4457
dilationT = dilation[0]
4458
dilationH = dilationT if len(dilation) == 1 else dilation[1]
4459
dilationW = dilationT if len(dilation) == 1 else dilation[2]
4462
input.ndim in (4, 5),
4463
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4466
nbatch = input.size(-5) if input.ndim == 5 else 1
4467
nslices = input.size(-4)
4468
itime = input.size(-3)
4469
iheight = input.size(-2)
4470
iwidth = input.size(-1)
4472
otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
4473
oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
4474
owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
4497
"max_pool3d_with_indices()",
4501
input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4504
input_channels_last_check = input.unsqueeze(0)
4506
not input_channels_last_check.is_contiguous()
4507
) and input_channels_last_check.is_contiguous(
4508
memory_format=torch.channels_last_3d
4510
out_shape = (nslices, otime, oheight, owidth)
4512
out_shape = (nbatch, nslices, otime, oheight, owidth)
4514
out = input.new_empty(out_shape)
4515
indices = input.new_empty(out_shape, dtype=torch.int64)
4518
out = out.to(memory_format=torch.channels_last_3d)
4519
indices = indices.to(memory_format=torch.channels_last_3d)
4524
@register_meta(aten.max_pool3d_with_indices_backward)
4525
@out_wrapper("grad_input")
4526
def meta_max_pool3d_with_indices_backward(
4537
len(kernel_size) in (1, 3),
4538
lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4541
kH = kT if len(kernel_size) == 1 else kernel_size[1]
4542
kW = kT if len(kernel_size) == 1 else kernel_size[2]
4545
not stride or len(stride) in (1, 3),
4546
lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4548
dT = kT if not stride else stride[0]
4549
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4550
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4553
len(padding) in (1, 3),
4554
lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4557
pH = pT if len(padding) == 1 else padding[1]
4558
pW = pT if len(padding) == 1 else padding[2]
4561
len(dilation) in (1, 3),
4562
lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4564
dilationT = dilation[0]
4565
dilationH = dilationT if len(dilation) == 1 else dilation[1]
4566
dilationW = dilationT if len(dilation) == 1 else dilation[2]
4569
input.ndim in (4, 5),
4570
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4573
nslices = input.size(-4)
4574
itime = input.size(-3)
4575
iheight = input.size(-2)
4576
iwidth = input.size(-1)
4578
otime = grad_output.size(-3)
4579
oheight = grad_output.size(-2)
4580
owidth = grad_output.size(-1)
4582
max_pool3d_backward_shape_check(
4605
"max_pool3d_with_indices_backward()",
4609
input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4612
input_channels_last_check = input.unsqueeze(0)
4614
not input_channels_last_check.is_contiguous()
4615
) and input_channels_last_check.is_contiguous(
4616
memory_format=torch.channels_last_3d
4619
grad_input = input.new_empty(input.shape)
4622
grad_input = grad_input.to(memory_format=torch.channels_last_3d)
4627
def check_grid_sampler_common(input: Tensor, grid: Tensor):
4629
input.device == grid.device,
4631
f"grid_sampler(): expected input and grid to be on same device, but input "
4632
f"is on {input.device} and grid is on {grid.device}"
4636
input.layout == torch.strided and grid.layout == torch.strided,
4638
f"grid_sampler(): expected input and grid to have torch.strided layout, but "
4639
f"input has {input.layout} and grid has {grid.layout}"
4643
input.shape[0] == grid.shape[0],
4645
f"grid_sampler(): expected grid and input to have same batch size, but got "
4646
f"input with sizes {input.shape} and grid with sizes {grid.shape}"
4650
grid.shape[-1] == input.ndim - 2,
4652
f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
4653
f"dimension, but got grid with sizes {grid.shape}"
4657
for i in range(2, input.ndim):
4661
f"grid_sampler(): expected input to have non-empty spatial dimensions, "
4662
f"but input has sizes {input.shape} with dimension {i} being empty"
4667
class GridSamplerInterpolation(Enum):
4673
def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
4675
input.ndim == 5 and input.ndim == grid.ndim,
4677
f"grid_sampler(): expected 5D input and grid with same number of "
4678
f"dimensions, but got input with sizes {input.shape}"
4679
f" and grid with sizes {grid.shape}"
4685
and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
4687
lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
4691
@register_meta(aten.grid_sampler_2d_backward.default)
4692
def grid_sampler_2d_backward_meta(
4701
input_requires_grad = output_mask[0]
4702
if input_requires_grad:
4703
grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
4706
grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
4707
return (grad_input, grad_grid)
4710
@register_meta(aten.grid_sampler_3d)
4719
check_grid_sampler_common(input, grid)
4720
check_grid_sampler_3d(input, grid, interpolation_mode)
4723
out_D = grid.shape[1]
4724
out_H = grid.shape[2]
4725
out_W = grid.shape[3]
4726
return input.new_empty((N, C, out_D, out_H, out_W))
4729
@register_meta(aten.grid_sampler_3d_backward)
4730
@out_wrapper("grad_input", "grad_grid")
4731
def grid_sampler_3d_backward(
4740
check_grid_sampler_common(input, grid)
4741
check_grid_sampler_3d(input, grid, interpolation_mode)
4742
input_requires_grad = output_mask[0]
4743
if input_requires_grad:
4744
grad_input = torch.zeros_like(
4745
input, memory_format=torch.legacy_contiguous_format
4749
grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
4750
return grad_input, grad_grid
4753
@register_meta([aten.full.default])
4754
def full(size, fill_value, *args, **kwargs):
4755
dtype = kwargs.get("dtype", None)
4757
dtype = utils.get_dtype(fill_value)
4758
kwargs["dtype"] = dtype
4759
return torch.empty(size, *args, **kwargs)
4763
@register_meta(aten.zeros_like.default)
4772
if layout == torch.sparse_coo:
4774
memory_format is None,
4775
lambda: "memory format option is only supported by strided tensors",
4780
dtype=self.dtype if dtype is None else dtype,
4782
device=self.device if device is None else device,
4783
pin_memory=pin_memory,
4787
res.sparse_resize_and_clear_(
4788
self.size(), self.sparse_dim(), self.dense_dim()
4791
res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
4793
res._coalesced_(True)
4795
res = aten.empty_like.default(
4800
pin_memory=pin_memory,
4801
memory_format=memory_format,
4808
@register_meta(aten.select.int)
4809
def meta_select(self, dim, index):
4813
lambda: "select() cannot be applied to a 0-dim tensor.",
4816
dim = dim if dim >= 0 else dim + ndim
4817
size = self.size(dim)
4820
not (-index > size or index >= size),
4821
lambda: f"select(): index {index} out of range for tensor of size "
4822
f"{self.size()} at dimension {dim}",
4825
index = index if index >= 0 else index + size
4827
new_size = list(self.size())
4828
new_stride = list(self.stride())
4830
new_storage_offset = self.storage_offset() + index * new_stride[dim]
4834
return self.as_strided(new_size, new_stride, new_storage_offset)
4837
@register_meta(aten.select_scatter.default)
4838
def meta_select_scatter(self, src, dim, index):
4839
return utils.clone_preserve_strides(self)
4842
@register_meta(aten.slice_scatter.default)
4843
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
4844
return utils.clone_preserve_strides(self)
4848
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
4849
if dim_post_expr <= 0:
4852
min = -dim_post_expr
4853
max = dim_post_expr - 1
4854
assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
4856
dim += dim_post_expr
4860
def ensure_nonempty_size(t, dim):
4861
return 1 if t.dim() == 0 else t.shape[dim]
4865
def gather_shape_check(self, dim, index):
4866
self_dims = max(self.dim(), 1)
4867
index_dims = max(index.dim(), 1)
4869
self_dims == index_dims,
4870
lambda: "Index tensor must have the same number of dimensions as input tensor",
4872
for i in range(self_dims):
4875
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
4876
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
4877
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
4881
@register_meta(aten.gather.default)
4882
def meta_gather(self, dim, index, sparse_grad=False):
4883
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4885
wrapped_dim = maybe_wrap_dim(dim, self.dim())
4886
is_index_empty = guard_size_oblivious(index.numel() == 0)
4887
if not is_index_empty:
4889
index.dtype == torch.long,
4890
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
4892
gather_shape_check(self, wrapped_dim, index)
4893
return self.new_empty(index.shape)
4897
def get_operator_enum(reduce_, use_new_options=False):
4899
if reduce_ == "sum":
4901
elif reduce_ == "prod":
4902
return "REDUCE_MULTIPLY"
4903
elif reduce_ == "mean":
4904
return "REDUCE_MEAN"
4905
elif reduce_ == "amax":
4906
return "REDUCE_MAXIMUM"
4907
elif reduce_ == "amin":
4908
return "REDUCE_MINIMUM"
4911
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
4915
if reduce_ == "add":
4917
elif reduce_ == "multiply":
4918
return "REDUCE_MULTIPLY"
4919
torch._check(False, lambda: "reduce argument must be either add or multiply.")
4924
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
4925
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4927
if guard_size_oblivious(index.numel() != 0):
4929
index.dtype == torch.long,
4930
lambda: f"{method_name}(): Expected dtype int64 for index",
4933
if src_opt is not None:
4935
self.dtype == src_opt.dtype,
4936
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
4940
def ensure_nonempty_dim(dim):
4945
def scatter_shape_check(self, dim, index, src_opt=None):
4946
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
4948
if guard_size_oblivious(index.numel() == 0):
4951
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4952
lambda: "Index tensor must have the same number of dimensions as self tensor",
4955
is_wrong_shape = False
4956
self_dims = ensure_nonempty_dim(self.dim())
4959
for d in range(self_dims):
4960
index_d_size = ensure_nonempty_size(index, d)
4963
if index_d_size > ensure_nonempty_size(self, d):
4964
is_wrong_shape = True
4968
if not is_wrong_shape and src_opt is not None:
4969
for d in range(self_dims):
4970
index_d_size = ensure_nonempty_size(index, d)
4971
if index_d_size > ensure_nonempty_size(src_opt, d):
4972
is_wrong_shape = True
4975
if src_opt is not None:
4977
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
4978
lambda: "Index tensor must have the same number of dimensions as self tensor",
4982
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4983
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
4988
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
4989
+ f" apart from dimension {dim}",
4994
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
4995
wrapped_dim = maybe_wrap_dim(dim, self.dim())
4996
scatter_gather_dtype_check("scatter", self, index, src)
4997
scatter_shape_check(self, wrapped_dim, index, src)
4998
if reduce_ is not None:
5000
get_operator_enum(reduce_, use_new_options)
5003
@register_meta(aten.scatter_add.default)
5004
def meta_scatter_add(self, dim, index, src):
5005
scatter_meta_impl(self, dim, index, src, "add")
5006
return self.new_empty(self.shape)
5009
@register_meta(aten.scatter_add_)
5010
def meta_scatter_add_(self, dim, index, src):
5011
scatter_meta_impl(self, dim, index, src, "add")
5019
aten.scatter.reduce,
5020
aten.scatter.value_reduce,
5024
def meta_scatter(self, dim, index, src_or_value, reduce=None):
5025
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5026
scatter_meta_impl(self, dim, index, src, reduce)
5027
return self.new_empty(self.shape)
5033
aten.scatter_.value,
5034
aten.scatter_.reduce,
5035
aten.scatter_.value_reduce,
5038
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
5039
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5040
scatter_meta_impl(self, dim, index, src, reduce)
5044
@register_meta([aten._scaled_dot_product_flash_attention])
5045
def meta__scaled_dot_product_flash_attention(
5049
dropout_p: float = 0.0,
5050
is_causal: bool = False,
5051
return_debug_mask: bool = False,
5052
scale: Optional[float] = None,
5054
batch_size = query.size(0)
5055
num_heads = query.size(1)
5056
max_seqlen_batch_q = query.size(2)
5057
head_dim = query.size(3)
5058
max_seqlen_batch_k = key.size(2)
5060
query_t = query.transpose(1, 2)
5061
attention = torch.empty_like(query_t).transpose(1, 2)
5062
logsumexp = torch.empty(
5063
(batch_size, num_heads, max_seqlen_batch_q),
5065
device=query.device,
5068
if return_debug_mask:
5069
blocksize_c = 128 if head_dim > 64 else 256
5070
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5071
if max_seqlen_batch_k <= 128:
5073
elif max_seqlen_batch_k <= 256:
5075
debug_mask = torch.empty(
5076
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5078
device=query.device,
5081
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5095
torch.empty((), dtype=torch.long, device="meta"),
5096
torch.empty((), dtype=torch.long, device="meta"),
5101
@register_meta([aten._scaled_dot_product_cudnn_attention])
5102
def meta__scaled_dot_product_cudnn_attention(
5106
attn_bias: Optional[Tensor],
5107
compute_log_sumexp: bool,
5108
dropout_p: float = 0.0,
5109
is_causal: bool = False,
5110
return_debug_mask: bool = False,
5111
scale: Optional[float] = None,
5117
D_QK = query.size(-1)
5118
D_V = value.size(-1)
5120
res = torch.empty((B, H, S_Q, D_V), dtype=query.dtype, device=query.device)
5121
logsum_exp = torch.empty(
5124
device=query.device,
5128
seed = torch.empty((), dtype=torch.long, device="meta")
5129
offset = torch.empty((), dtype=torch.long, device="meta")
5146
aten._scaled_dot_product_flash_attention_backward,
5149
def meta__scaled_dot_product_flash_backward(
5162
philox_seed: Tensor,
5163
philox_offset: Tensor,
5164
scale: Optional[float] = None,
5166
grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
5167
grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
5168
grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
5169
return grad_q, grad_k, grad_v
5174
aten._scaled_dot_product_flash_attention_for_cpu,
5177
def meta__scaled_dot_product_flash_attention_for_cpu(
5181
dropout_p: float = 0.0,
5182
is_causal: bool = False,
5183
attn_mask: Optional[Tensor] = None,
5184
scale: Optional[float] = None,
5186
batch_size = query.size(0)
5187
num_heads = query.size(1)
5188
max_seqlen_batch_q = query.size(2)
5189
head_dim = query.size(3)
5191
attention = torch.empty_like(query)
5192
logsumexp = torch.empty(
5199
device=query.device,
5209
aten._scaled_dot_product_flash_attention_for_cpu_backward,
5212
def meta__scaled_dot_product_flash_attention_for_cpu_backward(
5221
attn_mask: Optional[Tensor] = None,
5222
scale: Optional[float] = None,
5226
batch_size = query.size(0)
5227
num_heads = query.size(1)
5228
head_dim = query.size(3)
5229
len_q = query.size(2)
5232
grad_q = torch.empty_permuted(
5233
(batch_size, num_heads, len_q, head_dim),
5236
device=query.device,
5238
grad_k = torch.empty_permuted(
5239
(batch_size, num_heads, len_k, head_dim),
5244
grad_v = torch.empty_permuted(
5245
(batch_size, num_heads, len_k, head_dim),
5248
device=value.device,
5251
return grad_q, grad_k, grad_v
5254
@register_meta([aten._scaled_dot_product_efficient_attention])
5255
def meta__scaled_dot_product_efficient_attention(
5259
attn_bias: Optional[Tensor],
5260
compute_log_sumexp: bool,
5262
is_causal: bool = False,
5263
scale: Optional[float] = None,
5265
query = query.transpose(1, 2)
5266
key = key.transpose(1, 2)
5267
value = value.transpose(1, 2)
5272
num_heads = query.size(-2)
5276
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5278
logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
5279
logsum_exp = torch.empty(
5280
(B, num_heads, logsumexp_dim),
5282
device=query.device,
5285
res = res.transpose(1, 2)
5288
seed = torch.empty((), dtype=torch.long, device="meta")
5289
offset = torch.empty((), dtype=torch.long, device="meta")
5291
return res, logsum_exp, seed, offset
5296
aten._scaled_dot_product_efficient_attention_backward,
5299
def meta__scaled_dot_product_efficient_backward(
5304
attn_bias: Optional[Tensor],
5307
philox_seed: Tensor,
5308
philox_offset: Tensor,
5310
grad_input_mask: List[bool],
5311
is_causal: bool = False,
5312
scale: Optional[float] = None,
5314
batch_size = query.size(0)
5315
num_heads = query.size(1)
5316
max_q = query.size(2)
5317
head_dim = query.size(3)
5318
head_dim_v = value.size(3)
5322
grad_q = torch.empty_permuted(
5323
(batch_size, num_heads, max_q, head_dim),
5326
device=query.device,
5328
grad_k = torch.empty_permuted(
5329
(batch_size, num_heads, max_k, head_dim),
5334
grad_v = torch.empty_permuted(
5335
(batch_size, num_heads, max_k, head_dim_v),
5338
device=value.device,
5341
if attn_bias is not None and grad_input_mask[3]:
5342
lastDim = attn_bias.size(-1)
5343
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5344
new_sizes = list(attn_bias.size())
5345
new_sizes[-1] = lastDimAligned
5346
grad_bias = torch.empty(
5347
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5349
grad_bias = grad_bias[..., :lastDim]
5351
return grad_q, grad_k, grad_v, grad_bias
5356
aten._scaled_dot_product_cudnn_attention_backward,
5359
def meta__scaled_dot_product_cudnn_backward(
5366
philox_seed: Tensor,
5367
philox_offset: Tensor,
5375
scale: Optional[float] = None,
5377
grad_q = torch.empty_like(query)
5378
grad_k = torch.empty_like(key)
5379
grad_v = torch.empty_like(value)
5380
return grad_q, grad_k, grad_v
5385
aten._flash_attention_forward,
5388
def meta__flash_attention_forward(
5392
cum_seq_q: Optional[Tensor],
5393
cum_seq_k: Optional[Tensor],
5398
return_debug_mask: bool,
5399
scale: Optional[float] = None,
5400
window_size_left: Optional[int] = None,
5401
window_size_right: Optional[int] = None,
5402
seqused_k: Optional[Tensor] = None,
5403
alibi_slopes: Optional[Tensor] = None,
5409
batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
5410
max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
5411
max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
5412
num_heads = query.size(-2)
5413
head_dim = query.size(-1)
5416
attention = torch.empty_like(query)
5417
logsumexp = torch.empty(
5418
(batch_size, num_heads, max_seqlen_batch_q),
5420
device=query.device,
5423
if return_debug_mask:
5424
blocksize_c = 128 if head_dim > 64 else 256
5425
max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
5426
if max_seqlen_batch_k <= 128:
5428
elif max_seqlen_batch_k <= 256:
5430
debug_mask = torch.empty(
5431
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
5433
device=query.device,
5436
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
5442
torch.empty((), dtype=torch.long, device="meta"),
5443
torch.empty((), dtype=torch.long, device="meta"),
5450
aten._flash_attention_backward,
5453
def meta__flash_attention_backward(
5466
philox_seed: Tensor,
5467
philox_offset: Tensor,
5468
scale: Optional[float] = None,
5469
window_size_left: Optional[int] = None,
5470
window_size_right: Optional[int] = None,
5472
grad_query = torch.empty_like(query)
5473
grad_key = torch.empty_like(key)
5474
grad_value = torch.empty_like(value)
5476
return grad_query, grad_key, grad_value
5481
aten._efficient_attention_forward,
5484
def meta__efficient_attention_forward(
5488
bias: Optional[Tensor],
5489
cu_seqlens_q: Optional[Tensor],
5490
cu_seqlens_k: Optional[Tensor],
5491
max_seqlen_q: Optional[int],
5492
max_seqlen_k: Optional[int],
5494
custom_mask_type: int,
5495
compute_log_sumexp: bool = False,
5496
scale: Optional[float] = None,
5497
causal_diagonal: Optional[Tensor] = None,
5498
seqlen_k: Optional[Tensor] = None,
5499
window_size: Optional[int] = None,
5504
num_heads = query.size(-2)
5508
res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
5510
logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
5511
actual_max_seqlen_q = M
5512
if cu_seqlens_q is not None:
5513
assert max_seqlen_q is not None
5514
actual_max_seqlen_q = max_seqlen_q
5515
actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
5517
math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
5519
logsum_exp = torch.empty(
5520
(logsumexp_batch_dim, num_heads, logsumexp_dim),
5522
device=query.device,
5526
seed = torch.empty((), dtype=torch.long, device="meta")
5527
offset = torch.empty((), dtype=torch.long, device="meta")
5529
return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
5534
aten._efficient_attention_backward,
5537
def meta__efficient_attention_backward(
5542
bias: Optional[Tensor],
5543
cu_seqlens_q: Optional[Tensor],
5544
cu_seqlens_k: Optional[Tensor],
5545
max_seqlen_q: torch.SymInt,
5546
max_seqlen_k: torch.SymInt,
5549
philox_seed: Tensor,
5550
philox_offset: Tensor,
5551
custom_mask_type: int,
5552
bias_requires_grad: bool,
5553
scale: Optional[float] = None,
5554
num_splits_key: Optional[int] = None,
5555
shared_storage_dqdkdv: bool = False,
5557
if shared_storage_dqdkdv:
5559
query.shape[1] == key.shape[1],
5560
lambda: "seqlen must match for `shared_storage_dqdkdv",
5563
query.shape[3] == key.shape[3],
5564
lambda: "embedding dim must match for `shared_storage_dqdkdv",
5566
chunk = torch.empty(
5567
(*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
5569
device=query.device,
5571
grad_query = chunk.select(-3, 0)
5572
grad_key = chunk.select(-3, 1)
5573
grad_value = chunk.select(-3, 2)
5575
grad_query = torch.empty_like(query)
5576
grad_key = torch.empty_like(key)
5577
grad_value = torch.empty_like(value)
5579
if bias is not None:
5580
lastDim = bias.size(-1)
5581
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5582
new_sizes = list(bias.size())
5583
new_sizes[-1] = lastDimAligned
5584
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
5585
grad_bias = grad_bias[..., :lastDim]
5587
grad_bias = torch.empty((), device=query.device)
5589
return grad_query, grad_key, grad_value, grad_bias
5592
@register_meta([aten._scaled_mm.default])
5596
scale_a: torch.Tensor,
5597
scale_b: torch.Tensor,
5598
bias: Optional[torch.Tensor] = None,
5599
scale_result: Optional[torch.Tensor] = None,
5600
out_dtype: Optional[torch.dtype] = None,
5601
use_fast_accum: bool = False,
5603
def is_row_major(stride):
5604
return stride[0] > stride[1] and stride[1] == 1
5606
def is_col_major(stride):
5607
return stride[0] == 1 and stride[1] > 1
5609
def is_fp8_type(dtype):
5611
torch.float8_e4m3fn,
5613
torch.float8_e4m3fnuz,
5614
torch.float8_e5m2fnuz,
5618
self.dim() == 2 and mat2.dim() == 2,
5619
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5622
is_row_major(self.stride()),
5623
lambda: "self must be row_major",
5626
is_col_major(mat2.stride()),
5627
lambda: "mat2 must be col_major",
5630
self.size(1) % 16 == 0,
5631
lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
5634
mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5635
lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5638
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5639
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5644
scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
5645
lambda: "Both scale_a and scale_b must be float (fp32) tensors.",
5649
if scale_a.numel() == 1 and scale_b.numel() == 1:
5655
scale_a.dim() == 2 and scale_b.dim() == 2,
5656
lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}",
5660
scale_a.size(0) == m
5661
and scale_a.size(1) == 1
5662
and scale_b.size(0) == 1
5663
and scale_b.size(1) == n
5667
scale_a.is_contiguous() and scale_b.is_contiguous(),
5668
lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
5675
"Invalid scaling configuration. "
5676
"For tensorwise scaling, both scales should be scalar. "
5677
f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
5678
f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
5679
f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
5683
_out_dtype = out_dtype if out_dtype is not None else self.dtype
5684
return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)
5687
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
5689
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
5690
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5691
return self.new_empty(self.shape)
5694
@register_meta(aten.scatter_reduce_.two)
5695
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
5696
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5700
@register_meta([aten.multinomial.default, aten.multinomial.out])
5702
def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
5704
0 < input.dim() <= 2,
5705
lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
5707
if input.dim() == 1:
5708
return torch.empty(num_samples, dtype=torch.long, device=input.device)
5710
input.size(0), num_samples, dtype=torch.long, device=input.device
5714
def multiply_integers(vs):
5721
def upsample_common_check(input_size, output_size, num_spatial_dims):
5723
len(output_size) == num_spatial_dims,
5724
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
5726
expected_input_dims = num_spatial_dims + 2
5728
len(input_size) == expected_input_dims,
5729
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
5733
all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
5734
lambda: f"Input and output sizes should be greater than 0, but got "
5735
f"input size {input_size} and output size {output_size}",
5738
nbatch, channels = input_size[:2]
5739
return (nbatch, channels, *output_size)
5743
[aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
5745
def upsample_nearest1d(input, output_size, scales=None):
5747
input.numel() != 0 or multiply_integers(input.size()[1:]),
5748
lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
5750
full_output_size = upsample_common_check(
5751
input.size(), output_size, num_spatial_dims=1
5753
return input.new_empty(full_output_size).to(
5754
memory_format=utils.suggest_memory_format(input)
5759
[aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
5761
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
5763
input.numel() != 0 or multiply_integers(input.size()[1:]),
5764
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5766
full_output_size = upsample_common_check(
5767
input.size(), output_size, num_spatial_dims=2
5769
output = input.new_empty(full_output_size)
5772
memory_format = utils.suggest_memory_format(input)
5775
_, n_channels, _, _ = input.shape
5776
if input.device.type == "cuda" and n_channels < 4:
5777
memory_format = torch.contiguous_format
5779
output = output.contiguous(memory_format=memory_format)
5786
aten.upsample_nearest2d_backward.default,
5787
aten._upsample_nearest_exact2d_backward.default,
5790
def upsample_nearest2d_backward(
5791
grad_output: Tensor,
5792
output_size: Sequence[Union[int, torch.SymInt]],
5793
input_size: Sequence[Union[int, torch.SymInt]],
5794
scales_h: Optional[float] = None,
5795
scales_w: Optional[float] = None,
5797
full_output_size = upsample_common_check(
5798
input_size, output_size, num_spatial_dims=2
5801
grad_output.ndim == 4,
5802
lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
5806
grad_output.size(i) == full_output_size[i],
5808
f"Expected grad_output to have the same shape as output;"
5809
f" output.size({i}) = {full_output_size[i]}"
5810
f" but got grad_output.size({i}) = {grad_output.size(i)}"
5814
return grad_output.new_empty(input_size).to(
5815
memory_format=utils.suggest_memory_format(grad_output)
5820
[aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
5822
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
5824
input.numel() != 0 or multiply_integers(input.size()[1:]),
5825
lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
5827
full_output_size = upsample_common_check(
5828
input.size(), output_size, num_spatial_dims=3
5830
return input.new_empty(full_output_size).to(
5831
memory_format=utils.suggest_memory_format(input)
5840
aten.sort.values_stable,
5843
def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
5844
v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
5845
if values is not None and indices is not None:
5846
assert isinstance(values, TensorLike)
5847
assert isinstance(indices, TensorLike)
5851
out_stride = v.stride()
5852
values = _maybe_resize_out(values, out_shape)
5853
indices = _maybe_resize_out(indices, out_shape)
5854
values.as_strided_(out_shape, out_stride)
5855
indices.as_strided_(out_shape, out_stride)
5856
_safe_copy_out(copy_from=v, copy_to=values)
5857
_safe_copy_out(copy_from=i, copy_to=indices)
5858
return values, indices
5862
def rnn_cell_checkSizes(
5870
torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
5872
input_gates.shape == hidden_gates.shape,
5873
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
5875
gates_size = input_gates.size(1)
5876
if input_bias is not None:
5877
torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
5879
input_bias.numel() == gates_size,
5880
lambda: f"{input_bias.numel()} != {gates_size}",
5883
input_bias.shape == hidden_bias.shape,
5884
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
5886
torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
5887
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
5889
prev_hidden.numel() == expected_prev_hidden_numel,
5890
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
5894
x.device == input_gates.device
5895
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
5897
lambda: "expected all inputs to be same device",
5901
@register_meta(aten._thnn_fused_lstm_cell.default)
5902
def _thnn_fused_lstm_cell_meta(
5909
rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
5910
workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
5911
hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5912
cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5913
return (hy, cy, workspace)
5916
@register_meta(aten._cudnn_rnn.default)
5935
is_input_packed = len(batch_sizes) != 0
5937
seq_length = len(batch_sizes)
5938
mini_batch = batch_sizes[0]
5939
batch_sizes_sum = input.shape[0]
5941
seq_length = input.shape[1] if batch_first else input.shape[0]
5942
mini_batch = input.shape[0] if batch_first else input.shape[1]
5943
batch_sizes_sum = -1
5945
num_directions = 2 if bidirectional else 1
5946
out_size = proj_size if proj_size != 0 else hidden_size
5948
out_shape = [batch_sizes_sum, out_size * num_directions]
5951
[mini_batch, seq_length, out_size * num_directions]
5953
else [seq_length, mini_batch, out_size * num_directions]
5955
output = input.new_empty(out_shape)
5957
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
5959
cy = torch.empty(0, device=input.device)
5961
cy = cx.new_empty(cell_shape)
5963
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
5966
reserve_shape = 0 if train else 0
5967
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
5969
return output, hy, cy, reserve, weight_buf
5972
@register_meta(aten.mkldnn_rnn_layer.default)
5973
def mkldnn_rnn_layer(
5991
seq_length = input.shape[1] if batch_first else input.shape[0]
5992
mini_batch = input.shape[0] if batch_first else input.shape[1]
5993
output_chanels = hidden_size
5995
[mini_batch, seq_length, output_chanels]
5997
else [seq_length, mini_batch, output_chanels]
5999
output = input.new_empty(out_shape)
6001
hy = torch.empty(0, device=input.device)
6003
hy = hx_.new_empty(hx_.shape)
6005
cy = torch.empty(0, device=input.device)
6007
cy = cx_.new_empty(cx_.shape)
6008
workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
6009
return output, hy, cy, workspace
6012
def zero_numel_check_dims(self, dim, fn_name):
6015
dim == 0 or dim == -1,
6016
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
6020
self.size(dim) != 0,
6021
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
6026
def check_argmax_argmin(name, self, dim):
6028
dim = maybe_wrap_dim(dim, self.dim())
6029
zero_numel_check_dims(self, dim, name)
6033
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
6037
@register_meta([aten.argmax.default, aten.argmin.default])
6038
def argmax_argmin_meta(self, dim=None, keepdim=False):
6039
check_argmax_argmin("argmax", self, dim)
6040
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
6041
shape = _compute_reduction_shape(self, dims, keepdim)
6042
return self.new_empty(shape, dtype=torch.int64)
6045
@register_meta(aten.scalar_tensor.default)
6046
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
6048
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
6052
@register_meta(aten.topk.default)
6053
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
6055
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6056
sliceSize = 1 if self.dim() == 0 else self.size(dim)
6057
torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
6059
topKSize = list(self.shape)
6060
if len(topKSize) > 0:
6062
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
6065
@register_meta([aten.kthvalue.default, aten.kthvalue.values])
6066
@out_wrapper("values", "indices")
6067
def kthvalue_meta(self, k, dim=-1, keepdim=False):
6068
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
6069
dimSize = self.size(dim) if self.dim() > 0 else 1
6071
k >= 1 and k <= dimSize,
6072
lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
6075
shape = list(self.shape[:dim] + self.shape[dim + 1 :])
6076
if keepdim and self.dim() > 0:
6077
shape.insert(dim, 1)
6078
return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
6081
legacy_contiguous_memory_format = torch.contiguous_format
6085
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
6086
defined_grad = grad_hy if grad_hy is not None else grad_cy
6087
torch._check(defined_grad.dim() == 2, lambda: "")
6088
exp_size = defined_grad.size()
6089
if grad_hy is not None:
6090
torch._check(grad_hy.size() == exp_size, lambda: "")
6091
if grad_cy is not None:
6092
torch._check(grad_cy.size() == exp_size, lambda: "")
6093
torch._check(cx.size() == exp_size, lambda: "")
6094
torch._check(cy.size() == exp_size, lambda: "")
6095
torch._check(workspace.dim() == 2, lambda: "")
6096
torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
6100
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
6101
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
6102
if grad_hy is None and grad_cy is None:
6103
return None, None, None
6104
checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
6105
grad_gates = torch.empty_like(
6106
workspace, memory_format=legacy_contiguous_memory_format
6108
grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
6109
grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
6110
return grad_gates, grad_cx, grad_bias
6114
@register_meta(aten.linear_backward.default)
6115
def linear_backward(input_, grad_output_, weight_, output_mask):
6120
grad_input = grad_output_.new_empty(input_.size())
6121
if output_mask[1] or output_mask[2]:
6122
grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
6123
grad_bias = grad_output_.new_empty(grad_output_.size(-1))
6124
return (grad_input, grad_weight, grad_bias)
6127
@register_meta(aten.pixel_shuffle.default)
6128
def meta_pixel_shuffle(self, upscale_factor):
6130
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
6131
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
6133
def is_channels_last(ten):
6134
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
6136
def pick_memory_format():
6137
if is_channels_last(self):
6138
if device_hint(self) == "cuda":
6139
return torch.contiguous_format
6141
return torch.channels_last
6142
elif self.is_contiguous(memory_format=torch.contiguous_format):
6143
return torch.contiguous_format
6144
elif self.is_contiguous(memory_format=torch.preserve_format):
6145
return torch.preserve_format
6147
C = self.shape[-3] // (upscale_factor * upscale_factor)
6148
Hr = self.shape[-2] * upscale_factor
6149
Wr = self.shape[-1] * upscale_factor
6150
out_shape = (*self.shape[:-3], C, Hr, Wr)
6152
out = self.new_empty(out_shape)
6153
out = out.to(memory_format=pick_memory_format())
6157
@register_meta(aten.mkldnn_rnn_layer_backward.default)
6158
def mkldnn_rnn_layer_backward(
6183
diff_x = input.new_empty(input.shape)
6184
diff_hx = hx_.new_empty(hx_.shape)
6185
diff_cx = cx_tmp.new_empty(cx_tmp.shape)
6186
diff_w1 = weight0.new_empty(weight0.shape)
6187
diff_w2 = weight1.new_empty(weight1.shape)
6188
diff_b = weight2.new_empty(weight2.shape)
6189
return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
6192
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
6194
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
6195
return torch.empty_like(
6196
self, dtype=torch.int32 if out_int32 else torch.int64
6200
@register_meta([aten.histc])
6202
def meta_histc(input, bins=100, min=0, max=0):
6204
if device_hint(input) == "cpu":
6206
input.is_floating_point(),
6207
lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'",
6210
isinstance(bins, IntLike),
6211
lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}",
6213
torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}")
6215
isinstance(min, Number),
6216
lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}",
6219
isinstance(max, Number),
6220
lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
6222
torch._check(max >= min, lambda: "{fn_name}: max must be larger than min")
6223
return torch.empty(bins, device=input.device, dtype=input.dtype)
6227
[aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
6229
def meta_upsample_bimode2d_aa(
6236
full_output_size = upsample_common_check(
6237
input.size(), output_size, num_spatial_dims=2
6240
input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
6241
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
6243
return input.new_empty(full_output_size).to(
6244
memory_format=utils.suggest_memory_format(input)
6249
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
6250
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
6252
found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
6255
inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
6258
found_inf.dtype.is_floating_point,
6259
lambda: "found_inf must be a float tensor.",
6262
inv_scale.dtype.is_floating_point,
6263
lambda: "inv_scale must be a float tensor.",
6268
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
6270
def nan_to_num(self, nan=None, posinf=None, neginf=None):
6271
result_size = list(self.size())
6272
return self.new_empty(result_size)
6275
@register_meta(torch.ops.aten.transpose_)
6276
def transpose_(self, dim0, dim1):
6285
), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
6289
dim0 = maybe_wrap_dim(dim0, ndims)
6290
dim1 = maybe_wrap_dim(dim1, ndims)
6295
size = list(self.size())
6296
stride = list(self.stride())
6298
stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
6299
size[dim0], size[dim1] = size[dim1], size[dim0]
6301
self.as_strided_(size, stride)
6305
@register_meta(torch.ops.aten.t_)
6310
sparse_dim = self.sparse_dim()
6311
dense_dim = self.dense_dim()
6313
sparse_dim <= 2 and dense_dim == 0
6314
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions"
6318
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
6320
return transpose_(self, 0, 0 if ndims < 2 else 1)
6323
@register_meta(aten.searchsorted)
6325
def meta_searchsorted(
6334
dtype = torch.int32 if out_int32 else torch.int64
6335
if isinstance(self, torch.Tensor):
6336
return torch.empty_like(self, dtype=dtype).contiguous()
6338
return torch.empty((), dtype=dtype, device=sorted_sequence.device)
6341
def _check_for_unsupported_isin_dtype(dtype):
6343
dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
6344
lambda: f"Unsupported input type encountered for isin(): {dtype}",
6348
@register_meta(aten._embedding_bag_backward)
6349
def meta_embedding_bag_backward(
6364
return aten._embedding_bag_sparse_backward(
6377
return meta_embedding_bag_dense_backward(
6391
@register_meta(aten._embedding_bag_dense_backward)
6392
def meta_embedding_bag_dense_backward(
6405
grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64],
6406
lambda: f"Unsupported input type encountered: {grad.dtype}",
6408
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6409
if mode == MODE_MAX:
6410
torch._check(maximum_indices is not None)
6411
index_grad_weight = grad.new_empty((num_weights, grad.size(1)))
6412
return index_grad_weight
6415
@register_meta(aten._embedding_bag_per_sample_weights_backward)
6416
def meta_embedding_bag_per_sample_weights_backward(
6425
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
6426
embedding_features = grad.size(1)
6429
"embedding_bag_backward: per_sample_weights only supported for mode='sum'",
6431
torch._check(grad.dim() == 2)
6432
torch._check(indices.dim() == 1)
6433
num_samples = indices.size(0)
6434
torch._check(weight.dim() == 2)
6435
torch._check(weight.size(1) == embedding_features)
6436
output = grad.new_empty((num_samples,))
6440
@register_meta(aten.isin)
6442
def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
6444
isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
6445
lambda: "At least one of elements and test_elements must be a Tensor.",
6447
if not isinstance(elements, Tensor):
6448
elements = torch.tensor(elements, device=test_elements.device)
6450
if not isinstance(test_elements, Tensor):
6451
test_elements = torch.tensor(test_elements, device=elements.device)
6453
_check_for_unsupported_isin_dtype(elements.dtype)
6454
_check_for_unsupported_isin_dtype(test_elements.dtype)
6455
return torch.empty_like(elements, dtype=torch.bool)
6458
@register_meta(aten.polygamma)
6460
def meta_polygamma(n: int, self: Tensor) -> Tensor:
6461
torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
6462
_, result_dtype = elementwise_dtypes(
6464
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
6466
return torch.empty_like(self, dtype=result_dtype)
6469
@register_meta(aten._local_scalar_dense)
6470
def meta_local_scalar_dense(self: Tensor):
6471
raise RuntimeError("Tensor.item() cannot be called on meta tensors")
6474
@register_meta(aten._jagged_to_padded_dense_forward.default)
6475
def meta__jagged_to_padded_dense_forward(
6477
offsets: List[Tensor],
6478
max_lengths: List[int],
6479
padding_value: float = 0.0,
6482
assert len(offsets) == 1
6483
assert len(max_lengths) == 1
6485
B = offsets[0].shape[0] - 1
6487
output_shape = (B, S, *values.shape[1:])
6488
return values.new_empty(output_shape)
6491
@register_meta(aten._padded_dense_to_jagged_forward.default)
6492
def meta__padded_dense_to_jagged_forward(
6494
offsets: List[Tensor],
6495
total_L: Optional[int] = None,
6498
assert len(offsets) == 1
6501
assert isinstance(padded, torch._subclasses.FakeTensor)
6502
shape_env = padded.fake_mode.shape_env
6503
assert shape_env is not None
6504
total_L = shape_env.create_unbacked_symint()
6505
torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
6506
total_L, min=0, max=None
6509
output_shape = (total_L, *padded.shape[2:])
6510
return padded.new_empty(output_shape)
6513
def _create_unary_float_meta_func(func):
6514
@register_meta(func)
6517
return elementwise_meta(
6518
x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6524
def _create_binary_float_meta_func(func):
6525
@register_meta(func)
6528
return elementwise_meta(
6529
x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6535
_create_unary_float_meta_func(aten.special_airy_ai)
6536
_create_unary_float_meta_func(aten.special_bessel_y0)
6537
_create_unary_float_meta_func(aten.special_bessel_y1)
6538
_create_unary_float_meta_func(aten.special_modified_bessel_i0)
6539
_create_unary_float_meta_func(aten.special_modified_bessel_i1)
6540
_create_unary_float_meta_func(aten.special_modified_bessel_k0)
6541
_create_unary_float_meta_func(aten.special_modified_bessel_k1)
6542
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
6543
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
6546
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
6547
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
6548
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
6549
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
6550
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
6551
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
6552
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
6553
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
6554
_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
6555
_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
6556
_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
6557
_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
6563
import torch._refs.nn.functional
6564
import torch._refs.special
6568
activate_meta_table = {}
6572
for type in ["meta", "post_autograd", "pre_autograd"]:
6573
registry = global_decomposition_table[type]
6575
for opo in registry:
6576
if opo not in activate_meta_table:
6577
activate_meta_table[opo] = registry[opo]
6579
for op_overload, fn in activate_meta_table.items():
6584
if isinstance(op_overload, torch._ops.HigherOrderOperator):
6586
assert isinstance(op_overload, OpOverload)
6588
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6590
if torch._C._dispatch_has_kernel_for_dispatch_key(
6591
op_overload.name(), "CompositeImplicitAutograd"
6597
if op_overload in global_decomposition_table["meta"]:
6599
f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
6600
"register meta function for it. Instead, we should let the decomposition run and write "
6601
"meta kernels for the base operators."
6603
elif op_overload.is_view:
6611
"aten::empty_strided",
6615
"aten::constant_pad_nd",
6617
"aten::as_strided_scatter",
6622
if "mkldnn::" in op_overload.name():
6623
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
6624
elif "mkl::" in op_overload.name():
6625
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
6626
elif "onednn::" in op_overload.name():
6627
_meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
6628
elif "quantized::" in op_overload.name():
6629
_meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
6633
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)