pytorch

Форк
0
/
__init__.py 
3031 строка · 81.5 Кб
1
import contextlib
2
import itertools
3
import operator
4
import weakref
5
from enum import Enum
6
from functools import partial, reduce
7
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
8

9
import torch
10

11
import torch._prims_common as utils
12
import torch.library
13
from torch import sym_float, Tensor, TypedStorage
14
from torch._C import _get_default_device
15
from torch._prims.debug_prims import register_debug_prims
16
from torch._prims.rng_prims import register_rng_prims
17
from torch._prims_common import (
18
    Dim,
19
    DimsSequenceType,
20
    DimsType,
21
    IntLike,
22
    Number,
23
    NumberType,
24
    RETURN_TYPE,
25
    ShapeType,
26
    StrideType,
27
    TensorLike,
28
    TensorLikeType,
29
    type_to_dtype,
30
)
31
from torch._prims_common.wrappers import backwards_not_supported
32
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
33
from torch.overrides import handle_torch_function, has_torch_function
34
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
35

36
prim = torch.library.Library("prims", "DEF")
37
prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
38
prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
39
prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
40
prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
41

42
# Experimental module containing prototype "primitive" operations.
43

44
__all__ = [
45
    #
46
    # Common datastructures and helpers
47
    #
48
    "RETURN_TYPE",
49
    #
50
    # Elementwise unary prims
51
    #
52
    "abs",
53
    "acos",
54
    "acosh",
55
    "asin",
56
    "asinh",
57
    "atan",
58
    "atanh",
59
    "cos",
60
    "cosh",
61
    "bessel_i0",
62
    "bessel_i0e",
63
    "bessel_i1",
64
    "bessel_i1e",
65
    "bessel_j0",
66
    "bessel_j1",
67
    "bitwise_not",
68
    "cbrt",
69
    "ceil",
70
    "conj_physical",
71
    "digamma",
72
    "erf",
73
    "erf_inv",
74
    "erfc",
75
    "erfcx",
76
    "exp",
77
    "expm1",
78
    "exp2",
79
    "fill",
80
    "floor",
81
    "imag",
82
    "isfinite",
83
    "lgamma",
84
    "log",
85
    "log1p",
86
    "log2",
87
    "log10",
88
    "ndtri",
89
    "neg",
90
    "real",
91
    "reciprocal",
92
    "round",
93
    "sign",
94
    "signbit",
95
    "sin",
96
    "sinh",
97
    "spherical_bessel_j0",
98
    "sqrt",
99
    "tan",
100
    "tanh",
101
    "trunc",
102
    #
103
    # Elementwise binary prims
104
    #
105
    "add",
106
    "atan2",
107
    "bitwise_and",
108
    "bitwise_or",
109
    "bitwise_xor",
110
    # 'complex',  # needs custom meta
111
    "div",
112
    "eq",
113
    "fmax",
114
    "fmin",
115
    "fmod",
116
    "frexp",
117
    "gcd",
118
    "ge",
119
    "gt",
120
    "hypot",
121
    "igamma",
122
    "igammac",
123
    "le",
124
    "lt",
125
    "maximum",
126
    "minimum",
127
    "mul",
128
    "ne",
129
    "nextafter",
130
    "pow",
131
    "remainder",
132
    "rsqrt",
133
    "shift_left",
134
    "shift_right_arithmetic",
135
    "shift_right_logical",  # not implemented
136
    "sub",
137
    "zeta",
138
    #
139
    # View prims
140
    #
141
    "as_strided",
142
    "broadcast_in_dim",
143
    "collapse_view",
144
    "conj",
145
    "expand_dims",
146
    "slice",
147
    "slice_in_dim",  # implemented using slice -- make this a ref?
148
    "split_dim",
149
    "squeeze",
150
    "transpose",
151
    "view_of",
152
    "view_element_type",
153
    #
154
    # Functionalized view mutations
155
    #
156
    "as_strided_scatter",
157
    #
158
    # Shape prims
159
    #
160
    "collapse",
161
    "cat",
162
    "reshape",
163
    "rev",
164
    #
165
    # Conditional prims
166
    #
167
    "where",
168
    #
169
    # Data conversion and movement prims
170
    #
171
    "clone",
172
    "convert_element_type",
173
    "device_put",
174
    "item",
175
    "maximum_value",
176
    "minimum_value",
177
    "copy_strided",
178
    #
179
    # Inplace prims
180
    #
181
    "copy_to",
182
    "resize",
183
    # "_set",  # Commented out, see note below
184
    #
185
    # Reduction prims
186
    #
187
    "amax",
188
    "amin",
189
    "prod",
190
    "sum",
191
    "xor_sum",
192
    "var",
193
    #
194
    # Tensor Creation Prims
195
    #
196
    "empty_strided",
197
    "empty_permuted",
198
    "scalar_tensor",
199
    "iota",
200
    #
201
    # Linear algebra (linalg) Prims
202
    #
203
    "svd",
204
    #
205
    # Randomness Prims
206
    #
207
    "normal",
208
    "_uniform_helper",
209
    #
210
    # FFT prims
211
    #
212
    "fft_r2c",
213
    "fft_c2c",
214
    "fft_c2r",
215
]
216

217

218
def TensorMeta(
219
    tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
220
    *,
221
    shape: Optional[ShapeType] = None,
222
    strides: Optional[StrideType] = None,
223
    dtype: Optional[torch.dtype] = None,
224
    device: Optional[Union[torch.device, str]] = None,
225
):
226
    if isinstance(tensorlike, Number):
227
        assert not shape and (shape is None or isinstance(shape, Sequence))
228
        assert not strides and (strides is None or isinstance(strides, Sequence))
229
        inferred_shape: Tuple[int, ...] = ()
230
        inferred_strides: Tuple[int, ...] = ()
231
        inferred_dtype = type_to_dtype(type(tensorlike))
232
        inferred_device = torch.device("cpu")
233
        # TODO: This looks wrong, a number that is wrapped into a tensor
234
        # needs to behave differently than a scalar tensor for type
235
        # promotion purposes
236
    elif tensorlike is not None:
237
        assert isinstance(tensorlike, torch.Tensor)
238
        inferred_shape = tuple(tensorlike.shape)
239
        inferred_strides = tuple(tensorlike.stride())
240
        inferred_dtype = tensorlike.dtype
241
        inferred_device = tensorlike.device
242
    else:
243
        # If no tensorlike "example" is given then all metadata
244
        # must be provided explicitly
245
        assert shape is not None
246
        assert strides is not None
247
        assert dtype is not None
248
        assert device is not None
249

250
    shape = inferred_shape if shape is None else tuple(shape)  # type: ignore[possibly-undefined]
251
    strides = inferred_strides if strides is None else tuple(strides)  # type: ignore[possibly-undefined]
252
    dtype = inferred_dtype if dtype is None else dtype  # type: ignore[possibly-undefined]
253
    device = inferred_device if device is None else device  # type: ignore[possibly-undefined]
254

255
    if isinstance(device, str):
256
        device = torch.device(device)
257

258
    return torch.empty_strided(shape, strides, dtype=dtype, device=device)
259

260

261
def _make_prim(
262
    *,
263
    schema: str,
264
    return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
265
    meta: Callable,
266
    impl_aten: Callable,
267
    doc: str,
268
    tags: Optional[Sequence[torch.Tag]] = None,
269
):
270
    """
271
    Creates a primitive operation.
272

273
    """
274

275
    prim.define(schema, tags=torch.Tag.pt2_compliant_tag)
276

277
    def _prim_impl(*args, **kwargs):
278
        # always run the meta function because aten implementation will
279
        # typically accept more inputs (e.g., it will do promotion and
280
        # broadcasting) which we want to reject
281
        meta(*args, **kwargs)
282
        return impl_aten(*args, **kwargs)
283

284
    # Right now prims don't support autograd (we can and should add an
285
    # argument that provides an implementation for backward here.)  Because we
286
    # don't have derivative formulas, we must setup a custom autograd function
287
    # that raises an error if backwards is invoked
288
    def _autograd_impl(*args, **kwargs):
289
        return backwards_not_supported(_prim)(*args, **kwargs)
290

291
    def _backend_select_impl(*args, **kwargs):
292
        if kwargs.get("device") and kwargs["device"].type == "meta":
293
            return meta(*args, **kwargs)
294
        if any(isinstance(x, torch.device) and x.type == "meta" for x in args):
295
            return meta(*args, **kwargs)
296
        else:
297
            return _prim_impl(*args, **kwargs)
298

299
    name = schema.split("(")[0]
300
    prim_impl.impl(name, _prim_impl)
301
    prim_autograd_impl.impl(name, _autograd_impl)
302
    prim_meta_impl.impl(name, meta)
303

304
    _prim_packet = getattr(torch._ops.ops.prims, name)
305
    _prim = _prim_packet.default
306
    if tags:
307
        _prim._tags = tags
308

309
    from torch._subclasses.fake_tensor import contains_tensor_types
310

311
    if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str(
312
        _prim
313
    ) in [
314
        # See https://github.com/pytorch/pytorch/issues/103532
315
        "prims.device_put.default"
316
    ]:
317
        prim_backend_select_impl.impl(name, _backend_select_impl)
318

319
    for p in (_prim_packet, _prim):
320
        p.__doc__ = doc
321
        p.return_type = return_type  # type: ignore[attr-defined]
322

323
        p.schema = schema
324
        p.prim_impl = _prim_impl
325
        p.prim_meta_impl = meta
326
        p.impl_aten = impl_aten
327

328
    return _prim
329

330

331
class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
332
    DEFAULT = (0,)
333
    INT_TO_FLOAT = (2,)
334
    ALWAYS_BOOL = (3,)
335
    COMPLEX_TO_FLOAT = (4,)
336

337

338
# TODO: implement dtype validation here, too, or on the corresponding refs
339
def _prim_elementwise_meta(
340
    *args,
341
    type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
342
    args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
343
) -> FakeTensor:
344
    """
345
    Meta function for elementwise operations that produce outputs in the same dtype
346
    as their inputs.
347

348
    Stride logic is currently incorrect.
349
    """
350

351
    assert len(args) > 0
352

353
    utils.check_same_dtype(*args)
354

355
    args_ = list(args)
356
    if args_with_fixed_dtypes is not None:
357
        args_ = list(args_with_fixed_dtypes) + args_
358

359
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
360
    utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
361

362
    l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_)
363
    shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
364

365
    # Acquires the dtype
366
    dtype = None
367
    scalar_type = None
368
    for arg in args:
369
        if isinstance(arg, TensorLike):
370
            if not utils.is_cpu_scalar_tensor(arg):
371
                dtype = arg.dtype
372
                break
373
            else:
374
                dtype = arg.dtype
375
        elif isinstance(arg, Number):
376
            scalar_type = type(arg)
377

378
    if dtype is None and scalar_type is not None:
379
        dtype = utils.type_to_dtype(scalar_type)
380

381
    # Acquires the device (if it exists) or number
382
    device = None
383
    number = None
384
    for arg in args_:
385
        if isinstance(arg, TensorLike):
386
            if utils.is_cpu_scalar_tensor(arg):
387
                if device is None:
388
                    device = arg.device
389
                # keep going, in case there is a cuda tensor later
390
            else:
391
                device = arg.device
392
                break
393

394
        elif isinstance(arg, Number):
395
            if number is None:
396
                number = arg
397

398
    # NOTE: type promotion behavior here is mostly hidden from tests because
399
    # references will typically handle the type promotion properly even if this doesn't
400
    # (but getting it wrong will cause too many casts to be inserted in traces!)
401
    if device is not None:
402
        assert dtype is not None
403
        if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
404
            dtype = dtype
405
        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
406
            dtype = torch.bool
407
        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
408
            if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype):
409
                dtype = torch.get_default_dtype()
410
        elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
411
            if utils.is_complex_dtype(dtype):
412
                dtype = utils.corresponding_real_dtype(dtype)
413
            else:
414
                dtype = dtype
415

416
        assert shape is not None
417
        return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype)  # type: ignore[return-value]
418

419
    # Number case
420
    # TODO: fix number type promotion (bool, complex->float)
421

422
    # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat)
423
    seen_float = False
424
    if isinstance(number, (torch.SymInt, torch.SymFloat)):
425
        for a in args:
426
            assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
427
            seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
428
        if seen_float:
429
            number = sym_float(number)
430

431
    return TensorMeta(number)  # type: ignore[arg-type]
432

433

434
def _complex_only_elementwise_meta(*args, **kwargs):
435
    torch._check(
436
        utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
437
    )
438
    return _prim_elementwise_meta(*args, **kwargs)
439

440

441
def _make_elementwise_unary_prim(
442
    name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
443
):
444
    """
445
    Creates an elementwise unary prim.
446
    """
447

448
    return _make_prim(
449
        schema=f"{name}(Tensor self) -> Tensor",
450
        meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
451
        return_type=RETURN_TYPE.NEW,
452
        **kwargs,
453
    )
454

455

456
def _make_elementwise_binary_prim(
457
    name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
458
):
459
    """
460
    Creates an elementwise binary prim.
461
    """
462

463
    return _make_prim(
464
        schema=f"{name}(Tensor self, Tensor other) -> Tensor",
465
        meta=partial(_prim_elementwise_meta, type_promotion=type_promotion),
466
        return_type=RETURN_TYPE.NEW,
467
        **kwargs,
468
    )
469

470

471
def _not_impl(*args, **kwargs):
472
    raise NotImplementedError
473

474

475
#
476
# Elementwise unary operations
477
#
478

479

480
abs = _make_elementwise_unary_prim(
481
    "abs",
482
    impl_aten=torch.abs,
483
    doc="",
484
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
485
)
486

487
acos = _make_elementwise_unary_prim(
488
    "acos",
489
    impl_aten=torch.acos,
490
    doc="",
491
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
492
)
493

494
acosh = _make_elementwise_unary_prim(
495
    "acosh",
496
    impl_aten=torch.acosh,
497
    doc="",
498
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
499
)
500

501
asin = _make_elementwise_unary_prim(
502
    "asin",
503
    impl_aten=torch.asin,
504
    doc="",
505
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
506
)
507

508
asinh = _make_elementwise_unary_prim(
509
    "asinh",
510
    impl_aten=torch.asinh,
511
    doc="",
512
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
513
)
514

515
atan = _make_elementwise_unary_prim(
516
    "atan",
517
    impl_aten=torch.atan,
518
    doc="",
519
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
520
)
521

522
atanh = _make_elementwise_unary_prim(
523
    "atanh",
524
    impl_aten=torch.atanh,
525
    doc="",
526
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
527
)
528

529
cos = _make_elementwise_unary_prim(
530
    "cos",
531
    impl_aten=torch.cos,
532
    doc="",
533
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
534
)
535

536
cosh = _make_elementwise_unary_prim(
537
    "cosh",
538
    impl_aten=torch.cosh,
539
    doc="",
540
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
541
)
542

543
bessel_j0 = _make_elementwise_unary_prim(
544
    "bessel_j0",
545
    impl_aten=torch.special.bessel_j0,
546
    doc="",
547
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
548
)
549

550
bessel_j1 = _make_elementwise_unary_prim(
551
    "bessel_j1",
552
    impl_aten=torch.special.bessel_j1,
553
    doc="",
554
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
555
)
556

557
bessel_i0 = _make_elementwise_unary_prim(
558
    "bessel_i0",
559
    impl_aten=torch.i0,
560
    doc="",
561
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
562
)
563

564
bessel_i0e = _make_elementwise_unary_prim(
565
    "bessel_i0e",
566
    impl_aten=torch.special.i0e,
567
    doc="",
568
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
569
)
570

571
bessel_i1 = _make_elementwise_unary_prim(
572
    "bessel_i1",
573
    impl_aten=torch.special.i1,
574
    doc="",
575
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
576
)
577

578
bessel_i1e = _make_elementwise_unary_prim(
579
    "bessel_i1e",
580
    impl_aten=torch.special.i1e,
581
    doc="",
582
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
583
)
584

585
bitwise_not = _make_elementwise_unary_prim(
586
    "bitwise_not",
587
    impl_aten=torch.bitwise_not,
588
    doc="",
589
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
590
)
591

592

593
def _cbrt_aten(a: torch.Tensor) -> Tensor:
594
    torch._check(
595
        not a.is_complex(),
596
        lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
597
    )
598
    # Returns the real cubic root of the number.
599
    # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
600
    # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
601
    # which is a complex number.
602
    # For more info see the section Note in
603
    # https://en.cppreference.com/w/cpp/numeric/math/cbrt
604
    return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
605

606

607
cbrt = _make_elementwise_unary_prim(
608
    "cbrt",
609
    impl_aten=_cbrt_aten,
610
    doc="",
611
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
612
)
613

614
ceil = _make_elementwise_unary_prim(
615
    "ceil",
616
    impl_aten=torch.ceil,
617
    doc="",
618
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
619
)
620

621

622
def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
623
    if not input.dtype.is_complex:
624
        raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
625

626
    strides = utils.compute_elementwise_output_strides(input)
627
    return TensorMeta(input, strides=strides)
628

629

630
conj_physical = _make_prim(
631
    schema="conj_physical(Tensor self) -> Tensor",
632
    meta=_conj_physical_meta,
633
    impl_aten=torch._conj_physical,
634
    doc="Returns the physical conjugation of a complex tensor",
635
    return_type=RETURN_TYPE.NEW,
636
)
637

638

639
def _clone_meta(
640
    input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
641
) -> TensorLikeType:
642
    if memory_format != torch.preserve_format:
643
        return torch.empty(
644
            input.shape,
645
            dtype=input.dtype,
646
            layout=input.layout,
647
            device=input.device,
648
            memory_format=memory_format,
649
        )
650

651
    # memory_format == torch.preserve_format
652
    strides = utils.compute_elementwise_output_strides(input)
653
    return torch.empty_strided(
654
        input.shape,
655
        strides,
656
        dtype=input.dtype,
657
        layout=input.layout,
658
        device=input.device,
659
    )
660

661

662
clone = _make_prim(
663
    schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
664
    meta=_clone_meta,
665
    impl_aten=torch.clone,
666
    doc="Returns the copy of a tensor",
667
    return_type=RETURN_TYPE.NEW,
668
)
669

670
digamma = _make_elementwise_unary_prim(
671
    "digamma",
672
    impl_aten=torch.digamma,
673
    doc="",
674
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
675
)
676

677
erf = _make_elementwise_unary_prim(
678
    "erf",
679
    impl_aten=torch.erf,
680
    doc="",
681
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
682
)
683

684
erf_inv = _make_elementwise_unary_prim(
685
    "erf_inv",
686
    impl_aten=torch.special.erfinv,
687
    doc="",
688
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
689
)
690

691
erfc = _make_elementwise_unary_prim(
692
    "erfc",
693
    impl_aten=torch.special.erfc,
694
    doc="",
695
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
696
)
697

698
erfcx = _make_elementwise_unary_prim(
699
    "erfcx",
700
    impl_aten=torch.special.erfcx,
701
    doc="",
702
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
703
)
704

705
exp = _make_elementwise_unary_prim(
706
    "exp",
707
    impl_aten=torch.exp,
708
    doc="",
709
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
710
)
711

712
expm1 = _make_elementwise_unary_prim(
713
    "expm1",
714
    impl_aten=torch.special.expm1,
715
    doc="",
716
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
717
)
718

719
exp2 = _make_elementwise_unary_prim(
720
    "exp2",
721
    impl_aten=torch.special.exp2,
722
    doc="",
723
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
724
)
725

726

727
def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
728
    return _prim_elementwise_meta(
729
        a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
730
    )
731

732

733
# NOTE: fill uses _make_prim directly because it has a value parameter
734
fill = _make_prim(
735
    schema="fill(Tensor self, Scalar value) -> Tensor",
736
    return_type=RETURN_TYPE.NEW,
737
    meta=_fill_meta,
738
    impl_aten=torch.fill,
739
    doc="",
740
)
741

742
floor = _make_elementwise_unary_prim(
743
    "floor",
744
    impl_aten=torch.floor,
745
    doc="",
746
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
747
)
748

749
imag = _make_prim(
750
    schema="imag(Tensor self) -> Tensor",
751
    meta=partial(
752
        _complex_only_elementwise_meta,
753
        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
754
    ),
755
    return_type=RETURN_TYPE.VIEW,
756
    impl_aten=torch.imag,
757
    doc="",
758
)
759

760
isfinite = _make_elementwise_unary_prim(
761
    "isfinite",
762
    impl_aten=torch.isfinite,
763
    doc="",
764
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
765
)
766

767
lgamma = _make_elementwise_unary_prim(
768
    "lgamma",
769
    impl_aten=torch.lgamma,
770
    doc="",
771
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
772
)
773

774
log = _make_elementwise_unary_prim(
775
    "log",
776
    impl_aten=torch.log,
777
    doc="",
778
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
779
)
780

781
log1p = _make_elementwise_unary_prim(
782
    "log1p",
783
    impl_aten=torch.log1p,
784
    doc="",
785
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
786
)
787

788
log2 = _make_elementwise_unary_prim(
789
    "log2",
790
    impl_aten=torch.log2,
791
    doc="",
792
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
793
)
794

795
log10 = _make_elementwise_unary_prim(
796
    "log10",
797
    impl_aten=torch.log10,
798
    doc="",
799
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
800
)
801

802
real = _make_prim(
803
    schema="real(Tensor self) -> Tensor",
804
    meta=partial(
805
        _complex_only_elementwise_meta,
806
        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
807
    ),
808
    return_type=RETURN_TYPE.VIEW,
809
    impl_aten=torch.real,
810
    doc="",
811
)
812

813
reciprocal = _make_elementwise_unary_prim(
814
    "reciprocal",
815
    impl_aten=torch.reciprocal,
816
    doc="",
817
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
818
)
819

820
ndtri = _make_elementwise_unary_prim(
821
    "ndtri",
822
    impl_aten=torch.special.ndtri,
823
    doc="",
824
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
825
)
826

827
neg = _make_elementwise_unary_prim(
828
    "neg",
829
    impl_aten=torch.neg,
830
    doc="",
831
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
832
)
833

834
round = _make_elementwise_unary_prim(
835
    "round",
836
    impl_aten=torch.round,
837
    doc="",
838
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
839
)
840

841
rsqrt = _make_elementwise_unary_prim(
842
    "rsqrt",
843
    impl_aten=torch.rsqrt,
844
    doc="",
845
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
846
)
847

848
sign = _make_elementwise_unary_prim(
849
    "sign",
850
    impl_aten=torch.sign,
851
    doc="",
852
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
853
)
854

855
signbit = _make_elementwise_unary_prim(
856
    "signbit",
857
    impl_aten=torch.signbit,
858
    doc="",
859
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
860
)
861

862
sin = _make_elementwise_unary_prim(
863
    "sin",
864
    impl_aten=torch.sin,
865
    doc="",
866
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
867
)
868

869
sinh = _make_elementwise_unary_prim(
870
    "sinh",
871
    impl_aten=torch.sinh,
872
    doc="",
873
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
874
)
875

876
spherical_bessel_j0 = _make_elementwise_unary_prim(
877
    "spherical_bessel_j0",
878
    impl_aten=torch.special.spherical_bessel_j0,
879
    doc="",
880
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
881
)
882

883
sqrt = _make_elementwise_unary_prim(
884
    "sqrt",
885
    impl_aten=torch.sqrt,
886
    doc="",
887
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
888
)
889

890
tan = _make_elementwise_unary_prim(
891
    "tan",
892
    impl_aten=torch.tan,
893
    doc="",
894
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
895
)
896

897
tanh = _make_elementwise_unary_prim(
898
    "tanh",
899
    impl_aten=torch.tanh,
900
    doc="",
901
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
902
)
903

904
trunc = _make_elementwise_unary_prim(
905
    "trunc",
906
    impl_aten=torch.trunc,
907
    doc="",
908
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
909
)
910

911
#
912
# Elementwise binary operations
913
#
914

915
add = _make_elementwise_binary_prim(
916
    name="add",
917
    impl_aten=torch.add,
918
    doc="",
919
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
920
)
921

922
atan2 = _make_elementwise_binary_prim(
923
    name="atan2",
924
    impl_aten=torch.atan2,
925
    doc="",
926
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
927
)
928

929
bitwise_and = _make_elementwise_binary_prim(
930
    "bitwise_and",
931
    impl_aten=torch.bitwise_and,
932
    doc="",
933
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
934
)
935

936
bitwise_or = _make_elementwise_binary_prim(
937
    "bitwise_or",
938
    impl_aten=torch.bitwise_or,
939
    doc="",
940
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
941
)
942

943
bitwise_xor = _make_elementwise_binary_prim(
944
    "bitwise_xor",
945
    impl_aten=torch.bitwise_xor,
946
    doc="",
947
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
948
)
949

950
# TODO: complex needs a special meta to account for its float -> complex behavior
951
# complex = _make_elementwise_binary_prim(
952
#   impl_aten=torch.complex,
953
#   doc="",
954
# )
955

956

957
# div prim performs truncation division on integer inputs
958
#   and true division for floating and complex inputs
959
def _div_aten(a, b):
960
    is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
961
        isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
962
    )
963

964
    if is_integral:
965
        return torch.div(a, b, rounding_mode="trunc")
966
    else:
967
        return torch.true_divide(a, b)
968

969

970
div = _make_elementwise_binary_prim(
971
    "div",
972
    impl_aten=_div_aten,
973
    doc="",
974
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
975
)
976

977
eq = _make_elementwise_binary_prim(
978
    "eq",
979
    impl_aten=torch.eq,
980
    doc="",
981
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
982
)
983

984
fmax = _make_elementwise_binary_prim(
985
    "fmax",
986
    impl_aten=torch.fmax,
987
    doc="",
988
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
989
)
990

991
fmin = _make_elementwise_binary_prim(
992
    "fmin",
993
    impl_aten=torch.fmin,
994
    doc="",
995
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
996
)
997

998
fmod = _make_elementwise_binary_prim(
999
    "fmod",
1000
    impl_aten=torch.fmod,
1001
    doc="",
1002
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1003
)
1004

1005

1006
gcd = _make_elementwise_binary_prim(
1007
    "gcd",
1008
    impl_aten=torch.gcd,
1009
    doc="",
1010
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1011
)
1012

1013

1014
ge = _make_elementwise_binary_prim(
1015
    "ge",
1016
    impl_aten=torch.ge,
1017
    doc="",
1018
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1019
)
1020

1021
gt = _make_elementwise_binary_prim(
1022
    "gt",
1023
    impl_aten=torch.gt,
1024
    doc="",
1025
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1026
)
1027

1028
hypot = _make_elementwise_binary_prim(
1029
    "hypot",
1030
    impl_aten=torch.hypot,
1031
    doc="",
1032
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1033
)
1034

1035
igamma = _make_elementwise_binary_prim(
1036
    "igamma",
1037
    impl_aten=torch.special.gammainc,
1038
    doc="",
1039
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1040
)
1041

1042
igammac = _make_elementwise_binary_prim(
1043
    "igammac",
1044
    impl_aten=torch.special.gammaincc,
1045
    doc="",
1046
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1047
)
1048

1049
le = _make_elementwise_binary_prim(
1050
    "le",
1051
    impl_aten=torch.le,
1052
    doc="",
1053
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1054
)
1055

1056
lt = _make_elementwise_binary_prim(
1057
    "lt",
1058
    impl_aten=torch.lt,
1059
    doc="",
1060
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1061
)
1062

1063

1064
# Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs
1065
def _maximum_aten(
1066
    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1067
) -> TensorLikeType:
1068
    if isinstance(a, TensorLike) and isinstance(b, Number):
1069
        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1070
    elif isinstance(b, TensorLike) and isinstance(a, Number):
1071
        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1072

1073
    return torch.maximum(a, b)  # type: ignore[arg-type]
1074

1075

1076
maximum = _make_elementwise_binary_prim(
1077
    "maximum",
1078
    impl_aten=_maximum_aten,
1079
    doc="",
1080
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1081
)
1082

1083

1084
def _minimum_aten(
1085
    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
1086
) -> TensorLikeType:
1087
    if isinstance(a, TensorLike) and isinstance(b, Number):
1088
        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
1089
    elif isinstance(b, TensorLike) and isinstance(a, Number):
1090
        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
1091

1092
    return torch.minimum(a, b)  # type: ignore[arg-type]
1093

1094

1095
minimum = _make_elementwise_binary_prim(
1096
    "minimum",
1097
    impl_aten=_minimum_aten,
1098
    doc="",
1099
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1100
)
1101

1102
mul = _make_elementwise_binary_prim(
1103
    "mul",
1104
    impl_aten=torch.mul,
1105
    doc="",
1106
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1107
)
1108

1109
ne = _make_elementwise_binary_prim(
1110
    "ne",
1111
    impl_aten=torch.ne,
1112
    doc="",
1113
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
1114
)
1115

1116
nextafter = _make_elementwise_binary_prim(
1117
    "nextafter",
1118
    impl_aten=torch.nextafter,
1119
    doc="",
1120
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1121
)
1122

1123
pow = _make_elementwise_binary_prim(
1124
    "pow",
1125
    impl_aten=torch.pow,
1126
    doc="",
1127
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1128
)
1129

1130
remainder = _make_elementwise_binary_prim(
1131
    "remainder",
1132
    impl_aten=torch.remainder,
1133
    doc="",
1134
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1135
)
1136

1137

1138
shift_left = _make_elementwise_binary_prim(
1139
    "shift_left",
1140
    impl_aten=torch.bitwise_left_shift,
1141
    doc="",
1142
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1143
)
1144

1145
shift_right_arithmetic = _make_elementwise_binary_prim(
1146
    "shift_right_arithmetic",
1147
    impl_aten=torch.bitwise_right_shift,
1148
    doc="",
1149
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1150
)
1151

1152
shift_right_logical = _not_impl
1153

1154
sub = _make_elementwise_binary_prim(
1155
    "sub",
1156
    impl_aten=torch.sub,
1157
    doc="",
1158
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1159
)
1160

1161
zeta = _make_elementwise_binary_prim(
1162
    "zeta",
1163
    impl_aten=torch.special.zeta,
1164
    doc="",
1165
    type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
1166
)
1167

1168

1169
#
1170
# View operations
1171
def _as_strided_meta(
1172
    a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
1173
) -> TensorLikeType:
1174
    assert len(size) == len(stride)
1175
    assert storage_offset >= 0
1176
    utils.validate_strides(stride)
1177
    utils.validate_shape(size)
1178

1179
    if reduce(operator.mul, size) == 0:
1180
        # NOTE: This special case is to avoid having to acquire the storage below
1181
        # as_strided to shapes with no elements are trivially valid, so it's OK
1182
        pass
1183
    elif isinstance(a, torch.Tensor):
1184
        utils.check_in_bounds_for_storage(
1185
            a._typed_storage(), size, stride, storage_offset
1186
        )
1187

1188
    return torch.as_strided(a, size, stride, storage_offset)
1189

1190

1191
def _as_strided_aten(
1192
    a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
1193
) -> Tensor:
1194
    return torch.as_strided(a, size, stride, storage_offset)
1195

1196

1197
_as_strided_doc = """
1198
    Creates a view of the tensor with the given shape (size), strides (stride) and
1199
    storage offset (storage_offset).
1200
"""
1201

1202
as_strided = _make_prim(
1203
    schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
1204
    meta=_as_strided_meta,
1205
    impl_aten=_as_strided_aten,
1206
    return_type=RETURN_TYPE.VIEW,
1207
    doc=_as_strided_doc,
1208
)
1209

1210

1211
def _broadcast_in_dim_meta(
1212
    a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
1213
):
1214
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1215

1216
    # Type checks
1217
    assert isinstance(a, TensorLike)
1218
    assert isinstance(shape, Sequence)
1219
    assert isinstance(broadcast_dimensions, Sequence)
1220

1221
    # every dimension must be accounted for
1222
    assert a.ndim == len(broadcast_dimensions)
1223

1224
    # broadcast shape must have weakly more dimensions
1225
    assert len(shape) >= a.ndim
1226

1227
    # broadcast_dimensions must be an ascending sequence
1228
    # (no relative reordering of dims) of integers and
1229
    # each dimension must be within the new shape
1230
    def _greater_than_reduce(acc, x):
1231
        assert isinstance(x, Dim)
1232
        assert x > acc
1233
        assert x < len(shape)
1234

1235
        return x
1236

1237
    reduce(_greater_than_reduce, broadcast_dimensions, -1)
1238

1239
    # shape must be broadcastable to
1240
    for idx, new_idx in enumerate(broadcast_dimensions):
1241
        if not guard_size_oblivious(a.shape[idx] == 1):
1242
            torch._check(
1243
                a.shape[idx] == shape[new_idx],
1244
                lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}",
1245
            )
1246

1247
    new_strides = []
1248
    original_idx = 0
1249
    for idx in range(len(shape)):
1250
        if idx in broadcast_dimensions:
1251
            # Assigns a stride of zero to dimensions
1252
            # which were actually broadcast
1253
            if guard_size_oblivious(a.shape[original_idx] != shape[idx]):
1254
                new_strides.append(0)
1255
            else:
1256
                new_strides.append(a.stride()[original_idx])
1257
            original_idx = original_idx + 1
1258
        else:
1259
            if guard_size_oblivious(shape[idx] != 1):
1260
                new_strides.append(0)
1261
            elif original_idx == a.ndim:
1262
                new_strides.append(1)
1263
            else:
1264
                new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1265

1266
    return a.as_strided(shape, new_strides, a.storage_offset())
1267

1268

1269
def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
1270
    s = list(shape)
1271
    for broadcast_dimension in broadcast_dimensions:
1272
        s[broadcast_dimension] = -1
1273

1274
    v = a
1275
    for idx, x in enumerate(s):
1276
        if x != -1:
1277
            v = v.unsqueeze(idx)
1278

1279
    return v.expand(shape)
1280

1281

1282
_broadcast_in_dim_doc = """
1283
  Creates a view of a with the specified shape.
1284

1285
  Allows adding dimensions of any length and broadcasting
1286
  dimensions of length one in a to any length.
1287

1288
  The location of the broadcast dimensions must be specified
1289
  using the broadcast_dimensions argument. Changing the
1290
  relative order of dimensions is not supported.
1291
  """
1292

1293
broadcast_in_dim = _make_prim(
1294
    schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
1295
    meta=_broadcast_in_dim_meta,
1296
    impl_aten=_broadcast_in_dim_aten,
1297
    return_type=RETURN_TYPE.VIEW,
1298
    doc=_broadcast_in_dim_doc,
1299
)
1300

1301

1302
def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
1303
    # Special-case for zero dimensional tensors
1304
    ndim = max(1, a.dim())
1305
    utils.validate_idx(ndim, start)
1306
    utils.validate_idx(ndim, end)
1307

1308
    # Verifies end is strictly greater than start
1309
    # (Collapse requires a non-empty interval)
1310
    torch._check_value(
1311
        end >= start,
1312
        lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!",
1313
    )
1314

1315

1316
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
1317
    """
1318
    Returns the shape of a with dims in [start, end) merged into a single dimension.
1319
    """
1320
    # Special-case for zero dimensional tensors
1321
    shape = (1,) if len(shape) == 0 else tuple(shape)
1322

1323
    dim_length = 1
1324
    for s in shape[start : end + 1]:
1325
        dim_length = dim_length * s
1326

1327
    return shape[0:start] + (dim_length,) + shape[end + 1 :]
1328

1329

1330
def _collapse_view_helper(
1331
    a: TensorLikeType, start: int, end: int
1332
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
1333
    assert isinstance(a, TensorLike)
1334

1335
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
1336

1337
    _validate_collapse_args(a, start, end)
1338

1339
    # Special-case for zero dimensional tensors
1340
    if a.ndim == 0:
1341
        shape = (1,)
1342
        strides = (1,)
1343
    else:
1344
        shape = a.shape  # type: ignore[assignment]
1345
        strides = a.stride()  # type: ignore[assignment]
1346

1347
    if a.ndim == 0 or (end == start):
1348
        return shape, strides
1349

1350
    length = shape[end]
1351
    stride = strides[end]
1352
    for idx in range(end - 1, start - 1, -1):
1353
        if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
1354
            shape[idx + 1] == 0
1355
        ):
1356
            length = 0
1357
            stride = 0
1358
            break
1359

1360
        if guard_size_oblivious(shape[idx] == 1):
1361
            continue
1362

1363
        length = length * shape[idx]
1364
        stride = min(stride, strides[idx])
1365

1366
        if (
1367
            guard_size_oblivious(a.numel() > 0)
1368
            and guard_size_oblivious(shape[idx + 1] != 1)
1369
            and not guard_size_oblivious(
1370
                strides[idx] == strides[idx + 1] * shape[idx + 1]
1371
            )
1372
        ):
1373
            return None, None
1374

1375
    new_shape = shape[:start] + (length,) + shape[end + 1 :]
1376
    new_strides = strides[:start] + (stride,) + strides[end + 1 :]
1377

1378
    # NOTE: when the input has no elements it's restrided as if it were contiguous
1379
    if guard_size_oblivious(a.numel() == 0):
1380
        new_strides = utils.make_contiguous_strides_for(new_shape)
1381

1382
    return new_shape, new_strides
1383

1384

1385
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
1386
    new_shape, new_strides = _collapse_view_helper(a, start, end)
1387

1388
    if new_shape is None:
1389
        msg = "Attempting to view a collapsed tensor, but no such view exists!"
1390
        raise ValueError(msg)
1391

1392
    assert new_strides is not None
1393
    return a.as_strided(new_shape, new_strides, a.storage_offset())
1394

1395

1396
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
1397
    new_shape = _collapsed_shape(a.shape, start, end)
1398
    return a.view(new_shape)
1399

1400

1401
_collapse_view_doc = """
1402
  Creates a view of a with the dimensions between
1403
  start (inclusive) and end (exclusive) merged into a
1404
  single dimension.
1405

1406
  If it's not possible to take such a view then an error
1407
  is thrown. See collapse instead.
1408

1409
  The dimensions can be merged if and only if
1410
  they are all "nested" with each other. That is, they all
1411
  have the property that
1412

1413
  stride[i] = stride[i+1] * shape[i+1]
1414

1415
  for all i in [start, end - 1).
1416
  """
1417

1418
collapse_view = _make_prim(
1419
    schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
1420
    meta=_collapse_view_meta,
1421
    impl_aten=_collapse_view_aten,
1422
    return_type=RETURN_TYPE.VIEW,
1423
    doc=_collapse_view_doc,
1424
)
1425

1426

1427
def _conj_meta(a: TensorLikeType) -> TensorLikeType:
1428
    if not a.dtype.is_complex:
1429
        raise RuntimeError("Expected complex dtype in prims.conj")
1430
    out = a.as_strided(a.shape, a.stride(), a.storage_offset())
1431
    torch._C._set_conj(out, not a.is_conj())
1432
    return out
1433

1434

1435
_conj_doc = """
1436
Returns a conjugated view of the original tensor
1437
"""
1438

1439
conj = _make_prim(
1440
    schema="conj(Tensor(a) a) -> Tensor(a)",
1441
    meta=_conj_meta,
1442
    impl_aten=torch.conj,
1443
    return_type=RETURN_TYPE.VIEW,
1444
    doc=_conj_doc,
1445
)
1446

1447

1448
def expand_dims(
1449
    a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
1450
) -> TensorLikeType:
1451
    """
1452
    Creates a view of a with a.ndim + len(dimensions) dimensions, with new
1453
    dimensions of length one at the dimensions specified by dimensions.
1454
    """
1455
    if ndim is not None:
1456
        # TODO: this is only here to support the unsqueeze ref
1457
        dims = sorted(utils.canonicalize_dims(ndim, dimensions))  # type: ignore[arg-type]
1458
    else:
1459
        dims = sorted(utils.canonicalize_dims(a.ndim, dimensions))  # type: ignore[arg-type]
1460
    if len(set(dims)) != len(dims):
1461
        msg = f"Received duplicate dimensions to expand in {str(dimensions)}"
1462
        raise ValueError(msg)
1463

1464
    new_shape = list(a.shape)
1465
    for idx in dims:
1466
        new_shape.insert(idx, 1)
1467

1468
    broadcast_dimensions = [
1469
        idx for idx in range(len(new_shape)) if idx not in dimensions
1470
    ]
1471
    return broadcast_in_dim(a, new_shape, broadcast_dimensions)
1472

1473

1474
# Note: saves the Python slice object because we're about to clobber its name with the slice prim
1475
pyslice: Type[slice] = slice  # type: ignore[has-type]
1476

1477

1478
def _slice_meta(
1479
    a: TensorLikeType,
1480
    start_indices: DimsSequenceType,
1481
    limit_indices: DimsSequenceType,
1482
    strides: Optional[StrideType] = None,
1483
) -> TensorLikeType:
1484
    _strides = strides if strides is not None else [1] * len(start_indices)
1485

1486
    if a.ndim != len(start_indices):
1487
        msg = f"Attempting to slice tensor of rank {a.ndim} with start_indices of length {len(start_indices)}!"
1488
        raise ValueError(msg)
1489

1490
    if a.ndim != len(limit_indices):
1491
        msg = f"Attempting to slice tensor of rank {a.ndim} with limit_indices of length {len(limit_indices)}!"
1492
        raise ValueError(msg)
1493

1494
    if a.ndim != len(_strides):
1495
        msg = f"Attempting to slice tensor of rank {a.ndim} with strides of length {len(limit_indices)}!"
1496
        raise ValueError(msg)
1497

1498
    for x, y in zip(start_indices, a.shape):
1499
        if x < 0:
1500
            msg = f"Attempting to slice a tensor with a negative start index of {x}!"
1501
            raise ValueError(msg)
1502
        if x > y:
1503
            msg = (
1504
                f"Attempting to slice a tensor but a start index in {start_indices} is greater than"
1505
                f" the length of its corresponding dimension in shape {a.shape}"
1506
            )
1507
            raise ValueError(msg)
1508

1509
    for x, y, z in zip(limit_indices, a.shape, start_indices):
1510
        if x < 0:
1511
            msg = f"Attempting to slice a tensor with a negative stop index of {x}!"
1512
            raise ValueError(msg)
1513
        if x > y:
1514
            msg = (
1515
                f"Attempting to slice a tensor but a stop index in {limit_indices} is greater than the length of "
1516
                f" its corresponding dimension in shape {a.shape}"
1517
            )
1518
            raise ValueError(msg)
1519
        if x < z:
1520
            msg = (
1521
                f"Attempting to slice a tensor but a start index in {x} is greater than "
1522
                f" its corresponding stop index {z}"
1523
            )
1524

1525
    for x in _strides:
1526
        if x <= 0:
1527
            msg = f"Attempting to slice a tensor with a non-positive step of {x}!"
1528
            raise ValueError(msg)
1529

1530
    new_shape = []
1531
    for x, y, z in zip(start_indices, limit_indices, _strides):
1532
        new_shape.append(1 + (y - x - 1) // z)
1533

1534
    new_strides = []
1535
    for x, y in zip(a.stride(), _strides):
1536
        new_strides.append(x * y)
1537

1538
    return a.as_strided(new_shape, new_strides, a.storage_offset())
1539

1540

1541
def _slice_aten(
1542
    a: Tensor,
1543
    start_indices: DimsSequenceType,
1544
    limit_indices: DimsSequenceType,
1545
    strides: Optional[StrideType] = None,
1546
) -> Tensor:
1547
    _strides = strides if strides is not None else [1] * len(start_indices)
1548

1549
    slices = []
1550
    for start, stop, step in zip(start_indices, limit_indices, _strides):
1551
        slices.append(pyslice(start, stop, step))
1552

1553
    return operator.getitem(a, slices)  # type: ignore[call-overload]
1554

1555

1556
_slice_doc = """
1557
    Creates a view of a "bounding box" within the tensor.
1558

1559
    The bounding box is specified independently in each of the tensor's dimensions.
1560
    start_indices and limit_indices describe the box's boundaries for their corresponding
1561
    dimensions. If strides is specified then they specify the step size between elements
1562
    in their corresponding dimension.
1563

1564
    This operation is analogous to slicing in NumPy, but does not permit slices where
1565
    the stop indices are less than the start indices.
1566
    """
1567

1568
slice = _make_prim(
1569
    schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
1570
    meta=_slice_meta,
1571
    impl_aten=_slice_aten,
1572
    return_type=RETURN_TYPE.VIEW,
1573
    doc=_slice_doc,
1574
)
1575

1576

1577
def _slice_in_dim_meta(
1578
    a: TensorLikeType,
1579
    start_index: int,
1580
    limit_index: int,
1581
    stride: int = 1,
1582
    axis: int = 0,
1583
) -> TensorLikeType:
1584
    if axis < 0:
1585
        msg = f"slice_in_dim: received a negative axis {axis}"
1586
        raise ValueError(msg)
1587
    if axis >= a.ndim:
1588
        msg = f"slice_in_dim: axis {axis} is greater or equal to the rank {a.ndim} of the tensor"
1589
        raise ValueError(msg)
1590

1591
    if start_index < 0:
1592
        msg = f"slice_in_dim: received a negative start_index {start_index}"
1593
        raise ValueError(msg)
1594

1595
    if start_index > a.shape[axis]:
1596
        msg = f"slice_in_dim: start_index is greater than the length {start_index} of dimension {axis}"
1597
        raise ValueError(msg)
1598

1599
    if limit_index > a.shape[axis]:
1600
        msg = f"slice_in_dim: limit_index is greater than the length {limit_index} of dimension {axis}"
1601
        raise ValueError(msg)
1602

1603
    if limit_index < start_index:
1604
        msg = f"slice_in_dim: received a limit_index {limit_index} less than the start_index {start_index}"
1605
        raise ValueError(msg)
1606

1607
    if stride < 0:
1608
        msg = f"slice_in_dim: received a non-positive stride of {stride}!"
1609
        raise ValueError(msg)
1610

1611
    start_indices = [0] * a.ndim
1612
    limit_indices = list(a.shape)
1613
    strides = [1] * a.ndim
1614

1615
    start_indices[axis] = start_index
1616
    limit_indices[axis] = limit_index
1617
    strides[axis] = stride
1618

1619
    return _slice_meta(a, start_indices, limit_indices, strides)
1620

1621

1622
def _slice_in_dim_aten(
1623
    a: Tensor,
1624
    start_index: int,
1625
    limit_index: int,
1626
    stride: int = 1,
1627
    axis: int = 0,
1628
) -> Tensor:
1629
    start_indices = [0] * a.ndim
1630
    limit_indices = list(a.shape)
1631
    strides = [1] * a.ndim
1632

1633
    start_indices[axis] = start_index
1634
    limit_indices[axis] = limit_index
1635
    strides[axis] = stride
1636

1637
    return slice(a, start_indices, limit_indices, strides)
1638

1639

1640
_slice_in_dim_doc = """
1641
    Convenience wrapper for slicing just one dimension using slice.
1642
    """
1643

1644
# TODO: make stride SymInt
1645
slice_in_dim = _make_prim(
1646
    schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
1647
    meta=_slice_in_dim_meta,
1648
    impl_aten=_slice_in_dim_aten,
1649
    return_type=RETURN_TYPE.VIEW,
1650
    doc=_slice_in_dim_doc,
1651
)
1652

1653

1654
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
1655
    assert isinstance(a, TensorLike)
1656
    utils.validate_idx(a.ndim, dim)
1657
    utils.validate_dim_length(outer_length)
1658

1659
    # Verifies the dim can be split with the specified lhs_length
1660
    inner_length = a.shape[dim] // outer_length
1661

1662
    if (a.shape[dim] % outer_length) != 0:
1663
        msg = "Attempting to split dimension of length {}, but outer length of {} divides it with a remainder!".format(
1664
            a.shape[dim], outer_length
1665
        )
1666
        raise ValueError(msg)
1667

1668
    new_shape: List[int] = []
1669
    new_strides: List[int] = []
1670
    for idx in range(a.ndim):
1671
        if idx == dim:
1672
            new_shape.extend((outer_length, inner_length))
1673
            new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
1674
        else:
1675
            new_shape.append(a.shape[idx])
1676
            new_strides.append(a.stride()[idx])
1677

1678
    return a.as_strided(new_shape, new_strides, a.storage_offset())
1679

1680

1681
def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
1682
    inner_length = a.shape[dim] // outer_length
1683
    new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
1684

1685
    return a.view(new_shape)
1686

1687

1688
_split_dim_doc = """
1689
  Creates a view of a with the given dimension (of length l) split
1690
  into two dimensions, with the outer of the two having
1691
  length outer_length and the inner of the two having computed
1692
  length inner_length such outer_length * inner_length = l.
1693
  """
1694

1695
# TODO: consider renaming split_dim_view
1696
split_dim = _make_prim(
1697
    schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
1698
    meta=_split_dim_meta,
1699
    impl_aten=_split_dim_aten,
1700
    return_type=RETURN_TYPE.VIEW,
1701
    doc=_split_dim_doc,
1702
)
1703

1704

1705
# Note: allows dimensions to be specified redundantly
1706
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
1707
    assert isinstance(a, TensorLike)
1708

1709
    for idx in dimensions:
1710
        utils.validate_idx(a.ndim, idx)
1711
        assert a.shape[idx] == 1
1712

1713
    new_shape = []
1714
    new_strides = []
1715
    for idx in range(len(a.shape)):
1716
        if idx in dimensions:
1717
            continue
1718

1719
        new_shape.append(a.shape[idx])
1720
        new_strides.append(a.stride()[idx])
1721

1722
    return a.as_strided(new_shape, new_strides, a.storage_offset())
1723

1724

1725
_squeeze_doc = """
1726
  Creates a view of the tensor with the specified dimensions removed.
1727

1728
  The removed dimensions must each have length one.
1729
  """
1730

1731
squeeze = _make_prim(
1732
    schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
1733
    meta=_squeeze_meta,
1734
    impl_aten=torch.squeeze,
1735
    return_type=RETURN_TYPE.VIEW,
1736
    doc=_squeeze_doc,
1737
)
1738

1739

1740
def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
1741
    if a.ndim != len(permutation):
1742
        msg = "Attempting to permute a tensor of rank {}, but received a permutation of length {}!".format(
1743
            a.ndim, len(permutation)
1744
        )
1745
        raise ValueError(msg)
1746

1747
    if not utils.is_valid_permutation(a.ndim, permutation):
1748
        msg = f"Received an invalid permutation, {permutation}!"
1749
        raise ValueError(msg)
1750

1751
    new_shape = [0] * a.ndim
1752
    new_strides = [0] * a.ndim
1753
    for idx, dim in enumerate(permutation):
1754
        new_shape[idx] = a.shape[dim]
1755
        new_strides[idx] = a.stride()[dim]
1756

1757
    return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
1758

1759

1760
def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
1761
    return torch.permute(a, permutation)
1762

1763

1764
_transpose_doc = """
1765
    Creates a view of the tensor with its dimensions permuted.
1766

1767
    The length of the permutation must be the rank of the tensor,
1768
    and each element of the permutation specifies the new order
1769
    for the corresponding dimension.
1770
    """
1771

1772
transpose = _make_prim(
1773
    schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
1774
    meta=_transpose_meta,
1775
    impl_aten=_transpose_aten,
1776
    return_type=RETURN_TYPE.VIEW,
1777
    doc=_transpose_doc,
1778
)
1779

1780

1781
def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
1782
    return a.as_strided(a.shape, a.stride(), a.storage_offset())
1783

1784

1785
def _view_of_aten(a: Tensor) -> Tensor:
1786
    return a.view(a.shape)
1787

1788

1789
_view_of_doc = """
1790
    Creates a view of the tensor.
1791
    """
1792

1793
view_of = _make_prim(
1794
    schema="view_of(Tensor(a) a) -> Tensor",
1795
    meta=_view_of_meta,
1796
    impl_aten=_view_of_aten,
1797
    return_type=RETURN_TYPE.VIEW,
1798
    doc=_view_of_doc,
1799
)
1800

1801

1802
def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
1803
    return a.view(dtype)
1804

1805

1806
def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
1807
    return a.view(dtype)
1808

1809

1810
_view_element_type_doc = """
1811
    Creates a view of the tensor with a different dtype.
1812
    """
1813

1814
view_element_type = _make_prim(
1815
    schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor",
1816
    meta=_view_element_type_meta,
1817
    impl_aten=_view_element_type_aten,
1818
    return_type=RETURN_TYPE.VIEW,
1819
    doc=_view_element_type_doc,
1820
)
1821

1822
#
1823
# Functionalized view mutations
1824
#
1825

1826

1827
def _as_strided_scatter_meta(
1828
    input: TensorLikeType,
1829
    src: TensorLikeType,
1830
    size: ShapeType,
1831
    stride: StrideType,
1832
    storage_offset: int,
1833
) -> TensorLikeType:
1834
    utils.validate_shape(size)
1835
    utils.validate_strides(stride)
1836

1837
    required_size = utils.compute_required_storage_length(size, stride, storage_offset)
1838
    torch._check(
1839
        input.numel() >= required_size,
1840
        lambda: (
1841
            f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
1842
            f" and itemsize {input.element_size()} requiring a storage size of "
1843
            f"{required_size * input.element_size()} are out of bounds "
1844
            f"for storage of size {input.numel() * input.element_size()}"
1845
        ),
1846
    )
1847
    torch._check(
1848
        utils.is_same_shape(src.shape, size),
1849
        lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
1850
    )
1851

1852
    return utils.clone_preserve_strides(input)
1853

1854

1855
_as_strided_scatter_doc = """
1856
    Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
1857
    ``out.as_strided(size, stride, storage_offset).copy_(src)``.
1858
"""
1859

1860
as_strided_scatter = _make_prim(
1861
    schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
1862
    meta=_as_strided_scatter_meta,
1863
    impl_aten=torch.as_strided_scatter,
1864
    return_type=RETURN_TYPE.NEW,
1865
    doc=_as_strided_scatter_doc,
1866
)
1867

1868

1869
#
1870
# Shape operations
1871
#
1872

1873

1874
def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
1875
    # Special-case for zero dimensional tensors
1876
    _validate_collapse_args(a, start, end)
1877
    new_shape = _collapsed_shape(a.shape, start, end)
1878
    return a.new_empty(new_shape)
1879

1880

1881
def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
1882
    new_shape = _collapsed_shape(a.shape, start, end)
1883
    out = a.new_empty(new_shape)
1884
    with torch.no_grad():
1885
        out.view_as(a).copy_(a)
1886
    return out
1887

1888

1889
_collapse_doc = """
1890
Collapse a span of neighboring dimensions into one.
1891

1892
See collapse_view for the corresponding view operation.
1893
"""
1894
collapse = _make_prim(
1895
    schema="collapse(Tensor a, int start, int end) -> Tensor",
1896
    meta=_collapse_meta,
1897
    impl_aten=_collapse_aten,
1898
    return_type=RETURN_TYPE.NEW,
1899
    doc=_collapse_doc,
1900
)
1901

1902

1903
# TODO: review stride logic
1904
# NB: unlike torch.cat, this is more strict about empty tensors and dim is
1905
# never negative
1906
def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
1907
    # Verifies same shape (except in the concat dimension)
1908
    assert dim >= 0
1909
    shape = tensors[0].shape
1910
    concat_length = 0
1911
    for tensor_idx, tensor in enumerate(tensors):
1912
        assert len(shape) == len(tensor.shape)
1913
        for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
1914
            if idx == dim:
1915
                concat_length = concat_length + length
1916
            else:
1917
                torch._check(
1918
                    length == common_length,
1919
                    lambda: f"Sizes of tensors must match except in dimension {dim}. "
1920
                    f"Expected {common_length} but got {length} for tensor number "
1921
                    f"{tensor_idx} in the list",
1922
                )
1923

1924
    new_shape = list(tensors[0].shape).copy()
1925
    new_shape[dim] = concat_length
1926
    return TensorMeta(
1927
        tensors[0],
1928
        shape=new_shape,
1929
        strides=utils.make_contiguous_strides_for(new_shape),
1930
    )
1931

1932

1933
def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
1934
    return torch.cat(tensors, dim)
1935

1936

1937
_cat_doc = """
1938
  Concatenates tensors along the specified dimension.
1939

1940
  The tensors' shapes must have the same rank and same length for other dimensions.
1941
  """
1942

1943
cat = _make_prim(
1944
    schema="cat(Tensor[] tensors, int dim) -> Tensor",
1945
    meta=_cat_meta,
1946
    impl_aten=_cat_aten,
1947
    return_type=RETURN_TYPE.NEW,
1948
    doc=_cat_doc,
1949
)
1950

1951

1952
def _reshape_meta(a: TensorLikeType, shape: ShapeType):
1953
    assert isinstance(a, TensorLike)
1954
    utils.validate_shape(shape)
1955

1956
    # Validates the tensor and the requested shape have the
1957
    # same number of elements
1958
    numel = reduce(operator.mul, shape)
1959
    if numel != a.numel():
1960
        msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!"
1961
        raise ValueError(msg)
1962

1963
    return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
1964

1965

1966
def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
1967
    return a.reshape(shape).contiguous().clone()
1968

1969

1970
_reshape_doc = """
1971
  Creates a contiguous tensor with the specified shape
1972
  containing a copy of the data in a.
1973
  """
1974
reshape = _make_prim(
1975
    schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
1976
    meta=_reshape_meta,
1977
    impl_aten=_reshape_aten,
1978
    return_type=RETURN_TYPE.NEW,
1979
    doc=_reshape_doc,
1980
)
1981

1982

1983
def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
1984
    utils.validate_dimension_indices(a.ndim, dims)
1985
    return torch.empty_like(a, memory_format=torch.preserve_format)
1986

1987

1988
_rev_doc = """
1989
    Reverses the order of elements along the given dimensions.
1990
    """
1991

1992
rev = _make_prim(
1993
    schema="rev(Tensor a, int[] dims) -> Tensor",
1994
    meta=_rev_meta,
1995
    impl_aten=torch.flip,
1996
    return_type=RETURN_TYPE.NEW,
1997
    doc=_rev_doc,
1998
)
1999

2000
#
2001
# Conditional prims
2002
#
2003

2004

2005
def _where_meta(
2006
    pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
2007
) -> TensorLikeType:
2008
    return _prim_elementwise_meta(
2009
        a,
2010
        b,
2011
        type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
2012
        args_with_fixed_dtypes=(pred,),
2013
    )
2014

2015

2016
_where_doc = """
2017
  Selects elements from a and b according to pred.
2018

2019
  Where pred is true the result contains the element from a, and
2020
  where pred is false the result contains the element from b.
2021
  """
2022

2023
where = _make_prim(
2024
    schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
2025
    meta=_where_meta,
2026
    impl_aten=torch.where,
2027
    return_type=RETURN_TYPE.NEW,
2028
    doc=_where_doc,
2029
)
2030

2031

2032
#
2033
# Type conversions
2034
#
2035
def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
2036
    # Type checks
2037
    assert isinstance(a, TensorLike)
2038
    assert isinstance(dtype, torch.dtype)
2039

2040
    # dtype conversion preserves dense strides
2041
    if torch._prims_common.is_non_overlapping_and_dense(a):
2042
        strides = a.stride()
2043
    else:
2044
        strides = utils.compute_elementwise_output_strides(a)
2045

2046
    return TensorMeta(a, strides=strides, dtype=dtype)
2047

2048

2049
def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
2050
    # Propagates requires grad when possible
2051
    if not utils.is_grad_dtype(dtype):
2052
        requires_grad = False
2053
    else:
2054
        # TODO: update meta objects so this can be acquired directly
2055
        try:
2056
            requires_grad = a.requires_grad
2057
        except Exception as e:
2058
            requires_grad = False
2059

2060
    result = torch.empty_like(
2061
        a, device=a.device, dtype=dtype, requires_grad=requires_grad
2062
    )
2063
    with torch.no_grad():
2064
        return copy_to(result, a)
2065

2066

2067
_convert_element_type_doc = """
2068
  Creates a copy of a tensor with the given dtype.
2069
  """
2070

2071
convert_element_type = _make_prim(
2072
    schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
2073
    meta=_convert_element_type_meta,
2074
    impl_aten=_convert_element_type_aten,
2075
    return_type=RETURN_TYPE.NEW,
2076
    doc=_convert_element_type_doc,
2077
    tags=(torch.Tag.pointwise,),
2078
)
2079

2080

2081
def _device_put_meta(
2082
    a: TensorLikeType, device: Union[str, torch.device]
2083
) -> TensorLikeType:
2084
    assert isinstance(a, TensorLike)
2085
    assert isinstance(device, (str, torch.device))
2086

2087
    return TensorMeta(a, device=utils.canonicalize_device(device))
2088

2089

2090
def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
2091
    return a.to(device)
2092

2093

2094
_device_put_doc = """
2095
  Creates a copy of a tensor on the given device.
2096
  """
2097

2098
device_put = _make_prim(
2099
    schema="device_put(Tensor a, Device device) -> Tensor",
2100
    meta=_device_put_meta,
2101
    impl_aten=_device_put_aten,
2102
    return_type=RETURN_TYPE.NEW,
2103
    doc=_device_put_doc,
2104
)
2105

2106

2107
# NOTE: need to model meta scalars
2108
# See https://github.com/pytorch/pytorch/issues/78070
2109
def _item_meta(a: TensorLikeType) -> FakeTensor:
2110
    number_type = utils.dtype_to_type(a.dtype)
2111
    return TensorMeta(number_type(-1))
2112

2113

2114
_item_doc = """
2115
    Converts a tensor with one element to a Python number.
2116
"""
2117

2118
# TODO: create a new return type for scalars?
2119
# FIXME: currently returns integers for boolean tensors
2120
# https://github.com/pytorch/pytorch/issues/78071
2121
item = _make_prim(
2122
    schema="item(Tensor a) -> Scalar",
2123
    meta=_item_meta,
2124
    impl_aten=torch.Tensor.item,
2125
    return_type=RETURN_TYPE.NEW,
2126
    doc=_item_doc,
2127
)
2128

2129

2130
# NOTE: need to model meta scalars
2131
# See https://github.com/pytorch/pytorch/issues/78070
2132
def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
2133
    number_type = utils.dtype_to_type(dtype)
2134
    return TensorMeta(number_type(-1))
2135

2136

2137
def _maximum_value_aten(dtype: torch.dtype):
2138
    if dtype == torch.bool:
2139
        return True
2140
    elif dtype.is_complex or dtype.is_floating_point:
2141
        return torch.finfo(dtype).max
2142
    else:
2143
        return torch.iinfo(dtype).max
2144

2145

2146
_maximum_value_doc = """
2147
    Return the maximum finite value for a dtype.
2148
"""
2149

2150
# TODO: create a new return type for scalars?
2151
# FIXME: currently returns integers for boolean tensors
2152
# https://github.com/pytorch/pytorch/issues/78071
2153
maximum_value = _make_prim(
2154
    schema="maximum_value(ScalarType dtype) -> Scalar",
2155
    meta=_maximum_value_meta,
2156
    impl_aten=_maximum_value_aten,
2157
    return_type=RETURN_TYPE.NEW,
2158
    doc=_maximum_value_doc,
2159
)
2160

2161

2162
# NOTE: need to model meta scalars
2163
# See https://github.com/pytorch/pytorch/issues/78070
2164
def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
2165
    number_type = utils.dtype_to_type(dtype)
2166
    return TensorMeta(number_type(-1))
2167

2168

2169
def _minimum_value_aten(dtype: torch.dtype):
2170
    if dtype == torch.bool:
2171
        return False
2172
    elif dtype.is_complex or dtype.is_floating_point:
2173
        return torch.finfo(dtype).min
2174
    else:
2175
        return torch.iinfo(dtype).min
2176

2177

2178
_minimum_value_doc = """
2179
    Return the minimum finite value for a dtype.
2180
"""
2181

2182
# TODO: create a new return type for scalars?
2183
# FIXME: currently returns integers for boolean tensors
2184
# https://github.com/pytorch/pytorch/issues/78071
2185
minimum_value = _make_prim(
2186
    schema="minimum_value(ScalarType dtype) -> Scalar",
2187
    meta=_minimum_value_meta,
2188
    impl_aten=_minimum_value_aten,
2189
    return_type=RETURN_TYPE.NEW,
2190
    doc=_minimum_value_doc,
2191
)
2192

2193
#
2194
# Inplace operators
2195
#
2196

2197

2198
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
2199
    assert isinstance(a, TensorLike)
2200
    assert isinstance(b, TensorLike)
2201

2202
    # Validates the cast is safe
2203
    # TODO: move this as an option on the reference
2204
    # a_typ = utils.dtype_to_type(a.dtype)
2205
    # b_typ = utils.dtype_to_type(b.dtype)
2206
    # if a_typ is not utils.get_higher_type(a_typ, b_typ):
2207
    #     raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
2208

2209
    # Validates the tensors have the same number of elements
2210
    if a.numel() != b.numel():
2211
        msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!"
2212
        raise RuntimeError(msg)
2213

2214
    return a
2215

2216

2217
def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
2218
    return a.copy_(b)
2219

2220

2221
_copy_to_doc = """
2222
  Copies the data in b to a and returns the modified a.
2223
  """
2224

2225
# TODO: Remove safe casting and implement on reference instead
2226
copy_to = _make_prim(
2227
    schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
2228
    meta=_copy_to_meta,
2229
    impl_aten=_copy_to_aten,
2230
    return_type=RETURN_TYPE.INPLACE,
2231
    doc=_copy_to_doc,
2232
)
2233

2234

2235
def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
2236
    assert isinstance(a, TensorLike)
2237
    return torch.empty_strided(
2238
        a.shape,
2239
        stride,
2240
        dtype=a.dtype,
2241
        layout=a.layout,
2242
        device=a.device,
2243
        requires_grad=a.requires_grad,
2244
    )
2245

2246

2247
def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
2248
    out = torch.empty_strided(
2249
        a.size(),
2250
        stride=stride,
2251
        dtype=a.dtype,
2252
        layout=a.layout,
2253
        device=a.device,
2254
        requires_grad=a.requires_grad,
2255
    )
2256
    out.copy_(a)
2257
    return out
2258

2259

2260
_copy_strided_doc = """
2261
  Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
2262
  """
2263

2264

2265
copy_strided = _make_prim(
2266
    schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
2267
    meta=_copy_strided_meta,
2268
    impl_aten=_copy_strided_aten,
2269
    return_type=RETURN_TYPE.NEW,
2270
    doc=_copy_strided_doc,
2271
)
2272

2273

2274
def _resize_meta(a: TensorLikeType, shape: ShapeType):
2275
    return a.resize_(shape)
2276

2277

2278
def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
2279
    return a.resize_(shape)
2280

2281

2282
_resize_doc = """
2283
  Gives a tensor with no elements a new shape, returning the modified tensor.
2284

2285
  The tensor's strides are contiguous and its values are unitialized.
2286
  """
2287

2288
# TODO: review support arbitrary resizes
2289
resize = _make_prim(
2290
    schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
2291
    meta=_resize_meta,
2292
    impl_aten=_resize_aten,
2293
    return_type=RETURN_TYPE.INPLACE,
2294
    doc=_resize_doc,
2295
)
2296

2297

2298
def _reduction_meta(inp, dims, *, output_dtype=None):
2299
    """
2300
    Meta function for single output reduction operations
2301
    Stride logic is incorrect
2302
    """
2303
    assert isinstance(inp, TensorLike)
2304
    if output_dtype is None:
2305
        output_dtype = inp.dtype
2306
    output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
2307
    return TensorMeta(
2308
        shape=output_shape,
2309
        strides=utils.make_contiguous_strides_for(output_shape),
2310
        dtype=output_dtype,
2311
        device=inp.device,
2312
    )
2313

2314

2315
def _var_reduction_meta(inp, dims, *, correction):
2316
    if utils.is_complex_dtype(inp.dtype):
2317
        output_dtype = utils.corresponding_real_dtype(inp.dtype)
2318
    else:
2319
        output_dtype = inp.dtype
2320
    return _reduction_meta(inp, dims, output_dtype=output_dtype)
2321

2322

2323
_sum_doc = """
2324
    Computes the sum of elements in the input tensor over the list of dimensions
2325
    specified in the dim argument
2326
    """
2327
_xor_sum_doc = """
2328
    Computes the xor sum of elements in the input tensor over the list of dimensions
2329
    specified in the dim argument
2330
    """
2331
_prod_doc = """
2332
    Computes the product of elements in the input tensor over the list of dimensions
2333
    specified in the dim argument
2334
    """
2335
_amax_doc = """
2336
    Computes the maximum value of elements in the input tensor over the list of dimensions
2337
    specified in the dim argument
2338
    """
2339
_amin_doc = """
2340
    Computes the minimum value of elements in the input tensor over the list of dimensions
2341
    specified in the dim argument
2342
    """
2343
_var_doc = """
2344
    Computes the biased variance of x over the list of dimensions specified in the dim argument
2345
    """
2346

2347

2348
def _make_reduction_prim(name: str, impl_aten, doc):
2349
    """Creates a reduction prim."""
2350
    return _make_prim(
2351
        schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
2352
        meta=_reduction_meta,
2353
        impl_aten=impl_aten,
2354
        return_type=RETURN_TYPE.NEW,
2355
        doc=doc,
2356
    )
2357

2358

2359
def _make_var_reduction_prim(name: str, impl_aten, doc):
2360
    """Creates a reduction prim."""
2361
    return _make_prim(
2362
        schema=f"{name}(Tensor inp, int[]? dims, *, float correction, ScalarType? output_dtype=None) -> Tensor",
2363
        meta=_var_reduction_meta,
2364
        impl_aten=impl_aten,
2365
        return_type=RETURN_TYPE.NEW,
2366
        doc=doc,
2367
    )
2368

2369

2370
sum = _make_reduction_prim(
2371
    name="sum",
2372
    impl_aten=torch.sum,
2373
    doc=_sum_doc,
2374
)
2375

2376

2377
def _xor_sum_aten(
2378
    inp: TensorLikeType,
2379
    dims: Optional[DimsSequenceType],
2380
    *,
2381
    dtype: Optional[torch.dtype] = None,
2382
) -> Tensor:
2383
    raise NotImplementedError("xor_sum only implemented with inductor")
2384

2385

2386
xor_sum = _make_reduction_prim(
2387
    name="xor_sum",
2388
    impl_aten=_xor_sum_aten,
2389
    doc=_xor_sum_doc,
2390
)
2391

2392

2393
def _prod_aten(
2394
    inp: TensorLikeType,
2395
    dims: Optional[DimsSequenceType],
2396
    *,
2397
    dtype: Optional[torch.dtype] = None,
2398
) -> Tensor:
2399
    if dims is not None:
2400
        for d in sorted(dims, reverse=True):
2401
            assert d >= 0
2402
            inp = torch.prod(inp, d, dtype=dtype)
2403
        return inp
2404
    else:
2405
        return torch.prod(inp, dims, dtype=dtype)
2406

2407

2408
prod = _make_reduction_prim(
2409
    name="prod",
2410
    impl_aten=_prod_aten,
2411
    doc=_prod_doc,
2412
)
2413

2414
var = _make_var_reduction_prim(
2415
    name="var",
2416
    impl_aten=torch.var,
2417
    doc=_var_doc,
2418
)
2419

2420
amax = _make_reduction_prim(
2421
    name="amax",
2422
    impl_aten=torch.amax,
2423
    doc=_amax_doc,
2424
)
2425

2426
amin = _make_reduction_prim(
2427
    name="amin",
2428
    impl_aten=torch.amin,
2429
    doc=_amin_doc,
2430
)
2431

2432

2433
_iota_doc = """
2434
    Constructs a 1-D tensor t where ``t[i] == start + i * step``.
2435
"""
2436

2437

2438
# TODO: layout, pin_memory, memory_format
2439
# TODO: model requires_grad on TensorMeta
2440
def _iota_meta(
2441
    length: int,
2442
    *,
2443
    start: int,
2444
    step: int,
2445
    dtype: torch.dtype,
2446
    device: torch.device,
2447
    requires_grad: bool,
2448
) -> TensorLikeType:
2449
    torch._check(
2450
        utils.is_integer_dtype(dtype),
2451
        lambda: "prims.iota only supports integer dtypes",
2452
    )
2453
    torch._check(step != 0, lambda: "step must be nonzero")
2454
    return torch.empty(
2455
        length,
2456
        dtype=dtype,
2457
        device=device,
2458
        requires_grad=requires_grad,
2459
    )
2460

2461

2462
def _iota_aten(
2463
    length: int,
2464
    *,
2465
    start: int,
2466
    step: int,
2467
    dtype: torch.dtype,
2468
    device: torch.device,
2469
    requires_grad: bool,
2470
) -> TensorLikeType:
2471
    end = start + length * step
2472
    return torch.arange(
2473
        start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
2474
    )
2475

2476

2477
iota = _make_prim(
2478
    schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor",  # noqa: B950
2479
    return_type=RETURN_TYPE.NEW,
2480
    meta=_iota_meta,
2481
    impl_aten=_iota_aten,
2482
    doc=_iota_doc,
2483
)
2484

2485

2486
# TODO: layout, pin_memory, memory_format
2487
# TODO: model requires_grad on TensorMeta
2488
def _empty_meta(
2489
    shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2490
) -> TensorLikeType:
2491
    strides = utils.make_contiguous_strides_for(shape)
2492
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2493

2494

2495
def _empty_aten(
2496
    shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
2497
) -> Tensor:
2498
    return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2499

2500

2501
_empty_doc = """
2502
    Creates a tensor with uninitialized values and the specified shape, dtype, and device.
2503
"""
2504

2505
empty = _make_prim(
2506
    schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2507
    meta=_empty_meta,
2508
    impl_aten=_empty_aten,
2509
    return_type=RETURN_TYPE.NEW,
2510
    doc=_empty_doc,
2511
)
2512

2513

2514
def _empty_strided_meta(
2515
    shape: ShapeType,
2516
    strides: StrideType,
2517
    *,
2518
    dtype: torch.dtype,
2519
    device: torch.device,
2520
    requires_grad: bool,
2521
) -> TensorLikeType:
2522
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2523

2524

2525
_empty_strided_doc = """
2526
    Creates a tensor with uninitialized values.
2527
"""
2528

2529
# TODO: add layout, pin_memory
2530
empty_strided = _make_prim(
2531
    schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2532
    return_type=RETURN_TYPE.NEW,
2533
    meta=_empty_strided_meta,
2534
    impl_aten=torch.empty_strided,
2535
    doc=_empty_strided_doc,
2536
)
2537

2538

2539
def _empty_permuted_meta(
2540
    shape: ShapeType,
2541
    physical_layout: DimsSequenceType,
2542
    *,
2543
    dtype: torch.dtype,
2544
    device: torch.device,
2545
    requires_grad: bool,
2546
) -> TensorLikeType:
2547
    p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
2548
    dim = len(shape)
2549
    torch._check(
2550
        len(physical_layout) == dim,
2551
        lambda: (
2552
            "Number of dimensions in the tensor input does not match the "
2553
            f"length of the physical layout; i.e. len(size) = {dim} "
2554
            f"is not equal to len(physical_layout) = {len(physical_layout)}"
2555
        ),
2556
    )
2557
    strides = [0] * len(shape)
2558
    seen_dims = set()
2559
    for p, l in enumerate(physical_layout):
2560
        torch._check(
2561
            0 <= l < dim,
2562
            lambda: (
2563
                f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
2564
                f"{l} at index {p}).  NB: negative dims "
2565
                "not currently supported; file an issue if you want it."
2566
            ),
2567
        )
2568
        torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed")
2569
        strides[l] = p_strides[p]
2570
        seen_dims.add(l)
2571
    return TensorMeta(
2572
        shape=shape,
2573
        strides=strides,
2574
        dtype=dtype,
2575
        device=device,
2576
    )
2577

2578

2579
_empty_permuted_doc = """
2580
    Creates a tensor with uninitialized values according to some physical layout,
2581
    that is guaranteed to be non-overlapping and dense.
2582
"""
2583

2584
# TODO: add layout, pin_memory
2585
empty_permuted = _make_prim(
2586
    schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",  # noqa: B950
2587
    return_type=RETURN_TYPE.NEW,
2588
    meta=_empty_permuted_meta,
2589
    impl_aten=torch.empty_permuted,
2590
    doc=_empty_permuted_doc,
2591
)
2592

2593

2594
def _full_meta(
2595
    shape: ShapeType,
2596
    fill_value: NumberType,
2597
    *,
2598
    dtype: torch.dtype,
2599
    device: torch.device,
2600
    requires_grad: bool,
2601
) -> TensorLikeType:
2602
    strides = utils.make_contiguous_strides_for(shape)
2603
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2604

2605

2606
def _full_aten(
2607
    shape: ShapeType,
2608
    fill_value: NumberType,
2609
    *,
2610
    dtype: torch.dtype,
2611
    device: torch.device,
2612
    requires_grad: bool,
2613
) -> Tensor:
2614
    # Note that Mypy thinks torch.full can't accept a complex fill_value
2615
    return torch.full(
2616
        shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad  # type: ignore[arg-type]
2617
    )
2618

2619

2620
_full_doc = """
2621
    Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
2622
"""
2623

2624
# TODO: add layout
2625
full = _make_prim(
2626
    schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2627
    meta=_full_meta,
2628
    impl_aten=_full_aten,
2629
    return_type=RETURN_TYPE.NEW,
2630
    doc=_full_doc,
2631
)
2632

2633

2634
def _full_like_meta(
2635
    a: TensorLikeType,
2636
    fill_value: NumberType,
2637
    *,
2638
    dtype: torch.dtype,
2639
    device: torch.device,
2640
    requires_grad: bool,
2641
) -> TensorLikeType:
2642
    strides = utils.compute_elementwise_output_strides(a)
2643
    if a.numel() == 0:
2644
        strides = a.stride()
2645

2646
    return TensorMeta(a, strides=strides, dtype=dtype, device=device)
2647

2648

2649
def _full_like_aten(
2650
    a: Tensor,
2651
    fill_value: NumberType,
2652
    *,
2653
    dtype: torch.dtype,
2654
    device: torch.device,
2655
    requires_grad: bool,
2656
) -> Tensor:
2657
    # Note that Mypy thinks torch.full can't accept a complex fill_value
2658
    return torch.full_like(
2659
        a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad  # type: ignore[arg-type]
2660
    )
2661

2662

2663
_full_like_doc = """
2664
    Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
2665
    given tensor by default. The dtype and device settings can be overridden
2666
    by specifying them explicitly.
2667
"""
2668

2669
full_like = _make_prim(
2670
    schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
2671
    meta=_full_like_meta,
2672
    impl_aten=_full_like_aten,
2673
    return_type=RETURN_TYPE.NEW,
2674
    doc=_full_like_doc,
2675
)
2676

2677

2678
def _scalar_tensor_meta(
2679
    scalar: NumberType,
2680
    *,
2681
    dtype: torch.dtype,
2682
    device: torch.device,
2683
) -> TensorLikeType:
2684
    shape: ShapeType = []
2685
    strides = utils.make_contiguous_strides_for(shape)
2686
    return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
2687

2688

2689
def _scalar_tensor_aten(
2690
    scalar: NumberType,
2691
    *,
2692
    dtype: torch.dtype,
2693
    device: torch.device,
2694
) -> Tensor:
2695
    if isinstance(scalar, complex) and (
2696
        dtype is None or not utils.is_complex_dtype(dtype)
2697
    ):
2698
        raise TypeError("Complex scalar requires complex tensor dtype.")
2699
    # Note that Mypy thinks torch.scalar can't accept a complex scalar
2700
    return torch.scalar_tensor(scalar, dtype=dtype, device=device)  # type: ignore[arg-type]
2701

2702

2703
_scalar_tensor_doc = """
2704
    Wraps a Number into a Tensor with the specified dtype and device.
2705
"""
2706

2707
# TODO: add layout and pin_memory support
2708
scalar_tensor = _make_prim(
2709
    schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
2710
    meta=_scalar_tensor_meta,
2711
    impl_aten=_scalar_tensor_aten,
2712
    return_type=RETURN_TYPE.NEW,
2713
    doc=_scalar_tensor_doc,
2714
)
2715

2716

2717
#
2718
# Linear algebra (linalg) prims
2719
#
2720

2721

2722
def _svd_meta(
2723
    A: TensorLikeType, *, full_matrices: bool
2724
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
2725
    utils.check_is_matrix(A, "linalg.svd")
2726
    utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
2727

2728
    A_shape = A.shape
2729
    batch = A_shape[:-2]
2730
    m, n = A_shape[-2:]
2731
    k = min(m, n)
2732

2733
    shape_U = batch + (m, m if full_matrices else k)
2734
    strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
2735
    U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
2736

2737
    shape_S = batch + (k,)
2738
    strides_S = utils.make_contiguous_strides_for(shape_S)
2739
    S = TensorMeta(
2740
        shape=shape_S,
2741
        strides=strides_S,
2742
        dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
2743
        device=A.device,
2744
    )
2745

2746
    shape_Vh = batch + (n if full_matrices else k, n)
2747
    # The CPU backend returns V, but the cuSolver backend returns V^H
2748
    # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
2749
    is_cuda = A.device.type == "cuda"
2750
    strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
2751
    Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
2752
    # Also makes sure this is CUDA or HIP:
2753
    # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
2754
    if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available():
2755
        Vh = Vh.conj()
2756
    return U, S, Vh
2757

2758

2759
def _svd_aten(
2760
    A: TensorLikeType, *, full_matrices: bool
2761
) -> Tuple[Tensor, Tensor, Tensor]:
2762
    return torch.linalg.svd(A, full_matrices=full_matrices)
2763

2764

2765
_svd_doc = """
2766
    Returns the SVD of a matrix or batch of matrices.
2767

2768
    The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
2769
"""
2770

2771
svd = _make_prim(
2772
    schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
2773
    meta=_svd_meta,
2774
    impl_aten=_svd_aten,
2775
    return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
2776
    doc=_svd_doc,
2777
)
2778

2779

2780
#
2781
# Randomness Prims
2782
#
2783

2784

2785
def _normal_meta(
2786
    shape: ShapeType,
2787
    *,
2788
    mean: Union[float, complex],
2789
    std: float,
2790
    dtype: torch.dtype,
2791
    device: torch.device,
2792
    requires_grad: bool,
2793
    generator: Optional[torch.Generator] = None,
2794
) -> TensorLikeType:
2795
    torch._check(
2796
        std >= 0.0,
2797
        lambda: f"expected non-negative standard deviation, but got std={std}",
2798
    )
2799

2800
    torch._check(
2801
        utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
2802
        lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
2803
    )
2804

2805
    strides = utils.make_contiguous_strides_for(shape)
2806
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2807

2808

2809
def _normal_aten(
2810
    shape: ShapeType,
2811
    *,
2812
    mean: Union[float, complex],
2813
    std: float,
2814
    dtype: torch.dtype,
2815
    device: torch.device,
2816
    requires_grad: bool,
2817
    generator: Optional[torch.Generator] = None,
2818
) -> Tensor:
2819
    a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
2820
    with torch.no_grad():
2821
        # NOTE: normal_ is incorrectly annotated to expect mean to be a float
2822
        a.normal_(mean, std, generator=generator)  # type: ignore[arg-type]
2823
    return a
2824

2825

2826
_normal_doc = """
2827
    Constructs a tensor filled with values drawn from a normal distribution with the specified mean
2828
    and standard deviation.
2829

2830
    Only supports floating-point types.
2831
"""
2832

2833
normal = _make_prim(
2834
    schema=(
2835
        "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor"  # noqa: B950
2836
    ),
2837
    return_type=RETURN_TYPE.NEW,
2838
    meta=_normal_meta,
2839
    impl_aten=_normal_aten,
2840
    doc=_normal_doc,
2841
)
2842

2843

2844
def _uniform_meta(
2845
    shape: ShapeType,
2846
    *,
2847
    low: float,
2848
    high: float,
2849
    dtype: torch.dtype,
2850
    device: torch.device,
2851
    generator: Optional[torch.Generator] = None,
2852
) -> TensorLikeType:
2853
    strides = utils.make_contiguous_strides_for(shape)
2854
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
2855

2856

2857
def _uniform_aten(
2858
    shape: ShapeType,
2859
    *,
2860
    low: float,
2861
    high: float,
2862
    dtype: torch.dtype,
2863
    device: torch.device,
2864
    generator: Optional[torch.Generator] = None,
2865
) -> Tensor:
2866
    a = torch.empty(shape, dtype=dtype, device=device)
2867
    a.uniform_(low, high, generator=generator)
2868
    return a
2869

2870

2871
_uniform_doc = """
2872
    Constructs a tensor filled with values drawn uniformly from low to high.
2873
"""
2874

2875
# TODO: we should more seriously review randomness modeling and prims
2876
_uniform_helper = _make_prim(
2877
    schema=(
2878
        "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
2879
    ),
2880
    return_type=RETURN_TYPE.NEW,
2881
    meta=_uniform_meta,
2882
    impl_aten=_uniform_aten,
2883
    doc=_uniform_doc,
2884
)
2885

2886
#
2887
# FFT prims
2888
#
2889

2890

2891
def _fft_r2c_meta(
2892
    input: TensorLike,
2893
    *,
2894
    dim: DimsSequenceType,
2895
    onesided: bool,
2896
) -> TensorLikeType:
2897
    dim = utils.canonicalize_dims(input.ndim, dim)
2898
    utils.validate_no_repeating_dims(dim)
2899

2900
    shape = list(input.shape)
2901
    if onesided:
2902
        last_dim = dim[-1]
2903
        shape[last_dim] = shape[last_dim] // 2 + 1
2904

2905
    dtype = utils.corresponding_complex_dtype(input.dtype)
2906
    strides = utils.make_contiguous_strides_for(shape)
2907
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
2908

2909

2910
def _fft_r2c_aten(
2911
    input: TensorLike,
2912
    *,
2913
    dim: DimsSequenceType,
2914
    onesided: bool,
2915
) -> TensorLikeType:
2916
    normalization = 0  # No normalization
2917
    return torch._fft_r2c(input, dim, normalization, onesided)
2918

2919

2920
_fft_r2c_doc = """
2921
    Performs a real to complex Fast Fourier Transform
2922
"""
2923

2924

2925
fft_r2c = _make_prim(
2926
    schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
2927
    meta=_fft_r2c_meta,
2928
    impl_aten=_fft_r2c_aten,
2929
    return_type=RETURN_TYPE.NEW,
2930
    doc=_fft_r2c_doc,
2931
)
2932

2933

2934
def _fft_c2c_meta(
2935
    input: TensorLike,
2936
    *,
2937
    dim: DimsSequenceType,
2938
    forward: bool,
2939
) -> TensorLikeType:
2940
    dim = utils.canonicalize_dims(input.ndim, dim)
2941
    utils.validate_no_repeating_dims(dim)
2942

2943
    shape = input.shape
2944
    strides = utils.make_contiguous_strides_for(shape)
2945
    return TensorMeta(
2946
        shape=shape, strides=strides, dtype=input.dtype, device=input.device
2947
    )
2948

2949

2950
def _fft_c2c_aten(
2951
    input: TensorLike,
2952
    *,
2953
    dim: DimsSequenceType,
2954
    forward: bool,
2955
) -> TensorLikeType:
2956
    normalization = 0  # No normalization
2957
    return torch._fft_c2c(input, dim, normalization, forward)
2958

2959

2960
_fft_c2c_doc = """
2961
    Performs either a Fast Fourier Transform, or its inverse
2962
"""
2963

2964

2965
fft_c2c = _make_prim(
2966
    schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
2967
    meta=_fft_c2c_meta,
2968
    impl_aten=_fft_c2c_aten,
2969
    return_type=RETURN_TYPE.NEW,
2970
    doc=_fft_c2c_doc,
2971
)
2972

2973

2974
def _fft_c2r_meta(
2975
    input: TensorLike,
2976
    *,
2977
    dim: DimsSequenceType,
2978
    last_dim_size: int,
2979
) -> TensorLikeType:
2980
    dim = utils.canonicalize_dims(input.ndim, dim)
2981
    utils.validate_no_repeating_dims(dim)
2982

2983
    shape = list(input.shape)
2984
    shape[dim[-1]] = last_dim_size
2985
    dtype = utils.corresponding_real_dtype(input.dtype)
2986
    strides = utils.make_contiguous_strides_for(shape)
2987
    return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
2988

2989

2990
def _fft_c2r_aten(
2991
    input: TensorLike,
2992
    *,
2993
    dim: DimsSequenceType,
2994
    last_dim_size: int,
2995
) -> TensorLikeType:
2996
    normalization = 0  # No normalization
2997
    return torch._fft_c2r(input, dim, normalization, last_dim_size)
2998

2999

3000
_fft_c2r_doc = """
3001
    Performs a complex to real Inverse Fast Fourier Transform
3002
"""
3003

3004

3005
fft_c2r = _make_prim(
3006
    schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
3007
    meta=_fft_c2r_meta,
3008
    impl_aten=_fft_c2r_aten,
3009
    return_type=RETURN_TYPE.NEW,
3010
    doc=_fft_c2r_doc,
3011
)
3012

3013

3014
def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
3015
    torch._check(
3016
        self.dtype.is_floating_point,
3017
        lambda: "torch.frexp() only supports floating-point dtypes",
3018
    )
3019
    return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32)
3020

3021

3022
frexp = _make_prim(
3023
    schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)",
3024
    meta=_frexp_meta,
3025
    return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW),
3026
    impl_aten=torch.frexp,
3027
    doc="",
3028
)
3029

3030
register_rng_prims()
3031
register_debug_prims()
3032

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.