1
from __future__ import annotations
4
from dataclasses import dataclass
5
from typing import Sequence
7
from torchgen.api.types import DispatcherSignature
8
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
9
from torchgen.context import method_with_native_function
10
from torchgen.model import (
24
from torchgen.utils import mapMaybe
27
base_type_to_c_type = {
28
BaseTy.Tensor: "AtenTensorHandle",
29
BaseTy.bool: "int32_t",
30
BaseTy.int: "int64_t",
31
BaseTy.SymInt: "int64_t",
32
BaseTy.Scalar: "double",
33
BaseTy.float: "double",
34
BaseTy.str: "const char*",
35
BaseTy.DeviceIndex: "int32_t",
36
BaseTy.Layout: "int32_t",
37
BaseTy.MemoryFormat: "int32_t",
38
BaseTy.ScalarType: "int32_t",
39
BaseTy.Generator: "AtenGeneratorHandle",
42
base_type_to_aten_type = {
43
BaseTy.Tensor: "at::Tensor",
45
BaseTy.int: "int64_t",
46
BaseTy.SymInt: "c10::SymInt",
47
BaseTy.Scalar: "c10::Scalar",
48
BaseTy.float: "double",
49
BaseTy.str: "c10::string_view",
50
BaseTy.DeviceIndex: "c10::DeviceIndex",
51
BaseTy.Layout: "c10::Layout",
52
BaseTy.MemoryFormat: "c10::MemoryFormat",
53
BaseTy.ScalarType: "c10::ScalarType",
54
BaseTy.Generator: "at::Generator",
57
base_type_to_callsite_expr = {
58
BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
65
BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
66
BaseTy.Layout: "static_cast<c10::Layout>",
67
BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
68
BaseTy.ScalarType: "static_cast<c10::ScalarType>",
69
BaseTy.Generator: "*generator_handle_to_generator_pointer",
74
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]:
75
if isinstance(typ, BaseType):
76
if typ.name in base_type_to_c_type:
78
[base_type_to_c_type[typ.name]],
80
[base_type_to_aten_type[typ.name]],
82
f"{base_type_to_callsite_expr[typ.name]}({name})"
83
if base_type_to_callsite_expr[typ.name]
87
elif typ.name == BaseTy.Device:
89
["int32_t", "int32_t"],
90
[name, name + "_index_"],
93
f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
98
raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
99
elif isinstance(typ, OptionalType):
100
c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
105
new_callsite_exprs = []
106
for aten_type in aten_types:
108
c_types[j] = c_types[j] + "*"
109
if aten_type.startswith("c10::ArrayRef<"):
111
new_aten_types.append(f"::std::optional<{aten_type}>")
112
base_type = aten_type[len("c10::ArrayRef<") : -1]
113
new_callsite_exprs.append(
114
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
117
elif aten_type == "c10::Device":
119
new_aten_types.append("::std::optional<c10::Device>")
120
new_callsite_exprs.append(
121
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
125
new_aten_types.append(f"::std::optional<{aten_type}>")
126
new_callsite_exprs.append(
127
f"pointer_to_optional<{aten_type}>({names[j]})"
137
elif isinstance(typ, ListType):
139
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
140
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
143
c_types[0] = f"const {c_types[0]}*"
144
c_types.append("int64_t")
146
names.append(name + "_len_")
148
atype = aten_types[0]
153
assert typ.size is not None
154
callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
155
elif atype == "::std::optional<at::Tensor>":
157
callsite_exprs.append(
158
f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
161
callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
163
aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
172
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
173
return [typ + " " + name for typ, name in zip(types, names)]
177
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
181
for arg in flat_arguments:
182
new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
185
types.extend(new_types)
186
new_names.extend(names)
187
callsite_exprs.extend(new_callsite_exprs)
188
return zip_type_and_name(types, new_names), callsite_exprs
194
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
197
for idx, ret in enumerate(schema.returns):
198
names.append(f"ret{idx}")
199
if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
200
types.append(base_type_to_c_type[ret.type.name] + "*")
202
raise NotImplementedError(
203
f"TODO: add support for return type {repr(ret.type)}"
206
def convert_return(typ: BaseType, val: str) -> str:
207
if typ.name == BaseTy.Tensor:
208
return f"new_tensor_handle(std::move({val}));"
209
elif typ.name == BaseTy.SymInt:
210
return f"{val}.expect_int()"
211
elif typ.name == BaseTy.Scalar:
212
return f"{val}.toDouble()"
216
ret_pointer_can_be_null = False
217
unambiguous_name = schema.name.unambiguous_name()
219
"_scaled_dot_product_flash_attention",
220
"_scaled_dot_product_efficient_attention",
221
"_scaled_dot_product_cudnn_attention",
222
"convolution_backward",
224
if name in unambiguous_name:
225
ret_pointer_can_be_null = True
228
callsite_exprs: list[str] = []
229
for idx, ret in enumerate(schema.returns):
230
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
231
assert isinstance(ret.type, BaseType)
232
rval = convert_return(ret.type, tmp)
233
if ret_pointer_can_be_null:
234
callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
236
callsite_exprs.append(f"*{names[idx]} = {rval};")
238
return zip_type_and_name(types, names), callsite_exprs
242
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
245
def gen_declaration_and_definition(
246
schema: FunctionSchema, device: str, backend_call: str
248
func_name = schema.name.unambiguous_name()
250
global declaration_definition_cache
251
if (func_name, device, backend_call) in declaration_definition_cache:
252
return declaration_definition_cache[(func_name, device, backend_call)]
254
if schema.is_out_fn():
257
args, callsite_exprs = gen_arguments(
258
[*schema.arguments.out, *schema.arguments.flat_non_out]
260
ret_assignments: list[str] = []
262
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
264
ret_declarations, ret_assignments = (
265
([], []) if schema.name.name.inplace else gen_returns(schema)
267
args.extend(ret_declarations)
269
declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
271
tmp_result = "auto tmp_result = " if ret_assignments else ""
272
ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
275
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
276
{tmp_result}{backend_call}(
277
{textwrap.indent(', '.join(callsite_exprs), " ")}
278
);{textwrap.indent(ret_assignments_str, " ")}
282
declaration_definition_cache[(func_name, device, backend_call)] = (
286
return declaration, definition
289
def gen_static_dispatch_backend_call_signature(
290
sig: CppSignature | DispatcherSignature,
293
sig = DispatcherSignature.from_schema(f.func)
294
cpp_sigs = CppSignatureGroup.from_native_function(
295
f, method=False, fallback_binding=False
297
if sig.symint and f.func.has_symint():
298
cpp_sig = cpp_sigs.symint_signature
300
cpp_sig = cpp_sigs.signature
301
assert cpp_sig is not None
305
def gen_static_dispatch_backend_call(
307
backend_index: BackendIndex,
309
sig = DispatcherSignature.from_schema(f.func)
310
cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
311
return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
314
def get_backend_index_for_aoti(
315
func: NativeFunction,
316
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
317
dispatch_key: DispatchKey,
318
backend_indices: dict[DispatchKey, BackendIndex],
319
) -> BackendIndex | None:
321
if backend_indices[dispatch_key].has_kernel(func) or (
322
func.structured_delegate is not None
323
and func.structured_delegate in func_group_mapping
324
and backend_indices[dispatch_key].has_kernel(
325
func_group_mapping[func.structured_delegate]
328
backend_index = backend_indices[dispatch_key]
329
elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
331
backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
332
elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
336
backend_index = backend_indices[
337
DispatchKey.CompositeExplicitAutogradNonFunctional
339
elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
340
backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
345
def get_header_for_aoti(
346
func: NativeFunction,
347
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
348
dispatch_key: DispatchKey,
349
backend_indices: dict[DispatchKey, BackendIndex],
351
backend_index = get_backend_index_for_aoti(
352
func, func_group_mapping, dispatch_key, backend_indices
356
if backend_index is None
357
else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
361
def get_fallback_op_name(func: NativeFunction) -> str:
363
f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
364
if func.func.name.overload_name
365
else f"{func.namespace}.{func.func.name.name}.default"
370
func: NativeFunction,
371
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
372
dispatch_key: DispatchKey,
373
backend_indices: dict[DispatchKey, BackendIndex],
376
backend_index = get_backend_index_for_aoti(
377
func, func_group_mapping, dispatch_key, backend_indices
379
if backend_index is None:
383
device = dispatch_key.lower()
384
backend_call = gen_static_dispatch_backend_call(
391
declaration, _ = gen_declaration_and_definition(
392
schema, device, backend_call
394
return f"AOTI_TORCH_EXPORT {declaration};"
396
_, definition = gen_declaration_and_definition(schema, device, backend_call)
399
except NotImplementedError:
403
@dataclass(frozen=True)
405
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
406
dispatch_key: DispatchKey
407
backend_indices: dict[DispatchKey, BackendIndex]
410
@method_with_native_function
413
func: NativeFunction,
417
self.func_group_mapping,
419
self.backend_indices,
426
native_functions: Sequence[NativeFunction],
427
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
428
dispatch_key: DispatchKey,
429
backend_indices: dict[DispatchKey, BackendIndex],
437
func_group_mapping, dispatch_key, backend_indices, header
443
device = dispatch_key.lower()
446
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
447
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
455
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
472
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
473
#include <torch/csrc/inductor/aoti_torch/utils.h>
475
#ifndef AT_PER_OPERATOR_HEADERS
476
#include <ATen/{str(dispatch_key)}Functions.h>
477
#include <ATen/CompositeExplicitAutogradFunctions.h>
478
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
479
#include <ATen/CompositeImplicitAutogradFunctions.h>
484
using namespace torch::aot_inductor;