3
from functools import partial
4
from typing import List, Optional, Sequence, Tuple, Union
7
import torch._prims_common as utils
8
from torch import SymBool, SymFloat, Tensor
9
from torch._decomp import (
12
global_decomposition_table,
15
from torch._ops import OpOverload
16
from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
17
from torch._prims_common import (
18
corresponding_complex_dtype,
19
corresponding_real_dtype,
21
ELEMENTWISE_TYPE_PROMOTION_KIND,
23
make_contiguous_strides_for,
27
from torch._prims_common.wrappers import (
28
_maybe_convert_to_dtype,
34
from torch._refs import _broadcast_shapes, _maybe_broadcast
35
from torch.utils import _pytree as pytree
40
_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
45
fn = _convert_out_params(fn)
48
_add_op_to_registry(meta_table, op, fn)
50
pytree.tree_map_(register, op)
58
type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
61
_, result_dtype = utils.elementwise_dtypes(
63
type_promotion_kind=type_promotion,
65
args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
68
args = _maybe_broadcast(*args)
71
return _prim_elementwise_meta(
72
*args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
76
def toRealValueType(dtype):
78
torch.complex32: torch.half,
79
torch.cfloat: torch.float,
80
torch.cdouble: torch.double,
82
return from_complex.get(dtype, dtype)
85
def check_inplace_broadcast(self_shape, *args_shape):
86
broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
88
broadcasted_shape == self_shape,
89
lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
93
@register_meta([aten.linspace, aten.logspace])
95
def meta_linspace_logspace(
102
layout=torch.strided,
106
if isinstance(start, torch.Tensor):
109
lambda: "linspace only supports 0-dimensional start and end tensors",
111
if isinstance(end, torch.Tensor):
114
lambda: "linspace only supports 0-dimensional start and end tensors",
117
if any(isinstance(arg, complex) for arg in (start, end, steps)):
118
default_complex_dtype = utils.corresponding_complex_dtype(
119
torch.get_default_dtype()
122
dtype = default_complex_dtype
125
utils.is_complex_dtype(dtype),
126
lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
129
dtype = dtype or torch.get_default_dtype()
130
assert isinstance(dtype, torch.dtype)
134
isinstance(steps, IntLike),
135
lambda: f"received an invalid combination of arguments - got \
136
({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
138
assert isinstance(steps, IntLike)
139
torch._check(steps >= 0, lambda: "number of steps must be non-negative")
146
pin_memory=pin_memory,
147
requires_grad=requires_grad,
151
@register_meta([aten.take.default, aten.take.out])
153
def meta_take(self, index):
156
index.dtype == torch.long,
157
lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
161
not (self.numel() == 0 and index.numel() != 0),
162
lambda: "take(): tried to take from an empty tensor",
164
return self.new_empty(index.shape)
167
@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
169
def linalg_cross(self, other, *, dim=-1):
174
lambda: "linalg.cross: inputs must have the same number of dimensions.",
177
self.size(dim) == 3 and other.size(dim) == 3,
179
f"linalg.cross: inputs dimension {dim} must have length 3. "
180
f"Got {self.size(dim)} and {other.size(dim)}"
183
out_shape = _broadcast_shapes(self.shape, other.shape)
184
return self.new_empty(out_shape)
187
@register_meta(aten.linalg_matrix_exp)
189
def linalg_matrix_exp(self):
190
squareCheckInputs(self, "linalg.matrix_exp")
191
checkFloatingOrComplex(self, "linalg.matrix_exp")
192
return torch.empty_like(self, memory_format=torch.contiguous_format)
196
[aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
198
@out_wrapper("values", "indices")
199
def cummaxmin(self, dim):
200
values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
201
indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
202
if self.numel() != 0 and self.ndim != 0:
204
maybe_wrap_dim(dim, self.ndim)
205
return values, indices
208
@register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
210
def logcumsumexp(self, dim):
212
maybe_wrap_dim(dim, self.ndim)
213
return torch.empty_like(self).contiguous()
217
def _exec_fft(out, self, out_sizes, dim, forward):
219
signal_ndim = len(dim)
220
batch_dims = ndim - signal_ndim
223
dim_permute = list(range(ndim))
225
is_transformed_dim = [False for _ in range(ndim)]
227
is_transformed_dim[d] = True
231
for d in dim_permute:
232
if not is_transformed_dim[d]:
236
dim_permute = left + right
237
batch_end = len(left)
239
self_strides = self.stride()
240
tmp = dim_permute[:batch_end]
241
tmp.sort(key=lambda x: self_strides[x], reverse=True)
242
dim_permute = tmp + dim_permute[batch_end:]
243
input = self.permute(dim_permute)
246
batched_sizes = [-1] + list(input.shape[batch_dims:])
247
input = input.reshape(batched_sizes)
249
batch_size = input.size(0)
250
batched_sizes[0] = batch_size
251
batched_out_sizes = batched_sizes
252
for i in range(len(dim)):
253
batched_out_sizes[i + 1] = out_sizes[dim[i]]
254
out = out.reshape(batched_out_sizes)
257
out_strides = [0 for _ in range(ndim)]
261
out_strides[dim_permute[i]] = batch_numel * out.stride(0)
262
batch_numel *= out_sizes[dim_permute[i]]
264
for i in range(batch_dims, ndim):
265
out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
266
return out.as_strided(out_sizes, out_strides, out.storage_offset())
271
@register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
273
def meta_fft_c2c(self, dim, normalization, forward):
274
assert self.dtype.is_complex
276
out_sizes = self.shape
277
output = self.new_empty(out_sizes)
283
self_strides = self.stride()
284
sorted_dims.sort(key=lambda x: self_strides[x], reverse=True)
285
output = _exec_fft(output, self, out_sizes, sorted_dims, forward)
290
@register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
292
def meta_fft_r2c(self, dim, normalization, onesided):
293
assert self.dtype.is_floating_point
294
output_sizes = list(self.size())
298
last_dim_halfsize = (output_sizes[last_dim] // 2) + 1
299
output_sizes[last_dim] = last_dim_halfsize
301
return self.new_empty(
302
output_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
306
@register_meta(aten.randperm.generator_out)
307
def meta_randperm(n, *, generator=None, out):
308
return _maybe_resize_out(out, torch.Size([n]))
311
@register_meta(aten.randperm.default)
312
def meta_randperm_default(
313
n, *, dtype=torch.long, layout=None, device=None, pin_memory=None
316
n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
320
@register_meta(aten.randint.default)
322
high, size, *, dtype=torch.long, layout=None, device=None, pin_memory=None
325
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
329
@register_meta(aten.randint.low)
341
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
345
@register_meta(aten.rand.default)
346
def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
348
size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
352
@register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
354
def meta_fft_c2r(self, dim, normalization, lastdim):
355
assert self.dtype.is_complex
356
output_sizes = list(self.size())
357
output_sizes[dim[-1]] = lastdim
358
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))
361
@register_meta(aten.copy_.default)
362
def meta_copy_(self, src, non_blocking=False):
368
if torch._debug_has_internal_overlap(self) == 1:
370
"more than one element of the written-to tensor refers to a single memory location"
373
if isinstance(src, Tensor):
374
intermediate = src.to(self, non_blocking)
375
if self.size() != intermediate.size():
376
aten.expand_copy.default(intermediate, self.size())
380
def inferUnsqueezeGeometry(tensor, dim):
381
result_sizes = list(tensor.size())
382
result_strides = list(tensor.stride())
383
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
384
result_sizes.insert(dim, 1)
385
result_strides.insert(dim, new_stride)
386
return result_sizes, result_strides
389
@register_meta(aten.unsqueeze_.default)
390
def meta_unsqueeze_(self, dim):
391
dim = maybe_wrap_dim(dim, self.dim() + 1)
392
g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
393
self.as_strided_(g_sizes, g_strides)
397
@register_meta(aten._sparse_semi_structured_linear)
398
def meta_sparse_structured_linear(
402
bias: Optional[Tensor] = None,
403
_activation_opt: Optional[str] = None,
404
out_dtype: Optional[torch.dtype] = None,
406
output_sizes = list(input.shape)
408
assert weight.size(0) == bias.size(0), "output size mismatch"
409
assert weight.size(1) == input.size(-1) / 2
410
output_sizes[-1] = weight.size(0)
416
assert len(input.shape) == 2, "we can only handle the squashed input case"
417
transposed_strides = (1, input.size(0))
419
if out_dtype is not None:
421
input.dtype == torch.int8 and out_dtype == torch.int32
422
), "out_dtype is only supported for i8i8->i32 linear operator"
423
output = input.new_empty(
425
dtype=input.dtype if out_dtype is None else out_dtype,
426
).as_strided(output_sizes, transposed_strides)
431
@register_meta(aten._cslt_sparse_mm)
432
def meta__cslt_sparse_mm(
433
compressed_A: torch.Tensor,
434
dense_B: torch.Tensor,
435
bias: Optional[Tensor] = None,
436
alpha: Optional[Tensor] = None,
437
out_dtype: Optional[torch.dtype] = None,
438
transpose_result: bool = False,
440
assert dense_B.dtype in {
445
}, "_cslt_sparse_mm only supports fp16, bf16, and int8"
446
assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
447
assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
449
is_int8_input_type = compressed_A.dtype == torch.int8
450
compression_factor = 10 if is_int8_input_type else 9
453
m = (compressed_A.numel() * 16) // (compression_factor * k)
455
assert m == bias.size(0)
457
if out_dtype is not None:
458
assert is_int8_input_type and out_dtype in {
462
}, "out_dtype is only supported for i8i8->fp16, bf16, or i32 matmul"
463
output_shape = (n, m) if transpose_result else (m, n)
464
result = dense_B.new_empty(output_shape, dtype=out_dtype)
468
@register_meta(aten.index_reduce.default)
469
def meta_index_reduce(
473
source: torch.Tensor,
476
include_self: bool = True,
478
return torch.empty_like(self, memory_format=torch.contiguous_format)
481
@register_meta(aten.index_reduce_.default)
482
def meta_index_reduce_(
486
source: torch.Tensor,
489
include_self: bool = True,
496
@register_meta(aten.index_select.default)
497
def meta_index_select(self, dim, index):
498
result_size = list(self.size())
500
result_size[dim] = index.numel()
501
return self.new_empty(result_size)
504
@register_meta(aten.segment_reduce.default)
505
def meta_segment_reduce(
509
lengths: Optional[Tensor] = None,
510
indices: Optional[Tensor] = None,
511
offsets: Optional[Tensor] = None,
513
unsafe: bool = False,
516
if indices is not None:
517
raise NotImplementedError(
518
"segment_reduce(): indices based reduction is not supported yet."
521
def segment_reduce_lengths_tensor(lengths_shape):
523
lengths_shape + data.shape[axis + 1 :],
526
memory_format=torch.contiguous_format,
529
if lengths is not None:
530
return segment_reduce_lengths_tensor(lengths.shape)
533
if offsets is not None:
535
lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
536
return segment_reduce_lengths_tensor(lengths_shape)
537
raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
540
@register_meta([aten.max.default, aten.max.unary_out])
543
return self.new_empty(())
546
@register_meta(aten.max.dim)
547
def meta_max_dim(self, dim, keepdim=False):
548
dim = utils.reduction_dims(self.shape, (dim,))
549
output_shape = _compute_reduction_shape(self, dim, keepdim)
551
self.new_empty(output_shape),
552
self.new_empty(output_shape, dtype=torch.long),
556
@register_meta([aten.min.default, aten.min.unary_out])
559
return self.new_empty(())
562
@register_meta(aten.min.dim)
563
def meta_min_dim(self, dim, keepdim=False):
564
dim = utils.reduction_dims(self.shape, (dim,))
565
output_shape = _compute_reduction_shape(self, dim, keepdim)
567
self.new_empty(output_shape),
568
self.new_empty(output_shape, dtype=torch.long),
572
@register_meta(aten.angle.default)
574
if self.is_complex():
575
result_dtype = corresponding_real_dtype(self.dtype)
577
_, result_dtype = elementwise_dtypes(
579
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
581
return torch.empty_like(self, dtype=result_dtype)
584
@register_meta(aten.angle.out)
585
def meta_angle_out(self, out):
586
torch._resize_output_(out, self.size(), self.device)
587
return out.copy_(torch.angle(self))
590
@register_meta(aten._assert_async.default)
591
def assert_async(val):
595
@register_meta(aten._assert_async.msg)
596
def assert_async_meta(val, assert_msg):
600
@register_meta(aten._print.default)
605
@register_meta(aten._make_dep_token.default)
614
return torch.empty([], device="meta")
617
@register_meta(aten.sym_constrain_range.default)
618
def sym_constrain_range(size, min=None, max=None):
620
from torch.fx.experimental.symbolic_shapes import constrain_range
622
if isinstance(size, (SymFloat, SymBool)):
623
raise ValueError("Constraining SymFloat or Symbool is nyi")
624
constrain_range(size, min=min, max=max)
627
@register_meta(aten._functional_sym_constrain_range.default)
628
def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
629
aten.sym_constrain_range(size, min=min, max=max)
633
@register_meta(aten.sym_constrain_range_for_size.default)
634
def sym_constrain_range_for_size(size, min=None, max=None):
636
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
638
if isinstance(size, (SymFloat, SymBool)):
639
raise ValueError("Constraining SymFloat or Symbool is nyi")
640
_constrain_range_for_size(size, min=min, max=max)
643
@register_meta(aten._functional_sym_constrain_range_for_size.default)
644
def functional_sym_constrain_range_for_size(size, min, max, dep_token):
645
aten.sym_constrain_range_for_size(size, min=min, max=max)
649
@register_meta(aten._functional_assert_async.msg)
650
def functional_assert_async_meta(val, assert_msg, dep_token):
655
def squareCheckInputs(self: Tensor, f_name: str):
658
), f"{f_name}: The input tensor must have at least 2 dimensions."
659
assert self.size(-1) == self.size(
661
), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
667
def linearSolveCheckInputs(
673
self.device == A.device,
675
f"Expected b and A to be on the same device, but found b on "
676
f"{self.device} and A on {A.device} instead."
681
self.dtype == A.dtype,
683
f"Expected b and A to have the same dtype, but found b of type "
684
f"{self.dtype} and A of type {A.dtype} instead."
689
A.size(-1) == A.size(-2),
691
f"A must be batches of square matrices, "
692
f"but they are {A.size(-2)} by {A.size(-1)} matrices"
697
A.size(-1) == self.size(-2),
699
f"Incompatible matrix sizes for {name}: each A "
700
f"matrix is {A.size(-1)} by {A.size(-1)}"
701
f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
707
def checkFloatingOrComplex(
708
t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True
712
t.is_floating_point() or t.is_complex(),
713
lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
715
if not allow_low_precision_dtypes:
717
dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
718
lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
723
def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
726
lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
730
def checkInputsSolver(
736
squareCheckInputs(A, f_name)
737
checkIsMatrix(B, f_name)
739
A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
741
f"{f_name}: Incompatible shapes of A and B for the equation "
742
f"{'AX = B' if left else 'XA = B'}"
743
f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
749
fn_name: str, result: Tensor, input: Tensor, result_name: str = "result"
752
result.device == input.device,
754
f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
755
f"{result_name} on {result.device} and input on {input.device}"
760
def checkUplo(UPLO: str):
761
UPLO_uppercase = UPLO.upper()
763
len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
764
lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
768
@register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
769
@out_wrapper("eigenvalues", "eigenvectors")
770
def meta__linalg_eigh(
773
compute_v: bool = True,
775
squareCheckInputs(A, "linalg.eigh")
778
shape = list(A.shape)
780
vecs = A.new_empty(shape)
781
vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
783
vecs = A.new_empty([0])
786
vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
791
def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
792
return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
795
@register_meta(aten._cholesky_solve_helper)
797
def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
798
return cloneBatchedColumnMajor(self)
801
@register_meta(aten.cholesky_solve)
803
def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
806
lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
810
lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
812
self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
813
self, A, "cholesky_solve"
815
return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
818
@register_meta(aten.cholesky)
820
def cholesky(self: Tensor, upper: bool = False) -> Tensor:
821
if self.numel() == 0:
822
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
823
squareCheckInputs(self, "cholesky")
824
return cloneBatchedColumnMajor(self)
827
@register_meta(aten.cholesky_inverse)
829
def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
830
squareCheckInputs(self, "cholesky_inverse")
831
return cloneBatchedColumnMajor(self)
835
@register_meta(aten.linalg_cholesky_ex.default)
836
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
837
squareCheckInputs(A, "linalg.cholesky")
838
checkFloatingOrComplex(A, "linalg.cholesky")
844
L_strides = make_contiguous_strides_for(A_shape, False)
845
L = A.new_empty(A_shape)
846
L.as_strided_(A_shape, L_strides)
849
infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
854
[aten.linalg_householder_product.default, aten.linalg_householder_product.out]
857
def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
860
lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
863
input.size(-2) >= input.size(-1),
864
lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
867
input.size(-1) >= tau.size(-1),
868
lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
872
input.ndim - tau.ndim == 1,
874
f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
875
f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
879
expected_batch_tau_shape = input.shape[:-2]
880
actual_batch_tau_shape = tau.shape[:-1]
882
actual_batch_tau_shape == expected_batch_tau_shape,
884
f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
885
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
890
tau.dtype == input.dtype,
892
f"torch.linalg.householder_product: tau dtype {tau.dtype}"
893
f" does not match input dtype {input.dtype}"
896
checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
898
return torch.empty_strided(
900
stride=make_contiguous_strides_for(input.shape, row_major=False),
907
@register_meta(aten.linalg_inv_ex.default)
908
def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
909
squareCheckInputs(A, "linalg.inv_ex")
910
checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
912
L = A.new_empty(A.shape)
913
L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
915
infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
919
@register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
920
@out_wrapper("LD", "pivots", "info")
921
def linalg_ldl_factor_ex_meta(
924
hermitian: bool = False,
925
check_errors: bool = False,
926
) -> Tuple[Tensor, Tensor, Tensor]:
927
squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
928
checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
929
LD = torch.empty_strided(
931
stride=make_contiguous_strides_for(self.shape, row_major=False),
935
pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
936
info = self.new_empty(self.shape[:-2], dtype=torch.int)
937
return LD, pivots, info
940
@register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
942
def linalg_ldl_solve_meta(
943
LD: Tensor, pivots: Tensor, B: Tensor, *, hermitian: bool = False
945
squareCheckInputs(LD, "torch.linalg.ldl_solve")
946
checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
947
linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
951
f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
952
f"but it has {B.ndim} dimensions instead"
955
expected_pivots_shape = LD.shape[:-1]
957
expected_pivots_shape == pivots.shape,
959
f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
960
f"but got pivots with shape {pivots.shape} instead"
964
utils.is_integer_dtype(pivots.dtype),
965
lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
969
lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
971
B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
972
return torch.empty_strided(
973
size=B_broadcast_size,
974
stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
980
@register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
981
@out_wrapper("P", "L", "U")
982
def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
985
lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
988
sizes = list(A.shape)
995
P = A.new_empty(sizes)
1000
L = A.new_empty(sizes)
1004
U = A.new_empty(sizes)
1008
@register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
1009
@out_wrapper("LU", "pivots", "info")
1010
def linalg_lu_factor_ex_meta(
1011
A: Tensor, *, pivot: bool = True, check_errors: bool = False
1012
) -> Tuple[Tensor, Tensor, Tensor]:
1015
lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
1018
sizes = list(A.shape)
1022
LU = torch.empty_strided(
1024
stride=make_contiguous_strides_for(sizes, row_major=False),
1031
sizes[-1] = min(m, n)
1032
pivots = A.new_empty(sizes, dtype=torch.int)
1036
info = A.new_empty(sizes, dtype=torch.int)
1038
return LU, pivots, info
1041
@register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
1043
def linalg_lu_solve_meta(
1049
adjoint: bool = False,
1052
checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
1054
LU.dtype == B.dtype,
1056
f"linalg.lu_solve: Expected LU and B to have the same dtype, "
1057
f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
1061
pivots.dtype == torch.int,
1062
lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
1066
squareCheckInputs(LU, "torch.linalg.lu_solve")
1067
checkInputsSolver(LU, B, left, "linalg.lu_solve")
1069
LU.size(-1) == pivots.size(-1),
1070
lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
1075
LU.shape[:-1] == pivots.shape,
1077
f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
1078
f"but got pivots with shape {pivots.shape} instead"
1082
B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
1084
result = torch.empty_strided(
1085
size=B_broadcast_size,
1086
stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
1091
if result.numel() != 0 and not left:
1092
if result.is_complex():
1093
result = result.conj()
1098
@register_meta(aten.lu_unpack)
1099
@out_wrapper("P", "L", "U")
1103
unpack_data: bool = True,
1104
unpack_pivots: bool = True,
1105
) -> Tuple[Tensor, Tensor, Tensor]:
1108
lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
1112
pivots.dtype == torch.int32,
1114
"torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
1115
"Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
1118
sizes = list(LU.shape)
1124
P = LU.new_empty(sizes)
1126
P = LU.new_empty([0])
1129
L = LU.new_empty(sizes)
1132
U = LU.new_empty(sizes)
1134
L = LU.new_empty([0])
1135
U = LU.new_empty([0])
1140
def _parse_qr_mode(mode: str) -> Tuple[bool, bool]:
1141
if mode == "reduced":
1144
elif mode == "complete":
1154
f"qr received unrecognized mode '{mode}' "
1155
f"but expected one of 'reduced' (default), 'r', or 'complete'"
1158
return compute_q, reduced
1161
@register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
1162
@out_wrapper("Q", "R")
1165
mode: str = "reduced",
1166
) -> Tuple[Tensor, Tensor]:
1167
checkIsMatrix(A, "linalg.qr")
1168
checkFloatingOrComplex(A, "linalg.qr")
1170
compute_q, reduced_mode = _parse_qr_mode(mode)
1177
Q_shape = list(A.shape)
1178
Q_shape[-1] = k if reduced_mode else m
1179
Q = A.new_empty(Q_shape)
1180
Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
1182
Q = A.new_empty([0])
1185
R_shape = list(A.shape)
1186
R_shape[-2] = k if reduced_mode or not compute_q else m
1187
R = A.new_empty(R_shape)
1188
R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
1192
@register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
1193
@out_wrapper("sign", "logabsdet", "LU", "pivots")
1194
def _linalg_slogdet(A: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1195
squareCheckInputs(A, "linalg.slogdet")
1196
checkFloatingOrComplex(A, "linalg.slogdet", False)
1198
sign = A.new_empty(shape[:-2])
1199
logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
1200
LU = torch.empty_strided(
1202
stride=make_contiguous_strides_for(shape, False),
1206
pivots = A.new_empty(shape[:-1], dtype=torch.int32)
1207
return sign, logabsdet, LU, pivots
1212
@register_meta(aten._linalg_svd.default)
1213
def _linalg_svd_meta(
1215
full_matrices: bool = False,
1216
compute_uv: bool = True,
1217
driver: Optional[str] = None,
1219
checkIsMatrix(A, "linalg.svd")
1220
checkFloatingOrComplex(A, "linalg.svd")
1222
batch_dims = list(A.shape[:-2])
1228
U_shape = batch_dims + [m, m if full_matrices else k]
1229
U = A.new_empty(U_shape)
1230
U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
1232
V_shape = batch_dims + [n if full_matrices else k, n]
1233
V = A.new_empty(V_shape)
1238
is_cuda = device_hint(A) == "cuda"
1239
V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
1242
U = A.new_empty([0])
1243
V = A.new_empty([0])
1246
S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
1250
def _linalg_broadcast_batch_dims(
1251
arg1: Tensor, arg2: Tensor
1252
) -> Tuple[List[int], List[int]]:
1254
arg1_batch_sizes = arg1.shape[:-2]
1255
arg2_batch_sizes = arg2.shape[:-2]
1256
expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
1258
arg1_expand_size = list(expand_batch_portion)
1259
arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
1261
arg2_expand_size = list(expand_batch_portion)
1262
arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
1263
return arg1_expand_size, arg2_expand_size
1266
def _linalg_broadcast_batch_dims_name(
1267
arg1: Tensor, arg2: Tensor, name: Optional[str]
1268
) -> Tuple[Tensor, Tensor]:
1271
linearSolveCheckInputs(arg1, arg2, name)
1273
arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
1275
arg1_broadcasted = (
1276
arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
1278
arg2_broadcasted = (
1279
arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
1281
return arg1_broadcasted, arg2_broadcasted
1284
def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
1285
expected_batched_rhs_shape = input.shape[:-1]
1286
vector_case = other.ndim == 1 or (
1287
input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
1292
@register_meta(aten._linalg_solve_ex)
1293
def _linalg_solve_ex(
1298
check_errors: bool = False,
1299
result: Optional[Tensor] = None,
1300
LU: Optional[Tensor] = None,
1301
pivots: Optional[Tensor] = None,
1302
info: Optional[Tensor] = None,
1303
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
1304
checkFloatingOrComplex(A, "linalg.solve")
1308
f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
1309
f"{A.dtype} and B of type {B.dtype} instead"
1312
vector_case = linalg_solve_is_vector_rhs(A, B)
1313
B_ = B.unsqueeze(-1) if vector_case else B
1314
checkInputsSolver(A, B_, left, "linalg.solve")
1315
B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
1317
left or not vector_case,
1319
"linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
1320
"In this case linalg.solve is equivalent to B / A.squeeze(-1)"
1323
result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
1324
result_ = torch.empty_strided(
1326
stride=make_contiguous_strides_for(result_shape, not left),
1332
LU_ = torch.empty_strided(
1334
stride=make_contiguous_strides_for(shape, False),
1338
pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
1339
info_ = A.new_empty(shape[:-2], dtype=torch.int32)
1340
out = (result, LU, pivots, info)
1341
res = (result_, LU_, pivots_, info_)
1342
if all(x is not None for x in out):
1343
for r, o in zip(res, out):
1345
_maybe_resize_out(o, r.shape)
1347
o.as_strided_(r.shape, r.stride())
1348
_safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False)
1352
@register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
1353
def linalg_solve_triangular_meta(
1359
unitriangular: bool = False,
1360
out: Optional[Tensor] = None,
1363
out = A.new_empty([0])
1364
assert isinstance(out, TensorLike)
1365
checkInputsSolver(A, B, left, "linalg.solve_triangular")
1366
B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
1367
avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
1369
out = _maybe_resize_out(out, B_.shape)
1372
if _resize_output_check(out, B_.shape):
1373
out.resize_(B_.transpose(-2, -1).shape)
1374
out.transpose_(-2, -1)
1378
@register_meta(aten.triangular_solve)
1379
@out_wrapper("solution", "cloned_coefficient")
1380
def triangular_solve_meta(
1384
transpose: bool = False,
1385
unitriangular: bool = False,
1386
) -> Tuple[Tensor, Tensor]:
1390
f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
1391
f"but it has {self.ndim} dimensions instead"
1397
f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
1398
f"but it has {A.ndim} dimensions instead"
1402
linearSolveCheckInputs(self, A, "triangular_solve")
1404
if A.layout == torch.strided:
1405
self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
1406
solution = torch.empty_strided(
1407
size=self_broadcast_size,
1408
stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
1412
cloned_coefficient = torch.empty_strided(
1413
size=A_broadcast_size,
1414
stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
1418
elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
1419
solution = torch.empty_like(self)
1420
cloned_coefficient = self.new_empty([0])
1422
torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
1423
return solution, cloned_coefficient
1427
@register_meta(aten._linalg_det.default)
1428
def _linalg_det_meta(A):
1429
squareCheckInputs(A, "linalg.det")
1430
checkFloatingOrComplex(A, "linalg.det")
1432
det = A.new_empty(A.shape[:-2])
1434
LU = A.new_empty(A.shape)
1435
LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
1437
pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
1438
return det, LU, pivots
1441
@register_meta(aten.ormqr)
1448
transpose: bool = False,
1451
input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
1454
other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
1457
left_size_condition = -2 if left else -1
1459
other.shape[left_size_condition] >= tau.shape[-1],
1460
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
1463
other.shape[left_size_condition] == input.shape[-2],
1464
lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
1468
tau.shape[-1] <= input.shape[-1],
1469
lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
1473
input.ndim - tau.ndim == 1,
1475
f"torch.ormqr: Expected tau to have one dimension less than input, "
1476
f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
1480
input.ndim == other.ndim,
1482
f"torch.ormqr: Expected other to have the same number of dimensions as input, "
1483
f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
1488
expected_batch_shape = input.shape[:-2]
1489
actual_batch_tau_shape = tau.shape[:-1]
1491
actual_batch_tau_shape == expected_batch_shape,
1493
f"torch.ormqr: Expected batch dimensions of tau to be "
1494
f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
1498
actual_batch_other_shape = other.shape[:-2]
1500
actual_batch_other_shape == expected_batch_shape,
1502
f"torch.ormqr: Expected batch dimensions of other to be "
1503
f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
1508
tau.dtype == input.dtype,
1510
f"torch.ormqr: Expected input and tau to have the same dtype, "
1511
f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
1515
other.dtype == input.dtype,
1517
f"torch.ormqr: Expected input and other to have the same dtype, "
1518
f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
1522
checkSameDevice("torch.ormqr", tau, input, "tau")
1523
checkSameDevice("torch.ormqr", other, input, "other")
1525
return torch.empty_strided(
1527
stride=make_contiguous_strides_for(other.shape, row_major=False),
1529
device=other.device,
1533
def _padding_check_valid_input(input, padding, *, dim):
1535
len(padding) == 2 * dim,
1536
lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
1539
input_dim = input.ndim
1541
is_batch_mode = input_dim == (dim + 2)
1543
valid_batch_mode = is_batch_mode
1544
valid_non_batch_mode = not is_batch_mode
1548
for d in range(1, input_dim):
1549
valid_batch_mode = valid_batch_mode and input.size(d) != 0
1551
for d in range(0, input_dim):
1552
valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
1556
valid_batch_mode or valid_non_batch_mode,
1558
f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
1559
f"and other non-zero dimensions for input, but got: {input.shape}"
1564
def _pad1d_common(input, padding, *, is_reflection):
1570
nbatch = input.size(0)
1574
_padding_check_valid_input(input, padding, dim=1)
1576
pad_l, pad_r = padding
1578
nplane = input.size(dim_plane)
1579
input_w = input.size(dim_w)
1580
output_w = input_w + pad_l + pad_r
1584
pad_l < input_w and pad_r < input_w,
1586
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1587
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1593
lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
1597
return input.new_empty((nplane, output_w))
1599
return input.new_empty((nbatch, nplane, output_w))
1602
@register_meta(aten.reflection_pad1d)
1604
def meta_reflection_pad1d(input, padding):
1605
return _pad1d_common(input, padding, is_reflection=True)
1608
@register_meta(aten.replication_pad1d)
1610
def meta_replication_pad1d(input, padding):
1611
return _pad1d_common(input, padding, is_reflection=False)
1614
def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
1616
if not is_reflection:
1617
torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
1622
pad_l, pad_r = padding
1624
input_w = input.size(dim_w)
1625
output_w = input_w + pad_l + pad_r
1629
pad_l < input_w and pad_r < input_w,
1631
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1632
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1637
output_w == grad_output.size(dim_w),
1638
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1641
return input.new_empty(input.shape)
1644
@register_meta(aten.reflection_pad1d_backward)
1645
@out_wrapper("grad_input")
1646
def meta_reflection_pad1d_backward(grad_output, input, padding):
1647
return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
1650
@register_meta(aten.replication_pad1d_backward)
1651
@out_wrapper("grad_input")
1652
def meta_replication_pad1d_backward(grad_output, input, padding):
1653
return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
1656
def _pad2d_common(input, padding, *, is_reflection):
1662
_padding_check_valid_input(input, padding, dim=2)
1666
nbatch = input.size(0)
1671
pad_l, pad_r, pad_t, pad_b = padding
1673
nplane = input.size(dim_slices)
1674
input_h = input.size(dim_h)
1675
input_w = input.size(dim_w)
1676
output_h = input_h + pad_t + pad_b
1677
output_w = input_w + pad_l + pad_r
1681
pad_l < input_w and pad_r < input_w,
1683
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1684
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1688
pad_t < input_h and pad_b < input_h,
1690
f"Argument #6: Padding size should be less than the corresponding input dimension, "
1691
f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1696
output_w >= 1 or output_h >= 1,
1698
f"input (H: {input_h} W: {input_w}) is too small. "
1699
f"Calculated output H: {output_h} W: {output_w}"
1704
return input.new_empty((nplane, output_h, output_w))
1706
return input.new_empty((nbatch, nplane, output_h, output_w))
1709
@register_meta(aten.reflection_pad2d)
1711
def meta_reflection_pad2d(input, padding):
1712
return _pad2d_common(input, padding, is_reflection=True)
1715
@register_meta(aten.replication_pad2d)
1717
def meta_replication_pad2d(input, padding):
1718
return _pad2d_common(input, padding, is_reflection=False)
1723
aten.reflection_pad2d_backward.default,
1724
aten.reflection_pad2d_backward.grad_input,
1725
aten.replication_pad2d_backward.default,
1726
aten.replication_pad2d_backward.grad_input,
1729
@out_wrapper("grad_input")
1730
def meta_pad2d_backward(grad_output, self, padding):
1736
self_shape = self.shape
1738
nbatch = self_shape[0]
1743
pad_l, pad_r, pad_t, pad_b = padding
1745
nplane = self_shape[dim_plane]
1746
input_h = self_shape[dim_h]
1747
input_w = self_shape[dim_w]
1748
output_h = input_h + pad_t + pad_b
1749
output_w = input_w + pad_l + pad_r
1752
output_w == grad_output.size(dim_w),
1753
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1756
output_h == grad_output.size(dim_h),
1757
lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1759
return self.new_empty(self.shape)
1762
def _pad3d_common(input, padding, *, is_reflection):
1768
_padding_check_valid_input(input, padding, dim=3)
1770
batch_mode = input.ndim == 5
1772
nbatch = input.size(0)
1778
pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1780
nplane = input.size(dim_plane)
1781
input_d = input.size(dim_d)
1782
input_h = input.size(dim_h)
1783
input_w = input.size(dim_w)
1784
output_d = input_d + pad_f + pad_bk
1785
output_h = input_h + pad_t + pad_b
1786
output_w = input_w + pad_l + pad_r
1790
pad_l < input_w and pad_r < input_w,
1792
f"Argument #4: Padding size should be less than the corresponding input dimension, "
1793
f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
1797
pad_t < input_h and pad_b < input_h,
1799
f"Argument #6: Padding size should be less than the corresponding input dimension, "
1800
f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
1804
pad_f < input_d and pad_bk < input_d,
1806
f"Argument #8: Padding size should be less than the corresponding input dimension, "
1807
f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
1812
output_w >= 1 or output_h >= 1 or output_d >= 1,
1814
f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
1815
f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
1820
return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
1822
return input.new_empty((nplane, output_d, output_h, output_w))
1825
@register_meta(aten.reflection_pad3d)
1827
def meta_reflection_pad3d(input, padding):
1828
return _pad3d_common(input, padding, is_reflection=True)
1831
@register_meta(aten.replication_pad3d)
1833
def meta_replication_pad3d(input, padding):
1834
return _pad3d_common(input, padding, is_reflection=False)
1839
aten.reflection_pad3d_backward.default,
1840
aten.reflection_pad3d_backward.grad_input,
1841
aten.replication_pad3d_backward.default,
1842
aten.replication_pad3d_backward.grad_input,
1845
@out_wrapper("grad_input")
1846
def meta_pad3d_backward(grad_output, input, padding):
1847
torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
1848
assert input.ndim > 3
1849
assert grad_output.ndim == input.ndim
1860
pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
1862
input_d = input.size(dim_d)
1863
input_h = input.size(dim_h)
1864
input_w = input.size(dim_w)
1865
output_d = input_d + pad_f + pad_bk
1866
output_h = input_h + pad_t + pad_b
1867
output_w = input_w + pad_l + pad_r
1870
output_w == grad_output.size(dim_w),
1871
lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
1874
output_h == grad_output.size(dim_h),
1875
lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
1878
output_d == grad_output.size(dim_d),
1879
lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
1882
return input.new_empty(input.shape)
1885
@register_meta(aten._pdist_forward)
1887
def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
1889
self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
1893
return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format)
1895
return self.new_empty((n * (n - 1) // 2,)).to(
1896
memory_format=torch.legacy_contiguous_format
1900
@register_meta(aten._pdist_backward)
1902
def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
1904
self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
1907
pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
1909
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
1912
@register_meta([aten.baddbmm.default, aten.baddbmm.out])
1914
def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
1915
dim1 = batch1.size(0)
1916
dim2 = batch1.size(1)
1917
dim3 = batch2.size(2)
1918
self = self.expand((dim1, dim2, dim3))
1919
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
1920
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
1922
self.dtype == batch1.dtype == batch2.dtype,
1923
lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
1925
batch1_sizes = batch1.shape
1926
batch2_sizes = batch2.shape
1927
bs = batch1_sizes[0]
1928
contraction_size = batch1_sizes[2]
1930
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
1932
f"Expected size for first two dimensions of batch2 tensor to be: "
1933
f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
1936
return self.new_empty(self.size())
1939
@register_meta([aten.bernoulli.default, aten.bernoulli.out])
1941
def meta_bernoulli(self, *, generator=None):
1943
return torch.empty_like(self).contiguous()
1946
@register_meta(aten.bernoulli_.float)
1947
def meta_bernoulli_(self, p=0.5, generator=None):
1951
@register_meta(aten.bernoulli.p)
1952
def meta_bernoulli_p(self, p=0.5, generator=None):
1954
return torch.empty_like(self).contiguous()
1957
@register_meta(aten._fused_moving_avg_obs_fq_helper.default)
1958
def meta__fused_moving_avg_obs_fq_helper(
1970
per_row_fake_quant=False,
1971
symmetric_quant=False,
1974
ch_axis < self.dim(),
1975
lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
1977
mask = torch.empty_like(self, dtype=torch.bool)
1978
return (torch.empty_like(self), mask)
1981
@register_meta(aten.mm)
1984
torch._check(a.dim() == 2, lambda: "a must be 2D")
1985
torch._check(b.dim() == 2, lambda: "b must be 2D")
1990
lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
1992
return a.new_empty(N, P)
1995
def _compute_reduction_shape(self, dims, keepdim):
1997
return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
1999
return utils.compute_reduction_output_shape(self.shape, dims)
2006
def device_hint(tensor) -> "str":
2007
if isinstance(tensor, torch._subclasses.FakeTensor):
2008
return tensor.fake_device.type
2013
def calc_conv_nd_return_shape(
2014
input_tensor: torch.Tensor,
2015
weight: torch.Tensor,
2016
stride: Union[List[int], int],
2017
padding: Union[List[int], int],
2018
dilation: Union[List[int], int],
2019
is_transposed: bool,
2021
output_padding: Optional[Union[List[int], int]] = None,
2023
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
2025
Formula to apply to calculate the length of some dimension of the output
2027
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
2030
ln: length of the dimension
2031
p: padding in that dim
2032
d: dilation in that dim
2033
k: kernel size in that dim
2034
s: stride in that dim
2038
return (ln + 2 * p - d * (k - 1) - 1) // s + 1
2040
def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
2042
Formula to apply to calculate the length of some dimension of the output
2043
if transposed convolution is used.
2044
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
2047
ln: length of the dimension
2048
p: padding in that dim
2049
d: dilation in that dim
2050
k: kernel size in that dim
2051
s: stride in that dim
2052
op: output padding in that dim
2057
return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
2059
kernel_size = weight.shape[2:]
2060
dims = input_tensor.shape[2:]
2062
out_channels = groups * weight.shape[1]
2064
out_channels = weight.shape[0]
2065
if weight.shape[1] * groups != input_tensor.shape[1]:
2066
raise RuntimeError("Invalid channel dimensions")
2068
ret_shape = [input_tensor.shape[0], out_channels]
2069
if isinstance(stride, IntLike):
2070
stride = [stride] * len(dims)
2071
elif len(stride) == 1:
2072
stride = [stride[0]] * len(dims)
2074
if isinstance(padding, IntLike):
2075
padding = [padding] * len(dims)
2076
elif len(padding) == 1:
2077
padding = [padding[0]] * len(dims)
2079
if isinstance(dilation, IntLike):
2080
dilation = [dilation] * len(dims)
2081
elif len(dilation) == 1:
2082
dilation = [dilation[0]] * len(dims)
2084
output_padding_list: Optional[List[int]] = None
2086
if isinstance(output_padding, IntLike):
2087
output_padding_list = [output_padding] * len(dims)
2088
elif len(output_padding) == 1:
2089
output_padding_list = [output_padding[0]] * len(dims)
2091
output_padding_list = output_padding
2093
for i in range(len(dims)):
2095
if output_padding_list:
2097
_formula_transposed(
2103
output_padding_list[i],
2108
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
2114
def is_channels_last(ten):
2115
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
2118
@register_meta(aten.convolution.default)
2120
input_tensor: torch.Tensor,
2121
weight: torch.Tensor,
2125
dilation: List[int],
2126
is_transposed: bool,
2127
output_padding: List[int],
2130
def pick_memory_format():
2131
if device_hint(input_tensor) == "cuda":
2132
if is_channels_last(input_tensor) or is_channels_last(weight):
2133
return torch.channels_last
2135
if is_channels_last(input_tensor):
2136
return torch.channels_last
2137
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
2138
return torch.contiguous_format
2139
elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
2140
return torch.preserve_format
2142
shape_out = calc_conv_nd_return_shape(
2150
output_padding if is_transposed else None,
2153
input_channels_dim = 1
2154
output_channels_dim = 1
2155
if input_tensor.size(input_channels_dim) == 0:
2156
shape_out[output_channels_dim] = 0
2158
out = input_tensor.new_empty(shape_out)
2159
out = out.to(memory_format=pick_memory_format())
2163
if torch._C._has_mkldnn:
2164
_meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
2165
"mkldnn", "IMPL", "Meta"
2168
@register_meta(torch.ops.mkldnn._convolution_pointwise.default)
2169
def meta_mkldnn_convolution_default(
2181
shape_out = calc_conv_nd_return_shape(
2182
input_tensor, weight, stride, padding, dilation, False, groups, []
2184
out = input_tensor.new_empty(shape_out)
2185
out_memory_format = torch.channels_last
2186
out = out.to(memory_format=out_memory_format)
2189
@register_meta(torch.ops.mkldnn._linear_pointwise.default)
2190
def meta_linear_pointwise_default(
2191
input_tensor, weight, bias, attr, scalars, algorithm
2193
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2195
if torch._C.has_mkl:
2196
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
2197
"mkl", "IMPL", "Meta"
2200
@register_meta(torch.ops.mkl._mkl_linear)
2201
def meta_mkl_linear(
2208
return input_tensor.new_empty(
2209
(*input_tensor.shape[:-1], orig_weight.shape[0])
2212
_meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
2213
"onednn", "IMPL", "Meta"
2216
@register_meta(torch.ops.onednn.qconv2d_pointwise.default)
2217
def meta_qconv2d_pointwise(
2236
shape_out = calc_conv_nd_return_shape(
2246
assert output_dtype in [torch.float32, torch.bfloat16]
2247
out = x.new_empty(shape_out, dtype=output_dtype)
2248
out = out.to(memory_format=torch.channels_last)
2251
@register_meta(torch.ops.onednn.qlinear_pointwise.default)
2252
@register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
2253
def meta_qlinear_pointwise(
2268
output_shape = list(x.shape)
2270
output_shape[-1] = w.shape[1]
2271
assert output_dtype in [torch.float32, torch.bfloat16]
2272
out = x.new_empty(output_shape, dtype=output_dtype)
2275
_meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
2276
"quantized", "IMPL", "Meta"
2279
@register_meta(torch.ops.quantized.max_pool2d)
2280
def meta_quantized_max_pool2d(
2292
) = max_pool2d_checks_and_compute_shape(
2293
input, kernel_size, stride, padding, dilation, ceil_mode
2295
nbatch = input.size(-4) if input.dim() == 4 else 1
2296
memory_format = torch.channels_last
2297
if input.dim() == 3:
2298
size = [nInputPlane, outputHeight, outputWidth]
2300
size = [nbatch, nInputPlane, outputHeight, outputWidth]
2304
device=input.device,
2305
memory_format=memory_format,
2310
def check_dim_size(tensor, dim, dim_size, size):
2312
tensor.dim() == dim and tensor.shape[dim_size] == size,
2313
lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
2314
+ f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
2318
@register_meta(aten.avg_pool2d.default)
2325
count_include_pad=True,
2326
divisor_override=None,
2328
def unpack(name, val):
2331
lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
2334
W = H if len(val) == 1 else val[1]
2337
kH, kW = unpack("kernel_size", kernel_size)
2339
len(stride) in [0, 1, 2],
2340
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2342
if len(stride) == 0:
2344
elif len(stride) == 1:
2345
dH, dW = stride[0], stride[0]
2347
dH, dW = unpack("stride", stride)
2349
padH, padW = unpack("padding", padding)
2352
divisor_override is None or divisor_override != 0,
2353
lambda: "divisor must be not zero",
2356
nbatch = input.size(-4) if input.dim() == 4 else 1
2357
nInputPlane = input.size(-3)
2358
inputHeight = input.size(-2)
2359
inputWidth = input.size(-1)
2361
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2362
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2364
memory_format = utils.suggest_memory_format(input)
2383
if input.dim() == 3:
2384
size = [nInputPlane, outputHeight, outputWidth]
2386
size = [nbatch, nInputPlane, outputHeight, outputWidth]
2390
device=input.device,
2391
memory_format=memory_format,
2396
def avg_pool2d_backward_shape_check(
2432
nOutputPlane = nInputPlane
2434
check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
2435
check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
2436
check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
2440
@register_meta(aten.avg_pool2d_backward.default)
2441
def meta_avg_pool2d_backward(
2453
len(kernel_size) == 1 or len(kernel_size) == 2,
2454
lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
2457
kW = kH if len(kernel_size) == 1 else kernel_size[1]
2459
len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
2460
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
2462
dH = kH if len(stride) == 0 else stride[0]
2463
dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
2465
len(padding) == 1 or len(padding) == 2,
2466
lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
2469
padW = padH if len(padding) == 1 else padding[1]
2472
divisor_override is None or divisor_override != 0,
2473
lambda: "divisor must be not zero",
2476
input_size = input.shape
2477
nbatch = input_size[-4] if input.dim() == 4 else 1
2478
nInputPlane = input_size[-3]
2479
inputHeight = input_size[-2]
2480
inputWidth = input_size[-1]
2482
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
2483
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
2485
mem_format = utils.suggest_memory_format(input)
2487
avg_pool2d_backward_shape_check(
2508
device=input.device,
2509
memory_format=mem_format,
2513
@register_meta(aten.avg_pool3d)
2521
count_include_pad=True,
2522
divisor_override=None,
2525
len(kernel_size) in (1, 3),
2526
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2529
kH = kT if len(kernel_size) == 1 else kernel_size[1]
2530
kW = kT if len(kernel_size) == 1 else kernel_size[2]
2533
not stride or len(stride) in (1, 3),
2534
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2536
dT = kT if not stride else stride[0]
2537
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2538
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2541
len(padding) in (1, 3),
2542
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2545
padH = padT if len(padding) == 1 else padding[1]
2546
padW = padT if len(padding) == 1 else padding[2]
2549
input.ndim in (4, 5),
2550
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2554
not divisor_override or divisor_override != 0,
2555
lambda: "divisor must be not zero",
2558
nbatch = input.size(0)
2559
nslices = input.size(-4)
2560
itime = input.size(-3)
2561
iheight = input.size(-2)
2562
iwidth = input.size(-1)
2564
otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2565
oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2566
owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2590
check_input_size=True,
2594
return input.new_empty((nslices, otime, oheight, owidth))
2596
return input.new_empty((nbatch, nslices, otime, oheight, owidth))
2599
@register_meta(aten.avg_pool3d_backward)
2600
@out_wrapper("grad_input")
2601
def meta_avg_pool3d_backward(
2612
len(kernel_size) in (1, 3),
2613
lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
2616
kH = kT if len(kernel_size) == 1 else kernel_size[1]
2617
kW = kT if len(kernel_size) == 1 else kernel_size[2]
2620
not stride or len(stride) in (1, 3),
2621
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
2623
dT = kT if not stride else stride[0]
2624
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
2625
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
2628
len(padding) in (1, 3),
2629
lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
2632
padH = padT if len(padding) == 1 else padding[1]
2633
padW = padT if len(padding) == 1 else padding[2]
2636
input.ndim in (4, 5),
2637
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
2641
not divisor_override or divisor_override != 0,
2642
lambda: "divisor must be not zero",
2645
nslices = input.size(-4)
2646
itime = input.size(-3)
2647
iheight = input.size(-2)
2648
iwidth = input.size(-1)
2650
otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
2651
oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
2652
owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
2654
avg_pool3d_backward_shape_check(
2670
otime_for_shape_check,
2671
oheight_for_shape_check,
2672
owidth_for_shape_check,
2673
"avg_pool3d_backward()",
2676
return input.new_empty(input.shape)
2679
@register_meta(aten._adaptive_avg_pool2d.default)
2680
def meta_adaptive_avg_pool2d(self, output_size):
2682
self.ndim == 3 or self.ndim == 4,
2683
lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
2685
output_shape = self.shape[:-2] + tuple(output_size)
2686
memory_format = utils.suggest_memory_format(self)
2693
memory_format=memory_format,
2697
@register_meta(aten._adaptive_avg_pool3d.default)
2698
def meta_adaptive_avg_pool3d(self, output_size):
2700
self.ndim == 4 or self.ndim == 5,
2701
lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
2703
return self.new_empty(self.shape[:-3] + tuple(output_size))
2706
@register_meta(aten._adaptive_avg_pool2d_backward.default)
2707
def meta__adaptive_avg_pool2d_backward(grad_out, self):
2708
ndim = grad_out.ndim
2709
for i in range(1, ndim):
2711
grad_out.size(i) > 0,
2712
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
2713
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
2716
ndim == 3 or ndim == 4,
2717
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
2720
self.dtype == grad_out.dtype,
2721
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
2723
memory_format = torch.contiguous_format
2724
if is_channels_last(self):
2725
memory_format = torch.channels_last
2726
return self.new_empty(self.shape).to(memory_format=memory_format)
2729
@register_meta(aten._adaptive_avg_pool3d_backward)
2730
@out_wrapper("grad_input")
2731
def meta__adaptive_avg_pool3d_backward(grad_output, self):
2732
_adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
2733
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
2736
def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
2737
ndim = grad_output.ndim
2738
for i in range(1, ndim):
2740
grad_output.size(i) > 0,
2742
f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
2743
f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
2748
@register_meta(aten.adaptive_max_pool2d)
2749
@out_wrapper("out", "indices")
2750
def meta_adaptive_max_pool2d(input, output_size):
2754
lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
2756
for i in range(1, ndim):
2760
f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
2761
f"but input has sizes {input.shape} with dimension {i} being empty"
2766
len(output_size) == 2,
2767
lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
2775
sizeB = input.size(0)
2778
sizeD = input.size(dimH - 1)
2779
osizeH, osizeW = output_size
2782
out_shape = (sizeD, osizeH, osizeW)
2783
out = input.new_empty(out_shape)
2784
indices = input.new_empty(out_shape, dtype=torch.int64)
2787
out_shape = (sizeB, sizeD, osizeH, osizeW)
2788
memory_format = utils.suggest_memory_format(input)
2789
out = input.new_empty(out_shape).to(memory_format=memory_format)
2790
indices = input.new_empty(out_shape, dtype=torch.int64).to(
2791
memory_format=memory_format
2796
@register_meta(aten.adaptive_max_pool2d_backward)
2797
@out_wrapper("grad_input")
2798
def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
2799
ndim = grad_output.ndim
2802
lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
2805
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
2808
input.dtype == grad_output.dtype,
2809
lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
2812
memory_format = utils.suggest_memory_format(input)
2813
return input.new_empty(input.shape).to(memory_format=memory_format)
2816
@register_meta(aten.adaptive_max_pool3d)
2817
@out_wrapper("out", "indices")
2818
def meta_adaptive_max_pool3d(input, output_size):
2822
lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
2824
for i in range(1, ndim):
2828
f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
2829
f"but input has sizes {input.shape} with dimension {i} being empty"
2834
len(output_size) == 3,
2835
lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
2843
sizeB = input.size(0)
2846
sizeD = input.size(dimD)
2847
osizeT, osizeH, osizeW = output_size
2850
out_shape = (sizeD, osizeT, osizeH, osizeW)
2852
out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW)
2854
out = input.new_empty(out_shape)
2855
indices = input.new_empty(out_shape, dtype=torch.int64)
2860
@register_meta(aten.adaptive_max_pool3d_backward)
2861
@out_wrapper("grad_input")
2862
def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
2863
_adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
2864
return input.new_empty(input.shape)
2867
@register_meta(aten.repeat_interleave.Tensor)
2868
def meta_repeat_interleave_Tensor(repeats, output_size=None):
2869
if output_size is None:
2870
raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
2871
return repeats.new_empty(output_size)
2874
@register_meta([aten.complex.default, aten.complex.out])
2876
def meta_complex(real, imag):
2877
assert real.dtype.is_floating_point
2878
assert imag.dtype.is_floating_point
2879
out_shape = _broadcast_shapes(real.shape, imag.shape)
2880
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
2883
@register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
2885
def nonzero_static(self, *, size: int, fill_value: int = -1):
2886
return self.new_empty((size, self.dim()), dtype=torch.long)
2889
@register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
2890
def meta_index_Tensor(self, indices):
2891
torch._check(bool(indices), lambda: "at least one index must be provided")
2894
result: List[Optional[Tensor]] = []
2895
for i, index in enumerate(indices):
2896
if index is not None:
2898
index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
2899
lambda: "tensors used as indices must be long, int, byte or bool tensors",
2901
if index.dtype in [torch.int8, torch.bool]:
2902
nonzero = index.nonzero()
2905
k + index.ndim <= self.ndim,
2906
lambda: f"too many indices for tensor of dimension {self.ndim}",
2908
for j in range(index.ndim):
2910
index.shape[j] == self.shape[k + j],
2911
lambda: f"The shape of the mask {index.shape} at index {i} "
2912
f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
2914
result.append(nonzero.select(1, j))
2916
result.append(index)
2918
result.append(index)
2921
len(indices) <= self.ndim,
2922
lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
2925
import torch._refs as refs
2927
indices = list(refs._maybe_broadcast(*indices))
2929
while len(indices) < self.ndim:
2930
indices.append(None)
2938
has_contiguous_subspace = False
2939
for index in indices:
2941
if index is not None:
2947
if index is not None:
2950
has_contiguous_subspace = True
2955
if not has_contiguous_subspace:
2957
transposed_indices = []
2958
for i, index in enumerate(indices):
2959
if index is not None:
2961
transposed_indices.append(index)
2962
for i, index in enumerate(indices):
2965
transposed_indices.append(index)
2966
self = self.permute(dims)
2967
indices = transposed_indices
2975
before_shape: List[int] = []
2976
after_shape: List[int] = []
2977
replacement_shape: List[int] = []
2978
for dim, index in enumerate(indices):
2980
if replacement_shape:
2981
after_shape.append(self.shape[dim])
2983
before_shape.append(self.shape[dim])
2985
replacement_shape = list(index.shape)
2986
return self.new_empty(before_shape + replacement_shape + after_shape)
2989
@register_meta([aten.convolution_backward.default])
2990
def meta_convolution_backward(
3005
backend_grad_input = None
3006
backend_grad_weight = None
3007
backend_grad_bias = None
3010
backend_grad_input = grad_output_.new_empty(input_.size())
3012
backend_grad_weight = grad_output_.new_empty(weight_.size())
3014
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
3016
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
3019
@register_meta([aten.addbmm.default, aten.addbmm.out])
3021
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
3022
dim1 = batch1.size(1)
3023
dim2 = batch2.size(2)
3024
self = self.expand((dim1, dim2))
3025
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3026
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3028
batch1.size(0) == batch2.size(0),
3029
lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
3032
batch1.size(2) == batch2.size(1),
3034
f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
3035
f"and {batch2.size(1)}x{batch2.size(2)})"
3039
self.size(0) == dim1 and self.size(1) == dim2,
3040
lambda: "self tensor does not match matmul output shape",
3042
return self.new_empty(self.size())
3045
def register_meta_foreach(ops):
3048
op_name = str(op).split(".")[1]
3049
scalar_op = getattr(aten, op_name.replace("_foreach_", ""))
3051
_add_op_to_registry(
3056
_scalar_op=scalar_op,
3060
pytree.tree_map_(register, ops)
3066
@register_meta_foreach(
3078
aten._foreach_expm1,
3080
aten._foreach_floor,
3081
aten._foreach_lgamma,
3083
aten._foreach_log10,
3084
aten._foreach_log1p,
3088
aten._foreach_reciprocal,
3089
aten._foreach_round,
3090
aten._foreach_sigmoid,
3097
aten._foreach_trunc,
3103
aten._foreach_clamp_min,
3104
aten._foreach_clamp_max,
3108
def _meta_foreach_out_of_place(*args, _scalar_op=None, **kwargs):
3110
isinstance(args[0], list),
3111
lambda: (f"The first argument must be List[Tensor], but got {type(args[0])}."),
3114
nelem = len(args[0])
3117
lambda: ("Tensor list must have at least one tensor."),
3121
for iarg, arg in enumerate(args[1:]):
3122
if isinstance(arg, list):
3127
f"self and argument-{iarg+2} must match in length, "
3128
f"but got {nelem} and {len(arg)}."
3131
elif isinstance(arg, Tensor):
3133
arg.dim() == 0 and arg.numel() == 1,
3135
"scalar tensor expected to be 0 dim but it has "
3136
f"{arg.dim()} dimensions and {arg.numel()} elements."
3143
for elem in range(nelem):
3144
each_args = [args[i][elem] for i in range(nlists)]
3145
result.append(_scalar_op(*each_args, *args[nlists:], **kwargs))
3150
@register_meta_foreach(
3153
aten._foreach_acos_,
3154
aten._foreach_asin_,
3155
aten._foreach_atan_,
3156
aten._foreach_ceil_,
3158
aten._foreach_cosh_,
3160
aten._foreach_erfc_,
3162
aten._foreach_expm1_,
3163
aten._foreach_frac_,
3164
aten._foreach_floor_,
3165
aten._foreach_lgamma_,
3167
aten._foreach_log10_,
3168
aten._foreach_log1p_,
3169
aten._foreach_log2_,
3171
aten._foreach_reciprocal_,
3172
aten._foreach_round_,
3173
aten._foreach_sigmoid_,
3174
aten._foreach_sign_,
3176
aten._foreach_sinh_,
3177
aten._foreach_sqrt_,
3179
aten._foreach_tanh_,
3180
aten._foreach_trunc_,
3181
aten._foreach_zero_,
3186
aten._foreach_clamp_min_,
3187
aten._foreach_clamp_max_,
3188
aten._foreach_lerp_,
3189
aten._foreach_copy_,
3192
def _meta_foreach_inplace(*args, _scalar_op=None, **kwargs):
3193
_meta_foreach_out_of_place(*args, _scalar_op=_scalar_op, **kwargs)
3197
@register_meta([aten._foreach_pow.ScalarAndTensor])
3198
def meta__foreach_pow_scalar_and_tensor(self, exponent):
3202
isinstance(exponent, List),
3203
lambda: f"exponent must be a tensor list but got {type(exponent)}",
3205
return [torch.empty_like(e) for e in exponent]
3208
def _check_foreach_binop_tensor_lists(self, other):
3210
isinstance(self, List) and isinstance(other, List),
3212
"The first two arguments of must be List[Tensor], "
3213
f"but got {type(self)} and {type(other)}."
3217
len(self) > 0 and len(self) == len(other),
3219
"self and other must be non-empty and match in length, "
3220
f"but got {len(self)} and {len(other)}."
3227
aten._foreach_maximum,
3228
aten._foreach_minimum,
3231
def meta__foreach_binop_scalar(*args):
3233
return _meta_foreach_out_of_place(*args, _scalar_op=aten.clamp_min)
3238
aten._foreach_maximum_,
3239
aten._foreach_minimum_,
3242
def meta__foreach_binop__scalar(*args):
3244
_meta_foreach_inplace(*args, _scalar_op=aten.clamp_min_)
3250
aten._foreach_addcdiv.Scalar,
3251
aten._foreach_addcmul.Scalar,
3254
def meta__foreach_addcop_scalar(self, tensor1, tensor2, scalar=1):
3258
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
3260
"All arguments must be List[Tensor], "
3261
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
3264
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3266
len(self) == len(tensor1) and len(self) == len(tensor2),
3267
lambda: "All input tensor lists must have the same length",
3270
return [torch.empty_like(s) for s in self]
3273
@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
3274
def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
3276
all(isinstance(l, List) for l in [self, tensor1, tensor2])
3277
and isinstance(scalars, torch.Tensor),
3279
"_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, "
3280
f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
3283
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3285
len(self) == len(tensor1) and len(self) == len(tensor2),
3286
lambda: "All input tensor lists must have the same length",
3292
aten._foreach_addcdiv_.Scalar,
3293
aten._foreach_addcmul_.Scalar,
3296
def meta__foreach_addcop__scalar(self, tensor1, tensor2, scalar=1):
3298
all(isinstance(l, List) for l in [self, tensor1, tensor2]),
3300
"All arguments of _foreach_addc*_ must be List[Tensor], "
3301
f"but got {type(self)}, {type(tensor1)}, and {type(tensor2)}"
3304
torch._check(len(self) > 0, lambda: "input tensor list must not be empty.")
3306
len(self) == len(tensor1) and len(self) == len(tensor2),
3307
lambda: "All input tensor lists must have the same length",
3311
@register_meta([aten._fused_adam_.default])
3312
def meta__fused_adam_(
3330
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3332
isinstance(l, List),
3333
lambda: f"exponent must be a tensor list but got {type(l)}",
3337
@register_meta([aten._fused_adam.default])
3338
def meta__fused_adam(
3356
for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
3358
isinstance(l, List),
3359
lambda: f"exponent must be a tensor list but got {type(l)}",
3362
def empty_like_list(tensor_list):
3363
return [torch.empty_like(t) for t in tensor_list]
3366
empty_like_list(self),
3367
empty_like_list(grads),
3368
empty_like_list(exp_avgs),
3369
empty_like_list(exp_avg_sqs),
3370
empty_like_list(max_exp_avg_sqs),
3374
@register_meta([aten._int_mm])
3376
def meta__int_mm(a, b):
3377
torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
3378
torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
3380
a.dtype is torch.int8,
3381
lambda: f"expected self to be int8, got {a.dtype}",
3384
b.dtype is torch.int8,
3385
lambda: f"expected mat2 to be int8, got {b.dtype}",
3388
a.size(1) == b.size(0),
3390
f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
3391
f"and {b.size(0)}x{b.size(1)})"
3394
return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
3397
@register_meta([aten._convert_weight_to_int4pack])
3398
def meta__convert_weight_to_int4pack(w, inner_k_tiles):
3399
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3401
w.dtype is torch.int32,
3402
lambda: f"expected w to be int32, got {w.dtype}",
3406
if device_hint(w) == "cpu":
3415
k // (inner_k_tiles * 16),
3423
@register_meta([aten._weight_int4pack_mm])
3424
def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
3425
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3427
x.dtype is torch.bfloat16,
3428
lambda: f"expected x to be bf16, got {x.dtype}",
3430
if device_hint(w) == "cpu":
3431
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3433
w.dtype is torch.uint8,
3434
lambda: f"expected w to be uint8, got {w.dtype}",
3437
torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
3439
w.dtype is torch.int32,
3440
lambda: f"expected w to be int32, got {w.dtype}",
3442
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3445
@register_meta([aten._weight_int8pack_mm])
3446
def meta__weight_int8pack_mm(x, w, q_scales):
3447
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
3449
x.dtype is torch.bfloat16,
3450
lambda: f"expected x to be bf16, got {x.dtype}",
3452
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
3454
w.dtype is torch.int8,
3455
lambda: f"expected w to be int8, got {w.dtype}",
3457
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
3460
@register_meta(aten._cdist_forward.default)
3461
def meta_cdist_forward(x1, x2, p, compute_mode):
3464
lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
3468
lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
3471
x1.size(-1) == x2.size(-1),
3472
lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
3475
utils.is_float_dtype(x1.dtype),
3476
lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
3479
utils.is_float_dtype(x2.dtype),
3480
lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
3482
torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
3484
compute_mode in (None, 1, 2),
3485
lambda: f"possible modes: None, 1, 2, but was: {compute_mode}",
3489
batch_tensor1 = x1.shape[:-2]
3490
batch_tensor2 = x2.shape[:-2]
3491
output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3492
output_shape.extend([r1, r2])
3493
return x1.new_empty(output_shape)
3496
@register_meta(aten._cdist_backward)
3498
def meta_cdist_backward(grad, x1, x2, p, cdist):
3502
batch_tensor1 = x1.shape[:-2]
3503
batch_tensor2 = x2.shape[:-2]
3504
expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
3505
tensor1_expand_size = expand_batch_portion.copy()
3506
tensor1_expand_size.extend([r1, c1])
3507
batch_product = math.prod(expand_batch_portion)
3508
if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
3509
return torch.zeros_like(x1)
3510
if tensor1_expand_size != list(x1.shape):
3511
x1 = x1.expand(tensor1_expand_size)
3512
return torch.empty_like(x1, memory_format=torch.contiguous_format)
3519
@register_meta(aten._embedding_bag.default)
3520
def meta_embedding_bag(
3524
scale_grad_by_freq=False,
3527
per_sample_weights=None,
3528
include_last_offset=False,
3532
indices.dtype in (torch.long, torch.int),
3533
lambda: f"expected indices to be long or int, got {indices.dtype}",
3536
offsets.dtype in (torch.long, torch.int),
3537
lambda: f"expected offsets to be long or int, got {offsets.dtype}",
3540
utils.is_float_dtype(weight.dtype),
3541
lambda: f"expected weight to be floating point type, got {weight.dtype}",
3544
num_bags = offsets.size(0)
3545
if include_last_offset:
3548
lambda: "include_last_offset: numBags should be at least 1",
3552
output = weight.new_empty(num_bags, weight.size(1))
3553
MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
3555
if per_sample_weights is not None:
3558
lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
3561
per_sample_weights.dtype == weight.dtype,
3562
lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
3565
per_sample_weights.ndim == 1,
3566
lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
3569
per_sample_weights.numel() == indices.numel(),
3571
f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
3572
f"to be the same as indices.numel() ({indices.numel()})"
3576
def is_fast_path_index_select_scale(src, scale, output, padding_idx):
3578
is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
3581
def is_fast_path_index_select(src, output, padding_idx):
3583
(src.dtype == torch.float or src.dtype == torch.half)
3584
and src.stride(1) == 1
3585
and output.stride(1) == 1
3589
def is_fast_path(src, scale, output, padding_idx):
3590
if scale is not None:
3591
return is_fast_path_index_select_scale(src, scale, output, padding_idx)
3593
return is_fast_path_index_select(src, output, padding_idx)
3595
if device_hint(offsets) != "cpu":
3596
offset2bag = indices.new_empty(indices.size(0))
3597
bag_size = indices.new_empty(offsets.size())
3598
if mode == MODE_MAX:
3599
max_indices = indices.new_empty(num_bags, weight.size(1))
3601
max_indices = indices.new_empty(0)
3603
fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
3604
if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
3605
offset2bag = offsets.new_empty(indices.size(0))
3607
offset2bag = offsets.new_empty(0)
3608
bag_size = offsets.new_empty(num_bags)
3610
numBags = offsets.shape[0]
3611
if mode == MODE_MAX:
3612
if include_last_offset:
3615
lambda: "include_last_offset: numBags should be at least 1",
3618
max_indices = offsets.new_empty(numBags, weight.shape[1])
3620
max_indices = offsets.new_empty(bag_size.size())
3621
return output, offset2bag, bag_size, max_indices
3624
@register_meta(aten._embedding_bag_forward_only.default)
3625
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
3626
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
3627
weight, indices, offsets, *args
3629
if device_hint(offsets) == "cpu":
3630
bag_size = offsets.new_empty(offsets.size())
3631
return output, offset2bag, bag_size, max_indices
3634
def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
3639
if input.dtype.is_floating_point or input.dtype.is_complex:
3641
elif promote_int_to_long:
3647
@register_meta([aten.nansum.default, aten.nansum.out])
3649
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
3650
output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
3651
dims = utils.reduction_dims(input.shape, dims)
3652
output_shape = _compute_reduction_shape(input, dims, keepdim)
3653
return input.new_empty(output_shape, dtype=output_dtype)
3656
@register_meta([aten.median.default, aten.nanmedian.default])
3657
def meta_median(input):
3658
output_shape = utils.compute_reduction_output_shape(
3659
input.shape, tuple(range(input.dim()))
3661
return input.new_empty(output_shape)
3667
aten.median.dim_values,
3669
aten.nanmedian.dim_values,
3674
@out_wrapper("values", "indices")
3675
def meta_median_mode_dim(input, dim=-1, keepdim=False):
3676
if device_hint(input) == "cuda":
3677
utils.alert_not_deterministic("median CUDA with indices output")
3678
dim = utils.reduction_dims(input.shape, (dim,))
3679
output_shape = _compute_reduction_shape(input, dim, keepdim)
3681
input.new_empty(output_shape),
3682
input.new_empty(output_shape, dtype=torch.long),
3686
@register_meta(aten.logical_not_.default)
3687
def meta_logical_not_(self):
3691
@register_meta(aten.repeat.default)
3692
def meta_repeat(self, repeats):
3694
len(repeats) >= self.dim(),
3695
lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
3700
num_new_dimensions = len(repeats) - self.dim()
3701
padded_size = (1,) * num_new_dimensions + tuple(self.shape)
3702
target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
3703
return self.new_empty(target_size)
3706
@register_meta(aten.zero_.default)
3707
def meta_zero_(self):
3717
aten.logical_and_.default,
3718
aten.logical_or_.default,
3719
aten.logical_xor_.default,
3722
def meta_binop_inplace(self, other):
3723
if isinstance(other, torch.Tensor):
3724
check_inplace_broadcast(self.shape, other.shape)
3736
def meta_binop_inplace_alpha(self, other, alpha=1):
3737
if isinstance(other, torch.Tensor):
3738
check_inplace_broadcast(self.shape, other.shape)
3742
@register_meta([aten.round.default, aten.round.decimals])
3743
def meta_round(self, **kwargs):
3744
return elementwise_meta(
3745
self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3749
def shift_dtype_check(fn_name, self, val):
3751
utils.is_integer_dtype(self.dtype),
3752
lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
3754
if isinstance(val, torch.Tensor):
3756
utils.is_integer_dtype(val.dtype),
3757
lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
3761
isinstance(val, IntLike),
3762
lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
3766
@register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
3767
def meta_rshifts(self, other):
3768
shift_dtype_check("rshift", self, other)
3769
return elementwise_meta(
3770
self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3774
@register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
3775
def meta_lshifts(self, other):
3776
shift_dtype_check("lshift", self, other)
3777
return elementwise_meta(
3778
self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
3782
@register_meta(aten.zero.default)
3784
return self.new_empty(self.shape)
3787
@register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
3788
def meta_fill_(self, val):
3792
@register_meta([aten.fill.Tensor, aten.fill.Scalar])
3793
def meta_fill(self, val):
3794
return torch.empty_like(self)
3797
@register_meta(aten.relu_.default)
3798
def meta_relu_(self):
3802
@register_meta([aten.index_put.default, aten._unsafe_index_put.default])
3803
def meta_index_put(self, indices, values, accumulate=False):
3804
return torch.empty_like(self)
3807
@register_meta(aten.masked_fill_.Scalar)
3808
def meta_masked_fill_(self, mask, value):
3809
check_inplace_broadcast(self.shape, mask.shape)
3813
@register_meta(aten.masked_scatter_)
3814
def meta_masked_scatter_(self, mask, source):
3816
mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
3819
self.dtype == source.dtype,
3820
lambda: "masked_scatter: expected self and source to have same "
3821
"dtypes but got {self.dtype} and {source.dtype}",
3826
@register_meta(aten.masked_scatter)
3828
def meta_masked_scatter(self, mask, source):
3829
self, mask = _maybe_broadcast(self, mask)
3830
output = torch.empty_like(self, memory_format=torch.contiguous_format)
3831
return meta_masked_scatter_(output, mask, source)
3834
@register_meta(aten.masked_scatter_backward)
3835
def meta_masked_scatter_backward(self, mask, sizes):
3836
return self.new_empty(sizes)
3839
@register_meta(aten.index_put_.default)
3840
def meta_index_put_(self, indices, values, accumulate=False):
3844
@register_meta(aten.alias.default)
3845
def meta_alias(self):
3846
return self.view(self.shape)
3849
def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None):
3850
torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
3851
torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
3853
batch1_sizes = batch1.size()
3854
batch2_sizes = batch2.size()
3856
bs = batch1_sizes[0]
3857
contraction_size = batch1_sizes[2]
3858
res_rows = batch1_sizes[1]
3859
res_cols = batch2_sizes[2]
3860
output_size = (bs, res_rows, res_cols)
3863
batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
3864
lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
3865
f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
3870
output = batch2.new_empty(output_size)
3872
if not is_bmm and self_baddbmm is not None:
3873
torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
3875
self_baddbmm.size() == output_size,
3876
lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
3882
@register_meta(aten.bmm.default)
3883
def meta_bmm(self, mat2):
3884
return common_meta_baddbmm_bmm(self, mat2, True)
3892
if r != 0 and (bool(r < 0) != bool(y < 0)):
3897
def pooling_output_shape_pad_lr(
3898
inputSize, kernelSize, pad_l, pad_r, stride, dilation, ceil_mode
3905
- dilation * (kernelSize - 1)
3907
+ (stride - 1 if ceil_mode else 0),
3913
if (outputSize - 1) * stride >= inputSize + pad_l:
3918
def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
3919
torch._check(stride != 0, lambda: "stride should not be zero")
3920
torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
3922
pad <= ((kernelSize - 1) * dilation + 1) // 2,
3924
f"pad should be at most half of effective kernel size, but got pad={pad}, "
3925
f"kernel_size={kernelSize} and dilation={dilation}"
3928
return pooling_output_shape_pad_lr(
3929
inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
3933
def pool2d_shape_check(
3951
nOutputPlane = nInputPlane
3955
lambda: "kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
3959
lambda: "stride should be greater than zero, but got dH: {dH}, dW: {dW}",
3962
dilationH > 0 and dilationW > 0,
3963
lambda: "dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
3966
valid_dims = input.size(1) != 0 and input.size(2) != 0
3968
if memory_format == torch.channels_last:
3970
ndim == 4 and valid_dims and input.size(3) != 0,
3971
lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
3972
" with optional 0 dim batch size for input, but got: {input.size()}",
3976
(ndim == 3 and input.size(0) != 0 and valid_dims)
3977
or (ndim == 4 and valid_dims and input.size(3) != 0),
3978
lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
3982
kW // 2 >= padW and kH // 2 >= padH,
3983
lambda: "pad should be smaller than or equal to half of kernel size, but got "
3984
f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
3988
outputWidth >= 1 and outputHeight >= 1,
3989
lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
3990
f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
3991
"Output size is too small",
3995
def pool3d_shape_check(
4017
check_input_size: bool = False,
4022
kT > 0 and kW > 0 and kH > 0,
4024
f"kernel size should be greater than zero, but got "
4025
f"kT: {kT}, kH: {kH}, kW: {kW}"
4029
dT > 0 and dW > 0 and dH > 0,
4031
f"stride should be greater than zero, but got "
4032
f"dT: {dT}, dH: {dH}, dW: {dW}"
4036
dilationT > 0 and dilationW > 0 and dilationH > 0,
4038
f"dilation should be greater than zero, but got "
4039
f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
4045
lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
4048
for i in range(ndim):
4049
if ndim == 5 and i == 0:
4055
f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
4056
f" but input has a shape of {input.shape}"
4057
f" and non-batch dimension {input.size(i)} has length zero!"
4061
if check_input_size:
4063
itime >= kT and iheight >= kH and iwidth >= kW,
4065
f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
4066
f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
4071
kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
4073
f"pad should be smaller than or equal to half of kernel size, but got "
4074
f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
4079
otime >= 1 and owidth >= 1 and oheight >= 1,
4081
f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
4082
f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
4083
f"Output size is too small"
4088
def max_pool3d_backward_shape_check(
4139
check_dim_size(grad_output, ndim, ndim - 4, nslices)
4140
check_dim_size(grad_output, ndim, ndim - 3, otime)
4141
check_dim_size(grad_output, ndim, ndim - 2, oheight)
4142
check_dim_size(grad_output, ndim, ndim - 1, owidth)
4144
check_dim_size(indices, ndim, ndim - 4, nslices)
4145
check_dim_size(indices, ndim, ndim - 3, otime)
4146
check_dim_size(indices, ndim, ndim - 2, oheight)
4147
check_dim_size(indices, ndim, ndim - 1, owidth)
4150
def avg_pool3d_backward_shape_check(
4152
grad_output: Tensor,
4198
check_dim_size(grad_output, ndim, ndim - 4, nslices)
4199
check_dim_size(grad_output, ndim, ndim - 3, otime)
4200
check_dim_size(grad_output, ndim, ndim - 2, oheight)
4201
check_dim_size(grad_output, ndim, ndim - 1, owidth)
4204
def max_pool2d_checks_and_compute_shape(
4205
input, kernel_size, stride, padding, dilation, ceil_mode
4208
def unpack(name, val):
4211
lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
4214
W = H if len(val) == 1 else val[1]
4217
kH, kW = unpack("kernel_size", kernel_size)
4220
len(stride) in [0, 1, 2],
4221
lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
4223
if len(stride) == 0:
4226
dH, dW = unpack("stride", stride)
4228
padH, padW = unpack("padding", padding)
4229
dilationH, dilationW = unpack("dilation", dilation)
4230
nInputPlane = input.size(-3)
4231
inputHeight = input.size(-2)
4232
inputWidth = input.size(-1)
4234
memory_format = utils.suggest_memory_format(input)
4235
if memory_format == torch.channels_last:
4238
lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
4240
elif memory_format == torch.contiguous_format:
4242
input.dim() in [3, 4],
4243
lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
4248
lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous",
4251
outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
4252
outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
4272
return nInputPlane, outputHeight, outputWidth
4275
@register_meta(aten.max_pool2d_with_indices_backward.default)
4276
def meta_max_pool2d_with_indices_backward(
4290
) = max_pool2d_checks_and_compute_shape(
4291
self, kernel_size, stride, padding, dilation, ceil_mode
4295
self.dtype == grad_output.dtype,
4296
lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
4299
nOutputPlane = nInputPlane
4302
def _check_dim_size(t):
4303
check_dim_size(t, ndim, ndim - 3, nOutputPlane)
4304
check_dim_size(t, ndim, ndim - 2, outputHeight)
4305
check_dim_size(t, ndim, ndim - 1, outputWidth)
4307
_check_dim_size(grad_output)
4308
_check_dim_size(indices)
4310
memory_format = utils.suggest_memory_format(self)
4315
memory_format=memory_format,
4319
@register_meta(aten.max_pool2d_with_indices.default)
4320
def meta_max_pool2d_with_indices(
4321
input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False
4327
) = max_pool2d_checks_and_compute_shape(
4328
input, kernel_size, stride, padding, dilation, ceil_mode
4331
nbatch = input.size(-4) if input.dim() == 4 else 1
4332
memory_format = utils.suggest_memory_format(input)
4333
if input.dim() == 3:
4334
size = [nInputPlane, outputHeight, outputWidth]
4336
size = [nbatch, nInputPlane, outputHeight, outputWidth]
4341
device=input.device,
4342
memory_format=memory_format,
4347
device=input.device,
4348
memory_format=memory_format,
4353
@register_meta(aten.fractional_max_pool2d.default)
4354
def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
4356
self_.ndim in (3, 4),
4357
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}",
4361
for d in range(ndim - 3, ndim):
4364
f"fractional_max_pool2d: Expected input to have non-zero "
4365
f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty",
4370
len(kernel_size) == 2,
4371
lambda: "fractional_max_pool2d: kernel_size must"
4372
"either be a single int or tuple of Ints",
4375
len(output_size) == 2,
4376
lambda: "fractional_max_pool2d: output_size must "
4377
"either be a single int or tuple of Ints",
4380
input_channels = self_.size(-3)
4381
input_height = self_.size(-2)
4382
input_width = self_.size(-1)
4384
input_batch = self_.size(0)
4389
self_.dtype == random_samples.dtype,
4390
lambda: "Expect _random_samples to have the same dtype as input",
4393
random_samples.ndim == 3,
4394
lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
4397
n = random_samples.size(0)
4398
c = random_samples.size(1)
4399
d = random_samples.size(2)
4402
"Expect _random_samples.size(0) no less then input batch size.",
4405
c == input_channels,
4406
lambda: "Expect _random_samples.size(1) equals to input channel size.",
4408
torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
4411
output_size[0] + kernel_size[0] - 1 <= input_height,
4412
lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
4415
output_size[1] + kernel_size[1] - 1 <= input_width,
4416
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
4419
if self_.dim() == 4:
4420
size = [input_batch, input_channels, output_size[0], output_size[1]]
4422
size = [input_channels, output_size[0], output_size[1]]
4428
device=self_.device,
4433
device=self_.device,
4438
@register_meta(aten.max_unpool2d)
4440
def meta_max_unpool2d(self_, indices, output_size):
4441
utils.alert_not_deterministic("max_unpooling2d_forward_out")
4444
indices.dtype == torch.int64,
4445
lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
4448
len(output_size) == 2,
4450
f"There should be exactly two elements (height, width) in output_size, "
4451
f"but got {len(output_size)} elements."
4455
oheight, owidth = output_size
4458
self_.ndim in (3, 4),
4460
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
4461
f"but got a tensor with {self_.ndim} dimensions."
4465
self_.shape == indices.shape,
4467
f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
4468
f"but got indices tensor with shape: {indices.shape}"
4472
for i in range(1, self_.ndim):
4476
f"max_unpooling2d(): "
4477
f"Expected input to have non-zero size for non-batch dimensions, "
4478
f"but got {self_.shape} with dimension {i} being empty."
4482
self = self_.contiguous()
4485
nchannels = self.size(0)
4486
result = self.new_empty((nchannels, oheight, owidth))
4488
nbatch = self.size(0)
4489
nchannels = self.size(1)
4490
result = self.new_empty((nbatch, nchannels, oheight, owidth))
4495
def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, fn_name):
4497
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
4500
input.ndim in (4, 5),
4501
lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
4504
len(output_size) == 3,
4506
f"There should be exactly three elements (depth, height, width) in output_size, "
4507
f"but got {len(output_size)} elements."
4512
lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
4516
lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
4519
input.shape == indices.shape,
4521
f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
4522
f"but got indices tensor with shape: {indices.shape}"
4526
for i in range(1, input.ndim):
4531
f"Expected input to have non-zero size for non-batch dimensions, "
4532
f"but got {input.shape} with dimension {i} being empty."
4537
stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
4538
lambda: f"strides should be greater than zero, but got stride: {stride}",
4542
@register_meta(aten.max_unpool3d)
4544
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
4545
utils.alert_not_deterministic("max_unpooling3d_forward_out")
4547
_max_unpooling3d_shape_check(
4548
self_, indices, output_size, stride, padding, "max_unpooling3d()"
4551
self = self_.contiguous()
4553
odepth, oheight, owidth = output_size
4556
nchannels = self.size(0)
4557
result = self.new_empty((nchannels, odepth, oheight, owidth))
4559
nbatch = self.size(0)
4560
nchannels = self.size(1)
4561
result = self.new_empty((nbatch, nchannels, odepth, oheight, owidth))
4566
@register_meta(aten.max_pool3d_with_indices)
4567
@out_wrapper("out", "indices")
4568
def meta_max_pool3d_with_indices(
4577
len(kernel_size) in (1, 3),
4578
lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4581
kH = kT if len(kernel_size) == 1 else kernel_size[1]
4582
kW = kT if len(kernel_size) == 1 else kernel_size[2]
4585
not stride or len(stride) in (1, 3),
4586
lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4588
dT = kT if not stride else stride[0]
4589
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4590
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4593
len(padding) in (1, 3),
4594
lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4597
pH = pT if len(padding) == 1 else padding[1]
4598
pW = pT if len(padding) == 1 else padding[2]
4601
len(dilation) in (1, 3),
4602
lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4604
dilationT = dilation[0]
4605
dilationH = dilationT if len(dilation) == 1 else dilation[1]
4606
dilationW = dilationT if len(dilation) == 1 else dilation[2]
4609
input.ndim in (4, 5),
4610
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4613
nbatch = input.size(-5) if input.ndim == 5 else 1
4614
nslices = input.size(-4)
4615
itime = input.size(-3)
4616
iheight = input.size(-2)
4617
iwidth = input.size(-1)
4619
otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
4620
oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
4621
owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
4644
"max_pool3d_with_indices()",
4648
input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4651
input_channels_last_check = input.unsqueeze(0)
4653
not input_channels_last_check.is_contiguous()
4654
) and input_channels_last_check.is_contiguous(
4655
memory_format=torch.channels_last_3d
4657
out_shape = (nslices, otime, oheight, owidth)
4659
out_shape = (nbatch, nslices, otime, oheight, owidth)
4661
out = input.new_empty(out_shape)
4662
indices = input.new_empty(out_shape, dtype=torch.int64)
4665
out = out.to(memory_format=torch.channels_last_3d)
4666
indices = indices.to(memory_format=torch.channels_last_3d)
4671
@register_meta(aten.max_pool3d_with_indices_backward)
4672
@out_wrapper("grad_input")
4673
def meta_max_pool3d_with_indices_backward(
4684
len(kernel_size) in (1, 3),
4685
lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
4688
kH = kT if len(kernel_size) == 1 else kernel_size[1]
4689
kW = kT if len(kernel_size) == 1 else kernel_size[2]
4692
not stride or len(stride) in (1, 3),
4693
lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
4695
dT = kT if not stride else stride[0]
4696
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
4697
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
4700
len(padding) in (1, 3),
4701
lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
4704
pH = pT if len(padding) == 1 else padding[1]
4705
pW = pT if len(padding) == 1 else padding[2]
4708
len(dilation) in (1, 3),
4709
lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
4711
dilationT = dilation[0]
4712
dilationH = dilationT if len(dilation) == 1 else dilation[1]
4713
dilationW = dilationT if len(dilation) == 1 else dilation[2]
4716
input.ndim in (4, 5),
4717
lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
4720
nslices = input.size(-4)
4721
itime = input.size(-3)
4722
iheight = input.size(-2)
4723
iwidth = input.size(-1)
4725
otime = grad_output.size(-3)
4726
oheight = grad_output.size(-2)
4727
owidth = grad_output.size(-1)
4729
max_pool3d_backward_shape_check(
4752
"max_pool3d_with_indices_backward()",
4756
input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
4759
input_channels_last_check = input.unsqueeze(0)
4761
not input_channels_last_check.is_contiguous()
4762
) and input_channels_last_check.is_contiguous(
4763
memory_format=torch.channels_last_3d
4766
grad_input = input.new_empty(input.shape)
4769
grad_input = grad_input.to(memory_format=torch.channels_last_3d)
4774
def check_grid_sampler_common(input: Tensor, grid: Tensor):
4776
input.device == grid.device,
4778
f"grid_sampler(): expected input and grid to be on same device, but input "
4779
f"is on {input.device} and grid is on {grid.device}"
4783
input.layout == torch.strided and grid.layout == torch.strided,
4785
f"grid_sampler(): expected input and grid to have torch.strided layout, but "
4786
f"input has {input.layout} and grid has {grid.layout}"
4790
input.shape[0] == grid.shape[0],
4792
f"grid_sampler(): expected grid and input to have same batch size, but got "
4793
f"input with sizes {input.shape} and grid with sizes {grid.shape}"
4797
grid.shape[-1] == input.ndim - 2,
4799
f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
4800
f"dimension, but got grid with sizes {grid.shape}"
4804
for i in range(2, input.ndim):
4808
f"grid_sampler(): expected input to have non-empty spatial dimensions, "
4809
f"but input has sizes {input.shape} with dimension {i} being empty"
4814
class GridSamplerInterpolation(Enum):
4820
def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
4822
input.ndim == 5 and input.ndim == grid.ndim,
4824
f"grid_sampler(): expected 5D input and grid with same number of "
4825
f"dimensions, but got input with sizes {input.shape}"
4826
f" and grid with sizes {grid.shape}"
4832
and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
4834
lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
4838
@register_meta(aten.grid_sampler_2d_backward.default)
4839
def grid_sampler_2d_backward_meta(
4848
input_requires_grad = output_mask[0]
4849
if input_requires_grad:
4850
grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
4853
grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
4854
return (grad_input, grad_grid)
4857
@register_meta(aten.grid_sampler_3d)
4866
check_grid_sampler_common(input, grid)
4867
check_grid_sampler_3d(input, grid, interpolation_mode)
4870
out_D = grid.shape[1]
4871
out_H = grid.shape[2]
4872
out_W = grid.shape[3]
4873
return input.new_empty((N, C, out_D, out_H, out_W))
4876
@register_meta(aten.grid_sampler_3d_backward)
4877
@out_wrapper("grad_input", "grad_grid")
4878
def grid_sampler_3d_backward(
4887
check_grid_sampler_common(input, grid)
4888
check_grid_sampler_3d(input, grid, interpolation_mode)
4889
input_requires_grad = output_mask[0]
4890
if input_requires_grad:
4891
grad_input = torch.zeros_like(
4892
input, memory_format=torch.legacy_contiguous_format
4896
grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
4897
return grad_input, grad_grid
4900
@register_meta([aten.full.default])
4901
def full(size, fill_value, *args, **kwargs):
4902
dtype = kwargs.get("dtype", None)
4904
dtype = utils.get_dtype(fill_value)
4905
kwargs["dtype"] = dtype
4906
return torch.empty(size, *args, **kwargs)
4910
@register_meta(aten.zeros_like.default)
4919
if layout == torch.sparse_coo:
4921
memory_format is None,
4922
lambda: "memory format option is only supported by strided tensors",
4927
dtype=self.dtype if dtype is None else dtype,
4929
device=self.device if device is None else device,
4930
pin_memory=pin_memory,
4934
res.sparse_resize_and_clear_(
4935
self.size(), self.sparse_dim(), self.dense_dim()
4938
res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
4940
res._coalesced_(True)
4942
res = aten.empty_like.default(
4947
pin_memory=pin_memory,
4948
memory_format=memory_format,
4955
@register_meta(aten.select.int)
4956
def meta_select(self, dim, index):
4960
lambda: "select() cannot be applied to a 0-dim tensor.",
4963
dim = dim if dim >= 0 else dim + ndim
4964
size = self.size(dim)
4967
not (-index > size or index >= size),
4968
lambda: f"select(): index {index} out of range for tensor of size "
4969
f"{self.size()} at dimension {dim}",
4972
index = index if index >= 0 else index + size
4974
new_size = list(self.size())
4975
new_stride = list(self.stride())
4977
new_storage_offset = self.storage_offset() + index * new_stride[dim]
4981
return self.as_strided(new_size, new_stride, new_storage_offset)
4984
@register_meta(aten.select_scatter.default)
4985
def meta_select_scatter(self, src, dim, index):
4986
return utils.clone_preserve_strides(self)
4989
@register_meta(aten.slice_scatter.default)
4990
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
4991
return utils.clone_preserve_strides(self)
4995
def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
4996
if dim_post_expr <= 0:
4999
min = -dim_post_expr
5000
max = dim_post_expr - 1
5001
assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
5003
dim += dim_post_expr
5007
def ensure_nonempty_size(t, dim):
5008
return 1 if t.dim() == 0 else t.shape[dim]
5012
def gather_shape_check(self, dim, index):
5013
self_dims = max(self.dim(), 1)
5014
index_dims = max(index.dim(), 1)
5016
self_dims == index_dims,
5017
lambda: "Index tensor must have the same number of dimensions as input tensor",
5019
for i in range(self_dims):
5022
ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
5023
lambda: f"Size does not match at dimension {i} expected index {index.shape}"
5024
+ f" to be smaller than self {self.shape} apart from dimension {dim}",
5028
@register_meta(aten.gather.default)
5029
def meta_gather(self, dim, index, sparse_grad=False):
5030
wrapped_dim = maybe_wrap_dim(dim, self.dim())
5031
is_index_empty = index.numel() == 0
5032
if not is_index_empty:
5034
index.dtype == torch.long,
5035
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
5037
gather_shape_check(self, wrapped_dim, index)
5038
return self.new_empty(index.shape)
5042
def get_operator_enum(reduce_, use_new_options=False):
5044
if reduce_ == "sum":
5046
elif reduce_ == "prod":
5047
return "REDUCE_MULTIPLY"
5048
elif reduce_ == "mean":
5049
return "REDUCE_MEAN"
5050
elif reduce_ == "amax":
5051
return "REDUCE_MAXIMUM"
5052
elif reduce_ == "amin":
5053
return "REDUCE_MINIMUM"
5056
lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
5060
if reduce_ == "add":
5062
elif reduce_ == "multiply":
5063
return "REDUCE_MULTIPLY"
5064
torch._check(False, lambda: "reduce argument must be either add or multiply.")
5069
def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
5070
if index.numel() != 0:
5072
index.dtype == torch.long,
5073
lambda: f"{method_name}(): Expected dtype int64 for index",
5076
if src_opt is not None:
5078
self.dtype == src_opt.dtype,
5079
lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
5083
def ensure_nonempty_dim(dim):
5088
def scatter_shape_check(self, dim, index, src_opt=None):
5089
if index.numel() == 0:
5092
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
5093
lambda: "Index tensor must have the same number of dimensions as self tensor",
5096
is_wrong_shape = False
5097
self_dims = ensure_nonempty_dim(self.dim())
5100
for d in range(self_dims):
5101
index_d_size = ensure_nonempty_size(index, d)
5104
if index_d_size > ensure_nonempty_size(self, d):
5105
is_wrong_shape = True
5109
if not is_wrong_shape and src_opt is not None:
5110
for d in range(self_dims):
5111
index_d_size = ensure_nonempty_size(index, d)
5112
if index_d_size > ensure_nonempty_size(src_opt, d):
5113
is_wrong_shape = True
5116
if src_opt is not None:
5118
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
5119
lambda: "Index tensor must have the same number of dimensions as self tensor",
5123
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
5124
+ f" apart from dimension {dim} and to be smaller than src {src_opt.shape}",
5129
lambda: f"Expected index {index.shape} to be smaller than self {self.shape}"
5130
+ f" apart from dimension {dim}",
5135
def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
5136
wrapped_dim = maybe_wrap_dim(dim, self.dim())
5137
scatter_gather_dtype_check("scatter", self, index, src)
5138
scatter_shape_check(self, wrapped_dim, index, src)
5139
if reduce_ is not None:
5141
get_operator_enum(reduce_, use_new_options)
5144
@register_meta(aten.scatter_add.default)
5145
def meta_scatter_add(self, dim, index, src):
5146
scatter_meta_impl(self, dim, index, src, "add")
5147
return self.new_empty(self.shape)
5150
@register_meta(aten.scatter_add_)
5151
def meta_scatter_add_(self, dim, index, src):
5152
scatter_meta_impl(self, dim, index, src, "add")
5160
aten.scatter.reduce,
5161
aten.scatter.value_reduce,
5165
def meta_scatter(self, dim, index, src_or_value, reduce=None):
5166
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5167
scatter_meta_impl(self, dim, index, src, reduce)
5168
return self.new_empty(self.shape)
5174
aten.scatter_.value,
5175
aten.scatter_.reduce,
5176
aten.scatter_.value_reduce,
5179
def meta_scatter_(self, dim, index, src_or_value, reduce=None):
5180
src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
5181
scatter_meta_impl(self, dim, index, src, reduce)
5187
aten._scaled_dot_product_flash_attention_backward,
5190
def meta__scaled_dot_product_flash_backward(
5203
philox_seed: Tensor,
5204
philox_offset: Tensor,
5205
scale: Optional[float] = None,
5207
grad_q = torch.empty_like(query.transpose(1, 2)).transpose(1, 2)
5208
grad_k = torch.empty_like(key.transpose(1, 2)).transpose(1, 2)
5209
grad_v = torch.empty_like(value.transpose(1, 2)).transpose(1, 2)
5210
return grad_q, grad_k, grad_v
5215
aten._scaled_dot_product_flash_attention_for_cpu,
5218
def meta__scaled_dot_product_flash_attention_for_cpu(
5222
dropout_p: float = 0.0,
5223
is_causal: bool = False,
5224
attn_mask: Optional[Tensor] = None,
5225
scale: Optional[float] = None,
5227
batch_size = query.size(0)
5228
num_heads = query.size(1)
5229
max_seqlen_batch_q = query.size(2)
5230
head_dim = query.size(3)
5232
attention = torch.empty(
5233
(batch_size, max_seqlen_batch_q, num_heads, head_dim),
5235
device=query.device,
5237
logsumexp = torch.empty(
5244
device=query.device,
5254
aten._scaled_dot_product_flash_attention_for_cpu_backward,
5257
def meta__scaled_dot_product_flash_attention_for_cpu_backward(
5266
attn_mask: Optional[Tensor] = None,
5267
scale: Optional[float] = None,
5271
batch_size = query.size(0)
5272
num_heads = query.size(1)
5273
head_dim = query.size(3)
5274
len_q = query.size(2)
5277
grad_q = torch.empty_permuted(
5278
(batch_size, num_heads, len_q, head_dim),
5281
device=query.device,
5283
grad_k = torch.empty_permuted(
5284
(batch_size, num_heads, len_k, head_dim),
5289
grad_v = torch.empty_permuted(
5290
(batch_size, num_heads, len_k, head_dim),
5293
device=value.device,
5296
return grad_q, grad_k, grad_v
5301
aten._scaled_dot_product_efficient_attention_backward,
5304
def meta__scaled_dot_product_efficient_backward(
5309
attn_bias: Optional[Tensor],
5312
philox_seed: Tensor,
5313
philox_offset: Tensor,
5315
grad_input_mask: List[bool],
5316
is_causal: bool = False,
5317
scale: Optional[float] = None,
5319
batch_size = query.size(0)
5320
num_heads = query.size(1)
5321
max_q = query.size(2)
5322
head_dim = query.size(3)
5323
head_dim_v = value.size(3)
5327
grad_q = torch.empty_permuted(
5328
(batch_size, num_heads, max_q, head_dim),
5331
device=query.device,
5333
grad_k = torch.empty_permuted(
5334
(batch_size, num_heads, max_k, head_dim),
5339
grad_v = torch.empty_permuted(
5340
(batch_size, num_heads, max_k, head_dim_v),
5343
device=value.device,
5346
if attn_bias is not None and grad_input_mask[3]:
5347
lastDim = attn_bias.size(-1)
5348
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5349
new_sizes = list(attn_bias.size())
5350
new_sizes[-1] = lastDimAligned
5351
grad_bias = torch.empty(
5352
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
5354
grad_bias = grad_bias[..., :lastDim]
5356
return grad_q, grad_k, grad_v, grad_bias
5361
aten._flash_attention_backward,
5364
def meta__flash_attention_backward(
5377
philox_seed: Tensor,
5378
philox_offset: Tensor,
5379
scale: Optional[float] = None,
5381
grad_query = torch.empty_like(query)
5382
grad_key = torch.empty_like(key)
5383
grad_value = torch.empty_like(value)
5385
return grad_query, grad_key, grad_value
5390
aten._efficient_attention_backward,
5393
def meta__efficient_attention_backward(
5398
bias: Optional[Tensor],
5399
cu_seqlens_q: Optional[Tensor],
5400
cu_seqlens_k: Optional[Tensor],
5405
philox_seed: Tensor,
5406
philox_offset: Tensor,
5407
custom_mask_type: int,
5408
bias_requires_grad: bool,
5409
scale: Optional[float] = None,
5410
num_splits_key: Optional[int] = None,
5412
grad_query = torch.empty_like(query)
5413
grad_key = torch.empty_like(key)
5414
grad_value = torch.empty_like(value)
5416
if bias is not None:
5417
lastDim = bias.size(-1)
5418
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
5419
new_sizes = list(bias.size())
5420
new_sizes[-1] = lastDimAligned
5421
grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
5422
grad_bias = grad_bias[..., :lastDim]
5424
grad_bias = torch.empty((), device=query.device)
5426
return grad_query, grad_key, grad_value, grad_bias
5429
@register_meta([aten._scaled_mm.default])
5433
bias: Optional[torch.Tensor] = None,
5434
out_dtype: Optional[torch.dtype] = None,
5435
scale_a: Optional[torch.Tensor] = None,
5436
scale_b: Optional[torch.Tensor] = None,
5437
scale_result: Optional[torch.Tensor] = None,
5438
use_fast_accum: bool = False,
5440
def is_row_major(stride):
5441
return stride[0] > stride[1] and stride[1] == 1
5443
def is_col_major(shape, stride):
5444
return stride[0] == 1 and stride[1] == shape[0]
5446
def is_fp8_type(dtype):
5448
torch.float8_e4m3fn,
5450
torch.float8_e4m3fnuz,
5451
torch.float8_e5m2fnuz,
5455
self.dim() == 2 and mat2.dim() == 2,
5456
lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
5459
is_row_major(self.stride()),
5460
lambda: "self must be row_major",
5463
is_col_major(mat2.shape, mat2.stride()),
5464
lambda: "mat2 must be col_major",
5467
self.size(1) % 16 == 0,
5468
lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
5471
mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
5472
lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
5475
is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
5476
lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
5478
_out_dtype = out_dtype if out_dtype is not None else self.dtype
5480
self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
5481
), torch.empty((), dtype=torch.float32, device=self.device)
5484
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
5486
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
5487
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5488
return self.new_empty(self.shape)
5491
@register_meta(aten.scatter_reduce_.two)
5492
def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
5493
scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
5497
@register_meta([aten.multinomial.default, aten.multinomial.out])
5499
def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
5501
0 < input.dim() <= 2,
5502
lambda: f"The probabilty distributions dimensions must be 1 or 2, but got {input.dim()}",
5504
if input.dim() == 1:
5505
return torch.empty(num_samples, dtype=torch.long, device=input.device)
5507
input.size(0), num_samples, dtype=torch.long, device=input.device
5511
def multiply_integers(vs):
5518
def upsample_common_check(input_size, output_size, num_spatial_dims):
5520
len(output_size) == num_spatial_dims,
5521
lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
5523
expected_input_dims = num_spatial_dims + 2
5525
len(input_size) == expected_input_dims,
5526
lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
5530
all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
5531
lambda: f"Input and output sizes should be greater than 0, but got "
5532
f"input size {input_size} and output size {output_size}",
5535
nbatch, channels = input_size[:2]
5536
return (nbatch, channels, *output_size)
5540
[aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
5542
def upsample_nearest1d(input, output_size, scales=None):
5544
input.numel() != 0 or multiply_integers(input.size()[1:]),
5545
lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
5547
full_output_size = upsample_common_check(
5548
input.size(), output_size, num_spatial_dims=1
5550
return input.new_empty(full_output_size).to(
5551
memory_format=utils.suggest_memory_format(input)
5556
[aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
5558
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
5560
input.numel() != 0 or multiply_integers(input.size()[1:]),
5561
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5563
full_output_size = upsample_common_check(
5564
input.size(), output_size, num_spatial_dims=2
5566
output = input.new_empty(full_output_size)
5569
memory_format = utils.suggest_memory_format(input)
5572
_, n_channels, _, _ = input.shape
5573
if input.device.type == "cuda" and n_channels < 4:
5574
memory_format = torch.contiguous_format
5576
output = output.contiguous(memory_format=memory_format)
5583
aten.upsample_nearest2d_backward.default,
5584
aten._upsample_nearest_exact2d_backward.default,
5587
def upsample_nearest2d_backward(
5588
grad_output: Tensor,
5589
output_size: Sequence[Union[int, torch.SymInt]],
5590
input_size: Sequence[Union[int, torch.SymInt]],
5591
scales_h: Optional[float] = None,
5592
scales_w: Optional[float] = None,
5594
full_output_size = upsample_common_check(
5595
input_size, output_size, num_spatial_dims=2
5598
grad_output.ndim == 4,
5599
lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
5603
grad_output.size(i) == full_output_size[i],
5605
f"Expected grad_output to have the same shape as output;"
5606
f" output.size({i}) = {full_output_size[i]}"
5607
f" but got grad_output.size({i}) = {grad_output.size(i)}"
5611
return grad_output.new_empty(input_size).to(
5612
memory_format=utils.suggest_memory_format(grad_output)
5617
[aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
5619
def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
5621
input.numel() != 0 or multiply_integers(input.size()[1:]),
5622
lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
5624
full_output_size = upsample_common_check(
5625
input.size(), output_size, num_spatial_dims=3
5627
return input.new_empty(full_output_size).to(
5628
memory_format=utils.suggest_memory_format(input)
5637
aten.sort.values_stable,
5640
def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
5641
v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
5642
if values is not None and indices is not None:
5643
assert isinstance(values, TensorLike)
5644
assert isinstance(indices, TensorLike)
5648
out_stride = v.stride()
5649
values = _maybe_resize_out(values, out_shape)
5650
indices = _maybe_resize_out(indices, out_shape)
5651
values.as_strided_(out_shape, out_stride)
5652
indices.as_strided_(out_shape, out_stride)
5653
_safe_copy_out(copy_from=v, copy_to=values)
5654
_safe_copy_out(copy_from=i, copy_to=indices)
5655
return values, indices
5659
@register_meta(aten.argsort.stable)
5660
def meta_argsort(self, *, stable, dim=-1, descending=False):
5661
return meta_sort(self, stable=stable, dim=dim, descending=descending)[1]
5664
def rnn_cell_checkSizes(
5665
input_gates, hidden_gates, input_bias, hidden_bias, factor, prev_hidden
5667
torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
5669
input_gates.shape == hidden_gates.shape,
5670
lambda: f"{input_gates.shape} != {hidden_gates.shape}",
5672
gates_size = input_gates.size(1)
5673
if input_bias is not None:
5674
torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
5676
input_bias.numel() == gates_size,
5677
lambda: f"{input_bias.numel()} != {gates_size}",
5680
input_bias.shape == hidden_bias.shape,
5681
lambda: f"{input_bias.shape} != {hidden_bias.shape}",
5683
torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
5684
expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
5686
prev_hidden.numel() == expected_prev_hidden_numel,
5687
lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
5691
x.device == input_gates.device
5692
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
5694
lambda: "expected all inputs to be same device",
5698
@register_meta(aten._thnn_fused_lstm_cell.default)
5699
def _thnn_fused_lstm_cell_meta(
5700
input_gates, hidden_gates, cx, input_bias=None, hidden_bias=None
5702
rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
5703
workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
5704
hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5705
cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
5706
return (hy, cy, workspace)
5709
@register_meta(aten._cudnn_rnn.default)
5728
is_input_packed = len(batch_sizes) != 0
5730
seq_length = len(batch_sizes)
5731
mini_batch = batch_sizes[0]
5732
batch_sizes_sum = input.shape[0]
5734
seq_length = input.shape[1] if batch_first else input.shape[0]
5735
mini_batch = input.shape[0] if batch_first else input.shape[1]
5736
batch_sizes_sum = -1
5738
num_directions = 2 if bidirectional else 1
5739
out_size = proj_size if proj_size != 0 else hidden_size
5741
out_shape = [batch_sizes_sum, out_size * num_directions]
5744
[mini_batch, seq_length, out_size * num_directions]
5746
else [seq_length, mini_batch, out_size * num_directions]
5748
output = input.new_empty(out_shape)
5750
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
5752
cy = torch.empty(0, device=input.device)
5754
cy = cx.new_empty(cell_shape)
5756
hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
5759
reserve_shape = 0 if train else 0
5760
reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
5762
return output, hy, cy, reserve, weight_buf
5765
@register_meta(aten.mkldnn_rnn_layer.default)
5766
def mkldnn_rnn_layer(
5784
seq_length = input.shape[1] if batch_first else input.shape[0]
5785
mini_batch = input.shape[0] if batch_first else input.shape[1]
5786
output_chanels = hidden_size
5788
[mini_batch, seq_length, output_chanels]
5790
else [seq_length, mini_batch, output_chanels]
5792
output = input.new_empty(out_shape)
5794
hy = torch.empty(0, device=input.device)
5796
hy = hx_.new_empty(hx_.shape)
5798
cy = torch.empty(0, device=input.device)
5800
cy = cx_.new_empty(cx_.shape)
5801
workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
5802
return output, hy, cy, workspace
5805
def zero_numel_check_dims(self, dim, fn_name):
5808
dim == 0 or dim == -1,
5809
lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
5813
self.size(dim) != 0,
5814
lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
5819
def check_argmax_argmin(name, self, dim):
5821
dim = maybe_wrap_dim(dim, self.dim())
5822
zero_numel_check_dims(self, dim, name)
5826
lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
5830
@register_meta([aten.argmax.default, aten.argmin.default])
5831
def argmax_argmin_meta(self, dim=None, keepdim=False):
5832
check_argmax_argmin("argmax", self, dim)
5833
dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
5834
shape = _compute_reduction_shape(self, dims, keepdim)
5835
return self.new_empty(shape, dtype=torch.int64)
5838
@register_meta(aten.scalar_tensor.default)
5839
def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
5841
(), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
5845
@register_meta(aten.topk.default)
5846
def topk_meta(self, k, dim=-1, largest=True, sorted=True):
5848
dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
5850
k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1),
5851
lambda: "selected index k out of range",
5853
sliceSize = 1 if self.dim() == 0 else self.size(dim)
5854
torch._check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension")
5856
topKSize = list(self.shape)
5857
if len(topKSize) > 0:
5859
return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
5862
legacy_contiguous_memory_format = torch.contiguous_format
5866
def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
5867
defined_grad = grad_hy if grad_hy is not None else grad_cy
5868
torch._check(defined_grad.dim() == 2, lambda: "")
5869
exp_size = defined_grad.size()
5870
if grad_hy is not None:
5871
torch._check(grad_hy.size() == exp_size, lambda: "")
5872
if grad_cy is not None:
5873
torch._check(grad_cy.size() == exp_size, lambda: "")
5874
torch._check(cx.size() == exp_size, lambda: "")
5875
torch._check(cy.size() == exp_size, lambda: "")
5876
torch._check(workspace.dim() == 2, lambda: "")
5877
torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
5881
@register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
5882
def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
5883
if grad_hy is None and grad_cy is None:
5884
return None, None, None
5885
checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
5886
grad_gates = torch.empty_like(
5887
workspace, memory_format=legacy_contiguous_memory_format
5889
grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
5890
grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
5891
return grad_gates, grad_cx, grad_bias
5895
@register_meta(aten.linear_backward.default)
5896
def linear_backward(input_, grad_output_, weight_, output_mask):
5901
grad_input = grad_output_.new_empty(input_.size())
5902
if output_mask[1] or output_mask[2]:
5903
grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
5904
grad_bias = grad_output_.new_empty(grad_output_.size(-1))
5905
return (grad_input, grad_weight, grad_bias)
5908
@register_meta(aten.pixel_shuffle.default)
5909
def meta_pixel_shuffle(self, upscale_factor):
5911
len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
5912
), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
5914
def is_channels_last(ten):
5915
return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
5917
def pick_memory_format():
5918
if is_channels_last(self):
5919
if device_hint(self) == "cuda":
5920
return torch.contiguous_format
5922
return torch.channels_last
5923
elif self.is_contiguous(memory_format=torch.contiguous_format):
5924
return torch.contiguous_format
5925
elif self.is_contiguous(memory_format=torch.preserve_format):
5926
return torch.preserve_format
5928
C = self.shape[-3] // (upscale_factor * upscale_factor)
5929
Hr = self.shape[-2] * upscale_factor
5930
Wr = self.shape[-1] * upscale_factor
5931
out_shape = (*self.shape[:-3], C, Hr, Wr)
5933
out = self.new_empty(out_shape)
5934
out = out.to(memory_format=pick_memory_format())
5938
@register_meta(aten.mkldnn_rnn_layer_backward.default)
5939
def mkldnn_rnn_layer_backward(
5964
diff_x = input.new_empty(input.shape)
5965
diff_hx = hx_.new_empty(hx_.shape)
5966
diff_cx = cx_tmp.new_empty(cx_tmp.shape)
5967
diff_w1 = weight0.new_empty(weight0.shape)
5968
diff_w2 = weight1.new_empty(weight1.shape)
5969
diff_b = weight2.new_empty(weight2.shape)
5970
return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
5973
@register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
5975
def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
5976
return torch.empty_like(
5977
self, dtype=torch.int32 if out_int32 else torch.int64
5982
[aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
5984
def meta_upsample_bimode2d_aa(
5985
input, output_size, align_corners, scales_h=None, scales_w=None
5987
full_output_size = upsample_common_check(
5988
input.size(), output_size, num_spatial_dims=2
5991
input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
5992
lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
5994
return input.new_empty(full_output_size).to(
5995
memory_format=utils.suggest_memory_format(input)
6000
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
6001
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
6003
found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
6006
inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
6009
found_inf.dtype.is_floating_point,
6010
lambda: "found_inf must be a float tensor.",
6013
inv_scale.dtype.is_floating_point,
6014
lambda: "inv_scale must be a float tensor.",
6019
@register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
6021
def nan_to_num(self, nan=None, posinf=None, neginf=None):
6022
result_size = list(self.size())
6023
return self.new_empty(result_size)
6026
@register_meta(torch.ops.aten.transpose_)
6027
def transpose_(self, dim0, dim1):
6028
assert self.layout not in {
6033
}, f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
6037
dim0 = maybe_wrap_dim(dim0, ndims)
6038
dim1 = maybe_wrap_dim(dim1, ndims)
6043
size = list(self.size())
6044
stride = list(self.stride())
6046
stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
6047
size[dim0], size[dim1] = size[dim1], size[dim0]
6049
self.as_strided_(size, stride)
6053
@register_meta(torch.ops.aten.t_)
6058
sparse_dim = self.sparse_dim()
6059
dense_dim = self.dense_dim()
6061
sparse_dim <= 2 and dense_dim == 0
6062
), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions"
6066
), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
6068
return transpose_(self, 0, 0 if ndims < 2 else 1)
6071
@register_meta(aten.searchsorted)
6073
def meta_searchsorted(
6074
sorted_sequence, self, *, out_int32=False, right=False, side=None, sorter=None
6076
dtype = torch.int32 if out_int32 else torch.int64
6077
if isinstance(self, torch.Tensor):
6078
return torch.empty_like(self, dtype=dtype).contiguous()
6080
return torch.empty((), dtype=dtype, device=sorted_sequence.device)
6083
def _check_for_unsupported_isin_dtype(dtype):
6085
dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
6086
lambda: f"Unsupported input type encountered for isin(): {dtype}",
6090
@register_meta(aten.isin)
6092
def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
6094
isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
6095
lambda: "At least one of elements and test_elements must be a Tensor.",
6097
if not isinstance(elements, Tensor):
6098
elements = torch.tensor(elements, device=test_elements.device)
6100
if not isinstance(test_elements, Tensor):
6101
test_elements = torch.tensor(test_elements, device=elements.device)
6103
_check_for_unsupported_isin_dtype(elements.dtype)
6104
_check_for_unsupported_isin_dtype(test_elements.dtype)
6105
return torch.empty_like(elements, dtype=torch.bool)
6108
@register_meta(aten.polygamma)
6110
def meta_polygamma(n: int, self: Tensor) -> Tensor:
6111
torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
6112
_, result_dtype = elementwise_dtypes(
6114
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
6116
return torch.empty_like(self, dtype=result_dtype)
6119
def _create_unary_float_meta_func(func):
6120
@register_meta(func)
6123
return elementwise_meta(
6124
x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6130
def _create_binary_float_meta_func(func):
6131
@register_meta(func)
6134
return elementwise_meta(
6135
x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
6141
_create_unary_float_meta_func(aten.special_airy_ai)
6142
_create_unary_float_meta_func(aten.special_bessel_y0)
6143
_create_unary_float_meta_func(aten.special_bessel_y1)
6144
_create_unary_float_meta_func(aten.special_modified_bessel_i0)
6145
_create_unary_float_meta_func(aten.special_modified_bessel_i1)
6146
_create_unary_float_meta_func(aten.special_modified_bessel_k0)
6147
_create_unary_float_meta_func(aten.special_modified_bessel_k1)
6148
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
6149
_create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
6152
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
6153
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
6154
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
6155
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
6156
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
6157
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
6158
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
6159
_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
6160
_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
6161
_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
6162
_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
6163
_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
6169
import torch._refs.nn.functional
6170
import torch._refs.special
6174
activate_meta_table = {}
6178
for type in ["meta", "post_autograd", "pre_autograd"]:
6179
registry = global_decomposition_table[type]
6181
for opo in registry:
6182
if opo not in activate_meta_table:
6183
activate_meta_table[opo] = registry[opo]
6185
for op_overload, fn in activate_meta_table.items():
6190
if isinstance(op_overload, torch._ops.HigherOrderOperator):
6192
assert isinstance(op_overload, OpOverload)
6194
op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
6196
if torch._C._dispatch_has_kernel_for_dispatch_key(
6197
op_overload.name(), "CompositeImplicitAutograd"
6203
if op_overload in global_decomposition_table["meta"]:
6205
f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
6206
"register meta function for it. Instead, we should let the decomposition run and write "
6207
"meta kernels for the base operators."
6210
elif op_overload.is_view:
6215
elif op_overload.name() in {
6216
"aten::empty_strided",
6220
"aten::constant_pad_nd",
6222
"aten::as_strided_scatter",
6226
if "mkldnn::" in op_overload.name():
6227
_meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
6228
elif "mkl::" in op_overload.name():
6229
_meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
6230
elif "onednn::" in op_overload.name():
6231
_meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
6232
elif "quantized::" in op_overload.name():
6233
_meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
6237
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)