pytorch

Форк
0
/
gen.py 
2985 строк · 110.9 Кб
1
from __future__ import annotations
2

3
import argparse
4
import functools
5
import json
6
import os
7
from collections import defaultdict, namedtuple, OrderedDict
8
from dataclasses import dataclass, field
9
from pathlib import Path
10
from typing import Any, Callable, Literal, Sequence, TypeVar
11

12
import yaml
13

14
import torchgen.api.dispatcher as dispatcher
15
import torchgen.api.meta as meta
16
import torchgen.api.native as native
17
import torchgen.api.structured as structured
18
import torchgen.dest as dest
19
from torchgen.aoti.fallback_ops import inductor_fallback_ops
20
from torchgen.api import cpp
21
from torchgen.api.translate import translate
22
from torchgen.api.types import (
23
    Binding,
24
    CppSignature,
25
    CppSignatureGroup,
26
    DispatcherSignature,
27
    NamedCType,
28
    NativeSignature,
29
    SpecialArgName,
30
)
31
from torchgen.context import (
32
    method_with_native_function,
33
    native_function_manager,
34
    with_native_function,
35
    with_native_function_and_indices,
36
)
37
from torchgen.gen_aoti_c_shim import (
38
    gen_aoti_c_shim,
39
    gen_static_dispatch_backend_call_signature,
40
    get_fallback_op_name,
41
    get_header_for_aoti,
42
)
43
from torchgen.gen_functionalization_type import (
44
    gen_functionalization_definition,
45
    gen_functionalization_registration,
46
    gen_functionalization_view_inverse_declaration,
47
    GenCompositeViewCopyKernel,
48
)
49
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
50
from torchgen.model import (
51
    Argument,
52
    BackendIndex,
53
    BackendMetadata,
54
    BaseOperatorName,
55
    DEFAULT_KERNEL_NAMESPACE,
56
    DispatchKey,
57
    FRAGMENT_NAMESPACES,
58
    FunctionSchema,
59
    is_cuda_dispatch_key,
60
    is_generic_dispatch_key,
61
    is_ufunc_dispatch_key,
62
    Location,
63
    NativeFunction,
64
    NativeFunctionsGroup,
65
    NativeFunctionsViewGroup,
66
    OperatorName,
67
    OptionalType,
68
    SchemaKind,
69
    SelfArgument,
70
    STRUCTURED_DISPATCH_KEYS,
71
    TensorOptionsArguments,
72
    Type,
73
    Variant,
74
    ViewSchemaKind,
75
)
76
from torchgen.native_function_generation import (
77
    add_generated_native_functions,
78
    gen_composite_functional_kernel,
79
    gen_composite_out_kernel,
80
    pre_group_native_functions,
81
)
82
from torchgen.selective_build.selector import SelectiveBuilder
83
from torchgen.utils import (
84
    assert_never,
85
    concatMap,
86
    context,
87
    FileManager,
88
    make_file_manager,
89
    mapMaybe,
90
    NamespaceHelper,
91
    Target,
92
)
93
from torchgen.yaml_utils import YamlDumper, YamlLoader
94

95

96
T = TypeVar("T")
97

98
# Welcome to the ATen code generator v2!  The ATen code generator is
99
# responsible for parsing native_functions.yaml and then generating
100
# various generated files (e.g., TypeDefault.cpp) based on the operators
101
# defined in this file.  This means that the code generator knows how to
102
# parse function schema, and then translate this into various C++ types
103
# and boilerplate code.
104
#
105
# Some things to know about this file when you modify it:
106
#
107
# - This file has STRICT mypy typechecking.  Typecheck it with
108
#   `mypy --config mypy-strict.ini` in the root source directory
109
#
110
# - Most of the heavy lifting lives in external modules:
111
#   - 'model' has the data model for native_functions.yaml.  The classes
112
#     in those file represent what you see when you look at
113
#     a native_functions.yaml
114
#   - 'api' has conversions for how to translate JIT schema into
115
#     the various C++ APIs that the codegen interacts with.  There
116
#     are in fact THREE different C++ APIs: the public C++ API,
117
#     the dispatcher API, and the legacy dispatcher API.  See each
118
#     of these respective files for more information
119

120
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
121
#
122
#                         HELPER FUNCTIONS
123
#
124
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
125

126

127
# A custom loader for YAML to let us also keep track of line numbers
128
# of each entry in the YAML file
129
class LineLoader(YamlLoader):
130
    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
131
        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
132
        # Add 1 so line numbering starts at 1
133
        mapping["__line__"] = node.start_mark.line + 1
134
        return mapping
135

136

137
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
138
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
139

140

141
_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
142
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
143

144

145
def parse_native_yaml_struct(
146
    es: object,
147
    valid_tags: set[str],
148
    ignore_keys: set[DispatchKey] | None = None,
149
    path: str = "<stdin>",
150
    skip_native_fns_gen: bool = False,
151
) -> ParsedYaml:
152
    assert isinstance(es, list)
153
    rs: list[NativeFunction] = []
154
    bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
155
    for e in es:
156
        assert isinstance(e, dict), f"expected to be dict: {e}"
157
        assert isinstance(e.get("__line__"), int), e
158
        loc = Location(path, e["__line__"])
159
        funcs = e.get("func")
160
        assert funcs is not None, f"missed 'func' in {e}"
161
        with context(lambda: f"in {loc}:\n  {funcs}"):
162
            func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
163
            rs.append(func)
164
            BackendIndex.grow_index(bs, m)
165
    error_check_native_functions(rs)
166
    # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
167
    indices: dict[DispatchKey, BackendIndex] = defaultdict(
168
        lambda: BackendIndex(
169
            dispatch_key=DispatchKey.Undefined,
170
            use_out_as_primary=True,
171
            external=False,
172
            device_guard=False,
173
            # I'm actually not sure about this; undefined could be hit on
174
            # empty TensorList, hypothetically that could have sizes in it
175
            index={},
176
        )
177
    )
178
    if not skip_native_fns_gen:
179
        add_generated_native_functions(rs, bs)
180
    for k, v in bs.items():
181
        # All structured in-tree operators are implemented in terms of their out operator.
182
        indices[k] = BackendIndex(
183
            dispatch_key=k,
184
            use_out_as_primary=True,
185
            external=False,
186
            # Only cuda-like devices in tree require device guards
187
            device_guard=is_cuda_dispatch_key(k),
188
            index=v,
189
        )
190
    return ParsedYaml(rs, indices)
191

192

193
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
194
    assert isinstance(es, list)
195
    rs: set[str] = set()
196
    for e in es:
197
        assert isinstance(e.get("__line__"), int), e
198
        loc = Location(path, e["__line__"])
199
        tags = e.get("tag")
200
        with context(lambda: f"in {loc}:\n  {tags}"):
201
            e_i = e.copy()
202
            name = e_i.pop("tag")
203
            desc = e_i.pop("desc", "")
204
            # ensure that each tag has a non-empty description
205
            assert desc != ""
206
            rs.add(name)
207
    return rs
208

209

210
@functools.lru_cache(maxsize=None)
211
def parse_tags_yaml(path: str) -> set[str]:
212
    global _GLOBAL_PARSE_TAGS_YAML_CACHE
213
    if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
214
        with open(path) as f:
215
            es = yaml.load(f, Loader=LineLoader)
216
            _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
217

218
    return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
219

220

221
def parse_native_yaml(
222
    path: str,
223
    tags_yaml_path: str,
224
    ignore_keys: set[DispatchKey] | None = None,
225
    *,
226
    skip_native_fns_gen: bool = False,
227
    loaded_yaml: object | None = None,
228
) -> ParsedYaml:
229
    global _GLOBAL_PARSE_NATIVE_YAML_CACHE
230
    if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
231
        valid_tags = parse_tags_yaml(tags_yaml_path)
232

233
        # if a loaded yaml is provided, use that instead of reading from path
234
        if loaded_yaml is None:
235
            with open(path) as f:
236
                es = yaml.load(f, Loader=LineLoader)
237
        else:
238
            es = loaded_yaml
239

240
        _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
241
            es,
242
            valid_tags,
243
            ignore_keys,
244
            path=path,
245
            skip_native_fns_gen=skip_native_fns_gen,
246
        )
247

248
    return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
249

250

251
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
252
# Assertions here are meant to be performed across NativeFunctions.
253
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
254
    func_map: dict[OperatorName, NativeFunction] = {}
255
    base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
256
    for f in funcs:
257
        func_map[f.func.name] = f
258
        base_func_map[f.func.name.name].append(f)
259
    for f in funcs:
260
        if f.structured_delegate is not None:
261
            delegate_func = func_map.get(f.structured_delegate)
262
            assert delegate_func is not None, (
263
                f"{f.func.name} is marked as a structured_delegate pointing to "
264
                f"{f.structured_delegate}, but {f.structured_delegate} is missing."
265
            )
266
            assert delegate_func.structured, (
267
                f"{f.func.name} is marked as a structured_delegate pointing to "
268
                f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
269
                f"Consider adding 'structured=True' to the delegated operator"
270
            )
271
        # See Note [resize_ in Functionalization]
272
        # resize_() is technically an inplace view op (and therefore needs the tag),
273
        # but it would be overkill to add a true "view" variant of resize.
274
        # Instead, resize_() gets special treatment in functionalization,
275
        # and we have a resize() op that is non-aliasing + functional.
276
        if (
277
            "inplace_view" in f.tags
278
            and str(f.func.name) != "resize_"
279
            and str(f.func.name) != "resize_as_"
280
            and str(f.func.name.name) != "set_"
281
        ):
282
            base_name = f.func.name.name
283
            assert base_name.inplace, (
284
                f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
285
                "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
286
            )
287
            out_of_place_base_name = BaseOperatorName(
288
                base_name.base, False, base_name.dunder_method
289
            )
290
            assert len(base_func_map[out_of_place_base_name]) > 0, (
291
                f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
292
                f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
293
            )
294

295

296
def cpp_string(s: str) -> str:
297
    """Convert a python string into a c++ string literal"""
298
    s = s.replace("\\", "\\\\")
299
    s = s.replace('"', '\\"')
300
    s = s.replace("\a", "\\a")
301
    s = s.replace("\b", "\\b")
302
    s = s.replace("\f", "\\f")
303
    s = s.replace("\n", "\\n")
304
    s = s.replace("\v", "\\v")
305
    s = s.replace("\t", "\\t")
306
    return f'"{s}"'
307

308

309
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
310
#
311
#                        C++ CODE GENERATION
312
#
313
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
314

315
# Most functions in this section are curried: they consist of a function
316
# that takes some parameters (e.g., what is to be generated) which itself
317
# returns a function that actually maps NativeFunction to the code
318
# to be generated.  This pattern makes it convenient to use map, concatMap
319
# and similar functional combinators.
320

321

322
def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
323
    if len(backends) == 0:
324
        return []
325
    else:
326
        return [backend.dispatch_key for backend in backends] + [
327
            DispatchKey.CompositeImplicitAutograd,
328
            DispatchKey.CompositeImplicitAutogradNestedTensor,
329
            DispatchKey.CompositeExplicitAutograd,
330
            DispatchKey.CompositeExplicitAutogradNonFunctional,
331
        ]
332

333

334
def get_static_dispatch_backend(
335
    f: NativeFunction, backend_index: BackendIndex
336
) -> DispatchKey | None:
337
    if f.structured_delegate is not None or backend_index.has_kernel(f):
338
        # TODO: for ops with structured_delegate it should check the dispatch table of
339
        # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
340
        # so we always dispatch to the `backend`, but this could be wrong when we
341
        # migrate math/default_backend ops to use structured delegate.
342
        return backend_index.dispatch_key
343
    elif f.has_composite_explicit_autograd_kernel:
344
        return DispatchKey.CompositeExplicitAutograd
345
    elif f.has_composite_explicit_autograd_non_functional_kernel:
346
        return DispatchKey.CompositeExplicitAutogradNonFunctional
347
    elif f.has_composite_implicit_autograd_kernel:
348
        return DispatchKey.CompositeImplicitAutograd
349
    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
350
        return DispatchKey.CompositeImplicitAutogradNestedTensor
351
    return None
352

353

354
def static_dispatch_ops_header(
355
    f: NativeFunction, backend_index: list[BackendIndex]
356
) -> str | None:
357
    if backend_index is None or f.manual_kernel_registration:
358
        return None
359

360
    output = []
361
    for index in backend_index:
362
        dispatch_key = get_static_dispatch_backend(f, index)
363
        if dispatch_key is not None:
364
            output.append(
365
                f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
366
            )
367
    return "\n".join(output)
368

369

370
def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
371
    return [
372
        f"#include <ATen/{dispatch_key}Functions.h>"
373
        for dispatch_key in static_dispatch_keys(backends)
374
    ]
375

376

377
# Translates arguments of `sig` to CppSignature bindings.
378
# Note that we have a special case for `memory_format` argument and this case is not covered by
379
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
380
def translate_args(
381
    sig: CppSignature | DispatcherSignature,
382
    cpp_sig: CppSignature,
383
) -> str:
384
    # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
385
    def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
386
        output_bindings: list[Binding] = []
387
        for binding in input_bindings:
388
            if binding.name == "memory_format":
389
                spl_mem_format_binding = Binding(
390
                    nctype=NamedCType(
391
                        SpecialArgName.possibly_redundant_memory_format,
392
                        binding.nctype.type,
393
                    ),
394
                    name=binding.name,
395
                    default=binding.default,
396
                    argument=binding.argument,
397
                )
398
                output_bindings.append(spl_mem_format_binding)
399
            else:
400
                output_bindings.append(binding)
401
        return output_bindings
402

403
    src_bindings = list(sig.arguments())
404
    goal_bindings = list(cpp_sig.arguments())
405
    # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
406
    # get memory_format bindings of dispatcher signature to have the same NCType as well
407
    for arg in goal_bindings:
408
        if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
409
            src_bindings = add_spl_memory_format_binding(src_bindings)
410
            break
411
    exprs = translate(src_bindings, goal_bindings)
412
    return ", ".join(a.expr for a in exprs)
413

414

415
def generate_static_dispatch_backend_call(
416
    sig: CppSignature | DispatcherSignature,
417
    f: NativeFunction,
418
    backend_index: BackendIndex,
419
) -> str:
420
    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
421
    name = cpp_sig.name()
422
    exprs = translate_args(sig, cpp_sig)
423
    backend_metadata = backend_index.get_kernel(f)
424
    kernel_ns = (
425
        backend_metadata.cpp_namespace
426
        if backend_metadata and backend_metadata.cpp_namespace
427
        else DEFAULT_KERNEL_NAMESPACE
428
    )
429
    ns = kernel_ns.replace("::native", "")
430
    return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
431

432

433
def generate_static_dispatch_fallback_call(
434
    sig: CppSignature | DispatcherSignature,
435
    f: NativeFunction,
436
    backend_indices: list[BackendIndex],
437
) -> str:
438
    cpp_sigs = CppSignatureGroup.from_native_function(
439
        f, method=False, fallback_binding=False
440
    )
441
    if sig.symint and f.func.has_symint():
442
        cpp_sig = cpp_sigs.symint_signature
443
    else:
444
        cpp_sig = cpp_sigs.signature
445
    assert cpp_sig is not None
446
    name = cpp_sig.name()
447
    exprs = translate_args(sig, cpp_sig)
448
    ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
449
    if f.has_composite_explicit_autograd_kernel:
450
        return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
451
    elif f.has_composite_explicit_autograd_non_functional_kernel:
452
        return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
453
    elif f.has_composite_implicit_autograd_kernel:
454
        return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
455
    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
456
        return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
457
    else:
458
        return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
459
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
460

461

462
def static_dispatch(
463
    sig: CppSignature | DispatcherSignature,
464
    f: NativeFunction,
465
    backend_indices: list[BackendIndex],
466
) -> str:
467
    """
468
    For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
469
    backends exsit, fallback to static dispatch by determining dispatch key from inputs.
470
    Arguments:
471
        sig: A CppSignature or DispatcherSignature for this native function we want to use.
472
        f: NativeFunction to generate static dispatch.
473
        backend_indices: All available backends.
474
    Return:
475
        C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
476
    """
477
    if len(backend_indices) == 0 or f.manual_kernel_registration:
478
        return ""
479

480
    keys = [
481
        b
482
        for b in backend_indices
483
        if b.has_kernel(f)
484
        or (
485
            f.structured_delegate is not None
486
            and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
487
        )
488
    ]
489
    if len(keys) == 1:
490
        return generate_static_dispatch_backend_call(sig, f, keys[0])
491
    elif len(keys) == 0:
492
        return generate_static_dispatch_fallback_call(sig, f, backend_indices)
493

494
    native_tensor_args = [
495
        a.name
496
        for a in sig.arguments()
497
        if isinstance(a.argument, SelfArgument)
498
        or isinstance(a.argument, Argument)
499
        and a.argument.type.is_tensor_like()
500
    ]
501
    tensor_args = ", ".join(native_tensor_args)
502
    tensor_opts = f.func.arguments.tensor_options
503

504
    stmts = []
505
    subexprs: list[str] = []
506
    if tensor_opts is not None:
507
        subexprs.append(
508
            "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
509
        )
510
    if tensor_args != "":
511
        subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
512
    stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
513
    stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
514

515
    dispatch_code = []
516
    for index in keys:
517
        dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
518
        dispatch_code.append(
519
            f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
520
        )
521

522
    fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
523
    connector = "\n\t\t"
524

525
    return f"""
526
    {connector.join(stmts)}
527
    switch (_dk) {{
528
        {connector.join(dispatch_code)}
529
        default:
530
            {fallback}
531
    }}
532
    """
533

534

535
# Generates RegisterSchema.cpp.  Depending on the selector, either
536
# all schemas are registered, or only some are (in the case of
537
# selective build)
538
@dataclass(frozen=True)
539
class RegisterSchema:
540
    selector: SelectiveBuilder
541
    known_tags: dict[str, int] = field(default_factory=dict)
542

543
    @method_with_native_function
544
    def __call__(self, f: NativeFunction) -> str | None:
545
        if not self.selector.is_native_function_selected(f):
546
            return None
547
        tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
548
        if tags == "{}":
549
            return f"m.def({cpp_string(str(f.func))}, {{}});\n"
550
        maybe_tags = ""
551
        if tags not in self.known_tags:
552
            idx = len(self.known_tags)
553
            self.known_tags[tags] = idx
554
            maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
555
        return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
556

557

558
# Generates Operators.h and Operators.cpp.
559
# These provide macros that, given an operator and overload name, allow users
560
# to access an "un-overloaded" function version of the operator. This
561
# is useful for extension writers who want to (1) want to decltype the operator
562
# and (2) don't want to worry about method-only operators.
563
@dataclass(frozen=True)
564
class ComputeOperators:
565
    target: Literal[Target.DECLARATION, Target.DEFINITION]
566
    static_dispatch_backend_indices: list[BackendIndex]
567

568
    @method_with_native_function
569
    def __call__(self, f: NativeFunction) -> str:
570
        sig = DispatcherSignature.from_schema(f.func)
571
        name = f.func.name.unambiguous_name()
572

573
        if self.target is Target.DECLARATION:
574
            # Note [The ATen Operators API]
575
            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
576
            # metadata about each operator + entry points into the Dispatcher.
577
            # The C++ function, method, and redispatch API's are all implemented as wrappers
578
            # into various bits of the structs defined here.
579
            #
580
            # Important characteristics about the Operators API:
581
            # (1) It follows the Dispatcher API.
582
            #     This is kind of necessary to avoid overhead.
583
            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
584
            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
585
            # (2) Overload names are disambiguated.
586
            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
587
            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
588
            # (3) No argument defaulting is allowed.
589
            #     This is more of an implementation detail to avoid #include cycles,
590
            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
591
            # (4) manual_cpp_bindings and faithful names are not included in the API.
592
            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
593
            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
594
            #     They're implemented as wrappers in Functions.h that call into the actual operators
595
            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
596
            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
597
            return f"""
598
struct TORCH_API {name} {{
599
  using schema = {sig.type()};
600
  using ptr_schema = schema*;
601
  // See Note [static constexpr char* members for windows NVCC]
602
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
603
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
604
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
605
  static {sig.defn(name="call", is_redispatching_fn=False)};
606
  static {sig.defn(name="redispatch", is_redispatching_fn=True)};
607
}};"""
608

609
        elif self.target is Target.DEFINITION:
610
            defns = f"""
611
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
612
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
613
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
614

615
// aten::{f.func}
616
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
617
  return c10::Dispatcher::singleton()
618
      .findSchemaOrThrow({name}::name, {name}::overload_name)
619
      .typed<{name}::schema>();
620
}}
621
"""
622
            for is_redispatching_fn in [False, True]:
623
                if is_redispatching_fn:
624
                    dispatcher_exprs_str = ", ".join(
625
                        ["dispatchKeySet"] + [a.name for a in sig.arguments()]
626
                    )
627
                    method_base = "redispatch"
628
                else:
629
                    dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
630
                    method_base = "call"
631

632
                dispatcher_call = method_base
633
                method_name = f"{name}::{method_base}"
634

635
                fn_body = f"""
636
    static auto op = create_{name}_typed_handle();
637
    return op.{dispatcher_call}({dispatcher_exprs_str});"""
638

639
                if (
640
                    not is_redispatching_fn
641
                    and len(self.static_dispatch_backend_indices) > 0
642
                ):
643
                    # call() should go through static dispatch
644
                    fn_body = static_dispatch(
645
                        sig, f, backend_indices=self.static_dispatch_backend_indices
646
                    )
647
                defns += f"""
648
// aten::{f.func}
649
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
650
    {fn_body}
651
}}
652
"""
653
            return defns
654
        else:
655
            assert_never(self.target)
656

657

658
# Generates Functions.h, which provides the functional public C++ API,
659
# and the scaffolding to call into the dispatcher from these functions.
660
@dataclass(frozen=True)
661
class ComputeFunction:
662
    @method_with_native_function
663
    def __call__(self, f: NativeFunction) -> str | None:
664
        sig_group = CppSignatureGroup.from_native_function(
665
            f, method=False, fallback_binding=f.manual_cpp_binding
666
        )
667
        has_symint = f.func.has_symint()
668

669
        result = ""
670
        for sig in sig_group.signatures():
671
            # See Note [The ATen Operators API]
672
            target_sig = DispatcherSignature.from_schema(f.func)
673
            exprs = translate(sig.arguments(), target_sig.arguments())
674
            exprs_str = ", ".join([e.expr for e in exprs])
675

676
            if sig.symint:
677
                intlike_t = "c10::SymInt"
678
            else:
679
                intlike_t = "int64_t"
680

681
            if Variant.function in f.variants:
682
                result += f"""
683
// aten::{f.func}
684
inline {sig.decl()} {{
685
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
686
}}"""
687

688
            # The template function can be used from template situations
689
            # where you want to switch between the symint or not version
690
            # depending on a template argument
691
            #
692
            # NB: we ALWAYS generate this even for methods.  But we put it in
693
            # this header so it can take advantage of per-op headers
694
            if has_symint:
695
                result += f"""
696
namespace symint {{
697
  template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
698
  {sig.decl(suppress_symint_suffix=True)} {{
699
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
700
  }}
701
}}
702
"""
703
        return result
704

705

706
# Generates TensorBody.h. This file provides the object-oriented (method-based)
707
# public C++ API, and the scaffolding to call into the dispatcher from these functions.
708
@dataclass(frozen=True)
709
class ComputeTensorMethod:
710
    target: Literal[Target.DECLARATION, Target.DEFINITION]
711
    static_dispatch_backend_indices: list[BackendIndex]
712

713
    @method_with_native_function
714
    def __call__(self, f: NativeFunction) -> str | None:
715
        if Variant.method not in f.variants:
716
            return None
717

718
        assert not f.func.is_out_fn()
719
        assert f.func.arguments.self_arg is not None
720

721
        sig_group = CppSignatureGroup.from_native_function(
722
            f, method=True, fallback_binding=f.manual_cpp_binding
723
        )
724

725
        if self.target is Target.DECLARATION:
726
            result = ""
727
            for sig in sig_group.signatures():
728
                result += f"{sig.decl()} const;\n"
729
            return result
730

731
        if self.target is not Target.DEFINITION:
732
            assert_never(self.target)
733

734
        result = ""
735

736
        for sig in sig_group.signatures():
737
            target_sig = DispatcherSignature.from_schema(f.func)
738
            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
739
            exprs_str = ", ".join([e.expr for e in exprs])
740

741
            result += f"""
742
// aten::{f.func}
743
inline {sig.defn(prefix="Tensor::")} const {{
744
    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
745
}}
746
"""
747

748
        return result
749

750

751
# Generates RedispatchFunctions.h.
752
# This is similar to the C++ API defined in Functions.h, but provides access
753
# to the dispatcher's redispatch API.
754
@dataclass(frozen=True)
755
class ComputeRedispatchFunction:
756
    @method_with_native_function
757
    def __call__(self, f: NativeFunction) -> str | None:
758
        # We unconditionally generate function variants of the redispatch API.
759
        # This is mainly because we can namespace functions separately, but not methods,
760
        sig_group = CppSignatureGroup.from_native_function(
761
            f, method=False, fallback_binding=f.manual_cpp_binding
762
        )
763

764
        result = ""
765
        for sig in sig_group.signatures():
766
            target_sig = DispatcherSignature.from_schema(f.func)
767
            exprs = translate(sig.arguments(), target_sig.arguments())
768
            exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
769

770
            result += f"""
771
// aten::{f.func}
772
inline {sig.decl(is_redispatching_fn=True)} {{
773
    return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
774
}}
775
"""
776

777
        return result
778

779

780
# Generates ATenOpList.cpp, a runtime accessible list of all aten
781
# operators.
782
# TODO: This was historically used to help some JIT interop code
783
# figure out whether or not to treat aten namespace'd operators
784
# one way or another, we should reevaluate if this is actually needed.
785
@with_native_function
786
def compute_aten_op(f: NativeFunction) -> str:
787
    return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
788

789

790
# Generates MetaFunctions.h
791
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
792
    if not g.structured:
793
        return None
794
    with native_function_manager(g.out):
795
        name = meta.name(g)
796
        args = structured.meta_arguments(g)
797
        args_str = ", ".join(a.decl() for a in args)
798
        parent_class = g.out.structured_inherits
799
        if parent_class is None:
800
            parent_class = "at::impl::MetaBase"
801
        meta_return = "void"
802
        precomputed = g.out.precomputed if g.structured else None
803

804
        if precomputed:
805
            # Generate the template declaration with one bool parameter for each
806
            # precomputed element. Each parameter is true if the corresponding (in
807
            # terms of position) precomputed element has been set.
808
            precomputed_values = [*precomputed.replace.values(), precomputed.add]
809
            precomputed_elements = [
810
                elem for replace_list in precomputed_values for elem in replace_list
811
            ]
812
            precomputed_template_parameters = [
813
                elem.name.upper() for elem in precomputed_elements
814
            ]
815
            precomputed_template_params_str = ", ".join(
816
                f"bool {param} = false" for param in precomputed_template_parameters
817
            )
818
            precompute_template_decl = f"template <{precomputed_template_params_str}>"
819

820
            # Generate a string containing declarations of all precomputed elements.
821
            precomputed_elements_with_cpp_types = [
822
                structured.argument_type(elem, binds=elem.name)
823
                for elem in precomputed_elements
824
            ]
825

826
            precomputed_elements_decl = ";\n".join(
827
                f"{elem.cpp_type(strip_ref=True)} {elem.name}"
828
                for elem in precomputed_elements_with_cpp_types
829
            )
830

831
            # Generate "setter" methods for each precomputed element. Each method will return
832
            # a new instance of precompute_out with the template parameter that corresponds to
833
            # the member set by the method to true (to indicate that it has been set).
834
            setter_methods = []
835
            for i, elem in enumerate(precomputed_elements):
836
                # Generate the signature. The return type will be the same
837
                # as the type of `this` but with the template parameter
838
                # corresponding to the element set by this method set to true.
839
                # The assert generated below will ensure that this template
840
                # parameter is false on the type of `this`.
841
                return_ty_templates = ", ".join(
842
                    precomputed_template_parameters[:i]
843
                    + ["true"]
844
                    + precomputed_template_parameters[i + 1 :]
845
                )
846
                return_ty = f"precompute_out<{return_ty_templates}>"
847
                elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
848
                    strip_ref=True
849
                )
850
                signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
851

852
                # Generate an assert which checks that the
853
                # template parameter corresponding to the precomputed
854
                # element that is set by this method is false on the
855
                # class corresponding to the object that `this` points to.
856
                # This ensures that each element can be set only once.
857
                assert_msg = f'"{elem.name} already set"'
858
                assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
859

860
                # Generate the new object construction block. All state
861
                # except the element that this method sets is copied from the
862
                # object that `this` points to. The value for the element that
863
                # the method sets is taken from a method parameter.
864
                construction_stmts = []
865
                construction_stmts.append(f"{return_ty} ret;")
866

867
                for j, elem in enumerate(precomputed_elements):
868
                    if i == j:
869
                        construction_stmts.append(f"ret.{elem.name} = value;")
870
                    else:
871
                        construction_stmts.append(
872
                            f"ret.{elem.name} = this->{elem.name};"
873
                        )
874

875
                construction_stmts.append("return ret;")
876
                construction_block = "\n".join(construction_stmts)
877

878
                setter_methods.append(
879
                    f"""
880
                    {signature} {{
881
                        {assert_stmt}
882
                        {construction_block}
883
                    }}
884
                """
885
                )
886
            setter_methods_decl = "\n".join(setter_methods)
887

888
            # Meta should return an instance of the struct containing the precomputed elements.
889
            meta_return_template_params = ", ".join(
890
                ["true"] * len(precomputed_template_parameters)
891
            )
892
            # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
893
            # type (which has a variable number of template parameters).
894
            meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
895
            meta_return = "meta_return_ty"
896
            precomputed_decl = f"""
897
                {precompute_template_decl}
898
                struct TORCH_API precompute_out {{
899
                    {setter_methods_decl}
900
                    {precomputed_elements_decl};
901
            }};"""
902
        else:
903
            meta_return_typedef = ""
904
            precomputed_decl = ""
905

906
        return f"""\
907
struct TORCH_API structured_{name} : public {parent_class} {{
908
    {precomputed_decl}
909
    {meta_return_typedef}
910
    {meta_return} meta({args_str});
911
}};
912
"""
913

914

915
def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
916
    name = str(f.func.name.name)
917
    if name.endswith("_like") or name.startswith("new_"):
918
        return False
919
    if f.func.arguments.tensor_options is None:
920
        return False
921
    return selector.is_native_function_selected(f)
922

923

924
# Generates RegisterBackendSelect.cpp, a series of kernels which provide
925
# specialized computation of dispatch key for operator signatures which cannot
926
# be easily done automatically using templating.
927
@dataclass(frozen=True)
928
class ComputeBackendSelect:
929
    target: Literal[Target.DEFINITION, Target.REGISTRATION]
930

931
    # Selector object to determine which operators to generate
932
    # registration code for.
933
    selector: SelectiveBuilder
934

935
    @method_with_native_function
936
    def __call__(self, f: NativeFunction) -> str | None:
937
        if not needs_backend_select(f, self.selector):
938
            return None
939

940
        name = native.name(f.func)
941
        # BackendSelect can go to Meta, so it must preserve symints
942
        native_sig = NativeSignature(f.func, symint=True)
943

944
        native_tensor_args = [
945
            a
946
            for a in native_sig.arguments()
947
            if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
948
        ]
949

950
        dispatcher_sig = DispatcherSignature.from_schema(f.func)
951

952
        sig: NativeSignature | DispatcherSignature
953
        sig = dispatcher_sig
954
        dispatcher_exprs = dispatcher_sig.exprs()
955
        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
956

957
        if self.target is Target.DEFINITION:
958
            # I don't think there's actually a good reason to generate
959
            # these two cases differently
960
            # The first case could probably be improved though- it calls computeDispatchKeySet(),
961
            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
962
            if native_tensor_args:
963
                assert f.func.arguments.has_tensor_arg()
964
                tensor_args = ", ".join(a.name for a in native_tensor_args)
965
                compute_dk = f"""\
966
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
967
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
968
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
969
            else:
970
                assert not f.func.arguments.has_tensor_arg()
971
                compute_dk = (
972
                    f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
973
                )
974
            return f"""\
975
// aten::{f.func}
976
C10_ALWAYS_INLINE
977
{sig.defn(name)} {{
978
  {compute_dk}
979
  return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
980
      _dk, {', '.join(a.expr for a in dispatcher_exprs)});
981
}}
982
"""
983
        elif self.target is Target.REGISTRATION:
984
            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
985
        else:
986
            assert_never(self.target)
987

988

989
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
990
#
991
#                       YAML CODE GENERATION
992
#
993
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
994

995

996
def format_yaml(data: object) -> str:
997
    # Ignore alias in Dumper
998
    YamlDumper.ignore_aliases = lambda self, data: True  # type: ignore[assignment]
999

1000
    # Support serializing OrderedDict
1001
    def dict_representer(dumper: Any, data: Any) -> Any:
1002
        return dumper.represent_dict(data.items())
1003

1004
    YamlDumper.add_representer(OrderedDict, dict_representer)  # type: ignore[no-untyped-call]
1005
    # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
1006
    # width=1e9 turns off optional line breaks and improves
1007
    # the portability of the outputted yaml.
1008
    return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9)  # type: ignore[no-any-return, call-overload]
1009

1010

1011
# For some reason, some defaults we write to YAML are written as native
1012
# YAML objects, rather than doing them uniformly as strings.  This
1013
# function detects those cases and converts them into native Python
1014
# objects.
1015
def pythonify_default(s: str) -> object:
1016
    if s == "true":
1017
        return True
1018
    elif s == "false":
1019
        return False
1020

1021
    try:
1022
        return int(s)
1023
    except ValueError:
1024
        try:
1025
            return float(s)
1026
        except ValueError:
1027
            return s
1028

1029

1030
# What is a dynamic type?  Over time, the semantic meaning of
1031
# dynamic type has degraded to meaninglessness (in the old days,
1032
# it captured dtype-ness of types, but that has gone away with
1033
# the removal of TH).  These days, it's mostly the same thing as
1034
# the C++ API argument type, except that Tensor and Tensor?
1035
# arguments simply present as Tensor.
1036
#
1037
# TODO: Get rid of dynamic_type, after getting tools/autograd
1038
# to use the new codegen framework
1039
def dynamic_type(t: Type) -> str:
1040
    if isinstance(t, OptionalType):
1041
        return dynamic_type(t.elem)
1042
    # Note we don't use t.is_tensor_like() here because it would
1043
    # also include Tensor[]
1044
    if str(t) == "Tensor":
1045
        return "at::Tensor"
1046
    # This is a legacy concept, so never report SymInt
1047
    return cpp.argumenttype_type(
1048
        t, mutable=False, binds="__placeholder__", symint=False
1049
    ).cpp_type()
1050

1051

1052
def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
1053
    # This is written out explicitly to ensure that Tensor and
1054
    # namespace are put into the list in the right order
1055
    method_of = ["Type"]
1056
    if Variant.method in variants:
1057
        method_of.append("Tensor")
1058
    if Variant.function in variants:
1059
        method_of.append("namespace")
1060
    return method_of
1061

1062

1063
def compute_returns_yaml(
1064
    f: NativeFunction,
1065
) -> tuple[list[dict[str, str]], dict[str, str]]:
1066
    # Note [name and field_name]
1067
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
1068
    # To understand name_to_field_name, we must first talk about this
1069
    # schema:
1070
    #
1071
    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
1072
    #
1073
    # There is something very odd about this schema: it is an out
1074
    # variant of the function (that is to say, it will convert into
1075
    # at::lstsq_out() in the C++ API), but the names of the output
1076
    # return arguments don't match the keyword argument names of
1077
    # the inputs.  It TURNS OUT that in this situation, the historical
1078
    # Declarations.yaml we want to output is this (abbreviated to
1079
    # only show relevant fields):
1080
    #
1081
    #   arguments:
1082
    #     ...
1083
    #   - field_name: solution
1084
    #     name: X
1085
    #   - field_name: QR
1086
    #     name: qr
1087
    #     ...
1088
    #
1089
    #   returns:
1090
    #   - field_name: solution
1091
    #     name: X
1092
    #   - field_name: QR
1093
    #     name: qr
1094
    #
1095
    # The name of the return fields is stored in 'field_name', and the
1096
    # name of the arguments is stored in 'name'.  So when we process
1097
    # arguments, we need a way to get at the corresponding return.  At
1098
    # the moment, this is most conveniently done by constructing a
1099
    # mapping from name (the argument concept) to field_name (the
1100
    # return concept) while processing return arguments, since we don't
1101
    # directly maintain this correspondence in the modeling of function
1102
    # schema itself.
1103
    #
1104
    # See also https://github.com/pytorch/pytorch/issues/43114
1105
    name_to_field_name: dict[str, str] = {}
1106

1107
    # Compute the returns field of the YAML entry
1108
    names = cpp.return_names(f)
1109
    returns = []
1110
    for i, (r, name) in enumerate(zip(f.func.returns, names)):
1111
        ret = {
1112
            "dynamic_type": dynamic_type(r.type),
1113
            "name": name,
1114
            # legacy, report ints
1115
            "type": cpp.return_type(r, symint=False).cpp_type(),
1116
        }
1117

1118
        if r.name:
1119
            # See Note [name and field_name]
1120
            ret["field_name"] = r.name
1121
            if f.func.is_out_fn():
1122
                name_to_field_name[f.func.arguments.out[i].name] = r.name
1123

1124
        returns.append(ret)
1125

1126
    return returns, name_to_field_name
1127

1128

1129
# arguments in yaml roughly corresponds to the public C++ API
1130
def compute_cpp_argument_yaml(
1131
    cpp_a: Binding,
1132
    *,
1133
    schema_order: bool,
1134
    kwarg_only_set: set[str],
1135
    out_arg_set: set[str],
1136
    name_to_field_name: dict[str, str],
1137
) -> object:
1138
    if isinstance(cpp_a.argument, TensorOptionsArguments):
1139
        arg: dict[str, object] = {
1140
            "annotation": None,
1141
            "dynamic_type": "at::TensorOptions",
1142
            "is_nullable": False,
1143
            "name": cpp_a.name,
1144
            "type": cpp_a.type,
1145
            "kwarg_only": True,
1146
        }
1147
        if cpp_a.default is not None:
1148
            arg["default"] = cpp_a.default
1149
        return arg
1150
    elif isinstance(cpp_a.argument, SelfArgument):
1151
        raise AssertionError
1152
    elif isinstance(cpp_a.argument, Argument):
1153
        return compute_argument_yaml(
1154
            cpp_a.argument,
1155
            schema_order=schema_order,
1156
            kwarg_only_set=kwarg_only_set,
1157
            out_arg_set=out_arg_set,
1158
            name_to_field_name=name_to_field_name,
1159
        )
1160

1161

1162
def compute_argument_yaml(
1163
    a: Argument,
1164
    *,
1165
    schema_order: bool,
1166
    kwarg_only_set: set[str],
1167
    out_arg_set: set[str],
1168
    name_to_field_name: dict[str, str],
1169
) -> object:
1170
    arg: dict[str, object] = {
1171
        "annotation": str(a.annotation) if a.annotation else None,
1172
        "dynamic_type": dynamic_type(a.type),
1173
        "is_nullable": a.type.is_nullable(),
1174
        "name": a.name,
1175
        # legacy, report ints
1176
        "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
1177
    }
1178
    if a.default is not None:
1179
        arg["default"] = pythonify_default(
1180
            cpp.default_expr(a.default, a.type, symint=False)
1181
        )
1182
    if a.name in kwarg_only_set:
1183
        arg["kwarg_only"] = True
1184
    if a.name in out_arg_set:
1185
        arg["output"] = True
1186
        arg["allocate"] = True
1187
        # See Note [name and field_name]
1188
        if a.name in name_to_field_name:
1189
            arg["field_name"] = name_to_field_name[a.name]
1190
    # Historically, booleans don't get their size recorded, because it
1191
    # is already built into the cpp type (e.g., std::array<bool, 4>)
1192
    l = a.type.is_list_like()
1193
    if l is not None and l.size is not None and str(l.elem) != "bool":
1194
        arg["size"] = l.size
1195
    return arg
1196

1197

1198
@with_native_function
1199
def compute_declaration_yaml(f: NativeFunction) -> object:
1200
    returns, name_to_field_name = compute_returns_yaml(f)
1201

1202
    # These sets are used to conveniently test if an argument is a
1203
    # kwarg-only or out argument
1204
    kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
1205
    out_arg_set = {a.name for a in f.func.arguments.out}
1206

1207
    sig_group = CppSignatureGroup.from_native_function(
1208
        f, method=False, fallback_binding=False
1209
    )
1210
    cpp_args = sig_group.signature.arguments()
1211
    arguments = [
1212
        compute_cpp_argument_yaml(
1213
            cpp_a,
1214
            schema_order=False,
1215
            kwarg_only_set=kwarg_only_set,
1216
            out_arg_set=out_arg_set,
1217
            name_to_field_name=name_to_field_name,
1218
        )
1219
        for cpp_a in cpp_args
1220
    ]
1221

1222
    schema_order_jit_arguments = list(f.func.schema_order_arguments())
1223

1224
    schema_order_arguments = [
1225
        compute_argument_yaml(
1226
            a,
1227
            schema_order=True,
1228
            kwarg_only_set=kwarg_only_set,
1229
            out_arg_set=out_arg_set,
1230
            name_to_field_name=name_to_field_name,
1231
        )
1232
        for a in schema_order_jit_arguments
1233
    ]
1234

1235
    cpp_schema_order_types = [
1236
        # NB: method here doesn't matter
1237
        r.type
1238
        for a in schema_order_jit_arguments
1239
        for r in cpp.argument(
1240
            a,
1241
            method=False,
1242
            cpp_no_default_args=set(),
1243
            faithful=False,
1244
            symint=False,
1245
            has_tensor_options=False,
1246
        )
1247
    ]
1248

1249
    # legacy, report ints
1250
    cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
1251
    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
1252

1253
    is_factory_method = (
1254
        any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
1255
        and Variant.method not in f.variants
1256
    )
1257

1258
    return OrderedDict(
1259
        [
1260
            ("name", cpp.name(f.func)),
1261
            ("operator_name", str(f.func.name.name)),
1262
            ("overload_name", str(f.func.name.overload_name)),
1263
            ("manual_kernel_registration", f.manual_kernel_registration),
1264
            (
1265
                "category_override",
1266
                f.category_override if f.category_override is not None else "",
1267
            ),
1268
            ("schema_string", f"aten::{f.func}"),
1269
            ("arguments", arguments),
1270
            ("schema_order_cpp_signature", schema_order_cpp_signature),
1271
            ("schema_order_arguments", schema_order_arguments),
1272
            ("method_of", compute_method_of_yaml(f.variants)),
1273
            ("mode", "native"),
1274
            ("python_module", "" if f.python_module is None else f.python_module),
1275
            ("returns", returns),
1276
            ("inplace", f.func.name.name.inplace),
1277
            ("is_factory_method", is_factory_method),
1278
            ("abstract", f.is_abstract),
1279
            ("device_guard", f.device_guard),
1280
            ("with_gil", False),
1281
            ("deprecated", False),
1282
            ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
1283
        ]
1284
    )
1285

1286

1287
# See Note [Auto generated composite kernels]
1288
def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
1289
    return (f.structured or f.structured_delegate is not None) and (
1290
        f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
1291
    )
1292

1293

1294
@with_native_function_and_indices
1295
def compute_registration_declarations(
1296
    f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
1297
) -> str:
1298
    name = dispatcher.name(f.func)
1299
    returns_type = dispatcher.returns_type(
1300
        f.func.returns
1301
    ).cpp_type_registration_declarations()
1302
    args = dispatcher.arguments(f.func)
1303
    args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
1304
    comment_data: dict[str, str] = {
1305
        "schema": f"aten::{f.func}",
1306
        # TODO: What exactly is the semantics of the 'dispatch' field?
1307
        "dispatch": str(
1308
            {k for k, v in backend_indices.items() if v.has_kernel(f)}
1309
            != {DispatchKey.CompositeImplicitAutograd}
1310
            and {k for k, v in backend_indices.items() if v.has_kernel(f)}
1311
            != {
1312
                DispatchKey.CompositeImplicitAutograd,
1313
                DispatchKey.CompositeImplicitAutogradNestedTensor,
1314
            }
1315
        ),
1316
        "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
1317
    }
1318
    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
1319
"""
1320

1321

1322
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1323
#
1324
#                           RUN IT ALL
1325
#
1326
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1327

1328

1329
def get_custom_build_selector(
1330
    provided_op_registration_allowlist: list[str] | None,
1331
    op_selection_yaml_path: str | None,
1332
) -> SelectiveBuilder:
1333
    assert not (
1334
        provided_op_registration_allowlist is not None
1335
        and op_selection_yaml_path is not None
1336
    ), (
1337
        "Both provided_op_registration_allowlist and "
1338
        + "op_selection_yaml_path can NOT be provided at the "
1339
        + "same time."
1340
    )
1341

1342
    op_registration_allowlist: set[str] | None = None
1343
    if provided_op_registration_allowlist is not None:
1344
        op_registration_allowlist = set(provided_op_registration_allowlist)
1345

1346
    if op_registration_allowlist is not None:
1347
        selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
1348
            op_registration_allowlist,
1349
            True,
1350
            False,
1351
        )
1352
    elif op_selection_yaml_path is not None:
1353
        selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
1354
    else:
1355
        selector = SelectiveBuilder.get_nop_selector()
1356

1357
    return selector
1358

1359

1360
def get_grouped_by_view_native_functions(
1361
    native_functions: Sequence[NativeFunction],
1362
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
1363
    def maybe_create_view_group(
1364
        d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
1365
    ) -> list[NativeFunction | NativeFunctionsViewGroup]:
1366
        funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
1367
        if ViewSchemaKind.aliasing in d:
1368
            view = d.pop(ViewSchemaKind.aliasing)
1369
            view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
1370
            view_copy = d.pop(SchemaKind.functional, None)
1371

1372
            funcs.append(
1373
                NativeFunctionsViewGroup(
1374
                    view=view,
1375
                    view_copy=view_copy,
1376
                    view_inplace=view_inplace,
1377
                )
1378
            )
1379
        # Take the remaining functions that weren't part of the view group
1380
        # and emit them separately
1381
        funcs.extend(d.values())
1382
        return funcs
1383

1384
    grouped_by_views: dict[
1385
        FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
1386
    ] = defaultdict(dict)
1387
    for f in native_functions:
1388
        schema = f.func.view_signature()
1389
        view_kind: ViewSchemaKind = f.view_schema_kind
1390
        # We need to group up ops relevant to the same "view", consisting of:
1391
        # view op (ViewSchemaKind.aliasing)
1392
        # view_inplace op (ViewSchemaKind.aliasing_inplace)
1393
        # view_copy op (SchemaKind.functional)
1394
        if view_kind == ViewSchemaKind.non_aliasing:
1395
            kind = f.func.kind()
1396
            assert kind not in grouped_by_views[schema]
1397
            grouped_by_views[schema][kind] = f
1398
        else:
1399
            assert (
1400
                view_kind not in grouped_by_views[schema]
1401
            ), f"{view_kind} already in {grouped_by_views[schema].keys()}"
1402
            grouped_by_views[schema][view_kind] = f
1403

1404
    return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
1405

1406

1407
def get_grouped_native_functions(
1408
    native_functions: Sequence[NativeFunction],
1409
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1410
    def flatten_pre_group(
1411
        d: dict[SchemaKind, NativeFunction]
1412
    ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1413
        r = NativeFunctionsGroup.from_dict(d)
1414
        if r is None:
1415
            # Invariant: any NativeFunctions that are code-generated
1416
            # should have been grouped into NativeFunctionsGroup objects
1417
            assert not any("generated" in f.tags for f in d.values())
1418
            return list(d.values())
1419
        else:
1420
            return [r]
1421

1422
    # TODO: how come ValuesView isn't a Sequence lol
1423
    pre_grouped_native_functions = pre_group_native_functions(native_functions)
1424
    return list(
1425
        concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
1426
    )
1427

1428

1429
def get_ns_grouped_kernels(
1430
    *,
1431
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1432
    backend_indices: dict[DispatchKey, BackendIndex],
1433
    native_function_decl_gen: Callable[
1434
        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1435
    ] = dest.compute_native_function_declaration,
1436
) -> dict[str, list[str]]:
1437
    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1438
    for f in grouped_native_functions:
1439
        native_function_namespaces = set()
1440
        dispatch_keys = set()
1441
        for dispatch_key, backend_idx in backend_indices.items():
1442
            backend_metadata = backend_idx.get_kernel(f)
1443
            if backend_metadata:
1444
                namespace = backend_metadata.cpp_namespace
1445
                dispatch_keys.add(dispatch_key)
1446
                native_function_namespaces.add(namespace)
1447
            else:
1448
                namespace = DEFAULT_KERNEL_NAMESPACE
1449
            assert (
1450
                len(native_function_namespaces) <= 1
1451
            ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
1452
            ns_grouped_kernels[namespace].extend(
1453
                native_function_decl_gen(f, backend_idx)
1454
            )
1455
    return ns_grouped_kernels
1456

1457

1458
def get_native_function_declarations_from_ns_grouped_kernels(
1459
    *,
1460
    ns_grouped_kernels: dict[str, list[str]],
1461
) -> list[str]:
1462
    declarations: list[str] = []
1463
    newline = "\n"
1464
    for namespace, kernels in ns_grouped_kernels.items():
1465
        ns_helper = NamespaceHelper(
1466
            namespace_str=namespace,
1467
            entity_name="",
1468
            max_level=4,
1469
        )
1470
        # Convert to a set first to remove duplicate kernel names. Backends are
1471
        # allowed to repeat kernel names; only generate the declaration once!
1472
        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1473
        declarations.extend(
1474
            f"""
1475
{ns_helper.prologue}
1476
{newline.join(ordered_kernels)}
1477
{ns_helper.epilogue}
1478
        """.split(
1479
                newline
1480
            )
1481
        )
1482
    return declarations
1483

1484

1485
# Return native function declarations grouped by their namespaces.
1486
def get_native_function_declarations(
1487
    *,
1488
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1489
    backend_indices: dict[DispatchKey, BackendIndex],
1490
    native_function_decl_gen: Callable[
1491
        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1492
    ] = dest.compute_native_function_declaration,
1493
) -> list[str]:
1494
    """
1495
    Generate kernel declarations, in `NativeFunction(s).h`.
1496
    :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
1497
    :param backend_indices: kernel collections grouped by dispatch key.
1498
    :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
1499
    :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
1500
    """
1501

1502
    ns_grouped_kernels = get_ns_grouped_kernels(
1503
        grouped_native_functions=grouped_native_functions,
1504
        backend_indices=backend_indices,
1505
        native_function_decl_gen=native_function_decl_gen,
1506
    )
1507
    return get_native_function_declarations_from_ns_grouped_kernels(
1508
        ns_grouped_kernels=ns_grouped_kernels
1509
    )
1510

1511

1512
def get_kernel_namespace(
1513
    *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
1514
) -> str:
1515
    backend_metadata = backend_idx.get_kernel(f)
1516
    assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
1517
        f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
1518
        f"with dispatch key {backend_idx.dispatch_key}"
1519
        f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
1520
    )
1521
    return (
1522
        backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
1523
    )
1524

1525

1526
# Return native function definitions grouped by dispatch key and custom namespace.
1527
# Used in RegisterDispatchKey.cpp and etc.
1528
def get_native_function_definitions(
1529
    *,
1530
    fm: FileManager,
1531
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1532
    dispatch_key: DispatchKey,
1533
    backend_idx: BackendIndex,
1534
    selector: SelectiveBuilder,
1535
    rocm: bool,
1536
    symint: bool,
1537
    skip_dispatcher_op_registration: bool,
1538
    gen_dispatch_helpers: bool,
1539
) -> list[str]:
1540
    definitions: list[str] = []
1541
    ns_definitions: dict[str, list[str]] = defaultdict(list)
1542
    anonymous_definitions: dict[str, list[str]] = defaultdict(list)
1543
    registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
1544
    newline = "\n"
1545
    ns_gen = dest.RegisterDispatchKey(
1546
        backend_idx,
1547
        Target.NAMESPACED_DEFINITION,
1548
        selector,
1549
        rocm=rocm,
1550
        symint=symint,
1551
        class_method_name=None,
1552
        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1553
    )
1554
    anonymous_gen = dest.RegisterDispatchKey(
1555
        backend_idx,
1556
        Target.ANONYMOUS_DEFINITION,
1557
        selector,
1558
        rocm=rocm,
1559
        symint=symint,
1560
        class_method_name=None,
1561
        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1562
    )
1563
    reg_gen = dest.RegisterDispatchKey(
1564
        backend_idx,
1565
        Target.REGISTRATION,
1566
        selector,
1567
        rocm=rocm,
1568
        symint=symint,
1569
        class_method_name=None,
1570
        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1571
    )
1572
    for f in grouped_native_functions:
1573
        kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1574
            "::native", ""
1575
        )
1576

1577
        ns_definitions[kernel_namespace].extend(
1578
            ns_gen(f),
1579
        )
1580
        anonymous_definitions[kernel_namespace].extend(
1581
            anonymous_gen(f),
1582
        )
1583
        namespace = (
1584
            f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
1585
        )
1586
        if namespace not in registrations[kernel_namespace]:
1587
            registrations[kernel_namespace] = defaultdict(list)
1588
        registrations[kernel_namespace][namespace].extend(
1589
            reg_gen(f),
1590
        )
1591

1592
    for kernel_namespace in ns_definitions:
1593
        if len(ns_definitions[kernel_namespace]) == 0:
1594
            continue
1595
        ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
1596
        registration_body = ""
1597
        for namespace in registrations[kernel_namespace]:
1598
            if not registrations[kernel_namespace][namespace]:
1599
                continue
1600
            registration_body += f"""
1601
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
1602
    {newline.join(registrations[kernel_namespace][namespace])}
1603
}};"""
1604
        definitions.extend(
1605
            fm.substitute_with_template(
1606
                "RegisterDispatchDefinitions.ini",
1607
                lambda: {
1608
                    "ns_prologue": ns_helper.prologue,
1609
                    "ns_epilogue": ns_helper.epilogue,
1610
                    "dispatch_helpers": dest.gen_registration_helpers(backend_idx)
1611
                    if gen_dispatch_helpers
1612
                    else [],
1613
                    "dispatch_anonymous_definitions": anonymous_definitions[
1614
                        kernel_namespace
1615
                    ],
1616
                    "static_init_dispatch_registrations": ""
1617
                    if skip_dispatcher_op_registration
1618
                    else registration_body,
1619
                    "deferred_dispatch_registrations": "",
1620
                    "dispatch_namespace": dispatch_key.lower(),
1621
                    "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
1622
                },
1623
            ).split(newline)
1624
        )
1625

1626
    return definitions
1627

1628

1629
# Return native function declarations grouped by dispatch key and custom namespace.
1630
# Used in CPUFunctions_inl.h and etc.
1631
def get_namespaced_declaration(
1632
    *,
1633
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1634
    dispatch_key: DispatchKey,
1635
    backend_idx: BackendIndex,
1636
    selector: SelectiveBuilder,
1637
    rocm: bool,
1638
    symint: bool,
1639
) -> list[str]:
1640
    declarations: list[str] = []
1641
    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1642
    newline = "\n"
1643
    func = dest.RegisterDispatchKey(
1644
        backend_idx,
1645
        Target.NAMESPACED_DECLARATION,
1646
        selector,
1647
        rocm=rocm,
1648
        class_method_name=None,
1649
        skip_dispatcher_op_registration=False,
1650
        symint=symint,
1651
    )
1652
    for f in grouped_native_functions:
1653
        namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1654
            "native", dispatch_key.lower()
1655
        )
1656

1657
        ns_grouped_kernels[namespace].extend(
1658
            func(f),
1659
        )
1660

1661
    for namespace, kernels in ns_grouped_kernels.items():
1662
        if len(kernels) == 0:
1663
            continue
1664
        ns_helper = NamespaceHelper(
1665
            namespace_str=namespace, entity_name="", max_level=3
1666
        )
1667
        ordered_kernels = list(OrderedDict.fromkeys(kernels))
1668
        declarations.extend(
1669
            f"""
1670
{ns_helper.prologue}
1671
{newline.join(ordered_kernels)}
1672
{ns_helper.epilogue}
1673
        """.split(
1674
                newline
1675
            )
1676
        )
1677
    return declarations
1678

1679

1680
# Return native function schema registration code for aten and other namespaces.
1681
def get_native_function_schema_registrations(
1682
    *,
1683
    native_functions: Sequence[NativeFunction],
1684
    schema_selector: SelectiveBuilder,
1685
) -> tuple[list[str], str]:
1686
    ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
1687
    for native_function in native_functions:
1688
        ns_native_functions[native_function.namespace].append(native_function)
1689
    schema_registrations = ""
1690
    aten_schema_registrations = []
1691
    custom_namespace = None
1692
    for namespace, funcs in ns_native_functions.items():
1693
        schema_registrations_body = list(
1694
            mapMaybe(RegisterSchema(schema_selector), funcs)
1695
        )
1696
        # NB: we have to separate aten namespace registration from other namespaces,
1697
        # because in the template we hardcoded an operator for ATen already.
1698
        if namespace == "aten":
1699
            aten_schema_registrations = schema_registrations_body
1700
        else:
1701
            custom_namespace = namespace
1702
            tab = "\t"
1703
            # if the namespace is predefined, we should use define a library fragment
1704
            # instead of a new library
1705
            torch_library_macro = (
1706
                "TORCH_LIBRARY_FRAGMENT"
1707
                if namespace in FRAGMENT_NAMESPACES
1708
                else "TORCH_LIBRARY"
1709
            )
1710
            schema_registrations += f"""
1711
{torch_library_macro}({custom_namespace}, m) {{
1712
  {tab.join(schema_registrations_body)}
1713
}};"""
1714
    return (aten_schema_registrations, schema_registrations)
1715

1716

1717
def gen_aggregated_headers(
1718
    *,
1719
    native_functions: Sequence[NativeFunction],
1720
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1721
    structured_native_functions: Sequence[NativeFunctionsGroup],
1722
    static_dispatch_idx: list[BackendIndex],
1723
    selector: SelectiveBuilder,
1724
    backend_indices: dict[DispatchKey, BackendIndex],
1725
    cpu_fm: FileManager,
1726
    cuda_fm: FileManager,
1727
    functions_keys: set[DispatchKey],
1728
    dispatch_keys: Sequence[DispatchKey],
1729
    rocm: bool,
1730
) -> None:
1731
    # Buck doesn't support dynamic output files, so we aggregate all operator
1732
    # headers into a single file
1733
    cpu_fm.write(
1734
        "NativeMetaFunctions.h",
1735
        lambda: {
1736
            "NativeMetaFunctions_includes": [],
1737
            "NativeMetaFunctions_declarations": list(
1738
                mapMaybe(compute_meta_function_declaration, structured_native_functions)
1739
            ),
1740
        },
1741
    )
1742
    method_native_functions = [
1743
        fn for fn in native_functions if Variant.method in fn.variants
1744
    ]
1745
    non_method_native_functions = [
1746
        fn for fn in native_functions if fn not in method_native_functions
1747
    ]
1748
    cpu_fm.write(
1749
        "MethodOperators.h",
1750
        lambda: {
1751
            "MethodOperators_includes": [],
1752
            "MethodOperators_declarations": list(
1753
                mapMaybe(
1754
                    ComputeOperators(
1755
                        Target.DECLARATION,
1756
                        static_dispatch_backend_indices=static_dispatch_idx,
1757
                    ),
1758
                    method_native_functions,
1759
                )
1760
            ),
1761
        },
1762
    )
1763
    cpu_fm.write(
1764
        "Operators.h",
1765
        lambda: {
1766
            "Operators_includes": ["#include <ATen/MethodOperators.h>"],
1767
            "Operators_declarations": list(
1768
                mapMaybe(
1769
                    ComputeOperators(
1770
                        Target.DECLARATION,
1771
                        static_dispatch_backend_indices=static_dispatch_idx,
1772
                    ),
1773
                    non_method_native_functions,
1774
                )
1775
            ),
1776
        },
1777
    )
1778
    cpu_fm.write(
1779
        "Functions.h",
1780
        lambda: {
1781
            "static_dispatch_extra_headers": static_dispatch_extra_headers(
1782
                static_dispatch_idx
1783
            ),
1784
            "Functions_includes": ["#include <ATen/Operators.h>"],
1785
            "Functions_declarations": list(
1786
                mapMaybe(
1787
                    ComputeFunction(),
1788
                    native_functions,
1789
                )
1790
            ),
1791
        },
1792
    )
1793
    declarations = get_native_function_declarations(
1794
        grouped_native_functions=grouped_native_functions,
1795
        backend_indices=backend_indices,
1796
    )
1797
    cpu_fm.write(
1798
        "NativeFunctions.h",
1799
        lambda: {
1800
            "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
1801
            "NativeFunctions_declarations": declarations,
1802
        },
1803
    )
1804

1805
    for dispatch_key in dispatch_keys:
1806
        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1807
        if dispatch_key in functions_keys:
1808
            inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1809

1810
            fm.write_with_template(
1811
                f"{dispatch_key}Functions.h",
1812
                "DispatchKeyFunctions.h",
1813
                lambda: {
1814
                    "dispatch_key": str(dispatch_key),
1815
                    "inline_headers": inl_headers,
1816
                },
1817
            )
1818
            fm.write_with_template(
1819
                f"{dispatch_key}Functions_inl.h",
1820
                "DispatchKeyFunctions_inl.h",
1821
                lambda: {
1822
                    "DispatchKeyFunctions_inl_includes": [],
1823
                    "dispatch_namespace": dispatch_key.lower(),
1824
                    "dispatch_namespaced_declarations": get_namespaced_declaration(
1825
                        grouped_native_functions=grouped_native_functions,
1826
                        dispatch_key=dispatch_key,
1827
                        backend_idx=backend_indices[dispatch_key],
1828
                        selector=selector,
1829
                        rocm=rocm,
1830
                        symint=True,
1831
                    ),
1832
                },
1833
            )
1834

1835
        del fm
1836

1837

1838
def gen_per_operator_headers(
1839
    *,
1840
    native_functions: Sequence[NativeFunction],
1841
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1842
    static_dispatch_idx: list[BackendIndex],
1843
    selector: SelectiveBuilder,
1844
    backend_indices: dict[DispatchKey, BackendIndex],
1845
    cpu_fm: FileManager,
1846
    cuda_fm: FileManager,
1847
    ops_fm: FileManager,
1848
    functions_keys: set[DispatchKey],
1849
    dispatch_keys: Sequence[DispatchKey],
1850
    rocm: bool,
1851
) -> None:
1852
    # For CMake builds, split operator declarations into separate headers in
1853
    # the ATen/ops folder to split up header dependencies
1854
    functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
1855
    for fn in native_functions:
1856
        functions_by_root_name[fn.root_name].append(fn)
1857

1858
    grouped_functions_by_root_name: dict[
1859
        str, list[NativeFunction | NativeFunctionsGroup]
1860
    ] = defaultdict(list)
1861
    for group in grouped_native_functions:
1862
        name = group.root_name
1863
        grouped_functions_by_root_name[name].append(group)
1864

1865
    for name, functions in functions_by_root_name.items():
1866
        ops_fm.write_with_template(
1867
            f"{name}_ops.h",
1868
            "Operator.h",
1869
            lambda: {
1870
                "declarations": list(
1871
                    mapMaybe(
1872
                        ComputeOperators(
1873
                            Target.DECLARATION,
1874
                            static_dispatch_backend_indices=static_dispatch_idx,
1875
                        ),
1876
                        functions,
1877
                    )
1878
                ),
1879
            },
1880
        )
1881

1882
        ops_fm.write_with_template(
1883
            f"{name}.h",
1884
            "Function.h",
1885
            lambda: {
1886
                "static_dispatch_ops_headers": list(
1887
                    mapMaybe(
1888
                        lambda fn: static_dispatch_ops_header(
1889
                            fn, backend_index=static_dispatch_idx
1890
                        ),
1891
                        functions,
1892
                    )
1893
                ),
1894
                "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
1895
                "function_definitions": list(
1896
                    mapMaybe(
1897
                        ComputeFunction(),
1898
                        functions,
1899
                    )
1900
                ),
1901
            },
1902
        )
1903

1904
        grouped_functions = grouped_functions_by_root_name.get(name, [])
1905
        structured_functions = [
1906
            fn
1907
            for fn in grouped_functions
1908
            if isinstance(fn, NativeFunctionsGroup) and fn.structured
1909
        ]
1910
        is_structured = len(structured_functions) > 0
1911

1912
        if is_structured:
1913
            ops_fm.write_with_template(
1914
                f"{name}_meta.h",
1915
                "NativeMetaFunction.h",
1916
                lambda: {
1917
                    "meta_function_declarations": list(
1918
                        mapMaybe(
1919
                            compute_meta_function_declaration, structured_functions
1920
                        )
1921
                    ),
1922
                },
1923
            )
1924
        declarations = get_native_function_declarations(
1925
            grouped_native_functions=grouped_functions,
1926
            backend_indices=backend_indices,
1927
            native_function_decl_gen=dest.compute_native_function_declaration,
1928
        )
1929
        ops_fm.write_with_template(
1930
            f"{name}_native.h",
1931
            "NativeFunction.h",
1932
            lambda: {
1933
                "extra_includes": (
1934
                    f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
1935
                ),
1936
                "native_function_declarations": declarations,
1937
            },
1938
        )
1939

1940
    for category, suffix in [
1941
        ("Functions", ""),
1942
        ("Operators", "_ops"),
1943
        ("NativeMetaFunctions", "_meta"),
1944
        ("NativeFunctions", "_native"),
1945
    ]:
1946
        cpu_fm.write(
1947
            f"{category}.h",
1948
            lambda: {
1949
                f"{category}_includes": [
1950
                    f"#include <ATen/ops/{name}{suffix}.h>"
1951
                    for name in sorted(functions_by_root_name.keys())
1952
                ],
1953
                f"{category}_declarations": [],
1954
            },
1955
        )
1956

1957
    for dispatch_key in dispatch_keys:
1958
        if dispatch_key not in functions_keys:
1959
            continue
1960

1961
        dispatch_namespace = dispatch_key.lower()
1962
        dispatch_names = []
1963

1964
        for name, functions in functions_by_root_name.items():
1965
            grouped_functions = grouped_functions_by_root_name.get(name, [])
1966
            declarations = list(
1967
                concatMap(
1968
                    dest.RegisterDispatchKey(
1969
                        backend_indices[dispatch_key],
1970
                        Target.NAMESPACED_DECLARATION,
1971
                        selector,
1972
                        rocm=rocm,
1973
                        symint=True,
1974
                        class_method_name=None,
1975
                        skip_dispatcher_op_registration=False,
1976
                    ),
1977
                    grouped_functions,
1978
                )
1979
            )
1980

1981
            if len(declarations) == 0:
1982
                continue
1983

1984
            dispatch_names.append(name)
1985
            ops_fm.write_with_template(
1986
                f"{name}_{dispatch_namespace}_dispatch.h",
1987
                "DispatchKeyFunction.h",
1988
                lambda: {
1989
                    "dispatch_namespace": dispatch_namespace,
1990
                    "dispatch_namespaced_declarations": declarations,
1991
                },
1992
            )
1993

1994
        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1995
        inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1996

1997
        fm.write_with_template(
1998
            f"{dispatch_key}Functions.h",
1999
            "DispatchKeyFunctions.h",
2000
            lambda: {
2001
                "dispatch_key": str(dispatch_key),
2002
                "inline_headers": inl_headers,
2003
            },
2004
        )
2005
        fm.write_with_template(
2006
            f"{dispatch_key}Functions_inl.h",
2007
            "DispatchKeyFunctions_inl.h",
2008
            lambda: {
2009
                "dispatch_namespace": dispatch_namespace,
2010
                "DispatchKeyFunctions_inl_includes": [
2011
                    f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
2012
                    for name in sorted(dispatch_names)
2013
                ],
2014
                "dispatch_namespaced_declarations": [],
2015
            },
2016
        )
2017
        del fm
2018

2019
    cpu_fm.write(
2020
        "MethodOperators.h",
2021
        lambda: {
2022
            "MethodOperators_includes": sorted(
2023
                f"#include <ATen/ops/{name}_ops.h>"
2024
                for name, functions in functions_by_root_name.items()
2025
                if any(Variant.method in fn.variants for fn in functions)
2026
            ),
2027
            "MethodOperators_declarations": [],
2028
        },
2029
    )
2030

2031

2032
def gen_headers(
2033
    *,
2034
    native_functions: Sequence[NativeFunction],
2035
    valid_tags: set[str],
2036
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2037
    structured_native_functions: Sequence[NativeFunctionsGroup],
2038
    static_dispatch_idx: list[BackendIndex],
2039
    selector: SelectiveBuilder,
2040
    backend_indices: dict[DispatchKey, BackendIndex],
2041
    core_fm: FileManager,
2042
    cpu_fm: FileManager,
2043
    cuda_fm: FileManager,
2044
    ops_fm: FileManager,
2045
    dispatch_keys: Sequence[DispatchKey],
2046
    functions_keys: set[DispatchKey],
2047
    rocm: bool,
2048
    per_operator_headers: bool,
2049
) -> None:
2050
    if per_operator_headers:
2051
        gen_per_operator_headers(
2052
            native_functions=native_functions,
2053
            grouped_native_functions=grouped_native_functions,
2054
            static_dispatch_idx=static_dispatch_idx,
2055
            selector=selector,
2056
            backend_indices=backend_indices,
2057
            cpu_fm=cpu_fm,
2058
            cuda_fm=cuda_fm,
2059
            ops_fm=ops_fm,
2060
            dispatch_keys=dispatch_keys,
2061
            functions_keys=functions_keys,
2062
            rocm=rocm,
2063
        )
2064
    else:
2065
        gen_aggregated_headers(
2066
            native_functions=native_functions,
2067
            grouped_native_functions=grouped_native_functions,
2068
            structured_native_functions=structured_native_functions,
2069
            static_dispatch_idx=static_dispatch_idx,
2070
            selector=selector,
2071
            backend_indices=backend_indices,
2072
            cpu_fm=cpu_fm,
2073
            cuda_fm=cuda_fm,
2074
            dispatch_keys=dispatch_keys,
2075
            functions_keys=functions_keys,
2076
            rocm=rocm,
2077
        )
2078

2079
    core_fm.write(
2080
        "TensorBody.h",
2081
        lambda: {
2082
            "tensor_method_declarations": list(
2083
                mapMaybe(
2084
                    ComputeTensorMethod(
2085
                        target=Target.DECLARATION,
2086
                        static_dispatch_backend_indices=static_dispatch_idx,
2087
                    ),
2088
                    native_functions,
2089
                )
2090
            ),
2091
            "tensor_method_definitions": list(
2092
                mapMaybe(
2093
                    ComputeTensorMethod(
2094
                        target=Target.DEFINITION,
2095
                        static_dispatch_backend_indices=static_dispatch_idx,
2096
                    ),
2097
                    native_functions,
2098
                )
2099
            ),
2100
        },
2101
    )
2102

2103
    cpu_fm.write(
2104
        "RedispatchFunctions.h",
2105
        lambda: {
2106
            "function_redispatch_definitions": list(
2107
                mapMaybe(ComputeRedispatchFunction(), native_functions)
2108
            ),
2109
        },
2110
    )
2111

2112
    cpu_fm.write(
2113
        "RegistrationDeclarations.h",
2114
        lambda: {
2115
            "registration_declarations": [
2116
                compute_registration_declarations(f, backend_indices)
2117
                for f in native_functions
2118
            ],
2119
        },
2120
    )
2121

2122
    cpu_fm.write(
2123
        "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
2124
    )
2125

2126
    def gen_aten_interned_strings() -> dict[str, str]:
2127
        attrs: set[str] = set()  # All function argument names
2128
        names = set()  # All ATen function names
2129
        for func in native_functions:
2130
            names.add(str(func.func.name.name))
2131
            # Some operators don't have a functional variant but we still create a
2132
            # symbol without the underscore
2133
            names.add(func.func.name.name.base)
2134

2135
            attrs.update(arg.name for arg in func.func.schema_order_arguments())
2136

2137
        # These are keywords in C++, so aren't valid symbol names
2138
        # https://en.cppreference.com/w/cpp/language/operator_alternative
2139
        names -= {
2140
            "and",
2141
            "and_eq",
2142
            "bitand",
2143
            "bitor",
2144
            "compl",
2145
            "not",
2146
            "not_eq",
2147
            "or",
2148
            "or_eq",
2149
            "xor",
2150
            "xor_eq",
2151
        }
2152

2153
        return {
2154
            "aten_symbols": " \\\n".join(
2155
                [f"_(aten, {name})" for name in sorted(names)]
2156
            ),
2157
            "attr_symbols": " \\\n".join(
2158
                [f"_(attr, {name})" for name in sorted(attrs)]
2159
            ),
2160
        }
2161

2162
    core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
2163

2164
    def gen_tags_enum() -> dict[str, str]:
2165
        return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
2166

2167
    core_fm.write("enum_tag.h", gen_tags_enum)
2168

2169

2170
def gen_source_files(
2171
    *,
2172
    native_functions: Sequence[NativeFunction],
2173
    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2174
    structured_native_functions: Sequence[NativeFunctionsGroup],
2175
    view_groups: Sequence[NativeFunctionsViewGroup],
2176
    selector: SelectiveBuilder,
2177
    static_dispatch_idx: list[BackendIndex],
2178
    backend_indices: dict[DispatchKey, BackendIndex],
2179
    aoti_fm: FileManager,
2180
    core_fm: FileManager,
2181
    cpu_fm: FileManager,
2182
    cpu_vec_fm: FileManager,
2183
    cuda_fm: FileManager,
2184
    dispatch_keys: Sequence[DispatchKey],
2185
    functions_keys: set[DispatchKey],
2186
    rocm: bool,
2187
    force_schema_registration: bool,
2188
    per_operator_headers: bool,
2189
    skip_dispatcher_op_registration: bool,
2190
    update_aoti_c_shim: bool,
2191
) -> None:
2192
    extra_cuda_headers = """\
2193
#include <c10/cuda/CUDAGuard.h>
2194
#include <ATen/cuda/ATenCUDAGeneral.h>
2195
#include <ATen/cuda/CUDADevice.h>
2196
#include <ATen/cuda/CUDAContext.h>"""
2197
    if rocm:
2198
        extra_cuda_headers = """\
2199
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
2200
#include <ATen/hip/ATenHIPGeneral.h>
2201
#include <ATen/hip/HIPDevice.h>
2202
#include <ATen/hip/HIPContext.h>"""
2203

2204
    for dispatch_key in dispatch_keys:
2205
        fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
2206

2207
        if per_operator_headers:
2208

2209
            def operator_headers() -> list[str]:
2210
                headers = []
2211
                for g in grouped_native_functions:
2212
                    is_registered = False
2213
                    if backend_index.has_kernel(g):
2214
                        is_registered = True
2215
                    # The above has_kernel test on a group will only test for
2216
                    # the existence of out dispatch, because that's how
2217
                    # structured kernels work. But sometimes functions can be
2218
                    # grouped but not be structured, and then you need to check
2219
                    # each individual piece, as they may have manual dispatch
2220
                    # entries.
2221
                    elif isinstance(g, NativeFunctionsGroup) and any(
2222
                        backend_index.has_kernel(fn) for fn in g.functions()
2223
                    ):
2224
                        is_registered = True
2225
                    # TODO: this condition is a bit questionable
2226
                    # (It has to do with the fact that structured kernels get generated kernels
2227
                    # to the Meta + CompositeExplicitAutogradNonFunctional keys).
2228
                    elif g.structured and dispatch_key in (
2229
                        DispatchKey.Meta,
2230
                        DispatchKey.CompositeExplicitAutogradNonFunctional,
2231
                    ):
2232
                        is_registered = True
2233
                    if not is_registered:
2234
                        continue
2235

2236
                    headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
2237
                    if (
2238
                        dispatch_key
2239
                        == DispatchKey.CompositeExplicitAutogradNonFunctional
2240
                    ):
2241
                        headers.append(f"#include <ATen/ops/{g.root_name}.h>")
2242
                    if dispatch_key in functions_keys:
2243
                        headers.append(
2244
                            f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
2245
                        )
2246

2247
                return sorted(set(headers))
2248

2249
        else:
2250

2251
            def operator_headers() -> list[str]:
2252
                headers = ["#include <ATen/NativeFunctions.h>"]
2253
                if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
2254
                    headers.append("#include <ATen/Functions.h>")
2255
                if dispatch_key in functions_keys:
2256
                    headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
2257
                return headers
2258

2259
        backend_index = backend_indices[dispatch_key]
2260
        ns_grouped_native_functions = defaultdict(list)
2261
        for grouped_native_function in grouped_native_functions:
2262
            namespace = (
2263
                grouped_native_function.namespace
2264
                if isinstance(grouped_native_function, NativeFunction)
2265
                else grouped_native_function.functional.namespace
2266
            )
2267
            ns_grouped_native_functions[namespace].append(grouped_native_function)
2268

2269
        dispatch_namespace = str(dispatch_key).lower()
2270

2271
        # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
2272
        # compilation will fail when `-Werror=unused-function` flag is set
2273
        gen_dispatch_helpers: bool = (
2274
            dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
2275
        )
2276

2277
        dispatch_definitions = get_native_function_definitions(
2278
            fm=fm,
2279
            grouped_native_functions=grouped_native_functions,
2280
            dispatch_key=dispatch_key,
2281
            backend_idx=backend_index,
2282
            selector=selector,
2283
            rocm=rocm,
2284
            symint=True,
2285
            skip_dispatcher_op_registration=skip_dispatcher_op_registration,
2286
            gen_dispatch_helpers=gen_dispatch_helpers,
2287
        )
2288
        fm.write_with_template(
2289
            f"Register{dispatch_key}.cpp",
2290
            "RegisterDispatchKey.cpp",
2291
            lambda: {
2292
                "extra_cuda_headers": extra_cuda_headers
2293
                if is_cuda_dispatch_key(dispatch_key)
2294
                else "",
2295
                "external_backend_headers": "",
2296
                "dispatch_headers": dest.gen_registration_headers(
2297
                    backend_index, per_operator_headers, rocm
2298
                ),
2299
                "ops_headers": operator_headers(),
2300
                "dispatch_helpers": "",
2301
                "dispatch_definitions": dispatch_definitions,
2302
            },
2303
        )
2304

2305
        for g in structured_native_functions:
2306
            if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
2307
                continue
2308
            name = g.functional.func.name.name
2309
            if dispatch_key is DispatchKey.CPU:
2310
                assert fm is cpu_fm
2311
                fm.write_with_template(
2312
                    f"UfuncCPU_{name}.cpp",
2313
                    "UfuncCPU.cpp",
2314
                    lambda: {
2315
                        "meta_declaration": compute_meta_function_declaration(g),
2316
                        "native_declaration": dest.compute_native_function_declaration(
2317
                            g, backend_indices[dispatch_key]
2318
                        ),
2319
                        "native_definitions": dest.compute_ufunc_cpu(g),
2320
                    },
2321
                )
2322
                cpu_vec_fm.write_with_template(
2323
                    f"UfuncCPUKernel_{name}.cpp",
2324
                    "UfuncCPUKernel.cpp",
2325
                    lambda: {
2326
                        "name": name,
2327
                        "native_definitions": dest.compute_ufunc_cpu_kernel(g),
2328
                    },
2329
                )
2330
            elif dispatch_key is DispatchKey.CUDA:
2331
                cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
2332
                if rocm:
2333
                    cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
2334
                fm.write_with_template(
2335
                    f"UfuncCUDA_{name}.cu",
2336
                    "UfuncCUDA.cu",
2337
                    lambda: {
2338
                        "name": name,
2339
                        "cuda_headers": cuda_headers,
2340
                        "meta_declaration": compute_meta_function_declaration(g),
2341
                        "native_declaration": dest.compute_native_function_declaration(
2342
                            g, backend_indices[dispatch_key]
2343
                        ),
2344
                        "native_definitions": dest.compute_ufunc_cuda(g),
2345
                    },
2346
                )
2347
            else:
2348
                raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
2349

2350
        structured_func_group_dict = {}
2351
        for func_group in structured_native_functions:
2352
            for func in func_group.functions():
2353
                if func.structured_delegate is not None:
2354
                    structured_func_group_dict[func.structured_delegate] = func_group
2355
                    break
2356

2357
        if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
2358
            fallbacks = {}
2359
            for func in native_functions:
2360
                op_name = get_fallback_op_name(func)
2361
                if op_name in inductor_fallback_ops:
2362
                    fallbacks[op_name] = func
2363
            fallback_native_functions = tuple(
2364
                value for _, value in sorted(fallbacks.items())
2365
            )
2366

2367
            # header files were checked in for ABI-compatiblilty checking
2368
            header_file_name = f"c_shim_{dispatch_key.lower()}.h"
2369
            new_header = gen_aoti_c_shim(
2370
                fallback_native_functions,
2371
                structured_func_group_dict,
2372
                dispatch_key,
2373
                backend_indices,
2374
                header=True,
2375
                includes="",
2376
            )
2377
            if update_aoti_c_shim:
2378
                aoti_fm.write(
2379
                    header_file_name,
2380
                    lambda: new_header,
2381
                )
2382
            else:
2383
                try:
2384
                    with open(
2385
                        os.path.join(aoti_fm.install_dir, header_file_name)
2386
                    ) as old_file:
2387
                        old_header = old_file.read()
2388
                        assert (
2389
                            old_header == new_header
2390
                        ), """
2391

2392
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
2393
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
2394
Only in a limited number of situations, this is allowed:
2395

2396
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
2397
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
2398
C shim header files.
2399

2400
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
2401
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
2402
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
2403
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
2404
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
2405

2406
                        """
2407
                except FileNotFoundError:
2408
                    print(
2409
                        f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
2410
                    )
2411

2412
            # cpp files are always generated on-the-fly
2413
            def headers_for_aoti() -> str:
2414
                headers = []
2415
                for func in fallback_native_functions:
2416
                    header = get_header_for_aoti(
2417
                        func, structured_func_group_dict, dispatch_key, backend_indices
2418
                    )
2419
                    if header is not None:
2420
                        headers.append(header)
2421
                return "\n".join(sorted(set(headers)))
2422

2423
            extra_headers = (
2424
                extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
2425
            )
2426

2427
            aoti_fm.write(
2428
                f"c_shim_{dispatch_key.lower()}.cpp",
2429
                lambda: gen_aoti_c_shim(
2430
                    fallback_native_functions,
2431
                    structured_func_group_dict,
2432
                    dispatch_key,
2433
                    backend_indices,
2434
                    header=False,
2435
                    includes=headers_for_aoti() + "\n" + extra_headers,
2436
                ),
2437
            )
2438

2439
        del fm
2440

2441
    # BackendSelect is generated specially
2442
    def gen_backend_select() -> dict[str, list[str]]:
2443
        relevant_fns = [
2444
            fn for fn in native_functions if needs_backend_select(fn, selector)
2445
        ]
2446
        return {
2447
            "ops_headers": [
2448
                f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
2449
            ],
2450
            "backend_select_method_definitions": list(
2451
                mapMaybe(
2452
                    ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
2453
                )
2454
            ),
2455
            "backend_select_function_registrations": list(
2456
                mapMaybe(
2457
                    ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
2458
                )
2459
            ),
2460
        }
2461

2462
    cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
2463

2464
    schema_selector = selector
2465
    if force_schema_registration:
2466
        schema_selector = SelectiveBuilder.get_nop_selector()
2467

2468
    (
2469
        aten_schema_registrations,
2470
        schema_registrations,
2471
    ) = get_native_function_schema_registrations(
2472
        native_functions=native_functions, schema_selector=schema_selector
2473
    )
2474
    cpu_fm.write(
2475
        "RegisterSchema.cpp",
2476
        lambda: {
2477
            "aten_schema_registrations": []
2478
            if skip_dispatcher_op_registration
2479
            else aten_schema_registrations,
2480
            "schema_registrations": []
2481
            if skip_dispatcher_op_registration
2482
            else schema_registrations,
2483
        },
2484
    )
2485

2486
    def key_func(
2487
        fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2488
    ) -> str:
2489
        return fn.root_name
2490

2491
    cpu_fm.write_sharded(
2492
        "Operators.cpp",
2493
        native_functions,
2494
        key_fn=key_func,
2495
        env_callable=lambda fn: {
2496
            "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
2497
            "definitions": [
2498
                ComputeOperators(
2499
                    Target.DEFINITION,
2500
                    static_dispatch_backend_indices=static_dispatch_idx,
2501
                )(fn)
2502
            ],
2503
        },
2504
        base_env={
2505
            "static_dispatch_extra_headers": static_dispatch_extra_headers(
2506
                static_dispatch_idx
2507
            ),
2508
        },
2509
        num_shards=5,
2510
        sharded_keys={
2511
            "operator_headers",
2512
            "definitions",
2513
            "static_dispatch_extra_headers",
2514
        },
2515
    )
2516

2517
    cpu_fm.write("Functions.cpp", dict)
2518

2519
    core_fm.write("TensorMethods.cpp", dict)
2520

2521
    core_fm.write(
2522
        "ATenOpList.cpp",
2523
        lambda: {
2524
            "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
2525
        },
2526
    )
2527

2528
    def functionalization_env_callable(
2529
        g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2530
    ) -> dict[str, list[str]]:
2531
        def gen_op_headers(
2532
            g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2533
        ) -> list[str]:
2534
            if isinstance(g, NativeFunctionsViewGroup):
2535
                # view ops always get a functionalization kernel
2536
                headers = [
2537
                    f"#include <ATen/ops/{g.view.root_name}_native.h>",
2538
                    f"#include <ATen/ops/{g.view.root_name}_ops.h>",
2539
                ]
2540
                if g.view_copy is not None:
2541
                    headers += [
2542
                        f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
2543
                        f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
2544
                    ]
2545
                return headers
2546
            elif isinstance(g, NativeFunctionsGroup):
2547
                headers = [
2548
                    f"#include <ATen/ops/{g.functional.root_name}_native.h>",
2549
                    f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
2550
                    f"#include <ATen/ops/{g.out.root_name}_native.h>",
2551
                    f"#include <ATen/ops/{g.out.root_name}_ops.h>",
2552
                ]
2553
                if g.inplace is not None:
2554
                    headers += [
2555
                        f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
2556
                        f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
2557
                    ]
2558
                if g.mutable is not None:
2559
                    headers += [
2560
                        f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
2561
                        f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
2562
                    ]
2563
                return headers
2564
            else:
2565
                return [
2566
                    f"#include <ATen/ops/{g.root_name}_native.h>",
2567
                    f"#include <ATen/ops/{g.root_name}_ops.h>",
2568
                ]
2569

2570
        return {
2571
            "ops_headers": gen_op_headers(g),
2572
            "func_definitions": gen_functionalization_definition(
2573
                selector,
2574
                g,
2575
            ),
2576
            "func_registrations": gen_functionalization_registration(
2577
                selector,
2578
                g,
2579
                backend_indices[DispatchKey.CompositeImplicitAutograd],
2580
            ),
2581
        }
2582

2583
    all_groups: list[
2584
        NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
2585
    ] = list(structured_native_functions) + list(
2586
        view_groups  # type: ignore[assignment, arg-type, operator]
2587
    )
2588
    # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
2589
    # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
2590
    # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
2591
    # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
2592
    #     Although this could go away long-term if we add a dedicated dispatch key for decompositions.
2593
    structured_map: dict[OperatorName, NativeFunction] = {
2594
        f.func.name: f
2595
        for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
2596
    }
2597
    view_map: dict[OperatorName, NativeFunction] = {
2598
        f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
2599
    }
2600
    for f in native_functions:
2601
        if f.func.name not in structured_map and f.func.name not in view_map:
2602
            all_groups.append(f)
2603

2604
    cpu_fm.write_sharded(
2605
        "RegisterFunctionalization.cpp",
2606
        all_groups,
2607
        key_fn=key_func,
2608
        env_callable=functionalization_env_callable,
2609
        num_shards=4,
2610
        sharded_keys={
2611
            "ops_headers",
2612
            "func_definitions",
2613
            "func_registrations",
2614
            "func_add_back_views_definitions",
2615
            "func_add_back_views_registrations",
2616
        },
2617
    )
2618

2619
    cpu_fm.write(
2620
        "FunctionalInverses.h",
2621
        lambda: {
2622
            "view_inverse_declarations": list(
2623
                mapMaybe(
2624
                    lambda g: gen_functionalization_view_inverse_declaration(
2625
                        selector, g
2626
                    ),
2627
                    view_groups,
2628
                )
2629
            )
2630
        },
2631
    )
2632

2633
    # Note [view_copy NativeFunctions]
2634
    # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
2635
    # needs to have a corresponding non-aliasing {view}_copy variant.
2636
    # Backends that use functionalization and don't know how to handle aliasing ops
2637
    # are expected to implement kernels for these {view}_copy kernels instead.
2638
    # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
2639
    # so we codegen the following:
2640
    # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
2641
    #     These are never explicitly invoked by the functionalization pass,
2642
    #     but they could theoretically be called from user code (I added these kernels for completeness,
2643
    #     since the ops are part of the public API).
2644
    # (2) A derivative formula for every {view}_copy operator
2645
    #     {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
2646
    #     so rather than stamping all of the entries out in derivatives.yaml,
2647
    #     we codegen them in.
2648
    #     This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
2649
    cpu_fm.write(
2650
        "CompositeViewCopyKernels.cpp",
2651
        lambda: {
2652
            "ops_headers": [
2653
                "\n".join(
2654
                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2655
                    # NB: this include is important as it ensures we
2656
                    # set the visibility on generated view_copy kernels
2657
                    # correctly
2658
                    f"#include <ATen/ops/{f.root_name}_native.h>"
2659
                    for f in (
2660
                        [g.view] if g.view_copy is None else [g.view, g.view_copy]
2661
                    )
2662
                )
2663
                for g in view_groups
2664
            ]
2665
            + [
2666
                "\n".join(
2667
                    f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2668
                    # NB: this include is also important for correct visibility
2669
                    f"#include <ATen/ops/{f.root_name}_native.h>"
2670
                    for f in [g.inplace, g.mutable, g.functional]
2671
                    if f is not None and "generated" not in f.tags
2672
                )
2673
                for g in structured_native_functions
2674
            ],
2675
            "CompositeViewCopyKernel_Definitions": list(
2676
                mapMaybe(
2677
                    GenCompositeViewCopyKernel(
2678
                        backend_indices[
2679
                            DispatchKey.CompositeExplicitAutogradNonFunctional
2680
                        ]
2681
                    ),
2682
                    view_groups,
2683
                )
2684
            ),
2685
            "GeneratedCompositeFunctional_Definitions": list(
2686
                mapMaybe(
2687
                    gen_composite_functional_kernel,
2688
                    structured_native_functions,
2689
                )
2690
            ),
2691
            "GeneratedCompositeOut_Definitions": list(
2692
                mapMaybe(
2693
                    gen_composite_out_kernel,
2694
                    structured_native_functions,
2695
                )
2696
            ),
2697
        },
2698
    )
2699

2700

2701
def gen_declarations_yaml(
2702
    cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
2703
) -> None:
2704
    cpu_fm.write(
2705
        "Declarations.yaml",
2706
        lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
2707
    )
2708

2709

2710
def get_torchgen_root() -> Path:
2711
    """
2712
    If you're depending on torchgen out-of-tree, you can use the root to figure
2713
    out the path to native_functions.yaml
2714
    """
2715
    return Path(__file__).parent.resolve()
2716

2717

2718
def main() -> None:
2719
    parser = argparse.ArgumentParser(description="Generate ATen source files")
2720
    parser.add_argument(
2721
        "-s",
2722
        "--source-path",
2723
        help="path to source directory for ATen",
2724
        default="aten/src/ATen",
2725
    )
2726
    parser.add_argument(
2727
        "-o",
2728
        "--output-dependencies",
2729
        help="output a list of dependencies into the given file and exit",
2730
    )
2731
    parser.add_argument(
2732
        "--dry-run",
2733
        action="store_true",
2734
        help="run without writing any files (still updates outputs)",
2735
    )
2736
    parser.add_argument(
2737
        "--per-operator-headers",
2738
        action="store_true",
2739
        help="generate separate headers per operator in ATen/ops",
2740
    )
2741
    parser.add_argument(
2742
        "-d",
2743
        "--install-dir",
2744
        "--install_dir",
2745
        help="output directory",
2746
        default="build/aten/src/ATen",
2747
    )
2748
    parser.add_argument(
2749
        "--aoti-install-dir",
2750
        "--aoti_install_dir",
2751
        help="output directory for AOTInductor shim",
2752
        default="torch/csrc/inductor/aoti_torch/generated",
2753
    )
2754
    parser.add_argument(
2755
        "--rocm",
2756
        action="store_true",
2757
        help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
2758
    )
2759
    parser.add_argument(
2760
        "--mps",
2761
        action="store_true",
2762
        help="Generate MPS registration code when set",
2763
    )
2764
    # TODO: --op-registration-whitelist will be removed when all call-sites
2765
    # for gen.py are moved over to using the operator YAML file for mobile
2766
    # custom build.
2767
    parser.add_argument(
2768
        "--op-registration-whitelist",
2769
        "--op_registration_whitelist",
2770
        nargs="*",
2771
        help="filter op registrations by the whitelist (if set); "
2772
        "each item is `namespace`::`operator name` without overload name; "
2773
        "e.g.: aten::empty aten::conv2d ...",
2774
    )
2775
    parser.add_argument(
2776
        "--op-selection-yaml-path",
2777
        "--op_selection_yaml_path",
2778
        help="Provide a path to the operator selection (for custom build) YAML "
2779
        "that contains the information about the set of selected operators "
2780
        "and their categories (training, ...). Each operator is either a "
2781
        "full operator name with overload or just a bare operator name. "
2782
        "The operator names also contain the namespace prefix (e.g. aten::)",
2783
    )
2784
    parser.add_argument(
2785
        "--backend-whitelist",
2786
        "--backend_whitelist",
2787
        nargs="*",
2788
        help="filter dispatch backend by the whitelist (if set), "
2789
        "e.g.: CPU CUDA QuantizedCPU ...",
2790
    )
2791
    parser.add_argument(
2792
        "--static-dispatch-backend",
2793
        "--static_dispatch_backend",
2794
        nargs="*",
2795
        help="generate static dispatch code for the specific backend (if set)",
2796
    )
2797
    parser.add_argument(
2798
        "--skip-dispatcher-op-registration",
2799
        "--skip_dispatcher_op_registration",
2800
        action="store_true",
2801
        help="Avoid registering operators into the dispatcher.",
2802
    )
2803
    parser.add_argument(
2804
        "--force-schema-registration",
2805
        "--force_schema_registration",
2806
        action="store_true",
2807
        help="force it to generate schema-only registrations for all ops, including"
2808
        "those that are not listed on --op-registration-whitelist",
2809
    )
2810
    parser.add_argument(
2811
        "--generate",
2812
        type=str,
2813
        nargs="*",
2814
        choices=["headers", "sources", "declarations_yaml"],
2815
        default=["headers", "sources", "declarations_yaml"],
2816
        help="Generate only a subset of files",
2817
    )
2818
    parser.add_argument(
2819
        "--update-aoti-c-shim",
2820
        action="store_true",
2821
        help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
2822
        "WARNING: Do not use this unless you are sure what you are doing!!!",
2823
    )
2824

2825
    options = parser.parse_args()
2826

2827
    selector = get_custom_build_selector(
2828
        options.op_registration_whitelist,
2829
        options.op_selection_yaml_path,
2830
    )
2831

2832
    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
2833
    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
2834

2835
    from torchgen.model import dispatch_keys
2836

2837
    # TODO: stop generating CUDA kernels for non-CUDA builds
2838
    ignore_keys = set()
2839
    if not options.mps:
2840
        ignore_keys.add(DispatchKey.MPS)
2841

2842
        if DispatchKey.MPS in dispatch_keys:
2843
            del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
2844

2845
    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
2846
    valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
2847
    native_functions, backend_indices = (
2848
        parsed_yaml.native_functions,
2849
        parsed_yaml.backend_indices,
2850
    )
2851

2852
    grouped_native_functions = get_grouped_native_functions(native_functions)
2853

2854
    structured_native_functions = [
2855
        g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
2856
    ]
2857
    native_functions_with_view_groups = get_grouped_by_view_native_functions(
2858
        native_functions
2859
    )
2860
    view_groups = [
2861
        g
2862
        for g in native_functions_with_view_groups
2863
        if isinstance(g, NativeFunctionsViewGroup)
2864
    ]
2865

2866
    # NB: It is mandatory to NOT use os.path.join here, as the install directory
2867
    # will eventually be ingested by cmake, which does not respect Windows style
2868
    # path slashes.  If you switch this to use os.path.join, you'll get an error
2869
    # like:
2870
    #
2871
    #   Syntax error in cmake code when parsing string
2872
    #
2873
    #     C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
2874
    #
2875
    #   Invalid character escape '\c'.
2876
    core_install_dir = f"{options.install_dir}/core"
2877
    Path(core_install_dir).mkdir(parents=True, exist_ok=True)
2878
    ops_install_dir = f"{options.install_dir}/ops"
2879
    Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
2880
    aoti_install_dir = f"{options.aoti_install_dir}"
2881
    Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
2882

2883
    core_fm = make_file_manager(options=options, install_dir=core_install_dir)
2884
    cpu_fm = make_file_manager(options=options)
2885
    cpu_vec_fm = make_file_manager(options=options)
2886
    cuda_fm = make_file_manager(options=options)
2887
    ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2888
    aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
2889

2890
    # Only a limited set of dispatch keys get CPUFunctions.h headers generated
2891
    # for them; this is the set
2892
    functions_keys = {
2893
        DispatchKey.CPU,
2894
        DispatchKey.CUDA,
2895
        DispatchKey.CompositeImplicitAutograd,
2896
        DispatchKey.CompositeImplicitAutogradNestedTensor,
2897
        DispatchKey.CompositeExplicitAutograd,
2898
        DispatchKey.CompositeExplicitAutogradNonFunctional,
2899
        DispatchKey.Meta,
2900
    }
2901
    if options.mps:
2902
        functions_keys.add(DispatchKey.MPS)
2903

2904
    if options.backend_whitelist:
2905
        dispatch_keys = [
2906
            k
2907
            for k in dispatch_keys
2908
            if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
2909
        ]
2910

2911
    static_dispatch_idx: list[BackendIndex] = []
2912
    if options.static_dispatch_backend:
2913
        static_dispatch_idx = [
2914
            backend_indices[DispatchKey.parse(key)]
2915
            for key in options.static_dispatch_backend
2916
        ]
2917
        for key in options.static_dispatch_backend:
2918
            dp_key = DispatchKey.parse(key)
2919
            if dp_key not in functions_keys:
2920
                functions_keys.add(dp_key)
2921

2922
    if "sources" in options.generate:
2923
        gen_source_files(
2924
            native_functions=native_functions,
2925
            grouped_native_functions=grouped_native_functions,
2926
            structured_native_functions=structured_native_functions,
2927
            view_groups=view_groups,
2928
            selector=selector,
2929
            static_dispatch_idx=static_dispatch_idx,
2930
            backend_indices=backend_indices,
2931
            aoti_fm=aoti_fm,
2932
            core_fm=core_fm,
2933
            cpu_fm=cpu_fm,
2934
            cpu_vec_fm=cpu_vec_fm,
2935
            cuda_fm=cuda_fm,
2936
            dispatch_keys=dispatch_keys,
2937
            functions_keys=functions_keys,
2938
            rocm=options.rocm,
2939
            force_schema_registration=options.force_schema_registration,
2940
            per_operator_headers=options.per_operator_headers,
2941
            skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
2942
            update_aoti_c_shim=options.update_aoti_c_shim,
2943
        )
2944

2945
    if "headers" in options.generate:
2946
        gen_headers(
2947
            native_functions=native_functions,
2948
            valid_tags=valid_tags,
2949
            grouped_native_functions=grouped_native_functions,
2950
            structured_native_functions=structured_native_functions,
2951
            static_dispatch_idx=static_dispatch_idx,
2952
            selector=selector,
2953
            backend_indices=backend_indices,
2954
            core_fm=core_fm,
2955
            cpu_fm=cpu_fm,
2956
            cuda_fm=cuda_fm,
2957
            ops_fm=ops_fm,
2958
            dispatch_keys=dispatch_keys,
2959
            functions_keys=functions_keys,
2960
            rocm=options.rocm,
2961
            per_operator_headers=options.per_operator_headers,
2962
        )
2963

2964
    if "declarations_yaml" in options.generate:
2965
        gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
2966

2967
    if options.output_dependencies:
2968
        depfile_path = Path(options.output_dependencies).resolve()
2969
        depfile_name = depfile_path.name
2970
        depfile_stem = depfile_path.stem
2971

2972
        for fm, prefix in [
2973
            (cpu_fm, ""),
2974
            (cpu_vec_fm, "cpu_vec_"),
2975
            (core_fm, "core_"),
2976
            (cuda_fm, "cuda_"),
2977
            (ops_fm, "ops_"),
2978
        ]:
2979
            varname = prefix + depfile_stem
2980
            path = depfile_path.parent / (prefix + depfile_name)
2981
            fm.write_outputs(varname, str(path))
2982

2983

2984
if __name__ == "__main__":
2985
    main()
2986

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.