pytorch

Форк
0
1985 строк · 63.3 Кб
1
from __future__ import annotations
2

3
import operator
4
import warnings
5
import weakref
6

7
from contextlib import nullcontext
8
from enum import Enum
9
from functools import cmp_to_key, reduce
10
from typing import (
11
    Any,
12
    Callable,
13
    cast,
14
    List,
15
    NamedTuple,
16
    Optional,
17
    overload,
18
    Sequence,
19
    Tuple,
20
    Type,
21
    TYPE_CHECKING,
22
    Union,
23
)
24

25
from typing_extensions import TypeAlias
26

27

28
if TYPE_CHECKING:
29
    # Import the following modules during type checking to enable code intelligence features,
30
    # such as auto-completion in tools like pylance, even when these modules are not explicitly
31
    # imported in user code.
32

33
    import sympy
34

35
import torch
36
from torch import sym_float, sym_int, sym_max
37

38

39
ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
40
StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
41
DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
42
DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
43
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
44
NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
45
# TODO: This needs a lot more type annotations
46
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
47
NumberType: TypeAlias = Union[bool, int, float, complex]
48
RealNumberType: TypeAlias = Union[bool, int, float]
49

50
Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
51
# I don't call it Integral because numbers.Integral includes bool, but IntLike
52
# does not
53
Dim = int
54
IntLike = (int, torch.SymInt)
55
FloatLike = (float, torch.SymFloat)
56
IntWithoutSymInt = int
57
FloatWithoutSymFloat = float
58
DeviceLikeType: TypeAlias = Union[str, torch.device, int]
59
Tensor = torch.Tensor
60

61

62
torch_function_passthrough = {
63
    torch.device,
64
    torch.sym_not,
65
    torch.sym_float,
66
    torch.sym_int,
67
    torch.sym_max,
68
    torch.sym_min,
69
    torch._sym_sqrt,  # type: ignore[attr-defined]
70
    torch.sym_ite,
71
    torch.Tensor.dim,
72
    torch.Tensor.ndim.__get__,  # type: ignore[attr-defined]
73
    torch.Tensor.numel,
74
    torch.Tensor.size,
75
    torch.Tensor.storage_offset,
76
    torch.Tensor.stride,
77
    torch.Tensor.dtype.__get__,  # type: ignore[attr-defined]
78
    torch.Tensor.is_sparse.__get__,  # type: ignore[attr-defined]
79
    torch.Tensor.shape.__get__,  # type: ignore[attr-defined]
80
    torch.Tensor.device.__get__,  # type: ignore[attr-defined]
81
    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
82
    torch.Tensor.layout.__get__,  # type: ignore[attr-defined]
83
    torch.Tensor.is_contiguous,
84
    # For TorchRefsMode only
85
    torch.Tensor.__format__,
86
    torch.Tensor.__repr__,
87
    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
88
}
89

90

91
TensorLikeType = torch.Tensor
92
TensorLike = torch.Tensor
93
TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
94
TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
95

96
CustomOutParamAnnotation = "__custom_out_param__"
97

98

99
def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
100
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
101

102
    if len(a) != len(b):
103
        return False
104

105
    for x, y in zip(a, b):
106
        if allow_rhs_unbacked:
107
            # TODO: We should check that the symbols are consistent
108
            # with each other
109
            if isinstance(y, torch.SymInt):
110
                continue
111
        # NB: Naively, you would not expect to have to do an oblivious guard
112
        # here because there is seemingly no broadcasting here, but in fact we
113
        # use this in some situations to determine if we need to do an expand
114
        # on the tensor because they don't line up, so you can definitely end
115
        # up trying to prove u0 != 1 in this situation.  See
116
        # python test/test_proxy_tensor.py -k test_cumsum_unbacked
117
        if guard_size_oblivious(x != y):
118
            return False
119

120
    return True
121

122

123
def _maybe_get_pytype(t):
124
    if t is torch.SymFloat:
125
        return float
126
    elif t is torch.SymInt:
127
        return int
128
    elif t is torch.SymBool:
129
        return bool
130
    else:
131
        return t
132

133

134
# TODO: look at using torch.testing.assert_close instead with an option
135
#   to just compare metadata
136
def compare_tensor_meta(
137
    a: TensorLikeType,
138
    b: TensorLikeType,
139
    check_strides=False,
140
    *,
141
    allow_rhs_unbacked=False,
142
    check_conj=True,
143
):
144
    """
145
    Checks that two tensor likes have the same shape,
146
    dtype and device.
147

148
    In the future this will validate additional metadata, like
149
    strides.
150
    """
151
    assert isinstance(a, TensorLike)
152
    assert isinstance(b, TensorLike)
153

154
    if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
155
        msg = f"Shapes {a.shape} and {b.shape} are not equal!"
156
        raise AssertionError(msg)
157

158
    if a.dtype != b.dtype:
159
        msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"
160
        raise AssertionError(msg)
161

162
    if a.device != b.device:
163
        # Handles special cuda:0 vs cuda case
164
        # TODO: we should review why this happens and see about fixing it
165
        if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
166
            str(b.device) == "cuda:0" or str(b.device) == "cuda"
167
        ):
168
            pass
169
        else:
170
            msg = f"Devices {a.device} and {b.device} are not equal!"
171
            raise AssertionError(msg)
172

173
    # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
174
    if check_strides:
175
        same_strides, idx = check_significant_strides(a, b)
176
        if not same_strides:
177
            msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"
178
            raise RuntimeError(msg)
179

180
        if a.storage_offset() != b.storage_offset():
181
            msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"
182
            raise RuntimeError(msg)
183

184
    if check_conj:
185
        if a.is_conj() != b.is_conj():
186
            raise RuntimeError(
187
                f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"
188
            )
189

190
    if a.is_neg() != b.is_neg():
191
        raise RuntimeError(
192
            f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"
193
        )
194

195

196
def _check_strides_helper(
197
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True
198
) -> Tuple[bool, Optional[int]]:
199
    # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
200
    # See https://github.com/pytorch/pytorch/issues/77553
201
    # Only compares strides that are "meaningful" -- strides for dimensions with length > 1
202
    # and for tensors with more than one element
203
    if (
204
        not only_cuda or a.device.type == "cuda" or b.device.type == "cuda"
205
    ) and a.numel() > 0:
206
        for idx in range(a.ndim):
207
            check = not significant_only or a.shape[idx] > 1
208
            if a.stride()[idx] != b.stride()[idx] and check:
209
                return False, idx
210

211
    return True, None
212

213

214
def check_significant_strides(
215
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
216
) -> Tuple[bool, Optional[int]]:
217
    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)
218

219

220
def check_all_strides(
221
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
222
) -> Tuple[bool, Optional[int]]:
223
    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
224

225

226
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
227
def is_contiguous(a: TensorLikeType) -> bool:
228
    """
229
    Tests whether a tensor is contiguous or not.
230

231
    Tensors are contiguous when they have no elements,
232
    one element, or when they have "nested" strides.
233
    """
234
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
235

236
    if guard_size_oblivious(a.numel() < 2):
237
        return True
238

239
    expected_stride = 1
240
    for x, y in reversed(tuple(zip(a.shape, a.stride()))):
241
        # Skips checking strides when a dimension has length 1
242
        if guard_size_oblivious(x == 1):
243
            continue
244

245
        if y != expected_stride:
246
            return False
247
        expected_stride = expected_stride * x
248

249
    return True
250

251

252
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
253
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
254
    # NHWC or not channels last 2D contiguous
255
    if a.ndim != 4:
256
        return False
257

258
    expected_stride = 1
259
    for idx in (1, 3, 2, 0):
260
        length = a.shape[idx]
261
        if length == 1:
262
            continue
263

264
        stride = a.stride()[idx]
265
        if stride != expected_stride:
266
            return False
267

268
        expected_stride *= length
269

270
    return True
271

272

273
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
274
    # NDHWC or not channels last 3D contiguous
275
    if a.ndim != 5:
276
        return False
277

278
    expected_stride = 1
279
    for idx in (1, 4, 3, 2, 0):
280
        length = a.shape[idx]
281
        if length == 1:
282
            continue
283

284
        stride = a.stride()[idx]
285
        if stride != expected_stride:
286
            return False
287

288
        expected_stride *= length
289

290
    return True
291

292

293
_memory_formats = {
294
    torch.contiguous_format,
295
    torch.preserve_format,
296
    torch.channels_last,
297
    torch.channels_last_3d,
298
}
299

300

301
def validate_memory_format(memory_format: torch.memory_format):
302
    torch._check(
303
        memory_format in _memory_formats,
304
        lambda: f"Received unknown memory format {memory_format}!",
305
    )
306

307

308
def is_contiguous_for_memory_format(  # type: ignore[return]
309
    a: Tensor, *, memory_format: torch.memory_format
310
) -> bool:
311
    validate_memory_format(memory_format)
312

313
    if memory_format == torch.contiguous_format:
314
        return is_contiguous(a)
315
    if memory_format == torch.channels_last:
316
        return is_channels_last_contiguous_2d(a)
317
    if memory_format == torch.channels_last_3d:
318
        return is_channels_last_contiguous_3d(a)
319

320
    torch._check(
321
        False,
322
        lambda: f"is_contiguous received unsupported memory format {memory_format}",
323
    )
324

325

326
# NOTE: that tensors with no elements and channels last is ???
327
def is_channels_last_contiguous(a: Tensor) -> bool:
328
    """
329
    True when a tensor is channels-last contiguous.
330

331
    This requires that:
332

333
      - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
334
      - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
335
        stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
336
        each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
337
        "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
338
        for example.
339
    """
340
    return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
341

342

343
def is_non_overlapping_and_dense(a: Tensor) -> bool:
344
    """
345
    True when a tensor is non-overlapping and dense.
346

347
    A tensor is non-overlapping and dense when there exists a permutation of
348
    its dimensions that is contiguous.
349
    """
350

351
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
352

353
    if a.is_sparse:
354
        return False
355

356
    # Short-circuits if the tensor is already contiguous or channels-last contiguous
357
    if is_contiguous(a) or is_channels_last_contiguous(a):
358
        return True
359

360
    # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
361

362
    # Short-circuits for tensors of rank one, which are
363
    # non-overlapping and "dense" if their stride is one
364
    if a.ndim == 1:
365
        return a.stride()[0] == 1
366

367
    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
368
    # Sorts (length, stride) pairs by stride
369
    #
370
    # This sort is done in a size-oblivious way, which helps if we do a
371
    # comparison like 2048*u0 > u0; we just want this to return True
372
    # (and not worry about what if u0 is zero).
373
    class K(NamedTuple):
374
        size: int
375
        stride: int
376

377
        def __lt__(self, other):
378
            return guard_size_oblivious(self.stride < other.stride)
379

380
        def __gt__(self, other):
381
            return guard_size_oblivious(self.stride > other.stride)
382

383
        def __le__(self, other):
384
            return guard_size_oblivious(self.stride <= other.stride)
385

386
        def __ge__(self, other):
387
            return guard_size_oblivious(self.stride >= other.stride)
388

389
        def __eq__(self, other):
390
            return guard_size_oblivious(self.stride == other.stride)
391

392
    lengths_and_strides = sorted(map(K, a.shape, a.stride()))
393

394
    expected_stride = 1
395
    for length, stride in lengths_and_strides:
396
        if guard_size_oblivious(length == 1):
397
            continue
398

399
        if stride != expected_stride:
400
            return False
401

402
        expected_stride *= length
403

404
    return True
405

406

407
# NOTE: Based on the implementation in TensorIterator.cpp, but note that
408
# the note [Computing output strides] is incorrect, because it
409
# says that strides will be preserved even if they are not
410
# "non overlapping and dense", but this is incorrect. The
411
# output of elementwise operations are always given
412
# non overlapping and dense strides.
413
# This is also INCORRECT because it does not model TensorIterator's
414
# short-circuit, which can cause different strides.
415
def compute_elementwise_output_logical_to_physical_perm(
416
    *tensors, _skip_checks=False
417
) -> List[int]:
418
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
419

420
    if not _skip_checks and len(tensors) == 0:
421
        msg = "Can't compute elementwise output strides for zero tensors!"
422
        raise ValueError(msg)
423

424
    if not _skip_checks:
425
        check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
426

427
    # Filters the tensors to actual tensors
428
    if not _skip_checks:
429
        tensors = tuple(
430
            a
431
            for a in tensors
432
            if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
433
        )
434

435
    # Short-circuits for CPU scalar case
436
    if len(tensors) == 0:
437
        return []
438

439
    # Short-circuits for shapes with zero or one dimensions
440
    # TODO: are these necessary?
441
    ndim = tensors[0].ndim
442
    if ndim == 0:
443
        return []
444
    if ndim == 1:
445
        return [0]
446

447
    # Short-circuits if contiguous, following the fake fast path.
448
    # This reduces the number of guards we end up making
449
    # TODO: do channels last too
450
    is_contiguous = True
451
    for t in tensors:
452
        is_contiguous = is_contiguous and t.is_contiguous(
453
            memory_format=torch.contiguous_format
454
        )
455

456
    if is_contiguous:
457
        return list(range(ndim))
458

459
    shape = tensors[0].shape
460

461
    def should_swap(idx_a, idx_b):
462
        for tensor in tensors:
463
            stride_a = tensor.stride()[idx_a]
464
            stride_b = tensor.stride()[idx_b]
465

466
            if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
467
                stride_b == 0
468
            ):
469
                continue
470

471
            if guard_size_oblivious(stride_a < stride_b):
472
                return -1
473

474
            if guard_size_oblivious(stride_a > stride_b):
475
                return 1
476

477
            # stride_a == stride_b
478
            if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
479
                return 1
480

481
        # Note: this case is hit if all strides are zero,
482
        # or all strides are equal and all dimensions have the same length
483
        return 0
484

485
    # The "sort" order for the permutation is back-to-front, but
486
    # the natural order for permutations is front-to-back.  Do the
487
    # sorting back-to-front and then reverse it on output.
488
    #
489
    # also, note this returns the logical to physical shape permutation
490
    perm = list(reversed(range(ndim)))
491

492
    # insertion sort with support for ambiguous comparisons
493
    for i in range(1, ndim):
494
        dim1 = i
495
        for dim0 in reversed(range(i)):
496
            comparison = should_swap(perm[dim0], perm[dim1])
497
            if comparison > 0:
498
                perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
499
                dim1 = dim0
500
            elif comparison < 0:
501
                break
502

503
    return list(reversed(perm))
504

505

506
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
507
    """
508
    Computes the output strides for elementwise operations.
509
    """
510
    if len(tensors) == 0:
511
        msg = "Can't compute elementwise output strides for zero tensors!"
512
        raise ValueError(msg)
513

514
    check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
515

516
    # Filters the tensors to actual tensors
517
    tensors = tuple(
518
        a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
519
    )
520

521
    # Short-circuits for CPU scalar case
522
    if len(tensors) == 0:
523
        return ()
524

525
    ndim = tensors[0].ndim
526
    shape = tensors[0].shape
527

528
    if ndim == 0:
529
        return ()
530
    if ndim == 1:
531
        return (1,)
532

533
    logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
534
        *tensors, _skip_checks=True
535
    )
536
    permuted_shape = apply_perm(shape, logical_to_physical_perm)  # to physical
537

538
    new_strides = make_contiguous_strides_for(permuted_shape)
539
    permuted_strides = apply_perm(
540
        new_strides, invert_perm(logical_to_physical_perm)
541
    )  # to logical
542

543
    return tuple(permuted_strides)
544

545

546
# Identity permutation is [0, 1, 2]
547
def apply_perm(inp, perm):
548
    ndim = len(inp)
549
    permuted_inp = [-1] * ndim
550
    for idx, x in enumerate(perm):
551
        permuted_inp[idx] = inp[x]
552
    return permuted_inp
553

554

555
def invert_perm(perm):
556
    ndim = len(perm)
557
    new_perm = [-1] * ndim
558
    for idx, x in enumerate(perm):
559
        new_perm[x] = idx
560
    return new_perm
561

562

563
#
564
# Common helper functions
565
#
566

567

568
def validate_dim_length(length: int):
569
    """
570
    Validates that an object represents a valid
571
    dimension length.
572
    """
573

574
    if isinstance(length, (int, torch.SymInt)):
575
        torch._check_is_size(length)
576
    else:
577
        # sometimes called with sympy expression by inductor
578
        assert length >= 0
579

580

581
def validate_shape(shape: ShapeType):
582
    """
583
    Validates that a sequence represents a valid shape.
584
    """
585

586
    assert isinstance(shape, Sequence), type(shape)
587
    for l in shape:
588
        validate_dim_length(l)
589

590

591
def validate_strides(strides: StrideType):
592
    """
593
    Verifies the object specifies valid strides.
594
    """
595

596
    assert isinstance(strides, Sequence)
597
    for stride in strides:
598
        assert stride >= 0
599

600

601
def validate_idx(rank: int, idx: int):
602
    """
603
    Validates that idx is a valid index for the given shape.
604
    Assumes the index is already canonicalized.
605
    """
606

607
    assert isinstance(idx, Dim)
608
    assert isinstance(rank, Dim)
609

610
    assert idx >= 0 and idx < rank or idx == 0
611

612

613
def validate_dimension_indices(rank: int, indices: DimsSequenceType):
614
    for idx in indices:
615
        validate_idx(rank, idx)
616

617

618
def validate_exclusive_idx(rank: int, ex_idx: int):
619
    """
620
    Validates that ex_idx is a valid exclusive index
621
    for the given shape.
622
    """
623

624
    assert isinstance(ex_idx, Dim)
625
    assert isinstance(rank, Dim)
626
    assert ex_idx > 0 and ex_idx <= rank
627

628

629
# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
630
# specified using negative indices. If `wrap_scalar` is true then scalar
631
# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
632
# idx should be in the range [-rank, rank-1].
633
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
634
    if rank < 0:
635
        msg = f"Rank cannot be negative but got {rank}"
636
        raise IndexError(msg)
637

638
    if rank == 0:
639
        if not wrap_scalar:
640
            msg = f"Dimension specified as {idx} but tensor has no dimensions"
641
            raise IndexError(msg)
642
        rank = 1
643

644
    if idx >= 0 and idx < rank:
645
        return idx
646

647
    if idx < 0:
648
        _idx = idx + rank
649
    else:
650
        _idx = idx
651

652
    if _idx < 0 or _idx >= rank:
653
        # Same error message as in aten/src/ATen/WrapDimUtils.h:49
654
        msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})"
655
        raise IndexError(msg)
656

657
    return _idx
658

659

660
# Takes a dimension or sequence of dimensions and "wraps" them,
661
# mapping negative offsets to positive ones
662
@overload
663
def canonicalize_dims(
664
    rank: int, indices: Sequence[int], wrap_scalar: bool = True
665
) -> Tuple[int, ...]:
666
    pass
667

668

669
@overload
670
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
671
    pass
672

673

674
def canonicalize_dims(rank, indices, wrap_scalar=True):
675
    if isinstance(indices, Dim):
676
        return canonicalize_dim(rank, indices, wrap_scalar)
677

678
    return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)
679

680

681
def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
682
    """
683
    Validates that perm is a permutation of length rank.
684
    """
685

686
    if not isinstance(perm, Sequence):
687
        return False
688

689
    if not (tuple(sorted(perm)) == tuple(range(0, rank))):
690
        return False
691

692
    return True
693

694

695
def is_same_shape(a: Sequence, b: Sequence) -> bool:
696
    """
697
    Compares two shapes a and b, returning True if they are the same
698
    (their ranks and corresponding lengths match) and False otherwise.
699
    """
700

701
    return tuple(a) == tuple(b)
702

703

704
def is_cpu_scalar_tensor(a: Any) -> bool:
705
    return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
706

707

708
def check_same_device(*args, allow_cpu_scalar_tensors):
709
    """
710
    Checks that all Tensors in args have the same device.
711

712
    Raises a RuntimeError when:
713
      - args contains an object whose type is not Tensor or Number
714
      - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
715
    """
716
    # Short-circuits if all (one or fewer) arguments are trivially on the same device
717
    if len(args) <= 1:
718
        return
719

720
    # Note: cannot initialize device to the first arg's device (it may not have one)
721
    device = None
722
    for arg in args:
723
        if isinstance(arg, Number):
724
            continue
725
        elif isinstance(arg, TensorLike):
726
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
727
                continue
728

729
            if device is None:
730
                device = arg.device
731

732
            if device != arg.device:
733
                msg = (
734
                    "Tensor on device "
735
                    + str(arg.device)
736
                    + " is not on the expected device "
737
                    + str(device)
738
                    + "!"
739
                )
740
                raise RuntimeError(msg)
741
        else:
742
            msg = (
743
                "Unexpected type when checking for same device, " + str(type(arg)) + "!"
744
            )
745
            raise RuntimeError(msg)
746

747

748
def canonicalize_device(device: DeviceLikeType) -> torch.device:
749
    if isinstance(device, torch.device):
750
        return device
751

752
    assert isinstance(device, str)
753
    return torch.device(device)
754

755

756
# Asserts if any of the following are true:
757
#   - a non-scalar or non-Tensor is given
758
#   - the shape of any tensors is distinct
759
def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
760
    """
761
    Checks that all Tensors in args have the same shape.
762

763
    Raises a RuntimeError when:
764
      - args contains an object whose type is not Tensor or Number
765
      - two Tensor objects in args have different devices
766
    """
767
    shape = None
768

769
    for arg in args:
770
        if isinstance(arg, Number):
771
            continue
772
        elif isinstance(arg, TensorLike):
773
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
774
                continue
775

776
            if shape is None:
777
                shape = arg.shape
778

779
            if not is_same_shape(shape, arg.shape):
780
                msg = f"Shape {arg.shape} is not the expected shape {shape}!"
781
                raise RuntimeError(msg)
782
        else:
783
            msg = (
784
                "Unexpected type when checking for same shape, " + str(type(arg)) + "!"
785
            )
786
            raise RuntimeError(msg)
787

788

789
# Acquires a common shape, if it exists, from one or more tensor arguments,
790
# filtering number arguments
791
def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
792
    shape = None
793
    scalar_shape = None
794

795
    for arg in args:
796
        if isinstance(arg, Number):
797
            continue
798
        elif isinstance(arg, TensorLike):
799
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
800
                scalar_shape = arg.shape
801
                continue
802

803
            if shape is None:
804
                shape = arg.shape
805

806
            if not is_same_shape(shape, arg.shape):
807
                return None
808
        else:
809
            return None
810

811
    return shape if shape is not None else scalar_shape
812

813

814
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
815
# A typical case is Tensor.permute .
816
def extract_dims_from_varargs(
817
    dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
818
) -> DimsSequenceType:
819
    if dims and isinstance(dims[0], Sequence):
820
        assert len(dims) == 1
821
        dims = cast(Tuple[DimsSequenceType], dims)
822
        return dims[0]
823
    else:
824
        return cast(DimsSequenceType, dims)
825

826

827
def extract_shape_from_varargs(
828
    shape: Union[ShapeType, Tuple[ShapeType]],
829
    validate=True,
830
) -> Tuple[int, ...]:
831
    """
832
    Returns a shape from varargs.
833

834
    In PyTorch, operations that accept shapes often accept them as varargs, like
835
    foo(*shape). However a user can pass the shape as a sequence of integers,
836
    like this:
837

838
      foo(1, 2, 3)
839

840
    or as a sequence of integers
841

842
      foo((1, 2, 3))
843

844
    In the first case shape will be a tuple of integers, and in the second case it's a tuple
845
    containing a tuple of integers. This validates those inputs and canonicalizes them
846
    to a tuple of integers.
847
    """
848

849
    # Handles tuple unwrapping
850
    if len(shape) == 1 and isinstance(shape[0], Sequence):
851
        shape = shape[0]
852

853
    if validate:
854
        validate_shape(shape)  # type: ignore[arg-type]
855
    return shape  # type: ignore[return-value]
856

857

858
def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
859
    ndim = max(len(a), len(b))
860
    expandedSizes = [0] * ndim
861

862
    for i in range(ndim - 1, -1, -1):
863
        offset = ndim - 1 - i
864
        dimA = len(a) - 1 - offset
865
        dimB = len(b) - 1 - offset
866
        sizeA = a[dimA] if dimA >= 0 else 1
867
        sizeB = b[dimB] if dimB >= 0 else 1
868

869
        torch._check(
870
            (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
871
            lambda: (
872
                f"The size of tensor a ({sizeA}) must match the size of "
873
                f"tensor b ({sizeB}) at non-jagged dimension {i}"
874
            ),
875
        )
876

877
        # 1s map to the other size (even 0)
878
        expandedSizes[i] = sizeB if sizeA == 1 else sizeA
879

880
    return tuple(expandedSizes)
881

882

883
def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
884
    """
885
    Infers the size of a dim with size -1, if it exists.
886
    Also checks that new shape is compatible with the number of elements.
887
    """
888
    dim = None
889
    newsize = 1
890
    for i, d in enumerate(shape):
891
        if d == -1:
892
            torch._check(dim is None, lambda: "only one dimension can be inferred")
893
            dim = i
894
        elif d >= 0:
895
            newsize *= d
896
        else:
897
            torch._check(False, lambda: f"invalid shape dimension {d}")
898
    if dim is None:
899
        torch._check(
900
            numel == newsize,
901
            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
902
        )
903
    else:
904
        from torch.fx.experimental.symbolic_shapes import definitely_true
905

906
        torch._check(
907
            newsize != 0,
908
            lambda: (
909
                f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
910
                f"unspecified dimension size -1 can be any value and is ambiguous"
911
                if definitely_true(numel == 0)
912
                else f"shape '{list(shape)}' is invalid for input of size {numel}"
913
            ),
914
        )
915
        torch._check(
916
            numel % newsize == 0,
917
            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
918
        )
919
        # Convert to list to produce a compatible error message with core
920
        # PyTorch, which prints sequences in square brackets.
921
        shape = list(shape)
922
        shape[dim] = numel // newsize
923
        # NB: This is pretty important when you have unbacked SymInts.
924
        # Suppose you have (i0, 12) resizing into (2, -1, 12).  The old
925
        # range for i0 is typically [2, inf], which means if you divide
926
        # by two the new range should be [1, inf].  But this is bad news
927
        # if you have an unbacked SymInt: we need to reapply the unsound
928
        # assumption that the size is >= 2.
929
        torch._check_is_size(shape[dim])
930
    return tuple(shape)
931

932

933
_integer_dtypes = (
934
    torch.uint8,
935
    torch.uint16,
936
    torch.uint32,
937
    torch.uint64,
938
    torch.int8,
939
    torch.int16,
940
    torch.int32,
941
    torch.int64,
942
)
943
_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
944
_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
945

946

947
def is_boolean_dtype(dtype: torch.dtype) -> bool:
948
    assert isinstance(dtype, torch.dtype)
949
    return dtype is torch.bool
950

951

952
def is_integer_dtype(dtype: torch.dtype) -> bool:
953
    assert isinstance(dtype, torch.dtype)
954
    return dtype in _integer_dtypes
955

956

957
def is_low_precision_dtype(dtype: torch.dtype) -> bool:
958
    assert isinstance(dtype, torch.dtype)
959
    return dtype in _low_precision_dtypes
960

961

962
def is_float_dtype(dtype: torch.dtype) -> bool:
963
    assert isinstance(dtype, torch.dtype)
964
    return dtype.is_floating_point
965

966

967
def is_complex_dtype(dtype: torch.dtype) -> bool:
968
    assert isinstance(dtype, torch.dtype)
969
    return dtype in _complex_dtypes
970

971

972
def is_grad_dtype(dtype: torch.dtype) -> bool:
973
    """
974
    Checks if the dtype can require a gradient.
975
    """
976
    return dtype.is_floating_point or is_complex_dtype(dtype)
977

978

979
_complex_to_real_dtype_map = {
980
    torch.complex128: torch.float64,
981
    torch.complex64: torch.float32,
982
    torch.complex32: torch.float16,
983
}
984

985
_real_to_complex_dtype_map = {
986
    torch.float16: torch.complex32,
987
    torch.bfloat16: torch.complex64,
988
    torch.float32: torch.complex64,
989
    torch.float64: torch.complex128,
990
}
991

992

993
def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
994
    return _complex_to_real_dtype_map[dtype]
995

996

997
def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
998
    return _real_to_complex_dtype_map[dtype]
999

1000

1001
def dtype_to_type(dtype: torch.dtype) -> type:
1002
    """
1003
    Computes the corresponding Python type (AKA "type kind") for the
1004
    given dtype.
1005
    """
1006
    assert isinstance(dtype, torch.dtype)
1007

1008
    if dtype is torch.bool:
1009
        return bool
1010
    if dtype in _integer_dtypes:
1011
        return int
1012
    if dtype.is_floating_point:
1013
        return float
1014
    if dtype in _complex_dtypes:
1015
        return complex
1016

1017
    raise ValueError("Invalid dtype!")
1018

1019

1020
def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:
1021
    """
1022
    Computes the corresponding Python type constructor for the
1023
    given dtype.
1024
    """
1025
    assert isinstance(dtype, torch.dtype)
1026

1027
    if dtype is torch.bool:
1028
        return lambda x: bool(x)
1029
    if dtype in _integer_dtypes:
1030
        return sym_int
1031
    if dtype.is_floating_point:
1032
        return sym_float
1033
    if dtype in _complex_dtypes:
1034
        # TODO: type error here is real, replace with sym_complex
1035
        return lambda x: complex(x)  # type: ignore[arg-type]
1036

1037
    raise ValueError("Invalid dtype!")
1038

1039

1040
def type_to_dtype(typ: type) -> torch.dtype:
1041
    """
1042
    Computes the corresponding dtype for a Number type.
1043
    """
1044

1045
    assert isinstance(typ, type)
1046

1047
    if typ is bool:
1048
        return torch.bool
1049
    if typ in [int, torch.SymInt]:
1050
        return torch.long
1051
    if typ in [float, torch.SymFloat]:
1052
        return torch.get_default_dtype()
1053
    # TODO: sym_complex_float?
1054
    if typ is complex:
1055
        return corresponding_complex_dtype(torch.get_default_dtype())
1056

1057
    raise ValueError("Invalid type!")
1058

1059

1060
def get_dtype(x: Union[torch.Tensor, NumberType]):
1061
    if isinstance(x, torch.Tensor):
1062
        return x.dtype
1063
    else:
1064
        return type_to_dtype(type(x))
1065

1066

1067
_ordered_types = (bool, int, float, complex)
1068

1069

1070
def check_fp_or_complex(
1071
    dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True
1072
):
1073
    """
1074
    Checks whether the input is floating point or complex.
1075
    If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
1076
    """
1077
    torch._check(
1078
        is_float_dtype(dtype) or is_complex_dtype(dtype),
1079
        lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
1080
    )
1081
    torch._check(
1082
        allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
1083
        lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
1084
    )
1085

1086

1087
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
1088
    torch._check(
1089
        len(A.shape) >= 2,
1090
        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
1091
    )
1092

1093

1094
def get_higher_type(a: type, b: type) -> type:
1095
    """
1096
    Returns the higher of the two given Number types.
1097

1098
    The types are ordered bool -> int -> float -> complex.
1099
    """
1100
    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1101
    # Type checking
1102
    if a not in _ordered_types or b not in _ordered_types:
1103
        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1104

1105
    if a is b:
1106
        return a
1107

1108
    for typ in _ordered_types:
1109
        if a is typ:
1110
            return b
1111
        if b is typ:
1112
            return a
1113

1114
    raise ValueError("Unknown Python scalar type!")
1115

1116

1117
# Returns the higher of two torch datatypes a and b or, if the two
1118
#   are not ordered relative to each other, the next
1119
#   higher datatype
1120
def get_higher_dtype(
1121
    a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1122
    b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1123
) -> Optional[torch.dtype]:
1124
    """
1125
    Computes the "lowest" datatype that is weakly
1126
    "higher" than both a and b.
1127
    """
1128

1129
    # Type checking
1130
    assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
1131
    assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
1132

1133
    def _extract_dtype(
1134
        x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
1135
    ) -> Optional[torch.dtype]:
1136
        if x is None:
1137
            return None
1138
        if isinstance(x, torch.dtype):
1139
            return x
1140
        if isinstance(x, TensorLike):
1141
            return x.dtype
1142
        if isinstance(x, Number):
1143
            return type_to_dtype(type(x))
1144

1145
        raise RuntimeError("Unexpected type given to _extract_dtype!")
1146

1147
    a, b = _extract_dtype(a), _extract_dtype(b)
1148

1149
    if a is b:
1150
        return a
1151

1152
    if a is None:
1153
        return b
1154

1155
    if b is None:
1156
        return a
1157

1158
    ordered_datatypes = (
1159
        (torch.bool,),
1160
        (torch.uint8, torch.int8),
1161
        (torch.int16,),
1162
        (torch.int32,),
1163
        (torch.int64,),
1164
        (torch.float16, torch.bfloat16),
1165
        (torch.float32,),
1166
        (torch.float64,),
1167
        (torch.complex32,),
1168
        (torch.complex64,),
1169
        (torch.complex128,),
1170
    )
1171

1172
    for idx, dtypes in enumerate(ordered_datatypes):
1173
        if a in dtypes and b in dtypes:
1174
            return ordered_datatypes[idx + 1][0]
1175
        if a in dtypes:
1176
            return b
1177
        if b in dtypes:
1178
            return a
1179

1180
    raise RuntimeError("Unexpected termination!")
1181

1182

1183
def check_pin_memory(pin_memory: bool):
1184
    torch._check_not_implemented(
1185
        not pin_memory, lambda: "PrimTorch does not support pinned memory"
1186
    )
1187

1188

1189
def check_layout(layout: torch.layout):
1190
    torch._check_not_implemented(
1191
        layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}"
1192
    )
1193

1194

1195
# TODO: maybe unify with can_cast_to?
1196
def is_weakly_lesser_type(a: type, b: type) -> bool:
1197
    """
1198
    Compares two types, a and b, returning True if a is weakly "less" than b.
1199

1200
    The comparison is determined by the following type ordering: bool, int, float, complex.
1201
    """
1202

1203
    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1204

1205
    if a not in _ordered_types or b not in _ordered_types:
1206
        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1207

1208
    for typ in _ordered_types:
1209
        if a == typ:
1210
            return True
1211
        if b == typ:
1212
            return False
1213

1214
    raise RuntimeError("Unexpected termination!")
1215

1216

1217
def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
1218
    for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
1219
        if fn(cast_to):
1220
            return True
1221
        if fn(cast_from):
1222
            return False
1223

1224
    raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!")
1225

1226

1227
def check_same_dtype(*args):
1228
    """
1229
    Checks that all Tensors in args have the same device and that all Numbers have the
1230
    same corresponding Python type.
1231

1232
    Raises a RuntimeError when:
1233
      - args contains an object whose type is not Tensor or Number
1234
      - two Tensors objects in args have different dtypes
1235
      - two Number objects in args have different types
1236
      - there are Tensors and Numbers in args, and one of those Tensors corresponding
1237
          Python types is different from the type of one of those Numbers
1238
    """
1239
    full_dtype = None
1240
    scalar_type = None
1241

1242
    for arg in args:
1243
        if isinstance(arg, Number):
1244
            # Scalar type checking is disabled (and may be removed in the future)
1245
            continue
1246
            # if scalar_type is None:
1247
            #     scalar_type = type(arg)
1248

1249
            # if scalar_type is not type(arg):
1250
            #     msg = (
1251
            #         "Scalar of type "
1252
            #         + str(type(arg))
1253
            #         + " is not the expected type of "
1254
            #         + str(scalar_type)
1255
            #         + "!"
1256
            #     )
1257
            #     raise RuntimeError(msg)
1258
        elif isinstance(arg, TensorLike):
1259
            if full_dtype is None:
1260
                full_dtype = arg.dtype
1261
            if scalar_type is None:
1262
                scalar_type = dtype_to_type(arg.dtype)
1263

1264
            if full_dtype is not arg.dtype:
1265
                msg = (
1266
                    "Tensor with dtype "
1267
                    + str(arg.dtype)
1268
                    + " is not the expected dtype of "
1269
                    + str(full_dtype)
1270
                    + "!"
1271
                )
1272
                raise RuntimeError(msg)
1273

1274
            arg_type = dtype_to_type(arg.dtype)
1275
            if arg_type is not scalar_type:
1276
                msg = (
1277
                    "Tensor with corresponding Python type "
1278
                    + str(arg_type)
1279
                    + " is not the expected type of "
1280
                    + str(scalar_type)
1281
                    + "!"
1282
                )
1283
                raise RuntimeError(msg)
1284
        else:
1285
            msg = (
1286
                "Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
1287
            )
1288
            raise RuntimeError(msg)
1289

1290

1291
# Maps datatypes to their computation types for elementwise operations
1292
_computation_dtype_map = {
1293
    torch.bfloat16: torch.float32,
1294
    torch.float16: torch.float32,
1295
    torch.complex32: torch.complex64,
1296
}
1297

1298

1299
def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
1300
    return _computation_dtype_map.get(dtype, dtype)
1301

1302

1303
_cpu_acc_type_map = {
1304
    torch.bfloat16: torch.float64,
1305
    torch.float16: torch.float64,
1306
    torch.float32: torch.float64,
1307
    torch.complex32: torch.complex128,
1308
    torch.complex64: torch.complex128,
1309
}
1310

1311

1312
def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:
1313
    # Equivalent to at::toAccumulateType, prefer computation_dtype where possible
1314
    if device.type == "cpu":
1315
        return _cpu_acc_type_map.get(dtype, dtype)
1316
    else:
1317
        return get_computation_dtype(dtype)
1318

1319

1320
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
1321
    DEFAULT = (0,)
1322
    NO_OPMATH = (1,)
1323
    INT_TO_FLOAT = (2,)
1324
    ALWAYS_BOOL = (3,)
1325
    COMPLEX_TO_FLOAT = (4,)
1326
    BOOL_TO_LONG = (5,)
1327

1328

1329
class REDUCTION_OUTPUT_TYPE_KIND(Enum):
1330
    SAME = (0,)
1331
    COMPLEX_TO_FLOAT = (1,)  # for complex types outputs corresponding real type
1332
    KEEP_PROMOTED_TYPE = (2,)  # keep output in opmath type, needed for mean
1333
    ALWAYS_BOOL = (3,)
1334

1335

1336
# Describes the return type of the primitive:
1337
#
1338
#   - NEW, a new tensor is created
1339
#   - VIEW, a view of an input tensor is returned
1340
#   - INPLACE, one or more input tensors is modified
1341
#
1342
# these descriptors are mututally exclusive and exhaustive.
1343
class RETURN_TYPE(Enum):
1344
    NEW = (0,)
1345
    VIEW = (1,)
1346
    INPLACE = (2,)
1347

1348

1349
# TODO: when NumberType contains the sym types, can simplify this
1350
def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
1351
    if isinstance(x, torch.SymInt):
1352
        return int
1353
    elif isinstance(x, torch.SymFloat):
1354
        return float
1355
    else:
1356
        return type(x)
1357

1358

1359
def expr_type(x: sympy.Expr) -> Type:
1360
    if x.is_integer:  # type: ignore[attr-defined]
1361
        return int
1362
    else:
1363
        # NB: Not strictly correct, but we don't support SymPy complex or bool.
1364
        return float
1365

1366

1367
# TODO: document type promotion kinds
1368
def elementwise_dtypes(
1369
    *_args,
1370
    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
1371
) -> Tuple[torch.dtype, torch.dtype]:
1372
    """
1373
    Computes the computation and result dtypes for elementwise type promotion
1374
    on the given arguments and with the given elementwise type promotion kind.
1375

1376
    Note that not all inputs to an elementwise operation necessarily participate in type promotion.
1377
    For example, the "alpha" parameter of torch.add does not participate in type promotion,
1378
    although it may be cast to the Python type corresponding to the computation dtype that
1379
    the type promotion algorithm determines.
1380

1381
    Default elementwise type promotion, which all other type promotion kinds tweak (see below),
1382
    first decides which of four ordered types to use:
1383

1384
    bool -> integer -> floating point -> complex
1385

1386
    The selected type is the "lowest" type in the above list such that all number arguments
1387
    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
1388
    type for their dtype.
1389

1390
    Once the type is determined, the particular result dtype is found. The dtypes are
1391
    partially ordered as follows:
1392

1393
    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
1394
      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
1395

1396
    The result dtype is selected by:
1397
      - if no tensor's dtype has the same corresponding type as the one selected,
1398
          then the result dtype is the (default) dtype corresponding to the selected type
1399
          (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
1400
      - if the result type is complex then the dtype is:
1401
        -  the default complex dtype if there are no floating point or complex tensors
1402
        -  if there are floating point or complex tensors with one or more dimensions, then
1403
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1404
            (for example, double + cfloat -> cdouble)
1405
        -  if there are only floating point or complex tensors with zero dimensions, then
1406
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1407
      - if the first two cases do not apply, the result dtype is the highest dtype among
1408
          all tensors with one or more dimensions of the output type, and if there are no such
1409
          tensors then it's the highest dtype among all tensors with zero dimensions of the output type
1410
          (for example, long + half -> half, even if the half tensor has zero dimensions)
1411

1412
    The "corresponding complex dtypes" are:
1413
      float16    -> complex32
1414
      bfloat16   -> complex64
1415
      float32    -> complex64
1416
      float64    -> complex128
1417
      complex32  -> complex32
1418
      complex64  -> complex64
1419
      complex128 -> complex128
1420

1421
    The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
1422
    dtype by mapping low precision floating point and complex dtypes as follows:
1423

1424
      float16   -> float32
1425
      bfloat16  -> float32
1426
      complex32 -> complex64
1427

1428
    This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
1429
    computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
1430
    which perform no mathematical operations on their tensors (see below for examples).
1431

1432
    The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype,
1433
    and computation dtypes to the appropriate op math dtype.
1434

1435
    The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
1436
    mapping:
1437

1438
        complex32  -> float16
1439
        complex64  -> float32
1440
        complex128 -> float64
1441

1442
    Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
1443

1444
    The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
1445

1446
    The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
1447

1448
    Example operators for each type promotion option:
1449
      DEFAULT                 : add
1450
      NO_OPMATH               : where, nextafter, cat
1451
      INT_TO_FLOAT            : sin
1452
      COMPLEX_TO_FLOAT        : abs
1453
      BOOL_TO_LONG            : pow
1454
      ALWAYS_BOOL             : eq
1455

1456
    """
1457

1458
    args = tuple(x for x in _args if x is not None)
1459

1460
    highest_type: type = bool
1461

1462
    # Import sympy locally, as importing it eagerly at a module level is too slow
1463
    # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
1464
    import sympy
1465

1466
    for x in args:
1467
        if not isinstance(x, (Number, TensorLike, sympy.Expr)):
1468
            msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
1469
            raise ValueError(msg)
1470

1471
        if isinstance(x, Number):
1472
            highest_type = get_higher_type(highest_type, number_type(x))
1473
        elif isinstance(x, sympy.Expr):
1474
            highest_type = get_higher_type(highest_type, expr_type(x))
1475
        else:
1476
            # x is a TensorLike
1477
            highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
1478

1479
    result_dtype = None
1480

1481
    def _find_highest_dtype_filtered(
1482
        args, filter, *, float_as_complex=False
1483
    ) -> Optional[torch.dtype]:
1484
        zero_dim_tensor_dtype = None
1485
        one_plus_dim_tensor_dtype = None
1486
        for x in args:
1487
            if isinstance(x, TensorLike) and filter(x.dtype):
1488
                _dtype = x.dtype
1489
                if float_as_complex and is_float_dtype(_dtype):
1490
                    _dtype = corresponding_complex_dtype(_dtype)
1491
                if x.ndim == 0:
1492
                    zero_dim_tensor_dtype = get_higher_dtype(
1493
                        zero_dim_tensor_dtype, _dtype
1494
                    )
1495
                else:
1496
                    # x.ndim > 0
1497
                    one_plus_dim_tensor_dtype = get_higher_dtype(
1498
                        one_plus_dim_tensor_dtype, _dtype
1499
                    )
1500

1501
        # Prefers dtype of tensors with one or more dimensions
1502
        if one_plus_dim_tensor_dtype is not None:
1503
            return one_plus_dim_tensor_dtype
1504

1505
        return zero_dim_tensor_dtype
1506

1507
    if highest_type is float:
1508
        result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
1509
        result_dtype = (
1510
            torch.get_default_dtype() if result_dtype is None else result_dtype
1511
        )
1512
    elif highest_type is complex:
1513
        result_dtype = _find_highest_dtype_filtered(
1514
            args,
1515
            lambda x: is_float_dtype(x) or is_complex_dtype(x),
1516
            float_as_complex=True,
1517
        )
1518
        if result_dtype is None:
1519
            result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
1520
    elif highest_type is int:
1521
        result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
1522
        result_dtype = torch.long if result_dtype is None else result_dtype
1523
    else:
1524
        # highest_type is bool
1525
        result_dtype = torch.bool
1526

1527
    if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
1528
        return get_computation_dtype(result_dtype), result_dtype
1529
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
1530
        return result_dtype, result_dtype
1531
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
1532
        if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
1533
            result_dtype = torch.get_default_dtype()
1534
        return get_computation_dtype(result_dtype), result_dtype
1535
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
1536
        # NOTE: computation can still occur in a complex dtype
1537
        computation_dtype = get_computation_dtype(result_dtype)
1538
        if is_complex_dtype(result_dtype):
1539
            result_dtype = corresponding_real_dtype(result_dtype)
1540
        return computation_dtype, result_dtype
1541
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
1542
        if is_boolean_dtype(result_dtype):
1543
            return torch.long, torch.long
1544
        return get_computation_dtype(result_dtype), result_dtype
1545
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
1546
        return get_computation_dtype(result_dtype), torch.bool
1547
    else:
1548
        raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}")
1549

1550

1551
def reduction_dtypes(
1552
    arg,
1553
    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
1554
    dtype: Optional[torch.dtype] = None,
1555
) -> Tuple[torch.dtype, Optional[torch.dtype]]:
1556
    # even though some reductions, like amin or amax, don't strictly require type promotion,
1557
    # all the math ops (including comparisons) are still defined only for a computation type,
1558
    # so promotion will still happen. We are doing it explicitly here
1559
    inp_dtype = dtype if dtype is not None else arg.dtype
1560
    computation_dtype = get_computation_dtype(inp_dtype)
1561
    if (
1562
        output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME
1563
        or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1564
    ):
1565
        result_dtype = dtype if dtype else arg.dtype
1566
        if (
1567
            output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1568
            and is_complex_dtype(result_dtype)
1569
        ):
1570
            result_dtype = corresponding_real_dtype(result_dtype)
1571
    elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:
1572
        result_dtype = None
1573
    else:  # ALWAYS_BOOL
1574
        result_dtype = torch.bool
1575
    return computation_dtype, result_dtype
1576

1577

1578
# This function's logic is borrowed from the following functions defined in C++:
1579
# batched_matrix_contiguous_strides and contiguous_strides
1580
def make_contiguous_strides_for(
1581
    shape: ShapeType, row_major: bool = True
1582
) -> Tuple[int, ...]:
1583
    """
1584
    Returns the strides of a contiguous tensor if row_major
1585
    If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
1586
    This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
1587
    """
1588
    # contiguous_strides from c10/util/strides.h
1589
    validate_shape(shape)
1590
    if not shape:
1591
        return ()
1592

1593
    from torch.fx.experimental.symbolic_shapes import is_nested_int
1594

1595
    multiplier = 1
1596
    strides = []
1597
    for l in reversed(shape):
1598
        strides.append(multiplier)
1599
        multiplier *= l if is_nested_int(l) else sym_max(l, 1)
1600

1601
    result = tuple(reversed(strides))
1602

1603
    # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h
1604
    if row_major:
1605
        return result
1606
    else:
1607
        if len(shape) < 2:
1608
            return result
1609
        return result[:-2] + (1, max(shape[-2], 1))
1610

1611

1612
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1613
    torch._check(
1614
        len(shape) == 3,
1615
        lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
1616
    )
1617

1618
    multiplier = 1
1619
    strides = [0] * 3
1620
    for idx in (1, -1, 0):
1621
        # NOTE: intentionally divergence from make_contiguous_strides_for
1622
        # This is consistent with eager
1623
        strides[idx] = multiplier
1624
        multiplier *= shape[idx]
1625

1626
    return tuple(strides)
1627

1628

1629
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1630
    # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
1631
    torch._check(
1632
        len(shape) == 4,
1633
        lambda: "Only tensors of rank 4 can use the channels_last memory format",
1634
    )
1635

1636
    multiplier = 1
1637
    strides = [0] * 4
1638
    for idx in (1, -1, -2, 0):
1639
        # NOTE: intentionally divergence from make_contiguous_strides_for
1640
        # This is consistent with eager
1641
        strides[idx] = multiplier
1642
        multiplier *= shape[idx]
1643

1644
    return tuple(strides)
1645

1646

1647
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1648
    torch._check(
1649
        len(shape) == 5,
1650
        lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
1651
    )
1652

1653
    multiplier = 1
1654
    strides = [0] * 5
1655
    for idx in (1, -1, -2, -3, 0):
1656
        # NOTE: intentionally divergence from make_contiguous_strides_for
1657
        # This is consistent with eager
1658
        strides[idx] = multiplier
1659
        multiplier *= shape[idx]
1660

1661
    return tuple(strides)
1662

1663

1664
def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1665
    ndim = len(shape) if isinstance(shape, Sequence) else 1
1666
    if ndim == 3:
1667
        return make_channels_last_1d_strides_for(shape)
1668
    elif ndim == 4:
1669
        return make_channels_last_2d_strides_for(shape)
1670
    elif ndim == 5:
1671
        return make_channels_last_3d_strides_for(shape)
1672
    else:
1673
        raise RuntimeError(
1674
            f"no channels last format strides exist in {ndim} dimensions"
1675
        )
1676

1677

1678
def compute_reduction_output_shape(
1679
    shape: ShapeType, dimensions: Sequence
1680
) -> Tuple[int, ...]:
1681
    for idx in dimensions:
1682
        validate_idx(len(shape), idx)
1683

1684
    new_shape = []
1685
    for idx in range(len(shape)):
1686
        if idx in dimensions:
1687
            continue
1688

1689
        new_shape.append(shape[idx])
1690

1691
    return tuple(new_shape)
1692

1693

1694
def validate_no_repeating_dims(dims: Sequence):
1695
    if len(dims) != len(set(dims)):
1696
        raise RuntimeError("duplicate value in the list of dims")
1697

1698

1699
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
1700
    if dims is None:
1701
        return tuple(range(len(shape)))
1702
    dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
1703
    validate_no_repeating_dims(dims)
1704
    return dims
1705

1706

1707
def set_correction(
1708
    unbiased: Optional[bool] = None,
1709
    correction: Optional[NumberType] = None,
1710
) -> float:
1711
    if correction is not None and unbiased is not None:
1712
        raise RuntimeError("cannot specify both correction and unbiased arguments")
1713
    elif correction is None and unbiased is None:
1714
        correction = 1.0
1715
    elif correction is None and unbiased is not None:
1716
        correction = 0.0 if unbiased is False else 1.0
1717
    # NB: we don't actually support symint here, but it's harmless to accept
1718
    if not isinstance(correction, (IntLike, FloatLike)):
1719
        raise ValueError("correction argument should be integer or float")
1720
    if correction < 0:
1721
        raise ValueError("correction argument should be non-negative")
1722
    return sym_float(correction)
1723

1724

1725
def compute_required_storage_length(
1726
    shape: ShapeType, strides: StrideType, storage_offset: int
1727
) -> int:
1728
    """Computes the minimum storage size to hold the given tensor geometry.
1729

1730
    Example
1731
    =======
1732

1733
    This is the size of a newly allocated tensor's storage, in units of elements
1734

1735
    >>> t = torch.empty((10, 20))
1736
    >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
1737
    200
1738

1739
    >>> # xdoctest: +SKIP(failing)
1740
    >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
1741
    >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
1742
    >>> size == t.storage().size()
1743
    True
1744

1745
    A valid tensor may have a larger storage size, but never smaller
1746

1747
    >>> slice = torch.empty(100)[20:40]
1748
    >>> slice.storage().size()
1749
    100
1750

1751
    >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
1752
    40
1753

1754
    """
1755
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1756

1757
    # Short-circuits if the shape has no elements
1758
    if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):
1759
        return 0
1760

1761
    max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
1762
    # +1 to account for the first element which offsets are taken from
1763
    return 1 + storage_offset + max_offset
1764

1765

1766
def check_in_bounds_for_storage(
1767
    a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
1768
):
1769
    """
1770
    Determines if the given shape, strides, and offset are valid for the given storage.
1771
    """
1772

1773
    required_length = compute_required_storage_length(shape, strides, storage_offset)
1774
    if a.size() < required_length:
1775
        msg = (
1776
            "Can't view a storage of size {} with an offset of {}, shape of {}, and strides of {}, "
1777
            "which requires a storage of size {}".format(
1778
                a.size(), storage_offset, str(shape), str(strides), required_length
1779
            )
1780
        )
1781
        raise ValueError(msg)
1782

1783

1784
# NOTE: This function should ideally be removed, but some Meta internal models
1785
# packaged with `torch.package` are using it, so it will have to be removed
1786
# at some point in the future when those models no longer use this function.
1787
def check(
1788
    b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
1789
) -> None:
1790
    """
1791
    Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
1792
    Error message is a callable producing a string (to avoid wasting time
1793
    string formatting in non-error case, and also to make it easier for torchdynamo
1794
    to trace.)
1795

1796
    .. note:: This function is planned for removal in the future. Please use
1797
        `torch._check*` functions instead.
1798
    """
1799
    warnings.warn(
1800
        DeprecationWarning(
1801
            "'torch._prims_common.check' will be removed in the future. Please use "
1802
            "'torch._check*' functions instead"
1803
        )
1804
    )
1805
    torch._check_with(exc_type, b, s)
1806

1807

1808
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
1809
# c10/core/MemoryFormat.h into one function
1810
def are_strides_like_channels_last(
1811
    shape: Sequence[int], strides: Sequence[int]
1812
) -> bool:
1813
    ndim = len(shape)
1814

1815
    if ndim == 4:
1816
        # Check for channels_last_2d
1817
        dim_order = [1, 3, 2, 0]
1818
    elif ndim == 5:
1819
        # Check for channels_last_3d
1820
        dim_order = [1, 4, 3, 2, 0]
1821
    else:
1822
        return False
1823

1824
    if strides[1] == 0:
1825
        return False
1826

1827
    min = 0
1828
    for d in dim_order:
1829
        if shape[d] == 0:
1830
            return False
1831
        if strides[d] < min:
1832
            return False
1833
        if d == 0 and min == strides[1]:
1834
            return False
1835
        min = strides[d]
1836
        if strides[d] > 1:
1837
            min *= shape[d]
1838
    return True
1839

1840

1841
def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:
1842
    if x.layout != torch.strided:
1843
        return torch.contiguous_format
1844

1845
    if are_strides_like_channels_last(x.shape, x.stride()):
1846
        return torch.channels_last if x.ndim == 4 else torch.channels_last_3d
1847

1848
    return torch.contiguous_format
1849

1850

1851
def prod(xs: Sequence[NumberType]) -> NumberType:
1852
    """Product of elements in input sequence. Returns 1 for empty sequence"""
1853
    return reduce(operator.mul, xs, 1)
1854

1855

1856
def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:
1857
    """Checks if a shape can be expanded to another shape.
1858
    This is equivalent to checking if the two shapes are broadcastable.
1859
    """
1860
    # This is a Python implementation of
1861
    # aten/src/ATen/ExpandUtils.h:is_expandable_to
1862
    if len(shape) > len(desired):
1863
        return False
1864
    for i in range(len(shape)):
1865
        if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:
1866
            return False
1867
    return True
1868

1869

1870
def mask_tensor(mask: TensorLikeType, t: TensorLikeType):
1871
    """
1872
    Similar to torch.where(mask, t, 0) but if t is boolean,
1873
    result is also boolean and not promoted to int.
1874
    """
1875
    # torch.where(mask, t, False) is equivalent
1876
    # but feels hacky and might break in the future
1877
    if t.dtype is torch.bool:
1878
        return mask.logical_and(t)
1879
    else:
1880
        return torch.where(mask, t, 0)
1881

1882

1883
def get_aten_op(fn: Callable, name: str):
1884
    """
1885
    Given the __module__ of reference and its name, it returns
1886
    (our best guess of) the ATen name of the associated operation
1887

1888
    Note: In ATen, the __name__ of a function within a module often
1889
    starts by the module name. E.g. linalg_eigh, or special_zeta
1890
    """
1891
    module = fn.__module__
1892
    prefix = "torch._refs"
1893
    assert module.startswith(prefix)
1894
    module = module[len(prefix) :]
1895
    # We want to go from .special / .nn.functional
1896
    # to special and special_ / nn_functional_
1897
    if module:
1898
        module = module[1:]
1899
        module = module.replace(".", "_")
1900
        module = module + "_"
1901
    return getattr(torch._ops.ops.aten, f"{module}{name}")
1902

1903

1904
def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
1905
    return dtype if dtype is not None else torch.get_default_dtype()
1906

1907

1908
def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType:
1909
    return device if device is not None else torch.device("cpu")
1910

1911

1912
def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
1913
    return layout if layout is not None else torch.strided
1914

1915

1916
def clone_preserve_strides(x):
1917
    needed_size = compute_required_storage_length(
1918
        x.size(), x.stride(), x.storage_offset()
1919
    )
1920
    # Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
1921
    # so these as_strided() calls are not seen by autograd.
1922
    # We need to mimic this behavior in our ref/prim implementations.
1923
    # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
1924
    # We should revisit this when we add a compositional as_strided op,
1925
    # and also as part of https://github.com/pytorch/pytorch/issues/90507
1926
    try:
1927
        old = torch._C._dispatch_tls_is_dispatch_key_excluded(
1928
            torch._C.DispatchKey.ADInplaceOrView
1929
        )
1930
        torch._C._dispatch_tls_set_dispatch_key_excluded(
1931
            torch._C.DispatchKey.ADInplaceOrView, True
1932
        )
1933
        buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
1934
        return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
1935
    finally:
1936
        torch._C._dispatch_tls_set_dispatch_key_excluded(
1937
            torch._C.DispatchKey.ADInplaceOrView, old
1938
        )
1939

1940

1941
def alert_not_deterministic(caller: str):
1942
    if torch.are_deterministic_algorithms_enabled():
1943
        if torch.is_deterministic_algorithms_warn_only_enabled():
1944
            warnings.warn(
1945
                f"{caller} does not have a deterministic implementation, but you set "
1946
                f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
1947
                f"You can file an issue at https://github.com/pytorch/pytorch/issues "
1948
                f"to help us prioritize adding deterministic support for this operation."
1949
            )
1950
        else:
1951
            torch._check(
1952
                False,
1953
                lambda: (
1954
                    f"{caller} does not have a deterministic implementation, but you set "
1955
                    f"'torch.use_deterministic_algorithms(True)'. You can turn off "
1956
                    f"determinism just for this operation, or you can use the "
1957
                    f"'warn_only=True' option, if that's acceptable for your application. "
1958
                    f"You can also file an issue at https://github.com/pytorch/pytorch/issues "
1959
                    f"to help us prioritize adding deterministic support for this operation."
1960
                ),
1961
            )
1962

1963

1964
class CUDARngStateHelper:
1965
    @staticmethod
1966
    def get_torch_state_as_tuple(fake_mode=nullcontext()):
1967
        if not torch.cuda.is_available():
1968
            raise RuntimeError("CUDA not available")
1969

1970
        with fake_mode:
1971
            seed = torch.tensor(torch.cuda.initial_seed())
1972
            offset = torch.tensor(torch.cuda._get_rng_state_offset())
1973
            return seed, offset
1974

1975
    @staticmethod
1976
    def set_torch_state_tensor(seed, offset):
1977
        # Rng state is [64-bit seed, 64-bit offset]
1978
        seed_portion = seed.reshape([1]).view(torch.uint8)
1979
        offset_portion = offset.reshape([1]).view(torch.uint8)
1980
        new_state = torch.cat([seed_portion, offset_portion])
1981
        torch.cuda.set_rng_state(new_state)
1982

1983
    @staticmethod
1984
    def set_new_offset(relative_offset):
1985
        torch.cuda._set_rng_state_offset(relative_offset.item())
1986

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.