pytorch

Форк
0
/
model.py 
2839 строк · 109.9 Кб
1
from __future__ import annotations
2

3
import dataclasses
4
import itertools
5
import re
6
from dataclasses import dataclass
7
from enum import auto, Enum
8
from typing import Callable, Iterator, Sequence
9

10
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
11

12

13
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
14
#
15
#                           DATA MODEL
16
#
17
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
18
#
19
# Some general principles for our data model.
20
#
21
# - Stop using C++ data types as the internal data representation
22
#   format.  Instead, the internal data structures are centered
23
#   around JIT schema representation.  This avoid a big problem
24
#   with the old codegen where we read in all the types from
25
#   native_functions.yaml and then immediately had to retranslate
26
#   them into C++ types.
27
#
28
# - More semantic data representation.  Instead of representing
29
#   everything as dicts and strings, we define dataclasses for
30
#   every interesting entity the code generation has to deal with.
31
#   These dataclasses have strong semantic invariants: for example,
32
#   we generally require them to roundtrip losslessly into the
33
#   form they were parsed from.  These structures are immutable
34
#   and you're expected to populate information once during
35
#   construction.
36

37

38
# Represent a source location; used for better error reporting
39
@dataclass(frozen=True)
40
class Location:
41
    file: str
42
    line: int
43

44
    def __str__(self) -> str:
45
        return f"{self.file}:{self.line}"
46

47

48
# Valid values of the 'variants' field in native_functions.yaml
49
class Variant(Enum):
50
    function = auto()
51
    method = auto()
52

53

54
# Default kernel namespace
55
DEFAULT_KERNEL_NAMESPACE = "at::native"
56

57
# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
58
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
59
FUNCTIONALITY_KEYS = [
60
    "",
61
    "Quantized",
62
    "Sparse",
63
    "SparseCsr",
64
    "NestedTensor",
65
    "Autograd",
66
]
67

68
# This list guards dispatches that can be used in derivatives.yaml
69
# For now we omit AutogradFunctionality and AutogradOther
70
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
71
    "Autograd" + component for component in BACKEND_COMPONENTS
72
]
73

74
FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
75

76

77
# This doesn't have to be in sync with the header, it only needs to contain
78
# entries that we actually use in the codegen or want pyi entries for
79
class DispatchKey(Enum):
80
    Undefined = 0
81
    CatchAll = Undefined
82

83
    FPGA = auto()
84
    MAIA = auto()
85
    Vulkan = auto()
86
    Metal = auto()
87
    MKLDNN = auto()
88
    OpenGL = auto()
89
    OpenCL = auto()
90
    IDEEP = auto()
91
    CustomRNGKeyId = auto()
92
    MkldnnCPU = auto()
93
    Sparse = auto()
94
    SparseCsr = auto()
95
    NestedTensor = auto()
96
    Dense = auto()
97

98
    PythonTLSSnapshot = auto()
99
    PreDispatch = auto()
100
    PythonDispatcher = auto()
101
    Python = auto()
102
    FuncTorchDynamicLayerBackMode = auto()
103
    ZeroTensor = auto()
104
    Conjugate = auto()
105
    Negative = auto()
106
    BackendSelect = auto()
107
    Named = auto()
108
    AutogradOther = auto()
109
    AutogradFunctionality = auto()
110
    AutogradNestedTensor = auto()
111
    Tracer = auto()
112
    Autocast = auto()
113
    AutocastCPU = auto()
114
    AutocastCUDA = auto()
115
    Batched = auto()
116
    VmapMode = auto()
117
    FuncTorchGradWrapper = auto()
118
    FuncTorchBatched = auto()
119
    BatchedNestedTensor = auto()
120
    FuncTorchVmapMode = auto()
121
    FuncTorchDynamicLayerFrontMode = auto()
122
    Functionalize = auto()
123
    TESTING_ONLY_GenericWrapper = auto()
124
    TESTING_ONLY_GenericMode = auto()
125

126
    ADInplaceOrView = auto()
127
    Autograd = auto()
128
    CompositeImplicitAutograd = auto()
129
    CompositeImplicitAutogradNestedTensor = auto()
130
    CompositeExplicitAutograd = auto()
131
    CompositeExplicitAutogradNonFunctional = auto()
132
    FuncTorchBatchedDecomposition = auto()
133

134
    # BEGIN autogenerated
135
    CPU = auto()
136
    CUDA = auto()
137
    HIP = auto()
138
    XLA = auto()
139
    MTIA = auto()
140
    MPS = auto()
141
    IPU = auto()
142
    XPU = auto()
143
    HPU = auto()
144
    VE = auto()
145
    Lazy = auto()
146
    Meta = auto()
147
    PrivateUse1 = auto()
148
    PrivateUse2 = auto()
149
    PrivateUse3 = auto()
150
    QuantizedCPU = auto()
151
    QuantizedCUDA = auto()
152
    QuantizedHIP = auto()
153
    QuantizedXLA = auto()
154
    QuantizedMTIA = auto()
155
    QuantizedMPS = auto()
156
    QuantizedIPU = auto()
157
    QuantizedXPU = auto()
158
    QuantizedHPU = auto()
159
    QuantizedVE = auto()
160
    QuantizedLazy = auto()
161
    QuantizedMeta = auto()
162
    QuantizedPrivateUse1 = auto()
163
    QuantizedPrivateUse2 = auto()
164
    QuantizedPrivateUse3 = auto()
165
    SparseCPU = auto()
166
    SparseCUDA = auto()
167
    SparseHIP = auto()
168
    SparseXLA = auto()
169
    SparseMTIA = auto()
170
    SparseMPS = auto()
171
    SparseIPU = auto()
172
    SparseXPU = auto()
173
    SparseHPU = auto()
174
    SparseVE = auto()
175
    SparseLazy = auto()
176
    SparseMeta = auto()
177
    SparsePrivateUse1 = auto()
178
    SparsePrivateUse2 = auto()
179
    SparsePrivateUse3 = auto()
180
    SparseCsrCPU = auto()
181
    SparseCsrCUDA = auto()
182
    SparseCsrHIP = auto()
183
    SparseCsrXLA = auto()
184
    SparseCsrMTIA = auto()
185
    SparseCsrMPS = auto()
186
    SparseCsrIPU = auto()
187
    SparseCsrXPU = auto()
188
    SparseCsrHPU = auto()
189
    SparseCsrVE = auto()
190
    SparseCsrLazy = auto()
191
    SparseCsrMeta = auto()
192
    SparseCsrPrivateUse1 = auto()
193
    SparseCsrPrivateUse2 = auto()
194
    SparseCsrPrivateUse3 = auto()
195
    NestedTensorCPU = auto()
196
    NestedTensorCUDA = auto()
197
    NestedTensorHIP = auto()
198
    NestedTensorXLA = auto()
199
    NestedTensorMTIA = auto()
200
    NestedTensorMPS = auto()
201
    NestedTensorIPU = auto()
202
    NestedTensorXPU = auto()
203
    NestedTensorHPU = auto()
204
    NestedTensorVE = auto()
205
    NestedTensorLazy = auto()
206
    NestedTensorMeta = auto()
207
    NestedTensorPrivateUse1 = auto()
208
    NestedTensorPrivateUse2 = auto()
209
    NestedTensorPrivateUse3 = auto()
210
    AutogradCPU = auto()
211
    AutogradCUDA = auto()
212
    AutogradHIP = auto()
213
    AutogradXLA = auto()
214
    AutogradMTIA = auto()
215
    AutogradMPS = auto()
216
    AutogradIPU = auto()
217
    AutogradXPU = auto()
218
    AutogradHPU = auto()
219
    AutogradVE = auto()
220
    AutogradLazy = auto()
221
    AutogradMeta = auto()
222
    AutogradPrivateUse1 = auto()
223
    AutogradPrivateUse2 = auto()
224
    AutogradPrivateUse3 = auto()
225
    # END autogenerated
226

227
    def __str__(self) -> str:
228
        return self.name
229

230
    def lower(self) -> str:
231
        return str(self).lower()
232

233
    @staticmethod
234
    def parse(value: str) -> DispatchKey:
235
        for k, v in DispatchKey.__members__.items():
236
            if k == value:
237
                return v
238
        raise AssertionError(f"unknown dispatch key {value}")
239

240

241
class _TorchDispatchModeKey(Enum):
242
    FAKE = auto()
243
    PROXY = auto()
244
    FUNCTIONAL = auto()
245

246

247
def codegen_per_backend_entries() -> str:
248
    r = []
249
    for fk in FUNCTIONALITY_KEYS:
250
        for bc in BACKEND_COMPONENTS:
251
            r.append(f"    {fk}{bc} = auto()")
252
    return "\n".join(r)
253

254

255
for fk in FUNCTIONALITY_KEYS:
256
    for bc in BACKEND_COMPONENTS:
257
        if not hasattr(DispatchKey, fk + bc):
258
            r = codegen_per_backend_entries()
259
            print(r)
260
            raise RuntimeError(
261
                f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}"
262
            )
263

264

265
STRUCTURED_DISPATCH_KEYS = {
266
    DispatchKey.MPS,
267
    DispatchKey.CUDA,
268
    DispatchKey.CPU,
269
    DispatchKey.XPU,
270
}
271
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
272

273
# Set of supported dispatch keys
274
dispatch_keys = [
275
    DispatchKey.CPU,
276
    DispatchKey.SparseCPU,
277
    DispatchKey.SparseCsrCPU,
278
    DispatchKey.MkldnnCPU,
279
    DispatchKey.CUDA,
280
    DispatchKey.MPS,
281
    DispatchKey.XPU,
282
    DispatchKey.SparseCUDA,
283
    DispatchKey.SparseCsrCUDA,
284
    DispatchKey.QuantizedCPU,
285
    DispatchKey.QuantizedCUDA,
286
    DispatchKey.CompositeImplicitAutograd,
287
    DispatchKey.CompositeImplicitAutogradNestedTensor,
288
    DispatchKey.CompositeExplicitAutograd,
289
    DispatchKey.CompositeExplicitAutogradNonFunctional,
290
    DispatchKey.NestedTensorCPU,
291
    DispatchKey.NestedTensorCUDA,
292
    # Meta is a magic key: it is automatically generated for structured
293
    # kernels
294
    DispatchKey.Meta,
295
    DispatchKey.SparseMeta,
296
    DispatchKey.SparseCsrMeta,
297
    DispatchKey.QuantizedMeta,
298
    DispatchKey.NestedTensorMeta,
299
    DispatchKey.ZeroTensor,
300
]
301

302

303
# Dispatch keys that "support all backends".  These codegen slightly differently
304
# then backend specific keys.
305
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
306
    return dk in {
307
        DispatchKey.CompositeExplicitAutograd,
308
        DispatchKey.CompositeExplicitAutogradNonFunctional,
309
        DispatchKey.CompositeImplicitAutograd,
310
        DispatchKey.CompositeImplicitAutogradNestedTensor,
311
    }
312

313

314
# CUDA specific dispatch keys
315
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
316
    return dk in {
317
        DispatchKey.CUDA,
318
        DispatchKey.QuantizedCUDA,
319
        DispatchKey.SparseCUDA,
320
        DispatchKey.SparseCsrCUDA,
321
        DispatchKey.NestedTensorCUDA,
322
        DispatchKey.AutogradCUDA,
323
    }
324

325

326
# Structured kernel generation is only supported for certain key types;
327
# otherwise use old-style
328
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
329
    return dk in STRUCTURED_DISPATCH_KEYS
330

331

332
def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
333
    # For now, ufunc dispatch keys coincide with structured keys
334
    return dk in UFUNC_DISPATCH_KEYS
335

336

337
# This is oddly named ScalarType and not DType for symmetry with C++
338
class ScalarType(Enum):
339
    Byte = auto()
340
    Char = auto()
341
    Short = auto()
342
    Int = auto()
343
    Long = auto()
344
    Half = auto()
345
    Float = auto()
346
    Double = auto()
347
    ComplexHalf = auto()
348
    ComplexFloat = auto()
349
    ComplexDouble = auto()
350
    Bool = auto()
351
    BFloat16 = auto()
352
    Float8_e5m2 = auto()
353
    Float8_e5m2fnuz = auto()
354
    Float8_e4m3fn = auto()
355
    Float8_e4m3fnuz = auto()
356

357
    def __str__(self) -> str:
358
        return self.name
359

360
    @staticmethod
361
    def maybe_parse(value: str) -> ScalarType | None:
362
        for k, v in ScalarType.__members__.items():
363
            if k == value:
364
                return v
365
        return None
366

367
    @staticmethod
368
    def parse(value: str) -> ScalarType:
369
        mb_r = ScalarType.maybe_parse(value)
370
        assert mb_r is not None, f"unknown dtype {value}"
371
        return mb_r
372

373
    @staticmethod
374
    def parse_set(values: str) -> OrderedSet[ScalarType]:
375
        dtypes: OrderedSet[ScalarType] = OrderedSet()
376
        for value in values.split(", "):
377
            if value in DTYPE_CLASSES:
378
                dtypes.update(DTYPE_CLASSES[value])
379
            else:
380
                dtypes.add(ScalarType.parse(value))
381
        return dtypes
382

383

384
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
385
# NB: Integral doesn't include boolean
386
DTYPE_CLASSES["Integral"] = OrderedSet(
387
    [
388
        ScalarType.Byte,
389
        ScalarType.Char,
390
        ScalarType.Int,
391
        ScalarType.Long,
392
        ScalarType.Short,
393
    ]
394
)
395
# NB: Floating doesn't include low precision types
396
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
397
DTYPE_CLASSES["Complex"] = OrderedSet(
398
    [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
399
)
400
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
401
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
402
DTYPE_CLASSES["FloatingAndComplex"] = (
403
    DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
404
)
405

406

407
# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
408
# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
409
# to process it.  Most logic will ignore keys they don't understand, so your
410
# new key will get silently ignored until you hook in logic to deal with it.
411
class UfuncKey(Enum):
412
    # These are low level keys that represent exactly one particular
413
    # instantiation of the kernel produced by codegen
414
    CUDAFunctor = auto()
415
    CUDAFunctorOnOther = auto()
416
    CUDAFunctorOnSelf = auto()
417

418
    CPUScalar = auto()
419
    CPUVector = auto()
420

421
    # These are the ones users will usually specify, and
422
    # implicitly "fill in" the low level keys
423
    ScalarOnly = auto()  # CUDA*, CPUScalar
424
    Generic = auto()  # CUDA*, CPU*
425

426
    def __str__(self) -> str:
427
        return self.name
428

429
    @staticmethod
430
    def parse(value: str) -> UfuncKey:
431
        for k, v in UfuncKey.__members__.items():
432
            if k == value:
433
                return v
434
        raise AssertionError(f"unknown ufunc key {value}")
435

436

437
class DeviceCheckType(Enum):
438
    NoCheck = 0
439
    ExactSame = 1
440

441

442
class ViewSchemaKind(Enum):
443
    aliasing = auto()
444
    aliasing_inplace = auto()
445
    non_aliasing = auto()
446

447

448
# The basic input to the code generation is native_functions.yaml.
449
# The name "native", BTW, comes from the distinction between native
450
# functions and legacy TH functions.  The legacy TH functions are gone,
451
# but the "native" descriptor has stuck.
452
#
453
# NativeFunction models a single entry in native_functions.yaml.  Its
454
# fields roughly correspond to what you would see in the YAML itself,
455
# but after canonicalization and parsing has occurred.
456
#
457
# You can see some of the overall design patterns for how we setup
458
# dataclasses in this class, but we will defer a complete discussion
459
# of this at FunctionSchema.
460
@dataclass(frozen=True)
461
class NativeFunction:
462
    # The namespace for this operator. For example, if we have "at::add"
463
    # then the namespace would be "at". This enables ops to be registered
464
    # through the same DSL with a custom namespace. If not specified, the
465
    # default namespace would be "at".
466
    namespace: str
467

468
    # The function schema of the operator in question.  This schema
469
    # has been parsed; see FunctionSchema for more about its structure.
470
    # (This type is quoted as we are forward referencing a type
471
    # defined later in the file.  I opted for this ordering of the
472
    # classes for expository clarity.)
473
    func: FunctionSchema
474

475
    # Whether or not to generate mutable tensor arguments like regular
476
    # ones
477
    use_const_ref_for_mutable_tensors: bool
478

479
    # Whether or not to omit automatic generation of a DeviceGuard
480
    device_guard: bool
481

482
    # How to emit automatic generation of device check
483
    device_check: DeviceCheckType
484

485
    # What python module to put the function in
486
    python_module: str | None
487

488
    # TODO: figure out what this does
489
    category_override: str | None
490

491
    # If no variants are specified in native_functions.yaml, this is
492
    # assumed to be {'function'}.
493
    variants: set[Variant]
494

495
    # Whether or not we should skip generating registrations for
496
    # this kernel.  This is a bit of a double-edged sword, as manual
497
    # registrations don't participate in codegen-based selective build!
498
    manual_kernel_registration: bool
499

500
    # Whether or not to skip generating TensorMethod/Functions bindings
501
    # for this kernel.  Technically, this doesn't actually skip generating
502
    # the binding; instead, the binding gets generated to __dispatch_{funcname}
503
    # so you can make use of the normal binding if you need it.
504
    manual_cpp_binding: bool
505

506
    # The location in the YAML file were this native function entry was
507
    # defined.  This is for conveniently reporting error messages!
508
    loc: Location
509

510
    # A list of operators that are expected to be auto-generated for this NativeFunction.
511
    # Note: This list isn't actually directly used by the codegen to generate anything.
512
    # Instead, the codegen figures out what operators to generate purely based off of
513
    # function schema, and uses the autogen declarations to error check.
514
    # We expect every NativeFunction that gets auto-generated be explicitly called out
515
    # in native_functions.yaml
516
    autogen: list[OperatorName]
517

518
    # If non-empty, this kernel is subject to ufunc codegen.
519
    # Sorted by ufunc_key
520
    ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
521

522
    # Whether or not this out functions is a "structured kernel".  Structured
523
    # kernels are defined a little differently from normal kernels; in
524
    # particular, their shape checking logic is defined separately from
525
    # the kernel.  Only out functions can be structured; other functions
526
    # delegate to the out function using the structured_delegate keyword.
527
    # Every structured kernel must have at least an out and a functional
528
    # variant.
529
    structured: bool
530

531
    # Whether or not this non-out function is a structured kernel, defined
532
    # in terms of the out kernel referenced by the string here.
533
    structured_delegate: OperatorName | None
534

535
    # Only valid for structured kernels.  Specifies alternative of what
536
    # to inherit from when defining the meta class for the structured
537
    # operator.  This will usually be TensorIteratorBase.  This also
538
    # changes the semantics of set_output to call the parent class.
539
    structured_inherits: str | None
540

541
    # Structured kernels can declare elements as "precomputed". These elements
542
    # are returned by the meta function in one struct and passed to the impl
543
    # function in lieu of certain kernel arguments that these precomputed
544
    # elements supersede. Information about the names and types of these
545
    # precomputed elements and how they correspond to kernel arguments is stored
546
    # in this member, if applicable.
547
    precomputed: Precompute | None
548

549
    # Argument names whose default  should be excluded from the C++ interface.
550
    # Intended for resolving overload ambiguities between signatures.
551
    cpp_no_default_args: set[str]
552

553
    # Note [Abstract ATen methods]
554
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
555
    # An abstract ATen method is one whose dispatch differs between
556
    # types.  These are implemented in derived types (with a
557
    # standard (throwing) definition in Type).  A concrete ATen
558
    # method is one which has the same dispatch for all types;
559
    # we just implement it in the base Type.  This is exposed
560
    # in Declarations.yaml via a field named 'abstract'.
561
    is_abstract: bool
562

563
    # Whether or not the NativeFunction contains a backend-agnostic kernel
564
    has_composite_implicit_autograd_kernel: bool
565
    has_composite_implicit_autograd_nested_tensor_kernel: bool
566
    has_composite_explicit_autograd_kernel: bool
567
    has_composite_explicit_autograd_non_functional_kernel: bool
568

569
    # Tags are used to describe semantic information about (groups of) operators,
570
    # That aren't easily inferrable directly from the operator's schema.
571
    tags: set[str]
572

573
    # NB: The benefit of defining a dataclass is that we automatically get
574
    # a constructor defined for all the fields we specify.  No need
575
    # to explicitly write it out.
576

577
    # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
578
    @staticmethod
579
    def from_yaml(
580
        ei: dict[str, object],
581
        loc: Location,
582
        valid_tags: set[str],
583
        ignore_keys: set[DispatchKey] | None = None,
584
    ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
585
        """
586
        Parse a NativeFunction from a dictionary as directly parsed
587
        from native_functions.yaml
588
        """
589
        e = ei.copy()
590

591
        funcs = e.pop("func")
592
        assert isinstance(funcs, str), f"not a str: {funcs}"
593
        # only support one level of namespace. E.g., aten::add
594
        namespace_helper = NamespaceHelper.from_namespaced_entity(
595
            namespaced_entity=funcs, max_level=1
596
        )
597
        namespace = namespace_helper.get_cpp_namespace(default="aten")
598
        func = FunctionSchema.parse(namespace_helper.entity_name)
599

600
        cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
601
        assert isinstance(cpp_no_default_args_list, list)
602
        cpp_no_default_args = set(cpp_no_default_args_list)
603

604
        use_const_ref_for_mutable_tensors = e.pop(
605
            "use_const_ref_for_mutable_tensors", False
606
        )
607
        assert isinstance(use_const_ref_for_mutable_tensors, bool)
608

609
        variants_s = e.pop("variants", "function")
610
        assert isinstance(variants_s, str)
611
        variants: set[Variant] = set()
612
        for v in variants_s.split(", "):
613
            if v == "function":
614
                variants.add(Variant.function)
615
            elif v == "method":
616
                variants.add(Variant.method)
617
            else:
618
                raise AssertionError(f"illegal variant {v}")
619

620
        manual_kernel_registration = e.pop("manual_kernel_registration", False)
621
        assert isinstance(
622
            manual_kernel_registration, bool
623
        ), f"not a bool: {manual_kernel_registration}"
624

625
        manual_cpp_binding = e.pop("manual_cpp_binding", False)
626
        assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
627

628
        device_guard = e.pop("device_guard", True)
629
        assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
630

631
        device_check_s = e.pop("device_check", None)
632
        assert device_check_s is None or isinstance(
633
            device_check_s, str
634
        ), f"not a str: {device_check_s}"
635
        assert (
636
            device_check_s is None or device_check_s in DeviceCheckType.__members__
637
        ), f"illegal device_check: {device_check_s}"
638
        device_check: DeviceCheckType
639
        if device_check_s is None:
640
            device_check = DeviceCheckType.ExactSame
641
        else:
642
            device_check = DeviceCheckType[device_check_s]
643

644
        structured = e.pop("structured", False)
645
        assert isinstance(structured, bool), f"not a bool: {structured}"
646

647
        structured_delegate_s = e.pop("structured_delegate", None)
648
        assert structured_delegate_s is None or isinstance(
649
            structured_delegate_s, str
650
        ), f"not a str: {structured_delegate_s}"
651
        assert structured_delegate_s is None or "::" not in structured_delegate_s, (
652
            "namespace is not supported in structured delegate,"
653
            " using the same namespace as the native function"
654
        )
655
        structured_delegate: OperatorName | None = None
656
        if structured_delegate_s is not None:
657
            structured_delegate = OperatorName.parse(structured_delegate_s)
658

659
        structured_inherits = e.pop("structured_inherits", None)
660
        assert structured_inherits is None or isinstance(
661
            structured_inherits, str
662
        ), f"not a str: {structured_inherits}"
663
        assert structured_inherits is None or "::" not in structured_inherits, (
664
            "namespace is not supported in structured inherits,"
665
            " using the same namespace as the native function"
666
        )
667

668
        python_module = e.pop("python_module", None)
669
        assert python_module is None or isinstance(
670
            python_module, str
671
        ), f"not a str: {python_module}"
672
        assert (
673
            python_module is None or Variant.method not in variants
674
        ), "functions in modules cannot be methods"
675

676
        category_override = e.pop("category_override", None)
677
        assert category_override is None or isinstance(
678
            category_override, str
679
        ), f"not a str: {category_override}"
680

681
        precomputed_dict = e.pop("precomputed", None)
682
        assert precomputed_dict is None or structured is True
683
        precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
684

685
        tags_inp = e.pop("tags", [])
686
        if isinstance(tags_inp, str):
687
            tags_inp = [tags_inp]
688
        assert isinstance(tags_inp, list)
689

690
        # All aten ops generated by torchgen receive the pt2_compliant tag.
691
        if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
692
            tags_inp.append("pt2_compliant_tag")
693

694
        tags: set[str] = set()
695
        for t in tags_inp:
696
            assert len(valid_tags) > 0
697
            # TODO: verify that the tag is valid and has an entry in tags.yaml
698
            if t in valid_tags:
699
                tags.add(t)
700
            else:
701
                raise AssertionError(f"illegal tag {t}")
702

703
        from torchgen.api import cpp
704

705
        raw_dispatch = e.pop("dispatch", None)
706
        assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
707
        dispatch: dict[DispatchKey, BackendMetadata] = {}
708
        num_dispatch_keys: int = 0
709
        if raw_dispatch is not None:
710
            assert not manual_kernel_registration, (
711
                "cannot specify both manual_kernel_registration and dispatch; with "
712
                "manual registration, dispatch has no effect!"
713
            )
714
            redundant_composite_implicit_autograd = False
715
            for ks, v in raw_dispatch.items():
716
                if ks == "__line__":
717
                    continue  # not worth tracking line numbers for dispatch entries
718
                assert isinstance(
719
                    ks, str
720
                ), f"illegal dispatch key '{ks}' in {raw_dispatch}"
721
                assert isinstance(
722
                    v, str
723
                ), f"illegal dispatch value '{v}' in {raw_dispatch}"
724
                for k in ks.split(","):
725
                    dispatch_key = DispatchKey.parse(k.strip())
726
                    num_dispatch_keys += 1
727

728
                    if ignore_keys and dispatch_key in ignore_keys:
729
                        continue
730
                    assert dispatch_key in dispatch_keys, (
731
                        f"Dispatch key {dispatch_key} of kernel {v} "
732
                        "is not a supported dispatch key."
733
                    )
734
                    # We only allow at most 3 levels of namespace for kernels.
735
                    # We will append "native" to a custom kernel namespace.
736
                    namespace_helper = NamespaceHelper.from_namespaced_entity(
737
                        v, max_level=3
738
                    )
739
                    kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
740
                    # Why is 'structured' included? External backends (e.g.
741
                    # XLA) opt into which ops are structured independently
742
                    # of which in-tree ops are structured
743
                    dispatch[dispatch_key] = BackendMetadata(
744
                        kernel=namespace_helper.entity_name,
745
                        structured=structured
746
                        and is_structured_dispatch_key(dispatch_key),
747
                        cpp_namespace=(kernel_namespace + "::native"),
748
                    )
749
                    if (
750
                        dispatch_key is DispatchKey.CompositeImplicitAutograd
751
                        and v == cpp.name(func)
752
                    ):
753
                        redundant_composite_implicit_autograd = True
754

755
            # We count the number of dispatch keys which have not been ignored to prevent a dispatch table
756
            # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
757
            # from being treated as redundant.
758
            assert not (
759
                num_dispatch_keys == 1 and redundant_composite_implicit_autograd
760
            ), (
761
                "unnecessary dispatch table for this function; just delete the dispatch "
762
                "key entirely"
763
            )
764
            # if a function is a structured delegate, deleting the dispatch
765
            # table is NOT semantics preserving
766
            assert (
767
                structured_delegate
768
                or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
769
                or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
770
                or num_dispatch_keys != 1
771
            ), (
772
                f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
773
                f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}.  Rename your implementation to the expected "
774
                "name, then delete the dispatch table"
775
            )
776
        elif not structured and structured_delegate is None:
777
            name = str(func.name.name)
778
            assert not (
779
                name.startswith("new_")
780
                or name.endswith("_like")
781
                # TODO: maybe it's better to test the return
782
                or (
783
                    func.arguments.tensor_options
784
                    and not func.arguments.has_tensor_arg()
785
                )
786
            ), (
787
                f"expected {name} to have a CompositeExplicitAutograd "
788
                "dispatch entry, but there was no dispatch table.  Factory functions "
789
                "should not have implicit dispatch as they should not be decomposed "
790
                "for __torch_dispatch__"
791
            )
792
            dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
793
                cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
794
            )
795

796
        composites_in_dispatch = [
797
            d
798
            for d in dispatch
799
            if d == DispatchKey.CompositeExplicitAutograd
800
            or d == DispatchKey.CompositeExplicitAutogradNonFunctional
801
            or d == DispatchKey.CompositeImplicitAutograd
802
            or d == DispatchKey.CompositeImplicitAutogradNestedTensor
803
        ]
804

805
        assert len(composites_in_dispatch) <= 1 or (
806
            len(composites_in_dispatch) == 2
807
            and (
808
                DispatchKey.CompositeExplicitAutogradNonFunctional
809
                not in composites_in_dispatch
810
            )
811
            and (
812
                DispatchKey.CompositeImplicitAutogradNestedTensor
813
                in composites_in_dispatch
814
            )
815
        ), (
816
            "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
817
            "or CompositeImplicitAutograd on a single kernel; each "
818
            "strictly subsumes the other.  If you wanted to provide an explicit autograd "
819
            "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
820
        )
821

822
        autogen_str = e.pop("autogen", "")
823
        assert isinstance(autogen_str, str)
824
        autogen = (
825
            []
826
            if autogen_str == ""
827
            else [OperatorName.parse(x) for x in autogen_str.split(", ")]
828
        )
829

830
        raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
831
        ufunc_inner_loop = {}
832
        if isinstance(raw_ufunc_inner_loop, str):
833
            ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
834
                raw_ufunc_inner_loop, UfuncKey.Generic
835
            )
836
        elif isinstance(raw_ufunc_inner_loop, dict):
837
            for k, vo in raw_ufunc_inner_loop.items():
838
                if k == "__line__":
839
                    continue
840
                assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
841
                assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
842
                ufunc_key = UfuncKey.parse(k)
843
                ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
844
        else:
845
            raise AssertionError(
846
                f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
847
            )
848
        # Program the BackendIndex for the implicit dispatch entry from ufunc
849
        if ufunc_inner_loop:
850
            assert structured, "ufunc must be structured"
851

852
            # Delay import ufunc here to avoid circular import issue
853
            # See: https://github.com/pytorch/pytorch/issues/81294
854
            import torchgen.api.ufunc as ufunc
855

856
            for dispatch_key in UFUNC_DISPATCH_KEYS:
857
                assert (
858
                    dispatch_key not in dispatch
859
                ), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
860
                dispatch[dispatch_key] = BackendMetadata(
861
                    kernel=ufunc.schema_kernel_name(func, dispatch_key),
862
                    structured=True,
863
                    cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
864
                )
865

866
        if structured_delegate:
867
            # Structured functions MUST have a dispatch table
868
            is_abstract = True
869
        else:
870
            is_abstract = (
871
                dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
872
                and dispatch.keys()
873
                != {DispatchKey.CompositeImplicitAutogradNestedTensor}
874
                and dispatch.keys()
875
                != {
876
                    DispatchKey.CompositeImplicitAutograd,
877
                    DispatchKey.CompositeImplicitAutogradNestedTensor,
878
                }
879
            )
880

881
        has_composite_implicit_autograd_kernel = (
882
            DispatchKey.CompositeImplicitAutograd in dispatch
883
        )
884
        has_composite_implicit_autograd_nested_tensor_kernel = (
885
            DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
886
        )
887
        has_composite_explicit_autograd_kernel = (
888
            DispatchKey.CompositeExplicitAutograd in dispatch
889
        )
890
        has_composite_explicit_autograd_non_functional_kernel = (
891
            DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
892
        )
893

894
        # We aren't going to store dispatch metadata inline in NativeFunctions;
895
        # instead it is separately indexed by backend (so other backends can
896
        # add more dispatch entries after the fact).  Reindex the individual
897
        # metadata by OperatorName!
898
        backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
899

900
        # don't care if it exists or not; make it easier to use this function
901
        # with other yaml parsers that aren't setting __line__ in the dict
902
        e.pop("__line__", None)
903
        assert not e, f"leftover entries: {e}"
904

905
        # Asserts that we can't do in post_init, because they rely on backend-specific info
906
        if structured_delegate is not None:
907
            for key in STRUCTURED_DISPATCH_KEYS:
908
                assert key not in dispatch, (
909
                    f"if structured_delegate, then must not have {key} in dispatch dictionary "
910
                    "(it is delegated!)"
911
                )
912

913
        return (
914
            NativeFunction(
915
                func=func,
916
                use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
917
                variants=variants,
918
                structured=structured,
919
                structured_delegate=structured_delegate,
920
                structured_inherits=structured_inherits,
921
                precomputed=precomputed,
922
                autogen=autogen,
923
                ufunc_inner_loop=ufunc_inner_loop,
924
                manual_kernel_registration=manual_kernel_registration,
925
                manual_cpp_binding=manual_cpp_binding,
926
                python_module=python_module,
927
                category_override=category_override,
928
                device_guard=device_guard,
929
                device_check=device_check,
930
                loc=loc,
931
                cpp_no_default_args=cpp_no_default_args,
932
                is_abstract=is_abstract,
933
                has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
934
                has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
935
                has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
936
                has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
937
                tags=tags,
938
                namespace=namespace,
939
            ),
940
            backend_metadata,
941
        )
942

943
    def validate_unstructured(self) -> None:
944
        # TODO: probably better to accumulate these errors and report them all
945
        # at once
946
        assert not self.structured, (
947
            "This function is structured, but there was "
948
            "no valid functional variant of it."
949
        )
950
        assert self.structured_delegate, (
951
            "This function delegates to another structured out function, "
952
            "but no valid function was found (the delegate may not exist, or it has the wrong type)"
953
        )
954

955
    # __post_init__ functions in dataclasses can be used to do extra
956
    # validation after construction.
957
    #
958
    # Notice that we don't do any type validation here.  In fact, we
959
    # rely exclusively on mypy to check if you've done types correctly!
960
    # Validation is for nontrivial invariants that cannot be (conveniently)
961
    # encoded in the type system.
962
    def __post_init__(self) -> None:
963
        if self.func.arguments.out:
964
            assert self.variants == {Variant.function}, (
965
                "Native functions with out arguments MUST "
966
                "be declared with only function variant; e.g., variants: function; "
967
                "otherwise you will tickle a Python argument binding bug "
968
                "(which usually manifests itself as the result variable being undefined.)"
969
            )
970
        if self.structured:
971
            assert self.func.kind() == SchemaKind.out, (
972
                "Put structured field on the out= "
973
                "variant of a function; did you mean structured_delegate?"
974
            )
975
            assert (
976
                self.device_guard
977
            ), "device_guard: False is not respected by structured kernels"
978
        if self.structured_delegate:
979
            assert self.func.kind() != SchemaKind.out, (
980
                "structured_delegate field not allowed "
981
                "on out= functions; did you mean structured?"
982
            )
983
            assert (
984
                self.device_guard
985
            ), "device_guard: False is not respected by structured kernels"
986
        # Technically, with the asserts above, this assert is impossible to
987
        # happen
988
        assert not (
989
            self.structured and self.structured_delegate
990
        ), "Cannot have both structured and structured_delegate on function"
991
        defaulted_arguments = {
992
            a.name for a in self.func.schema_order_arguments() if a.default is not None
993
        }
994
        invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
995
        assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
996
        if self.structured_inherits is not None:
997
            assert (
998
                self.structured
999
            ), "structured_inherits must also imply structured: True"
1000
        if str(self.func.name).startswith("_foreach"):
1001
            assert self.device_check == DeviceCheckType.NoCheck, (
1002
                "foreach kernels fall back to slow path when tensor are on different devices, "
1003
                "device_check not allowed to be enabled"
1004
            )
1005

1006
        # NB: if your function accidentally has rand/dropout/... in its name
1007
        # but is not actually random, feel free to amend this to special case
1008
        if (
1009
            "rand" in str(self.func.name)
1010
            or (
1011
                (
1012
                    "dropout" in str(self.func.name)
1013
                    or any(
1014
                        "dropout" in arg.name for arg in self.func.arguments.flat_all
1015
                    )
1016
                )
1017
                # Backwards of dropout is typically deterministic
1018
                and "backward" not in str(self.func.name)
1019
                and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
1020
            )
1021
            or self.func.arguments.has_generator_arg()
1022
        ):
1023
            assert "nondeterministic_seeded" in self.tags, str(self.func.name)
1024

1025
    @property
1026
    def has_composite_kernel(self) -> bool:
1027
        return (
1028
            self.has_composite_implicit_autograd_kernel
1029
            or self.has_composite_explicit_autograd_kernel
1030
            or self.has_composite_explicit_autograd_non_functional_kernel
1031
        ) or (
1032
            self.has_composite_implicit_autograd_kernel
1033
            and self.has_composite_implicit_autograd_nested_tensor_kernel
1034
        )
1035

1036
    @property
1037
    def is_view_op(self) -> bool:
1038
        rets = self.func.returns
1039
        is_non_mutating_view = len(rets) > 0 and any(
1040
            r.annotation is not None and not r.annotation.is_write for r in rets
1041
        )
1042
        # See Note [resize_ in Functionalization] for more dtails
1043
        is_inplace_view = (
1044
            "inplace_view" in self.tags
1045
            and str(self.func.name) != "resize_"
1046
            and str(self.func.name) != "resize_as_"
1047
        )
1048
        is_wildcard_view = any(
1049
            inp.annotation is not None and "*" in inp.annotation.alias_set_after
1050
            for inp in self.func.schema_order_arguments()
1051
        )
1052
        return is_non_mutating_view or is_inplace_view or is_wildcard_view
1053

1054
    @property
1055
    def view_schema_kind(self) -> ViewSchemaKind:
1056
        if self.is_view_op and self.func.name.name.inplace:
1057
            assert "inplace_view" in self.tags
1058
            return ViewSchemaKind.aliasing_inplace
1059
        if self.is_view_op:
1060
            return ViewSchemaKind.aliasing
1061
        else:
1062
            return ViewSchemaKind.non_aliasing
1063

1064
    @property
1065
    def root_name(self) -> str:
1066
        return self.func.name.name.base
1067

1068
    @property
1069
    def part_of_structured_group(self) -> bool:
1070
        return self.structured or self.structured_delegate is not None
1071

1072

1073
class SchemaKind(Enum):
1074
    functional = auto()
1075
    inplace = auto()
1076
    out = auto()
1077
    mutable = auto()
1078
    scratch = auto()
1079

1080

1081
# A structured kernel is guaranteed to have a functional and out variant, and
1082
# optionally an inplace variant.
1083
#
1084
# NB: we create NativeFunctionsGroup *even if* the function is not
1085
# actually annotated structured.  Test the structured boolean to see if it
1086
# actually is structured or not.
1087
@dataclass(frozen=True)
1088
class NativeFunctionsGroup:
1089
    functional: NativeFunction
1090
    inplace: NativeFunction | None
1091
    mutable: NativeFunction | None
1092
    out: NativeFunction
1093

1094
    @property
1095
    def structured(self) -> bool:
1096
        # Whether or not the operator has a meta() function. This information is backend-agnostic.
1097
        return self.out.structured
1098

1099
    def __post_init__(self) -> None:
1100
        test_sig: FunctionSchema = self.functional.func.signature()
1101
        for f in self.functions():
1102
            if test_sig != f.func.signature():
1103
                raise AssertionError(
1104
                    "NativeFunctionsGroup constructed from two NativeFunctions "
1105
                    f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
1106
                )
1107

1108
            if self.structured != f.part_of_structured_group:
1109
                raise AssertionError(
1110
                    "NativeFunctionsGroup constructed from structured and unstructured "
1111
                    f"functions: {self.out.func.name} and {f.func.name}"
1112
                )
1113
        assert self.functional.func.kind() == SchemaKind.functional
1114
        assert self.out.func.kind() == SchemaKind.out
1115
        assert self.functional.namespace == self.out.namespace
1116
        if self.inplace is not None:
1117
            assert self.inplace.func.kind() == SchemaKind.inplace
1118
            assert self.inplace.namespace == self.functional.namespace
1119

1120
        if self.mutable is not None:
1121
            assert self.mutable.func.kind() == SchemaKind.mutable
1122
            assert self.mutable.namespace == self.functional.namespace
1123
            # See Note [Overload Ambiguity With Functional Variants]
1124
            assert self.functional.func.name.name.functional_overload
1125

1126
        if self.structured:
1127
            # For now, structured composite kernels are not supported (need some
1128
            # design work to figure out how to make the composite case work)
1129
            assert (
1130
                not self.out.has_composite_implicit_autograd_kernel
1131
                and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
1132
            )
1133

1134
            assert self.functional.structured_delegate == self.out.func.name, (
1135
                f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
1136
                f"but its actual delegate is {self.out.func.name}"
1137
            )
1138
            if self.inplace is not None:
1139
                assert self.inplace.structured_delegate == self.out.func.name
1140

1141
        generated_fns = sorted(
1142
            [str(f.func.name) for f in self.functions() if "generated" in f.tags]
1143
        )
1144
        generated_fns_str = ", ".join(str(x) for x in generated_fns)
1145
        expected_generated_fns: set[str] = set()
1146
        for f in self.functions():
1147
            expected_generated_fns.update(str(op) for op in f.autogen)
1148
        expected_generated_fns_str = ", ".join(
1149
            str(x) for x in sorted(expected_generated_fns)
1150
        )
1151
        if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
1152
            raise RuntimeError(
1153
                f"The codegen expects to be able to generate '{generated_fns_str}'."
1154
                " In order to generate them however, we expect them to be called out explicitly in the yaml."
1155
                f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
1156
            )
1157
        if expected_generated_fns_str != generated_fns_str:
1158
            raise RuntimeError(
1159
                f"The codegen expects to be able to generate '{generated_fns_str}'."
1160
                f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
1161
                f" Instead, it found 'autogen: {expected_generated_fns_str}'"
1162
            )
1163

1164
    def signature(self) -> FunctionSchema:
1165
        return self.out.func.signature()
1166

1167
    def functions(self) -> Iterator[NativeFunction]:
1168
        yield self.functional
1169
        yield self.out
1170
        if self.inplace is not None:
1171
            yield self.inplace
1172
        if self.mutable is not None:
1173
            yield self.mutable
1174

1175
    @property
1176
    def root_name(self) -> str:
1177
        return self.functional.root_name
1178

1179
    @staticmethod
1180
    def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
1181
        assert d
1182
        if len(d) == 1:
1183
            return None
1184
        d = dict(d)  # non-destructive updates please
1185
        functional = d.pop(SchemaKind.functional, None)
1186
        inplace = d.pop(SchemaKind.inplace, None)
1187
        mutable = d.pop(SchemaKind.mutable, None)
1188
        out = d.pop(SchemaKind.out, None)
1189
        assert not d
1190
        assert functional is not None
1191
        # There are a few operators which only have functional/inplace variants;
1192
        # these don't count as structured for our purposes here
1193
        if out is None:
1194
            return None
1195
        # assuming all variants have the same namespace
1196
        return NativeFunctionsGroup(
1197
            functional=functional,
1198
            inplace=inplace,
1199
            mutable=mutable,
1200
            out=out,
1201
        )
1202

1203

1204
@dataclass(frozen=True)
1205
class BackendMetadata:
1206
    # The name of the backend kernel, for a given operator
1207
    # for in-tree backends. These names come directly from the 'dispatch" field
1208
    # in native_functions.yaml. The dispatch entry is optional; in that
1209
    # case, that is equivalent to having written:
1210
    #
1211
    #   dispatch:
1212
    #       CompositeImplicitAutograd: $operator_name
1213
    kernel: str
1214
    # Whether or not the operator has a structured kernel implemented, for this particular backend.
1215
    # For in-tree backends, they all have the same value for structured- this is listed
1216
    # in native_functions.yaml.
1217
    # However, external backends like XLA can indendently toggle which ops are structured.
1218
    structured: bool
1219

1220
    # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
1221
    cpp_namespace: str
1222

1223
    def supports_symint(self) -> bool:
1224
        return "_symint" in self.kernel
1225

1226

1227
@dataclass(frozen=True)
1228
class UfuncInnerLoop:
1229
    name: str
1230
    supported_dtypes: OrderedSet[ScalarType]
1231
    # key is stored here because it affects the semantics of name,
1232
    # so its helpful to have them together for further processing
1233
    ufunc_key: UfuncKey
1234

1235
    @staticmethod
1236
    def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
1237
        name, supported_dtypes_str = value.split(" ", 1)
1238
        assert supported_dtypes_str[0] == "("
1239
        assert supported_dtypes_str[-1] == ")"
1240
        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
1241
        for k in supported_dtypes_str[1:-1].split(", "):
1242
            supported_dtypes |= ScalarType.parse_set(k)
1243
        return UfuncInnerLoop(
1244
            name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
1245
        )
1246

1247

1248
# BackendIndex represents a backend.
1249
# The BackendIndex encodes per-operator information that is potentially different
1250
# for each backend. The most obvious example is the name of the kernel
1251
# (the 'dispatch' entry in native_functions.yaml).
1252
# However, there can be other examples of different backends having different information.
1253
# External backends can choose to opt their kernels to be structured independently from in-tree backends,
1254
# which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
1255
@dataclass(frozen=True)
1256
class BackendIndex:
1257
    dispatch_key: DispatchKey
1258
    # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
1259
    # All in-tree ops use out kernels, while XLA uses functional kernels.
1260
    use_out_as_primary: bool
1261
    # Whether the backend requires a device guard, and device checks.
1262
    # For in-tree backends, this is currently just CUDA/HIP
1263
    # For out-of-tree backends, this is currently just Intel XPU
1264
    device_guard: bool
1265
    # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
1266
    external: bool
1267
    # Other backend-specific information that is on a per-operator basis
1268
    index: dict[OperatorName, BackendMetadata]
1269

1270
    @staticmethod
1271
    def grow_index(
1272
        parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1273
        child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1274
    ) -> None:
1275
        for k, v in child_index.items():
1276
            for op_name, metadata in v.items():
1277
                assert (
1278
                    op_name not in parent_index[k]
1279
                ), f"duplicate operator {op_name} for dispatch key {k}"
1280
                parent_index[k][op_name] = metadata
1281

1282
    def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
1283
        if self.use_out_as_primary:
1284
            return g.out
1285
        else:
1286
            return g.functional
1287

1288
    def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
1289
        m = self.get_kernel(g)
1290
        return m is not None
1291

1292
    def get_kernel(
1293
        self, g: NativeFunction | NativeFunctionsGroup
1294
    ) -> BackendMetadata | None:
1295
        if isinstance(g, NativeFunction):
1296
            f = g
1297
        elif isinstance(g, NativeFunctionsGroup):
1298
            f = self.primary(g)
1299
        else:
1300
            assert_never(g)
1301
        if f.func.name not in self.index:
1302
            return None
1303
        return self.index[f.func.name]
1304

1305
    def native_function_class_name(self) -> str | None:
1306
        if self.external:
1307
            return f"{str(self.dispatch_key)}NativeFunctions"
1308
        else:
1309
            # TODO: This discrepancy isn't required; we could also generated
1310
            # a class for in-tree kernels. It'll just require carefully
1311
            # updating every kernel definition + callsite of every in-tree aten kernel.
1312
            return None
1313

1314

1315
# The function schema is undoubtedly the most important data structure
1316
# in all of the codegen, as it defines the type signature for operators,
1317
# and most of the code generation we do is type directed (e.g., look at
1318
# the types, decide what to do.  Think about how we code generate
1319
# C++ function stubs!)
1320
#
1321
# We will also see in this class the general structure for how we model
1322
# data in this code generation.  A few notable properties to point out
1323
# ahead of time:
1324
#
1325
#   - These dataclasses are a *lossless* representation of the strings
1326
#     they are parsed from.  In fact, we assert that given the
1327
#     information stored in the dataclass, we can exactly reconstruct
1328
#     the string we parsed from (and assert this inside the parse
1329
#     definition).  There are a few reasons for this:
1330
#
1331
#       - If you find that it is difficult to reconstruct the string
1332
#         given a dataclass, that is a clue that you are data
1333
#         representation is wrong.
1334
#
1335
#       - It helps ensure that all relevant information is present
1336
#         in the dataclass, so that downstream users aren't tempted
1337
#         to reparse the original string to get some information
1338
#         that was omitted.
1339
#
1340
#       - It forces you to represent the data in-memory in the same way
1341
#         it is recorded textually, which makes the dataclasses easier
1342
#         to understand for someone who is familiar with the
1343
#         textual format.  (As a tradeoff, it means you have to model
1344
#         the syntax, even when it is inconvenient.  But maybe that means
1345
#         the syntax is bad!)  If you don't understand the internal
1346
#         representation, go look at the printing code to see how
1347
#         it maps onto the surface syntax!
1348
#
1349
#       - It makes it easy to test the parsing code, as parsing code
1350
#         that is inconsistent with the string code will fail early
1351
#         and loudly.  (As a tradeoff, it makes the parsing code a bit
1352
#         brittle (in particular, with trivial whitespace changes you
1353
#         are likely to trigger an assert error).
1354
#
1355
#     In general, try to make the __str__ code as simple as possible
1356
#     (even at the cost of more complex parsing logic.)  Additionally,
1357
#     try to minimize redundancy in data representation.  (Precomputed
1358
#     fields are OK though: they are defined as a simple function on
1359
#     the canonical representation in question.)
1360
#
1361
#   - These dataclasses are all frozen; once constructed their
1362
#     values never change.  This makes it easy to tell where any
1363
#     given data came from: just look to the constructor.  As a
1364
#     tradeoff, you can't easily "decorate" a schema with extra
1365
#     information from a post-facto analysis.  We impose this
1366
#     restriction to make these structures more understandable.
1367
#
1368
@dataclass(frozen=True)
1369
class FunctionSchema:
1370
    # The name of the operator this function schema describes.
1371
    name: OperatorName
1372

1373
    arguments: Arguments
1374

1375
    # TODO: Need to handle collisions with argument names at some point
1376
    returns: tuple[Return, ...]
1377

1378
    @property
1379
    def is_mutable(self) -> bool:
1380
        def is_write(arg: Argument) -> bool:
1381
            if arg.annotation is None:
1382
                return False
1383
            return arg.annotation.is_write
1384

1385
        # Corresponds to torch._C._FunctionSchema.is_mutable
1386
        # See aten/src/ATen/core/function_schema.h (keep these in sync)
1387
        return any(is_write(a) for a in self.arguments.flat_all)
1388

1389
    def schema_order_arguments(self) -> Iterator[Argument]:
1390
        return itertools.chain(
1391
            self.arguments.flat_positional,
1392
            self.arguments.flat_kwarg_only,
1393
            self.arguments.out,
1394
        )
1395

1396
    decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
1397

1398
    @staticmethod
1399
    def parse(func: str) -> FunctionSchema:
1400
        # We should probably get a proper parser here
1401
        decls = FunctionSchema.decl_re.findall(func)
1402
        assert len(decls) == 1, f"Invalid function schema: {func}"
1403
        ops, args, return_decl = decls[0]
1404
        name = OperatorName.parse(ops)
1405
        arguments = Arguments.parse(args)
1406
        returns = parse_returns(return_decl)
1407
        r = FunctionSchema(name=name, arguments=arguments, returns=returns)
1408
        assert str(r) == func, f"{str(r)} != {func}"
1409
        return r
1410

1411
    def returns_are_aliased(self) -> bool:
1412
        # We assert earlier that schemas can't have a mix of aliased and non-aliased returns
1413
        return any(
1414
            r
1415
            for r in self.returns
1416
            if r.annotation is not None and r.annotation.is_write
1417
        )
1418

1419
    def __post_init__(self) -> None:
1420
        for arg, ret in zip(self.arguments.out, self.returns):
1421
            assert arg.annotation == ret.annotation, (
1422
                "Out arguments must have matching return Tensor; furthermore, "
1423
                "the ith-argument needs to correspond to the ith return"
1424
            )
1425
        # We also enforce that if you have any mutable, positional args, then they are not returned.
1426
        # This makes it easier to group these functions properly with their functional/out= counterparts.
1427
        for a in self.arguments.post_self_positional_mutable:
1428
            assert not any(
1429
                a.annotation == r.annotation for r in self.returns
1430
            ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
1431
        # Invariant: we expect out arguments to appear as keyword arguments in the schema.
1432
        # This means that all mutable returns should be aliased to a keyword argument
1433
        # (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
1434
        # See Note [is_out_fn]
1435
        out_and_self = list(self.arguments.out) + [
1436
            arg for arg in self.arguments.flat_positional if arg.name == "self"
1437
        ]
1438
        mutable_returns = [
1439
            ret
1440
            for ret in self.returns
1441
            if ret.annotation is not None and ret.annotation.is_write
1442
        ]
1443
        immutable_returns = [
1444
            ret
1445
            for ret in self.returns
1446
            if ret.annotation is None or not ret.annotation.is_write
1447
        ]
1448
        # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
1449
        # because:
1450
        # (1) It's more annoying to handle properly
1451
        # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
1452
        # Instead, we expect the (a!) argument to not be returned.
1453
        assert (
1454
            len(mutable_returns) == 0 or len(immutable_returns) == 0
1455
        ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
1456
        for ret in mutable_returns:
1457
            assert any(ret.annotation == arg.annotation for arg in out_and_self), (
1458
                'All mutable returns must be aliased either to a keyword argument, or to "self". '
1459
                "Did you forget to mark an out argument as keyword-only?"
1460
            )
1461
        if self.arguments.out:
1462
            # out= ops that return their mutable inputs are only really useful for method chaining.
1463
            # And method chaining is only really useful if the thing you're returning is a plain Tensor.
1464
            # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
1465
            # and all other types of out= op schemas should return void.
1466
            # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
1467
            if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
1468
                assert (
1469
                    len(self.returns) == 0
1470
                ), "out= ops that accept tensor lists as out arguments "
1471
                "are expected to have no return type (since you can't do method chaining on them)"
1472
            else:
1473
                # mutable keyword arguments whose name has _scratch_ prefix are
1474
                # scratch tensors for memory planning and should not be returned
1475
                assert len(
1476
                    [
1477
                        arg
1478
                        for arg in self.arguments.out
1479
                        if not arg.name.startswith("_scratch_")
1480
                    ]
1481
                ) == len(
1482
                    self.returns
1483
                ), "Must return as many arguments as there are out arguments, or no return at all"
1484

1485
        if self.name.name.inplace:
1486
            self_a = self.arguments.self_arg
1487
            assert (
1488
                self_a
1489
                and self_a.argument.annotation
1490
                and self_a.argument.annotation.is_write
1491
            )
1492
            if self_a.argument.type == BaseType(BaseTy.Tensor):
1493
                # All inplace ops with an ordinary `Tensor self` argument should return self,
1494
                # to allow for method chaining.
1495
                assert (
1496
                    len(self.returns) == 1
1497
                    and self.returns[0].annotation == self_a.argument.annotation
1498
                )
1499
            else:
1500
                # You can't method chain on non-tensor self arguments though (like a List[Tensor])
1501
                # so in all other cases we expect the return type to be none.
1502
                assert len(self.returns) == 0
1503

1504
        if self.arguments.tensor_options is not None:
1505
            assert self.kind() == SchemaKind.functional, (
1506
                "Found an operator that is not functional or out variant, but has tensor options arguments."
1507
                "This is not allowed- tensor options arguments are only allowed for factory functions."
1508
                f"schema: {str(self)}"
1509
            )
1510
        if self.is_functional_fn():
1511
            assert self.kind() == SchemaKind.functional, (
1512
                "Found an operator that is not functional, but its overload contains the string 'functional'."
1513
                "This is a special keyword in the codegen, please use a different overload name."
1514
                f"schema: {str(self)}"
1515
            )
1516

1517
    def is_functional_fn(self) -> bool:
1518
        return "functional" in self.name.overload_name
1519

1520
    def is_out_fn(self) -> bool:
1521
        # Note [is_out_fn]
1522
        #
1523
        # out functions are the variants which take an explicit out= argument
1524
        # to populate into.  We need to know if a schema corresponds to an
1525
        # out function for several reasons:
1526
        #
1527
        #   - They codegen differently in C++ API
1528
        #       - codegen to at::add_out rather than at::add
1529
        #       - out argument is moved to front of C++ argument list
1530
        #
1531
        # out functions are DEFINED to be any function with a keyword-only
1532
        # argument that is mutable.  In principle, this could lead to a
1533
        # false positive if you define a function that mutates a
1534
        # kwarg only argument, but this isn't the "true" output of this
1535
        # function.  A more robust definition that would work in this
1536
        # case would also look at:
1537
        #
1538
        #   - The output types.  Out functions take in the arguments
1539
        #     they mutate and then return them again; this is sort
1540
        #     of "definitionally" what makes something an out function.
1541
        #     Historically, we DO check this for consistency.
1542
        #   - Correspondence with pure variant.  An out function
1543
        #     should have a signature equivalent to its pure variant,
1544
        #     but just with extra kwargs for the output elements.  This
1545
        #     is difficult to actually check for and historically
1546
        #     we only do this check in tools/
1547
        return bool(self.arguments.out)
1548

1549
    def kind(self) -> SchemaKind:
1550
        """
1551
        What kind of schema is this?  A functional schema is one
1552
        that returns a newly allocated output; an inplace schema
1553
        modifies the self argument inplace; an out schema writes
1554
        the result into an explicitly provided out argument.
1555
        """
1556
        is_out = bool(self.arguments.out)
1557
        is_scratch = bool(
1558
            [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
1559
        )
1560
        is_inplace = self.name.name.inplace
1561
        is_mutable = any(
1562
            a.annotation is not None and a.annotation.is_write
1563
            for a in self.arguments.post_self_positional
1564
        )
1565
        assert not (is_out and is_inplace)
1566
        # out= and inplace schemas can also have post_self_positional mutable args,
1567
        # but we give precedence to out= and inplace when deciding the schema kind.
1568
        # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
1569
        # to also worry about mutable post_self_positional arguments,
1570
        # but it seems like a much bigger lift to classify them has having a new schema kind.
1571
        # The number of ops that fit in this strange category is small enough that
1572
        # we can probably manually write code for them instead of forcing the codegen to handle them.
1573
        if is_inplace:
1574
            return SchemaKind.inplace
1575
        elif is_scratch:
1576
            assert (
1577
                is_out
1578
            ), "invariant: all scratch operators are expected to be out= operators too"
1579
            return SchemaKind.scratch
1580
        elif is_out:
1581
            assert (
1582
                not is_scratch
1583
            ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
1584
            return SchemaKind.out
1585
        elif is_mutable:
1586
            return SchemaKind.mutable
1587
        else:
1588
            return SchemaKind.functional
1589

1590
    # For every return:
1591
    # - If the return aliases an input, we return the input name
1592
    # - Otherwise, we return None.
1593
    # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
1594
    def aliased_return_names(self) -> list[str | None]:
1595
        outs: list[str | None] = []
1596
        for r in self.returns:
1597
            aliased_args = [
1598
                a
1599
                for a in self.arguments.flat_all
1600
                if a.annotation is not None and a.annotation == r.annotation
1601
            ]
1602
            if len(aliased_args) == 0:
1603
                outs.append(None)
1604
            elif len(aliased_args) == 1:
1605
                outs.append(aliased_args[0].name)
1606
            else:
1607
                aliased_names = ", ".join(a.name for a in aliased_args)
1608
                raise AssertionError(
1609
                    f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
1610
                )
1611
        return outs
1612

1613
    def signature(
1614
        self,
1615
        *,
1616
        strip_default: bool = False,
1617
        strip_view_copy_name: bool = False,
1618
        keep_return_names: bool = False,
1619
    ) -> FunctionSchema:
1620
        """
1621
                Certain schemas are 'related', in that they are simply
1622
                inplace/out/functional versions of the same function.  This method
1623
                factors these schemas into the "core" functional signature which
1624
                is equal across all versions.
1625

1626
                Here is what normalization happens to the schema to convert
1627
                it to a signature:
1628
                - The overload name is stripped (name is retained, since
1629
                  it expresses semantic content about what the function does)
1630
                - Inplace is set False
1631
                - Out arguments are stripped
1632
                - Mutable post_self_positional args are converted to returns
1633
                - Mutability annotations are stripped  (this is sound
1634
                  because you cannot overload on mutability annotation)
1635
                - Return names are stripped since they are not overloadable and
1636
                  some variants have return names but some not
1637
                - TensorOptions are dropped
1638
                  because out= variants of factory functions don't include them
1639
                  (and we want to be able to pair up factory functions with their out variants)
1640

1641
                Finally, we want to be able to pair up related "view" and their
1642
                corresponding "view_copy" operators. We do this by optionally
1643
                stripping the trailing "_copy" from the base name.
1644

1645
                Example of a mutable op before and after:
1646

1647
                f.func (Mutable operator):
1648
        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
1649

1650
                f.func (Corresponding functional operator):
1651
        _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)  # noqa: B950
1652

1653
                f.func.signature() output:
1654
        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)  # noqa: B950
1655
        """
1656

1657
        def strip_ret_annotation(r: Return) -> Return:
1658
            return Return(
1659
                name=r.name if keep_return_names else None,
1660
                type=r.type,
1661
                annotation=None,
1662
            )
1663

1664
        base_name = self.name.name.base
1665
        if strip_view_copy_name:
1666
            if base_name.endswith("_copy"):
1667
                base_name = base_name.replace("_copy", "")
1668
            elif base_name.endswith("_scatter"):
1669
                base_name = base_name.replace("scatter", "inverse")
1670

1671
        # find mutable inputs that are not originally returned, and convert them to returns
1672
        returns_from_mutable_inputs = tuple(
1673
            # When we're grouping functions we strip the return names,
1674
            # but when we're generating the actual functional variants then we follow
1675
            # a convention for what to name the returns
1676
            Return(
1677
                name=f"{a.name}_out" if keep_return_names else None,
1678
                type=a.type,
1679
                annotation=None,
1680
            )
1681
            for a in itertools.chain(
1682
                # Order is important here (otherwise e.g. inplace with mutable args
1683
                # and out= with mutable args won't have the same signature)
1684
                [self.arguments.self_arg.argument]
1685
                if self.arguments.self_arg is not None
1686
                else [],
1687
                self.arguments.out,
1688
                self.arguments.post_self_positional,
1689
            )
1690
            if a.annotation is not None
1691
            and a.annotation.is_write
1692
            and not any(a.annotation == r.annotation for r in self.returns)
1693
        )
1694
        original_returns = tuple(map(strip_ret_annotation, self.returns))
1695
        # Ordering is important here. We expect the "mutable input" returns to come last.
1696
        returns = original_returns + returns_from_mutable_inputs
1697

1698
        args_sig = self.arguments.signature(strip_default=strip_default)
1699
        # See Note [bernoulli.p schema]
1700
        if str(self.name) == "bernoulli.p":
1701
            args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
1702

1703
        return FunctionSchema(
1704
            name=OperatorName(
1705
                name=BaseOperatorName(
1706
                    base=base_name,
1707
                    inplace=False,
1708
                    dunder_method=self.name.name.dunder_method,
1709
                ),
1710
                overload_name="",  # stripped
1711
            ),
1712
            arguments=args_sig,
1713
            returns=returns,
1714
        )
1715

1716
    def view_signature(self) -> FunctionSchema:
1717
        return self.signature(strip_view_copy_name=True)
1718

1719
    def with_name(self, name: OperatorName) -> FunctionSchema:
1720
        return FunctionSchema(
1721
            name=name,
1722
            arguments=self.arguments,
1723
            returns=self.returns,
1724
        )
1725

1726
    @property
1727
    def modifies_arguments(self) -> bool:
1728
        return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
1729

1730
    def has_symint(self) -> bool:
1731
        return self.arguments.has_symint_arg()
1732

1733
    def __str__(self) -> str:
1734
        all_arguments_str = str(self.arguments)
1735
        if len(self.returns) == 1:
1736
            returns = str(self.returns[0])  # omit parentheses
1737
        else:
1738
            returns = "(" + ", ".join(map(str, self.returns)) + ")"
1739
        return f"{self.name}({all_arguments_str}) -> {returns}"
1740

1741

1742
# Here is the rest of the data model, described more briefly.
1743

1744

1745
# Simplified version for what actually shows up in built-ins.
1746
# Look at alias_info.h for expanded syntax.  If you need the structure,
1747
# you also need to make this structure recursive so it can be lined
1748
# up with the type components too.  For primitives this isn't really
1749
# necessary
1750
@dataclass(frozen=True)
1751
class Annotation:
1752
    # Typically only has one element.  Not actually a set so
1753
    # we can conveniently assume it is canonically ordered
1754
    alias_set: tuple[str, ...]
1755
    is_write: bool
1756
    alias_set_after: tuple[str, ...]
1757

1758
    @staticmethod
1759
    def parse(ann: str) -> Annotation:
1760
        # TODO: implement a proper parser if this gets more ugly
1761
        # Regex Explanation:
1762
        # Example: "a! -> a|b"
1763
        # Group #1: alias before optional '|', required. Matches the first
1764
        #   character 'a' in the example
1765
        # Group #2: optional alias set after optional '|', matches empty string
1766
        #   in the example
1767
        # Group #3: optional "is write" flag, matches '!' in the example.
1768
        # Group #4: optional section containing arrow, matches " -> a|b" in the
1769
        #   example.
1770
        # Group #5: optional alias after set, supports wildcard, matches "a|b"
1771
        #   in the example.
1772
        # Group #6: optional sub-section of alias after set, matches "|b" in the
1773
        #   example.
1774
        m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
1775

1776
        assert m is not None, f"unrecognized alias annotation {ann}"
1777
        before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
1778
        alias_set = tuple(before_alias.split("|"))
1779
        is_write = m.group(3) == "!"
1780
        assert not (
1781
            is_write and len(alias_set) > 1
1782
        ), f"alias set larger than 1 is not mutable, got {ann} instead."
1783
        after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
1784
        assert not (
1785
            len(before_alias) > 1 and len(after_set) > 1
1786
        ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
1787
        r = Annotation(
1788
            alias_set=alias_set, is_write=is_write, alias_set_after=after_set
1789
        )
1790
        assert str(r) == ann, f"{r} != {ann}"
1791
        return r
1792

1793
    def __str__(self) -> str:
1794
        alias_set = "|".join(self.alias_set)
1795
        if self.is_write:
1796
            alias_set = f"{alias_set}!"
1797
        alias_set_after = "|".join(self.alias_set_after)
1798
        if alias_set_after:
1799
            alias_set = f'{alias_set}{" -> "}{alias_set_after}'
1800
        return alias_set
1801

1802

1803
# The base class for the type system.  This is also loosely modeled
1804
# off of jit_type.h, but we've simplified the hierarchy to focus
1805
# in on the aspects of the type system that matter for code generation
1806
# (for example, there's no SingleElementType subclass anymore).
1807
# You never actually construct a Type; usually it's going to be one
1808
# of the subclasses.  If Python had ADTs this would be one!
1809
@dataclass(frozen=True)
1810
class Type:
1811
    @staticmethod
1812
    def parse(t: str) -> Type:
1813
        r = Type._parse(t)
1814
        assert str(r) == t, f"{r} != {t}"
1815
        return r
1816

1817
    @staticmethod
1818
    def _parse(t: str) -> Type:
1819
        m = re.match(r"^(.+)\?$", t)
1820
        if m is not None:
1821
            return OptionalType(Type.parse(m.group(1)))
1822
        m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
1823
        if m is not None:
1824
            size = int(m.group(2)) if m.group(2) is not None else None
1825
            return ListType(elem=Type.parse(m.group(1)), size=size)
1826

1827
        # '__torch__.torch.classes.' is the prefix for custom class
1828
        m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
1829
        if m is not None:
1830
            return CustomClassType(m.group(1))
1831
        try:
1832
            return BaseType(BaseTy[t])
1833
        except KeyError as e:
1834
            raise RuntimeError(f"unrecognized type {t}") from e
1835

1836
    def __str__(self) -> str:
1837
        raise NotImplementedError
1838

1839
    # WARNING: These concepts are not very well-defined.  For example,
1840
    # is "int?" nullable? How about "int?[]".  They are defined
1841
    # so we can conveniently generate legacy Declarations.yaml but
1842
    # really we should probably just remove these at some point
1843

1844
    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1845
        raise NotImplementedError
1846

1847
    def is_tensor_like(self) -> bool:
1848
        return self.is_base_ty_like(BaseTy.Tensor)
1849

1850
    def is_generator_like(self) -> bool:
1851
        return self.is_base_ty_like(BaseTy.Generator)
1852

1853
    def is_symint_like(self) -> bool:
1854
        return self.is_base_ty_like(BaseTy.SymInt)
1855

1856
    def is_nullable(self) -> bool:
1857
        raise NotImplementedError
1858

1859
    def is_list_like(self) -> ListType | None:
1860
        raise NotImplementedError
1861

1862

1863
# Base types are simple, atomic types with no further structure
1864
class BaseTy(Enum):
1865
    Generator = auto()
1866
    ScalarType = auto()
1867
    Tensor = auto()
1868
    int = auto()
1869
    Dimname = auto()
1870
    DimVector = auto()
1871
    float = auto()
1872
    str = auto()
1873
    bool = auto()
1874
    Layout = auto()
1875
    Device = auto()
1876
    DeviceIndex = auto()
1877
    Scalar = auto()
1878
    MemoryFormat = auto()
1879
    QScheme = auto()
1880
    Storage = auto()
1881
    Stream = auto()
1882
    SymInt = auto()
1883
    SymBool = auto()
1884
    ConstQuantizerPtr = auto()  # TODO: rename
1885
    GraphModule = auto()
1886

1887

1888
@dataclass(frozen=True)
1889
class BaseType(Type):
1890
    name: BaseTy
1891

1892
    def __str__(self) -> str:
1893
        return f"{self.name.name}"
1894

1895
    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1896
        return self.name == base_ty
1897

1898
    def is_nullable(self) -> bool:
1899
        return False
1900

1901
    def is_list_like(self) -> ListType | None:
1902
        return None
1903

1904
    def is_symint_like(self) -> bool:
1905
        return self.name == BaseTy.SymInt
1906

1907

1908
# Optional types may be specified, or may also be validly given None
1909
@dataclass(frozen=True)
1910
class OptionalType(Type):
1911
    elem: Type
1912

1913
    def __str__(self) -> str:
1914
        return f"{self.elem}?"
1915

1916
    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1917
        return self.elem.is_base_ty_like(base_ty)
1918

1919
    def is_symint_like(self) -> bool:
1920
        return self.elem.is_symint_like()
1921

1922
    def is_nullable(self) -> bool:
1923
        return True
1924

1925
    def is_list_like(self) -> ListType | None:
1926
        return self.elem.is_list_like()
1927

1928

1929
# A type representing a PyTorch custom class
1930
@dataclass(frozen=True)
1931
class CustomClassType(Type):
1932
    class_name: str
1933

1934
    def __str__(self) -> str:
1935
        """
1936
        Return the class name will prefix __torch__.torch.classes
1937
        """
1938
        return f"__torch__.torch.classes.{self.class_name}"
1939

1940
    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1941
        return False
1942

1943
    def is_symint_like(self) -> bool:
1944
        return False
1945

1946
    def is_nullable(self) -> bool:
1947
        """
1948
        Assume a custom class is not nullable.
1949
        """
1950
        return False
1951

1952
    def is_list_like(self) -> ListType | None:
1953
        return None
1954

1955

1956
# List types specify that we may have multiples of an element.  We
1957
# also support explicit sizes on list types, but these have
1958
# some nontrivial semantics!  (However, for C++ API purposes, explicit
1959
# sizes are mostly erased from the type system.)
1960
#
1961
# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
1962
# int[] elaborates differently than bool[3]!
1963
@dataclass(frozen=True)
1964
class ListType(Type):
1965
    elem: Type
1966
    size: int | None
1967

1968
    def __str__(self) -> str:
1969
        size = f"{self.size}" if self.size else ""
1970
        return f"{self.elem}[{size}]"
1971

1972
    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1973
        return self.elem.is_base_ty_like(base_ty)
1974

1975
    def is_symint_like(self) -> bool:
1976
        return self.elem.is_symint_like()
1977

1978
    def is_nullable(self) -> bool:
1979
        return self.elem.is_nullable()
1980

1981
    def is_list_like(self) -> ListType | None:
1982
        return self
1983

1984

1985
@dataclass(frozen=True)
1986
class Argument:
1987
    # NB: I didn't put kwarg_only as a boolean field here, unlike
1988
    # c10::Argument, so that printing works correctly
1989

1990
    name: str
1991
    type: Type
1992
    default: str | None
1993

1994
    # The semantics of the annotation field are a little strange.
1995
    #
1996
    # Alias annotations parametrize Tensors (since Tensors are the only things
1997
    # that can alias.)  This motivates why I write Tensor(a!)?  (and not, for
1998
    # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
1999
    # which may be optional (i.e., the alias annotation should bind first to
2000
    # Tensor, before the optional postfix annotation).
2001
    #
2002
    # However, despite being a property of Tensor, we (and c10::Argument)
2003
    # store the annotation at the top level of the Argument, rather than
2004
    # inside the embedded Tensor type.  In the C++ version of this
2005
    # class, we then go through great lengths to mimic the type
2006
    # structure in the annotation structure so we can correlate
2007
    # annotations with types.
2008
    #
2009
    # Now, it turns out, in all applications in code generation, the
2010
    # structure of annotated types is very simple.  So we just hard
2011
    # code it here.  But if we ever do get anything more complex, this
2012
    # model will have to change!
2013
    annotation: Annotation | None
2014

2015
    @property
2016
    def alias_info(self) -> Annotation | None:
2017
        return self.annotation
2018

2019
    @staticmethod
2020
    def parse(arg: str) -> Argument:
2021
        name: str
2022
        default: str | None
2023
        assert " " in arg, f"illegal argument '{arg}'"
2024
        if "=" in arg:
2025
            assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
2026
            type_and_annot_and_name, default = arg.split("=")
2027
            type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
2028
            name_and_default = f"{name}={default}"
2029
        else:
2030
            type_and_annot, name_and_default = arg.rsplit(" ", 1)
2031
            name = name_and_default
2032
            default = None
2033
        # TODO: deduplicate annotation matching with Return
2034
        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2035
        annotation: Annotation | None
2036
        if match:
2037
            # If you update this, make sure the __str__ still works too
2038
            assert match.group(2) in [
2039
                "",
2040
                "?",
2041
                "[]",
2042
            ], "unrecognized alias analysis form with Tensor"
2043
            type_s = "Tensor" + match.group(2)
2044
            annotation = Annotation.parse(match.group(1))
2045
        else:
2046
            type_s = type_and_annot
2047
            annotation = None
2048
        type = Type.parse(type_s)
2049
        r = Argument(
2050
            name=name,
2051
            type=type,
2052
            default=default,
2053
            annotation=annotation,
2054
        )
2055
        assert str(r) == arg, f"{str(r)} != {arg}"
2056
        return r
2057

2058
    @property
2059
    def is_write(self) -> bool:
2060
        return self.annotation is not None and self.annotation.is_write
2061

2062
    def __str__(self) -> str:
2063
        type = f"{self.type}"
2064
        if self.annotation:
2065
            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2066
            type = type.replace("Tensor", f"Tensor({self.annotation})")
2067
        if self.name is None:
2068
            return type
2069
        else:
2070
            mb_default = ""
2071
            if self.default:
2072
                mb_default = f"={self.default}"
2073
            return f"{type} {self.name}{mb_default}"
2074

2075

2076
@dataclass(frozen=True)
2077
class Return:
2078
    name: str | None
2079
    type: Type
2080
    annotation: Annotation | None
2081

2082
    @property
2083
    def alias_info(self) -> Annotation | None:
2084
        return self.annotation
2085

2086
    @staticmethod
2087
    def parse(arg: str) -> Return:
2088
        name: str | None
2089
        if " " in arg:
2090
            type_and_annot, name = arg.rsplit(" ", 1)
2091
        else:
2092
            type_and_annot = arg
2093
            name = None
2094
        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2095
        annotation: Annotation | None
2096
        if match:
2097
            # If you update this, make sure the __str__ still works too
2098
            assert match.group(2) in [
2099
                "",
2100
                "?",
2101
                "[]",
2102
            ], "unrecognized alias analysis form with Tensor"
2103
            type_s = "Tensor" + match.group(2)
2104
            annotation = Annotation.parse(match.group(1))
2105
        else:
2106
            type_s = type_and_annot
2107
            annotation = None
2108
        type = Type.parse(type_s)
2109
        r = Return(
2110
            name=name,
2111
            type=type,
2112
            annotation=annotation,
2113
        )
2114
        assert str(r) == arg, f"{str(r)} != {arg}"
2115
        return r
2116

2117
    @property
2118
    def is_write(self) -> bool:
2119
        return self.annotation is not None and self.annotation.is_write
2120

2121
    def __str__(self) -> str:
2122
        type = f"{self.type}"
2123
        if self.annotation:
2124
            assert type in ["Tensor", "Tensor?", "Tensor[]"]
2125
            type = type.replace("Tensor", f"Tensor({self.annotation})")
2126
        if self.name is None:
2127
            return type
2128
        else:
2129
            return f"{type} {self.name}"
2130

2131

2132
# Represents the self argument for functions that may be methods
2133
@dataclass(frozen=True)
2134
class SelfArgument:
2135
    argument: Argument
2136

2137

2138
# Bundle of arguments that represent a TensorOptions.  This is mostly
2139
# relevant for the public C++ API but we bake it into the core data
2140
# model because other APIs often have to interact with it
2141
@dataclass(frozen=True)
2142
class TensorOptionsArguments:
2143
    dtype: Argument
2144
    layout: Argument
2145
    device: Argument
2146
    pin_memory: Argument
2147

2148
    def all(self) -> Sequence[Argument]:
2149
        return [self.dtype, self.layout, self.device, self.pin_memory]
2150

2151

2152
@dataclass(frozen=True)
2153
class Arguments:
2154
    # pre_self_positional is usually empty, but is notably non-empty
2155
    # for where.self, where the condition argument comes before the
2156
    # self argument
2157
    pre_self_positional: tuple[Argument, ...]
2158
    self_arg: SelfArgument | None
2159
    post_self_positional: tuple[Argument, ...]
2160

2161
    pre_tensor_options_kwarg_only: tuple[Argument, ...]
2162
    tensor_options: TensorOptionsArguments | None
2163
    # post_tensor_options is typically memory format, which should be
2164
    # part of tensor options but isn't right now, and is usually
2165
    # placed after the tensor options arguments
2166
    post_tensor_options_kwarg_only: tuple[Argument, ...]
2167

2168
    # Unlike in the previous codegen, we have factored out 'out' arguments
2169
    # in the canonical representation, removing them from kwarg
2170
    # arguments.  This choice is justified by numerous downstream
2171
    # transformations which treat out arguments specially; additionally,
2172
    # you can see that canonicity is not violated!
2173
    out: tuple[Argument, ...]  # these are also kwarg-only
2174

2175
    @property
2176
    def flat_non_out(self) -> Sequence[Argument]:
2177
        ret: list[Argument] = []
2178
        ret.extend(self.flat_positional)
2179
        ret.extend(self.flat_kwarg_only)
2180
        return ret
2181

2182
    @property
2183
    def flat_positional(self) -> Sequence[Argument]:
2184
        ret: list[Argument] = []
2185
        ret.extend(self.pre_self_positional)
2186
        if self.self_arg is not None:
2187
            ret.append(self.self_arg.argument)
2188
        ret.extend(self.post_self_positional)
2189
        return ret
2190

2191
    @property
2192
    def post_self_positional_mutable(self) -> Sequence[Argument]:
2193
        return [a for a in self.post_self_positional if a.is_write]
2194

2195
    # NB: doesn't contain out arguments
2196
    @property
2197
    def flat_kwarg_only(self) -> Sequence[Argument]:
2198
        ret: list[Argument] = []
2199
        ret.extend(self.pre_tensor_options_kwarg_only)
2200
        if self.tensor_options is not None:
2201
            ret.extend(self.tensor_options.all())
2202
        ret.extend(self.post_tensor_options_kwarg_only)
2203
        return ret
2204

2205
    @property
2206
    def flat_all(self) -> Sequence[Argument]:
2207
        ret: list[Argument] = []
2208
        ret.extend(self.flat_positional)
2209
        ret.extend(self.flat_kwarg_only)
2210
        ret.extend(self.out)
2211
        return ret
2212

2213
    @property
2214
    def non_out(
2215
        self,
2216
    ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2217
        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2218
        ret.extend(self.positional)
2219
        ret.extend(self.kwarg_only)
2220
        return ret
2221

2222
    @property
2223
    def positional(self) -> Sequence[Argument | SelfArgument]:
2224
        ret: list[Argument | SelfArgument] = []
2225
        ret.extend(self.pre_self_positional)
2226
        if self.self_arg is not None:
2227
            ret.append(self.self_arg)
2228
        ret.extend(self.post_self_positional)
2229
        return ret
2230

2231
    @property
2232
    def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
2233
        ret: list[Argument | TensorOptionsArguments] = []
2234
        ret.extend(self.pre_tensor_options_kwarg_only)
2235
        if self.tensor_options is not None:
2236
            ret.append(self.tensor_options)
2237
        ret.extend(self.post_tensor_options_kwarg_only)
2238
        return ret
2239

2240
    @property
2241
    def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2242
        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2243
        ret.extend(self.positional)
2244
        ret.extend(self.kwarg_only)
2245
        ret.extend(self.out)
2246
        return ret
2247

2248
    def mutable_arg_names(self) -> list[str]:
2249
        return [
2250
            a.name
2251
            for a in self.flat_all
2252
            if a.annotation is not None and a.annotation.is_write
2253
        ]
2254

2255
    def has_tensor_arg(self) -> bool:
2256
        return any(a.type.is_tensor_like() for a in self.flat_non_out)
2257

2258
    def has_symint_arg(self) -> bool:
2259
        return any(a.type.is_symint_like() for a in self.flat_non_out)
2260

2261
    def has_generator_arg(self) -> bool:
2262
        return any(a.type.is_generator_like() for a in self.flat_non_out)
2263

2264
    def signature(self, *, strip_default: bool = False) -> Arguments:
2265
        # dataclasses.replace could be used here, but it is less
2266
        # type safe so for now I've opted to type everything out
2267
        def strip_arg_annotation(a: Argument) -> Argument:
2268
            return Argument(
2269
                name=a.name,
2270
                type=a.type,
2271
                default=a.default if not strip_default else None,
2272
                annotation=None,
2273
            )
2274

2275
        return Arguments(
2276
            pre_self_positional=tuple(
2277
                map(strip_arg_annotation, self.pre_self_positional)
2278
            ),
2279
            self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
2280
            if self.self_arg is not None
2281
            else None,
2282
            post_self_positional=tuple(
2283
                map(strip_arg_annotation, self.post_self_positional)
2284
            ),
2285
            # Since TensorOptions are dropped, the post_tensor_options_kwargs are
2286
            # converted to pre_tensor_options_kwargs
2287
            pre_tensor_options_kwarg_only=tuple(
2288
                map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
2289
            )
2290
            + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
2291
            # TensorOptions are dropped in signature,
2292
            # so we can pair factory functions with their out= variants.
2293
            tensor_options=None,
2294
            post_tensor_options_kwarg_only=(),
2295
            # out arguments are dropped in signature
2296
            out=(),
2297
        )
2298

2299
    def remove_self_annotation(self) -> Arguments:
2300
        assert self.self_arg is not None
2301
        return dataclasses.replace(
2302
            self,
2303
            self_arg=SelfArgument(
2304
                dataclasses.replace(self.self_arg.argument, annotation=None)
2305
            ),
2306
        )
2307

2308
    def with_out_args(self, outs: list[Argument]) -> Arguments:
2309
        assert len(self.out) == 0
2310
        return dataclasses.replace(
2311
            self,
2312
            out=tuple(outs),
2313
        )
2314

2315
    @staticmethod
2316
    def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
2317
        positional: list[Argument] = []
2318
        kwarg_only: list[Argument] = []
2319
        out: list[Argument] = []
2320
        arguments_acc = positional
2321

2322
        # TODO: Use a real parser here; this will get bamboozled
2323
        # by signatures that contain things like std::array<bool, 2> (note the space)
2324
        for arg in args.split(", "):
2325
            if not arg:
2326
                continue
2327
            if arg == "*":
2328
                assert (
2329
                    arguments_acc is positional
2330
                ), "invalid syntax: kwarg-only specifier * can only occur once"
2331
                arguments_acc = kwarg_only
2332
                continue
2333
            parg = Argument.parse(arg)
2334
            # Currently, we rely directly on the invariant that there are NO
2335
            # kwarg-only mutating arguments.  If you want to relax this,
2336
            # we will need a more semantic way of matching that takes
2337
            # into account return arguments.  In that case, you will have
2338
            # to manage out computation a level up, in FunctionSchema.  See Note
2339
            # [is_out_fn]
2340
            if parg.annotation is not None and parg.annotation.is_write:
2341
                if arguments_acc is positional:
2342
                    pass  # do nothing
2343
                elif arguments_acc is kwarg_only:
2344
                    arguments_acc = out
2345
            else:
2346
                assert arguments_acc is not out
2347
            arguments_acc.append(parg)
2348

2349
        return positional, kwarg_only, out
2350

2351
    @staticmethod
2352
    def parse(args: str) -> Arguments:
2353
        """
2354
        Input: 'int x, int y, int z'
2355
        """
2356

2357
        # We do this in two phases.  First we parse into three
2358
        # main categories: positional, kwarg_only, out.
2359
        # Then, we reparse positional and kwarg_only to separate
2360
        # out the self argument and tensor options arguments.
2361

2362
        positional, kwarg_only, out = Arguments._preparse(args)
2363

2364
        # Split self argument
2365
        self_ix = None
2366
        for i, a in enumerate(positional):
2367
            if a.name == "self":
2368
                self_ix = i
2369
                break
2370
        pre_self_positional: list[Argument]
2371
        self_arg: SelfArgument | None
2372
        post_self_positional: list[Argument]
2373
        if self_ix is not None:
2374
            pre_self_positional = positional[:self_ix]
2375
            self_arg = SelfArgument(positional[self_ix])
2376
            post_self_positional = positional[self_ix + 1 :]
2377
        else:
2378
            pre_self_positional = []
2379
            self_arg = None
2380
            post_self_positional = positional
2381

2382
        # Group tensor options arguments
2383
        pre_tensor_options_kwarg_only: list[Argument] = []
2384
        tensor_options: TensorOptionsArguments | None = None
2385
        post_tensor_options_kwarg_only: list[Argument] = []
2386
        kwarg_only_acc = pre_tensor_options_kwarg_only
2387

2388
        def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
2389
            return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
2390

2391
        predicates = [  # order matters
2392
            pred("dtype", Type.parse("ScalarType")),
2393
            pred("layout", Type.parse("Layout")),
2394
            pred("device", Type.parse("Device")),
2395
            pred("pin_memory", Type.parse("bool")),
2396
        ]
2397

2398
        i = 0
2399
        while i < len(kwarg_only):
2400
            # If there is enough space...
2401
            if i <= len(kwarg_only) - len(predicates):
2402
                # And the next len(predicates) arguments look like TensorOptions arguments
2403
                if all(
2404
                    p(a)
2405
                    for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
2406
                ):
2407
                    assert kwarg_only_acc is pre_tensor_options_kwarg_only
2408
                    # Group them together as one argument
2409
                    tensor_options = TensorOptionsArguments(
2410
                        dtype=kwarg_only[i],
2411
                        layout=kwarg_only[i + 1],
2412
                        device=kwarg_only[i + 2],
2413
                        pin_memory=kwarg_only[i + 3],
2414
                    )
2415
                    i += len(predicates)
2416
                    kwarg_only_acc = post_tensor_options_kwarg_only
2417
                    continue
2418
            kwarg_only_acc.append(kwarg_only[i])
2419
            i += 1
2420

2421
        return Arguments(
2422
            pre_self_positional=tuple(pre_self_positional),
2423
            self_arg=self_arg,
2424
            post_self_positional=tuple(post_self_positional),
2425
            pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
2426
            tensor_options=tensor_options,
2427
            post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
2428
            out=tuple(out),
2429
        )
2430

2431
    def __str__(self) -> str:
2432
        all_arguments: list[str] = []
2433
        all_arguments.extend(map(str, self.flat_positional))
2434
        if self.flat_kwarg_only or self.out:
2435
            all_arguments.append("*")
2436
        all_arguments.extend(map(str, self.flat_kwarg_only))
2437
        all_arguments.extend(map(str, self.out))
2438
        return ", ".join(all_arguments)
2439

2440
    def __post_init__(self) -> None:
2441
        # TODO: These invariants are weirdly asymmetric?
2442
        # TODO: Fancier types?
2443
        if self.self_arg is None:
2444
            assert not self.pre_self_positional
2445
        if self.tensor_options is None:
2446
            assert not self.post_tensor_options_kwarg_only
2447

2448
        # We don't allow any of the following to have argument annotations,
2449
        # to keep things simple.
2450
        mutable_pre_self_positionals = [
2451
            a
2452
            for a in self.pre_self_positional
2453
            if a.annotation is not None and a.annotation.is_write
2454
        ]
2455
        assert (
2456
            len(mutable_pre_self_positionals) == 0
2457
        ), "mutable pre_self_positional arguments are not currently supported in the schema"
2458

2459

2460
# Names that validly are __iXXX__ indicating inplace operations.
2461
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
2462
# NB: PyTorch hasn't actually implemented all of these
2463
AUGMENTED_ASSIGNMENT_NAMES = [
2464
    "add",
2465
    "sub",
2466
    "mul",
2467
    "div",
2468
    "mod",
2469
    "pow",
2470
    "lshift",
2471
    "rshift",
2472
    "and",
2473
    "xor",
2474
    "or",
2475
]
2476

2477

2478
# A BaseOperatorName is what we think of the operator name, without
2479
# the overload name.  Unusually, we don't represent this as just a
2480
# string; instead, we directly represent a few important semantic
2481
# bits of information we derive from the string: namely whether
2482
# or not it's inplace (add_) and whether or not it's a double-underscore
2483
# method (__add__)
2484
@dataclass(frozen=True)
2485
class BaseOperatorName:
2486
    base: str
2487
    inplace: bool
2488
    dunder_method: bool
2489
    # Note [Overload Ambiguity With Functional Variants]
2490
    # A handful of operators have both a "mutable" and a "functional" variant.
2491
    # (native_batch_norm is a good example, although this isn't the case today).
2492
    # For those operators, the mutable and functional variant take in the same set of
2493
    # arguments, but have different alias annotations.
2494
    # this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
2495
    # given a set of input arguments.
2496
    #
2497
    # So instead of making the "functional" variant in this case a real overload, e.g:
2498
    #   native_batch_norm (mutable variant)
2499
    #   native_batch_norm.functional (functional variant)
2500
    # we make it a new base operator,
2501
    #   native_batch_norm_functional (functional variant)
2502
    #
2503
    # In an ideal world, we would probably invert this so the operators were:
2504
    #   native_batch_norm.mutable (mutable variant)
2505
    #   native_batch_norm (functional variant)
2506
    #
2507
    # Doing that is BC-breaking though, so we're stuck with the above modeling.
2508
    functional_overload: bool = False
2509

2510
    @staticmethod
2511
    def parse(op: str) -> BaseOperatorName:
2512
        assert op != ""
2513
        assert not op.endswith("_out"), (
2514
            "_out suffix is reserved and not permitted for operator names; "
2515
            "did you mean to specify an out overload name instead?"
2516
        )
2517
        m = re.match(r"^__([^_]+)__$", op)
2518
        if m is not None:
2519
            dunder_method = True
2520
            base = m.group(1)
2521
            if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
2522
                inplace = True
2523
                base = base[1:]
2524
            else:
2525
                inplace = False
2526
                # temporary, this is not intrinsically true but
2527
                # has been historically true for dunder methods
2528
                # we support  (but, if we ever got, say, __int__, this would
2529
                # be wrong!)
2530
                assert base[0] != "i"
2531
        else:
2532
            dunder_method = False
2533
            base = op
2534
            if base[-1] == "_":
2535
                inplace = True
2536
                base = base[:-1]
2537
            else:
2538
                inplace = False
2539

2540
        # See Note [Overload Ambiguity With Functional Variants]
2541
        functional_suffix = "_functional"
2542
        if base.endswith(functional_suffix):
2543
            functional_overload = True
2544
            base = base[: -len(functional_suffix)]
2545
            # This seems complicated and unnecessary, so banning dunder methods
2546
            # for now on ops that have a functional + mutable variant (like native_batch_norm).
2547
            assert not dunder_method and not inplace
2548
        else:
2549
            functional_overload = False
2550

2551
        r = BaseOperatorName(
2552
            base=base,
2553
            inplace=inplace,
2554
            dunder_method=dunder_method,
2555
            functional_overload=functional_overload,
2556
        )
2557
        assert str(r) == op, f"{str(r)} != {op}"
2558
        return r
2559

2560
    def __str__(self) -> str:
2561
        if self.dunder_method:
2562
            i = "i" if self.inplace else ""
2563
            return f"__{i}{self.base}__"
2564
        else:
2565
            i = (
2566
                "_"
2567
                if self.inplace
2568
                else "_functional"
2569
                if self.functional_overload
2570
                else ""
2571
            )
2572
            return f"{self.base}{i}"
2573

2574

2575
# Operator name is the base operator name along with the (typically not
2576
# user visible) overload string.
2577
@dataclass(frozen=True)
2578
class OperatorName:
2579
    name: BaseOperatorName
2580
    overload_name: str
2581

2582
    @staticmethod
2583
    def parse(op_name: str) -> OperatorName:
2584
        if "." in op_name:
2585
            name, overload_name = op_name.split(".", 1)
2586
        else:
2587
            name = op_name
2588
            overload_name = ""
2589
        r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
2590
        assert str(r) == op_name, f"{str(r)} != {op_name}"
2591
        return r
2592

2593
    def __str__(self) -> str:
2594
        if self.overload_name:
2595
            return f"{self.name}.{self.overload_name}"
2596
        else:
2597
            return f"{self.name}"
2598

2599
    # NB: This must be synchronized with the naming scheme in
2600
    # aten/src/ATen/templates/Operators.h
2601
    # Given a function schema "aten::op.overload(...)",
2602
    # If there is no overload name, this returns f"{op}"
2603
    # If there is an overload name, this returns f"{op}_{overload}"
2604
    def unambiguous_name(self) -> str:
2605
        if self.overload_name:
2606
            return f"{self.name}_{self.overload_name}"
2607
        else:
2608
            return f"{self.name}"
2609

2610
    def remove_inplace(self) -> OperatorName:
2611
        return OperatorName(
2612
            name=BaseOperatorName(
2613
                base=self.name.base,
2614
                inplace=False,
2615
                dunder_method=self.name.dunder_method,
2616
            ),
2617
            overload_name=self.overload_name,
2618
        )
2619

2620
    def with_overload(self, overload: str) -> OperatorName:
2621
        return OperatorName(
2622
            name=BaseOperatorName(
2623
                base=self.name.base,
2624
                inplace=False,
2625
                dunder_method=self.name.dunder_method,
2626
            ),
2627
            overload_name=overload,
2628
        )
2629

2630

2631
def gets_generated_out_inplace_wrapper(
2632
    f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
2633
) -> bool:
2634
    return (
2635
        f.func.kind() is not SchemaKind.functional
2636
        and not b.has_kernel(f)
2637
        and b.has_kernel(g.functional)
2638
    )
2639

2640

2641
# NativeFunction objects that are views (f.is_view_op returns True)
2642
# are added into a `NativeFunctionsViewGroup`, which we can use to
2643
# easily access the generated (optional) view_copy NativeFunction.
2644
# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
2645
# See Note [Codegen'd {view}_copy Operators]
2646
#
2647
# One property of this representation is that in order for a view-like op to be part of
2648
# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
2649
# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
2650
# but don't have corresponding aliasing `narrow.out` op.
2651
# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
2652
@dataclass(frozen=True)
2653
class NativeFunctionsViewGroup:
2654
    view: NativeFunction
2655
    # Note: the {view}_copy operator is optional because we currently don't generate copy variants
2656
    # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
2657
    # (we already get them "for free" through decomposition)
2658
    view_copy: NativeFunction | None
2659
    # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
2660
    view_inplace: NativeFunction | None
2661

2662
    def __post_init__(self) -> None:
2663
        assert self.view.is_view_op
2664
        if self.view_copy is None:
2665
            assert not gets_generated_view_copy(self.view), (
2666
                f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
2667
                " The codegen expects you to add a corresponding operator to native_functions.yaml:"
2668
                f" {get_view_copy_name(self.view)!s}."
2669
                " See Note [view_copy NativeFunctions] for details."
2670
            )
2671
        else:
2672
            assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
2673
            assert self.view.func.signature() == self.view_copy.func.signature(
2674
                strip_view_copy_name=True,
2675
            )
2676
            assert "view_copy" in self.view_copy.tags, (
2677
                f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
2678
                " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
2679
                " See Note [view_copy NativeFunction] for details."
2680
            )
2681
        if self.view_inplace is not None:
2682
            assert self.view.func.signature() == self.view_inplace.func.signature()
2683

2684
        if self.view.has_composite_implicit_autograd_kernel:
2685
            if self.view_inplace is not None:
2686
                assert self.view_inplace.has_composite_implicit_autograd_kernel, (
2687
                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2688
                    " both have CompositeImplicitAutograd kernels, or both not have composite kernels."
2689
                )
2690
        if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
2691
            if self.view_inplace is not None:
2692
                assert (
2693
                    self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
2694
                ), (
2695
                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2696
                    " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
2697
                )
2698

2699
    def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
2700
        yield self.view
2701
        if self.view_inplace is not None:
2702
            yield self.view_inplace
2703
        if self.view_copy is not None and include_copy:
2704
            yield self.view_copy
2705

2706
    @property
2707
    def root_name(self) -> str:
2708
        return self.view.root_name
2709

2710
    @property
2711
    def composite(self) -> bool:
2712
        # We currently assert that the "group" is consistent.
2713
        # If the view op is composite, then its view_inplace op is too.
2714
        return self.view.has_composite_implicit_autograd_kernel
2715

2716

2717
def gets_generated_view_copy(f: NativeFunction) -> bool:
2718
    # Only aliasing (view) operators get a copy variant.
2719
    if not f.is_view_op:
2720
        return False
2721
    # We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
2722
    # because we can let them decompose into base view ops.
2723
    if f.has_composite_implicit_autograd_kernel:
2724
        return False
2725
    # We also don't need to generate copy variants for inplace views.
2726
    if "inplace_view" in f.tags:
2727
        return False
2728
    # Assume ops ending in _inverse have manually-defined copy variants
2729
    # (e.g. slice_inverse() has the copy variant slice_scatter()).
2730
    # We -could- probably generate these as well, but the codegen will be
2731
    # slightly different, and hand-writing these few kernels keeps codegen
2732
    # complexity lower.
2733
    if f.func.name.name.base.endswith("_inverse"):
2734
        return False
2735
    return True
2736

2737

2738
# Given a NativeFunction that corresponds to a view op,
2739
# returns the OperatorName of the corresponding "copy" variant of the op.
2740
def get_view_copy_name(f: NativeFunction) -> OperatorName:
2741
    # Right now, when asking for a view op's corresponding "view_copy" name
2742
    # we assert for sanity that the op is allowed to have a generated view_copy variant.
2743
    # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
2744
    # However, narrow_copy() already exists as an op directly in native_functions.yaml.
2745
    # I'm hardcoding narrow_copy here for now to maintain the assert,
2746
    # But we could also just get rid of the assert.
2747
    list_of_ops_with_explicit_view_copy_operators = ["narrow"]
2748
    if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
2749
        assert gets_generated_view_copy(f)
2750

2751
    base_name = f"{f.func.name.name.base}_copy"
2752
    view_copy_name = OperatorName(
2753
        name=BaseOperatorName(
2754
            base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
2755
        ),
2756
        overload_name=f.func.name.overload_name,
2757
    )
2758
    return view_copy_name
2759

2760

2761
# Helper functions for parsing argument lists (both inputs and returns)
2762

2763

2764
def parse_returns(return_decl: str) -> tuple[Return, ...]:
2765
    """
2766
    Input: '()'
2767
    Output: []
2768
    """
2769
    if return_decl == "()":
2770
        return ()
2771
    if return_decl[0] == "(" and return_decl[-1] == ")":
2772
        return_decl = return_decl[1:-1]
2773
    return tuple(Return.parse(arg) for arg in return_decl.split(", "))
2774

2775

2776
# A Precompute instance consists of a map from kernel argument name
2777
# to the list of Argument instances that should replace that
2778
# kernel argument in the impl function.
2779
@dataclass(frozen=True)
2780
class Precompute:
2781
    # A map from kernel argument name -> a list of precomputed
2782
    # elements that replaces/supersedes it.
2783
    replace: dict[str, list[Argument]]
2784
    # List of precomputed args added without replacement
2785
    add: list[Argument]
2786

2787
    @staticmethod
2788
    def parse(src: object) -> Precompute:
2789
        assert isinstance(src, list)
2790

2791
        # src is a list of strings of the format:
2792
        #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
2793
        #   [{add decl}[, {add decl}, ...]]
2794
        # The last line is optional and contains the precomputed parameters that are
2795
        # added without replacement.
2796
        # The other lines are parsed to get the names of which precomputed elements
2797
        # should replace which kernel arguments.
2798
        add_args = []
2799
        if " -> " not in src[-1]:
2800
            add_list = src[-1].split(",")
2801
            add_args = [Argument.parse(name.strip()) for name in add_list]
2802
            src = src[:-1]
2803

2804
        replace = {}
2805
        for raw_replace_item in src:
2806
            assert isinstance(raw_replace_item, str)
2807
            assert " -> " in raw_replace_item, (
2808
                "precomputed parameters without replacement"
2809
                " are allowed only in the last line"
2810
            )
2811

2812
            arg, with_list_raw = raw_replace_item.split(" -> ")
2813
            assert (
2814
                " " not in arg
2815
            ), f"illegal kernel param name '{arg}' in precomputed parameters'"
2816
            with_list = with_list_raw.split(",")
2817
            with_list_args = [Argument.parse(name.strip()) for name in with_list]
2818
            replace[arg] = with_list_args
2819

2820
        r = Precompute(replace=replace, add=add_args)
2821
        assert r.to_list() == src, "r.to_list() != src"
2822
        return r
2823

2824
    def __post_init__(self) -> None:
2825
        # the template parameters are upper so if these are the
2826
        # same then it is ambiguous
2827
        for a in self.add:
2828
            assert a.name.upper() != a.name
2829
        for args in self.replace.values():
2830
            for a in args:
2831
                assert a.name.upper() != a.name
2832

2833
    def to_list(self) -> list[str]:
2834
        replace_list = []
2835
        for kernel_param, replacement_params in self.replace.items():
2836
            replacements = ", ".join(str(param) for param in replacement_params)
2837
            replace_list.append(f"{kernel_param} -> {replacements}")
2838

2839
        return replace_list
2840

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.