pytorch

Форк
0
/
native.py 
153 строки · 5.0 Кб
1
from typing import List, Optional, Sequence, Union
2

3
from torchgen import local
4
from torchgen.api import cpp
5

6
from torchgen.api.types import (
7
    ArgName,
8
    BaseCType,
9
    Binding,
10
    boolT,
11
    ConstRefCType,
12
    CType,
13
    deviceT,
14
    layoutT,
15
    ListCType,
16
    MutRefCType,
17
    NamedCType,
18
    OptionalCType,
19
    scalarT,
20
    scalarTypeT,
21
    tensorT,
22
)
23
from torchgen.model import (
24
    Argument,
25
    FunctionSchema,
26
    Return,
27
    SelfArgument,
28
    TensorOptionsArguments,
29
    Type,
30
)
31
from torchgen.utils import assert_never
32

33
# This file describes the translation of JIT schema to the native functions API.
34
# This looks a lot like the C++ API (which makes historical sense, because the
35
# idea was you wrote native functions to implement functions in the C++ API),
36
# but over time we have evolved the C++ API without actually changing our
37
# native:: kernels.  The intention is to make native API and dispatcher API
38
# line up as closely as possible, since this results in the least overhead
39
# (no translation is needed from dispatcher API to native API).
40
#
41
# NB: this is symint aware, you will get the non-SymInt variant for some
42
# dispatch entries and SymInt for others.
43

44

45
def name(func: FunctionSchema) -> str:
46
    name = str(func.name.name)
47
    # TODO: delete this!
48
    if func.is_out_fn():
49
        name += "_out"
50
    if func.name.overload_name:
51
        name += f"_{func.name.overload_name}"
52
    return name
53

54

55
def argumenttype_type(
56
    t: Type, *, mutable: bool, binds: ArgName, symint: bool
57
) -> NamedCType:
58
    if str(t) == "Tensor?":
59
        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
60
        if mutable and not local.use_const_ref_for_mutable_tensors():
61
            return NamedCType(binds, MutRefCType(tensor_type))
62
        else:
63
            return NamedCType(binds, ConstRefCType(tensor_type))
64
    elif str(t) == "Tensor?[]":
65
        return NamedCType(
66
            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
67
        )
68
    elif str(t) == "Scalar":
69
        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
70
    elif str(t) == "Scalar?":
71
        return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
72
    return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
73

74

75
def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
76
    return cpp.returns_type(rs, symint=symint)
77

78

79
def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
80
    return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
81

82

83
def argument(
84
    a: Union[Argument, SelfArgument, TensorOptionsArguments],
85
    *,
86
    is_out: bool,
87
    symint: bool,
88
) -> List[Binding]:
89
    # Ideally, we NEVER default native functions.  However, there are a number
90
    # of functions that call native:: directly and rely on the defaulting
91
    # existing.  So for BC, we generate defaults for non-out variants (but not
92
    # for out variants, where it is impossible to generate an appropriate
93
    # default)
94
    should_default = not is_out
95
    if isinstance(a, Argument):
96
        default: Optional[str] = None
97
        if should_default and a.default is not None:
98
            default = cpp.default_expr(a.default, a.type, symint=symint)
99
        return [
100
            Binding(
101
                nctype=argument_type(a, binds=a.name, symint=symint),
102
                name=a.name,
103
                default=default,
104
                argument=a,
105
            )
106
        ]
107
    elif isinstance(a, SelfArgument):
108
        # Erase SelfArgument from the distinction
109
        return argument(a.argument, is_out=is_out, symint=symint)
110
    elif isinstance(a, TensorOptionsArguments):
111
        default = None
112
        if should_default:
113
            default = "{}"
114
        # TODO: Not sure why the arguments assigned here are for
115
        # TensorOptionsArguments and not the constituent pieces.  It seems
116
        # to matter
117
        return [
118
            Binding(
119
                nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
120
                name="dtype",
121
                default=default,
122
                argument=a,
123
            ),
124
            Binding(
125
                nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
126
                name="layout",
127
                default=default,
128
                argument=a,
129
            ),
130
            Binding(
131
                nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
132
                name="device",
133
                default=default,
134
                argument=a,
135
            ),
136
            Binding(
137
                nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
138
                name="pin_memory",
139
                default=default,
140
                argument=a,
141
            ),
142
        ]
143
    else:
144
        assert_never(a)
145

146

147
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
148
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
149
    args.extend(func.arguments.non_out)
150
    args.extend(func.arguments.out)
151
    return [
152
        r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
153
    ]
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.