pytorch

Форк
0
/
gen_aoti_c_shim.py 
486 строк · 16.2 Кб
1
from __future__ import annotations
2

3
import textwrap
4
from dataclasses import dataclass
5
from typing import Sequence
6

7
from torchgen.api.types import DispatcherSignature
8
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
9
from torchgen.context import method_with_native_function
10
from torchgen.model import (
11
    Argument,
12
    BackendIndex,
13
    BaseTy,
14
    BaseType,
15
    DispatchKey,
16
    FunctionSchema,
17
    ListType,
18
    NativeFunction,
19
    NativeFunctionsGroup,
20
    OperatorName,
21
    OptionalType,
22
    Type,
23
)
24
from torchgen.utils import mapMaybe
25

26

27
base_type_to_c_type = {
28
    BaseTy.Tensor: "AtenTensorHandle",
29
    BaseTy.bool: "int32_t",  # Use int to pass bool
30
    BaseTy.int: "int64_t",
31
    BaseTy.SymInt: "int64_t",  # Inductor-generated code won't see a SymInt
32
    BaseTy.Scalar: "double",  # Use double to pass both integer and floating point
33
    BaseTy.float: "double",  # TODO: how about other floating point types?
34
    BaseTy.str: "const char*",
35
    BaseTy.DeviceIndex: "int32_t",
36
    BaseTy.Layout: "int32_t",  # Represent enum as int
37
    BaseTy.MemoryFormat: "int32_t",  # Represent enum as int
38
    BaseTy.ScalarType: "int32_t",  # Represent enum as int
39
    BaseTy.Generator: "AtenGeneratorHandle",
40
}
41

42
base_type_to_aten_type = {
43
    BaseTy.Tensor: "at::Tensor",
44
    BaseTy.bool: "bool",
45
    BaseTy.int: "int64_t",
46
    BaseTy.SymInt: "c10::SymInt",
47
    BaseTy.Scalar: "c10::Scalar",
48
    BaseTy.float: "double",
49
    BaseTy.str: "c10::string_view",
50
    BaseTy.DeviceIndex: "c10::DeviceIndex",
51
    BaseTy.Layout: "c10::Layout",
52
    BaseTy.MemoryFormat: "c10::MemoryFormat",
53
    BaseTy.ScalarType: "c10::ScalarType",
54
    BaseTy.Generator: "at::Generator",
55
}
56

57
base_type_to_callsite_expr = {
58
    BaseTy.Tensor: "*tensor_handle_to_tensor_pointer",
59
    BaseTy.bool: "",
60
    BaseTy.int: "",
61
    BaseTy.SymInt: "",
62
    BaseTy.Scalar: "",
63
    BaseTy.float: "",
64
    BaseTy.str: "",
65
    BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
66
    BaseTy.Layout: "static_cast<c10::Layout>",
67
    BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
68
    BaseTy.ScalarType: "static_cast<c10::ScalarType>",
69
    BaseTy.Generator: "*generator_handle_to_generator_pointer",
70
}
71

72

73
# convert args to C types, names in declarations, and expressions in function bodies
74
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]:  # type: ignore[return]
75
    if isinstance(typ, BaseType):
76
        if typ.name in base_type_to_c_type:
77
            return (
78
                [base_type_to_c_type[typ.name]],
79
                [name],
80
                [base_type_to_aten_type[typ.name]],
81
                [
82
                    f"{base_type_to_callsite_expr[typ.name]}({name})"
83
                    if base_type_to_callsite_expr[typ.name]
84
                    else name
85
                ],
86
            )
87
        elif typ.name == BaseTy.Device:
88
            return (
89
                ["int32_t", "int32_t"],
90
                [name, name + "_index_"],
91
                ["c10::Device"],
92
                [
93
                    f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
94
                ],
95
            )
96
        else:
97
            # TODO: BaseTy.Dimname, etc.
98
            raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
99
    elif isinstance(typ, OptionalType):
100
        c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
101
            typ.elem, name
102
        )
103
        j = 0  # index for names
104
        new_aten_types = []
105
        new_callsite_exprs = []
106
        for aten_type in aten_types:
107
            # Use pointer to denote optional type
108
            c_types[j] = c_types[j] + "*"
109
            if aten_type.startswith("c10::ArrayRef<"):
110
                # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
111
                new_aten_types.append(f"::std::optional<{aten_type}>")
112
                base_type = aten_type[len("c10::ArrayRef<") : -1]
113
                new_callsite_exprs.append(
114
                    f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
115
                )
116
                j += 2
117
            elif aten_type == "c10::Device":
118
                # Device is passed as device_type + device_index
119
                new_aten_types.append("::std::optional<c10::Device>")
120
                new_callsite_exprs.append(
121
                    f"pointer_to_optional_device({names[j]}, {names[j+1]})"
122
                )
123
                j += 2
124
            else:
125
                new_aten_types.append(f"::std::optional<{aten_type}>")
126
                new_callsite_exprs.append(
127
                    f"pointer_to_optional<{aten_type}>({names[j]})"
128
                )
129
                j += 1
130

131
        return (
132
            c_types,
133
            names,
134
            new_aten_types,
135
            new_callsite_exprs,
136
        )
137
    elif isinstance(typ, ListType):
138
        # Need to explictly pass the list as pointer + length
139
        c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
140
        assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
141

142
        # The list content should never be modified
143
        c_types[0] = f"const {c_types[0]}*"
144
        c_types.append("int64_t")
145
        name = names[0]
146
        names.append(name + "_len_")
147

148
        atype = aten_types[0]
149
        callsite_exprs = []
150
        if atype == "bool":
151
            # no converter from std::vector<bool> to c10::ArrayRef<bool>
152
            # construct std::array<bool, N> instead
153
            assert typ.size is not None
154
            callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
155
        elif atype == "::std::optional<at::Tensor>":
156
            # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
157
            callsite_exprs.append(
158
                f"c10::List<{atype}>(c10::ArrayRef<{atype}>(pointer_to_list<{atype}>({name}, {name}_len_)))"
159
            )
160
        else:
161
            callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
162

163
        aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
164
        return (
165
            c_types,
166
            names,
167
            aten_types,
168
            callsite_exprs,
169
        )
170

171

172
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
173
    return [typ + " " + name for typ, name in zip(types, names)]
174

175

176
# Generate argument declarations and callsite expressions
177
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
178
    types = []
179
    new_names = []
180
    callsite_exprs = []
181
    for arg in flat_arguments:
182
        new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
183
            arg.type, arg.name
184
        )
185
        types.extend(new_types)
186
        new_names.extend(names)
187
        callsite_exprs.extend(new_callsite_exprs)
188
    return zip_type_and_name(types, new_names), callsite_exprs
189

190

191
# Return values are passed out as pointer arguments because all the C shim functions
192
# are expected to return AOTITorchError.
193
# Generate returns as declarations and callsite expressions
194
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
195
    types = []
196
    names = []
197
    for idx, ret in enumerate(schema.returns):
198
        names.append(f"ret{idx}")
199
        if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
200
            types.append(base_type_to_c_type[ret.type.name] + "*")
201
        else:
202
            raise NotImplementedError(
203
                f"TODO: add support for return type {repr(ret.type)}"
204
            )
205

206
    def convert_return(typ: BaseType, val: str) -> str:
207
        if typ.name == BaseTy.Tensor:
208
            return f"new_tensor_handle(std::move({val}));"
209
        elif typ.name == BaseTy.SymInt:
210
            return f"{val}.expect_int()"
211
        elif typ.name == BaseTy.Scalar:
212
            return f"{val}.toDouble()"
213
        else:
214
            return val
215

216
    ret_pointer_can_be_null = False
217
    unambiguous_name = schema.name.unambiguous_name()
218
    for name in [
219
        "_scaled_dot_product_flash_attention",
220
        "_scaled_dot_product_efficient_attention",
221
        "_scaled_dot_product_cudnn_attention",
222
        "convolution_backward",
223
    ]:
224
        if name in unambiguous_name:
225
            ret_pointer_can_be_null = True
226
            break
227

228
    callsite_exprs: list[str] = []
229
    for idx, ret in enumerate(schema.returns):
230
        tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
231
        assert isinstance(ret.type, BaseType)
232
        rval = convert_return(ret.type, tmp)
233
        if ret_pointer_can_be_null:
234
            callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
235
        else:
236
            callsite_exprs.append(f"*{names[idx]} = {rval};")
237

238
    return zip_type_and_name(types, names), callsite_exprs
239

240

241
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
242
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
243

244

245
def gen_declaration_and_definition(
246
    schema: FunctionSchema, device: str, backend_call: str
247
) -> tuple[str, str]:
248
    func_name = schema.name.unambiguous_name()
249

250
    global declaration_definition_cache
251
    if (func_name, device, backend_call) in declaration_definition_cache:
252
        return declaration_definition_cache[(func_name, device, backend_call)]
253

254
    if schema.is_out_fn():
255
        # out_variant has out arguments in the front, and it's ok to ignore return values
256
        # because C shim functions only return AOTITorchError
257
        args, callsite_exprs = gen_arguments(
258
            [*schema.arguments.out, *schema.arguments.flat_non_out]
259
        )
260
        ret_assignments: list[str] = []
261
    else:
262
        args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
263
        # ignore return values for inplace ops
264
        ret_declarations, ret_assignments = (
265
            ([], []) if schema.name.name.inplace else gen_returns(schema)
266
        )
267
        args.extend(ret_declarations)
268

269
    declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
270

271
    tmp_result = "auto tmp_result = " if ret_assignments else ""
272
    ret_assignments_str = "\n" + "\n".join(ret_assignments) if ret_assignments else ""
273
    definition = f"""
274
{declaration} {{
275
    AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
276
        {tmp_result}{backend_call}(
277
{textwrap.indent(', '.join(callsite_exprs), "            ")}
278
        );{textwrap.indent(ret_assignments_str, "        ")}
279
    }});
280
}}
281
"""
282
    declaration_definition_cache[(func_name, device, backend_call)] = (
283
        declaration,
284
        definition,
285
    )
286
    return declaration, definition
287

288

289
def gen_static_dispatch_backend_call_signature(
290
    sig: CppSignature | DispatcherSignature,
291
    f: NativeFunction,
292
) -> CppSignature:
293
    sig = DispatcherSignature.from_schema(f.func)
294
    cpp_sigs = CppSignatureGroup.from_native_function(
295
        f, method=False, fallback_binding=False
296
    )
297
    if sig.symint and f.func.has_symint():
298
        cpp_sig = cpp_sigs.symint_signature
299
    else:
300
        cpp_sig = cpp_sigs.signature
301
    assert cpp_sig is not None
302
    return cpp_sig
303

304

305
def gen_static_dispatch_backend_call(
306
    f: NativeFunction,
307
    backend_index: BackendIndex,
308
) -> str:
309
    sig = DispatcherSignature.from_schema(f.func)
310
    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
311
    return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
312

313

314
def get_backend_index_for_aoti(
315
    func: NativeFunction,
316
    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
317
    dispatch_key: DispatchKey,
318
    backend_indices: dict[DispatchKey, BackendIndex],
319
) -> BackendIndex | None:
320
    backend_index = None
321
    if backend_indices[dispatch_key].has_kernel(func) or (
322
        func.structured_delegate is not None
323
        and func.structured_delegate in func_group_mapping
324
        and backend_indices[dispatch_key].has_kernel(
325
            func_group_mapping[func.structured_delegate]
326
        )
327
    ):
328
        backend_index = backend_indices[dispatch_key]
329
    elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
330
        # We need to create C shim wrappers for CompositeExplicitAutograd kernels
331
        backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
332
    elif backend_indices[DispatchKey.CompositeExplicitAutogradNonFunctional].has_kernel(
333
        func
334
    ):
335
        # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
336
        backend_index = backend_indices[
337
            DispatchKey.CompositeExplicitAutogradNonFunctional
338
        ]
339
    elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
340
        backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
341

342
    return backend_index
343

344

345
def get_header_for_aoti(
346
    func: NativeFunction,
347
    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
348
    dispatch_key: DispatchKey,
349
    backend_indices: dict[DispatchKey, BackendIndex],
350
) -> str | None:
351
    backend_index = get_backend_index_for_aoti(
352
        func, func_group_mapping, dispatch_key, backend_indices
353
    )
354
    return (
355
        None
356
        if backend_index is None
357
        else f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
358
    )
359

360

361
def get_fallback_op_name(func: NativeFunction) -> str:
362
    return (
363
        f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
364
        if func.func.name.overload_name
365
        else f"{func.namespace}.{func.func.name.name}.default"
366
    )
367

368

369
def gen_c_shim(
370
    func: NativeFunction,
371
    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
372
    dispatch_key: DispatchKey,
373
    backend_indices: dict[DispatchKey, BackendIndex],
374
    header: bool,
375
) -> str | None:
376
    backend_index = get_backend_index_for_aoti(
377
        func, func_group_mapping, dispatch_key, backend_indices
378
    )
379
    if backend_index is None:
380
        return None
381

382
    schema = func.func
383
    device = dispatch_key.lower()
384
    backend_call = gen_static_dispatch_backend_call(
385
        func,
386
        backend_index,
387
    )
388

389
    try:
390
        if header:
391
            declaration, _ = gen_declaration_and_definition(
392
                schema, device, backend_call
393
            )
394
            return f"AOTI_TORCH_EXPORT {declaration};"
395
        else:
396
            _, definition = gen_declaration_and_definition(schema, device, backend_call)
397
            return definition
398

399
    except NotImplementedError:
400
        return None
401

402

403
@dataclass(frozen=True)
404
class ShimGenerator:
405
    func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
406
    dispatch_key: DispatchKey
407
    backend_indices: dict[DispatchKey, BackendIndex]
408
    header: bool  # True to generate .h and False to generate .cpp
409

410
    @method_with_native_function
411
    def __call__(
412
        self,
413
        func: NativeFunction,
414
    ) -> str | None:
415
        result = gen_c_shim(
416
            func,
417
            self.func_group_mapping,
418
            self.dispatch_key,
419
            self.backend_indices,
420
            self.header,
421
        )
422
        return result
423

424

425
def gen_aoti_c_shim(
426
    native_functions: Sequence[NativeFunction],
427
    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
428
    dispatch_key: DispatchKey,
429
    backend_indices: dict[DispatchKey, BackendIndex],
430
    header: bool,
431
    includes: str = "",
432
) -> str:
433
    body = "\n".join(
434
        list(
435
            mapMaybe(
436
                ShimGenerator(
437
                    func_group_mapping, dispatch_key, backend_indices, header
438
                ),
439
                native_functions,
440
            )
441
        )
442
    )
443
    device = dispatch_key.lower()
444

445
    warning = """
446
// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
447
// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
448

449
    if header:
450
        return f"""
451
{warning}
452

453
#pragma once
454

455
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
456

457
#ifdef __cplusplus
458
extern "C" {{
459
#endif
460

461
{body}
462

463
#ifdef __cplusplus
464
}} // extern "C"
465
#endif
466
"""
467

468
    else:
469
        return f"""
470
{warning}
471

472
#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{device}.h>
473
#include <torch/csrc/inductor/aoti_torch/utils.h>
474

475
#ifndef AT_PER_OPERATOR_HEADERS
476
#include <ATen/{str(dispatch_key)}Functions.h>
477
#include <ATen/CompositeExplicitAutogradFunctions.h>
478
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
479
#include <ATen/CompositeImplicitAutogradFunctions.h>
480
#else
481
{includes}
482
#endif
483

484
using namespace torch::aot_inductor;
485

486
{body}"""
487

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.