1
from __future__ import annotations
7
from collections import defaultdict, namedtuple, OrderedDict
8
from dataclasses import dataclass, field
9
from pathlib import Path
10
from typing import Any, Callable, Literal, Sequence, TypeVar
14
import torchgen.api.dispatcher as dispatcher
15
import torchgen.api.meta as meta
16
import torchgen.api.native as native
17
import torchgen.api.structured as structured
18
import torchgen.dest as dest
19
from torchgen.aoti.fallback_ops import inductor_fallback_ops
20
from torchgen.api import cpp
21
from torchgen.api.translate import translate
22
from torchgen.api.types import (
31
from torchgen.context import (
32
method_with_native_function,
33
native_function_manager,
35
with_native_function_and_indices,
37
from torchgen.gen_aoti_c_shim import (
39
gen_static_dispatch_backend_call_signature,
43
from torchgen.gen_functionalization_type import (
44
gen_functionalization_definition,
45
gen_functionalization_registration,
46
gen_functionalization_view_inverse_declaration,
47
GenCompositeViewCopyKernel,
49
from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
50
from torchgen.model import (
55
DEFAULT_KERNEL_NAMESPACE,
60
is_generic_dispatch_key,
61
is_ufunc_dispatch_key,
65
NativeFunctionsViewGroup,
70
STRUCTURED_DISPATCH_KEYS,
71
TensorOptionsArguments,
76
from torchgen.native_function_generation import (
77
add_generated_native_functions,
78
gen_composite_functional_kernel,
79
gen_composite_out_kernel,
80
pre_group_native_functions,
82
from torchgen.selective_build.selector import SelectiveBuilder
83
from torchgen.utils import (
93
from torchgen.yaml_utils import YamlDumper, YamlLoader
98
# Welcome to the ATen code generator v2! The ATen code generator is
99
# responsible for parsing native_functions.yaml and then generating
100
# various generated files (e.g., TypeDefault.cpp) based on the operators
101
# defined in this file. This means that the code generator knows how to
102
# parse function schema, and then translate this into various C++ types
103
# and boilerplate code.
105
# Some things to know about this file when you modify it:
107
# - This file has STRICT mypy typechecking. Typecheck it with
108
# `mypy --config mypy-strict.ini` in the root source directory
110
# - Most of the heavy lifting lives in external modules:
111
# - 'model' has the data model for native_functions.yaml. The classes
112
# in those file represent what you see when you look at
113
# a native_functions.yaml
114
# - 'api' has conversions for how to translate JIT schema into
115
# the various C++ APIs that the codegen interacts with. There
116
# are in fact THREE different C++ APIs: the public C++ API,
117
# the dispatcher API, and the legacy dispatcher API. See each
118
# of these respective files for more information
120
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
124
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
127
# A custom loader for YAML to let us also keep track of line numbers
128
# of each entry in the YAML file
129
class LineLoader(YamlLoader):
130
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
131
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
132
# Add 1 so line numbering starts at 1
133
mapping["__line__"] = node.start_mark.line + 1
137
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
138
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
141
_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
142
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
145
def parse_native_yaml_struct(
147
valid_tags: set[str],
148
ignore_keys: set[DispatchKey] | None = None,
149
path: str = "<stdin>",
150
skip_native_fns_gen: bool = False,
152
assert isinstance(es, list)
153
rs: list[NativeFunction] = []
154
bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
156
assert isinstance(e, dict), f"expected to be dict: {e}"
157
assert isinstance(e.get("__line__"), int), e
158
loc = Location(path, e["__line__"])
159
funcs = e.get("func")
160
assert funcs is not None, f"missed 'func' in {e}"
161
with context(lambda: f"in {loc}:\n {funcs}"):
162
func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
164
BackendIndex.grow_index(bs, m)
165
error_check_native_functions(rs)
166
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
167
indices: dict[DispatchKey, BackendIndex] = defaultdict(
168
lambda: BackendIndex(
169
dispatch_key=DispatchKey.Undefined,
170
use_out_as_primary=True,
173
# I'm actually not sure about this; undefined could be hit on
174
# empty TensorList, hypothetically that could have sizes in it
178
if not skip_native_fns_gen:
179
add_generated_native_functions(rs, bs)
180
for k, v in bs.items():
181
# All structured in-tree operators are implemented in terms of their out operator.
182
indices[k] = BackendIndex(
184
use_out_as_primary=True,
186
# Only cuda-like devices in tree require device guards
187
device_guard=is_cuda_dispatch_key(k),
190
return ParsedYaml(rs, indices)
193
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
194
assert isinstance(es, list)
197
assert isinstance(e.get("__line__"), int), e
198
loc = Location(path, e["__line__"])
200
with context(lambda: f"in {loc}:\n {tags}"):
202
name = e_i.pop("tag")
203
desc = e_i.pop("desc", "")
204
# ensure that each tag has a non-empty description
210
@functools.lru_cache(maxsize=None)
211
def parse_tags_yaml(path: str) -> set[str]:
212
global _GLOBAL_PARSE_TAGS_YAML_CACHE
213
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
214
with open(path) as f:
215
es = yaml.load(f, Loader=LineLoader)
216
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
218
return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
221
def parse_native_yaml(
224
ignore_keys: set[DispatchKey] | None = None,
226
skip_native_fns_gen: bool = False,
227
loaded_yaml: object | None = None,
229
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
230
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
231
valid_tags = parse_tags_yaml(tags_yaml_path)
233
# if a loaded yaml is provided, use that instead of reading from path
234
if loaded_yaml is None:
235
with open(path) as f:
236
es = yaml.load(f, Loader=LineLoader)
240
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
245
skip_native_fns_gen=skip_native_fns_gen,
248
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
251
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
252
# Assertions here are meant to be performed across NativeFunctions.
253
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
254
func_map: dict[OperatorName, NativeFunction] = {}
255
base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
257
func_map[f.func.name] = f
258
base_func_map[f.func.name.name].append(f)
260
if f.structured_delegate is not None:
261
delegate_func = func_map.get(f.structured_delegate)
262
assert delegate_func is not None, (
263
f"{f.func.name} is marked as a structured_delegate pointing to "
264
f"{f.structured_delegate}, but {f.structured_delegate} is missing."
266
assert delegate_func.structured, (
267
f"{f.func.name} is marked as a structured_delegate pointing to "
268
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
269
f"Consider adding 'structured=True' to the delegated operator"
271
# See Note [resize_ in Functionalization]
272
# resize_() is technically an inplace view op (and therefore needs the tag),
273
# but it would be overkill to add a true "view" variant of resize.
274
# Instead, resize_() gets special treatment in functionalization,
275
# and we have a resize() op that is non-aliasing + functional.
277
"inplace_view" in f.tags
278
and str(f.func.name) != "resize_"
279
and str(f.func.name) != "resize_as_"
280
and str(f.func.name.name) != "set_"
282
base_name = f.func.name.name
283
assert base_name.inplace, (
284
f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
285
"convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
287
out_of_place_base_name = BaseOperatorName(
288
base_name.base, False, base_name.dunder_method
290
assert len(base_func_map[out_of_place_base_name]) > 0, (
291
f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
292
f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
296
def cpp_string(s: str) -> str:
297
"""Convert a python string into a c++ string literal"""
298
s = s.replace("\\", "\\\\")
299
s = s.replace('"', '\\"')
300
s = s.replace("\a", "\\a")
301
s = s.replace("\b", "\\b")
302
s = s.replace("\f", "\\f")
303
s = s.replace("\n", "\\n")
304
s = s.replace("\v", "\\v")
305
s = s.replace("\t", "\\t")
309
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
313
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
315
# Most functions in this section are curried: they consist of a function
316
# that takes some parameters (e.g., what is to be generated) which itself
317
# returns a function that actually maps NativeFunction to the code
318
# to be generated. This pattern makes it convenient to use map, concatMap
319
# and similar functional combinators.
322
def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
323
if len(backends) == 0:
326
return [backend.dispatch_key for backend in backends] + [
327
DispatchKey.CompositeImplicitAutograd,
328
DispatchKey.CompositeImplicitAutogradNestedTensor,
329
DispatchKey.CompositeExplicitAutograd,
330
DispatchKey.CompositeExplicitAutogradNonFunctional,
334
def get_static_dispatch_backend(
335
f: NativeFunction, backend_index: BackendIndex
336
) -> DispatchKey | None:
337
if f.structured_delegate is not None or backend_index.has_kernel(f):
338
# TODO: for ops with structured_delegate it should check the dispatch table of
339
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
340
# so we always dispatch to the `backend`, but this could be wrong when we
341
# migrate math/default_backend ops to use structured delegate.
342
return backend_index.dispatch_key
343
elif f.has_composite_explicit_autograd_kernel:
344
return DispatchKey.CompositeExplicitAutograd
345
elif f.has_composite_explicit_autograd_non_functional_kernel:
346
return DispatchKey.CompositeExplicitAutogradNonFunctional
347
elif f.has_composite_implicit_autograd_kernel:
348
return DispatchKey.CompositeImplicitAutograd
349
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
350
return DispatchKey.CompositeImplicitAutogradNestedTensor
354
def static_dispatch_ops_header(
355
f: NativeFunction, backend_index: list[BackendIndex]
357
if backend_index is None or f.manual_kernel_registration:
361
for index in backend_index:
362
dispatch_key = get_static_dispatch_backend(f, index)
363
if dispatch_key is not None:
365
f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
367
return "\n".join(output)
370
def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
372
f"#include <ATen/{dispatch_key}Functions.h>"
373
for dispatch_key in static_dispatch_keys(backends)
377
# Translates arguments of `sig` to CppSignature bindings.
378
# Note that we have a special case for `memory_format` argument and this case is not covered by
379
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
381
sig: CppSignature | DispatcherSignature,
382
cpp_sig: CppSignature,
384
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
385
def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
386
output_bindings: list[Binding] = []
387
for binding in input_bindings:
388
if binding.name == "memory_format":
389
spl_mem_format_binding = Binding(
391
SpecialArgName.possibly_redundant_memory_format,
395
default=binding.default,
396
argument=binding.argument,
398
output_bindings.append(spl_mem_format_binding)
400
output_bindings.append(binding)
401
return output_bindings
403
src_bindings = list(sig.arguments())
404
goal_bindings = list(cpp_sig.arguments())
405
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
406
# get memory_format bindings of dispatcher signature to have the same NCType as well
407
for arg in goal_bindings:
408
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
409
src_bindings = add_spl_memory_format_binding(src_bindings)
411
exprs = translate(src_bindings, goal_bindings)
412
return ", ".join(a.expr for a in exprs)
415
def generate_static_dispatch_backend_call(
416
sig: CppSignature | DispatcherSignature,
418
backend_index: BackendIndex,
420
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
421
name = cpp_sig.name()
422
exprs = translate_args(sig, cpp_sig)
423
backend_metadata = backend_index.get_kernel(f)
425
backend_metadata.cpp_namespace
426
if backend_metadata and backend_metadata.cpp_namespace
427
else DEFAULT_KERNEL_NAMESPACE
429
ns = kernel_ns.replace("::native", "")
430
return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
433
def generate_static_dispatch_fallback_call(
434
sig: CppSignature | DispatcherSignature,
436
backend_indices: list[BackendIndex],
438
cpp_sigs = CppSignatureGroup.from_native_function(
439
f, method=False, fallback_binding=False
441
if sig.symint and f.func.has_symint():
442
cpp_sig = cpp_sigs.symint_signature
444
cpp_sig = cpp_sigs.signature
445
assert cpp_sig is not None
446
name = cpp_sig.name()
447
exprs = translate_args(sig, cpp_sig)
448
ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
449
if f.has_composite_explicit_autograd_kernel:
450
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
451
elif f.has_composite_explicit_autograd_non_functional_kernel:
452
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
453
elif f.has_composite_implicit_autograd_kernel:
454
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
455
elif f.has_composite_implicit_autograd_nested_tensor_kernel:
456
return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
458
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
459
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
463
sig: CppSignature | DispatcherSignature,
465
backend_indices: list[BackendIndex],
468
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
469
backends exsit, fallback to static dispatch by determining dispatch key from inputs.
471
sig: A CppSignature or DispatcherSignature for this native function we want to use.
472
f: NativeFunction to generate static dispatch.
473
backend_indices: All available backends.
475
C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
477
if len(backend_indices) == 0 or f.manual_kernel_registration:
482
for b in backend_indices
485
f.structured_delegate is not None
486
and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
490
return generate_static_dispatch_backend_call(sig, f, keys[0])
492
return generate_static_dispatch_fallback_call(sig, f, backend_indices)
494
native_tensor_args = [
496
for a in sig.arguments()
497
if isinstance(a.argument, SelfArgument)
498
or isinstance(a.argument, Argument)
499
and a.argument.type.is_tensor_like()
501
tensor_args = ", ".join(native_tensor_args)
502
tensor_opts = f.func.arguments.tensor_options
505
subexprs: list[str] = []
506
if tensor_opts is not None:
508
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
510
if tensor_args != "":
511
subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
512
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
513
stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
517
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
518
dispatch_code.append(
519
f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
522
fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
526
{connector.join(stmts)}
528
{connector.join(dispatch_code)}
535
# Generates RegisterSchema.cpp. Depending on the selector, either
536
# all schemas are registered, or only some are (in the case of
538
@dataclass(frozen=True)
540
selector: SelectiveBuilder
541
known_tags: dict[str, int] = field(default_factory=dict)
543
@method_with_native_function
544
def __call__(self, f: NativeFunction) -> str | None:
545
if not self.selector.is_native_function_selected(f):
547
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
549
return f"m.def({cpp_string(str(f.func))}, {{}});\n"
551
if tags not in self.known_tags:
552
idx = len(self.known_tags)
553
self.known_tags[tags] = idx
554
maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
555
return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
558
# Generates Operators.h and Operators.cpp.
559
# These provide macros that, given an operator and overload name, allow users
560
# to access an "un-overloaded" function version of the operator. This
561
# is useful for extension writers who want to (1) want to decltype the operator
562
# and (2) don't want to worry about method-only operators.
563
@dataclass(frozen=True)
564
class ComputeOperators:
565
target: Literal[Target.DECLARATION, Target.DEFINITION]
566
static_dispatch_backend_indices: list[BackendIndex]
568
@method_with_native_function
569
def __call__(self, f: NativeFunction) -> str:
570
sig = DispatcherSignature.from_schema(f.func)
571
name = f.func.name.unambiguous_name()
573
if self.target is Target.DECLARATION:
574
# Note [The ATen Operators API]
575
# The ATen Operators API lives in the at::_ops namespace, and contains compile-time
576
# metadata about each operator + entry points into the Dispatcher.
577
# The C++ function, method, and redispatch API's are all implemented as wrappers
578
# into various bits of the structs defined here.
580
# Important characteristics about the Operators API:
581
# (1) It follows the Dispatcher API.
582
# This is kind of necessary to avoid overhead.
583
# For example: if it followed the C++ API, then all of the faithful C++ factory functions
584
# would need to wrap their arguments into TensorOptions only to unwrap them again.
585
# (2) Overload names are disambiguated.
586
# This is helpful for pytorch extenders who would like to decltype() an aten operator,
587
# that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
588
# (3) No argument defaulting is allowed.
589
# This is more of an implementation detail to avoid #include cycles,
590
# since TensorBody.h (which defines the Tensor class) needs to include this file.
591
# (4) manual_cpp_bindings and faithful names are not included in the API.
592
# This applies to stuff like __dispatch__is_complex(), and add_outf().
593
# These aren't "real aten ops", they're just additional functions provided by the C++ API.
594
# They're implemented as wrappers in Functions.h that call into the actual operators
595
# defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
596
# This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
598
struct TORCH_API {name} {{
599
using schema = {sig.type()};
600
using ptr_schema = schema*;
601
// See Note [static constexpr char* members for windows NVCC]
602
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
603
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
604
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
605
static {sig.defn(name="call", is_redispatching_fn=False)};
606
static {sig.defn(name="redispatch", is_redispatching_fn=True)};
609
elif self.target is Target.DEFINITION:
611
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
612
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
613
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
616
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
617
return c10::Dispatcher::singleton()
618
.findSchemaOrThrow({name}::name, {name}::overload_name)
619
.typed<{name}::schema>();
622
for is_redispatching_fn in [False, True]:
623
if is_redispatching_fn:
624
dispatcher_exprs_str = ", ".join(
625
["dispatchKeySet"] + [a.name for a in sig.arguments()]
627
method_base = "redispatch"
629
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
632
dispatcher_call = method_base
633
method_name = f"{name}::{method_base}"
636
static auto op = create_{name}_typed_handle();
637
return op.{dispatcher_call}({dispatcher_exprs_str});"""
640
not is_redispatching_fn
641
and len(self.static_dispatch_backend_indices) > 0
643
# call() should go through static dispatch
644
fn_body = static_dispatch(
645
sig, f, backend_indices=self.static_dispatch_backend_indices
649
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
655
assert_never(self.target)
658
# Generates Functions.h, which provides the functional public C++ API,
659
# and the scaffolding to call into the dispatcher from these functions.
660
@dataclass(frozen=True)
661
class ComputeFunction:
662
@method_with_native_function
663
def __call__(self, f: NativeFunction) -> str | None:
664
sig_group = CppSignatureGroup.from_native_function(
665
f, method=False, fallback_binding=f.manual_cpp_binding
667
has_symint = f.func.has_symint()
670
for sig in sig_group.signatures():
671
# See Note [The ATen Operators API]
672
target_sig = DispatcherSignature.from_schema(f.func)
673
exprs = translate(sig.arguments(), target_sig.arguments())
674
exprs_str = ", ".join([e.expr for e in exprs])
677
intlike_t = "c10::SymInt"
679
intlike_t = "int64_t"
681
if Variant.function in f.variants:
684
inline {sig.decl()} {{
685
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
688
# The template function can be used from template situations
689
# where you want to switch between the symint or not version
690
# depending on a template argument
692
# NB: we ALWAYS generate this even for methods. But we put it in
693
# this header so it can take advantage of per-op headers
697
template <typename T, typename = std::enable_if_t<std::is_same<T, {intlike_t}>::value>>
698
{sig.decl(suppress_symint_suffix=True)} {{
699
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
706
# Generates TensorBody.h. This file provides the object-oriented (method-based)
707
# public C++ API, and the scaffolding to call into the dispatcher from these functions.
708
@dataclass(frozen=True)
709
class ComputeTensorMethod:
710
target: Literal[Target.DECLARATION, Target.DEFINITION]
711
static_dispatch_backend_indices: list[BackendIndex]
713
@method_with_native_function
714
def __call__(self, f: NativeFunction) -> str | None:
715
if Variant.method not in f.variants:
718
assert not f.func.is_out_fn()
719
assert f.func.arguments.self_arg is not None
721
sig_group = CppSignatureGroup.from_native_function(
722
f, method=True, fallback_binding=f.manual_cpp_binding
725
if self.target is Target.DECLARATION:
727
for sig in sig_group.signatures():
728
result += f"{sig.decl()} const;\n"
731
if self.target is not Target.DEFINITION:
732
assert_never(self.target)
736
for sig in sig_group.signatures():
737
target_sig = DispatcherSignature.from_schema(f.func)
738
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
739
exprs_str = ", ".join([e.expr for e in exprs])
743
inline {sig.defn(prefix="Tensor::")} const {{
744
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
751
# Generates RedispatchFunctions.h.
752
# This is similar to the C++ API defined in Functions.h, but provides access
753
# to the dispatcher's redispatch API.
754
@dataclass(frozen=True)
755
class ComputeRedispatchFunction:
756
@method_with_native_function
757
def __call__(self, f: NativeFunction) -> str | None:
758
# We unconditionally generate function variants of the redispatch API.
759
# This is mainly because we can namespace functions separately, but not methods,
760
sig_group = CppSignatureGroup.from_native_function(
761
f, method=False, fallback_binding=f.manual_cpp_binding
765
for sig in sig_group.signatures():
766
target_sig = DispatcherSignature.from_schema(f.func)
767
exprs = translate(sig.arguments(), target_sig.arguments())
768
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
772
inline {sig.decl(is_redispatching_fn=True)} {{
773
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
780
# Generates ATenOpList.cpp, a runtime accessible list of all aten
782
# TODO: This was historically used to help some JIT interop code
783
# figure out whether or not to treat aten namespace'd operators
784
# one way or another, we should reevaluate if this is actually needed.
786
def compute_aten_op(f: NativeFunction) -> str:
787
return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
790
# Generates MetaFunctions.h
791
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
794
with native_function_manager(g.out):
796
args = structured.meta_arguments(g)
797
args_str = ", ".join(a.decl() for a in args)
798
parent_class = g.out.structured_inherits
799
if parent_class is None:
800
parent_class = "at::impl::MetaBase"
802
precomputed = g.out.precomputed if g.structured else None
805
# Generate the template declaration with one bool parameter for each
806
# precomputed element. Each parameter is true if the corresponding (in
807
# terms of position) precomputed element has been set.
808
precomputed_values = [*precomputed.replace.values(), precomputed.add]
809
precomputed_elements = [
810
elem for replace_list in precomputed_values for elem in replace_list
812
precomputed_template_parameters = [
813
elem.name.upper() for elem in precomputed_elements
815
precomputed_template_params_str = ", ".join(
816
f"bool {param} = false" for param in precomputed_template_parameters
818
precompute_template_decl = f"template <{precomputed_template_params_str}>"
820
# Generate a string containing declarations of all precomputed elements.
821
precomputed_elements_with_cpp_types = [
822
structured.argument_type(elem, binds=elem.name)
823
for elem in precomputed_elements
826
precomputed_elements_decl = ";\n".join(
827
f"{elem.cpp_type(strip_ref=True)} {elem.name}"
828
for elem in precomputed_elements_with_cpp_types
831
# Generate "setter" methods for each precomputed element. Each method will return
832
# a new instance of precompute_out with the template parameter that corresponds to
833
# the member set by the method to true (to indicate that it has been set).
835
for i, elem in enumerate(precomputed_elements):
836
# Generate the signature. The return type will be the same
837
# as the type of `this` but with the template parameter
838
# corresponding to the element set by this method set to true.
839
# The assert generated below will ensure that this template
840
# parameter is false on the type of `this`.
841
return_ty_templates = ", ".join(
842
precomputed_template_parameters[:i]
844
+ precomputed_template_parameters[i + 1 :]
846
return_ty = f"precompute_out<{return_ty_templates}>"
847
elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
850
signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
852
# Generate an assert which checks that the
853
# template parameter corresponding to the precomputed
854
# element that is set by this method is false on the
855
# class corresponding to the object that `this` points to.
856
# This ensures that each element can be set only once.
857
assert_msg = f'"{elem.name} already set"'
858
assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
860
# Generate the new object construction block. All state
861
# except the element that this method sets is copied from the
862
# object that `this` points to. The value for the element that
863
# the method sets is taken from a method parameter.
864
construction_stmts = []
865
construction_stmts.append(f"{return_ty} ret;")
867
for j, elem in enumerate(precomputed_elements):
869
construction_stmts.append(f"ret.{elem.name} = value;")
871
construction_stmts.append(
872
f"ret.{elem.name} = this->{elem.name};"
875
construction_stmts.append("return ret;")
876
construction_block = "\n".join(construction_stmts)
878
setter_methods.append(
886
setter_methods_decl = "\n".join(setter_methods)
888
# Meta should return an instance of the struct containing the precomputed elements.
889
meta_return_template_params = ", ".join(
890
["true"] * len(precomputed_template_parameters)
892
# This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
893
# type (which has a variable number of template parameters).
894
meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
895
meta_return = "meta_return_ty"
896
precomputed_decl = f"""
897
{precompute_template_decl}
898
struct TORCH_API precompute_out {{
899
{setter_methods_decl}
900
{precomputed_elements_decl};
903
meta_return_typedef = ""
904
precomputed_decl = ""
907
struct TORCH_API structured_{name} : public {parent_class} {{
909
{meta_return_typedef}
910
{meta_return} meta({args_str});
915
def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
916
name = str(f.func.name.name)
917
if name.endswith("_like") or name.startswith("new_"):
919
if f.func.arguments.tensor_options is None:
921
return selector.is_native_function_selected(f)
924
# Generates RegisterBackendSelect.cpp, a series of kernels which provide
925
# specialized computation of dispatch key for operator signatures which cannot
926
# be easily done automatically using templating.
927
@dataclass(frozen=True)
928
class ComputeBackendSelect:
929
target: Literal[Target.DEFINITION, Target.REGISTRATION]
931
# Selector object to determine which operators to generate
932
# registration code for.
933
selector: SelectiveBuilder
935
@method_with_native_function
936
def __call__(self, f: NativeFunction) -> str | None:
937
if not needs_backend_select(f, self.selector):
940
name = native.name(f.func)
941
# BackendSelect can go to Meta, so it must preserve symints
942
native_sig = NativeSignature(f.func, symint=True)
944
native_tensor_args = [
946
for a in native_sig.arguments()
947
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
950
dispatcher_sig = DispatcherSignature.from_schema(f.func)
952
sig: NativeSignature | DispatcherSignature
954
dispatcher_exprs = dispatcher_sig.exprs()
955
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
957
if self.target is Target.DEFINITION:
958
# I don't think there's actually a good reason to generate
959
# these two cases differently
960
# The first case could probably be improved though- it calls computeDispatchKeySet(),
961
# which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
962
if native_tensor_args:
963
assert f.func.arguments.has_tensor_arg()
964
tensor_args = ", ".join(a.name for a in native_tensor_args)
966
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
967
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
968
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
970
assert not f.func.arguments.has_tensor_arg()
972
f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
979
return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
980
_dk, {', '.join(a.expr for a in dispatcher_exprs)});
983
elif self.target is Target.REGISTRATION:
984
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
986
assert_never(self.target)
989
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
991
# YAML CODE GENERATION
993
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
996
def format_yaml(data: object) -> str:
997
# Ignore alias in Dumper
998
YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
1000
# Support serializing OrderedDict
1001
def dict_representer(dumper: Any, data: Any) -> Any:
1002
return dumper.represent_dict(data.items())
1004
YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
1005
# Some yaml parsers (e.g. Haskell's) don't understand line breaks.
1006
# width=1e9 turns off optional line breaks and improves
1007
# the portability of the outputted yaml.
1008
return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
1011
# For some reason, some defaults we write to YAML are written as native
1012
# YAML objects, rather than doing them uniformly as strings. This
1013
# function detects those cases and converts them into native Python
1015
def pythonify_default(s: str) -> object:
1030
# What is a dynamic type? Over time, the semantic meaning of
1031
# dynamic type has degraded to meaninglessness (in the old days,
1032
# it captured dtype-ness of types, but that has gone away with
1033
# the removal of TH). These days, it's mostly the same thing as
1034
# the C++ API argument type, except that Tensor and Tensor?
1035
# arguments simply present as Tensor.
1037
# TODO: Get rid of dynamic_type, after getting tools/autograd
1038
# to use the new codegen framework
1039
def dynamic_type(t: Type) -> str:
1040
if isinstance(t, OptionalType):
1041
return dynamic_type(t.elem)
1042
# Note we don't use t.is_tensor_like() here because it would
1043
# also include Tensor[]
1044
if str(t) == "Tensor":
1046
# This is a legacy concept, so never report SymInt
1047
return cpp.argumenttype_type(
1048
t, mutable=False, binds="__placeholder__", symint=False
1052
def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
1053
# This is written out explicitly to ensure that Tensor and
1054
# namespace are put into the list in the right order
1055
method_of = ["Type"]
1056
if Variant.method in variants:
1057
method_of.append("Tensor")
1058
if Variant.function in variants:
1059
method_of.append("namespace")
1063
def compute_returns_yaml(
1065
) -> tuple[list[dict[str, str]], dict[str, str]]:
1066
# Note [name and field_name]
1067
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
1068
# To understand name_to_field_name, we must first talk about this
1071
# lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
1073
# There is something very odd about this schema: it is an out
1074
# variant of the function (that is to say, it will convert into
1075
# at::lstsq_out() in the C++ API), but the names of the output
1076
# return arguments don't match the keyword argument names of
1077
# the inputs. It TURNS OUT that in this situation, the historical
1078
# Declarations.yaml we want to output is this (abbreviated to
1079
# only show relevant fields):
1083
# - field_name: solution
1090
# - field_name: solution
1095
# The name of the return fields is stored in 'field_name', and the
1096
# name of the arguments is stored in 'name'. So when we process
1097
# arguments, we need a way to get at the corresponding return. At
1098
# the moment, this is most conveniently done by constructing a
1099
# mapping from name (the argument concept) to field_name (the
1100
# return concept) while processing return arguments, since we don't
1101
# directly maintain this correspondence in the modeling of function
1104
# See also https://github.com/pytorch/pytorch/issues/43114
1105
name_to_field_name: dict[str, str] = {}
1107
# Compute the returns field of the YAML entry
1108
names = cpp.return_names(f)
1110
for i, (r, name) in enumerate(zip(f.func.returns, names)):
1112
"dynamic_type": dynamic_type(r.type),
1114
# legacy, report ints
1115
"type": cpp.return_type(r, symint=False).cpp_type(),
1119
# See Note [name and field_name]
1120
ret["field_name"] = r.name
1121
if f.func.is_out_fn():
1122
name_to_field_name[f.func.arguments.out[i].name] = r.name
1126
return returns, name_to_field_name
1129
# arguments in yaml roughly corresponds to the public C++ API
1130
def compute_cpp_argument_yaml(
1134
kwarg_only_set: set[str],
1135
out_arg_set: set[str],
1136
name_to_field_name: dict[str, str],
1138
if isinstance(cpp_a.argument, TensorOptionsArguments):
1139
arg: dict[str, object] = {
1141
"dynamic_type": "at::TensorOptions",
1142
"is_nullable": False,
1147
if cpp_a.default is not None:
1148
arg["default"] = cpp_a.default
1150
elif isinstance(cpp_a.argument, SelfArgument):
1151
raise AssertionError
1152
elif isinstance(cpp_a.argument, Argument):
1153
return compute_argument_yaml(
1155
schema_order=schema_order,
1156
kwarg_only_set=kwarg_only_set,
1157
out_arg_set=out_arg_set,
1158
name_to_field_name=name_to_field_name,
1162
def compute_argument_yaml(
1166
kwarg_only_set: set[str],
1167
out_arg_set: set[str],
1168
name_to_field_name: dict[str, str],
1170
arg: dict[str, object] = {
1171
"annotation": str(a.annotation) if a.annotation else None,
1172
"dynamic_type": dynamic_type(a.type),
1173
"is_nullable": a.type.is_nullable(),
1175
# legacy, report ints
1176
"type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
1178
if a.default is not None:
1179
arg["default"] = pythonify_default(
1180
cpp.default_expr(a.default, a.type, symint=False)
1182
if a.name in kwarg_only_set:
1183
arg["kwarg_only"] = True
1184
if a.name in out_arg_set:
1185
arg["output"] = True
1186
arg["allocate"] = True
1187
# See Note [name and field_name]
1188
if a.name in name_to_field_name:
1189
arg["field_name"] = name_to_field_name[a.name]
1190
# Historically, booleans don't get their size recorded, because it
1191
# is already built into the cpp type (e.g., std::array<bool, 4>)
1192
l = a.type.is_list_like()
1193
if l is not None and l.size is not None and str(l.elem) != "bool":
1194
arg["size"] = l.size
1198
@with_native_function
1199
def compute_declaration_yaml(f: NativeFunction) -> object:
1200
returns, name_to_field_name = compute_returns_yaml(f)
1202
# These sets are used to conveniently test if an argument is a
1203
# kwarg-only or out argument
1204
kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
1205
out_arg_set = {a.name for a in f.func.arguments.out}
1207
sig_group = CppSignatureGroup.from_native_function(
1208
f, method=False, fallback_binding=False
1210
cpp_args = sig_group.signature.arguments()
1212
compute_cpp_argument_yaml(
1215
kwarg_only_set=kwarg_only_set,
1216
out_arg_set=out_arg_set,
1217
name_to_field_name=name_to_field_name,
1219
for cpp_a in cpp_args
1222
schema_order_jit_arguments = list(f.func.schema_order_arguments())
1224
schema_order_arguments = [
1225
compute_argument_yaml(
1228
kwarg_only_set=kwarg_only_set,
1229
out_arg_set=out_arg_set,
1230
name_to_field_name=name_to_field_name,
1232
for a in schema_order_jit_arguments
1235
cpp_schema_order_types = [
1236
# NB: method here doesn't matter
1238
for a in schema_order_jit_arguments
1239
for r in cpp.argument(
1242
cpp_no_default_args=set(),
1245
has_tensor_options=False,
1249
# legacy, report ints
1250
cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
1251
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
1253
is_factory_method = (
1254
any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
1255
and Variant.method not in f.variants
1260
("name", cpp.name(f.func)),
1261
("operator_name", str(f.func.name.name)),
1262
("overload_name", str(f.func.name.overload_name)),
1263
("manual_kernel_registration", f.manual_kernel_registration),
1265
"category_override",
1266
f.category_override if f.category_override is not None else "",
1268
("schema_string", f"aten::{f.func}"),
1269
("arguments", arguments),
1270
("schema_order_cpp_signature", schema_order_cpp_signature),
1271
("schema_order_arguments", schema_order_arguments),
1272
("method_of", compute_method_of_yaml(f.variants)),
1274
("python_module", "" if f.python_module is None else f.python_module),
1275
("returns", returns),
1276
("inplace", f.func.name.name.inplace),
1277
("is_factory_method", is_factory_method),
1278
("abstract", f.is_abstract),
1279
("device_guard", f.device_guard),
1280
("with_gil", False),
1281
("deprecated", False),
1282
("has_math_kernel", f.has_composite_implicit_autograd_kernel),
1287
# See Note [Auto generated composite kernels]
1288
def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
1289
return (f.structured or f.structured_delegate is not None) and (
1290
f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
1294
@with_native_function_and_indices
1295
def compute_registration_declarations(
1296
f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
1298
name = dispatcher.name(f.func)
1299
returns_type = dispatcher.returns_type(
1301
).cpp_type_registration_declarations()
1302
args = dispatcher.arguments(f.func)
1303
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
1304
comment_data: dict[str, str] = {
1305
"schema": f"aten::{f.func}",
1306
# TODO: What exactly is the semantics of the 'dispatch' field?
1308
{k for k, v in backend_indices.items() if v.has_kernel(f)}
1309
!= {DispatchKey.CompositeImplicitAutograd}
1310
and {k for k, v in backend_indices.items() if v.has_kernel(f)}
1312
DispatchKey.CompositeImplicitAutograd,
1313
DispatchKey.CompositeImplicitAutogradNestedTensor,
1316
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
1318
return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
1322
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1326
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1329
def get_custom_build_selector(
1330
provided_op_registration_allowlist: list[str] | None,
1331
op_selection_yaml_path: str | None,
1332
) -> SelectiveBuilder:
1334
provided_op_registration_allowlist is not None
1335
and op_selection_yaml_path is not None
1337
"Both provided_op_registration_allowlist and "
1338
+ "op_selection_yaml_path can NOT be provided at the "
1342
op_registration_allowlist: set[str] | None = None
1343
if provided_op_registration_allowlist is not None:
1344
op_registration_allowlist = set(provided_op_registration_allowlist)
1346
if op_registration_allowlist is not None:
1347
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
1348
op_registration_allowlist,
1352
elif op_selection_yaml_path is not None:
1353
selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
1355
selector = SelectiveBuilder.get_nop_selector()
1360
def get_grouped_by_view_native_functions(
1361
native_functions: Sequence[NativeFunction],
1362
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
1363
def maybe_create_view_group(
1364
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
1365
) -> list[NativeFunction | NativeFunctionsViewGroup]:
1366
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
1367
if ViewSchemaKind.aliasing in d:
1368
view = d.pop(ViewSchemaKind.aliasing)
1369
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
1370
view_copy = d.pop(SchemaKind.functional, None)
1373
NativeFunctionsViewGroup(
1375
view_copy=view_copy,
1376
view_inplace=view_inplace,
1379
# Take the remaining functions that weren't part of the view group
1380
# and emit them separately
1381
funcs.extend(d.values())
1384
grouped_by_views: dict[
1385
FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
1386
] = defaultdict(dict)
1387
for f in native_functions:
1388
schema = f.func.view_signature()
1389
view_kind: ViewSchemaKind = f.view_schema_kind
1390
# We need to group up ops relevant to the same "view", consisting of:
1391
# view op (ViewSchemaKind.aliasing)
1392
# view_inplace op (ViewSchemaKind.aliasing_inplace)
1393
# view_copy op (SchemaKind.functional)
1394
if view_kind == ViewSchemaKind.non_aliasing:
1395
kind = f.func.kind()
1396
assert kind not in grouped_by_views[schema]
1397
grouped_by_views[schema][kind] = f
1400
view_kind not in grouped_by_views[schema]
1401
), f"{view_kind} already in {grouped_by_views[schema].keys()}"
1402
grouped_by_views[schema][view_kind] = f
1404
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
1407
def get_grouped_native_functions(
1408
native_functions: Sequence[NativeFunction],
1409
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1410
def flatten_pre_group(
1411
d: dict[SchemaKind, NativeFunction]
1412
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
1413
r = NativeFunctionsGroup.from_dict(d)
1415
# Invariant: any NativeFunctions that are code-generated
1416
# should have been grouped into NativeFunctionsGroup objects
1417
assert not any("generated" in f.tags for f in d.values())
1418
return list(d.values())
1422
# TODO: how come ValuesView isn't a Sequence lol
1423
pre_grouped_native_functions = pre_group_native_functions(native_functions)
1425
concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
1429
def get_ns_grouped_kernels(
1431
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1432
backend_indices: dict[DispatchKey, BackendIndex],
1433
native_function_decl_gen: Callable[
1434
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1435
] = dest.compute_native_function_declaration,
1436
) -> dict[str, list[str]]:
1437
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1438
for f in grouped_native_functions:
1439
native_function_namespaces = set()
1440
dispatch_keys = set()
1441
for dispatch_key, backend_idx in backend_indices.items():
1442
backend_metadata = backend_idx.get_kernel(f)
1443
if backend_metadata:
1444
namespace = backend_metadata.cpp_namespace
1445
dispatch_keys.add(dispatch_key)
1446
native_function_namespaces.add(namespace)
1448
namespace = DEFAULT_KERNEL_NAMESPACE
1450
len(native_function_namespaces) <= 1
1451
), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
1452
ns_grouped_kernels[namespace].extend(
1453
native_function_decl_gen(f, backend_idx)
1455
return ns_grouped_kernels
1458
def get_native_function_declarations_from_ns_grouped_kernels(
1460
ns_grouped_kernels: dict[str, list[str]],
1462
declarations: list[str] = []
1464
for namespace, kernels in ns_grouped_kernels.items():
1465
ns_helper = NamespaceHelper(
1466
namespace_str=namespace,
1470
# Convert to a set first to remove duplicate kernel names. Backends are
1471
# allowed to repeat kernel names; only generate the declaration once!
1472
ordered_kernels = list(OrderedDict.fromkeys(kernels))
1473
declarations.extend(
1476
{newline.join(ordered_kernels)}
1485
# Return native function declarations grouped by their namespaces.
1486
def get_native_function_declarations(
1488
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1489
backend_indices: dict[DispatchKey, BackendIndex],
1490
native_function_decl_gen: Callable[
1491
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
1492
] = dest.compute_native_function_declaration,
1495
Generate kernel declarations, in `NativeFunction(s).h`.
1496
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
1497
:param backend_indices: kernel collections grouped by dispatch key.
1498
:param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
1499
:return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
1502
ns_grouped_kernels = get_ns_grouped_kernels(
1503
grouped_native_functions=grouped_native_functions,
1504
backend_indices=backend_indices,
1505
native_function_decl_gen=native_function_decl_gen,
1507
return get_native_function_declarations_from_ns_grouped_kernels(
1508
ns_grouped_kernels=ns_grouped_kernels
1512
def get_kernel_namespace(
1513
*, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
1515
backend_metadata = backend_idx.get_kernel(f)
1516
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
1517
f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
1518
f"with dispatch key {backend_idx.dispatch_key}"
1519
f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
1522
backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
1526
# Return native function definitions grouped by dispatch key and custom namespace.
1527
# Used in RegisterDispatchKey.cpp and etc.
1528
def get_native_function_definitions(
1531
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1532
dispatch_key: DispatchKey,
1533
backend_idx: BackendIndex,
1534
selector: SelectiveBuilder,
1537
skip_dispatcher_op_registration: bool,
1538
gen_dispatch_helpers: bool,
1540
definitions: list[str] = []
1541
ns_definitions: dict[str, list[str]] = defaultdict(list)
1542
anonymous_definitions: dict[str, list[str]] = defaultdict(list)
1543
registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
1545
ns_gen = dest.RegisterDispatchKey(
1547
Target.NAMESPACED_DEFINITION,
1551
class_method_name=None,
1552
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1554
anonymous_gen = dest.RegisterDispatchKey(
1556
Target.ANONYMOUS_DEFINITION,
1560
class_method_name=None,
1561
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1563
reg_gen = dest.RegisterDispatchKey(
1565
Target.REGISTRATION,
1569
class_method_name=None,
1570
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
1572
for f in grouped_native_functions:
1573
kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1577
ns_definitions[kernel_namespace].extend(
1580
anonymous_definitions[kernel_namespace].extend(
1584
f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
1586
if namespace not in registrations[kernel_namespace]:
1587
registrations[kernel_namespace] = defaultdict(list)
1588
registrations[kernel_namespace][namespace].extend(
1592
for kernel_namespace in ns_definitions:
1593
if len(ns_definitions[kernel_namespace]) == 0:
1595
ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
1596
registration_body = ""
1597
for namespace in registrations[kernel_namespace]:
1598
if not registrations[kernel_namespace][namespace]:
1600
registration_body += f"""
1601
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
1602
{newline.join(registrations[kernel_namespace][namespace])}
1605
fm.substitute_with_template(
1606
"RegisterDispatchDefinitions.ini",
1608
"ns_prologue": ns_helper.prologue,
1609
"ns_epilogue": ns_helper.epilogue,
1610
"dispatch_helpers": dest.gen_registration_helpers(backend_idx)
1611
if gen_dispatch_helpers
1613
"dispatch_anonymous_definitions": anonymous_definitions[
1616
"static_init_dispatch_registrations": ""
1617
if skip_dispatcher_op_registration
1618
else registration_body,
1619
"deferred_dispatch_registrations": "",
1620
"dispatch_namespace": dispatch_key.lower(),
1621
"dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
1629
# Return native function declarations grouped by dispatch key and custom namespace.
1630
# Used in CPUFunctions_inl.h and etc.
1631
def get_namespaced_declaration(
1633
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1634
dispatch_key: DispatchKey,
1635
backend_idx: BackendIndex,
1636
selector: SelectiveBuilder,
1640
declarations: list[str] = []
1641
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
1643
func = dest.RegisterDispatchKey(
1645
Target.NAMESPACED_DECLARATION,
1648
class_method_name=None,
1649
skip_dispatcher_op_registration=False,
1652
for f in grouped_native_functions:
1653
namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
1654
"native", dispatch_key.lower()
1657
ns_grouped_kernels[namespace].extend(
1661
for namespace, kernels in ns_grouped_kernels.items():
1662
if len(kernels) == 0:
1664
ns_helper = NamespaceHelper(
1665
namespace_str=namespace, entity_name="", max_level=3
1667
ordered_kernels = list(OrderedDict.fromkeys(kernels))
1668
declarations.extend(
1671
{newline.join(ordered_kernels)}
1680
# Return native function schema registration code for aten and other namespaces.
1681
def get_native_function_schema_registrations(
1683
native_functions: Sequence[NativeFunction],
1684
schema_selector: SelectiveBuilder,
1685
) -> tuple[list[str], str]:
1686
ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
1687
for native_function in native_functions:
1688
ns_native_functions[native_function.namespace].append(native_function)
1689
schema_registrations = ""
1690
aten_schema_registrations = []
1691
custom_namespace = None
1692
for namespace, funcs in ns_native_functions.items():
1693
schema_registrations_body = list(
1694
mapMaybe(RegisterSchema(schema_selector), funcs)
1696
# NB: we have to separate aten namespace registration from other namespaces,
1697
# because in the template we hardcoded an operator for ATen already.
1698
if namespace == "aten":
1699
aten_schema_registrations = schema_registrations_body
1701
custom_namespace = namespace
1703
# if the namespace is predefined, we should use define a library fragment
1704
# instead of a new library
1705
torch_library_macro = (
1706
"TORCH_LIBRARY_FRAGMENT"
1707
if namespace in FRAGMENT_NAMESPACES
1708
else "TORCH_LIBRARY"
1710
schema_registrations += f"""
1711
{torch_library_macro}({custom_namespace}, m) {{
1712
{tab.join(schema_registrations_body)}
1714
return (aten_schema_registrations, schema_registrations)
1717
def gen_aggregated_headers(
1719
native_functions: Sequence[NativeFunction],
1720
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1721
structured_native_functions: Sequence[NativeFunctionsGroup],
1722
static_dispatch_idx: list[BackendIndex],
1723
selector: SelectiveBuilder,
1724
backend_indices: dict[DispatchKey, BackendIndex],
1725
cpu_fm: FileManager,
1726
cuda_fm: FileManager,
1727
functions_keys: set[DispatchKey],
1728
dispatch_keys: Sequence[DispatchKey],
1731
# Buck doesn't support dynamic output files, so we aggregate all operator
1732
# headers into a single file
1734
"NativeMetaFunctions.h",
1736
"NativeMetaFunctions_includes": [],
1737
"NativeMetaFunctions_declarations": list(
1738
mapMaybe(compute_meta_function_declaration, structured_native_functions)
1742
method_native_functions = [
1743
fn for fn in native_functions if Variant.method in fn.variants
1745
non_method_native_functions = [
1746
fn for fn in native_functions if fn not in method_native_functions
1749
"MethodOperators.h",
1751
"MethodOperators_includes": [],
1752
"MethodOperators_declarations": list(
1756
static_dispatch_backend_indices=static_dispatch_idx,
1758
method_native_functions,
1766
"Operators_includes": ["#include <ATen/MethodOperators.h>"],
1767
"Operators_declarations": list(
1771
static_dispatch_backend_indices=static_dispatch_idx,
1773
non_method_native_functions,
1781
"static_dispatch_extra_headers": static_dispatch_extra_headers(
1784
"Functions_includes": ["#include <ATen/Operators.h>"],
1785
"Functions_declarations": list(
1793
declarations = get_native_function_declarations(
1794
grouped_native_functions=grouped_native_functions,
1795
backend_indices=backend_indices,
1798
"NativeFunctions.h",
1800
"NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
1801
"NativeFunctions_declarations": declarations,
1805
for dispatch_key in dispatch_keys:
1806
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1807
if dispatch_key in functions_keys:
1808
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1810
fm.write_with_template(
1811
f"{dispatch_key}Functions.h",
1812
"DispatchKeyFunctions.h",
1814
"dispatch_key": str(dispatch_key),
1815
"inline_headers": inl_headers,
1818
fm.write_with_template(
1819
f"{dispatch_key}Functions_inl.h",
1820
"DispatchKeyFunctions_inl.h",
1822
"DispatchKeyFunctions_inl_includes": [],
1823
"dispatch_namespace": dispatch_key.lower(),
1824
"dispatch_namespaced_declarations": get_namespaced_declaration(
1825
grouped_native_functions=grouped_native_functions,
1826
dispatch_key=dispatch_key,
1827
backend_idx=backend_indices[dispatch_key],
1838
def gen_per_operator_headers(
1840
native_functions: Sequence[NativeFunction],
1841
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
1842
static_dispatch_idx: list[BackendIndex],
1843
selector: SelectiveBuilder,
1844
backend_indices: dict[DispatchKey, BackendIndex],
1845
cpu_fm: FileManager,
1846
cuda_fm: FileManager,
1847
ops_fm: FileManager,
1848
functions_keys: set[DispatchKey],
1849
dispatch_keys: Sequence[DispatchKey],
1852
# For CMake builds, split operator declarations into separate headers in
1853
# the ATen/ops folder to split up header dependencies
1854
functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
1855
for fn in native_functions:
1856
functions_by_root_name[fn.root_name].append(fn)
1858
grouped_functions_by_root_name: dict[
1859
str, list[NativeFunction | NativeFunctionsGroup]
1860
] = defaultdict(list)
1861
for group in grouped_native_functions:
1862
name = group.root_name
1863
grouped_functions_by_root_name[name].append(group)
1865
for name, functions in functions_by_root_name.items():
1866
ops_fm.write_with_template(
1870
"declarations": list(
1874
static_dispatch_backend_indices=static_dispatch_idx,
1882
ops_fm.write_with_template(
1886
"static_dispatch_ops_headers": list(
1888
lambda fn: static_dispatch_ops_header(
1889
fn, backend_index=static_dispatch_idx
1894
"operator_includes": f"#include <ATen/ops/{name}_ops.h>",
1895
"function_definitions": list(
1904
grouped_functions = grouped_functions_by_root_name.get(name, [])
1905
structured_functions = [
1907
for fn in grouped_functions
1908
if isinstance(fn, NativeFunctionsGroup) and fn.structured
1910
is_structured = len(structured_functions) > 0
1913
ops_fm.write_with_template(
1915
"NativeMetaFunction.h",
1917
"meta_function_declarations": list(
1919
compute_meta_function_declaration, structured_functions
1924
declarations = get_native_function_declarations(
1925
grouped_native_functions=grouped_functions,
1926
backend_indices=backend_indices,
1927
native_function_decl_gen=dest.compute_native_function_declaration,
1929
ops_fm.write_with_template(
1934
f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
1936
"native_function_declarations": declarations,
1940
for category, suffix in [
1942
("Operators", "_ops"),
1943
("NativeMetaFunctions", "_meta"),
1944
("NativeFunctions", "_native"),
1949
f"{category}_includes": [
1950
f"#include <ATen/ops/{name}{suffix}.h>"
1951
for name in sorted(functions_by_root_name.keys())
1953
f"{category}_declarations": [],
1957
for dispatch_key in dispatch_keys:
1958
if dispatch_key not in functions_keys:
1961
dispatch_namespace = dispatch_key.lower()
1964
for name, functions in functions_by_root_name.items():
1965
grouped_functions = grouped_functions_by_root_name.get(name, [])
1966
declarations = list(
1968
dest.RegisterDispatchKey(
1969
backend_indices[dispatch_key],
1970
Target.NAMESPACED_DECLARATION,
1974
class_method_name=None,
1975
skip_dispatcher_op_registration=False,
1981
if len(declarations) == 0:
1984
dispatch_names.append(name)
1985
ops_fm.write_with_template(
1986
f"{name}_{dispatch_namespace}_dispatch.h",
1987
"DispatchKeyFunction.h",
1989
"dispatch_namespace": dispatch_namespace,
1990
"dispatch_namespaced_declarations": declarations,
1994
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
1995
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
1997
fm.write_with_template(
1998
f"{dispatch_key}Functions.h",
1999
"DispatchKeyFunctions.h",
2001
"dispatch_key": str(dispatch_key),
2002
"inline_headers": inl_headers,
2005
fm.write_with_template(
2006
f"{dispatch_key}Functions_inl.h",
2007
"DispatchKeyFunctions_inl.h",
2009
"dispatch_namespace": dispatch_namespace,
2010
"DispatchKeyFunctions_inl_includes": [
2011
f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
2012
for name in sorted(dispatch_names)
2014
"dispatch_namespaced_declarations": [],
2020
"MethodOperators.h",
2022
"MethodOperators_includes": sorted(
2023
f"#include <ATen/ops/{name}_ops.h>"
2024
for name, functions in functions_by_root_name.items()
2025
if any(Variant.method in fn.variants for fn in functions)
2027
"MethodOperators_declarations": [],
2034
native_functions: Sequence[NativeFunction],
2035
valid_tags: set[str],
2036
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2037
structured_native_functions: Sequence[NativeFunctionsGroup],
2038
static_dispatch_idx: list[BackendIndex],
2039
selector: SelectiveBuilder,
2040
backend_indices: dict[DispatchKey, BackendIndex],
2041
core_fm: FileManager,
2042
cpu_fm: FileManager,
2043
cuda_fm: FileManager,
2044
ops_fm: FileManager,
2045
dispatch_keys: Sequence[DispatchKey],
2046
functions_keys: set[DispatchKey],
2048
per_operator_headers: bool,
2050
if per_operator_headers:
2051
gen_per_operator_headers(
2052
native_functions=native_functions,
2053
grouped_native_functions=grouped_native_functions,
2054
static_dispatch_idx=static_dispatch_idx,
2056
backend_indices=backend_indices,
2060
dispatch_keys=dispatch_keys,
2061
functions_keys=functions_keys,
2065
gen_aggregated_headers(
2066
native_functions=native_functions,
2067
grouped_native_functions=grouped_native_functions,
2068
structured_native_functions=structured_native_functions,
2069
static_dispatch_idx=static_dispatch_idx,
2071
backend_indices=backend_indices,
2074
dispatch_keys=dispatch_keys,
2075
functions_keys=functions_keys,
2082
"tensor_method_declarations": list(
2084
ComputeTensorMethod(
2085
target=Target.DECLARATION,
2086
static_dispatch_backend_indices=static_dispatch_idx,
2091
"tensor_method_definitions": list(
2093
ComputeTensorMethod(
2094
target=Target.DEFINITION,
2095
static_dispatch_backend_indices=static_dispatch_idx,
2104
"RedispatchFunctions.h",
2106
"function_redispatch_definitions": list(
2107
mapMaybe(ComputeRedispatchFunction(), native_functions)
2113
"RegistrationDeclarations.h",
2115
"registration_declarations": [
2116
compute_registration_declarations(f, backend_indices)
2117
for f in native_functions
2123
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
2126
def gen_aten_interned_strings() -> dict[str, str]:
2127
attrs: set[str] = set() # All function argument names
2128
names = set() # All ATen function names
2129
for func in native_functions:
2130
names.add(str(func.func.name.name))
2131
# Some operators don't have a functional variant but we still create a
2132
# symbol without the underscore
2133
names.add(func.func.name.name.base)
2135
attrs.update(arg.name for arg in func.func.schema_order_arguments())
2137
# These are keywords in C++, so aren't valid symbol names
2138
# https://en.cppreference.com/w/cpp/language/operator_alternative
2154
"aten_symbols": " \\\n".join(
2155
[f"_(aten, {name})" for name in sorted(names)]
2157
"attr_symbols": " \\\n".join(
2158
[f"_(attr, {name})" for name in sorted(attrs)]
2162
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
2164
def gen_tags_enum() -> dict[str, str]:
2165
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
2167
core_fm.write("enum_tag.h", gen_tags_enum)
2170
def gen_source_files(
2172
native_functions: Sequence[NativeFunction],
2173
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
2174
structured_native_functions: Sequence[NativeFunctionsGroup],
2175
view_groups: Sequence[NativeFunctionsViewGroup],
2176
selector: SelectiveBuilder,
2177
static_dispatch_idx: list[BackendIndex],
2178
backend_indices: dict[DispatchKey, BackendIndex],
2179
aoti_fm: FileManager,
2180
core_fm: FileManager,
2181
cpu_fm: FileManager,
2182
cpu_vec_fm: FileManager,
2183
cuda_fm: FileManager,
2184
dispatch_keys: Sequence[DispatchKey],
2185
functions_keys: set[DispatchKey],
2187
force_schema_registration: bool,
2188
per_operator_headers: bool,
2189
skip_dispatcher_op_registration: bool,
2190
update_aoti_c_shim: bool,
2192
extra_cuda_headers = """\
2193
#include <c10/cuda/CUDAGuard.h>
2194
#include <ATen/cuda/ATenCUDAGeneral.h>
2195
#include <ATen/cuda/CUDADevice.h>
2196
#include <ATen/cuda/CUDAContext.h>"""
2198
extra_cuda_headers = """\
2199
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
2200
#include <ATen/hip/ATenHIPGeneral.h>
2201
#include <ATen/hip/HIPDevice.h>
2202
#include <ATen/hip/HIPContext.h>"""
2204
for dispatch_key in dispatch_keys:
2205
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
2207
if per_operator_headers:
2209
def operator_headers() -> list[str]:
2211
for g in grouped_native_functions:
2212
is_registered = False
2213
if backend_index.has_kernel(g):
2214
is_registered = True
2215
# The above has_kernel test on a group will only test for
2216
# the existence of out dispatch, because that's how
2217
# structured kernels work. But sometimes functions can be
2218
# grouped but not be structured, and then you need to check
2219
# each individual piece, as they may have manual dispatch
2221
elif isinstance(g, NativeFunctionsGroup) and any(
2222
backend_index.has_kernel(fn) for fn in g.functions()
2224
is_registered = True
2225
# TODO: this condition is a bit questionable
2226
# (It has to do with the fact that structured kernels get generated kernels
2227
# to the Meta + CompositeExplicitAutogradNonFunctional keys).
2228
elif g.structured and dispatch_key in (
2230
DispatchKey.CompositeExplicitAutogradNonFunctional,
2232
is_registered = True
2233
if not is_registered:
2236
headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
2239
== DispatchKey.CompositeExplicitAutogradNonFunctional
2241
headers.append(f"#include <ATen/ops/{g.root_name}.h>")
2242
if dispatch_key in functions_keys:
2244
f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
2247
return sorted(set(headers))
2251
def operator_headers() -> list[str]:
2252
headers = ["#include <ATen/NativeFunctions.h>"]
2253
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
2254
headers.append("#include <ATen/Functions.h>")
2255
if dispatch_key in functions_keys:
2256
headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
2259
backend_index = backend_indices[dispatch_key]
2260
ns_grouped_native_functions = defaultdict(list)
2261
for grouped_native_function in grouped_native_functions:
2263
grouped_native_function.namespace
2264
if isinstance(grouped_native_function, NativeFunction)
2265
else grouped_native_function.functional.namespace
2267
ns_grouped_native_functions[namespace].append(grouped_native_function)
2269
dispatch_namespace = str(dispatch_key).lower()
2271
# CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
2272
# compilation will fail when `-Werror=unused-function` flag is set
2273
gen_dispatch_helpers: bool = (
2274
dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
2277
dispatch_definitions = get_native_function_definitions(
2279
grouped_native_functions=grouped_native_functions,
2280
dispatch_key=dispatch_key,
2281
backend_idx=backend_index,
2285
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
2286
gen_dispatch_helpers=gen_dispatch_helpers,
2288
fm.write_with_template(
2289
f"Register{dispatch_key}.cpp",
2290
"RegisterDispatchKey.cpp",
2292
"extra_cuda_headers": extra_cuda_headers
2293
if is_cuda_dispatch_key(dispatch_key)
2295
"external_backend_headers": "",
2296
"dispatch_headers": dest.gen_registration_headers(
2297
backend_index, per_operator_headers, rocm
2299
"ops_headers": operator_headers(),
2300
"dispatch_helpers": "",
2301
"dispatch_definitions": dispatch_definitions,
2305
for g in structured_native_functions:
2306
if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
2308
name = g.functional.func.name.name
2309
if dispatch_key is DispatchKey.CPU:
2311
fm.write_with_template(
2312
f"UfuncCPU_{name}.cpp",
2315
"meta_declaration": compute_meta_function_declaration(g),
2316
"native_declaration": dest.compute_native_function_declaration(
2317
g, backend_indices[dispatch_key]
2319
"native_definitions": dest.compute_ufunc_cpu(g),
2322
cpu_vec_fm.write_with_template(
2323
f"UfuncCPUKernel_{name}.cpp",
2324
"UfuncCPUKernel.cpp",
2327
"native_definitions": dest.compute_ufunc_cpu_kernel(g),
2330
elif dispatch_key is DispatchKey.CUDA:
2331
cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
2333
cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
2334
fm.write_with_template(
2335
f"UfuncCUDA_{name}.cu",
2339
"cuda_headers": cuda_headers,
2340
"meta_declaration": compute_meta_function_declaration(g),
2341
"native_declaration": dest.compute_native_function_declaration(
2342
g, backend_indices[dispatch_key]
2344
"native_definitions": dest.compute_ufunc_cuda(g),
2348
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
2350
structured_func_group_dict = {}
2351
for func_group in structured_native_functions:
2352
for func in func_group.functions():
2353
if func.structured_delegate is not None:
2354
structured_func_group_dict[func.structured_delegate] = func_group
2357
if dispatch_key in (DispatchKey.CPU, DispatchKey.CUDA):
2359
for func in native_functions:
2360
op_name = get_fallback_op_name(func)
2361
if op_name in inductor_fallback_ops:
2362
fallbacks[op_name] = func
2363
fallback_native_functions = tuple(
2364
value for _, value in sorted(fallbacks.items())
2367
# header files were checked in for ABI-compatiblilty checking
2368
header_file_name = f"c_shim_{dispatch_key.lower()}.h"
2369
new_header = gen_aoti_c_shim(
2370
fallback_native_functions,
2371
structured_func_group_dict,
2377
if update_aoti_c_shim:
2385
os.path.join(aoti_fm.install_dir, header_file_name)
2387
old_header = old_file.read()
2389
old_header == new_header
2392
WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This
2393
indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
2394
Only in a limited number of situations, this is allowed:
2396
1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
2397
If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to update the existing
2400
2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
2401
change in the AOTInductor land. In this case, you need to keep a manual copy of that existing
2402
fallback op in a file, e.g. torch/csrc/inductor/aoti_torch/c/shim.h, bump up the version
2403
number of that fallback op in the newly generated C shim files, and update the cpp wrapper
2404
codegen to generate the correct cpp call for this op. Contact AOTInductor team for assistance.
2407
except FileNotFoundError:
2409
f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
2412
# cpp files are always generated on-the-fly
2413
def headers_for_aoti() -> str:
2415
for func in fallback_native_functions:
2416
header = get_header_for_aoti(
2417
func, structured_func_group_dict, dispatch_key, backend_indices
2419
if header is not None:
2420
headers.append(header)
2421
return "\n".join(sorted(set(headers)))
2424
extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
2428
f"c_shim_{dispatch_key.lower()}.cpp",
2429
lambda: gen_aoti_c_shim(
2430
fallback_native_functions,
2431
structured_func_group_dict,
2435
includes=headers_for_aoti() + "\n" + extra_headers,
2441
# BackendSelect is generated specially
2442
def gen_backend_select() -> dict[str, list[str]]:
2444
fn for fn in native_functions if needs_backend_select(fn, selector)
2448
f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
2450
"backend_select_method_definitions": list(
2452
ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
2455
"backend_select_function_registrations": list(
2457
ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
2462
cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
2464
schema_selector = selector
2465
if force_schema_registration:
2466
schema_selector = SelectiveBuilder.get_nop_selector()
2469
aten_schema_registrations,
2470
schema_registrations,
2471
) = get_native_function_schema_registrations(
2472
native_functions=native_functions, schema_selector=schema_selector
2475
"RegisterSchema.cpp",
2477
"aten_schema_registrations": []
2478
if skip_dispatcher_op_registration
2479
else aten_schema_registrations,
2480
"schema_registrations": []
2481
if skip_dispatcher_op_registration
2482
else schema_registrations,
2487
fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2491
cpu_fm.write_sharded(
2495
env_callable=lambda fn: {
2496
"operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
2500
static_dispatch_backend_indices=static_dispatch_idx,
2505
"static_dispatch_extra_headers": static_dispatch_extra_headers(
2513
"static_dispatch_extra_headers",
2517
cpu_fm.write("Functions.cpp", dict)
2519
core_fm.write("TensorMethods.cpp", dict)
2524
"aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
2528
def functionalization_env_callable(
2529
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2530
) -> dict[str, list[str]]:
2532
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
2534
if isinstance(g, NativeFunctionsViewGroup):
2535
# view ops always get a functionalization kernel
2537
f"#include <ATen/ops/{g.view.root_name}_native.h>",
2538
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
2540
if g.view_copy is not None:
2542
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
2543
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
2546
elif isinstance(g, NativeFunctionsGroup):
2548
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
2549
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
2550
f"#include <ATen/ops/{g.out.root_name}_native.h>",
2551
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
2553
if g.inplace is not None:
2555
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
2556
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
2558
if g.mutable is not None:
2560
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
2561
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
2566
f"#include <ATen/ops/{g.root_name}_native.h>",
2567
f"#include <ATen/ops/{g.root_name}_ops.h>",
2571
"ops_headers": gen_op_headers(g),
2572
"func_definitions": gen_functionalization_definition(
2576
"func_registrations": gen_functionalization_registration(
2579
backend_indices[DispatchKey.CompositeImplicitAutograd],
2584
NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
2585
] = list(structured_native_functions) + list(
2586
view_groups # type: ignore[assignment, arg-type, operator]
2588
# Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
2589
# The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
2590
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
2591
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
2592
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
2593
structured_map: dict[OperatorName, NativeFunction] = {
2595
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
2597
view_map: dict[OperatorName, NativeFunction] = {
2598
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
2600
for f in native_functions:
2601
if f.func.name not in structured_map and f.func.name not in view_map:
2602
all_groups.append(f)
2604
cpu_fm.write_sharded(
2605
"RegisterFunctionalization.cpp",
2608
env_callable=functionalization_env_callable,
2613
"func_registrations",
2614
"func_add_back_views_definitions",
2615
"func_add_back_views_registrations",
2620
"FunctionalInverses.h",
2622
"view_inverse_declarations": list(
2624
lambda g: gen_functionalization_view_inverse_declaration(
2633
# Note [view_copy NativeFunctions]
2634
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
2635
# needs to have a corresponding non-aliasing {view}_copy variant.
2636
# Backends that use functionalization and don't know how to handle aliasing ops
2637
# are expected to implement kernels for these {view}_copy kernels instead.
2638
# The code for {view}_copy operators in core is pretty boilerplate-heavy however,
2639
# so we codegen the following:
2640
# (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
2641
# These are never explicitly invoked by the functionalization pass,
2642
# but they could theoretically be called from user code (I added these kernels for completeness,
2643
# since the ops are part of the public API).
2644
# (2) A derivative formula for every {view}_copy operator
2645
# {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
2646
# so rather than stamping all of the entries out in derivatives.yaml,
2647
# we codegen them in.
2648
# This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
2650
"CompositeViewCopyKernels.cpp",
2654
f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2655
# NB: this include is important as it ensures we
2656
# set the visibility on generated view_copy kernels
2658
f"#include <ATen/ops/{f.root_name}_native.h>"
2660
[g.view] if g.view_copy is None else [g.view, g.view_copy]
2663
for g in view_groups
2667
f"#include <ATen/ops/{f.root_name}_ops.h>\n"
2668
# NB: this include is also important for correct visibility
2669
f"#include <ATen/ops/{f.root_name}_native.h>"
2670
for f in [g.inplace, g.mutable, g.functional]
2671
if f is not None and "generated" not in f.tags
2673
for g in structured_native_functions
2675
"CompositeViewCopyKernel_Definitions": list(
2677
GenCompositeViewCopyKernel(
2679
DispatchKey.CompositeExplicitAutogradNonFunctional
2685
"GeneratedCompositeFunctional_Definitions": list(
2687
gen_composite_functional_kernel,
2688
structured_native_functions,
2691
"GeneratedCompositeOut_Definitions": list(
2693
gen_composite_out_kernel,
2694
structured_native_functions,
2701
def gen_declarations_yaml(
2702
cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
2705
"Declarations.yaml",
2706
lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
2710
def get_torchgen_root() -> Path:
2712
If you're depending on torchgen out-of-tree, you can use the root to figure
2713
out the path to native_functions.yaml
2715
return Path(__file__).parent.resolve()
2719
parser = argparse.ArgumentParser(description="Generate ATen source files")
2720
parser.add_argument(
2723
help="path to source directory for ATen",
2724
default="aten/src/ATen",
2726
parser.add_argument(
2728
"--output-dependencies",
2729
help="output a list of dependencies into the given file and exit",
2731
parser.add_argument(
2733
action="store_true",
2734
help="run without writing any files (still updates outputs)",
2736
parser.add_argument(
2737
"--per-operator-headers",
2738
action="store_true",
2739
help="generate separate headers per operator in ATen/ops",
2741
parser.add_argument(
2745
help="output directory",
2746
default="build/aten/src/ATen",
2748
parser.add_argument(
2749
"--aoti-install-dir",
2750
"--aoti_install_dir",
2751
help="output directory for AOTInductor shim",
2752
default="torch/csrc/inductor/aoti_torch/generated",
2754
parser.add_argument(
2756
action="store_true",
2757
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
2759
parser.add_argument(
2761
action="store_true",
2762
help="Generate MPS registration code when set",
2764
# TODO: --op-registration-whitelist will be removed when all call-sites
2765
# for gen.py are moved over to using the operator YAML file for mobile
2767
parser.add_argument(
2768
"--op-registration-whitelist",
2769
"--op_registration_whitelist",
2771
help="filter op registrations by the whitelist (if set); "
2772
"each item is `namespace`::`operator name` without overload name; "
2773
"e.g.: aten::empty aten::conv2d ...",
2775
parser.add_argument(
2776
"--op-selection-yaml-path",
2777
"--op_selection_yaml_path",
2778
help="Provide a path to the operator selection (for custom build) YAML "
2779
"that contains the information about the set of selected operators "
2780
"and their categories (training, ...). Each operator is either a "
2781
"full operator name with overload or just a bare operator name. "
2782
"The operator names also contain the namespace prefix (e.g. aten::)",
2784
parser.add_argument(
2785
"--backend-whitelist",
2786
"--backend_whitelist",
2788
help="filter dispatch backend by the whitelist (if set), "
2789
"e.g.: CPU CUDA QuantizedCPU ...",
2791
parser.add_argument(
2792
"--static-dispatch-backend",
2793
"--static_dispatch_backend",
2795
help="generate static dispatch code for the specific backend (if set)",
2797
parser.add_argument(
2798
"--skip-dispatcher-op-registration",
2799
"--skip_dispatcher_op_registration",
2800
action="store_true",
2801
help="Avoid registering operators into the dispatcher.",
2803
parser.add_argument(
2804
"--force-schema-registration",
2805
"--force_schema_registration",
2806
action="store_true",
2807
help="force it to generate schema-only registrations for all ops, including"
2808
"those that are not listed on --op-registration-whitelist",
2810
parser.add_argument(
2814
choices=["headers", "sources", "declarations_yaml"],
2815
default=["headers", "sources", "declarations_yaml"],
2816
help="Generate only a subset of files",
2818
parser.add_argument(
2819
"--update-aoti-c-shim",
2820
action="store_true",
2821
help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
2822
"WARNING: Do not use this unless you are sure what you are doing!!!",
2825
options = parser.parse_args()
2827
selector = get_custom_build_selector(
2828
options.op_registration_whitelist,
2829
options.op_selection_yaml_path,
2832
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
2833
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
2835
from torchgen.model import dispatch_keys
2837
# TODO: stop generating CUDA kernels for non-CUDA builds
2840
ignore_keys.add(DispatchKey.MPS)
2842
if DispatchKey.MPS in dispatch_keys:
2843
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
2845
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
2846
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
2847
native_functions, backend_indices = (
2848
parsed_yaml.native_functions,
2849
parsed_yaml.backend_indices,
2852
grouped_native_functions = get_grouped_native_functions(native_functions)
2854
structured_native_functions = [
2855
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
2857
native_functions_with_view_groups = get_grouped_by_view_native_functions(
2862
for g in native_functions_with_view_groups
2863
if isinstance(g, NativeFunctionsViewGroup)
2866
# NB: It is mandatory to NOT use os.path.join here, as the install directory
2867
# will eventually be ingested by cmake, which does not respect Windows style
2868
# path slashes. If you switch this to use os.path.join, you'll get an error
2871
# Syntax error in cmake code when parsing string
2873
# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
2875
# Invalid character escape '\c'.
2876
core_install_dir = f"{options.install_dir}/core"
2877
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
2878
ops_install_dir = f"{options.install_dir}/ops"
2879
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
2880
aoti_install_dir = f"{options.aoti_install_dir}"
2881
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
2883
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
2884
cpu_fm = make_file_manager(options=options)
2885
cpu_vec_fm = make_file_manager(options=options)
2886
cuda_fm = make_file_manager(options=options)
2887
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
2888
aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
2890
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
2891
# for them; this is the set
2895
DispatchKey.CompositeImplicitAutograd,
2896
DispatchKey.CompositeImplicitAutogradNestedTensor,
2897
DispatchKey.CompositeExplicitAutograd,
2898
DispatchKey.CompositeExplicitAutogradNonFunctional,
2902
functions_keys.add(DispatchKey.MPS)
2904
if options.backend_whitelist:
2907
for k in dispatch_keys
2908
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
2911
static_dispatch_idx: list[BackendIndex] = []
2912
if options.static_dispatch_backend:
2913
static_dispatch_idx = [
2914
backend_indices[DispatchKey.parse(key)]
2915
for key in options.static_dispatch_backend
2917
for key in options.static_dispatch_backend:
2918
dp_key = DispatchKey.parse(key)
2919
if dp_key not in functions_keys:
2920
functions_keys.add(dp_key)
2922
if "sources" in options.generate:
2924
native_functions=native_functions,
2925
grouped_native_functions=grouped_native_functions,
2926
structured_native_functions=structured_native_functions,
2927
view_groups=view_groups,
2929
static_dispatch_idx=static_dispatch_idx,
2930
backend_indices=backend_indices,
2934
cpu_vec_fm=cpu_vec_fm,
2936
dispatch_keys=dispatch_keys,
2937
functions_keys=functions_keys,
2939
force_schema_registration=options.force_schema_registration,
2940
per_operator_headers=options.per_operator_headers,
2941
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
2942
update_aoti_c_shim=options.update_aoti_c_shim,
2945
if "headers" in options.generate:
2947
native_functions=native_functions,
2948
valid_tags=valid_tags,
2949
grouped_native_functions=grouped_native_functions,
2950
structured_native_functions=structured_native_functions,
2951
static_dispatch_idx=static_dispatch_idx,
2953
backend_indices=backend_indices,
2958
dispatch_keys=dispatch_keys,
2959
functions_keys=functions_keys,
2961
per_operator_headers=options.per_operator_headers,
2964
if "declarations_yaml" in options.generate:
2965
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
2967
if options.output_dependencies:
2968
depfile_path = Path(options.output_dependencies).resolve()
2969
depfile_name = depfile_path.name
2970
depfile_stem = depfile_path.stem
2974
(cpu_vec_fm, "cpu_vec_"),
2979
varname = prefix + depfile_stem
2980
path = depfile_path.parent / (prefix + depfile_name)
2981
fm.write_outputs(varname, str(path))
2984
if __name__ == "__main__":