pytorch

Форк
0
/
structured.py 
157 строк · 6.0 Кб
1
from __future__ import annotations
2

3
from torchgen.api import cpp
4
from torchgen.api.types import (
5
    ArgName,
6
    ArrayRefCType,
7
    BaseCType,
8
    Binding,
9
    ConstRefCType,
10
    dimnameListT,
11
    intArrayRefT,
12
    iOptTensorListRefT,
13
    iTensorListRefT,
14
    NamedCType,
15
    OptionalCType,
16
    optionalIntArrayRefT,
17
    optionalScalarRefT,
18
    optionalTensorRefT,
19
    scalarT,
20
    tensorT,
21
)
22
from torchgen.model import (
23
    Argument,
24
    BaseTy,
25
    BaseType,
26
    ListType,
27
    NativeFunctionsGroup,
28
    OptionalType,
29
    SelfArgument,
30
    TensorOptionsArguments,
31
    Type,
32
)
33
from torchgen.utils import assert_never
34

35

36
# This file describes the translation of JIT schema to the structured functions API.
37
# This is similar to native API, but a number of historical problems with native
38
# API have been fixed.
39

40

41
# Translation of types occurring in JIT arguments to a C++ argument type.
42
# NB: For now, mutable doesn't do anything; but it could if we make
43
# some more nominal types
44
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
45
    # If it's a value type, do the value type translation
46
    # NB: structured kernels ALWAYS have symint off, since they involve actual
47
    # kernels that require real ints.  The one exception is the
48
    # CompositeExplicitAutograd and the meta function (which could
49
    # hypothetically be SymInt), but for simplicity we plan for these to just
50
    # be handled in Python
51
    r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
52
    if r is not None:
53
        return r
54

55
    if isinstance(t, BaseType):
56
        if t.name == BaseTy.Tensor:
57
            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
58
        elif t.name == BaseTy.Scalar:
59
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
60
        else:
61
            raise AssertionError(f"base type should have been value type {t}")
62
    elif isinstance(t, OptionalType):
63
        if t.elem == BaseType(BaseTy.Tensor):
64
            return NamedCType(binds, BaseCType(optionalTensorRefT))
65
        elif t.elem == BaseType(BaseTy.Scalar):
66
            return NamedCType(binds, BaseCType(optionalScalarRefT))
67
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
68
            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
69
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
70
        return NamedCType(binds, OptionalCType(elem.type))
71
    elif isinstance(t, ListType):
72
        if t.elem == BaseType(BaseTy.Tensor):
73
            return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
74
        elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
75
            return NamedCType(binds, BaseCType(iOptTensorListRefT))
76
        # TODO: delete these special cases; see torchgen.api.cpp--these
77
        # must be changed in tandem, but there are problems; see
78
        # https://github.com/pytorch/pytorch/pull/51485
79
        elif str(t.elem) == "int":
80
            return NamedCType(binds, BaseCType(intArrayRefT))
81
        elif str(t.elem) == "Dimname":
82
            return NamedCType(binds, BaseCType(dimnameListT))
83
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
84
        return NamedCType(binds, ArrayRefCType(elem.type))
85
    else:
86
        raise AssertionError(f"unrecognized type {repr(t)}")
87

88

89
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
90
    return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
91

92

93
# returns_type intentionally omitted, because structured kernels never "return";
94
# instead, they always indirectly report their outputs (in the case of a meta
95
# function, by calling set_output; in the case of an impl function, by writing
96
# directly into the provided out argument).
97

98

99
# Structured kernels are never defaulted
100
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
101
    if isinstance(a, Argument):
102
        return [
103
            Binding(
104
                nctype=argument_type(a, binds=a.name),
105
                name=a.name,
106
                default=None,
107
                argument=a,
108
            )
109
        ]
110
    elif isinstance(a, SelfArgument):
111
        return argument(a.argument)
112
    elif isinstance(a, TensorOptionsArguments):
113
        raise AssertionError("structured kernels don't support TensorOptions yet")
114
    else:
115
        assert_never(a)
116

117

118
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
119
    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
120

121
    if g.out.precomputed:
122
        # A list of parameters for the impl function with
123
        # certain parameters replaced with precomputed counterparts
124
        # as specified in native_functions.yaml.
125
        non_out_args_replaced: list[
126
            Argument | TensorOptionsArguments | SelfArgument
127
        ] = []
128
        for a in g.out.func.arguments.non_out:
129
            if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
130
                # If a is in precompute.replace, append the parameters
131
                # that should replace it onto non_out_args_replaced.
132
                non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
133
            else:
134
                # If not, push a as it is.
135
                non_out_args_replaced.append(a)
136

137
        args.extend(non_out_args_replaced)
138
        # g.out.precomputed.add is the list of parameters that are added
139
        # without replacement after the non out args and just before the out args
140
        args.extend(g.out.precomputed.add)
141
    else:
142
        args.extend(g.out.func.arguments.non_out)
143

144
    args.extend(g.out.func.arguments.out)
145
    return [r for arg in args for r in argument(arg)]
146

147

148
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
149
    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
150
    args.extend(g.functional.func.arguments.non_out)
151
    return [r for arg in args for r in argument(arg)]
152

153

154
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
155
    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
156
    args.extend(g.out.func.arguments.out)
157
    return [r for arg in args for r in argument(arg)]
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.