6
from functools import partial, reduce
7
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
11
import torch._prims_common as utils
13
from torch import sym_float, Tensor, TypedStorage
14
from torch._C import _get_default_device
15
from torch._prims.debug_prims import register_debug_prims
16
from torch._prims.rng_prims import register_rng_prims
17
from torch._prims_common import (
31
from torch._prims_common.wrappers import backwards_not_supported
32
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
33
from torch.overrides import handle_torch_function, has_torch_function
34
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
36
prim = torch.library.Library("prims", "DEF")
37
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
38
prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
39
prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
40
prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
97
"spherical_bessel_j0",
134
"shift_right_arithmetic",
135
"shift_right_logical",
156
"as_strided_scatter",
172
"convert_element_type",
219
tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
221
shape: Optional[ShapeType] = None,
222
strides: Optional[StrideType] = None,
223
dtype: Optional[torch.dtype] = None,
224
device: Optional[Union[torch.device, str]] = None,
226
if isinstance(tensorlike, Number):
227
assert not shape and (shape is None or isinstance(shape, Sequence))
228
assert not strides and (strides is None or isinstance(strides, Sequence))
229
inferred_shape: Tuple[int, ...] = ()
230
inferred_strides: Tuple[int, ...] = ()
231
inferred_dtype = type_to_dtype(type(tensorlike))
232
inferred_device = torch.device("cpu")
236
elif tensorlike is not None:
237
assert isinstance(tensorlike, torch.Tensor)
238
inferred_shape = tuple(tensorlike.shape)
239
inferred_strides = tuple(tensorlike.stride())
240
inferred_dtype = tensorlike.dtype
241
inferred_device = tensorlike.device
245
assert shape is not None
246
assert strides is not None
247
assert dtype is not None
248
assert device is not None
250
shape = inferred_shape if shape is None else tuple(shape)
251
strides = inferred_strides if strides is None else tuple(strides)
252
dtype = inferred_dtype if dtype is None else dtype
253
device = inferred_device if device is None else device
255
if isinstance(device, str):
256
device = torch.device(device)
258
return torch.empty_strided(shape, strides, dtype=dtype, device=device)
264
return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
268
tags: Optional[Sequence[torch.Tag]] = None,
271
Creates a primitive operation.
275
prim.define(schema, tags=torch.Tag.pt2_compliant_tag)
277
def _prim_impl(*args, **kwargs):
281
meta(*args, **kwargs)
282
return impl_aten(*args, **kwargs)
288
def _autograd_impl(*args, **kwargs):
289
return backwards_not_supported(_prim)(*args, **kwargs)
291
def _backend_select_impl(*args, **kwargs):
292
if kwargs.get("device") and kwargs["device"].type == "meta":
293
return meta(*args, **kwargs)
294
if any(isinstance(x, torch.device) and x.type == "meta" for x in args):
295
return meta(*args, **kwargs)
297
return _prim_impl(*args, **kwargs)
299
name = schema.split("(")[0]
300
prim_impl.impl(name, _prim_impl)
301
prim_autograd_impl.impl(name, _autograd_impl)
302
prim_meta_impl.impl(name, meta)
304
_prim_packet = getattr(torch._ops.ops.prims, name)
305
_prim = _prim_packet.default
309
from torch._subclasses.fake_tensor import contains_tensor_types
311
if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
315
"prims.device_put.default"
317
prim_backend_select_impl.impl(name, _backend_select_impl)
319
for p in (_prim_packet, _prim):
321
p.return_type = return_type
324
p.prim_impl = _prim_impl
325
p.prim_meta_impl = meta
326
p.impl_aten = impl_aten
331
class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
335
COMPLEX_TO_FLOAT = (4,)
339
def _prim_elementwise_meta(
341
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
342
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
345
Meta function for elementwise operations that produce outputs in the same dtype
348
Stride logic is currently incorrect.
353
utils.check_same_dtype(*args)
356
if args_with_fixed_dtypes is not None:
357
args_ = list(args_with_fixed_dtypes) + args_
359
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
360
utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
362
l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
363
shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
369
if isinstance(arg, TensorLike):
370
if not utils.is_cpu_scalar_tensor(arg):
375
elif isinstance(arg, Number):
376
scalar_type = type(arg)
378
if dtype is None and scalar_type is not None:
379
dtype = utils.type_to_dtype(scalar_type)
385
if isinstance(arg, TensorLike):
386
if utils.is_cpu_scalar_tensor(arg):
394
elif isinstance(arg, Number):
401
if device is not None:
402
assert dtype is not None
403
if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
405
elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
407
elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
408
if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype):
409
dtype = torch.get_default_dtype()
410
elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
411
if utils.is_complex_dtype(dtype):
412
dtype = utils.corresponding_real_dtype(dtype)
416
assert shape is not None
417
return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype)
424
if isinstance(number, (torch.SymInt, torch.SymFloat)):
426
assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
427
seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
429
number = sym_float(number)
431
return TensorMeta(number)
434
def _complex_only_elementwise_meta(*args, **kwargs):
436
utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
438
return _prim_elementwise_meta(*args, **kwargs)
441
def _make_elementwise_unary_prim(
442
name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
445
Creates an elementwise unary prim.
449
schema=f"{name}(Tensor self) -> Tensor",
450
meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
451
return_type=RETURN_TYPE.NEW,
456
def _make_elementwise_binary_prim(
457
name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
460
Creates an elementwise binary prim.
464
schema=f"{name}(Tensor self, Tensor other) -> Tensor",
465
meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
466
return_type=RETURN_TYPE.NEW,
471
def _not_impl(*args, **kwargs):
472
raise NotImplementedError
480
abs = _make_elementwise_unary_prim(
484
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
487
acos = _make_elementwise_unary_prim(
489
impl_aten=torch.acos,
491
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
494
acosh = _make_elementwise_unary_prim(
496
impl_aten=torch.acosh,
498
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
501
asin = _make_elementwise_unary_prim(
503
impl_aten=torch.asin,
505
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
508
asinh = _make_elementwise_unary_prim(
510
impl_aten=torch.asinh,
512
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
515
atan = _make_elementwise_unary_prim(
517
impl_aten=torch.atan,
519
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
522
atanh = _make_elementwise_unary_prim(
524
impl_aten=torch.atanh,
526
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
529
cos = _make_elementwise_unary_prim(
533
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
536
cosh = _make_elementwise_unary_prim(
538
impl_aten=torch.cosh,
540
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
543
bessel_j0 = _make_elementwise_unary_prim(
545
impl_aten=torch.special.bessel_j0,
547
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
550
bessel_j1 = _make_elementwise_unary_prim(
552
impl_aten=torch.special.bessel_j1,
554
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
557
bessel_i0 = _make_elementwise_unary_prim(
561
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
564
bessel_i0e = _make_elementwise_unary_prim(
566
impl_aten=torch.special.i0e,
568
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
571
bessel_i1 = _make_elementwise_unary_prim(
573
impl_aten=torch.special.i1,
575
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
578
bessel_i1e = _make_elementwise_unary_prim(
580
impl_aten=torch.special.i1e,
582
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
585
bitwise_not = _make_elementwise_unary_prim(
587
impl_aten=torch.bitwise_not,
589
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
593
def _cbrt_aten(a: torch.Tensor) -> Tensor:
596
lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
604
return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
607
cbrt = _make_elementwise_unary_prim(
609
impl_aten=_cbrt_aten,
611
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
614
ceil = _make_elementwise_unary_prim(
616
impl_aten=torch.ceil,
618
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
622
def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
623
if not input.dtype.is_complex:
624
raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
626
strides = utils.compute_elementwise_output_strides(input)
627
return TensorMeta(input, strides=strides)
630
conj_physical = _make_prim(
631
schema="conj_physical(Tensor self) -> Tensor",
632
meta=_conj_physical_meta,
633
impl_aten=torch._conj_physical,
634
doc="Returns the physical conjugation of a complex tensor",
635
return_type=RETURN_TYPE.NEW,
640
input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
642
if memory_format != torch.preserve_format:
648
memory_format=memory_format,
652
strides = utils.compute_elementwise_output_strides(input)
653
return torch.empty_strided(
663
schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
665
impl_aten=torch.clone,
666
doc="Returns the copy of a tensor",
667
return_type=RETURN_TYPE.NEW,
670
digamma = _make_elementwise_unary_prim(
672
impl_aten=torch.digamma,
674
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
677
erf = _make_elementwise_unary_prim(
681
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
684
erf_inv = _make_elementwise_unary_prim(
686
impl_aten=torch.special.erfinv,
688
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
691
erfc = _make_elementwise_unary_prim(
693
impl_aten=torch.special.erfc,
695
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
698
erfcx = _make_elementwise_unary_prim(
700
impl_aten=torch.special.erfcx,
702
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
705
exp = _make_elementwise_unary_prim(
709
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
712
expm1 = _make_elementwise_unary_prim(
714
impl_aten=torch.special.expm1,
716
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
719
exp2 = _make_elementwise_unary_prim(
721
impl_aten=torch.special.exp2,
723
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
727
def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
728
return _prim_elementwise_meta(
729
a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
735
schema="fill(Tensor self, Scalar value) -> Tensor",
736
return_type=RETURN_TYPE.NEW,
738
impl_aten=torch.fill,
742
floor = _make_elementwise_unary_prim(
744
impl_aten=torch.floor,
746
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
750
schema="imag(Tensor self) -> Tensor",
752
_complex_only_elementwise_meta,
753
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
755
return_type=RETURN_TYPE.VIEW,
756
impl_aten=torch.imag,
760
isfinite = _make_elementwise_unary_prim(
762
impl_aten=torch.isfinite,
764
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
767
lgamma = _make_elementwise_unary_prim(
769
impl_aten=torch.lgamma,
771
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
774
log = _make_elementwise_unary_prim(
778
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
781
log1p = _make_elementwise_unary_prim(
783
impl_aten=torch.log1p,
785
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
788
log2 = _make_elementwise_unary_prim(
790
impl_aten=torch.log2,
792
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
795
log10 = _make_elementwise_unary_prim(
797
impl_aten=torch.log10,
799
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
803
schema="real(Tensor self) -> Tensor",
805
_complex_only_elementwise_meta,
806
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
808
return_type=RETURN_TYPE.VIEW,
809
impl_aten=torch.real,
813
reciprocal = _make_elementwise_unary_prim(
815
impl_aten=torch.reciprocal,
817
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
820
ndtri = _make_elementwise_unary_prim(
822
impl_aten=torch.special.ndtri,
824
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
827
neg = _make_elementwise_unary_prim(
831
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
834
round = _make_elementwise_unary_prim(
836
impl_aten=torch.round,
838
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
841
rsqrt = _make_elementwise_unary_prim(
843
impl_aten=torch.rsqrt,
845
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
848
sign = _make_elementwise_unary_prim(
850
impl_aten=torch.sign,
852
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
855
signbit = _make_elementwise_unary_prim(
857
impl_aten=torch.signbit,
859
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
862
sin = _make_elementwise_unary_prim(
866
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
869
sinh = _make_elementwise_unary_prim(
871
impl_aten=torch.sinh,
873
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
876
spherical_bessel_j0 = _make_elementwise_unary_prim(
877
"spherical_bessel_j0",
878
impl_aten=torch.special.spherical_bessel_j0,
880
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
883
sqrt = _make_elementwise_unary_prim(
885
impl_aten=torch.sqrt,
887
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
890
tan = _make_elementwise_unary_prim(
894
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
897
tanh = _make_elementwise_unary_prim(
899
impl_aten=torch.tanh,
901
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
904
trunc = _make_elementwise_unary_prim(
906
impl_aten=torch.trunc,
908
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
915
add = _make_elementwise_binary_prim(
919
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
922
atan2 = _make_elementwise_binary_prim(
924
impl_aten=torch.atan2,
926
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
929
bitwise_and = _make_elementwise_binary_prim(
931
impl_aten=torch.bitwise_and,
933
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
936
bitwise_or = _make_elementwise_binary_prim(
938
impl_aten=torch.bitwise_or,
940
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
943
bitwise_xor = _make_elementwise_binary_prim(
945
impl_aten=torch.bitwise_xor,
947
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
960
is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
961
isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
965
return torch.div(a, b, rounding_mode="trunc")
967
return torch.true_divide(a, b)
970
div = _make_elementwise_binary_prim(
974
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
977
eq = _make_elementwise_binary_prim(
981
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
984
fmax = _make_elementwise_binary_prim(
986
impl_aten=torch.fmax,
988
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
991
fmin = _make_elementwise_binary_prim(
993
impl_aten=torch.fmin,
995
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
998
fmod = _make_elementwise_binary_prim(
1000
impl_aten=torch.fmod,
1002
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1006
gcd = _make_elementwise_binary_prim(
1008
impl_aten=torch.gcd,
1010
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1014
ge = _make_elementwise_binary_prim(
1018
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1021
gt = _make_elementwise_binary_prim(
1025
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1028
hypot = _make_elementwise_binary_prim(
1030
impl_aten=torch.hypot,
1032
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1035
igamma = _make_elementwise_binary_prim(
1037
impl_aten=torch.special.gammainc,
1039
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1042
igammac = _make_elementwise_binary_prim(
1044
impl_aten=torch.special.gammaincc,
1046
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1049
le = _make_elementwise_binary_prim(
1053
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1056
lt = _make_elementwise_binary_prim(
1060
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1066
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1068
if isinstance(a, TensorLike) and isinstance(b, Number):
1069
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1070
elif isinstance(b, TensorLike) and isinstance(a, Number):
1071
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1073
return torch.maximum(a, b)
1076
maximum = _make_elementwise_binary_prim(
1078
impl_aten=_maximum_aten,
1080
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1085
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1087
if isinstance(a, TensorLike) and isinstance(b, Number):
1088
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1089
elif isinstance(b, TensorLike) and isinstance(a, Number):
1090
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1092
return torch.minimum(a, b)
1095
minimum = _make_elementwise_binary_prim(
1097
impl_aten=_minimum_aten,
1099
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1102
mul = _make_elementwise_binary_prim(
1104
impl_aten=torch.mul,
1106
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1109
ne = _make_elementwise_binary_prim(
1113
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1116
nextafter = _make_elementwise_binary_prim(
1118
impl_aten=torch.nextafter,
1120
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1123
pow = _make_elementwise_binary_prim(
1125
impl_aten=torch.pow,
1127
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1130
remainder = _make_elementwise_binary_prim(
1132
impl_aten=torch.remainder,
1134
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1138
shift_left = _make_elementwise_binary_prim(
1140
impl_aten=torch.bitwise_left_shift,
1142
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1145
shift_right_arithmetic = _make_elementwise_binary_prim(
1146
"shift_right_arithmetic",
1147
impl_aten=torch.bitwise_right_shift,
1149
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1152
shift_right_logical = _not_impl
1154
sub = _make_elementwise_binary_prim(
1156
impl_aten=torch.sub,
1158
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1161
zeta = _make_elementwise_binary_prim(
1163
impl_aten=torch.special.zeta,
1165
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1171
def _as_strided_meta(
1172
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
1174
assert len(size) == len(stride)
1175
assert storage_offset >= 0
1176
utils.validate_strides(stride)
1177
utils.validate_shape(size)
1179
if reduce(operator.mul, size) == 0:
1183
elif isinstance(a, torch.Tensor):
1184
utils.check_in_bounds_for_storage(
1185
a._typed_storage(), size, stride, storage_offset
1188
return torch.as_strided(a, size, stride, storage_offset)
1191
def _as_strided_aten(
1192
a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
1194
return torch.as_strided(a, size, stride, storage_offset)
1197
_as_strided_doc = """
1198
Creates a view of the tensor with the given shape (size), strides (stride) and
1199
storage offset (storage_offset).
1202
as_strided = _make_prim(
1203
schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
1204
meta=_as_strided_meta,
1205
impl_aten=_as_strided_aten,
1206
return_type=RETURN_TYPE.VIEW,
1207
doc=_as_strided_doc,
1211
def _broadcast_in_dim_meta(
1212
a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
1214
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1217
assert isinstance(a, TensorLike)
1218
assert isinstance(shape, Sequence)
1219
assert isinstance(broadcast_dimensions, Sequence)
1222
assert a.ndim == len(broadcast_dimensions)
1225
assert len(shape) >= a.ndim
1230
def _greater_than_reduce(acc, x):
1231
assert isinstance(x, Dim)
1233
assert x < len(shape)
1237
reduce(_greater_than_reduce, broadcast_dimensions, -1)
1240
for idx, new_idx in enumerate(broadcast_dimensions):
1241
if not guard_size_oblivious(a.shape[idx] == 1):
1243
a.shape[idx] == shape[new_idx],
1244
lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
1249
for idx in range(len(shape)):
1250
if idx in broadcast_dimensions:
1253
if guard_size_oblivious(a.shape[original_idx] != shape[idx]):
1254
new_strides.append(0)
1256
new_strides.append(a.stride()[original_idx])
1257
original_idx = original_idx + 1
1259
if guard_size_oblivious(shape[idx] != 1):
1260
new_strides.append(0)
1261
elif original_idx == a.ndim:
1262
new_strides.append(1)
1264
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1266
return a.as_strided(shape, new_strides, a.storage_offset())
1269
def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
1271
for broadcast_dimension in broadcast_dimensions:
1272
s[broadcast_dimension] = -1
1275
for idx, x in enumerate(s):
1277
v = v.unsqueeze(idx)
1279
return v.expand(shape)
1282
_broadcast_in_dim_doc = """
1283
Creates a view of a with the specified shape.
1285
Allows adding dimensions of any length and broadcasting
1286
dimensions of length one in a to any length.
1288
The location of the broadcast dimensions must be specified
1289
using the broadcast_dimensions argument. Changing the
1290
relative order of dimensions is not supported.
1293
broadcast_in_dim = _make_prim(
1294
schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
1295
meta=_broadcast_in_dim_meta,
1296
impl_aten=_broadcast_in_dim_aten,
1297
return_type=RETURN_TYPE.VIEW,
1298
doc=_broadcast_in_dim_doc,
1302
def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
1304
ndim = max(1, a.dim())
1305
utils.validate_idx(ndim, start)
1306
utils.validate_idx(ndim, end)
1312
lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
1316
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
1318
Returns the shape of a with dims in [start, end) merged into a single dimension.
1321
shape = (1,) if len(shape) == 0 else tuple(shape)
1324
for s in shape[start : end + 1]:
1325
dim_length = dim_length * s
1327
return shape[0:start] + (dim_length,) + shape[end + 1 :]
1330
def _collapse_view_helper(
1331
a: TensorLikeType, start: int, end: int
1332
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
1333
assert isinstance(a, TensorLike)
1335
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1337
_validate_collapse_args(a, start, end)
1345
strides = a.stride()
1347
if a.ndim == 0 or (end == start):
1348
return shape, strides
1351
stride = strides[end]
1352
for idx in range(end - 1, start - 1, -1):
1353
if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
1360
if guard_size_oblivious(shape[idx] == 1):
1363
length = length * shape[idx]
1364
stride = min(stride, strides[idx])
1367
guard_size_oblivious(a.numel() > 0)
1368
and guard_size_oblivious(shape[idx + 1] != 1)
1369
and not guard_size_oblivious(
1370
strides[idx] == strides[idx + 1] * shape[idx + 1]
1375
new_shape = shape[:start] + (length,) + shape[end + 1 :]
1376
new_strides = strides[:start] + (stride,) + strides[end + 1 :]
1379
if guard_size_oblivious(a.numel() == 0):
1380
new_strides = utils.make_contiguous_strides_for(new_shape)
1382
return new_shape, new_strides
1385
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
1386
new_shape, new_strides = _collapse_view_helper(a, start, end)
1388
if new_shape is None:
1389
msg = "Attempting to view a collapsed tensor, but no such view exists!"
1390
raise ValueError(msg)
1392
assert new_strides is not None
1393
return a.as_strided(new_shape, new_strides, a.storage_offset())
1396
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
1397
new_shape = _collapsed_shape(a.shape, start, end)
1398
return a.view(new_shape)
1401
_collapse_view_doc = """
1402
Creates a view of a with the dimensions between
1403
start (inclusive) and end (exclusive) merged into a
1406
If it's not possible to take such a view then an error
1407
is thrown. See collapse instead.
1409
The dimensions can be merged if and only if
1410
they are all "nested" with each other. That is, they all
1411
have the property that
1413
stride[i] = stride[i+1] * shape[i+1]
1415
for all i in [start, end - 1).
1418
collapse_view = _make_prim(
1419
schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
1420
meta=_collapse_view_meta,
1421
impl_aten=_collapse_view_aten,
1422
return_type=RETURN_TYPE.VIEW,
1423
doc=_collapse_view_doc,
1427
def _conj_meta(a: TensorLikeType) -> TensorLikeType:
1428
if not a.dtype.is_complex:
1429
raise RuntimeError("Expected complex dtype in prims.conj")
1430
out = a.as_strided(a.shape, a.stride(), a.storage_offset())
1431
torch._C._set_conj(out, not a.is_conj())
1436
Returns a conjugated view of the original tensor
1440
schema="conj(Tensor(a) a) -> Tensor(a)",
1442
impl_aten=torch.conj,
1443
return_type=RETURN_TYPE.VIEW,
1449
a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
1452
Creates a view of a with a.ndim + len(dimensions) dimensions, with new
1453
dimensions of length one at the dimensions specified by dimensions.
1455
if ndim is not None:
1457
dims = sorted(utils.canonicalize_dims(ndim, dimensions))
1459
dims = sorted(utils.canonicalize_dims(a.ndim, dimensions))
1460
if len(set(dims)) != len(dims):
1461
msg = f"Received duplicate dimensions to expand in {str(dimensions)}"
1462
raise ValueError(msg)
1464
new_shape = list(a.shape)
1466
new_shape.insert(idx, 1)
1468
broadcast_dimensions = [
1469
idx for idx in range(len(new_shape)) if idx not in dimensions
1471
return broadcast_in_dim(a, new_shape, broadcast_dimensions)
1475
pyslice: Type[slice] = slice
1480
start_indices: DimsSequenceType,
1481
limit_indices: DimsSequenceType,
1482
strides: Optional[StrideType] = None,
1484
_strides = strides if strides is not None else [1] * len(start_indices)
1486
if a.ndim != len(start_indices):
1487
msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!"
1488
raise ValueError(msg)
1490
if a.ndim != len(limit_indices):
1491
msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!"
1492
raise ValueError(msg)
1494
if a.ndim != len(_strides):
1495
msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!"
1496
raise ValueError(msg)
1498
for x, y in zip(start_indices, a.shape):
1500
msg = f"Attempting to slice a tensor with a negative start index of {x}!"
1501
raise ValueError(msg)
1504
f"Attempting to slice a tensor but a start index in {start_indices} is greater than"
1505
f" the length of its corresponding dimension in shape {a.shape}"
1507
raise ValueError(msg)
1509
for x, y, z in zip(limit_indices, a.shape, start_indices):
1511
msg = f"Attempting to slice a tensor with a negative stop index of {x}!"
1512
raise ValueError(msg)
1515
f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of "
1516
f" its corresponding dimension in shape {a.shape}"
1518
raise ValueError(msg)
1521
f"Attempting to slice a tensor but a start index in {x} is greater than "
1522
f" its corresponding stop index {z}"
1527
msg = f"Attempting to slice a tensor with a non-positive step of {x}!"
1528
raise ValueError(msg)
1531
for x, y, z in zip(start_indices, limit_indices, _strides):
1532
new_shape.append(1 + (y - x - 1) // z)
1535
for x, y in zip(a.stride(), _strides):
1536
new_strides.append(x * y)
1538
return a.as_strided(new_shape, new_strides, a.storage_offset())
1543
start_indices: DimsSequenceType,
1544
limit_indices: DimsSequenceType,
1545
strides: Optional[StrideType] = None,
1547
_strides = strides if strides is not None else [1] * len(start_indices)
1550
for start, stop, step in zip(start_indices, limit_indices, _strides):
1551
slices.append(pyslice(start, stop, step))
1553
return operator.getitem(a, slices)
1557
Creates a view of a "bounding box" within the tensor.
1559
The bounding box is specified independently in each of the tensor's dimensions.
1560
start_indices and limit_indices describe the box's boundaries for their corresponding
1561
dimensions. If strides is specified then they specify the step size between elements
1562
in their corresponding dimension.
1564
This operation is analogous to slicing in NumPy, but does not permit slices where
1565
the stop indices are less than the start indices.
1569
schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
1571
impl_aten=_slice_aten,
1572
return_type=RETURN_TYPE.VIEW,
1577
def _slice_in_dim_meta(
1585
msg = f"slice_in_dim: received a negative axis {axis}"
1586
raise ValueError(msg)
1588
msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor"
1589
raise ValueError(msg)
1592
msg = f"slice_in_dim: received a negative start_index {start_index}"
1593
raise ValueError(msg)
1595
if start_index > a.shape[axis]:
1596
msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}"
1597
raise ValueError(msg)
1599
if limit_index > a.shape[axis]:
1600
msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}"
1601
raise ValueError(msg)
1603
if limit_index < start_index:
1604
msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}"
1605
raise ValueError(msg)
1608
msg = f"slice_in_dim: received a non-positive stride of {stride}!"
1609
raise ValueError(msg)
1611
start_indices = [0] * a.ndim
1612
limit_indices = list(a.shape)
1613
strides = [1] * a.ndim
1615
start_indices[axis] = start_index
1616
limit_indices[axis] = limit_index
1617
strides[axis] = stride
1619
return _slice_meta(a, start_indices, limit_indices, strides)
1622
def _slice_in_dim_aten(
1629
start_indices = [0] * a.ndim
1630
limit_indices = list(a.shape)
1631
strides = [1] * a.ndim
1633
start_indices[axis] = start_index
1634
limit_indices[axis] = limit_index
1635
strides[axis] = stride
1637
return slice(a, start_indices, limit_indices, strides)
1640
_slice_in_dim_doc = """
1641
Convenience wrapper for slicing just one dimension using slice.
1645
slice_in_dim = _make_prim(
1646
schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
1647
meta=_slice_in_dim_meta,
1648
impl_aten=_slice_in_dim_aten,
1649
return_type=RETURN_TYPE.VIEW,
1650
doc=_slice_in_dim_doc,
1654
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
1655
assert isinstance(a, TensorLike)
1656
utils.validate_idx(a.ndim, dim)
1657
utils.validate_dim_length(outer_length)
1660
inner_length = a.shape[dim] // outer_length
1662
if (a.shape[dim] % outer_length) != 0:
1663
msg = "Attempting to split dimension of length {}, but outer length of {} divides it with a remainder!".format(
1664
a.shape[dim], outer_length
1666
raise ValueError(msg)
1668
new_shape: List[int] = []
1669
new_strides: List[int] = []
1670
for idx in range(a.ndim):
1672
new_shape.extend((outer_length, inner_length))
1673
new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
1675
new_shape.append(a.shape[idx])
1676
new_strides.append(a.stride()[idx])
1678
return a.as_strided(new_shape, new_strides, a.storage_offset())
1681
def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
1682
inner_length = a.shape[dim] // outer_length
1683
new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
1685
return a.view(new_shape)
1689
Creates a view of a with the given dimension (of length l) split
1690
into two dimensions, with the outer of the two having
1691
length outer_length and the inner of the two having computed
1692
length inner_length such outer_length * inner_length = l.
1696
split_dim = _make_prim(
1697
schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
1698
meta=_split_dim_meta,
1699
impl_aten=_split_dim_aten,
1700
return_type=RETURN_TYPE.VIEW,
1706
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
1707
assert isinstance(a, TensorLike)
1709
for idx in dimensions:
1710
utils.validate_idx(a.ndim, idx)
1711
assert a.shape[idx] == 1
1715
for idx in range(len(a.shape)):
1716
if idx in dimensions:
1719
new_shape.append(a.shape[idx])
1720
new_strides.append(a.stride()[idx])
1722
return a.as_strided(new_shape, new_strides, a.storage_offset())
1726
Creates a view of the tensor with the specified dimensions removed.
1728
The removed dimensions must each have length one.
1731
squeeze = _make_prim(
1732
schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
1734
impl_aten=torch.squeeze,
1735
return_type=RETURN_TYPE.VIEW,
1740
def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
1741
if a.ndim != len(permutation):
1742
msg = "Attempting to permute a tensor of rank {}, but received a permutation of length {}!".format(
1743
a.ndim, len(permutation)
1745
raise ValueError(msg)
1747
if not utils.is_valid_permutation(a.ndim, permutation):
1748
msg = f"Received an invalid permutation, {permutation}!"
1749
raise ValueError(msg)
1751
new_shape = [0] * a.ndim
1752
new_strides = [0] * a.ndim
1753
for idx, dim in enumerate(permutation):
1754
new_shape[idx] = a.shape[dim]
1755
new_strides[idx] = a.stride()[dim]
1757
return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
1760
def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
1761
return torch.permute(a, permutation)
1765
Creates a view of the tensor with its dimensions permuted.
1767
The length of the permutation must be the rank of the tensor,
1768
and each element of the permutation specifies the new order
1769
for the corresponding dimension.
1772
transpose = _make_prim(
1773
schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
1774
meta=_transpose_meta,
1775
impl_aten=_transpose_aten,
1776
return_type=RETURN_TYPE.VIEW,
1781
def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
1782
return a.as_strided(a.shape, a.stride(), a.storage_offset())
1785
def _view_of_aten(a: Tensor) -> Tensor:
1786
return a.view(a.shape)
1790
Creates a view of the tensor.
1793
view_of = _make_prim(
1794
schema="view_of(Tensor(a) a) -> Tensor",
1796
impl_aten=_view_of_aten,
1797
return_type=RETURN_TYPE.VIEW,
1802
def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
1803
return a.view(dtype)
1806
def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
1807
return a.view(dtype)
1810
_view_element_type_doc = """
1811
Creates a view of the tensor with a different dtype.
1814
view_element_type = _make_prim(
1815
schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor",
1816
meta=_view_element_type_meta,
1817
impl_aten=_view_element_type_aten,
1818
return_type=RETURN_TYPE.VIEW,
1819
doc=_view_element_type_doc,
1827
def _as_strided_scatter_meta(
1828
input: TensorLikeType,
1829
src: TensorLikeType,
1832
storage_offset: int,
1834
utils.validate_shape(size)
1835
utils.validate_strides(stride)
1837
required_size = utils.compute_required_storage_length(size, stride, storage_offset)
1839
input.numel() >= required_size,
1841
f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
1842
f" and itemsize {input.element_size()} requiring a storage size of "
1843
f"{required_size * input.element_size()} are out of bounds "
1844
f"for storage of size {input.numel() * input.element_size()}"
1848
utils.is_same_shape(src.shape, size),
1849
lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
1852
return utils.clone_preserve_strides(input)
1855
_as_strided_scatter_doc = """
1856
Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
1857
``out.as_strided(size, stride, storage_offset).copy_(src)``.
1860
as_strided_scatter = _make_prim(
1861
schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
1862
meta=_as_strided_scatter_meta,
1863
impl_aten=torch.as_strided_scatter,
1864
return_type=RETURN_TYPE.NEW,
1865
doc=_as_strided_scatter_doc,
1874
def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
1876
_validate_collapse_args(a, start, end)
1877
new_shape = _collapsed_shape(a.shape, start, end)
1878
return a.new_empty(new_shape)
1881
def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
1882
new_shape = _collapsed_shape(a.shape, start, end)
1883
out = a.new_empty(new_shape)
1884
with torch.no_grad():
1885
out.view_as(a).copy_(a)
1890
Collapse a span of neighboring dimensions into one.
1892
See collapse_view for the corresponding view operation.
1894
collapse = _make_prim(
1895
schema="collapse(Tensor a, int start, int end) -> Tensor",
1896
meta=_collapse_meta,
1897
impl_aten=_collapse_aten,
1898
return_type=RETURN_TYPE.NEW,
1906
def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
1909
shape = tensors[0].shape
1911
for tensor_idx, tensor in enumerate(tensors):
1912
assert len(shape) == len(tensor.shape)
1913
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
1915
concat_length = concat_length + length
1918
length == common_length,
1919
lambda: f"Sizes of tensors must match except in dimension {dim}. "
1920
f"Expected {common_length} but got {length} for tensor number "
1921
f"{tensor_idx} in the list",
1924
new_shape = list(tensors[0].shape).copy()
1925
new_shape[dim] = concat_length
1929
strides=utils.make_contiguous_strides_for(new_shape),
1933
def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
1934
return torch.cat(tensors, dim)
1938
Concatenates tensors along the specified dimension.
1940
The tensors' shapes must have the same rank and same length for other dimensions.
1944
schema="cat(Tensor[] tensors, int dim) -> Tensor",
1946
impl_aten=_cat_aten,
1947
return_type=RETURN_TYPE.NEW,
1952
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
1953
assert isinstance(a, TensorLike)
1954
utils.validate_shape(shape)
1958
numel = reduce(operator.mul, shape)
1959
if numel != a.numel():
1960
msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!"
1961
raise ValueError(msg)
1963
return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
1966
def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
1967
return a.reshape(shape).contiguous().clone()
1971
Creates a contiguous tensor with the specified shape
1972
containing a copy of the data in a.
1974
reshape = _make_prim(
1975
schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
1977
impl_aten=_reshape_aten,
1978
return_type=RETURN_TYPE.NEW,
1983
def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
1984
utils.validate_dimension_indices(a.ndim, dims)
1985
return torch.empty_like(a, memory_format=torch.preserve_format)
1989
Reverses the order of elements along the given dimensions.
1993
schema="rev(Tensor a, int[] dims) -> Tensor",
1995
impl_aten=torch.flip,
1996
return_type=RETURN_TYPE.NEW,
2006
pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
2008
return _prim_elementwise_meta(
2011
type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
2012
args_with_fixed_dtypes=(pred,),
2017
Selects elements from a and b according to pred.
2019
Where pred is true the result contains the element from a, and
2020
where pred is false the result contains the element from b.
2024
schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
2026
impl_aten=torch.where,
2027
return_type=RETURN_TYPE.NEW,
2035
def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
2037
assert isinstance(a, TensorLike)
2038
assert isinstance(dtype, torch.dtype)
2041
if torch._prims_common.is_non_overlapping_and_dense(a):
2042
strides = a.stride()
2044
strides = utils.compute_elementwise_output_strides(a)
2046
return TensorMeta(a, strides=strides, dtype=dtype)
2049
def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
2051
if not utils.is_grad_dtype(dtype):
2052
requires_grad = False
2056
requires_grad = a.requires_grad
2057
except Exception as e:
2058
requires_grad = False
2060
result = torch.empty_like(
2061
a, device=a.device, dtype=dtype, requires_grad=requires_grad
2063
with torch.no_grad():
2064
return copy_to(result, a)
2067
_convert_element_type_doc = """
2068
Creates a copy of a tensor with the given dtype.
2071
convert_element_type = _make_prim(
2072
schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
2073
meta=_convert_element_type_meta,
2074
impl_aten=_convert_element_type_aten,
2075
return_type=RETURN_TYPE.NEW,
2076
doc=_convert_element_type_doc,
2077
tags=(torch.Tag.pointwise,),
2081
def _device_put_meta(
2082
a: TensorLikeType, device: Union[str, torch.device]
2084
assert isinstance(a, TensorLike)
2085
assert isinstance(device, (str, torch.device))
2087
return TensorMeta(a, device=utils.canonicalize_device(device))
2090
def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
2094
_device_put_doc = """
2095
Creates a copy of a tensor on the given device.
2098
device_put = _make_prim(
2099
schema="device_put(Tensor a, Device device) -> Tensor",
2100
meta=_device_put_meta,
2101
impl_aten=_device_put_aten,
2102
return_type=RETURN_TYPE.NEW,
2103
doc=_device_put_doc,
2109
def _item_meta(a: TensorLikeType) -> FakeTensor:
2110
number_type = utils.dtype_to_type(a.dtype)
2111
return TensorMeta(number_type(-1))
2115
Converts a tensor with one element to a Python number.
2122
schema="item(Tensor a) -> Scalar",
2124
impl_aten=torch.Tensor.item,
2125
return_type=RETURN_TYPE.NEW,
2132
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
2133
number_type = utils.dtype_to_type(dtype)
2134
return TensorMeta(number_type(-1))
2137
def _maximum_value_aten(dtype: torch.dtype):
2138
if dtype == torch.bool:
2140
elif dtype.is_complex or dtype.is_floating_point:
2141
return torch.finfo(dtype).max
2143
return torch.iinfo(dtype).max
2146
_maximum_value_doc = """
2147
Return the maximum finite value for a dtype.
2153
maximum_value = _make_prim(
2154
schema="maximum_value(ScalarType dtype) -> Scalar",
2155
meta=_maximum_value_meta,
2156
impl_aten=_maximum_value_aten,
2157
return_type=RETURN_TYPE.NEW,
2158
doc=_maximum_value_doc,
2164
def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
2165
number_type = utils.dtype_to_type(dtype)
2166
return TensorMeta(number_type(-1))
2169
def _minimum_value_aten(dtype: torch.dtype):
2170
if dtype == torch.bool:
2172
elif dtype.is_complex or dtype.is_floating_point:
2173
return torch.finfo(dtype).min
2175
return torch.iinfo(dtype).min
2178
_minimum_value_doc = """
2179
Return the minimum finite value for a dtype.
2185
minimum_value = _make_prim(
2186
schema="minimum_value(ScalarType dtype) -> Scalar",
2187
meta=_minimum_value_meta,
2188
impl_aten=_minimum_value_aten,
2189
return_type=RETURN_TYPE.NEW,
2190
doc=_minimum_value_doc,
2198
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
2199
assert isinstance(a, TensorLike)
2200
assert isinstance(b, TensorLike)
2210
if a.numel() != b.numel():
2211
msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!"
2212
raise RuntimeError(msg)
2217
def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
2222
Copies the data in b to a and returns the modified a.
2226
copy_to = _make_prim(
2227
schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
2229
impl_aten=_copy_to_aten,
2230
return_type=RETURN_TYPE.INPLACE,
2235
def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
2236
assert isinstance(a, TensorLike)
2237
return torch.empty_strided(
2243
requires_grad=a.requires_grad,
2247
def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
2248
out = torch.empty_strided(
2254
requires_grad=a.requires_grad,
2260
_copy_strided_doc = """
2261
Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
2265
copy_strided = _make_prim(
2266
schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
2267
meta=_copy_strided_meta,
2268
impl_aten=_copy_strided_aten,
2269
return_type=RETURN_TYPE.NEW,
2270
doc=_copy_strided_doc,
2274
def _resize_meta(a: TensorLikeType, shape: ShapeType):
2275
return a.resize_(shape)
2278
def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
2279
return a.resize_(shape)
2283
Gives a tensor with no elements a new shape, returning the modified tensor.
2285
The tensor's strides are contiguous and its values are unitialized.
2290
schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
2292
impl_aten=_resize_aten,
2293
return_type=RETURN_TYPE.INPLACE,
2298
def _reduction_meta(inp, dims, *, output_dtype=None):
2300
Meta function for single output reduction operations
2301
Stride logic is incorrect
2303
assert isinstance(inp, TensorLike)
2304
if output_dtype is None:
2305
output_dtype = inp.dtype
2306
output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
2309
strides=utils.make_contiguous_strides_for(output_shape),
2315
def _var_reduction_meta(inp, dims, *, correction):
2316
if utils.is_complex_dtype(inp.dtype):
2317
output_dtype = utils.corresponding_real_dtype(inp.dtype)
2319
output_dtype = inp.dtype
2320
return _reduction_meta(inp, dims, output_dtype=output_dtype)
2324
Computes the sum of elements in the input tensor over the list of dimensions
2325
specified in the dim argument
2328
Computes the xor sum of elements in the input tensor over the list of dimensions
2329
specified in the dim argument
2332
Computes the product of elements in the input tensor over the list of dimensions
2333
specified in the dim argument
2336
Computes the maximum value of elements in the input tensor over the list of dimensions
2337
specified in the dim argument
2340
Computes the minimum value of elements in the input tensor over the list of dimensions
2341
specified in the dim argument
2344
Computes the biased variance of x over the list of dimensions specified in the dim argument
2348
def _make_reduction_prim(name: str, impl_aten, doc):
2349
"""Creates a reduction prim."""
2351
schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
2352
meta=_reduction_meta,
2353
impl_aten=impl_aten,
2354
return_type=RETURN_TYPE.NEW,
2359
def _make_var_reduction_prim(name: str, impl_aten, doc):
2360
"""Creates a reduction prim."""
2362
schema=f"{name}(Tensor inp, int[]? dims, *, float correction, ScalarType? output_dtype=None) -> Tensor",
2363
meta=_var_reduction_meta,
2364
impl_aten=impl_aten,
2365
return_type=RETURN_TYPE.NEW,
2370
sum = _make_reduction_prim(
2372
impl_aten=torch.sum,
2378
inp: TensorLikeType,
2379
dims: Optional[DimsSequenceType],
2381
dtype: Optional[torch.dtype] = None,
2383
raise NotImplementedError("xor_sum only implemented with inductor")
2386
xor_sum = _make_reduction_prim(
2388
impl_aten=_xor_sum_aten,
2394
inp: TensorLikeType,
2395
dims: Optional[DimsSequenceType],
2397
dtype: Optional[torch.dtype] = None,
2399
if dims is not None:
2400
for d in sorted(dims, reverse=True):
2402
inp = torch.prod(inp, d, dtype=dtype)
2405
return torch.prod(inp, dims, dtype=dtype)
2408
prod = _make_reduction_prim(
2410
impl_aten=_prod_aten,
2414
var = _make_var_reduction_prim(
2416
impl_aten=torch.var,
2420
amax = _make_reduction_prim(
2422
impl_aten=torch.amax,
2426
amin = _make_reduction_prim(
2428
impl_aten=torch.amin,
2434
Constructs a 1-D tensor t where ``t[i] == start + i * step``.
2446
device: torch.device,
2447
requires_grad: bool,
2450
utils.is_integer_dtype(dtype),
2451
lambda: "prims.iota only supports integer dtypes",
2453
torch._check(step != 0, lambda: "step must be nonzero")
2458
requires_grad=requires_grad,
2468
device: torch.device,
2469
requires_grad: bool,
2471
end = start + length * step
2472
return torch.arange(
2473
start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
2478
schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2479
return_type=RETURN_TYPE.NEW,
2481
impl_aten=_iota_aten,
2489
shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2491
strides = utils.make_contiguous_strides_for(shape)
2492
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2496
shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2498
return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2502
Creates a tensor with uninitialized values and the specified shape, dtype, and device.
2506
schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2508
impl_aten=_empty_aten,
2509
return_type=RETURN_TYPE.NEW,
2514
def _empty_strided_meta(
2516
strides: StrideType,
2519
device: torch.device,
2520
requires_grad: bool,
2522
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2525
_empty_strided_doc = """
2526
Creates a tensor with uninitialized values.
2530
empty_strided = _make_prim(
2531
schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2532
return_type=RETURN_TYPE.NEW,
2533
meta=_empty_strided_meta,
2534
impl_aten=torch.empty_strided,
2535
doc=_empty_strided_doc,
2539
def _empty_permuted_meta(
2541
physical_layout: DimsSequenceType,
2544
device: torch.device,
2545
requires_grad: bool,
2547
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
2550
len(physical_layout) == dim,
2552
"Number of dimensions in the tensor input does not match the "
2553
f"length of the physical layout; i.e. len(size) = {dim} "
2554
f"is not equal to len(physical_layout) = {len(physical_layout)}"
2557
strides = [0] * len(shape)
2559
for p, l in enumerate(physical_layout):
2563
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
2564
f"{l} at index {p}). NB: negative dims "
2565
"not currently supported; file an issue if you want it."
2568
torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
2569
strides[l] = p_strides[p]
2579
_empty_permuted_doc = """
2580
Creates a tensor with uninitialized values according to some physical layout,
2581
that is guaranteed to be non-overlapping and dense.
2585
empty_permuted = _make_prim(
2586
schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2587
return_type=RETURN_TYPE.NEW,
2588
meta=_empty_permuted_meta,
2589
impl_aten=torch.empty_permuted,
2590
doc=_empty_permuted_doc,
2596
fill_value: NumberType,
2599
device: torch.device,
2600
requires_grad: bool,
2602
strides = utils.make_contiguous_strides_for(shape)
2603
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2608
fill_value: NumberType,
2611
device: torch.device,
2612
requires_grad: bool,
2616
shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad
2621
Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
2626
schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2628
impl_aten=_full_aten,
2629
return_type=RETURN_TYPE.NEW,
2636
fill_value: NumberType,
2639
device: torch.device,
2640
requires_grad: bool,
2642
strides = utils.compute_elementwise_output_strides(a)
2644
strides = a.stride()
2646
return TensorMeta(a, strides=strides, dtype=dtype, device=device)
2651
fill_value: NumberType,
2654
device: torch.device,
2655
requires_grad: bool,
2658
return torch.full_like(
2659
a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad
2664
Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
2665
given tensor by default. The dtype and device settings can be overridden
2666
by specifying them explicitly.
2669
full_like = _make_prim(
2670
schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2671
meta=_full_like_meta,
2672
impl_aten=_full_like_aten,
2673
return_type=RETURN_TYPE.NEW,
2678
def _scalar_tensor_meta(
2682
device: torch.device,
2684
shape: ShapeType = []
2685
strides = utils.make_contiguous_strides_for(shape)
2686
return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
2689
def _scalar_tensor_aten(
2693
device: torch.device,
2695
if isinstance(scalar, complex) and (
2696
dtype is None or not utils.is_complex_dtype(dtype)
2698
raise TypeError("Complex scalar requires complex tensor dtype.")
2700
return torch.scalar_tensor(scalar, dtype=dtype, device=device)
2703
_scalar_tensor_doc = """
2704
Wraps a Number into a Tensor with the specified dtype and device.
2708
scalar_tensor = _make_prim(
2709
schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
2710
meta=_scalar_tensor_meta,
2711
impl_aten=_scalar_tensor_aten,
2712
return_type=RETURN_TYPE.NEW,
2713
doc=_scalar_tensor_doc,
2723
A: TensorLikeType, *, full_matrices: bool
2724
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
2725
utils.check_is_matrix(A, "linalg.svd")
2726
utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
2729
batch = A_shape[:-2]
2733
shape_U = batch + (m, m if full_matrices else k)
2734
strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
2735
U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
2737
shape_S = batch + (k,)
2738
strides_S = utils.make_contiguous_strides_for(shape_S)
2742
dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
2746
shape_Vh = batch + (n if full_matrices else k, n)
2749
is_cuda = A.device.type == "cuda"
2750
strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
2751
Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
2754
if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available():
2760
A: TensorLikeType, *, full_matrices: bool
2761
) -> Tuple[Tensor, Tensor, Tensor]:
2762
return torch.linalg.svd(A, full_matrices=full_matrices)
2766
Returns the SVD of a matrix or batch of matrices.
2768
The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
2772
schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
2774
impl_aten=_svd_aten,
2775
return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
2788
mean: Union[float, complex],
2791
device: torch.device,
2792
requires_grad: bool,
2793
generator: Optional[torch.Generator] = None,
2797
lambda: f"expected non-negative standard deviation, but got std={std}",
2801
utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
2802
lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
2805
strides = utils.make_contiguous_strides_for(shape)
2806
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2812
mean: Union[float, complex],
2815
device: torch.device,
2816
requires_grad: bool,
2817
generator: Optional[torch.Generator] = None,
2819
a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2820
with torch.no_grad():
2822
a.normal_(mean, std, generator=generator)
2827
Constructs a tensor filled with values drawn from a normal distribution with the specified mean
2828
and standard deviation.
2830
Only supports floating-point types.
2835
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor"
2837
return_type=RETURN_TYPE.NEW,
2839
impl_aten=_normal_aten,
2850
device: torch.device,
2851
generator: Optional[torch.Generator] = None,
2853
strides = utils.make_contiguous_strides_for(shape)
2854
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2863
device: torch.device,
2864
generator: Optional[torch.Generator] = None,
2866
a = torch.empty(shape, dtype=dtype, device=device)
2867
a.uniform_(low, high, generator=generator)
2872
Constructs a tensor filled with values drawn uniformly from low to high.
2876
_uniform_helper = _make_prim(
2878
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
2880
return_type=RETURN_TYPE.NEW,
2882
impl_aten=_uniform_aten,
2894
dim: DimsSequenceType,
2897
dim = utils.canonicalize_dims(input.ndim, dim)
2898
utils.validate_no_repeating_dims(dim)
2900
shape = list(input.shape)
2903
shape[last_dim] = shape[last_dim] // 2 + 1
2905
dtype = utils.corresponding_complex_dtype(input.dtype)
2906
strides = utils.make_contiguous_strides_for(shape)
2907
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
2913
dim: DimsSequenceType,
2917
return torch._fft_r2c(input, dim, normalization, onesided)
2921
Performs a real to complex Fast Fourier Transform
2925
fft_r2c = _make_prim(
2926
schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
2928
impl_aten=_fft_r2c_aten,
2929
return_type=RETURN_TYPE.NEW,
2937
dim: DimsSequenceType,
2940
dim = utils.canonicalize_dims(input.ndim, dim)
2941
utils.validate_no_repeating_dims(dim)
2944
strides = utils.make_contiguous_strides_for(shape)
2946
shape=shape, strides=strides, dtype=input.dtype, device=input.device
2953
dim: DimsSequenceType,
2957
return torch._fft_c2c(input, dim, normalization, forward)
2961
Performs either a Fast Fourier Transform, or its inverse
2965
fft_c2c = _make_prim(
2966
schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
2968
impl_aten=_fft_c2c_aten,
2969
return_type=RETURN_TYPE.NEW,
2977
dim: DimsSequenceType,
2980
dim = utils.canonicalize_dims(input.ndim, dim)
2981
utils.validate_no_repeating_dims(dim)
2983
shape = list(input.shape)
2984
shape[dim[-1]] = last_dim_size
2985
dtype = utils.corresponding_real_dtype(input.dtype)
2986
strides = utils.make_contiguous_strides_for(shape)
2987
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
2993
dim: DimsSequenceType,
2997
return torch._fft_c2r(input, dim, normalization, last_dim_size)
3001
Performs a complex to real Inverse Fast Fourier Transform
3005
fft_c2r = _make_prim(
3006
schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
3008
impl_aten=_fft_c2r_aten,
3009
return_type=RETURN_TYPE.NEW,
3014
def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
3016
self.dtype.is_floating_point,
3017
lambda: "torch.frexp() only supports floating-point dtypes",
3019
return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32)
3023
schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)",
3025
return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW),
3026
impl_aten=torch.frexp,
3031
register_debug_prims()