pytorch
1985 строк · 63.3 Кб
1from __future__ import annotations
2
3import operator
4import warnings
5import weakref
6
7from contextlib import nullcontext
8from enum import Enum
9from functools import cmp_to_key, reduce
10from typing import (
11Any,
12Callable,
13cast,
14List,
15NamedTuple,
16Optional,
17overload,
18Sequence,
19Tuple,
20Type,
21TYPE_CHECKING,
22Union,
23)
24
25from typing_extensions import TypeAlias
26
27
28if TYPE_CHECKING:
29# Import the following modules during type checking to enable code intelligence features,
30# such as auto-completion in tools like pylance, even when these modules are not explicitly
31# imported in user code.
32
33import sympy
34
35import torch
36from torch import sym_float, sym_int, sym_max
37
38
39ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
40StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
41DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
42DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
43# TODO: Type[torch.SymInt], Type[torch.SymFloat]
44NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
45# TODO: This needs a lot more type annotations
46# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
47NumberType: TypeAlias = Union[bool, int, float, complex]
48RealNumberType: TypeAlias = Union[bool, int, float]
49
50Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
51# I don't call it Integral because numbers.Integral includes bool, but IntLike
52# does not
53Dim = int
54IntLike = (int, torch.SymInt)
55FloatLike = (float, torch.SymFloat)
56IntWithoutSymInt = int
57FloatWithoutSymFloat = float
58DeviceLikeType: TypeAlias = Union[str, torch.device, int]
59Tensor = torch.Tensor
60
61
62torch_function_passthrough = {
63torch.device,
64torch.sym_not,
65torch.sym_float,
66torch.sym_int,
67torch.sym_max,
68torch.sym_min,
69torch._sym_sqrt, # type: ignore[attr-defined]
70torch.sym_ite,
71torch.Tensor.dim,
72torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
73torch.Tensor.numel,
74torch.Tensor.size,
75torch.Tensor.storage_offset,
76torch.Tensor.stride,
77torch.Tensor.dtype.__get__, # type: ignore[attr-defined]
78torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined]
79torch.Tensor.shape.__get__, # type: ignore[attr-defined]
80torch.Tensor.device.__get__, # type: ignore[attr-defined]
81torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
82torch.Tensor.layout.__get__, # type: ignore[attr-defined]
83torch.Tensor.is_contiguous,
84# For TorchRefsMode only
85torch.Tensor.__format__,
86torch.Tensor.__repr__,
87torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
88}
89
90
91TensorLikeType = torch.Tensor
92TensorLike = torch.Tensor
93TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
94TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
95
96CustomOutParamAnnotation = "__custom_out_param__"
97
98
99def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
100from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
101
102if len(a) != len(b):
103return False
104
105for x, y in zip(a, b):
106if allow_rhs_unbacked:
107# TODO: We should check that the symbols are consistent
108# with each other
109if isinstance(y, torch.SymInt):
110continue
111# NB: Naively, you would not expect to have to do an oblivious guard
112# here because there is seemingly no broadcasting here, but in fact we
113# use this in some situations to determine if we need to do an expand
114# on the tensor because they don't line up, so you can definitely end
115# up trying to prove u0 != 1 in this situation. See
116# python test/test_proxy_tensor.py -k test_cumsum_unbacked
117if guard_size_oblivious(x != y):
118return False
119
120return True
121
122
123def _maybe_get_pytype(t):
124if t is torch.SymFloat:
125return float
126elif t is torch.SymInt:
127return int
128elif t is torch.SymBool:
129return bool
130else:
131return t
132
133
134# TODO: look at using torch.testing.assert_close instead with an option
135# to just compare metadata
136def compare_tensor_meta(
137a: TensorLikeType,
138b: TensorLikeType,
139check_strides=False,
140*,
141allow_rhs_unbacked=False,
142check_conj=True,
143):
144"""
145Checks that two tensor likes have the same shape,
146dtype and device.
147
148In the future this will validate additional metadata, like
149strides.
150"""
151assert isinstance(a, TensorLike)
152assert isinstance(b, TensorLike)
153
154if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
155msg = f"Shapes {a.shape} and {b.shape} are not equal!"
156raise AssertionError(msg)
157
158if a.dtype != b.dtype:
159msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"
160raise AssertionError(msg)
161
162if a.device != b.device:
163# Handles special cuda:0 vs cuda case
164# TODO: we should review why this happens and see about fixing it
165if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
166str(b.device) == "cuda:0" or str(b.device) == "cuda"
167):
168pass
169else:
170msg = f"Devices {a.device} and {b.device} are not equal!"
171raise AssertionError(msg)
172
173# Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
174if check_strides:
175same_strides, idx = check_significant_strides(a, b)
176if not same_strides:
177msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"
178raise RuntimeError(msg)
179
180if a.storage_offset() != b.storage_offset():
181msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"
182raise RuntimeError(msg)
183
184if check_conj:
185if a.is_conj() != b.is_conj():
186raise RuntimeError(
187f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"
188)
189
190if a.is_neg() != b.is_neg():
191raise RuntimeError(
192f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"
193)
194
195
196def _check_strides_helper(
197a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True
198) -> Tuple[bool, Optional[int]]:
199# NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
200# See https://github.com/pytorch/pytorch/issues/77553
201# Only compares strides that are "meaningful" -- strides for dimensions with length > 1
202# and for tensors with more than one element
203if (
204not only_cuda or a.device.type == "cuda" or b.device.type == "cuda"
205) and a.numel() > 0:
206for idx in range(a.ndim):
207check = not significant_only or a.shape[idx] > 1
208if a.stride()[idx] != b.stride()[idx] and check:
209return False, idx
210
211return True, None
212
213
214def check_significant_strides(
215a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
216) -> Tuple[bool, Optional[int]]:
217return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)
218
219
220def check_all_strides(
221a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
222) -> Tuple[bool, Optional[int]]:
223return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
224
225
226# This function is equivalent to compute_contiguous() from TensorImpl.cpp
227def is_contiguous(a: TensorLikeType) -> bool:
228"""
229Tests whether a tensor is contiguous or not.
230
231Tensors are contiguous when they have no elements,
232one element, or when they have "nested" strides.
233"""
234from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
235
236if guard_size_oblivious(a.numel() < 2):
237return True
238
239expected_stride = 1
240for x, y in reversed(tuple(zip(a.shape, a.stride()))):
241# Skips checking strides when a dimension has length 1
242if guard_size_oblivious(x == 1):
243continue
244
245if y != expected_stride:
246return False
247expected_stride = expected_stride * x
248
249return True
250
251
252# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
253def is_channels_last_contiguous_2d(a: Tensor) -> bool:
254# NHWC or not channels last 2D contiguous
255if a.ndim != 4:
256return False
257
258expected_stride = 1
259for idx in (1, 3, 2, 0):
260length = a.shape[idx]
261if length == 1:
262continue
263
264stride = a.stride()[idx]
265if stride != expected_stride:
266return False
267
268expected_stride *= length
269
270return True
271
272
273def is_channels_last_contiguous_3d(a: Tensor) -> bool:
274# NDHWC or not channels last 3D contiguous
275if a.ndim != 5:
276return False
277
278expected_stride = 1
279for idx in (1, 4, 3, 2, 0):
280length = a.shape[idx]
281if length == 1:
282continue
283
284stride = a.stride()[idx]
285if stride != expected_stride:
286return False
287
288expected_stride *= length
289
290return True
291
292
293_memory_formats = {
294torch.contiguous_format,
295torch.preserve_format,
296torch.channels_last,
297torch.channels_last_3d,
298}
299
300
301def validate_memory_format(memory_format: torch.memory_format):
302torch._check(
303memory_format in _memory_formats,
304lambda: f"Received unknown memory format {memory_format}!",
305)
306
307
308def is_contiguous_for_memory_format( # type: ignore[return]
309a: Tensor, *, memory_format: torch.memory_format
310) -> bool:
311validate_memory_format(memory_format)
312
313if memory_format == torch.contiguous_format:
314return is_contiguous(a)
315if memory_format == torch.channels_last:
316return is_channels_last_contiguous_2d(a)
317if memory_format == torch.channels_last_3d:
318return is_channels_last_contiguous_3d(a)
319
320torch._check(
321False,
322lambda: f"is_contiguous received unsupported memory format {memory_format}",
323)
324
325
326# NOTE: that tensors with no elements and channels last is ???
327def is_channels_last_contiguous(a: Tensor) -> bool:
328"""
329True when a tensor is channels-last contiguous.
330
331This requires that:
332
333- the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
334- if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
335stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
336each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
337"nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
338for example.
339"""
340return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
341
342
343def is_non_overlapping_and_dense(a: Tensor) -> bool:
344"""
345True when a tensor is non-overlapping and dense.
346
347A tensor is non-overlapping and dense when there exists a permutation of
348its dimensions that is contiguous.
349"""
350
351from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
352
353if a.is_sparse:
354return False
355
356# Short-circuits if the tensor is already contiguous or channels-last contiguous
357if is_contiguous(a) or is_channels_last_contiguous(a):
358return True
359
360# The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
361
362# Short-circuits for tensors of rank one, which are
363# non-overlapping and "dense" if their stride is one
364if a.ndim == 1:
365return a.stride()[0] == 1
366
367# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
368# Sorts (length, stride) pairs by stride
369#
370# This sort is done in a size-oblivious way, which helps if we do a
371# comparison like 2048*u0 > u0; we just want this to return True
372# (and not worry about what if u0 is zero).
373class K(NamedTuple):
374size: int
375stride: int
376
377def __lt__(self, other):
378return guard_size_oblivious(self.stride < other.stride)
379
380def __gt__(self, other):
381return guard_size_oblivious(self.stride > other.stride)
382
383def __le__(self, other):
384return guard_size_oblivious(self.stride <= other.stride)
385
386def __ge__(self, other):
387return guard_size_oblivious(self.stride >= other.stride)
388
389def __eq__(self, other):
390return guard_size_oblivious(self.stride == other.stride)
391
392lengths_and_strides = sorted(map(K, a.shape, a.stride()))
393
394expected_stride = 1
395for length, stride in lengths_and_strides:
396if guard_size_oblivious(length == 1):
397continue
398
399if stride != expected_stride:
400return False
401
402expected_stride *= length
403
404return True
405
406
407# NOTE: Based on the implementation in TensorIterator.cpp, but note that
408# the note [Computing output strides] is incorrect, because it
409# says that strides will be preserved even if they are not
410# "non overlapping and dense", but this is incorrect. The
411# output of elementwise operations are always given
412# non overlapping and dense strides.
413# This is also INCORRECT because it does not model TensorIterator's
414# short-circuit, which can cause different strides.
415def compute_elementwise_output_logical_to_physical_perm(
416*tensors, _skip_checks=False
417) -> List[int]:
418from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
419
420if not _skip_checks and len(tensors) == 0:
421msg = "Can't compute elementwise output strides for zero tensors!"
422raise ValueError(msg)
423
424if not _skip_checks:
425check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
426
427# Filters the tensors to actual tensors
428if not _skip_checks:
429tensors = tuple(
430a
431for a in tensors
432if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
433)
434
435# Short-circuits for CPU scalar case
436if len(tensors) == 0:
437return []
438
439# Short-circuits for shapes with zero or one dimensions
440# TODO: are these necessary?
441ndim = tensors[0].ndim
442if ndim == 0:
443return []
444if ndim == 1:
445return [0]
446
447# Short-circuits if contiguous, following the fake fast path.
448# This reduces the number of guards we end up making
449# TODO: do channels last too
450is_contiguous = True
451for t in tensors:
452is_contiguous = is_contiguous and t.is_contiguous(
453memory_format=torch.contiguous_format
454)
455
456if is_contiguous:
457return list(range(ndim))
458
459shape = tensors[0].shape
460
461def should_swap(idx_a, idx_b):
462for tensor in tensors:
463stride_a = tensor.stride()[idx_a]
464stride_b = tensor.stride()[idx_b]
465
466if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
467stride_b == 0
468):
469continue
470
471if guard_size_oblivious(stride_a < stride_b):
472return -1
473
474if guard_size_oblivious(stride_a > stride_b):
475return 1
476
477# stride_a == stride_b
478if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
479return 1
480
481# Note: this case is hit if all strides are zero,
482# or all strides are equal and all dimensions have the same length
483return 0
484
485# The "sort" order for the permutation is back-to-front, but
486# the natural order for permutations is front-to-back. Do the
487# sorting back-to-front and then reverse it on output.
488#
489# also, note this returns the logical to physical shape permutation
490perm = list(reversed(range(ndim)))
491
492# insertion sort with support for ambiguous comparisons
493for i in range(1, ndim):
494dim1 = i
495for dim0 in reversed(range(i)):
496comparison = should_swap(perm[dim0], perm[dim1])
497if comparison > 0:
498perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
499dim1 = dim0
500elif comparison < 0:
501break
502
503return list(reversed(perm))
504
505
506def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
507"""
508Computes the output strides for elementwise operations.
509"""
510if len(tensors) == 0:
511msg = "Can't compute elementwise output strides for zero tensors!"
512raise ValueError(msg)
513
514check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
515
516# Filters the tensors to actual tensors
517tensors = tuple(
518a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
519)
520
521# Short-circuits for CPU scalar case
522if len(tensors) == 0:
523return ()
524
525ndim = tensors[0].ndim
526shape = tensors[0].shape
527
528if ndim == 0:
529return ()
530if ndim == 1:
531return (1,)
532
533logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
534*tensors, _skip_checks=True
535)
536permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical
537
538new_strides = make_contiguous_strides_for(permuted_shape)
539permuted_strides = apply_perm(
540new_strides, invert_perm(logical_to_physical_perm)
541) # to logical
542
543return tuple(permuted_strides)
544
545
546# Identity permutation is [0, 1, 2]
547def apply_perm(inp, perm):
548ndim = len(inp)
549permuted_inp = [-1] * ndim
550for idx, x in enumerate(perm):
551permuted_inp[idx] = inp[x]
552return permuted_inp
553
554
555def invert_perm(perm):
556ndim = len(perm)
557new_perm = [-1] * ndim
558for idx, x in enumerate(perm):
559new_perm[x] = idx
560return new_perm
561
562
563#
564# Common helper functions
565#
566
567
568def validate_dim_length(length: int):
569"""
570Validates that an object represents a valid
571dimension length.
572"""
573
574if isinstance(length, (int, torch.SymInt)):
575torch._check_is_size(length)
576else:
577# sometimes called with sympy expression by inductor
578assert length >= 0
579
580
581def validate_shape(shape: ShapeType):
582"""
583Validates that a sequence represents a valid shape.
584"""
585
586assert isinstance(shape, Sequence), type(shape)
587for l in shape:
588validate_dim_length(l)
589
590
591def validate_strides(strides: StrideType):
592"""
593Verifies the object specifies valid strides.
594"""
595
596assert isinstance(strides, Sequence)
597for stride in strides:
598assert stride >= 0
599
600
601def validate_idx(rank: int, idx: int):
602"""
603Validates that idx is a valid index for the given shape.
604Assumes the index is already canonicalized.
605"""
606
607assert isinstance(idx, Dim)
608assert isinstance(rank, Dim)
609
610assert idx >= 0 and idx < rank or idx == 0
611
612
613def validate_dimension_indices(rank: int, indices: DimsSequenceType):
614for idx in indices:
615validate_idx(rank, idx)
616
617
618def validate_exclusive_idx(rank: int, ex_idx: int):
619"""
620Validates that ex_idx is a valid exclusive index
621for the given shape.
622"""
623
624assert isinstance(ex_idx, Dim)
625assert isinstance(rank, Dim)
626assert ex_idx > 0 and ex_idx <= rank
627
628
629# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
630# specified using negative indices. If `wrap_scalar` is true then scalar
631# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
632# idx should be in the range [-rank, rank-1].
633def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
634if rank < 0:
635msg = f"Rank cannot be negative but got {rank}"
636raise IndexError(msg)
637
638if rank == 0:
639if not wrap_scalar:
640msg = f"Dimension specified as {idx} but tensor has no dimensions"
641raise IndexError(msg)
642rank = 1
643
644if idx >= 0 and idx < rank:
645return idx
646
647if idx < 0:
648_idx = idx + rank
649else:
650_idx = idx
651
652if _idx < 0 or _idx >= rank:
653# Same error message as in aten/src/ATen/WrapDimUtils.h:49
654msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})"
655raise IndexError(msg)
656
657return _idx
658
659
660# Takes a dimension or sequence of dimensions and "wraps" them,
661# mapping negative offsets to positive ones
662@overload
663def canonicalize_dims(
664rank: int, indices: Sequence[int], wrap_scalar: bool = True
665) -> Tuple[int, ...]:
666pass
667
668
669@overload
670def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
671pass
672
673
674def canonicalize_dims(rank, indices, wrap_scalar=True):
675if isinstance(indices, Dim):
676return canonicalize_dim(rank, indices, wrap_scalar)
677
678return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)
679
680
681def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
682"""
683Validates that perm is a permutation of length rank.
684"""
685
686if not isinstance(perm, Sequence):
687return False
688
689if not (tuple(sorted(perm)) == tuple(range(0, rank))):
690return False
691
692return True
693
694
695def is_same_shape(a: Sequence, b: Sequence) -> bool:
696"""
697Compares two shapes a and b, returning True if they are the same
698(their ranks and corresponding lengths match) and False otherwise.
699"""
700
701return tuple(a) == tuple(b)
702
703
704def is_cpu_scalar_tensor(a: Any) -> bool:
705return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
706
707
708def check_same_device(*args, allow_cpu_scalar_tensors):
709"""
710Checks that all Tensors in args have the same device.
711
712Raises a RuntimeError when:
713- args contains an object whose type is not Tensor or Number
714- two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
715"""
716# Short-circuits if all (one or fewer) arguments are trivially on the same device
717if len(args) <= 1:
718return
719
720# Note: cannot initialize device to the first arg's device (it may not have one)
721device = None
722for arg in args:
723if isinstance(arg, Number):
724continue
725elif isinstance(arg, TensorLike):
726if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
727continue
728
729if device is None:
730device = arg.device
731
732if device != arg.device:
733msg = (
734"Tensor on device "
735+ str(arg.device)
736+ " is not on the expected device "
737+ str(device)
738+ "!"
739)
740raise RuntimeError(msg)
741else:
742msg = (
743"Unexpected type when checking for same device, " + str(type(arg)) + "!"
744)
745raise RuntimeError(msg)
746
747
748def canonicalize_device(device: DeviceLikeType) -> torch.device:
749if isinstance(device, torch.device):
750return device
751
752assert isinstance(device, str)
753return torch.device(device)
754
755
756# Asserts if any of the following are true:
757# - a non-scalar or non-Tensor is given
758# - the shape of any tensors is distinct
759def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
760"""
761Checks that all Tensors in args have the same shape.
762
763Raises a RuntimeError when:
764- args contains an object whose type is not Tensor or Number
765- two Tensor objects in args have different devices
766"""
767shape = None
768
769for arg in args:
770if isinstance(arg, Number):
771continue
772elif isinstance(arg, TensorLike):
773if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
774continue
775
776if shape is None:
777shape = arg.shape
778
779if not is_same_shape(shape, arg.shape):
780msg = f"Shape {arg.shape} is not the expected shape {shape}!"
781raise RuntimeError(msg)
782else:
783msg = (
784"Unexpected type when checking for same shape, " + str(type(arg)) + "!"
785)
786raise RuntimeError(msg)
787
788
789# Acquires a common shape, if it exists, from one or more tensor arguments,
790# filtering number arguments
791def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
792shape = None
793scalar_shape = None
794
795for arg in args:
796if isinstance(arg, Number):
797continue
798elif isinstance(arg, TensorLike):
799if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
800scalar_shape = arg.shape
801continue
802
803if shape is None:
804shape = arg.shape
805
806if not is_same_shape(shape, arg.shape):
807return None
808else:
809return None
810
811return shape if shape is not None else scalar_shape
812
813
814# Extracts dimensions that might be passed either as a list/tuple or as varargs.
815# A typical case is Tensor.permute .
816def extract_dims_from_varargs(
817dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
818) -> DimsSequenceType:
819if dims and isinstance(dims[0], Sequence):
820assert len(dims) == 1
821dims = cast(Tuple[DimsSequenceType], dims)
822return dims[0]
823else:
824return cast(DimsSequenceType, dims)
825
826
827def extract_shape_from_varargs(
828shape: Union[ShapeType, Tuple[ShapeType]],
829validate=True,
830) -> Tuple[int, ...]:
831"""
832Returns a shape from varargs.
833
834In PyTorch, operations that accept shapes often accept them as varargs, like
835foo(*shape). However a user can pass the shape as a sequence of integers,
836like this:
837
838foo(1, 2, 3)
839
840or as a sequence of integers
841
842foo((1, 2, 3))
843
844In the first case shape will be a tuple of integers, and in the second case it's a tuple
845containing a tuple of integers. This validates those inputs and canonicalizes them
846to a tuple of integers.
847"""
848
849# Handles tuple unwrapping
850if len(shape) == 1 and isinstance(shape[0], Sequence):
851shape = shape[0]
852
853if validate:
854validate_shape(shape) # type: ignore[arg-type]
855return shape # type: ignore[return-value]
856
857
858def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
859ndim = max(len(a), len(b))
860expandedSizes = [0] * ndim
861
862for i in range(ndim - 1, -1, -1):
863offset = ndim - 1 - i
864dimA = len(a) - 1 - offset
865dimB = len(b) - 1 - offset
866sizeA = a[dimA] if dimA >= 0 else 1
867sizeB = b[dimB] if dimB >= 0 else 1
868
869torch._check(
870(sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
871lambda: (
872f"The size of tensor a ({sizeA}) must match the size of "
873f"tensor b ({sizeB}) at non-jagged dimension {i}"
874),
875)
876
877# 1s map to the other size (even 0)
878expandedSizes[i] = sizeB if sizeA == 1 else sizeA
879
880return tuple(expandedSizes)
881
882
883def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
884"""
885Infers the size of a dim with size -1, if it exists.
886Also checks that new shape is compatible with the number of elements.
887"""
888dim = None
889newsize = 1
890for i, d in enumerate(shape):
891if d == -1:
892torch._check(dim is None, lambda: "only one dimension can be inferred")
893dim = i
894elif d >= 0:
895newsize *= d
896else:
897torch._check(False, lambda: f"invalid shape dimension {d}")
898if dim is None:
899torch._check(
900numel == newsize,
901lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
902)
903else:
904from torch.fx.experimental.symbolic_shapes import definitely_true
905
906torch._check(
907newsize != 0,
908lambda: (
909f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
910f"unspecified dimension size -1 can be any value and is ambiguous"
911if definitely_true(numel == 0)
912else f"shape '{list(shape)}' is invalid for input of size {numel}"
913),
914)
915torch._check(
916numel % newsize == 0,
917lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
918)
919# Convert to list to produce a compatible error message with core
920# PyTorch, which prints sequences in square brackets.
921shape = list(shape)
922shape[dim] = numel // newsize
923# NB: This is pretty important when you have unbacked SymInts.
924# Suppose you have (i0, 12) resizing into (2, -1, 12). The old
925# range for i0 is typically [2, inf], which means if you divide
926# by two the new range should be [1, inf]. But this is bad news
927# if you have an unbacked SymInt: we need to reapply the unsound
928# assumption that the size is >= 2.
929torch._check_is_size(shape[dim])
930return tuple(shape)
931
932
933_integer_dtypes = (
934torch.uint8,
935torch.uint16,
936torch.uint32,
937torch.uint64,
938torch.int8,
939torch.int16,
940torch.int32,
941torch.int64,
942)
943_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
944_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
945
946
947def is_boolean_dtype(dtype: torch.dtype) -> bool:
948assert isinstance(dtype, torch.dtype)
949return dtype is torch.bool
950
951
952def is_integer_dtype(dtype: torch.dtype) -> bool:
953assert isinstance(dtype, torch.dtype)
954return dtype in _integer_dtypes
955
956
957def is_low_precision_dtype(dtype: torch.dtype) -> bool:
958assert isinstance(dtype, torch.dtype)
959return dtype in _low_precision_dtypes
960
961
962def is_float_dtype(dtype: torch.dtype) -> bool:
963assert isinstance(dtype, torch.dtype)
964return dtype.is_floating_point
965
966
967def is_complex_dtype(dtype: torch.dtype) -> bool:
968assert isinstance(dtype, torch.dtype)
969return dtype in _complex_dtypes
970
971
972def is_grad_dtype(dtype: torch.dtype) -> bool:
973"""
974Checks if the dtype can require a gradient.
975"""
976return dtype.is_floating_point or is_complex_dtype(dtype)
977
978
979_complex_to_real_dtype_map = {
980torch.complex128: torch.float64,
981torch.complex64: torch.float32,
982torch.complex32: torch.float16,
983}
984
985_real_to_complex_dtype_map = {
986torch.float16: torch.complex32,
987torch.bfloat16: torch.complex64,
988torch.float32: torch.complex64,
989torch.float64: torch.complex128,
990}
991
992
993def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
994return _complex_to_real_dtype_map[dtype]
995
996
997def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
998return _real_to_complex_dtype_map[dtype]
999
1000
1001def dtype_to_type(dtype: torch.dtype) -> type:
1002"""
1003Computes the corresponding Python type (AKA "type kind") for the
1004given dtype.
1005"""
1006assert isinstance(dtype, torch.dtype)
1007
1008if dtype is torch.bool:
1009return bool
1010if dtype in _integer_dtypes:
1011return int
1012if dtype.is_floating_point:
1013return float
1014if dtype in _complex_dtypes:
1015return complex
1016
1017raise ValueError("Invalid dtype!")
1018
1019
1020def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:
1021"""
1022Computes the corresponding Python type constructor for the
1023given dtype.
1024"""
1025assert isinstance(dtype, torch.dtype)
1026
1027if dtype is torch.bool:
1028return lambda x: bool(x)
1029if dtype in _integer_dtypes:
1030return sym_int
1031if dtype.is_floating_point:
1032return sym_float
1033if dtype in _complex_dtypes:
1034# TODO: type error here is real, replace with sym_complex
1035return lambda x: complex(x) # type: ignore[arg-type]
1036
1037raise ValueError("Invalid dtype!")
1038
1039
1040def type_to_dtype(typ: type) -> torch.dtype:
1041"""
1042Computes the corresponding dtype for a Number type.
1043"""
1044
1045assert isinstance(typ, type)
1046
1047if typ is bool:
1048return torch.bool
1049if typ in [int, torch.SymInt]:
1050return torch.long
1051if typ in [float, torch.SymFloat]:
1052return torch.get_default_dtype()
1053# TODO: sym_complex_float?
1054if typ is complex:
1055return corresponding_complex_dtype(torch.get_default_dtype())
1056
1057raise ValueError("Invalid type!")
1058
1059
1060def get_dtype(x: Union[torch.Tensor, NumberType]):
1061if isinstance(x, torch.Tensor):
1062return x.dtype
1063else:
1064return type_to_dtype(type(x))
1065
1066
1067_ordered_types = (bool, int, float, complex)
1068
1069
1070def check_fp_or_complex(
1071dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True
1072):
1073"""
1074Checks whether the input is floating point or complex.
1075If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
1076"""
1077torch._check(
1078is_float_dtype(dtype) or is_complex_dtype(dtype),
1079lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
1080)
1081torch._check(
1082allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
1083lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
1084)
1085
1086
1087def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
1088torch._check(
1089len(A.shape) >= 2,
1090lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
1091)
1092
1093
1094def get_higher_type(a: type, b: type) -> type:
1095"""
1096Returns the higher of the two given Number types.
1097
1098The types are ordered bool -> int -> float -> complex.
1099"""
1100a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1101# Type checking
1102if a not in _ordered_types or b not in _ordered_types:
1103raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1104
1105if a is b:
1106return a
1107
1108for typ in _ordered_types:
1109if a is typ:
1110return b
1111if b is typ:
1112return a
1113
1114raise ValueError("Unknown Python scalar type!")
1115
1116
1117# Returns the higher of two torch datatypes a and b or, if the two
1118# are not ordered relative to each other, the next
1119# higher datatype
1120def get_higher_dtype(
1121a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1122b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1123) -> Optional[torch.dtype]:
1124"""
1125Computes the "lowest" datatype that is weakly
1126"higher" than both a and b.
1127"""
1128
1129# Type checking
1130assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
1131assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
1132
1133def _extract_dtype(
1134x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
1135) -> Optional[torch.dtype]:
1136if x is None:
1137return None
1138if isinstance(x, torch.dtype):
1139return x
1140if isinstance(x, TensorLike):
1141return x.dtype
1142if isinstance(x, Number):
1143return type_to_dtype(type(x))
1144
1145raise RuntimeError("Unexpected type given to _extract_dtype!")
1146
1147a, b = _extract_dtype(a), _extract_dtype(b)
1148
1149if a is b:
1150return a
1151
1152if a is None:
1153return b
1154
1155if b is None:
1156return a
1157
1158ordered_datatypes = (
1159(torch.bool,),
1160(torch.uint8, torch.int8),
1161(torch.int16,),
1162(torch.int32,),
1163(torch.int64,),
1164(torch.float16, torch.bfloat16),
1165(torch.float32,),
1166(torch.float64,),
1167(torch.complex32,),
1168(torch.complex64,),
1169(torch.complex128,),
1170)
1171
1172for idx, dtypes in enumerate(ordered_datatypes):
1173if a in dtypes and b in dtypes:
1174return ordered_datatypes[idx + 1][0]
1175if a in dtypes:
1176return b
1177if b in dtypes:
1178return a
1179
1180raise RuntimeError("Unexpected termination!")
1181
1182
1183def check_pin_memory(pin_memory: bool):
1184torch._check_not_implemented(
1185not pin_memory, lambda: "PrimTorch does not support pinned memory"
1186)
1187
1188
1189def check_layout(layout: torch.layout):
1190torch._check_not_implemented(
1191layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}"
1192)
1193
1194
1195# TODO: maybe unify with can_cast_to?
1196def is_weakly_lesser_type(a: type, b: type) -> bool:
1197"""
1198Compares two types, a and b, returning True if a is weakly "less" than b.
1199
1200The comparison is determined by the following type ordering: bool, int, float, complex.
1201"""
1202
1203a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1204
1205if a not in _ordered_types or b not in _ordered_types:
1206raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1207
1208for typ in _ordered_types:
1209if a == typ:
1210return True
1211if b == typ:
1212return False
1213
1214raise RuntimeError("Unexpected termination!")
1215
1216
1217def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
1218for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
1219if fn(cast_to):
1220return True
1221if fn(cast_from):
1222return False
1223
1224raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!")
1225
1226
1227def check_same_dtype(*args):
1228"""
1229Checks that all Tensors in args have the same device and that all Numbers have the
1230same corresponding Python type.
1231
1232Raises a RuntimeError when:
1233- args contains an object whose type is not Tensor or Number
1234- two Tensors objects in args have different dtypes
1235- two Number objects in args have different types
1236- there are Tensors and Numbers in args, and one of those Tensors corresponding
1237Python types is different from the type of one of those Numbers
1238"""
1239full_dtype = None
1240scalar_type = None
1241
1242for arg in args:
1243if isinstance(arg, Number):
1244# Scalar type checking is disabled (and may be removed in the future)
1245continue
1246# if scalar_type is None:
1247# scalar_type = type(arg)
1248
1249# if scalar_type is not type(arg):
1250# msg = (
1251# "Scalar of type "
1252# + str(type(arg))
1253# + " is not the expected type of "
1254# + str(scalar_type)
1255# + "!"
1256# )
1257# raise RuntimeError(msg)
1258elif isinstance(arg, TensorLike):
1259if full_dtype is None:
1260full_dtype = arg.dtype
1261if scalar_type is None:
1262scalar_type = dtype_to_type(arg.dtype)
1263
1264if full_dtype is not arg.dtype:
1265msg = (
1266"Tensor with dtype "
1267+ str(arg.dtype)
1268+ " is not the expected dtype of "
1269+ str(full_dtype)
1270+ "!"
1271)
1272raise RuntimeError(msg)
1273
1274arg_type = dtype_to_type(arg.dtype)
1275if arg_type is not scalar_type:
1276msg = (
1277"Tensor with corresponding Python type "
1278+ str(arg_type)
1279+ " is not the expected type of "
1280+ str(scalar_type)
1281+ "!"
1282)
1283raise RuntimeError(msg)
1284else:
1285msg = (
1286"Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
1287)
1288raise RuntimeError(msg)
1289
1290
1291# Maps datatypes to their computation types for elementwise operations
1292_computation_dtype_map = {
1293torch.bfloat16: torch.float32,
1294torch.float16: torch.float32,
1295torch.complex32: torch.complex64,
1296}
1297
1298
1299def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
1300return _computation_dtype_map.get(dtype, dtype)
1301
1302
1303_cpu_acc_type_map = {
1304torch.bfloat16: torch.float64,
1305torch.float16: torch.float64,
1306torch.float32: torch.float64,
1307torch.complex32: torch.complex128,
1308torch.complex64: torch.complex128,
1309}
1310
1311
1312def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:
1313# Equivalent to at::toAccumulateType, prefer computation_dtype where possible
1314if device.type == "cpu":
1315return _cpu_acc_type_map.get(dtype, dtype)
1316else:
1317return get_computation_dtype(dtype)
1318
1319
1320class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
1321DEFAULT = (0,)
1322NO_OPMATH = (1,)
1323INT_TO_FLOAT = (2,)
1324ALWAYS_BOOL = (3,)
1325COMPLEX_TO_FLOAT = (4,)
1326BOOL_TO_LONG = (5,)
1327
1328
1329class REDUCTION_OUTPUT_TYPE_KIND(Enum):
1330SAME = (0,)
1331COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type
1332KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean
1333ALWAYS_BOOL = (3,)
1334
1335
1336# Describes the return type of the primitive:
1337#
1338# - NEW, a new tensor is created
1339# - VIEW, a view of an input tensor is returned
1340# - INPLACE, one or more input tensors is modified
1341#
1342# these descriptors are mututally exclusive and exhaustive.
1343class RETURN_TYPE(Enum):
1344NEW = (0,)
1345VIEW = (1,)
1346INPLACE = (2,)
1347
1348
1349# TODO: when NumberType contains the sym types, can simplify this
1350def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
1351if isinstance(x, torch.SymInt):
1352return int
1353elif isinstance(x, torch.SymFloat):
1354return float
1355else:
1356return type(x)
1357
1358
1359def expr_type(x: sympy.Expr) -> Type:
1360if x.is_integer: # type: ignore[attr-defined]
1361return int
1362else:
1363# NB: Not strictly correct, but we don't support SymPy complex or bool.
1364return float
1365
1366
1367# TODO: document type promotion kinds
1368def elementwise_dtypes(
1369*_args,
1370type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
1371) -> Tuple[torch.dtype, torch.dtype]:
1372"""
1373Computes the computation and result dtypes for elementwise type promotion
1374on the given arguments and with the given elementwise type promotion kind.
1375
1376Note that not all inputs to an elementwise operation necessarily participate in type promotion.
1377For example, the "alpha" parameter of torch.add does not participate in type promotion,
1378although it may be cast to the Python type corresponding to the computation dtype that
1379the type promotion algorithm determines.
1380
1381Default elementwise type promotion, which all other type promotion kinds tweak (see below),
1382first decides which of four ordered types to use:
1383
1384bool -> integer -> floating point -> complex
1385
1386The selected type is the "lowest" type in the above list such that all number arguments
1387have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
1388type for their dtype.
1389
1390Once the type is determined, the particular result dtype is found. The dtypes are
1391partially ordered as follows:
1392
1393bool -> uint8, int8 -> int16 -> int32 -> int64 ->
1394float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
1395
1396The result dtype is selected by:
1397- if no tensor's dtype has the same corresponding type as the one selected,
1398then the result dtype is the (default) dtype corresponding to the selected type
1399(for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
1400- if the result type is complex then the dtype is:
1401- the default complex dtype if there are no floating point or complex tensors
1402- if there are floating point or complex tensors with one or more dimensions, then
1403the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1404(for example, double + cfloat -> cdouble)
1405- if there are only floating point or complex tensors with zero dimensions, then
1406the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1407- if the first two cases do not apply, the result dtype is the highest dtype among
1408all tensors with one or more dimensions of the output type, and if there are no such
1409tensors then it's the highest dtype among all tensors with zero dimensions of the output type
1410(for example, long + half -> half, even if the half tensor has zero dimensions)
1411
1412The "corresponding complex dtypes" are:
1413float16 -> complex32
1414bfloat16 -> complex64
1415float32 -> complex64
1416float64 -> complex128
1417complex32 -> complex32
1418complex64 -> complex64
1419complex128 -> complex128
1420
1421The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
1422dtype by mapping low precision floating point and complex dtypes as follows:
1423
1424float16 -> float32
1425bfloat16 -> float32
1426complex32 -> complex64
1427
1428This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
1429computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
1430which perform no mathematical operations on their tensors (see below for examples).
1431
1432The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype,
1433and computation dtypes to the appropriate op math dtype.
1434
1435The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
1436mapping:
1437
1438complex32 -> float16
1439complex64 -> float32
1440complex128 -> float64
1441
1442Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
1443
1444The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
1445
1446The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
1447
1448Example operators for each type promotion option:
1449DEFAULT : add
1450NO_OPMATH : where, nextafter, cat
1451INT_TO_FLOAT : sin
1452COMPLEX_TO_FLOAT : abs
1453BOOL_TO_LONG : pow
1454ALWAYS_BOOL : eq
1455
1456"""
1457
1458args = tuple(x for x in _args if x is not None)
1459
1460highest_type: type = bool
1461
1462# Import sympy locally, as importing it eagerly at a module level is too slow
1463# See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
1464import sympy
1465
1466for x in args:
1467if not isinstance(x, (Number, TensorLike, sympy.Expr)):
1468msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
1469raise ValueError(msg)
1470
1471if isinstance(x, Number):
1472highest_type = get_higher_type(highest_type, number_type(x))
1473elif isinstance(x, sympy.Expr):
1474highest_type = get_higher_type(highest_type, expr_type(x))
1475else:
1476# x is a TensorLike
1477highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
1478
1479result_dtype = None
1480
1481def _find_highest_dtype_filtered(
1482args, filter, *, float_as_complex=False
1483) -> Optional[torch.dtype]:
1484zero_dim_tensor_dtype = None
1485one_plus_dim_tensor_dtype = None
1486for x in args:
1487if isinstance(x, TensorLike) and filter(x.dtype):
1488_dtype = x.dtype
1489if float_as_complex and is_float_dtype(_dtype):
1490_dtype = corresponding_complex_dtype(_dtype)
1491if x.ndim == 0:
1492zero_dim_tensor_dtype = get_higher_dtype(
1493zero_dim_tensor_dtype, _dtype
1494)
1495else:
1496# x.ndim > 0
1497one_plus_dim_tensor_dtype = get_higher_dtype(
1498one_plus_dim_tensor_dtype, _dtype
1499)
1500
1501# Prefers dtype of tensors with one or more dimensions
1502if one_plus_dim_tensor_dtype is not None:
1503return one_plus_dim_tensor_dtype
1504
1505return zero_dim_tensor_dtype
1506
1507if highest_type is float:
1508result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
1509result_dtype = (
1510torch.get_default_dtype() if result_dtype is None else result_dtype
1511)
1512elif highest_type is complex:
1513result_dtype = _find_highest_dtype_filtered(
1514args,
1515lambda x: is_float_dtype(x) or is_complex_dtype(x),
1516float_as_complex=True,
1517)
1518if result_dtype is None:
1519result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
1520elif highest_type is int:
1521result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
1522result_dtype = torch.long if result_dtype is None else result_dtype
1523else:
1524# highest_type is bool
1525result_dtype = torch.bool
1526
1527if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
1528return get_computation_dtype(result_dtype), result_dtype
1529elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
1530return result_dtype, result_dtype
1531elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
1532if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
1533result_dtype = torch.get_default_dtype()
1534return get_computation_dtype(result_dtype), result_dtype
1535elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
1536# NOTE: computation can still occur in a complex dtype
1537computation_dtype = get_computation_dtype(result_dtype)
1538if is_complex_dtype(result_dtype):
1539result_dtype = corresponding_real_dtype(result_dtype)
1540return computation_dtype, result_dtype
1541elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
1542if is_boolean_dtype(result_dtype):
1543return torch.long, torch.long
1544return get_computation_dtype(result_dtype), result_dtype
1545elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
1546return get_computation_dtype(result_dtype), torch.bool
1547else:
1548raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}")
1549
1550
1551def reduction_dtypes(
1552arg,
1553output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
1554dtype: Optional[torch.dtype] = None,
1555) -> Tuple[torch.dtype, Optional[torch.dtype]]:
1556# even though some reductions, like amin or amax, don't strictly require type promotion,
1557# all the math ops (including comparisons) are still defined only for a computation type,
1558# so promotion will still happen. We are doing it explicitly here
1559inp_dtype = dtype if dtype is not None else arg.dtype
1560computation_dtype = get_computation_dtype(inp_dtype)
1561if (
1562output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME
1563or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1564):
1565result_dtype = dtype if dtype else arg.dtype
1566if (
1567output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1568and is_complex_dtype(result_dtype)
1569):
1570result_dtype = corresponding_real_dtype(result_dtype)
1571elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:
1572result_dtype = None
1573else: # ALWAYS_BOOL
1574result_dtype = torch.bool
1575return computation_dtype, result_dtype
1576
1577
1578# This function's logic is borrowed from the following functions defined in C++:
1579# batched_matrix_contiguous_strides and contiguous_strides
1580def make_contiguous_strides_for(
1581shape: ShapeType, row_major: bool = True
1582) -> Tuple[int, ...]:
1583"""
1584Returns the strides of a contiguous tensor if row_major
1585If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
1586This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
1587"""
1588# contiguous_strides from c10/util/strides.h
1589validate_shape(shape)
1590if not shape:
1591return ()
1592
1593from torch.fx.experimental.symbolic_shapes import is_nested_int
1594
1595multiplier = 1
1596strides = []
1597for l in reversed(shape):
1598strides.append(multiplier)
1599multiplier *= l if is_nested_int(l) else sym_max(l, 1)
1600
1601result = tuple(reversed(strides))
1602
1603# batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h
1604if row_major:
1605return result
1606else:
1607if len(shape) < 2:
1608return result
1609return result[:-2] + (1, max(shape[-2], 1))
1610
1611
1612def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1613torch._check(
1614len(shape) == 3,
1615lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
1616)
1617
1618multiplier = 1
1619strides = [0] * 3
1620for idx in (1, -1, 0):
1621# NOTE: intentionally divergence from make_contiguous_strides_for
1622# This is consistent with eager
1623strides[idx] = multiplier
1624multiplier *= shape[idx]
1625
1626return tuple(strides)
1627
1628
1629def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1630# TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
1631torch._check(
1632len(shape) == 4,
1633lambda: "Only tensors of rank 4 can use the channels_last memory format",
1634)
1635
1636multiplier = 1
1637strides = [0] * 4
1638for idx in (1, -1, -2, 0):
1639# NOTE: intentionally divergence from make_contiguous_strides_for
1640# This is consistent with eager
1641strides[idx] = multiplier
1642multiplier *= shape[idx]
1643
1644return tuple(strides)
1645
1646
1647def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1648torch._check(
1649len(shape) == 5,
1650lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
1651)
1652
1653multiplier = 1
1654strides = [0] * 5
1655for idx in (1, -1, -2, -3, 0):
1656# NOTE: intentionally divergence from make_contiguous_strides_for
1657# This is consistent with eager
1658strides[idx] = multiplier
1659multiplier *= shape[idx]
1660
1661return tuple(strides)
1662
1663
1664def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1665ndim = len(shape) if isinstance(shape, Sequence) else 1
1666if ndim == 3:
1667return make_channels_last_1d_strides_for(shape)
1668elif ndim == 4:
1669return make_channels_last_2d_strides_for(shape)
1670elif ndim == 5:
1671return make_channels_last_3d_strides_for(shape)
1672else:
1673raise RuntimeError(
1674f"no channels last format strides exist in {ndim} dimensions"
1675)
1676
1677
1678def compute_reduction_output_shape(
1679shape: ShapeType, dimensions: Sequence
1680) -> Tuple[int, ...]:
1681for idx in dimensions:
1682validate_idx(len(shape), idx)
1683
1684new_shape = []
1685for idx in range(len(shape)):
1686if idx in dimensions:
1687continue
1688
1689new_shape.append(shape[idx])
1690
1691return tuple(new_shape)
1692
1693
1694def validate_no_repeating_dims(dims: Sequence):
1695if len(dims) != len(set(dims)):
1696raise RuntimeError("duplicate value in the list of dims")
1697
1698
1699def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
1700if dims is None:
1701return tuple(range(len(shape)))
1702dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
1703validate_no_repeating_dims(dims)
1704return dims
1705
1706
1707def set_correction(
1708unbiased: Optional[bool] = None,
1709correction: Optional[NumberType] = None,
1710) -> float:
1711if correction is not None and unbiased is not None:
1712raise RuntimeError("cannot specify both correction and unbiased arguments")
1713elif correction is None and unbiased is None:
1714correction = 1.0
1715elif correction is None and unbiased is not None:
1716correction = 0.0 if unbiased is False else 1.0
1717# NB: we don't actually support symint here, but it's harmless to accept
1718if not isinstance(correction, (IntLike, FloatLike)):
1719raise ValueError("correction argument should be integer or float")
1720if correction < 0:
1721raise ValueError("correction argument should be non-negative")
1722return sym_float(correction)
1723
1724
1725def compute_required_storage_length(
1726shape: ShapeType, strides: StrideType, storage_offset: int
1727) -> int:
1728"""Computes the minimum storage size to hold the given tensor geometry.
1729
1730Example
1731=======
1732
1733This is the size of a newly allocated tensor's storage, in units of elements
1734
1735>>> t = torch.empty((10, 20))
1736>>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
1737200
1738
1739>>> # xdoctest: +SKIP(failing)
1740>>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
1741>>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
1742>>> size == t.storage().size()
1743True
1744
1745A valid tensor may have a larger storage size, but never smaller
1746
1747>>> slice = torch.empty(100)[20:40]
1748>>> slice.storage().size()
1749100
1750
1751>>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
175240
1753
1754"""
1755from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1756
1757# Short-circuits if the shape has no elements
1758if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):
1759return 0
1760
1761max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
1762# +1 to account for the first element which offsets are taken from
1763return 1 + storage_offset + max_offset
1764
1765
1766def check_in_bounds_for_storage(
1767a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
1768):
1769"""
1770Determines if the given shape, strides, and offset are valid for the given storage.
1771"""
1772
1773required_length = compute_required_storage_length(shape, strides, storage_offset)
1774if a.size() < required_length:
1775msg = (
1776"Can't view a storage of size {} with an offset of {}, shape of {}, and strides of {}, "
1777"which requires a storage of size {}".format(
1778a.size(), storage_offset, str(shape), str(strides), required_length
1779)
1780)
1781raise ValueError(msg)
1782
1783
1784# NOTE: This function should ideally be removed, but some Meta internal models
1785# packaged with `torch.package` are using it, so it will have to be removed
1786# at some point in the future when those models no longer use this function.
1787def check(
1788b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
1789) -> None:
1790"""
1791Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
1792Error message is a callable producing a string (to avoid wasting time
1793string formatting in non-error case, and also to make it easier for torchdynamo
1794to trace.)
1795
1796.. note:: This function is planned for removal in the future. Please use
1797`torch._check*` functions instead.
1798"""
1799warnings.warn(
1800DeprecationWarning(
1801"'torch._prims_common.check' will be removed in the future. Please use "
1802"'torch._check*' functions instead"
1803)
1804)
1805torch._check_with(exc_type, b, s)
1806
1807
1808# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
1809# c10/core/MemoryFormat.h into one function
1810def are_strides_like_channels_last(
1811shape: Sequence[int], strides: Sequence[int]
1812) -> bool:
1813ndim = len(shape)
1814
1815if ndim == 4:
1816# Check for channels_last_2d
1817dim_order = [1, 3, 2, 0]
1818elif ndim == 5:
1819# Check for channels_last_3d
1820dim_order = [1, 4, 3, 2, 0]
1821else:
1822return False
1823
1824if strides[1] == 0:
1825return False
1826
1827min = 0
1828for d in dim_order:
1829if shape[d] == 0:
1830return False
1831if strides[d] < min:
1832return False
1833if d == 0 and min == strides[1]:
1834return False
1835min = strides[d]
1836if strides[d] > 1:
1837min *= shape[d]
1838return True
1839
1840
1841def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:
1842if x.layout != torch.strided:
1843return torch.contiguous_format
1844
1845if are_strides_like_channels_last(x.shape, x.stride()):
1846return torch.channels_last if x.ndim == 4 else torch.channels_last_3d
1847
1848return torch.contiguous_format
1849
1850
1851def prod(xs: Sequence[NumberType]) -> NumberType:
1852"""Product of elements in input sequence. Returns 1 for empty sequence"""
1853return reduce(operator.mul, xs, 1)
1854
1855
1856def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:
1857"""Checks if a shape can be expanded to another shape.
1858This is equivalent to checking if the two shapes are broadcastable.
1859"""
1860# This is a Python implementation of
1861# aten/src/ATen/ExpandUtils.h:is_expandable_to
1862if len(shape) > len(desired):
1863return False
1864for i in range(len(shape)):
1865if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:
1866return False
1867return True
1868
1869
1870def mask_tensor(mask: TensorLikeType, t: TensorLikeType):
1871"""
1872Similar to torch.where(mask, t, 0) but if t is boolean,
1873result is also boolean and not promoted to int.
1874"""
1875# torch.where(mask, t, False) is equivalent
1876# but feels hacky and might break in the future
1877if t.dtype is torch.bool:
1878return mask.logical_and(t)
1879else:
1880return torch.where(mask, t, 0)
1881
1882
1883def get_aten_op(fn: Callable, name: str):
1884"""
1885Given the __module__ of reference and its name, it returns
1886(our best guess of) the ATen name of the associated operation
1887
1888Note: In ATen, the __name__ of a function within a module often
1889starts by the module name. E.g. linalg_eigh, or special_zeta
1890"""
1891module = fn.__module__
1892prefix = "torch._refs"
1893assert module.startswith(prefix)
1894module = module[len(prefix) :]
1895# We want to go from .special / .nn.functional
1896# to special and special_ / nn_functional_
1897if module:
1898module = module[1:]
1899module = module.replace(".", "_")
1900module = module + "_"
1901return getattr(torch._ops.ops.aten, f"{module}{name}")
1902
1903
1904def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
1905return dtype if dtype is not None else torch.get_default_dtype()
1906
1907
1908def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType:
1909return device if device is not None else torch.device("cpu")
1910
1911
1912def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
1913return layout if layout is not None else torch.strided
1914
1915
1916def clone_preserve_strides(x):
1917needed_size = compute_required_storage_length(
1918x.size(), x.stride(), x.storage_offset()
1919)
1920# Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
1921# so these as_strided() calls are not seen by autograd.
1922# We need to mimic this behavior in our ref/prim implementations.
1923# TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
1924# We should revisit this when we add a compositional as_strided op,
1925# and also as part of https://github.com/pytorch/pytorch/issues/90507
1926try:
1927old = torch._C._dispatch_tls_is_dispatch_key_excluded(
1928torch._C.DispatchKey.ADInplaceOrView
1929)
1930torch._C._dispatch_tls_set_dispatch_key_excluded(
1931torch._C.DispatchKey.ADInplaceOrView, True
1932)
1933buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
1934return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
1935finally:
1936torch._C._dispatch_tls_set_dispatch_key_excluded(
1937torch._C.DispatchKey.ADInplaceOrView, old
1938)
1939
1940
1941def alert_not_deterministic(caller: str):
1942if torch.are_deterministic_algorithms_enabled():
1943if torch.is_deterministic_algorithms_warn_only_enabled():
1944warnings.warn(
1945f"{caller} does not have a deterministic implementation, but you set "
1946f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
1947f"You can file an issue at https://github.com/pytorch/pytorch/issues "
1948f"to help us prioritize adding deterministic support for this operation."
1949)
1950else:
1951torch._check(
1952False,
1953lambda: (
1954f"{caller} does not have a deterministic implementation, but you set "
1955f"'torch.use_deterministic_algorithms(True)'. You can turn off "
1956f"determinism just for this operation, or you can use the "
1957f"'warn_only=True' option, if that's acceptable for your application. "
1958f"You can also file an issue at https://github.com/pytorch/pytorch/issues "
1959f"to help us prioritize adding deterministic support for this operation."
1960),
1961)
1962
1963
1964class CUDARngStateHelper:
1965@staticmethod
1966def get_torch_state_as_tuple(fake_mode=nullcontext()):
1967if not torch.cuda.is_available():
1968raise RuntimeError("CUDA not available")
1969
1970with fake_mode:
1971seed = torch.tensor(torch.cuda.initial_seed())
1972offset = torch.tensor(torch.cuda._get_rng_state_offset())
1973return seed, offset
1974
1975@staticmethod
1976def set_torch_state_tensor(seed, offset):
1977# Rng state is [64-bit seed, 64-bit offset]
1978seed_portion = seed.reshape([1]).view(torch.uint8)
1979offset_portion = offset.reshape([1]).view(torch.uint8)
1980new_state = torch.cat([seed_portion, offset_portion])
1981torch.cuda.set_rng_state(new_state)
1982
1983@staticmethod
1984def set_new_offset(relative_offset):
1985torch.cuda._set_rng_state_offset(relative_offset.item())
1986