pytorch

Форк
0
1996 строк · 64.0 Кб
1
# mypy: allow-untyped-defs
2
from __future__ import annotations
3

4
import operator
5
import warnings
6
from contextlib import nullcontext
7
from enum import Enum
8
from functools import reduce
9
from typing import (
10
    Any,
11
    Callable,
12
    cast,
13
    List,
14
    NamedTuple,
15
    Optional,
16
    overload,
17
    Sequence,
18
    Tuple,
19
    Type,
20
    TYPE_CHECKING,
21
    Union,
22
)
23
from typing_extensions import deprecated, TypeAlias
24

25

26
if TYPE_CHECKING:
27
    # Import the following modules during type checking to enable code intelligence features,
28
    # such as auto-completion in tools like pylance, even when these modules are not explicitly
29
    # imported in user code.
30

31
    import sympy
32

33
import torch
34
from torch import sym_float, sym_int, sym_max
35

36

37
ShapeType: TypeAlias = Union[torch.Size, List[int], Tuple[int, ...]]
38
StrideType: TypeAlias = Union[List[int], Tuple[int, ...]]
39
DimsType: TypeAlias = Union[int, List[int], Tuple[int, ...]]
40
DimsSequenceType: TypeAlias = Union[List[int], Tuple[int, ...]]
41
# TODO: Type[torch.SymInt], Type[torch.SymFloat]
42
NumberTypeType: TypeAlias = Union[Type[bool], Type[int], Type[float], Type[complex]]
43
# TODO: This needs a lot more type annotations
44
# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
45
NumberType: TypeAlias = Union[bool, int, float, complex]
46
RealNumberType: TypeAlias = Union[bool, int, float]
47

48
Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool)
49
# I don't call it Integral because numbers.Integral includes bool, but IntLike
50
# does not
51
Dim = int
52
IntLike = (int, torch.SymInt)
53
FloatLike = (float, torch.SymFloat)
54
BoolLike = (bool, torch.SymBool)
55
IntWithoutSymInt = int
56
FloatWithoutSymFloat = float
57
DeviceLikeType: TypeAlias = Union[str, torch.device, int]
58
Tensor = torch.Tensor
59

60

61
torch_function_passthrough = {
62
    torch.device,
63
    torch.sym_not,
64
    torch.sym_float,
65
    torch.sym_int,
66
    torch.sym_max,
67
    torch.sym_min,
68
    torch._sym_sqrt,  # type: ignore[attr-defined]
69
    torch.sym_ite,
70
    torch.Tensor.dim,
71
    torch.Tensor.ndim.__get__,  # type: ignore[attr-defined]
72
    torch.Tensor.numel,
73
    torch.Tensor.size,
74
    torch.Tensor.storage_offset,
75
    torch.Tensor.stride,
76
    torch.Tensor.dtype.__get__,  # type: ignore[attr-defined]
77
    torch.Tensor.is_sparse.__get__,  # type: ignore[attr-defined]
78
    torch.Tensor.shape.__get__,  # type: ignore[attr-defined]
79
    torch.Tensor.device.__get__,  # type: ignore[attr-defined]
80
    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
81
    torch.Tensor.layout.__get__,  # type: ignore[attr-defined]
82
    torch.Tensor.is_contiguous,
83
    # For TorchRefsMode only
84
    torch.Tensor.__format__,
85
    torch.Tensor.__repr__,
86
    torch.Tensor.requires_grad.__get__,  # type: ignore[attr-defined]
87
    torch.Tensor.__getitem__,
88
}
89

90

91
TensorLikeType = torch.Tensor
92
TensorLike = torch.Tensor
93
TensorSequenceType: TypeAlias = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
94
TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType]
95

96
CustomOutParamAnnotation = "__custom_out_param__"
97

98

99
def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
100
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
101

102
    if len(a) != len(b):
103
        return False
104

105
    for x, y in zip(a, b):
106
        if allow_rhs_unbacked:
107
            # TODO: We should check that the symbols are consistent
108
            # with each other
109
            if isinstance(y, torch.SymInt):
110
                continue
111
        # NB: Naively, you would not expect to have to do an oblivious guard
112
        # here because there is seemingly no broadcasting here, but in fact we
113
        # use this in some situations to determine if we need to do an expand
114
        # on the tensor because they don't line up, so you can definitely end
115
        # up trying to prove u0 != 1 in this situation.  See
116
        # python test/test_proxy_tensor.py -k test_cumsum_unbacked
117
        if guard_size_oblivious(x != y):
118
            return False
119

120
    return True
121

122

123
def _maybe_get_pytype(t):
124
    if t is torch.SymFloat:
125
        return float
126
    elif t is torch.SymInt:
127
        return int
128
    elif t is torch.SymBool:
129
        return bool
130
    else:
131
        return t
132

133

134
# TODO: look at using torch.testing.assert_close instead with an option
135
#   to just compare metadata
136
def compare_tensor_meta(
137
    a: TensorLikeType,
138
    b: TensorLikeType,
139
    check_strides=False,
140
    *,
141
    allow_rhs_unbacked=False,
142
    check_conj=True,
143
):
144
    """
145
    Checks that two tensor likes have the same shape,
146
    dtype and device.
147

148
    In the future this will validate additional metadata, like
149
    strides.
150
    """
151
    assert isinstance(a, TensorLike)
152
    assert isinstance(b, TensorLike)
153

154
    if not same_shape(a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked):
155
        msg = f"Shapes {a.shape} and {b.shape} are not equal!"
156
        raise AssertionError(msg)
157

158
    if a.dtype != b.dtype:
159
        msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!"
160
        raise AssertionError(msg)
161

162
    if a.device != b.device:
163
        # Handles special cuda:0 vs cuda case
164
        # TODO: we should review why this happens and see about fixing it
165
        if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
166
            str(b.device) == "cuda:0" or str(b.device) == "cuda"
167
        ):
168
            pass
169
        else:
170
            msg = f"Devices {a.device} and {b.device} are not equal!"
171
            raise AssertionError(msg)
172

173
    # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
174
    if check_strides:
175
        same_strides, idx = check_significant_strides(a, b)
176
        if not same_strides:
177
            msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!"
178
            raise RuntimeError(msg)
179

180
        if a.storage_offset() != b.storage_offset():
181
            msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!"
182
            raise RuntimeError(msg)
183

184
    if check_conj:
185
        if a.is_conj() != b.is_conj():
186
            raise RuntimeError(
187
                f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}"
188
            )
189

190
    if a.is_neg() != b.is_neg():
191
        raise RuntimeError(
192
            f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}"
193
        )
194

195

196
def _check_strides_helper(
197
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True
198
) -> Tuple[bool, Optional[int]]:
199
    # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
200
    # See https://github.com/pytorch/pytorch/issues/77553
201
    # Only compares strides that are "meaningful" -- strides for dimensions with length > 1
202
    # and for tensors with more than one element
203
    if (
204
        not only_cuda or a.device.type == "cuda" or b.device.type == "cuda"
205
    ) and a.numel() > 0:
206
        for idx in range(a.ndim):
207
            check = not significant_only or a.shape[idx] > 1
208
            if a.stride()[idx] != b.stride()[idx] and check:
209
                return False, idx
210

211
    return True, None
212

213

214
def check_significant_strides(
215
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
216
) -> Tuple[bool, Optional[int]]:
217
    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)
218

219

220
def check_all_strides(
221
    a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
222
) -> Tuple[bool, Optional[int]]:
223
    return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
224

225

226
# This function is equivalent to compute_contiguous() from TensorImpl.cpp
227
def is_contiguous(a: TensorLikeType) -> bool:
228
    """
229
    Tests whether a tensor is contiguous or not.
230

231
    Tensors are contiguous when they have no elements,
232
    one element, or when they have "nested" strides.
233
    """
234
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
235

236
    if guard_size_oblivious(a.numel() < 2):
237
        return True
238

239
    expected_stride = 1
240
    for x, y in reversed(tuple(zip(a.shape, a.stride()))):
241
        # Skips checking strides when a dimension has length 1
242
        if guard_size_oblivious(x == 1):
243
            continue
244

245
        if guard_size_oblivious(y != expected_stride):
246
            return False
247
        expected_stride = expected_stride * x
248

249
    return True
250

251

252
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
253
def is_channels_last_contiguous_2d(a: Tensor) -> bool:
254
    # NHWC or not channels last 2D contiguous
255
    if a.ndim != 4:
256
        return False
257

258
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
259

260
    expected_stride = 1
261
    for idx in (1, 3, 2, 0):
262
        length = a.shape[idx]
263
        if guard_size_oblivious(length == 1):
264
            continue
265

266
        stride = a.stride()[idx]
267
        if guard_size_oblivious(stride != expected_stride):
268
            return False
269

270
        expected_stride *= length
271

272
    return True
273

274

275
def is_channels_last_contiguous_3d(a: Tensor) -> bool:
276
    # NDHWC or not channels last 3D contiguous
277
    if a.ndim != 5:
278
        return False
279

280
    expected_stride = 1
281
    for idx in (1, 4, 3, 2, 0):
282
        length = a.shape[idx]
283
        if length == 1:
284
            continue
285

286
        stride = a.stride()[idx]
287
        if stride != expected_stride:
288
            return False
289

290
        expected_stride *= length
291

292
    return True
293

294

295
_memory_formats = {
296
    torch.contiguous_format,
297
    torch.preserve_format,
298
    torch.channels_last,
299
    torch.channels_last_3d,
300
}
301

302

303
def validate_memory_format(memory_format: torch.memory_format):
304
    torch._check(
305
        memory_format in _memory_formats,
306
        lambda: f"Received unknown memory format {memory_format}!",
307
    )
308

309

310
def is_contiguous_for_memory_format(  # type: ignore[return]
311
    a: Tensor, *, memory_format: torch.memory_format
312
) -> bool:
313
    validate_memory_format(memory_format)
314

315
    if memory_format == torch.contiguous_format:
316
        return is_contiguous(a)
317
    if memory_format == torch.channels_last:
318
        return is_channels_last_contiguous_2d(a)
319
    if memory_format == torch.channels_last_3d:
320
        return is_channels_last_contiguous_3d(a)
321

322
    torch._check(
323
        False,
324
        lambda: f"is_contiguous received unsupported memory format {memory_format}",
325
    )
326

327

328
# NOTE: that tensors with no elements and channels last is ???
329
def is_channels_last_contiguous(a: Tensor) -> bool:
330
    """
331
    True when a tensor is channels-last contiguous.
332

333
    This requires that:
334

335
      - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
336
      - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
337
        stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
338
        each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
339
        "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
340
        for example.
341
    """
342
    return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
343

344

345
def is_non_overlapping_and_dense(a: Tensor) -> bool:
346
    """
347
    True when a tensor is non-overlapping and dense.
348

349
    A tensor is non-overlapping and dense when there exists a permutation of
350
    its dimensions that is contiguous.
351
    """
352

353
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
354

355
    if a.is_sparse:
356
        return False
357

358
    # Short-circuits if the tensor is already contiguous or channels-last contiguous
359
    if is_contiguous(a) or is_channels_last_contiguous(a):
360
        return True
361

362
    # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
363

364
    # Short-circuits for tensors of rank one, which are
365
    # non-overlapping and "dense" if their stride is one
366
    if a.ndim == 1:
367
        return a.stride()[0] == 1
368

369
    # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
370
    # Sorts (length, stride) pairs by stride
371
    #
372
    # This sort is done in a size-oblivious way, which helps if we do a
373
    # comparison like 2048*u0 > u0; we just want this to return True
374
    # (and not worry about what if u0 is zero).
375
    class K(NamedTuple):
376
        size: int
377
        stride: int
378

379
        def __lt__(self, other):
380
            return guard_size_oblivious(self.stride < other.stride)
381

382
        def __gt__(self, other):
383
            return guard_size_oblivious(self.stride > other.stride)
384

385
        def __le__(self, other):
386
            return guard_size_oblivious(self.stride <= other.stride)
387

388
        def __ge__(self, other):
389
            return guard_size_oblivious(self.stride >= other.stride)
390

391
        def __eq__(self, other):
392
            return guard_size_oblivious(self.stride == other.stride)
393

394
    lengths_and_strides = sorted(map(K, a.shape, a.stride()))
395

396
    expected_stride = 1
397
    for length, stride in lengths_and_strides:
398
        if guard_size_oblivious(length == 1):
399
            continue
400

401
        if stride != expected_stride:
402
            return False
403

404
        expected_stride *= length
405

406
    return True
407

408

409
# NOTE: Based on the implementation in TensorIterator.cpp, but note that
410
# the note [Computing output strides] is incorrect, because it
411
# says that strides will be preserved even if they are not
412
# "non overlapping and dense", but this is incorrect. The
413
# output of elementwise operations are always given
414
# non overlapping and dense strides.
415
# This is also INCORRECT because it does not model TensorIterator's
416
# short-circuit, which can cause different strides.
417
def compute_elementwise_output_logical_to_physical_perm(
418
    *tensors, _skip_checks=False
419
) -> List[int]:
420
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
421

422
    if not _skip_checks and len(tensors) == 0:
423
        msg = "Can't compute elementwise output strides for zero tensors!"
424
        raise ValueError(msg)
425

426
    if not _skip_checks:
427
        check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
428

429
    # Filters the tensors to actual tensors
430
    if not _skip_checks:
431
        tensors = tuple(
432
            a
433
            for a in tensors
434
            if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
435
        )
436

437
    # Short-circuits for CPU scalar case
438
    if len(tensors) == 0:
439
        return []
440

441
    # Short-circuits for shapes with zero or one dimensions
442
    # TODO: are these necessary?
443
    ndim = tensors[0].ndim
444
    if ndim == 0:
445
        return []
446
    if ndim == 1:
447
        return [0]
448

449
    # Short-circuits if contiguous or channels last, following the fake fast path.
450
    # This reduces the number of guards we end up making
451
    is_contiguous = True
452
    is_channels_last = True
453
    for t in tensors:
454
        is_contiguous = is_contiguous and t.is_contiguous(
455
            memory_format=torch.contiguous_format
456
        )
457
        is_channels_last = is_channels_last and t.is_contiguous(
458
            memory_format=torch.channels_last
459
        )
460

461
    if is_contiguous and not is_channels_last:
462
        return list(range(ndim))
463

464
    if is_channels_last and not is_contiguous:
465
        return [0, *list(range(2, ndim)), 1]
466

467
    shape = tensors[0].shape
468

469
    def should_swap(idx_a, idx_b):
470
        for tensor in tensors:
471
            stride_a = tensor.stride()[idx_a]
472
            stride_b = tensor.stride()[idx_b]
473

474
            if guard_size_oblivious(stride_a == 0) or guard_size_oblivious(
475
                stride_b == 0
476
            ):
477
                continue
478

479
            if guard_size_oblivious(stride_a < stride_b):
480
                return -1
481

482
            if guard_size_oblivious(stride_a > stride_b):
483
                return 1
484

485
            # stride_a == stride_b
486
            if guard_size_oblivious(shape[idx_a] > shape[idx_b]):
487
                return 1
488

489
        # Note: this case is hit if all strides are zero,
490
        # or all strides are equal and all dimensions have the same length
491
        return 0
492

493
    # The "sort" order for the permutation is back-to-front, but
494
    # the natural order for permutations is front-to-back.  Do the
495
    # sorting back-to-front and then reverse it on output.
496
    #
497
    # also, note this returns the logical to physical shape permutation
498
    perm = list(reversed(range(ndim)))
499

500
    # insertion sort with support for ambiguous comparisons
501
    for i in range(1, ndim):
502
        dim1 = i
503
        for dim0 in reversed(range(i)):
504
            comparison = should_swap(perm[dim0], perm[dim1])
505
            if comparison > 0:
506
                perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
507
                dim1 = dim0
508
            elif comparison < 0:
509
                break
510

511
    return list(reversed(perm))
512

513

514
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
515
    """
516
    Computes the output strides for elementwise operations.
517
    """
518
    if len(tensors) == 0:
519
        msg = "Can't compute elementwise output strides for zero tensors!"
520
        raise ValueError(msg)
521

522
    check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
523

524
    # Filters the tensors to actual tensors
525
    tensors = tuple(
526
        a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
527
    )
528

529
    # Short-circuits for CPU scalar case
530
    if len(tensors) == 0:
531
        return ()
532

533
    ndim = tensors[0].ndim
534
    shape = tensors[0].shape
535

536
    if ndim == 0:
537
        return ()
538
    if ndim == 1:
539
        return (1,)
540

541
    logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm(
542
        *tensors, _skip_checks=True
543
    )
544
    permuted_shape = apply_perm(shape, logical_to_physical_perm)  # to physical
545

546
    new_strides = make_contiguous_strides_for(permuted_shape)
547
    permuted_strides = apply_perm(
548
        new_strides, invert_perm(logical_to_physical_perm)
549
    )  # to logical
550

551
    return tuple(permuted_strides)
552

553

554
# Identity permutation is [0, 1, 2]
555
def apply_perm(inp, perm):
556
    ndim = len(inp)
557
    permuted_inp = [-1] * ndim
558
    for idx, x in enumerate(perm):
559
        permuted_inp[idx] = inp[x]
560
    return permuted_inp
561

562

563
def invert_perm(perm):
564
    ndim = len(perm)
565
    new_perm = [-1] * ndim
566
    for idx, x in enumerate(perm):
567
        new_perm[x] = idx
568
    return new_perm
569

570

571
#
572
# Common helper functions
573
#
574

575

576
def validate_dim_length(length: int):
577
    """
578
    Validates that an object represents a valid
579
    dimension length.
580
    """
581

582
    if isinstance(length, (int, torch.SymInt)):
583
        torch._check_is_size(length)
584
    else:
585
        # sometimes called with sympy expression by inductor
586
        assert length >= 0
587

588

589
def validate_shape(shape: ShapeType):
590
    """
591
    Validates that a sequence represents a valid shape.
592
    """
593

594
    assert isinstance(shape, Sequence), type(shape)
595
    for l in shape:
596
        validate_dim_length(l)
597

598

599
def validate_strides(strides: StrideType):
600
    """
601
    Verifies the object specifies valid strides.
602
    """
603

604
    assert isinstance(strides, Sequence)
605
    for stride in strides:
606
        assert stride >= 0
607

608

609
def validate_idx(rank: int, idx: int):
610
    """
611
    Validates that idx is a valid index for the given shape.
612
    Assumes the index is already canonicalized.
613
    """
614

615
    assert isinstance(idx, Dim)
616
    assert isinstance(rank, Dim)
617

618
    assert idx >= 0 and idx < rank or idx == 0
619

620

621
def validate_dimension_indices(rank: int, indices: DimsSequenceType):
622
    for idx in indices:
623
        validate_idx(rank, idx)
624

625

626
def validate_exclusive_idx(rank: int, ex_idx: int):
627
    """
628
    Validates that ex_idx is a valid exclusive index
629
    for the given shape.
630
    """
631

632
    assert isinstance(ex_idx, Dim)
633
    assert isinstance(rank, Dim)
634
    assert ex_idx > 0 and ex_idx <= rank
635

636

637
# "Wraps" a dim (up to one time) for the given rank, allowing dims to be
638
# specified using negative indices. If `wrap_scalar` is true then scalar
639
# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
640
# idx should be in the range [-rank, rank-1].
641
def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
642
    if rank < 0:
643
        msg = f"Rank cannot be negative but got {rank}"
644
        raise IndexError(msg)
645

646
    if rank == 0:
647
        if not wrap_scalar:
648
            msg = f"Dimension specified as {idx} but tensor has no dimensions"
649
            raise IndexError(msg)
650
        rank = 1
651

652
    if idx >= 0 and idx < rank:
653
        return idx
654

655
    if idx < 0:
656
        _idx = idx + rank
657
    else:
658
        _idx = idx
659

660
    if _idx < 0 or _idx >= rank:
661
        # Same error message as in aten/src/ATen/WrapDimUtils.h:49
662
        msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})"
663
        raise IndexError(msg)
664

665
    return _idx
666

667

668
# Takes a dimension or sequence of dimensions and "wraps" them,
669
# mapping negative offsets to positive ones
670
@overload
671
def canonicalize_dims(
672
    rank: int, indices: Sequence[int], wrap_scalar: bool = True
673
) -> Tuple[int, ...]:
674
    pass
675

676

677
@overload
678
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
679
    pass
680

681

682
def canonicalize_dims(rank, indices, wrap_scalar=True):
683
    if isinstance(indices, Dim):
684
        return canonicalize_dim(rank, indices, wrap_scalar)
685

686
    return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)
687

688

689
def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
690
    """
691
    Validates that perm is a permutation of length rank.
692
    """
693

694
    return isinstance(perm, Sequence) and sorted(perm) == list(range(rank))
695

696

697
def is_same_shape(a: Sequence, b: Sequence) -> bool:
698
    """
699
    Compares two shapes a and b, returning True if they are the same
700
    (their ranks and corresponding lengths match) and False otherwise.
701
    """
702

703
    return tuple(a) == tuple(b)
704

705

706
def is_cpu_scalar_tensor(a: Any) -> bool:
707
    return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
708

709

710
def check_same_device(*args, allow_cpu_scalar_tensors):
711
    """
712
    Checks that all Tensors in args have the same device.
713

714
    Raises a RuntimeError when:
715
      - args contains an object whose type is not Tensor or Number
716
      - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
717
    """
718
    # Short-circuits if all (one or fewer) arguments are trivially on the same device
719
    if len(args) <= 1:
720
        return
721

722
    # Note: cannot initialize device to the first arg's device (it may not have one)
723
    device = None
724
    for arg in args:
725
        if isinstance(arg, Number):
726
            continue
727
        elif isinstance(arg, TensorLike):
728
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
729
                continue
730

731
            if device is None:
732
                device = arg.device
733

734
            if device != arg.device:
735
                msg = (
736
                    "Tensor on device "
737
                    + str(arg.device)
738
                    + " is not on the expected device "
739
                    + str(device)
740
                    + "!"
741
                )
742
                raise RuntimeError(msg)
743
        else:
744
            msg = (
745
                "Unexpected type when checking for same device, " + str(type(arg)) + "!"
746
            )
747
            raise RuntimeError(msg)
748

749

750
def canonicalize_device(device: DeviceLikeType) -> torch.device:
751
    if isinstance(device, torch.device):
752
        return device
753

754
    assert isinstance(device, str)
755
    return torch.device(device)
756

757

758
# Asserts if any of the following are true:
759
#   - a non-scalar or non-Tensor is given
760
#   - the shape of any tensors is distinct
761
def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
762
    """
763
    Checks that all Tensors in args have the same shape.
764

765
    Raises a RuntimeError when:
766
      - args contains an object whose type is not Tensor or Number
767
      - two Tensor objects in args have different devices
768
    """
769
    shape = None
770

771
    for arg in args:
772
        if isinstance(arg, Number):
773
            continue
774
        elif isinstance(arg, TensorLike):
775
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
776
                continue
777

778
            if shape is None:
779
                shape = arg.shape
780

781
            if not is_same_shape(shape, arg.shape):
782
                msg = f"Shape {arg.shape} is not the expected shape {shape}!"
783
                raise RuntimeError(msg)
784
        else:
785
            msg = (
786
                "Unexpected type when checking for same shape, " + str(type(arg)) + "!"
787
            )
788
            raise RuntimeError(msg)
789

790

791
# Acquires a common shape, if it exists, from one or more tensor arguments,
792
# filtering number arguments
793
def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
794
    shape = None
795
    scalar_shape = None
796

797
    for arg in args:
798
        if isinstance(arg, Number):
799
            continue
800
        elif isinstance(arg, TensorLike):
801
            if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
802
                scalar_shape = arg.shape
803
                continue
804

805
            if shape is None:
806
                shape = arg.shape
807

808
            if not is_same_shape(shape, arg.shape):
809
                return None
810
        else:
811
            return None
812

813
    return shape if shape is not None else scalar_shape
814

815

816
# Extracts dimensions that might be passed either as a list/tuple or as varargs.
817
# A typical case is Tensor.permute .
818
def extract_dims_from_varargs(
819
    dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]
820
) -> DimsSequenceType:
821
    if dims and isinstance(dims[0], Sequence):
822
        assert len(dims) == 1
823
        dims = cast(Tuple[DimsSequenceType], dims)
824
        return dims[0]
825
    else:
826
        return cast(DimsSequenceType, dims)
827

828

829
def extract_shape_from_varargs(
830
    shape: Union[ShapeType, Tuple[ShapeType]],
831
    validate=True,
832
) -> Tuple[int, ...]:
833
    """
834
    Returns a shape from varargs.
835

836
    In PyTorch, operations that accept shapes often accept them as varargs, like
837
    foo(*shape). However a user can pass the shape as a sequence of integers,
838
    like this:
839

840
      foo(1, 2, 3)
841

842
    or as a sequence of integers
843

844
      foo((1, 2, 3))
845

846
    In the first case shape will be a tuple of integers, and in the second case it's a tuple
847
    containing a tuple of integers. This validates those inputs and canonicalizes them
848
    to a tuple of integers.
849
    """
850

851
    # Handles tuple unwrapping
852
    if len(shape) == 1 and isinstance(shape[0], Sequence):
853
        shape = shape[0]
854

855
    if validate:
856
        validate_shape(shape)  # type: ignore[arg-type]
857
    return shape  # type: ignore[return-value]
858

859

860
def infer_size_shapes(a: ShapeType, b: ShapeType) -> Tuple[int, ...]:
861
    ndim = max(len(a), len(b))
862
    expandedSizes = [0] * ndim
863

864
    for i in range(ndim - 1, -1, -1):
865
        offset = ndim - 1 - i
866
        dimA = len(a) - 1 - offset
867
        dimB = len(b) - 1 - offset
868
        sizeA = a[dimA] if dimA >= 0 else 1
869
        sizeB = b[dimB] if dimB >= 0 else 1
870

871
        torch._check(
872
            (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1),
873
            lambda: (
874
                f"The size of tensor a ({sizeA}) must match the size of "
875
                f"tensor b ({sizeB}) at non-jagged dimension {i}"
876
            ),
877
        )
878

879
        # 1s map to the other size (even 0)
880
        expandedSizes[i] = sizeB if sizeA == 1 else sizeA
881

882
    return tuple(expandedSizes)
883

884

885
def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
886
    """
887
    Infers the size of a dim with size -1, if it exists.
888
    Also checks that new shape is compatible with the number of elements.
889
    """
890
    dim = None
891
    newsize = 1
892
    for i, d in enumerate(shape):
893
        if d == -1:
894
            torch._check(dim is None, lambda: "only one dimension can be inferred")
895
            dim = i
896
        elif d >= 0:
897
            newsize *= d
898
        else:
899
            torch._check(False, lambda: f"invalid shape dimension {d}")
900
    if dim is None:
901
        torch._check(
902
            numel == newsize,
903
            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
904
        )
905
    else:
906
        from torch.fx.experimental.symbolic_shapes import definitely_true
907

908
        torch._check(
909
            newsize != 0,
910
            lambda: (
911
                f"cannot reshape tensor of 0 elements into shape {list(shape)} because the "
912
                f"unspecified dimension size -1 can be any value and is ambiguous"
913
                if definitely_true(numel == 0)
914
                else f"shape '{list(shape)}' is invalid for input of size {numel}"
915
            ),
916
        )
917
        torch._check(
918
            numel % newsize == 0,
919
            lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
920
        )
921
        # Convert to list to produce a compatible error message with core
922
        # PyTorch, which prints sequences in square brackets.
923
        shape = list(shape)
924
        shape[dim] = numel // newsize
925
        # NB: This is pretty important when you have unbacked SymInts.
926
        # Suppose you have (i0, 12) resizing into (2, -1, 12).  The old
927
        # range for i0 is typically [2, inf], which means if you divide
928
        # by two the new range should be [1, inf].  But this is bad news
929
        # if you have an unbacked SymInt: we need to reapply the unsound
930
        # assumption that the size is >= 2.
931
        torch._check_is_size(shape[dim])
932
    return tuple(shape)
933

934

935
_integer_dtypes = (
936
    torch.uint8,
937
    torch.uint16,
938
    torch.uint32,
939
    torch.uint64,
940
    torch.int8,
941
    torch.int16,
942
    torch.int32,
943
    torch.int64,
944
)
945
_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
946
_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
947

948

949
def is_boolean_dtype(dtype: torch.dtype) -> bool:
950
    assert isinstance(dtype, torch.dtype)
951
    return dtype is torch.bool
952

953

954
def is_integer_dtype(dtype: torch.dtype) -> bool:
955
    assert isinstance(dtype, torch.dtype)
956
    return dtype in _integer_dtypes
957

958

959
def is_low_precision_dtype(dtype: torch.dtype) -> bool:
960
    assert isinstance(dtype, torch.dtype)
961
    return dtype in _low_precision_dtypes
962

963

964
def is_float_dtype(dtype: torch.dtype) -> bool:
965
    assert isinstance(dtype, torch.dtype)
966
    return dtype.is_floating_point
967

968

969
def is_complex_dtype(dtype: torch.dtype) -> bool:
970
    assert isinstance(dtype, torch.dtype)
971
    return dtype in _complex_dtypes
972

973

974
def is_grad_dtype(dtype: torch.dtype) -> bool:
975
    """
976
    Checks if the dtype can require a gradient.
977
    """
978
    return dtype.is_floating_point or is_complex_dtype(dtype)
979

980

981
_complex_to_real_dtype_map = {
982
    torch.complex128: torch.float64,
983
    torch.complex64: torch.float32,
984
    torch.complex32: torch.float16,
985
}
986

987
_real_to_complex_dtype_map = {
988
    torch.float16: torch.complex32,
989
    torch.bfloat16: torch.complex64,
990
    torch.float32: torch.complex64,
991
    torch.float64: torch.complex128,
992
}
993

994

995
def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
996
    return _complex_to_real_dtype_map[dtype]
997

998

999
def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
1000
    return _real_to_complex_dtype_map[dtype]
1001

1002

1003
def dtype_to_type(dtype: torch.dtype) -> type:
1004
    """
1005
    Computes the corresponding Python type (AKA "type kind") for the
1006
    given dtype.
1007
    """
1008
    assert isinstance(dtype, torch.dtype)
1009

1010
    if dtype is torch.bool:
1011
        return bool
1012
    if dtype in _integer_dtypes:
1013
        return int
1014
    if dtype.is_floating_point:
1015
        return float
1016
    if dtype in _complex_dtypes:
1017
        return complex
1018

1019
    raise ValueError("Invalid dtype!")
1020

1021

1022
def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:
1023
    """
1024
    Computes the corresponding Python type constructor for the
1025
    given dtype.
1026
    """
1027
    assert isinstance(dtype, torch.dtype)
1028

1029
    if dtype is torch.bool:
1030
        return lambda x: bool(x)
1031
    if dtype in _integer_dtypes:
1032
        return sym_int
1033
    if dtype.is_floating_point:
1034
        return sym_float
1035
    if dtype in _complex_dtypes:
1036
        # TODO: type error here is real, replace with sym_complex
1037
        return lambda x: complex(x)  # type: ignore[arg-type]
1038

1039
    raise ValueError("Invalid dtype!")
1040

1041

1042
def type_to_dtype(typ: type) -> torch.dtype:
1043
    """
1044
    Computes the corresponding dtype for a Number type.
1045
    """
1046

1047
    assert isinstance(typ, type)
1048

1049
    if typ in (bool, torch.SymBool):
1050
        return torch.bool
1051
    if typ in (int, torch.SymInt):
1052
        return torch.long
1053
    if typ in (float, torch.SymFloat):
1054
        return torch.get_default_dtype()
1055
    # TODO: sym_complex_float?
1056
    if typ is complex:
1057
        return corresponding_complex_dtype(torch.get_default_dtype())
1058

1059
    raise ValueError(f"Invalid type {typ}!")
1060

1061

1062
def get_dtype(x: Union[torch.Tensor, NumberType]):
1063
    if isinstance(x, torch.Tensor):
1064
        return x.dtype
1065
    else:
1066
        return type_to_dtype(type(x))
1067

1068

1069
_ordered_types = (bool, int, float, complex)
1070

1071

1072
def check_fp_or_complex(
1073
    dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True
1074
):
1075
    """
1076
    Checks whether the input is floating point or complex.
1077
    If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
1078
    """
1079
    torch._check(
1080
        is_float_dtype(dtype) or is_complex_dtype(dtype),
1081
        lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
1082
    )
1083
    torch._check(
1084
        allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
1085
        lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
1086
    )
1087

1088

1089
def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
1090
    torch._check(
1091
        len(A.shape) >= 2,
1092
        lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
1093
    )
1094

1095

1096
def get_higher_type(a: type, b: type) -> type:
1097
    """
1098
    Returns the higher of the two given Number types.
1099

1100
    The types are ordered bool -> int -> float -> complex.
1101
    """
1102
    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1103
    # Type checking
1104
    if a not in _ordered_types or b not in _ordered_types:
1105
        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1106

1107
    if a is b:
1108
        return a
1109

1110
    for typ in _ordered_types:
1111
        if a is typ:
1112
            return b
1113
        if b is typ:
1114
            return a
1115

1116
    raise ValueError("Unknown Python scalar type!")
1117

1118

1119
# Returns the higher of two torch datatypes a and b or, if the two
1120
#   are not ordered relative to each other, the next
1121
#   higher datatype
1122
def get_higher_dtype(
1123
    a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1124
    b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
1125
) -> Optional[torch.dtype]:
1126
    """
1127
    Computes the "lowest" datatype that is weakly
1128
    "higher" than both a and b.
1129
    """
1130

1131
    # Type checking
1132
    assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
1133
    assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
1134

1135
    def _extract_dtype(
1136
        x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
1137
    ) -> Optional[torch.dtype]:
1138
        if x is None:
1139
            return None
1140
        if isinstance(x, torch.dtype):
1141
            return x
1142
        if isinstance(x, TensorLike):
1143
            return x.dtype
1144
        if isinstance(x, Number):
1145
            return type_to_dtype(type(x))
1146

1147
        raise RuntimeError("Unexpected type given to _extract_dtype!")
1148

1149
    a, b = _extract_dtype(a), _extract_dtype(b)
1150

1151
    if a is b:
1152
        return a
1153

1154
    if a is None:
1155
        return b
1156

1157
    if b is None:
1158
        return a
1159

1160
    ordered_datatypes = (
1161
        (torch.bool,),
1162
        (torch.uint8, torch.int8),
1163
        (torch.int16,),
1164
        (torch.int32,),
1165
        (torch.int64,),
1166
        (torch.float16, torch.bfloat16),
1167
        (torch.float32,),
1168
        (torch.float64,),
1169
        (torch.complex32,),
1170
        (torch.complex64,),
1171
        (torch.complex128,),
1172
    )
1173

1174
    for idx, dtypes in enumerate(ordered_datatypes):
1175
        if a in dtypes and b in dtypes:
1176
            return ordered_datatypes[idx + 1][0]
1177
        if a in dtypes:
1178
            return b
1179
        if b in dtypes:
1180
            return a
1181

1182
    raise RuntimeError("Unexpected termination!")
1183

1184

1185
def check_pin_memory(pin_memory: bool):
1186
    torch._check_not_implemented(
1187
        not pin_memory, lambda: "PrimTorch does not support pinned memory"
1188
    )
1189

1190

1191
def check_layout(layout: torch.layout):
1192
    torch._check_not_implemented(
1193
        layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}"
1194
    )
1195

1196

1197
# TODO: maybe unify with can_cast_to?
1198
def is_weakly_lesser_type(a: type, b: type) -> bool:
1199
    """
1200
    Compares two types, a and b, returning True if a is weakly "less" than b.
1201

1202
    The comparison is determined by the following type ordering: bool, int, float, complex.
1203
    """
1204

1205
    a, b = _maybe_get_pytype(a), _maybe_get_pytype(b)
1206

1207
    if a not in _ordered_types or b not in _ordered_types:
1208
        raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}")
1209

1210
    for typ in _ordered_types:
1211
        if a == typ:
1212
            return True
1213
        if b == typ:
1214
            return False
1215

1216
    raise RuntimeError("Unexpected termination!")
1217

1218

1219
def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
1220
    for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
1221
        if fn(cast_to):
1222
            return True
1223
        if fn(cast_from):
1224
            return False
1225

1226
    raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!")
1227

1228

1229
def check_same_dtype(*args):
1230
    """
1231
    Checks that all Tensors in args have the same device and that all Numbers have the
1232
    same corresponding Python type.
1233

1234
    Raises a RuntimeError when:
1235
      - args contains an object whose type is not Tensor or Number
1236
      - two Tensors objects in args have different dtypes
1237
      - two Number objects in args have different types
1238
      - there are Tensors and Numbers in args, and one of those Tensors corresponding
1239
          Python types is different from the type of one of those Numbers
1240
    """
1241
    full_dtype = None
1242
    scalar_type = None
1243

1244
    for arg in args:
1245
        if isinstance(arg, Number):
1246
            # Scalar type checking is disabled (and may be removed in the future)
1247
            continue
1248
            # if scalar_type is None:
1249
            #     scalar_type = type(arg)
1250

1251
            # if scalar_type is not type(arg):
1252
            #     msg = (
1253
            #         "Scalar of type "
1254
            #         + str(type(arg))
1255
            #         + " is not the expected type of "
1256
            #         + str(scalar_type)
1257
            #         + "!"
1258
            #     )
1259
            #     raise RuntimeError(msg)
1260
        elif isinstance(arg, TensorLike):
1261
            if full_dtype is None:
1262
                full_dtype = arg.dtype
1263
            if scalar_type is None:
1264
                scalar_type = dtype_to_type(arg.dtype)
1265

1266
            if full_dtype is not arg.dtype:
1267
                msg = (
1268
                    "Tensor with dtype "
1269
                    + str(arg.dtype)
1270
                    + " is not the expected dtype of "
1271
                    + str(full_dtype)
1272
                    + "!"
1273
                )
1274
                raise RuntimeError(msg)
1275

1276
            arg_type = dtype_to_type(arg.dtype)
1277
            if arg_type is not scalar_type:
1278
                msg = (
1279
                    "Tensor with corresponding Python type "
1280
                    + str(arg_type)
1281
                    + " is not the expected type of "
1282
                    + str(scalar_type)
1283
                    + "!"
1284
                )
1285
                raise RuntimeError(msg)
1286
        else:
1287
            msg = (
1288
                "Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
1289
            )
1290
            raise RuntimeError(msg)
1291

1292

1293
# Maps datatypes to their computation types for elementwise operations
1294
_computation_dtype_map = {
1295
    torch.bfloat16: torch.float32,
1296
    torch.float16: torch.float32,
1297
    torch.complex32: torch.complex64,
1298
}
1299

1300

1301
def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
1302
    return _computation_dtype_map.get(dtype, dtype)
1303

1304

1305
_cpu_acc_type_map = {
1306
    torch.bfloat16: torch.float64,
1307
    torch.float16: torch.float64,
1308
    torch.float32: torch.float64,
1309
    torch.complex32: torch.complex128,
1310
    torch.complex64: torch.complex128,
1311
}
1312

1313

1314
def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:
1315
    # Equivalent to at::toAccumulateType, prefer computation_dtype where possible
1316
    if device.type == "cpu":
1317
        return _cpu_acc_type_map.get(dtype, dtype)
1318
    else:
1319
        return get_computation_dtype(dtype)
1320

1321

1322
class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
1323
    DEFAULT = (0,)
1324
    NO_OPMATH = (1,)
1325
    INT_TO_FLOAT = (2,)
1326
    ALWAYS_BOOL = (3,)
1327
    COMPLEX_TO_FLOAT = (4,)
1328
    BOOL_TO_LONG = (5,)
1329

1330

1331
class REDUCTION_OUTPUT_TYPE_KIND(Enum):
1332
    SAME = (0,)
1333
    COMPLEX_TO_FLOAT = (1,)  # for complex types outputs corresponding real type
1334
    KEEP_PROMOTED_TYPE = (2,)  # keep output in opmath type, needed for mean
1335
    ALWAYS_BOOL = (3,)
1336

1337

1338
# Describes the return type of the primitive:
1339
#
1340
#   - NEW, a new tensor is created
1341
#   - VIEW, a view of an input tensor is returned
1342
#   - INPLACE, one or more input tensors is modified
1343
#
1344
# these descriptors are mututally exclusive and exhaustive.
1345
class RETURN_TYPE(Enum):
1346
    NEW = (0,)
1347
    VIEW = (1,)
1348
    INPLACE = (2,)
1349
    NONE = (3,)
1350

1351

1352
# TODO: when NumberType contains the sym types, can simplify this
1353
def number_type(
1354
    x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool]
1355
) -> Type:
1356
    if isinstance(x, torch.SymInt):
1357
        return int
1358
    elif isinstance(x, torch.SymFloat):
1359
        return float
1360
    elif isinstance(x, torch.SymBool):
1361
        return bool
1362
    else:
1363
        return type(x)
1364

1365

1366
def expr_type(x: sympy.Basic) -> Type:
1367
    import sympy
1368

1369
    if x.kind is sympy.core.kind.BooleanKind:
1370
        return bool
1371
    elif x.is_integer:  # type: ignore[attr-defined]
1372
        return int
1373
    else:
1374
        # NB: Not strictly correct, but we don't support SymPy complex or bool.
1375
        return float
1376

1377

1378
# TODO: document type promotion kinds
1379
def elementwise_dtypes(
1380
    *_args,
1381
    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
1382
) -> Tuple[torch.dtype, torch.dtype]:
1383
    """
1384
    Computes the computation and result dtypes for elementwise type promotion
1385
    on the given arguments and with the given elementwise type promotion kind.
1386

1387
    Note that not all inputs to an elementwise operation necessarily participate in type promotion.
1388
    For example, the "alpha" parameter of torch.add does not participate in type promotion,
1389
    although it may be cast to the Python type corresponding to the computation dtype that
1390
    the type promotion algorithm determines.
1391

1392
    Default elementwise type promotion, which all other type promotion kinds tweak (see below),
1393
    first decides which of four ordered types to use:
1394

1395
    bool -> integer -> floating point -> complex
1396

1397
    The selected type is the "lowest" type in the above list such that all number arguments
1398
    have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
1399
    type for their dtype.
1400

1401
    Once the type is determined, the particular result dtype is found. The dtypes are
1402
    partially ordered as follows:
1403

1404
    bool -> uint8, int8 -> int16 -> int32 -> int64 ->
1405
      float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
1406

1407
    The result dtype is selected by:
1408
      - if no tensor's dtype has the same corresponding type as the one selected,
1409
          then the result dtype is the (default) dtype corresponding to the selected type
1410
          (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
1411
      - if the result type is complex then the dtype is:
1412
        -  the default complex dtype if there are no floating point or complex tensors
1413
        -  if there are floating point or complex tensors with one or more dimensions, then
1414
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1415
            (for example, double + cfloat -> cdouble)
1416
        -  if there are only floating point or complex tensors with zero dimensions, then
1417
            the complex dtype corresponding to the highest corresponding complex dtype among those tensors
1418
      - if the first two cases do not apply, the result dtype is the highest dtype among
1419
          all tensors with one or more dimensions of the output type, and if there are no such
1420
          tensors then it's the highest dtype among all tensors with zero dimensions of the output type
1421
          (for example, long + half -> half, even if the half tensor has zero dimensions)
1422

1423
    The "corresponding complex dtypes" are:
1424
      float16    -> complex32
1425
      bfloat16   -> complex64
1426
      float32    -> complex64
1427
      float64    -> complex128
1428
      complex32  -> complex32
1429
      complex64  -> complex64
1430
      complex128 -> complex128
1431

1432
    The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
1433
    dtype by mapping low precision floating point and complex dtypes as follows:
1434

1435
      float16   -> float32
1436
      bfloat16  -> float32
1437
      complex32 -> complex64
1438

1439
    This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
1440
    computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
1441
    which perform no mathematical operations on their tensors (see below for examples).
1442

1443
    The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype,
1444
    and computation dtypes to the appropriate op math dtype.
1445

1446
    The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
1447
    mapping:
1448

1449
        complex32  -> float16
1450
        complex64  -> float32
1451
        complex128 -> float64
1452

1453
    Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
1454

1455
    The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
1456

1457
    The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
1458

1459
    Example operators for each type promotion option:
1460
      DEFAULT                 : add
1461
      NO_OPMATH               : where, nextafter, cat
1462
      INT_TO_FLOAT            : sin
1463
      COMPLEX_TO_FLOAT        : abs
1464
      BOOL_TO_LONG            : pow
1465
      ALWAYS_BOOL             : eq
1466

1467
    """
1468

1469
    args = tuple(x for x in _args if x is not None)
1470

1471
    highest_type: type = bool
1472

1473
    # Import sympy locally, as importing it eagerly at a module level is too slow
1474
    # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589
1475
    import sympy
1476

1477
    for x in args:
1478
        if not isinstance(x, (Number, TensorLike, sympy.Basic)):
1479
            msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
1480
            raise ValueError(msg)
1481

1482
        if isinstance(x, Number):
1483
            highest_type = get_higher_type(highest_type, number_type(x))
1484
        elif isinstance(x, sympy.Basic):
1485
            highest_type = get_higher_type(highest_type, expr_type(x))
1486
        else:
1487
            # x is a TensorLike
1488
            highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
1489

1490
    result_dtype = None
1491

1492
    def _find_highest_dtype_filtered(
1493
        args, filter, *, float_as_complex=False
1494
    ) -> Optional[torch.dtype]:
1495
        zero_dim_tensor_dtype = None
1496
        one_plus_dim_tensor_dtype = None
1497
        for x in args:
1498
            if isinstance(x, TensorLike) and filter(x.dtype):
1499
                _dtype = x.dtype
1500
                if float_as_complex and is_float_dtype(_dtype):
1501
                    _dtype = corresponding_complex_dtype(_dtype)
1502
                if x.ndim == 0:
1503
                    zero_dim_tensor_dtype = get_higher_dtype(
1504
                        zero_dim_tensor_dtype, _dtype
1505
                    )
1506
                else:
1507
                    # x.ndim > 0
1508
                    one_plus_dim_tensor_dtype = get_higher_dtype(
1509
                        one_plus_dim_tensor_dtype, _dtype
1510
                    )
1511

1512
        # Prefers dtype of tensors with one or more dimensions
1513
        if one_plus_dim_tensor_dtype is not None:
1514
            return one_plus_dim_tensor_dtype
1515

1516
        return zero_dim_tensor_dtype
1517

1518
    if highest_type is float:
1519
        result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
1520
        result_dtype = (
1521
            torch.get_default_dtype() if result_dtype is None else result_dtype
1522
        )
1523
    elif highest_type is complex:
1524
        result_dtype = _find_highest_dtype_filtered(
1525
            args,
1526
            lambda x: is_float_dtype(x) or is_complex_dtype(x),
1527
            float_as_complex=True,
1528
        )
1529
        if result_dtype is None:
1530
            result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
1531
    elif highest_type is int:
1532
        result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
1533
        result_dtype = torch.long if result_dtype is None else result_dtype
1534
    else:
1535
        # highest_type is bool
1536
        result_dtype = torch.bool
1537

1538
    if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
1539
        return get_computation_dtype(result_dtype), result_dtype
1540
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
1541
        return result_dtype, result_dtype
1542
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
1543
        if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
1544
            result_dtype = torch.get_default_dtype()
1545
        return get_computation_dtype(result_dtype), result_dtype
1546
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
1547
        # NOTE: computation can still occur in a complex dtype
1548
        computation_dtype = get_computation_dtype(result_dtype)
1549
        if is_complex_dtype(result_dtype):
1550
            result_dtype = corresponding_real_dtype(result_dtype)
1551
        return computation_dtype, result_dtype
1552
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
1553
        if is_boolean_dtype(result_dtype):
1554
            return torch.long, torch.long
1555
        return get_computation_dtype(result_dtype), result_dtype
1556
    elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
1557
        return get_computation_dtype(result_dtype), torch.bool
1558
    else:
1559
        raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}")
1560

1561

1562
def reduction_dtypes(
1563
    arg,
1564
    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
1565
    dtype: Optional[torch.dtype] = None,
1566
) -> Tuple[torch.dtype, Optional[torch.dtype]]:
1567
    # even though some reductions, like amin or amax, don't strictly require type promotion,
1568
    # all the math ops (including comparisons) are still defined only for a computation type,
1569
    # so promotion will still happen. We are doing it explicitly here
1570
    inp_dtype = dtype if dtype is not None else arg.dtype
1571
    computation_dtype = get_computation_dtype(inp_dtype)
1572
    if (
1573
        output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME
1574
        or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1575
    ):
1576
        result_dtype = dtype if dtype else arg.dtype
1577
        if (
1578
            output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
1579
            and is_complex_dtype(result_dtype)
1580
        ):
1581
            result_dtype = corresponding_real_dtype(result_dtype)
1582
    elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:
1583
        result_dtype = None
1584
    else:  # ALWAYS_BOOL
1585
        result_dtype = torch.bool
1586
    return computation_dtype, result_dtype
1587

1588

1589
# This function's logic is borrowed from the following functions defined in C++:
1590
# batched_matrix_contiguous_strides and contiguous_strides
1591
def make_contiguous_strides_for(
1592
    shape: ShapeType, row_major: bool = True
1593
) -> Tuple[int, ...]:
1594
    """
1595
    Returns the strides of a contiguous tensor if row_major
1596
    If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
1597
    This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
1598
    """
1599
    # contiguous_strides from c10/util/strides.h
1600
    validate_shape(shape)
1601
    if not shape:
1602
        return ()
1603

1604
    from torch.fx.experimental.symbolic_shapes import is_nested_int
1605

1606
    multiplier = 1
1607
    strides = []
1608
    for l in reversed(shape):
1609
        strides.append(multiplier)
1610
        multiplier *= l if is_nested_int(l) else sym_max(l, 1)
1611

1612
    result = tuple(reversed(strides))
1613

1614
    # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h
1615
    if row_major:
1616
        return result
1617
    else:
1618
        if len(shape) < 2:
1619
            return result
1620
        return result[:-2] + (1, max(shape[-2], 1))
1621

1622

1623
def make_channels_last_1d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1624
    torch._check(
1625
        len(shape) == 3,
1626
        lambda: "Only tensors of rank 3 can use the channels_last_1d memory format",
1627
    )
1628

1629
    multiplier = 1
1630
    strides = [0] * 3
1631
    for idx in (1, -1, 0):
1632
        # NOTE: intentionally divergence from make_contiguous_strides_for
1633
        # This is consistent with eager
1634
        strides[idx] = multiplier
1635
        multiplier *= shape[idx]
1636

1637
    return tuple(strides)
1638

1639

1640
def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1641
    # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
1642
    torch._check(
1643
        len(shape) == 4,
1644
        lambda: "Only tensors of rank 4 can use the channels_last memory format",
1645
    )
1646

1647
    multiplier = 1
1648
    strides = [0] * 4
1649
    for idx in (1, -1, -2, 0):
1650
        # NOTE: intentionally divergence from make_contiguous_strides_for
1651
        # This is consistent with eager
1652
        strides[idx] = multiplier
1653
        multiplier *= shape[idx]
1654

1655
    return tuple(strides)
1656

1657

1658
def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1659
    torch._check(
1660
        len(shape) == 5,
1661
        lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
1662
    )
1663

1664
    multiplier = 1
1665
    strides = [0] * 5
1666
    for idx in (1, -1, -2, -3, 0):
1667
        # NOTE: intentionally divergence from make_contiguous_strides_for
1668
        # This is consistent with eager
1669
        strides[idx] = multiplier
1670
        multiplier *= shape[idx]
1671

1672
    return tuple(strides)
1673

1674

1675
def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:
1676
    ndim = len(shape) if isinstance(shape, Sequence) else 1
1677
    if ndim == 3:
1678
        return make_channels_last_1d_strides_for(shape)
1679
    elif ndim == 4:
1680
        return make_channels_last_2d_strides_for(shape)
1681
    elif ndim == 5:
1682
        return make_channels_last_3d_strides_for(shape)
1683
    else:
1684
        raise RuntimeError(
1685
            f"no channels last format strides exist in {ndim} dimensions"
1686
        )
1687

1688

1689
def compute_reduction_output_shape(
1690
    shape: ShapeType, dimensions: Sequence
1691
) -> Tuple[int, ...]:
1692
    for idx in dimensions:
1693
        validate_idx(len(shape), idx)
1694

1695
    new_shape = []
1696
    for idx in range(len(shape)):
1697
        if idx in dimensions:
1698
            continue
1699

1700
        new_shape.append(shape[idx])
1701

1702
    return tuple(new_shape)
1703

1704

1705
def validate_no_repeating_dims(dims: Sequence):
1706
    if len(dims) != len(set(dims)):
1707
        raise RuntimeError("duplicate value in the list of dims")
1708

1709

1710
def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
1711
    if dims is None:
1712
        return tuple(range(len(shape)))
1713
    dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
1714
    validate_no_repeating_dims(dims)
1715
    return dims
1716

1717

1718
def set_correction(
1719
    unbiased: Optional[bool] = None,
1720
    correction: Optional[NumberType] = None,
1721
) -> float:
1722
    if correction is not None and unbiased is not None:
1723
        raise RuntimeError("cannot specify both correction and unbiased arguments")
1724
    elif correction is None and unbiased is None:
1725
        correction = 1.0
1726
    elif correction is None and unbiased is not None:
1727
        correction = 0.0 if unbiased is False else 1.0
1728
    # NB: we don't actually support symint here, but it's harmless to accept
1729
    if not isinstance(correction, (IntLike, FloatLike)):
1730
        raise ValueError("correction argument should be integer or float")
1731
    if correction < 0:
1732
        raise ValueError("correction argument should be non-negative")
1733
    return sym_float(correction)
1734

1735

1736
def compute_required_storage_length(
1737
    shape: ShapeType, strides: StrideType, storage_offset: int
1738
) -> int:
1739
    """Computes the minimum storage size to hold the given tensor geometry.
1740

1741
    Example
1742
    =======
1743

1744
    This is the size of a newly allocated tensor's storage, in units of elements
1745

1746
    >>> t = torch.empty((10, 20))
1747
    >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
1748
    200
1749

1750
    >>> # xdoctest: +SKIP(failing)
1751
    >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
1752
    >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
1753
    >>> size == t.storage().size()
1754
    True
1755

1756
    A valid tensor may have a larger storage size, but never smaller
1757

1758
    >>> slice = torch.empty(100)[20:40]
1759
    >>> slice.storage().size()
1760
    100
1761

1762
    >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
1763
    40
1764

1765
    """
1766
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1767

1768
    # Short-circuits if the shape has no elements
1769
    if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0):
1770
        return 0
1771

1772
    max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
1773
    # +1 to account for the first element which offsets are taken from
1774
    return 1 + storage_offset + max_offset
1775

1776

1777
def check_in_bounds_for_storage(
1778
    a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
1779
):
1780
    """
1781
    Determines if the given shape, strides, and offset are valid for the given storage.
1782
    """
1783

1784
    required_length = compute_required_storage_length(shape, strides, storage_offset)
1785
    if a.size() < required_length:
1786
        msg = (
1787
            f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, "
1788
            f"shape of {str(shape)}, and strides of {str(strides)}, "
1789
            f"which requires a storage of size {required_length}"
1790
        )
1791
        raise ValueError(msg)
1792

1793

1794
# NOTE: This function should ideally be removed, but some Meta internal models
1795
# packaged with `torch.package` are using it, so it will have to be removed
1796
# at some point in the future when those models no longer use this function.
1797
@deprecated(
1798
    "`torch._prims_common.check` is deprecated and will be removed in the future. "
1799
    "Please use `torch._check*` functions instead.",
1800
    category=FutureWarning,
1801
)
1802
def check(
1803
    b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
1804
) -> None:
1805
    """
1806
    Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
1807
    Error message is a callable producing a string (to avoid wasting time
1808
    string formatting in non-error case, and also to make it easier for torchdynamo
1809
    to trace.)
1810

1811
    .. note:: This function is planned for removal in the future. Please use
1812
        `torch._check*` functions instead.
1813
    """
1814
    torch._check_with(exc_type, b, s)
1815

1816

1817
# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
1818
# c10/core/MemoryFormat.h into one function
1819
def are_strides_like_channels_last(
1820
    shape: Sequence[int], strides: Sequence[int]
1821
) -> bool:
1822
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1823

1824
    ndim = len(shape)
1825

1826
    if ndim == 4:
1827
        # Check for channels_last_2d
1828
        dim_order = [1, 3, 2, 0]
1829
    elif ndim == 5:
1830
        # Check for channels_last_3d
1831
        dim_order = [1, 4, 3, 2, 0]
1832
    else:
1833
        return False
1834

1835
    if guard_size_oblivious(strides[1] == 0):
1836
        return False
1837

1838
    min = 0
1839
    for d in dim_order:
1840
        if guard_size_oblivious(shape[d] == 0):
1841
            return False
1842
        if guard_size_oblivious(strides[d] < min):
1843
            return False
1844
        if d == 0 and min == strides[1]:
1845
            return False
1846
        min = strides[d]
1847
        if guard_size_oblivious(strides[d] > 1):
1848
            min *= shape[d]
1849
    return True
1850

1851

1852
def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:
1853
    if x.layout != torch.strided:
1854
        return torch.contiguous_format
1855

1856
    if are_strides_like_channels_last(x.shape, x.stride()):
1857
        return torch.channels_last if x.ndim == 4 else torch.channels_last_3d
1858

1859
    return torch.contiguous_format
1860

1861

1862
def prod(xs: Sequence[NumberType]) -> NumberType:
1863
    """Product of elements in input sequence. Returns 1 for empty sequence"""
1864
    return reduce(operator.mul, xs, 1)
1865

1866

1867
def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:
1868
    """Checks if a shape can be expanded to another shape.
1869
    This is equivalent to checking if the two shapes are broadcastable.
1870
    """
1871
    # This is a Python implementation of
1872
    # aten/src/ATen/ExpandUtils.h:is_expandable_to
1873
    if len(shape) > len(desired):
1874
        return False
1875
    for i in range(len(shape)):
1876
        if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:
1877
            return False
1878
    return True
1879

1880

1881
def mask_tensor(mask: TensorLikeType, t: TensorLikeType):
1882
    """
1883
    Similar to torch.where(mask, t, 0) but if t is boolean,
1884
    result is also boolean and not promoted to int.
1885
    """
1886
    # torch.where(mask, t, False) is equivalent
1887
    # but feels hacky and might break in the future
1888
    if t.dtype is torch.bool:
1889
        return mask.logical_and(t)
1890
    else:
1891
        return torch.where(mask, t, 0)
1892

1893

1894
def get_aten_op(fn: Callable, name: str):
1895
    """
1896
    Given the __module__ of reference and its name, it returns
1897
    (our best guess of) the ATen name of the associated operation
1898

1899
    Note: In ATen, the __name__ of a function within a module often
1900
    starts by the module name. E.g. linalg_eigh, or special_zeta
1901
    """
1902
    module = fn.__module__
1903
    prefix = "torch._refs"
1904
    assert module.startswith(prefix)
1905
    module = module[len(prefix) :]
1906
    # We want to go from .special / .nn.functional
1907
    # to special and special_ / nn_functional_
1908
    if module:
1909
        module = module[1:]
1910
        module = module.replace(".", "_")
1911
        module = module + "_"
1912
    return getattr(torch._ops.ops.aten, f"{module}{name}")
1913

1914

1915
def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
1916
    return dtype if dtype is not None else torch.get_default_dtype()
1917

1918

1919
def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType:
1920
    return device if device is not None else torch.device("cpu")
1921

1922

1923
def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
1924
    return layout if layout is not None else torch.strided
1925

1926

1927
def clone_preserve_strides(x):
1928
    needed_size = compute_required_storage_length(
1929
        x.size(), x.stride(), x.storage_offset()
1930
    )
1931
    # Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
1932
    # so these as_strided() calls are not seen by autograd.
1933
    # We need to mimic this behavior in our ref/prim implementations.
1934
    # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
1935
    # We should revisit this when we add a compositional as_strided op,
1936
    # and also as part of https://github.com/pytorch/pytorch/issues/90507
1937
    try:
1938
        old = torch._C._dispatch_tls_is_dispatch_key_excluded(
1939
            torch._C.DispatchKey.ADInplaceOrView
1940
        )
1941
        torch._C._dispatch_tls_set_dispatch_key_excluded(
1942
            torch._C.DispatchKey.ADInplaceOrView, True
1943
        )
1944
        buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
1945
        return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
1946
    finally:
1947
        torch._C._dispatch_tls_set_dispatch_key_excluded(
1948
            torch._C.DispatchKey.ADInplaceOrView, old
1949
        )
1950

1951

1952
def alert_not_deterministic(caller: str):
1953
    if torch.are_deterministic_algorithms_enabled():
1954
        if torch.is_deterministic_algorithms_warn_only_enabled():
1955
            warnings.warn(
1956
                f"{caller} does not have a deterministic implementation, but you set "
1957
                f"'torch.use_deterministic_algorithms(True, warn_only=True)'. "
1958
                f"You can file an issue at https://github.com/pytorch/pytorch/issues "
1959
                f"to help us prioritize adding deterministic support for this operation."
1960
            )
1961
        else:
1962
            torch._check(
1963
                False,
1964
                lambda: (
1965
                    f"{caller} does not have a deterministic implementation, but you set "
1966
                    f"'torch.use_deterministic_algorithms(True)'. You can turn off "
1967
                    f"determinism just for this operation, or you can use the "
1968
                    f"'warn_only=True' option, if that's acceptable for your application. "
1969
                    f"You can also file an issue at https://github.com/pytorch/pytorch/issues "
1970
                    f"to help us prioritize adding deterministic support for this operation."
1971
                ),
1972
            )
1973

1974

1975
class CUDARngStateHelper:
1976
    @staticmethod
1977
    def get_torch_state_as_tuple(fake_mode=nullcontext()):
1978
        if not torch.cuda.is_available():
1979
            raise RuntimeError("CUDA not available")
1980

1981
        with fake_mode:
1982
            seed = torch.tensor(torch.cuda.initial_seed())
1983
            offset = torch.tensor(torch.cuda._get_rng_state_offset())
1984
            return seed, offset
1985

1986
    @staticmethod
1987
    def set_torch_state_tensor(seed, offset):
1988
        # Rng state is [64-bit seed, 64-bit offset]
1989
        seed_portion = seed.reshape([1]).view(torch.uint8)
1990
        offset_portion = offset.reshape([1]).view(torch.uint8)
1991
        new_state = torch.cat([seed_portion, offset_portion])
1992
        torch.cuda.set_rng_state(new_state)
1993

1994
    @staticmethod
1995
    def set_new_offset(relative_offset):
1996
        torch.cuda._set_rng_state_offset(relative_offset.item())
1997

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.