pytorch

Форк
0
/
gen_vmap_plumbing.py 
271 строка · 9.1 Кб
1
from __future__ import annotations
2

3
import textwrap
4
from dataclasses import dataclass
5
from typing import Sequence
6

7
from torchgen.api.translate import translate
8
from torchgen.api.types import DispatcherSignature
9
from torchgen.context import method_with_native_function
10
from torchgen.model import (
11
    Argument,
12
    BaseTy,
13
    BaseType,
14
    FunctionSchema,
15
    ListType,
16
    NativeFunction,
17
    OptionalType,
18
    Return,
19
    SchemaKind,
20
    Type,
21
)
22
from torchgen.utils import mapMaybe
23

24

25
def is_tensor(typ: Type) -> bool:
26
    return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
27

28

29
def is_optional_tensor(typ: Type) -> bool:
30
    return isinstance(typ, OptionalType) and is_tensor(typ.elem)
31

32

33
def is_tensor_list(typ: Type) -> bool:
34
    return isinstance(typ, ListType) and is_tensor(typ.elem)
35

36

37
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
38
    result = f"""\
39
    auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
40
    return textwrap.dedent(result).split("\n")
41

42

43
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
44
    result = f"""\
45
    std::optional<Tensor> {name}_value;
46
    std::optional<int64_t> {name}_bdim;
47
    if ({name}) {{
48
        std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
49
    }}"""
50
    return textwrap.dedent(result).split("\n")
51

52

53
def gen_unwraps(
54
    flat_arguments: Sequence[Argument], cur_level_var: str
55
) -> tuple[str, list[str]]:
56
    arg_names = [a.name for a in flat_arguments]
57
    arg_types = [a.type for a in flat_arguments]
58

59
    tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
60
    optional_tensors = [
61
        name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
62
    ]
63

64
    unwraps = []
65
    for tensor in tensors:
66
        unwraps += unwrap_tensor(tensor, cur_level_var)
67

68
    for opt_tensor in optional_tensors:
69
        unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
70
    unwrap_code = "\n".join(unwraps)
71

72
    unwrapped_arg_list = []
73
    for arg in arg_names:
74
        if arg in tensors or arg in optional_tensors:
75
            unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
76
        else:
77
            unwrapped_arg_list.append(arg)
78
    return unwrap_code, unwrapped_arg_list
79

80

81
def gen_case_where_all_bdims_are_none(
82
    outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
83
) -> str:
84
    conditions = []
85
    flat_args = schema.arguments.flat_all
86
    for arg in flat_args:
87
        if not arg.type.is_tensor_like():
88
            continue
89
        conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
90

91
    sig = DispatcherSignature.from_schema(schema)
92
    translated_args = ", ".join(
93
        e.expr for e in translate(outer_sig.arguments(), sig.arguments())
94
    )
95
    return f"""\
96
if ({' && '.join(conditions)}) {{
97
  return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
98
}}"""
99

100

101
def gen_returns(
102
    returns: tuple[Return, ...], cur_level_var: str, results_var: str
103
) -> str:
104
    idx = 0
105
    wrapped_returns = []
106
    for ret in returns:
107
        if is_tensor(ret.type):
108
            wrapped_returns.append(
109
                f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
110
            )
111
            idx += 2
112
        elif is_tensor_list(ret.type):
113
            wrapped_returns.append(
114
                f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
115
            )
116
            idx += 2
117
        else:
118
            wrapped_returns.append(f"std::get<{idx}>({results_var})")
119
            idx += 1
120
    if len(wrapped_returns) == 1:
121
        result = f"return {wrapped_returns[0]};"
122
    else:
123
        result = f'return std::make_tuple({", ".join(wrapped_returns)});'
124
    return result
125

126

127
def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
128
    return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
129

130

131
def is_mutated_arg(argument: Argument) -> bool:
132
    return argument.annotation is not None and argument.annotation.is_write
133

134

135
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
136
    # Assumptions:
137
    # - only one argument is being modified in-place
138
    # - the argument that is being modified in-place is the first argument
139
    # - all returns are either Tensor, tuple of Tensor, or TensorList
140
    schema = native_function.func
141
    sig = DispatcherSignature.from_schema(schema)
142
    returns = schema.returns
143

144
    # Check assumptions. If these are invalid we return None
145
    # and punt the work to handle them to the future.
146
    assert schema.kind() == SchemaKind.inplace
147
    if not is_mutated_arg(schema.arguments.flat_all[0]):
148
        return None
149
    if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
150
        return None
151

152
    # Only support cases where all returns are Tensors or vector<Tensor>
153
    if len(returns) == 0:
154
        return None
155
    if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
156
        return None
157
    if not accepts_at_least_one_tensor_input(schema):
158
        return None
159

160
    cur_level_var = "cur_level"
161

162
    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
163
    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
164

165
    return f"""\
166
template <typename batch_rule_t, batch_rule_t batch_rule>
167
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
168
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
169
  auto maybe_layer = maybeCurrentDynamicLayer();
170
  vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
171
  int64_t {cur_level_var} = maybe_layer->layerId();
172
{textwrap.indent(bdims_all_none_case, "  ")}
173
{textwrap.indent(unwraps, "  ")}
174
  batch_rule({', '.join(unwrapped_arg_list)});
175
  return {schema.arguments.flat_all[0].name};
176
}}"""
177

178

179
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
180
    schema = native_function.func
181
    sig = DispatcherSignature.from_schema(schema)
182
    cur_level_var = "cur_level"
183

184
    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
185
    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
186

187
    return f"""\
188
template <typename batch_rule_t, batch_rule_t batch_rule>
189
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
190
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
191
  auto maybe_layer = maybeCurrentDynamicLayer();
192
  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
193
  int64_t {cur_level_var} = maybe_layer->layerId();
194
{textwrap.indent(bdims_all_none_case, "  ")}
195
{textwrap.indent(unwraps, "  ")}
196
  batch_rule({', '.join(unwrapped_arg_list)});
197
}}"""
198

199

200
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
201
    schema = native_function.func
202
    sig = DispatcherSignature.from_schema(schema)
203
    returns = schema.returns
204

205
    # Only support cases where all returns are Tensors or vector<Tensor>
206
    if not accepts_at_least_one_tensor_input(schema):
207
        return None
208
    if len(returns) == 0:
209
        return gen_vmap_plumbing_no_returns(native_function)
210
    return_symint_overrides = [
211
        "_scaled_dot_product_flash_attention",
212
        "_scaled_dot_product_cudnn_attention",
213
    ]
214
    if (
215
        not all(ret.type.is_tensor_like() for ret in returns)
216
        and schema.name.unambiguous_name() not in return_symint_overrides
217
    ):
218
        return None
219
    # in-place views need special handling
220
    if "inplace_view" in native_function.tags:
221
        return None
222

223
    if schema.kind() == SchemaKind.inplace:
224
        return gen_vmap_inplace_plumbing(native_function)
225

226
    # Don't support these (mutable, out, scratch)
227
    if schema.kind() != SchemaKind.functional:
228
        return None
229

230
    results_var = "results"
231
    cur_level_var = "cur_level"
232

233
    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
234
    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
235

236
    wrapped_returns = gen_returns(returns, cur_level_var, results_var)
237
    return f"""\
238
template <typename batch_rule_t, batch_rule_t batch_rule>
239
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
240
  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
241
  auto maybe_layer = maybeCurrentDynamicLayer();
242
  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
243
  int64_t {cur_level_var} = maybe_layer->layerId();
244
{textwrap.indent(bdims_all_none_case, "  ")}
245
{textwrap.indent(unwraps, "  ")}
246
  auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
247
  {wrapped_returns}
248
}}"""
249

250

251
@dataclass(frozen=True)
252
class ComputeBatchRulePlumbing:
253
    @method_with_native_function
254
    def __call__(self, f: NativeFunction) -> str | None:
255
        result = gen_vmap_plumbing(f)
256
        return result
257

258

259
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
260
    body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
261
    return f"""
262
#pragma once
263
#include <ATen/Operators.h>
264
#include <ATen/functorch/PlumbingHelper.h>
265

266
namespace at {{ namespace functorch {{
267

268
{body}
269

270
}}}} // namespace at::functorch
271
"""
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.