1
from __future__ import annotations
4
from dataclasses import dataclass
5
from typing import Sequence
7
from torchgen.api.translate import translate
8
from torchgen.api.types import DispatcherSignature
9
from torchgen.context import method_with_native_function
10
from torchgen.model import (
22
from torchgen.utils import mapMaybe
25
def is_tensor(typ: Type) -> bool:
26
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
29
def is_optional_tensor(typ: Type) -> bool:
30
return isinstance(typ, OptionalType) and is_tensor(typ.elem)
33
def is_tensor_list(typ: Type) -> bool:
34
return isinstance(typ, ListType) and is_tensor(typ.elem)
37
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
39
auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
40
return textwrap.dedent(result).split("\n")
43
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
45
std::optional<Tensor> {name}_value;
46
std::optional<int64_t> {name}_bdim;
48
std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
50
return textwrap.dedent(result).split("\n")
54
flat_arguments: Sequence[Argument], cur_level_var: str
55
) -> tuple[str, list[str]]:
56
arg_names = [a.name for a in flat_arguments]
57
arg_types = [a.type for a in flat_arguments]
59
tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
61
name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
65
for tensor in tensors:
66
unwraps += unwrap_tensor(tensor, cur_level_var)
68
for opt_tensor in optional_tensors:
69
unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
70
unwrap_code = "\n".join(unwraps)
72
unwrapped_arg_list = []
74
if arg in tensors or arg in optional_tensors:
75
unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
77
unwrapped_arg_list.append(arg)
78
return unwrap_code, unwrapped_arg_list
81
def gen_case_where_all_bdims_are_none(
82
outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
85
flat_args = schema.arguments.flat_all
87
if not arg.type.is_tensor_like():
89
conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
91
sig = DispatcherSignature.from_schema(schema)
92
translated_args = ", ".join(
93
e.expr for e in translate(outer_sig.arguments(), sig.arguments())
96
if ({' && '.join(conditions)}) {{
97
return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
102
returns: tuple[Return, ...], cur_level_var: str, results_var: str
107
if is_tensor(ret.type):
108
wrapped_returns.append(
109
f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
112
elif is_tensor_list(ret.type):
113
wrapped_returns.append(
114
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
118
wrapped_returns.append(f"std::get<{idx}>({results_var})")
120
if len(wrapped_returns) == 1:
121
result = f"return {wrapped_returns[0]};"
123
result = f'return std::make_tuple({", ".join(wrapped_returns)});'
127
def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
128
return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
131
def is_mutated_arg(argument: Argument) -> bool:
132
return argument.annotation is not None and argument.annotation.is_write
135
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
137
# - only one argument is being modified in-place
138
# - the argument that is being modified in-place is the first argument
139
# - all returns are either Tensor, tuple of Tensor, or TensorList
140
schema = native_function.func
141
sig = DispatcherSignature.from_schema(schema)
142
returns = schema.returns
144
# Check assumptions. If these are invalid we return None
145
# and punt the work to handle them to the future.
146
assert schema.kind() == SchemaKind.inplace
147
if not is_mutated_arg(schema.arguments.flat_all[0]):
149
if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
152
# Only support cases where all returns are Tensors or vector<Tensor>
153
if len(returns) == 0:
155
if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
157
if not accepts_at_least_one_tensor_input(schema):
160
cur_level_var = "cur_level"
162
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
163
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
166
template <typename batch_rule_t, batch_rule_t batch_rule>
167
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
168
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
169
auto maybe_layer = maybeCurrentDynamicLayer();
170
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
171
int64_t {cur_level_var} = maybe_layer->layerId();
172
{textwrap.indent(bdims_all_none_case, " ")}
173
{textwrap.indent(unwraps, " ")}
174
batch_rule({', '.join(unwrapped_arg_list)});
175
return {schema.arguments.flat_all[0].name};
179
def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
180
schema = native_function.func
181
sig = DispatcherSignature.from_schema(schema)
182
cur_level_var = "cur_level"
184
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
185
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
188
template <typename batch_rule_t, batch_rule_t batch_rule>
189
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
190
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
191
auto maybe_layer = maybeCurrentDynamicLayer();
192
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
193
int64_t {cur_level_var} = maybe_layer->layerId();
194
{textwrap.indent(bdims_all_none_case, " ")}
195
{textwrap.indent(unwraps, " ")}
196
batch_rule({', '.join(unwrapped_arg_list)});
200
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
201
schema = native_function.func
202
sig = DispatcherSignature.from_schema(schema)
203
returns = schema.returns
205
# Only support cases where all returns are Tensors or vector<Tensor>
206
if not accepts_at_least_one_tensor_input(schema):
208
if len(returns) == 0:
209
return gen_vmap_plumbing_no_returns(native_function)
210
return_symint_overrides = [
211
"_scaled_dot_product_flash_attention",
212
"_scaled_dot_product_cudnn_attention",
215
not all(ret.type.is_tensor_like() for ret in returns)
216
and schema.name.unambiguous_name() not in return_symint_overrides
219
# in-place views need special handling
220
if "inplace_view" in native_function.tags:
223
if schema.kind() == SchemaKind.inplace:
224
return gen_vmap_inplace_plumbing(native_function)
226
# Don't support these (mutable, out, scratch)
227
if schema.kind() != SchemaKind.functional:
230
results_var = "results"
231
cur_level_var = "cur_level"
233
unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
234
bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
236
wrapped_returns = gen_returns(returns, cur_level_var, results_var)
238
template <typename batch_rule_t, batch_rule_t batch_rule>
239
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
240
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
241
auto maybe_layer = maybeCurrentDynamicLayer();
242
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
243
int64_t {cur_level_var} = maybe_layer->layerId();
244
{textwrap.indent(bdims_all_none_case, " ")}
245
{textwrap.indent(unwraps, " ")}
246
auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
251
@dataclass(frozen=True)
252
class ComputeBatchRulePlumbing:
253
@method_with_native_function
254
def __call__(self, f: NativeFunction) -> str | None:
255
result = gen_vmap_plumbing(f)
259
def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
260
body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
263
#include <ATen/Operators.h>
264
#include <ATen/functorch/PlumbingHelper.h>
266
namespace at {{ namespace functorch {{
270
}}}} // namespace at::functorch