pytorch
1996 строк · 64.0 Кб
1# mypy: allow-untyped-defs
2from __future__ import annotations3
4import operator5import warnings6from contextlib import nullcontext7from enum import Enum8from functools import reduce9from typing import (10Any,11Callable,12cast,13List,14NamedTuple,15Optional,16overload,17Sequence,18Tuple,19Type,20TYPE_CHECKING,21Union,22)
23from typing_extensions import deprecated, TypeAlias24
25
26if TYPE_CHECKING:27# Import the following modules during type checking to enable code intelligence features,28# such as auto-completion in tools like pylance, even when these modules are not explicitly29# imported in user code.30
31import sympy32
33import torch34from torch import sym_float, sym_int, sym_max35
36
37ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]38StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]39DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]40DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]41# TODO: Type[torch.SymInt], Type[torch.SymFloat]
42NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]43# TODO: This needs a lot more type annotations
44# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
45NumberType: TypeAlias = Union[bool, int, float, complex]46RealNumberType: TypeAlias = Union[bool, int, float]47
48Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool)49# I don't call it Integral because numbers.Integral includes bool, but IntLike
50# does not
51Dim = int52IntLike = (int, torch.SymInt)53FloatLike = (float, torch.SymFloat)54BoolLike = (bool, torch.SymBool)55IntWithoutSymInt = int56FloatWithoutSymFloat = float57DeviceLikeType: TypeAlias = Union[str, torch.device, int]58Tensor = torch.Tensor59
60
61torch_function_passthrough = {62torch.device,63torch.sym_not,64torch.sym_float,65torch.sym_int,66torch.sym_max,67torch.sym_min,68torch._sym_sqrt, # type: ignore[attr-defined]69torch.sym_ite,70torch.Tensor.dim,71torch.Tensor.ndim.__get__, # type: ignore[attr-defined]72torch.Tensor.numel,73torch.Tensor.size,74torch.Tensor.storage_offset,75torch.Tensor.stride,76torch.Tensor.dtype.__get__, # type: ignore[attr-defined]77torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined]78torch.Tensor.shape.__get__, # type: ignore[attr-defined]79torch.Tensor.device.__get__, # type: ignore[attr-defined]80torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]81torch.Tensor.layout.__get__, # type: ignore[attr-defined]82torch.Tensor.is_contiguous,83# For TorchRefsMode only84torch.Tensor.__format__,85torch.Tensor.__repr__,86torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]87torch.Tensor.__getitem__,88}
89
90
91TensorLikeType = torch.Tensor92TensorLike = torch.Tensor93TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]94TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]95
96CustomOutParamAnnotation = "__custom_out_param__"97
98
99def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:100from torch.fx.experimental.symbolic_shapes import guard_size_oblivious101
102if len(a) != len(b):103return False104
105for x, y in zip(a, b):106if allow_rhs_unbacked:107# TODO: We should check that the symbols are consistent108# with each other109if isinstance(y, torch.SymInt):110continue111# NB: Naively, you would not expect to have to do an oblivious guard112# here because there is seemingly no broadcasting here, but in fact we113# use this in some situations to determine if we need to do an expand114# on the tensor because they don't line up, so you can definitely end115# up trying to prove u0 != 1 in this situation. See116# python test/test_proxy_tensor.py -k test_cumsum_unbacked117if guard_size_oblivious(x != y):118return False119
120return True121
122
123def _maybe_get_pytype(t):124if t is torch.SymFloat:125return float126elif t is torch.SymInt:127return int128elif t is torch.SymBool:129return bool130else:131return t132
133
134# TODO: look at using torch.testing.assert_close instead with an option
135# to just compare metadata
136def compare_tensor_meta(137a: TensorLikeType,138b: TensorLikeType,139check_strides=False,140*,141allow_rhs_unbacked=False,142check_conj=True,143):144"""145Checks that two tensor likes have the same shape,
146dtype and device.
147
148In the future this will validate additional metadata, like
149strides.
150"""
151assert isinstance(a, TensorLike)152assert isinstance(b, TensorLike)153
154if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):155msg = f"Shapes {a.shape} and {b.shape} are not equal!"156raise AssertionError(msg)157
158if a.dtype != b.dtype:159msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"160raise AssertionError(msg)161
162if a.device != b.device:163# Handles special cuda:0 vs cuda case164# TODO: we should review why this happens and see about fixing it165if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (166str(b.device) == "cuda:0" or str(b.device) == "cuda"167):168pass169else:170msg = f"Devices {a.device} and {b.device} are not equal!"171raise AssertionError(msg)172
173# Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050174if check_strides:175same_strides, idx = check_significant_strides(a, b)176if not same_strides:177msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"178raise RuntimeError(msg)179
180if a.storage_offset() != b.storage_offset():181msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"182raise RuntimeError(msg)183
184if check_conj:185if a.is_conj() != b.is_conj():186raise RuntimeError(187f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"188)189
190if a.is_neg() != b.is_neg():191raise RuntimeError(192f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"193)194
195
196def _check_strides_helper(197a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True198) -> Tuple[bool, Optional[int]]:199# NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch200# See https://github.com/pytorch/pytorch/issues/77553201# Only compares strides that are "meaningful" -- strides for dimensions with length > 1202# and for tensors with more than one element203if (204not only_cuda or a.device.type == "cuda" or b.device.type == "cuda"205) and a.numel() > 0:206for idx in range(a.ndim):207check = not significant_only or a.shape[idx] > 1208if a.stride()[idx] != b.stride()[idx] and check:209return False, idx210
211return True, None212
213
214def check_significant_strides(215a: TensorLikeType, b: TensorLikeType, *, only_cuda=True216) -> Tuple[bool, Optional[int]]:217return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)218
219
220def check_all_strides(221a: TensorLikeType, b: TensorLikeType, *, only_cuda=True222) -> Tuple[bool, Optional[int]]:223return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)224
225
226# This function is equivalent to compute_contiguous() from TensorImpl.cpp
227def is_contiguous(a: TensorLikeType) -> bool:228"""229Tests whether a tensor is contiguous or not.
230
231Tensors are contiguous when they have no elements,
232one element, or when they have "nested" strides.
233"""
234from torch.fx.experimental.symbolic_shapes import guard_size_oblivious235
236if guard_size_oblivious(a.numel() < 2):237return True238
239expected_stride = 1240for x, y in reversed(tuple(zip(a.shape, a.stride()))):241# Skips checking strides when a dimension has length 1242if guard_size_oblivious(x == 1):243continue244
245if guard_size_oblivious(y != expected_stride):246return False247expected_stride = expected_stride * x248
249return True250
251
252# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
253def is_channels_last_contiguous_2d(a: Tensor) -> bool:254# NHWC or not channels last 2D contiguous255if a.ndim != 4:256return False257
258from torch.fx.experimental.symbolic_shapes import guard_size_oblivious259
260expected_stride = 1261for idx in (1, 3, 2, 0):262length = a.shape[idx]263if guard_size_oblivious(length == 1):264continue265
266stride = a.stride()[idx]267if guard_size_oblivious(stride != expected_stride):268return False269
270expected_stride *= length271
272return True273
274
275def is_channels_last_contiguous_3d(a: Tensor) -> bool:276# NDHWC or not channels last 3D contiguous277if a.ndim != 5:278return False279
280expected_stride = 1281for idx in (1, 4, 3, 2, 0):282length = a.shape[idx]283if length == 1:284continue285
286stride = a.stride()[idx]287if stride != expected_stride:288return False289
290expected_stride *= length291
292return True293
294
295_memory_formats = {296torch.contiguous_format,297torch.preserve_format,298torch.channels_last,299torch.channels_last_3d,300}
301
302
303def validate_memory_format(memory_format: torch.memory_format):304torch._check(305memory_format in _memory_formats,306lambda: f"Received unknown memory format {memory_format}!",307)308
309
310def is_contiguous_for_memory_format( # type: ignore[return]311a: Tensor, *, memory_format: torch.memory_format312) -> bool:313validate_memory_format(memory_format)314
315if memory_format == torch.contiguous_format:316return is_contiguous(a)317if memory_format == torch.channels_last:318return is_channels_last_contiguous_2d(a)319if memory_format == torch.channels_last_3d:320return is_channels_last_contiguous_3d(a)321
322torch._check(323False,324lambda: f"is_contiguous received unsupported memory format {memory_format}",325)326
327
328# NOTE: that tensors with no elements and channels last is ???
329def is_channels_last_contiguous(a: Tensor) -> bool:330"""331True when a tensor is channels-last contiguous.
332
333This requires that:
334
335- the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
336- if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
337stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
338each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
339"nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
340for example.
341"""
342return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)343
344
345def is_non_overlapping_and_dense(a: Tensor) -> bool:346"""347True when a tensor is non-overlapping and dense.
348
349A tensor is non-overlapping and dense when there exists a permutation of
350its dimensions that is contiguous.
351"""
352
353from torch.fx.experimental.symbolic_shapes import guard_size_oblivious354
355if a.is_sparse:356return False357
358# Short-circuits if the tensor is already contiguous or channels-last contiguous359if is_contiguous(a) or is_channels_last_contiguous(a):360return True361
362# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp363
364# Short-circuits for tensors of rank one, which are365# non-overlapping and "dense" if their stride is one366if a.ndim == 1:367return a.stride()[0] == 1368
369# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous370# Sorts (length, stride) pairs by stride371#372# This sort is done in a size-oblivious way, which helps if we do a373# comparison like 2048*u0 > u0; we just want this to return True374# (and not worry about what if u0 is zero).375class K(NamedTuple):376size: int377stride: int378
379def __lt__(self, other):380return guard_size_oblivious(self.stride < other.stride)381
382def __gt__(self, other):383return guard_size_oblivious(self.stride > other.stride)384
385def __le__(self, other):386return guard_size_oblivious(self.stride <= other.stride)387
388def __ge__(self, other):389return guard_size_oblivious(self.stride >= other.stride)390
391def __eq__(self, other):392return guard_size_oblivious(self.stride == other.stride)393
394lengths_and_strides = sorted(map(K, a.shape, a.stride()))395
396expected_stride = 1397for length, stride in lengths_and_strides:398if guard_size_oblivious(length == 1):399continue400
401if stride != expected_stride:402return False403
404expected_stride *= length405
406return True407
408
409# NOTE: Based on the implementation in TensorIterator.cpp, but note that
410# the note [Computing output strides] is incorrect, because it
411# says that strides will be preserved even if they are not
412# "non overlapping and dense", but this is incorrect. The
413# output of elementwise operations are always given
414# non overlapping and dense strides.
415# This is also INCORRECT because it does not model TensorIterator's
416# short-circuit, which can cause different strides.
417def compute_elementwise_output_logical_to_physical_perm(418*tensors, _skip_checks=False419) -> List[int]:420from torch.fx.experimental.symbolic_shapes import guard_size_oblivious421
422if not _skip_checks and len(tensors) == 0:423msg = "Can't compute elementwise output strides for zero tensors!"424raise ValueError(msg)425
426if not _skip_checks:427check_same_shape(*tensors, allow_cpu_scalar_tensors=True)428
429# Filters the tensors to actual tensors430if not _skip_checks:431tensors = tuple(432a
433for a in tensors434if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)435)436
437# Short-circuits for CPU scalar case438if len(tensors) == 0:439return []440
441# Short-circuits for shapes with zero or one dimensions442# TODO: are these necessary?443ndim = tensors[0].ndim444if ndim == 0:445return []446if ndim == 1:447return [0]448
449# Short-circuits if contiguous or channels last, following the fake fast path.450# This reduces the number of guards we end up making451is_contiguous = True452is_channels_last = True453for t in tensors:454is_contiguous = is_contiguous and t.is_contiguous(455memory_format=torch.contiguous_format456)457is_channels_last = is_channels_last and t.is_contiguous(458memory_format=torch.channels_last459)460
461if is_contiguous and not is_channels_last:462return list(range(ndim))463
464if is_channels_last and not is_contiguous:465return [0, *list(range(2, ndim)), 1]466
467shape = tensors[0].shape468
469def should_swap(idx_a, idx_b):470for tensor in tensors:471stride_a = tensor.stride()[idx_a]472stride_b = tensor.stride()[idx_b]473
474if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(475stride_b == 0476):477continue478
479if guard_size_oblivious(stride_a < stride_b):480return -1481
482if guard_size_oblivious(stride_a > stride_b):483return 1484
485# stride_a == stride_b486if guard_size_oblivious(shape[idx_a] > shape[idx_b]):487return 1488
489# Note: this case is hit if all strides are zero,490# or all strides are equal and all dimensions have the same length491return 0492
493# The "sort" order for the permutation is back-to-front, but494# the natural order for permutations is front-to-back. Do the495# sorting back-to-front and then reverse it on output.496#497# also, note this returns the logical to physical shape permutation498perm = list(reversed(range(ndim)))499
500# insertion sort with support for ambiguous comparisons501for i in range(1, ndim):502dim1 = i503for dim0 in reversed(range(i)):504comparison = should_swap(perm[dim0], perm[dim1])505if comparison > 0:506perm[dim0], perm[dim1] = perm[dim1], perm[dim0]507dim1 = dim0508elif comparison < 0:509break510
511return list(reversed(perm))512
513
514def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:515"""516Computes the output strides for elementwise operations.
517"""
518if len(tensors) == 0:519msg = "Can't compute elementwise output strides for zero tensors!"520raise ValueError(msg)521
522check_same_shape(*tensors, allow_cpu_scalar_tensors=True)523
524# Filters the tensors to actual tensors525tensors = tuple(526a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)527)528
529# Short-circuits for CPU scalar case530if len(tensors) == 0:531return ()532
533ndim = tensors[0].ndim534shape = tensors[0].shape535
536if ndim == 0:537return ()538if ndim == 1:539return (1,)540
541logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(542*tensors, _skip_checks=True543)544permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical545
546new_strides = make_contiguous_strides_for(permuted_shape)547permuted_strides = apply_perm(548new_strides, invert_perm(logical_to_physical_perm)549) # to logical550
551return tuple(permuted_strides)552
553
554# Identity permutation is [0, 1, 2]
555def apply_perm(inp, perm):556ndim = len(inp)557permuted_inp = [-1] * ndim558for idx, x in enumerate(perm):559permuted_inp[idx] = inp[x]560return permuted_inp561
562
563def invert_perm(perm):564ndim = len(perm)565new_perm = [-1] * ndim566for idx, x in enumerate(perm):567new_perm[x] = idx568return new_perm569
570
571#
572# Common helper functions
573#
574
575
576def validate_dim_length(length: int):577"""578Validates that an object represents a valid
579dimension length.
580"""
581
582if isinstance(length, (int, torch.SymInt)):583torch._check_is_size(length)584else:585# sometimes called with sympy expression by inductor586assert length >= 0587
588
589def validate_shape(shape: ShapeType):590"""591Validates that a sequence represents a valid shape.
592"""
593
594assert isinstance(shape, Sequence), type(shape)595for l in shape:596validate_dim_length(l)597
598
599def validate_strides(strides: StrideType):600"""601Verifies the object specifies valid strides.
602"""
603
604assert isinstance(strides, Sequence)605for stride in strides:606assert stride >= 0607
608
609def validate_idx(rank: int, idx: int):610"""611Validates that idx is a valid index for the given shape.
612Assumes the index is already canonicalized.
613"""
614
615assert isinstance(idx, Dim)616assert isinstance(rank, Dim)617
618assert idx >= 0 and idx < rank or idx == 0619
620
621def validate_dimension_indices(rank: int, indices: DimsSequenceType):622for idx in indices:623validate_idx(rank, idx)624
625
626def validate_exclusive_idx(rank: int, ex_idx: int):627"""628Validates that ex_idx is a valid exclusive index
629for the given shape.
630"""
631
632assert isinstance(ex_idx, Dim)633assert isinstance(rank, Dim)634assert ex_idx > 0 and ex_idx <= rank635
636
637# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
638# specified using negative indices. If `wrap_scalar` is true then scalar
639# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
640# idx should be in the range [-rank, rank-1].
641def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:642if rank < 0:643msg = f"Rank cannot be negative but got {rank}"644raise IndexError(msg)645
646if rank == 0:647if not wrap_scalar:648msg = f"Dimension specified as {idx} but tensor has no dimensions"649raise IndexError(msg)650rank = 1651
652if idx >= 0 and idx < rank:653return idx654
655if idx < 0:656_idx = idx + rank657else:658_idx = idx659
660if _idx < 0 or _idx >= rank:661# Same error message as in aten/src/ATen/WrapDimUtils.h:49662msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})"663raise IndexError(msg)664
665return _idx666
667
668# Takes a dimension or sequence of dimensions and "wraps" them,
669# mapping negative offsets to positive ones
670@overload
671def canonicalize_dims(672rank: int, indices: Sequence[int], wrap_scalar: bool = True673) -> Tuple[int, ...]:674pass675
676
677@overload
678def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:679pass680
681
682def canonicalize_dims(rank, indices, wrap_scalar=True):683if isinstance(indices, Dim):684return canonicalize_dim(rank, indices, wrap_scalar)685
686return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)687
688
689def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:690"""691Validates that perm is a permutation of length rank.
692"""
693
694return isinstance(perm, Sequence) and sorted(perm) == list(range(rank))695
696
697def is_same_shape(a: Sequence, b: Sequence) -> bool:698"""699Compares two shapes a and b, returning True if they are the same
700(their ranks and corresponding lengths match) and False otherwise.
701"""
702
703return tuple(a) == tuple(b)704
705
706def is_cpu_scalar_tensor(a: Any) -> bool:707return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"708
709
710def check_same_device(*args, allow_cpu_scalar_tensors):711"""712Checks that all Tensors in args have the same device.
713
714Raises a RuntimeError when:
715- args contains an object whose type is not Tensor or Number
716- two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
717"""
718# Short-circuits if all (one or fewer) arguments are trivially on the same device719if len(args) <= 1:720return721
722# Note: cannot initialize device to the first arg's device (it may not have one)723device = None724for arg in args:725if isinstance(arg, Number):726continue727elif isinstance(arg, TensorLike):728if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):729continue730
731if device is None:732device = arg.device733
734if device != arg.device:735msg = (736"Tensor on device "737+ str(arg.device)738+ " is not on the expected device "739+ str(device)740+ "!"741)742raise RuntimeError(msg)743else:744msg = (745"Unexpected type when checking for same device, " + str(type(arg)) + "!"746)747raise RuntimeError(msg)748
749
750def canonicalize_device(device: DeviceLikeType) -> torch.device:751if isinstance(device, torch.device):752return device753
754assert isinstance(device, str)755return torch.device(device)756
757
758# Asserts if any of the following are true:
759# - a non-scalar or non-Tensor is given
760# - the shape of any tensors is distinct
761def check_same_shape(*args, allow_cpu_scalar_tensors: bool):762"""763Checks that all Tensors in args have the same shape.
764
765Raises a RuntimeError when:
766- args contains an object whose type is not Tensor or Number
767- two Tensor objects in args have different devices
768"""
769shape = None770
771for arg in args:772if isinstance(arg, Number):773continue774elif isinstance(arg, TensorLike):775if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):776continue777
778if shape is None:779shape = arg.shape780
781if not is_same_shape(shape, arg.shape):782msg = f"Shape {arg.shape} is not the expected shape {shape}!"783raise RuntimeError(msg)784else:785msg = (786"Unexpected type when checking for same shape, " + str(type(arg)) + "!"787)788raise RuntimeError(msg)789
790
791# Acquires a common shape, if it exists, from one or more tensor arguments,
792# filtering number arguments
793def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:794shape = None795scalar_shape = None796
797for arg in args:798if isinstance(arg, Number):799continue800elif isinstance(arg, TensorLike):801if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):802scalar_shape = arg.shape803continue804
805if shape is None:806shape = arg.shape807
808if not is_same_shape(shape, arg.shape):809return None810else:811return None812
813return shape if shape is not None else scalar_shape814
815
816# Extracts dimensions that might be passed either as a list/tuple or as varargs.
817# A typical case is Tensor.permute .
818def extract_dims_from_varargs(819dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]820) -> DimsSequenceType:821if dims and isinstance(dims[0], Sequence):822assert len(dims) == 1823dims = cast(Tuple[DimsSequenceType], dims)824return dims[0]825else:826return cast(DimsSequenceType, dims)827
828
829def extract_shape_from_varargs(830shape: Union[ShapeType, Tuple[ShapeType]],831validate=True,832) -> Tuple[int, ...]:833"""834Returns a shape from varargs.
835
836In PyTorch, operations that accept shapes often accept them as varargs, like
837foo(*shape). However a user can pass the shape as a sequence of integers,
838like this:
839
840foo(1, 2, 3)
841
842or as a sequence of integers
843
844foo((1, 2, 3))
845
846In the first case shape will be a tuple of integers, and in the second case it's a tuple
847containing a tuple of integers. This validates those inputs and canonicalizes them
848to a tuple of integers.
849"""
850
851# Handles tuple unwrapping852if len(shape) == 1 and isinstance(shape[0], Sequence):853shape = shape[0]854
855if validate:856validate_shape(shape) # type: ignore[arg-type]857return shape # type: ignore[return-value]858
859
860def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:861ndim = max(len(a), len(b))862expandedSizes = [0] * ndim863
864for i in range(ndim - 1, -1, -1):865offset = ndim - 1 - i866dimA = len(a) - 1 - offset867dimB = len(b) - 1 - offset868sizeA = a[dimA] if dimA >= 0 else 1869sizeB = b[dimB] if dimB >= 0 else 1870
871torch._check(872(sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),873lambda: (874f"The size of tensor a ({sizeA}) must match the size of "875f"tensor b ({sizeB}) at non-jagged dimension {i}"876),877)878
879# 1s map to the other size (even 0)880expandedSizes[i] = sizeB if sizeA == 1 else sizeA881
882return tuple(expandedSizes)883
884
885def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:886"""887Infers the size of a dim with size -1, if it exists.
888Also checks that new shape is compatible with the number of elements.
889"""
890dim = None891newsize = 1892for i, d in enumerate(shape):893if d == -1:894torch._check(dim is None, lambda: "only one dimension can be inferred")895dim = i896elif d >= 0:897newsize *= d898else:899torch._check(False, lambda: f"invalid shape dimension {d}")900if dim is None:901torch._check(902numel == newsize,903lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",904)905else:906from torch.fx.experimental.symbolic_shapes import definitely_true907
908torch._check(909newsize != 0,910lambda: (911f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "912f"unspecified dimension size -1 can be any value and is ambiguous"913if definitely_true(numel == 0)914else f"shape '{list(shape)}' is invalid for input of size {numel}"915),916)917torch._check(918numel % newsize == 0,919lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",920)921# Convert to list to produce a compatible error message with core922# PyTorch, which prints sequences in square brackets.923shape = list(shape)924shape[dim] = numel // newsize925# NB: This is pretty important when you have unbacked SymInts.926# Suppose you have (i0, 12) resizing into (2, -1, 12). The old927# range for i0 is typically [2, inf], which means if you divide928# by two the new range should be [1, inf]. But this is bad news929# if you have an unbacked SymInt: we need to reapply the unsound930# assumption that the size is >= 2.931torch._check_is_size(shape[dim])932return tuple(shape)933
934
935_integer_dtypes = (936torch.uint8,937torch.uint16,938torch.uint32,939torch.uint64,940torch.int8,941torch.int16,942torch.int32,943torch.int64,944)
945_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)946_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)947
948
949def is_boolean_dtype(dtype: torch.dtype) -> bool:950assert isinstance(dtype, torch.dtype)951return dtype is torch.bool952
953
954def is_integer_dtype(dtype: torch.dtype) -> bool:955assert isinstance(dtype, torch.dtype)956return dtype in _integer_dtypes957
958
959def is_low_precision_dtype(dtype: torch.dtype) -> bool:960assert isinstance(dtype, torch.dtype)961return dtype in _low_precision_dtypes962
963
964def is_float_dtype(dtype: torch.dtype) -> bool:965assert isinstance(dtype, torch.dtype)966return dtype.is_floating_point967
968
969def is_complex_dtype(dtype: torch.dtype) -> bool:970assert isinstance(dtype, torch.dtype)971return dtype in _complex_dtypes972
973
974def is_grad_dtype(dtype: torch.dtype) -> bool:975"""976Checks if the dtype can require a gradient.
977"""
978return dtype.is_floating_point or is_complex_dtype(dtype)979
980
981_complex_to_real_dtype_map = {982torch.complex128: torch.float64,983torch.complex64: torch.float32,984torch.complex32: torch.float16,985}
986
987_real_to_complex_dtype_map = {988torch.float16: torch.complex32,989torch.bfloat16: torch.complex64,990torch.float32: torch.complex64,991torch.float64: torch.complex128,992}
993
994
995def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:996return _complex_to_real_dtype_map[dtype]997
998
999def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:1000return _real_to_complex_dtype_map[dtype]1001
1002
1003def dtype_to_type(dtype: torch.dtype) -> type:1004"""1005Computes the corresponding Python type (AKA "type kind") for the
1006given dtype.
1007"""
1008assert isinstance(dtype, torch.dtype)1009
1010if dtype is torch.bool:1011return bool1012if dtype in _integer_dtypes:1013return int1014if dtype.is_floating_point:1015return float1016if dtype in _complex_dtypes:1017return complex1018
1019raise ValueError("Invalid dtype!")1020
1021
1022def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:1023"""1024Computes the corresponding Python type constructor for the
1025given dtype.
1026"""
1027assert isinstance(dtype, torch.dtype)1028
1029if dtype is torch.bool:1030return lambda x: bool(x)1031if dtype in _integer_dtypes:1032return sym_int1033if dtype.is_floating_point:1034return sym_float1035if dtype in _complex_dtypes:1036# TODO: type error here is real, replace with sym_complex1037return lambda x: complex(x) # type: ignore[arg-type]1038
1039raise ValueError("Invalid dtype!")1040
1041
1042def type_to_dtype(typ: type) -> torch.dtype:1043"""1044Computes the corresponding dtype for a Number type.
1045"""
1046
1047assert isinstance(typ, type)1048
1049if typ in (bool, torch.SymBool):1050return torch.bool1051if typ in (int, torch.SymInt):1052return torch.long1053if typ in (float, torch.SymFloat):1054return torch.get_default_dtype()1055# TODO: sym_complex_float?1056if typ is complex:1057return corresponding_complex_dtype(torch.get_default_dtype())1058
1059raise ValueError(f"Invalid type {typ}!")1060
1061
1062def get_dtype(x: Union[torch.Tensor, NumberType]):1063if isinstance(x, torch.Tensor):1064return x.dtype1065else:1066return type_to_dtype(type(x))1067
1068
1069_ordered_types = (bool, int, float, complex)1070
1071
1072def check_fp_or_complex(1073dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True1074):1075"""1076Checks whether the input is floating point or complex.
1077If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
1078"""
1079torch._check(1080is_float_dtype(dtype) or is_complex_dtype(dtype),1081lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",1082)1083torch._check(1084allow_low_precision_dtypes or not is_low_precision_dtype(dtype),1085lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",1086)1087
1088
1089def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):1090torch._check(1091len(A.shape) >= 2,1092lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",1093)1094
1095
1096def get_higher_type(a: type, b: type) -> type:1097"""1098Returns the higher of the two given Number types.
1099
1100The types are ordered bool -> int -> float -> complex.
1101"""
1102a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)1103# Type checking1104if a not in _ordered_types or b not in _ordered_types:1105raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")1106
1107if a is b:1108return a1109
1110for typ in _ordered_types:1111if a is typ:1112return b1113if b is typ:1114return a1115
1116raise ValueError("Unknown Python scalar type!")1117
1118
1119# Returns the higher of two torch datatypes a and b or, if the two
1120# are not ordered relative to each other, the next
1121# higher datatype
1122def get_higher_dtype(1123a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],1124b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],1125) -> Optional[torch.dtype]:1126"""1127Computes the "lowest" datatype that is weakly
1128"higher" than both a and b.
1129"""
1130
1131# Type checking1132assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))1133assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))1134
1135def _extract_dtype(1136x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]1137) -> Optional[torch.dtype]:1138if x is None:1139return None1140if isinstance(x, torch.dtype):1141return x1142if isinstance(x, TensorLike):1143return x.dtype1144if isinstance(x, Number):1145return type_to_dtype(type(x))1146
1147raise RuntimeError("Unexpected type given to _extract_dtype!")1148
1149a, b = _extract_dtype(a), _extract_dtype(b)1150
1151if a is b:1152return a1153
1154if a is None:1155return b1156
1157if b is None:1158return a1159
1160ordered_datatypes = (1161(torch.bool,),1162(torch.uint8, torch.int8),1163(torch.int16,),1164(torch.int32,),1165(torch.int64,),1166(torch.float16, torch.bfloat16),1167(torch.float32,),1168(torch.float64,),1169(torch.complex32,),1170(torch.complex64,),1171(torch.complex128,),1172)1173
1174for idx, dtypes in enumerate(ordered_datatypes):1175if a in dtypes and b in dtypes:1176return ordered_datatypes[idx + 1][0]1177if a in dtypes:1178return b1179if b in dtypes:1180return a1181
1182raise RuntimeError("Unexpected termination!")1183
1184
1185def check_pin_memory(pin_memory: bool):1186torch._check_not_implemented(1187not pin_memory, lambda: "PrimTorch does not support pinned memory"1188)1189
1190
1191def check_layout(layout: torch.layout):1192torch._check_not_implemented(1193layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}"1194)1195
1196
1197# TODO: maybe unify with can_cast_to?
1198def is_weakly_lesser_type(a: type, b: type) -> bool:1199"""1200Compares two types, a and b, returning True if a is weakly "less" than b.
1201
1202The comparison is determined by the following type ordering: bool, int, float, complex.
1203"""
1204
1205a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)1206
1207if a not in _ordered_types or b not in _ordered_types:1208raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")1209
1210for typ in _ordered_types:1211if a == typ:1212return True1213if b == typ:1214return False1215
1216raise RuntimeError("Unexpected termination!")1217
1218
1219def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:1220for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):1221if fn(cast_to):1222return True1223if fn(cast_from):1224return False1225
1226raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!")1227
1228
1229def check_same_dtype(*args):1230"""1231Checks that all Tensors in args have the same device and that all Numbers have the
1232same corresponding Python type.
1233
1234Raises a RuntimeError when:
1235- args contains an object whose type is not Tensor or Number
1236- two Tensors objects in args have different dtypes
1237- two Number objects in args have different types
1238- there are Tensors and Numbers in args, and one of those Tensors corresponding
1239Python types is different from the type of one of those Numbers
1240"""
1241full_dtype = None1242scalar_type = None1243
1244for arg in args:1245if isinstance(arg, Number):1246# Scalar type checking is disabled (and may be removed in the future)1247continue1248# if scalar_type is None:1249# scalar_type = type(arg)1250
1251# if scalar_type is not type(arg):1252# msg = (1253# "Scalar of type "1254# + str(type(arg))1255# + " is not the expected type of "1256# + str(scalar_type)1257# + "!"1258# )1259# raise RuntimeError(msg)1260elif isinstance(arg, TensorLike):1261if full_dtype is None:1262full_dtype = arg.dtype1263if scalar_type is None:1264scalar_type = dtype_to_type(arg.dtype)1265
1266if full_dtype is not arg.dtype:1267msg = (1268"Tensor with dtype "1269+ str(arg.dtype)1270+ " is not the expected dtype of "1271+ str(full_dtype)1272+ "!"1273)1274raise RuntimeError(msg)1275
1276arg_type = dtype_to_type(arg.dtype)1277if arg_type is not scalar_type:1278msg = (1279"Tensor with corresponding Python type "1280+ str(arg_type)1281+ " is not the expected type of "1282+ str(scalar_type)1283+ "!"1284)1285raise RuntimeError(msg)1286else:1287msg = (1288"Unexpected type when checking for same dtype, " + str(type(arg)) + "!"1289)1290raise RuntimeError(msg)1291
1292
1293# Maps datatypes to their computation types for elementwise operations
1294_computation_dtype_map = {1295torch.bfloat16: torch.float32,1296torch.float16: torch.float32,1297torch.complex32: torch.complex64,1298}
1299
1300
1301def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:1302return _computation_dtype_map.get(dtype, dtype)1303
1304
1305_cpu_acc_type_map = {1306torch.bfloat16: torch.float64,1307torch.float16: torch.float64,1308torch.float32: torch.float64,1309torch.complex32: torch.complex128,1310torch.complex64: torch.complex128,1311}
1312
1313
1314def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:1315# Equivalent to at::toAccumulateType, prefer computation_dtype where possible1316if device.type == "cpu":1317return _cpu_acc_type_map.get(dtype, dtype)1318else:1319return get_computation_dtype(dtype)1320
1321
1322class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):1323DEFAULT = (0,)1324NO_OPMATH = (1,)1325INT_TO_FLOAT = (2,)1326ALWAYS_BOOL = (3,)1327COMPLEX_TO_FLOAT = (4,)1328BOOL_TO_LONG = (5,)1329
1330
1331class REDUCTION_OUTPUT_TYPE_KIND(Enum):1332SAME = (0,)1333COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type1334KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean1335ALWAYS_BOOL = (3,)1336
1337
1338# Describes the return type of the primitive:
1339#
1340# - NEW, a new tensor is created
1341# - VIEW, a view of an input tensor is returned
1342# - INPLACE, one or more input tensors is modified
1343#
1344# these descriptors are mututally exclusive and exhaustive.
1345class RETURN_TYPE(Enum):1346NEW = (0,)1347VIEW = (1,)1348INPLACE = (2,)1349NONE = (3,)1350
1351
1352# TODO: when NumberType contains the sym types, can simplify this
1353def number_type(1354x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]1355) -> Type:1356if isinstance(x, torch.SymInt):1357return int1358elif isinstance(x, torch.SymFloat):1359return float1360elif isinstance(x, torch.SymBool):1361return bool1362else:1363return type(x)1364
1365
1366def expr_type(x: sympy.Basic) -> Type:1367import sympy1368
1369if x.kind is sympy.core.kind.BooleanKind:1370return bool1371elif x.is_integer: # type: ignore[attr-defined]1372return int1373else:1374# NB: Not strictly correct, but we don't support SymPy complex or bool.1375return float1376
1377
1378# TODO: document type promotion kinds
1379def elementwise_dtypes(1380*_args,1381type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,1382) -> Tuple[torch.dtype, torch.dtype]:1383"""1384Computes the computation and result dtypes for elementwise type promotion
1385on the given arguments and with the given elementwise type promotion kind.
1386
1387Note that not all inputs to an elementwise operation necessarily participate in type promotion.
1388For example, the "alpha" parameter of torch.add does not participate in type promotion,
1389although it may be cast to the Python type corresponding to the computation dtype that
1390the type promotion algorithm determines.
1391
1392Default elementwise type promotion, which all other type promotion kinds tweak (see below),
1393first decides which of four ordered types to use:
1394
1395bool -> integer -> floating point -> complex
1396
1397The selected type is the "lowest" type in the above list such that all number arguments
1398have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
1399type for their dtype.
1400
1401Once the type is determined, the particular result dtype is found. The dtypes are
1402partially ordered as follows:
1403
1404bool -> uint8, int8 -> int16 -> int32 -> int64 ->
1405float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
1406
1407The result dtype is selected by:
1408- if no tensor's dtype has the same corresponding type as the one selected,
1409then the result dtype is the (default) dtype corresponding to the selected type
1410(for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
1411- if the result type is complex then the dtype is:
1412- the default complex dtype if there are no floating point or complex tensors
1413- if there are floating point or complex tensors with one or more dimensions, then
1414the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1415(for example, double + cfloat -> cdouble)
1416- if there are only floating point or complex tensors with zero dimensions, then
1417the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1418- if the first two cases do not apply, the result dtype is the highest dtype among
1419all tensors with one or more dimensions of the output type, and if there are no such
1420tensors then it's the highest dtype among all tensors with zero dimensions of the output type
1421(for example, long + half -> half, even if the half tensor has zero dimensions)
1422
1423The "corresponding complex dtypes" are:
1424float16 -> complex32
1425bfloat16 -> complex64
1426float32 -> complex64
1427float64 -> complex128
1428complex32 -> complex32
1429complex64 -> complex64
1430complex128 -> complex128
1431
1432The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
1433dtype by mapping low precision floating point and complex dtypes as follows:
1434
1435float16 -> float32
1436bfloat16 -> float32
1437complex32 -> complex64
1438
1439This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
1440computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
1441which perform no mathematical operations on their tensors (see below for examples).
1442
1443The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype,
1444and computation dtypes to the appropriate op math dtype.
1445
1446The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
1447mapping:
1448
1449complex32 -> float16
1450complex64 -> float32
1451complex128 -> float64
1452
1453Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
1454
1455The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
1456
1457The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
1458
1459Example operators for each type promotion option:
1460DEFAULT : add
1461NO_OPMATH : where, nextafter, cat
1462INT_TO_FLOAT : sin
1463COMPLEX_TO_FLOAT : abs
1464BOOL_TO_LONG : pow
1465ALWAYS_BOOL : eq
1466
1467"""
1468
1469args = tuple(x for x in _args if x is not None)1470
1471highest_type: type = bool1472
1473# Import sympy locally, as importing it eagerly at a module level is too slow1474# See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/15891475import sympy1476
1477for x in args:1478if not isinstance(x, (Number, TensorLike, sympy.Basic)):1479msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"1480raise ValueError(msg)1481
1482if isinstance(x, Number):1483highest_type = get_higher_type(highest_type, number_type(x))1484elif isinstance(x, sympy.Basic):1485highest_type = get_higher_type(highest_type, expr_type(x))1486else:1487# x is a TensorLike1488highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))1489
1490result_dtype = None1491
1492def _find_highest_dtype_filtered(1493args, filter, *, float_as_complex=False1494) -> Optional[torch.dtype]:1495zero_dim_tensor_dtype = None1496one_plus_dim_tensor_dtype = None1497for x in args:1498if isinstance(x, TensorLike) and filter(x.dtype):1499_dtype = x.dtype1500if float_as_complex and is_float_dtype(_dtype):1501_dtype = corresponding_complex_dtype(_dtype)1502if x.ndim == 0:1503zero_dim_tensor_dtype = get_higher_dtype(1504zero_dim_tensor_dtype, _dtype1505)1506else:1507# x.ndim > 01508one_plus_dim_tensor_dtype = get_higher_dtype(1509one_plus_dim_tensor_dtype, _dtype1510)1511
1512# Prefers dtype of tensors with one or more dimensions1513if one_plus_dim_tensor_dtype is not None:1514return one_plus_dim_tensor_dtype1515
1516return zero_dim_tensor_dtype1517
1518if highest_type is float:1519result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)1520result_dtype = (1521torch.get_default_dtype() if result_dtype is None else result_dtype1522)1523elif highest_type is complex:1524result_dtype = _find_highest_dtype_filtered(1525args,1526lambda x: is_float_dtype(x) or is_complex_dtype(x),1527float_as_complex=True,1528)1529if result_dtype is None:1530result_dtype = corresponding_complex_dtype(torch.get_default_dtype())1531elif highest_type is int:1532result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)1533result_dtype = torch.long if result_dtype is None else result_dtype1534else:1535# highest_type is bool1536result_dtype = torch.bool1537
1538if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:1539return get_computation_dtype(result_dtype), result_dtype1540elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:1541return result_dtype, result_dtype1542elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:1543if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):1544result_dtype = torch.get_default_dtype()1545return get_computation_dtype(result_dtype), result_dtype1546elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:1547# NOTE: computation can still occur in a complex dtype1548computation_dtype = get_computation_dtype(result_dtype)1549if is_complex_dtype(result_dtype):1550result_dtype = corresponding_real_dtype(result_dtype)1551return computation_dtype, result_dtype1552elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:1553if is_boolean_dtype(result_dtype):1554return torch.long, torch.long1555return get_computation_dtype(result_dtype), result_dtype1556elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:1557return get_computation_dtype(result_dtype), torch.bool1558else:1559raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}")1560
1561
1562def reduction_dtypes(1563arg,1564output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,1565dtype: Optional[torch.dtype] = None,1566) -> Tuple[torch.dtype, Optional[torch.dtype]]:1567# even though some reductions, like amin or amax, don't strictly require type promotion,1568# all the math ops (including comparisons) are still defined only for a computation type,1569# so promotion will still happen. We are doing it explicitly here1570inp_dtype = dtype if dtype is not None else arg.dtype1571computation_dtype = get_computation_dtype(inp_dtype)1572if (1573output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME1574or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT1575):1576result_dtype = dtype if dtype else arg.dtype1577if (1578output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT1579and is_complex_dtype(result_dtype)1580):1581result_dtype = corresponding_real_dtype(result_dtype)1582elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:1583result_dtype = None1584else: # ALWAYS_BOOL1585result_dtype = torch.bool1586return computation_dtype, result_dtype1587
1588
1589# This function's logic is borrowed from the following functions defined in C++:
1590# batched_matrix_contiguous_strides and contiguous_strides
1591def make_contiguous_strides_for(1592shape: ShapeType, row_major: bool = True1593) -> Tuple[int, ...]:1594"""1595Returns the strides of a contiguous tensor if row_major
1596If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
1597This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
1598"""
1599# contiguous_strides from c10/util/strides.h1600validate_shape(shape)1601if not shape:1602return ()1603
1604from torch.fx.experimental.symbolic_shapes import is_nested_int1605
1606multiplier = 11607strides = []1608for l in reversed(shape):1609strides.append(multiplier)1610multiplier *= l if is_nested_int(l) else sym_max(l, 1)1611
1612result = tuple(reversed(strides))1613
1614# batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h1615if row_major:1616return result1617else:1618if len(shape) < 2:1619return result1620return result[:-2] + (1, max(shape[-2], 1))1621
1622
1623def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:1624torch._check(1625len(shape) == 3,1626lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",1627)1628
1629multiplier = 11630strides = [0] * 31631for idx in (1, -1, 0):1632# NOTE: intentionally divergence from make_contiguous_strides_for1633# This is consistent with eager1634strides[idx] = multiplier1635multiplier *= shape[idx]1636
1637return tuple(strides)1638
1639
1640def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:1641# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?1642torch._check(1643len(shape) == 4,1644lambda: "Only tensors of rank 4 can use the channels_last memory format",1645)1646
1647multiplier = 11648strides = [0] * 41649for idx in (1, -1, -2, 0):1650# NOTE: intentionally divergence from make_contiguous_strides_for1651# This is consistent with eager1652strides[idx] = multiplier1653multiplier *= shape[idx]1654
1655return tuple(strides)1656
1657
1658def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:1659torch._check(1660len(shape) == 5,1661lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",1662)1663
1664multiplier = 11665strides = [0] * 51666for idx in (1, -1, -2, -3, 0):1667# NOTE: intentionally divergence from make_contiguous_strides_for1668# This is consistent with eager1669strides[idx] = multiplier1670multiplier *= shape[idx]1671
1672return tuple(strides)1673
1674
1675def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:1676ndim = len(shape) if isinstance(shape, Sequence) else 11677if ndim == 3:1678return make_channels_last_1d_strides_for(shape)1679elif ndim == 4:1680return make_channels_last_2d_strides_for(shape)1681elif ndim == 5:1682return make_channels_last_3d_strides_for(shape)1683else:1684raise RuntimeError(1685f"no channels last format strides exist in {ndim} dimensions"1686)1687
1688
1689def compute_reduction_output_shape(1690shape: ShapeType, dimensions: Sequence1691) -> Tuple[int, ...]:1692for idx in dimensions:1693validate_idx(len(shape), idx)1694
1695new_shape = []1696for idx in range(len(shape)):1697if idx in dimensions:1698continue1699
1700new_shape.append(shape[idx])1701
1702return tuple(new_shape)1703
1704
1705def validate_no_repeating_dims(dims: Sequence):1706if len(dims) != len(set(dims)):1707raise RuntimeError("duplicate value in the list of dims")1708
1709
1710def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:1711if dims is None:1712return tuple(range(len(shape)))1713dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)1714validate_no_repeating_dims(dims)1715return dims1716
1717
1718def set_correction(1719unbiased: Optional[bool] = None,1720correction: Optional[NumberType] = None,1721) -> float:1722if correction is not None and unbiased is not None:1723raise RuntimeError("cannot specify both correction and unbiased arguments")1724elif correction is None and unbiased is None:1725correction = 1.01726elif correction is None and unbiased is not None:1727correction = 0.0 if unbiased is False else 1.01728# NB: we don't actually support symint here, but it's harmless to accept1729if not isinstance(correction, (IntLike, FloatLike)):1730raise ValueError("correction argument should be integer or float")1731if correction < 0:1732raise ValueError("correction argument should be non-negative")1733return sym_float(correction)1734
1735
1736def compute_required_storage_length(1737shape: ShapeType, strides: StrideType, storage_offset: int1738) -> int:1739"""Computes the minimum storage size to hold the given tensor geometry.1740
1741Example
1742=======
1743
1744This is the size of a newly allocated tensor's storage, in units of elements
1745
1746>>> t = torch.empty((10, 20))
1747>>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
1748200
1749
1750>>> # xdoctest: +SKIP(failing)
1751>>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
1752>>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
1753>>> size == t.storage().size()
1754True
1755
1756A valid tensor may have a larger storage size, but never smaller
1757
1758>>> slice = torch.empty(100)[20:40]
1759>>> slice.storage().size()
1760100
1761
1762>>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
176340
1764
1765"""
1766from torch.fx.experimental.symbolic_shapes import guard_size_oblivious1767
1768# Short-circuits if the shape has no elements1769if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):1770return 01771
1772max_offset = sum((x - 1) * y for x, y in zip(shape, strides))1773# +1 to account for the first element which offsets are taken from1774return 1 + storage_offset + max_offset1775
1776
1777def check_in_bounds_for_storage(1778a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int1779):1780"""1781Determines if the given shape, strides, and offset are valid for the given storage.
1782"""
1783
1784required_length = compute_required_storage_length(shape, strides, storage_offset)1785if a.size() < required_length:1786msg = (1787f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, "1788f"shape of {str(shape)}, and strides of {str(strides)}, "1789f"which requires a storage of size {required_length}"1790)1791raise ValueError(msg)1792
1793
1794# NOTE: This function should ideally be removed, but some Meta internal models
1795# packaged with `torch.package` are using it, so it will have to be removed
1796# at some point in the future when those models no longer use this function.
1797@deprecated(1798"`torch._prims_common.check` is deprecated and will be removed in the future. "1799"Please use `torch._check*` functions instead.",1800category=FutureWarning,1801)
1802def check(1803b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError1804) -> None:1805"""1806Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
1807Error message is a callable producing a string (to avoid wasting time
1808string formatting in non-error case, and also to make it easier for torchdynamo
1809to trace.)
1810
1811.. note:: This function is planned for removal in the future. Please use
1812`torch._check*` functions instead.
1813"""
1814torch._check_with(exc_type, b, s)1815
1816
1817# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
1818# c10/core/MemoryFormat.h into one function
1819def are_strides_like_channels_last(1820shape: Sequence[int], strides: Sequence[int]1821) -> bool:1822from torch.fx.experimental.symbolic_shapes import guard_size_oblivious1823
1824ndim = len(shape)1825
1826if ndim == 4:1827# Check for channels_last_2d1828dim_order = [1, 3, 2, 0]1829elif ndim == 5:1830# Check for channels_last_3d1831dim_order = [1, 4, 3, 2, 0]1832else:1833return False1834
1835if guard_size_oblivious(strides[1] == 0):1836return False1837
1838min = 01839for d in dim_order:1840if guard_size_oblivious(shape[d] == 0):1841return False1842if guard_size_oblivious(strides[d] < min):1843return False1844if d == 0 and min == strides[1]:1845return False1846min = strides[d]1847if guard_size_oblivious(strides[d] > 1):1848min *= shape[d]1849return True1850
1851
1852def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:1853if x.layout != torch.strided:1854return torch.contiguous_format1855
1856if are_strides_like_channels_last(x.shape, x.stride()):1857return torch.channels_last if x.ndim == 4 else torch.channels_last_3d1858
1859return torch.contiguous_format1860
1861
1862def prod(xs: Sequence[NumberType]) -> NumberType:1863"""Product of elements in input sequence. Returns 1 for empty sequence"""1864return reduce(operator.mul, xs, 1)1865
1866
1867def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:1868"""Checks if a shape can be expanded to another shape.1869This is equivalent to checking if the two shapes are broadcastable.
1870"""
1871# This is a Python implementation of1872# aten/src/ATen/ExpandUtils.h:is_expandable_to1873if len(shape) > len(desired):1874return False1875for i in range(len(shape)):1876if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:1877return False1878return True1879
1880
1881def mask_tensor(mask: TensorLikeType, t: TensorLikeType):1882"""1883Similar to torch.where(mask, t, 0) but if t is boolean,
1884result is also boolean and not promoted to int.
1885"""
1886# torch.where(mask, t, False) is equivalent1887# but feels hacky and might break in the future1888if t.dtype is torch.bool:1889return mask.logical_and(t)1890else:1891return torch.where(mask, t, 0)1892
1893
1894def get_aten_op(fn: Callable, name: str):1895"""1896Given the __module__ of reference and its name, it returns
1897(our best guess of) the ATen name of the associated operation
1898
1899Note: In ATen, the __name__ of a function within a module often
1900starts by the module name. E.g. linalg_eigh, or special_zeta
1901"""
1902module = fn.__module__1903prefix = "torch._refs"1904assert module.startswith(prefix)1905module = module[len(prefix) :]1906# We want to go from .special / .nn.functional1907# to special and special_ / nn_functional_1908if module:1909module = module[1:]1910module = module.replace(".", "_")1911module = module + "_"1912return getattr(torch._ops.ops.aten, f"{module}{name}")1913
1914
1915def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:1916return dtype if dtype is not None else torch.get_default_dtype()1917
1918
1919def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType:1920return device if device is not None else torch.device("cpu")1921
1922
1923def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:1924return layout if layout is not None else torch.strided1925
1926
1927def clone_preserve_strides(x):1928needed_size = compute_required_storage_length(1929x.size(), x.stride(), x.storage_offset()1930)1931# Our eager implementations for *_scatter ops are all primitives w.r.t autograd,1932# so these as_strided() calls are not seen by autograd.1933# We need to mimic this behavior in our ref/prim implementations.1934# TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"1935# We should revisit this when we add a compositional as_strided op,1936# and also as part of https://github.com/pytorch/pytorch/issues/905071937try:1938old = torch._C._dispatch_tls_is_dispatch_key_excluded(1939torch._C.DispatchKey.ADInplaceOrView1940)1941torch._C._dispatch_tls_set_dispatch_key_excluded(1942torch._C.DispatchKey.ADInplaceOrView, True1943)1944buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()1945return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())1946finally:1947torch._C._dispatch_tls_set_dispatch_key_excluded(1948torch._C.DispatchKey.ADInplaceOrView, old1949)1950
1951
1952def alert_not_deterministic(caller: str):1953if torch.are_deterministic_algorithms_enabled():1954if torch.is_deterministic_algorithms_warn_only_enabled():1955warnings.warn(1956f"{caller} does not have a deterministic implementation, but you set "1957f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "1958f"You can file an issue at https://github.com/pytorch/pytorch/issues "1959f"to help us prioritize adding deterministic support for this operation."1960)1961else:1962torch._check(1963False,1964lambda: (1965f"{caller} does not have a deterministic implementation, but you set "1966f"'torch.use_deterministic_algorithms(True)'. You can turn off "1967f"determinism just for this operation, or you can use the "1968f"'warn_only=True' option, if that's acceptable for your application. "1969f"You can also file an issue at https://github.com/pytorch/pytorch/issues "1970f"to help us prioritize adding deterministic support for this operation."1971),1972)1973
1974
1975class CUDARngStateHelper:1976@staticmethod1977def get_torch_state_as_tuple(fake_mode=nullcontext()):1978if not torch.cuda.is_available():1979raise RuntimeError("CUDA not available")1980
1981with fake_mode:1982seed = torch.tensor(torch.cuda.initial_seed())1983offset = torch.tensor(torch.cuda._get_rng_state_offset())1984return seed, offset1985
1986@staticmethod1987def set_torch_state_tensor(seed, offset):1988# Rng state is [64-bit seed, 64-bit offset]1989seed_portion = seed.reshape([1]).view(torch.uint8)1990offset_portion = offset.reshape([1]).view(torch.uint8)1991new_state = torch.cat([seed_portion, offset_portion])1992torch.cuda.set_rng_state(new_state)1993
1994@staticmethod1995def set_new_offset(relative_offset):1996torch.cuda._set_rng_state_offset(relative_offset.item())1997