pytorch

Форк
0
/
gen_view_funcs.py 
334 строки · 11.3 Кб
1
# Generates ViewFuncs.h/cpp
2
#
3
# NOTE: If any changes are being made to the ViewFunc codegen please also check
4
# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5
# The fallback is expected to mimic this codegen, so we should keep the two in sync.
6

7
from typing import List, Tuple
8

9
import torchgen.api.dispatcher as dispatcher
10
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
11
from torchgen.api.translate import translate
12
from torchgen.api.types import (
13
    BaseCType,
14
    Binding,
15
    NamedCType,
16
    SymIntT,
17
    tensorT,
18
    VectorCType,
19
)
20
from torchgen.code_template import CodeTemplate
21
from torchgen.model import Argument, NativeFunction, OptionalType
22
from torchgen.utils import FileManager
23

24
from .gen_inplace_or_view_type import (
25
    CALL_DISPATCH,
26
    extract_bindings,
27
    get_view_info,
28
    modifies_arguments,
29
    use_derived,
30
)
31

32
FUNCTION_DECLARATION = CodeTemplate(
33
    """\
34
#define ${uppercase_op}_AVAILABLE
35
struct ${op} : public ${superclass} {
36
  ${op}(${constructor_args}) ${initializer_list}
37
  {};
38
  virtual ~${op}() override {};
39
  virtual std::vector<c10::SymInt> get_symints() const override;
40
  virtual size_t num_symints() const override;
41
  virtual std::vector<at::Tensor> get_tensors() const override;
42
  virtual size_t num_tensors() const override;
43
  virtual at::Tensor operator()(const at::Tensor&) const override;
44
  virtual std::unique_ptr<ViewFunc> clone_and_set(
45
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
46
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
47

48
protected:
49
  virtual void set_symints(std::vector<c10::SymInt>) override;
50
  virtual void set_tensors(std::vector<at::Tensor>) override;
51

52
private:
53
  ${state}
54
};
55

56
"""
57
)
58

59
FUNCTION_DEFINITION = CodeTemplate(
60
    """\
61
std::vector<c10::SymInt> ${op}::get_symints() const {
62
  ${get_symints}
63
}
64

65
size_t ${op}::num_symints() const {
66
  return static_cast<size_t>(${num_symints});
67
}
68

69
void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
70
  TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
71
  ${set_symints}
72
}
73

74
std::vector<at::Tensor> ${op}::get_tensors() const {
75
  ${get_tensors}
76
}
77

78
size_t ${op}::num_tensors() const {
79
  return static_cast<size_t>(${num_tensors});
80
}
81

82
void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
83
  TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
84
  ${set_tensors}
85
}
86

87
at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
88
  return ${op_call};
89
}
90

91
std::unique_ptr<ViewFunc> ${op}::clone_and_set(
92
    std::optional<std::vector<c10::SymInt>> ${symints_vec},
93
    std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
94
  auto output = std::make_unique<${op}>(${clone_args});
95
  if (${symints_vec}.has_value()) {
96
    output->set_symints(std::move(*(${symints_vec})));
97
  }
98
  if (${tensors_vec}.has_value()) {
99
    output->set_tensors(std::move(*(${tensors_vec})));
100
  }
101
  return output;
102
}
103

104
"""
105
)
106

107

108
# e.g. as_strided -> AsStridedViewFunc for camel case or
109
# as_strided_view_func otherwise
110
def view_func_name(
111
    f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
112
) -> str:
113
    name = f.func.name.unambiguous_name()
114
    view_func_name = f"{name.replace('.', '_')}_view_func"
115
    if camel_case:
116
        is_private = view_func_name.startswith("_")
117
        view_func_name = "".join(
118
            [p.title() for p in view_func_name.replace(".", "_").split("_")]
119
        )
120
        if is_private:
121
            # put the leading underscore back in
122
            view_func_name = f"_{view_func_name}"
123
    namespace = "torch::autograd::generated::" if include_namespace else ""
124
    return f"{namespace}{view_func_name}"
125

126

127
def is_symint_or_tensor(arg: Argument) -> bool:
128
    return arg.type.is_tensor_like() or arg.type.is_symint_like()
129

130

131
def remove_const_ref(binding: Binding) -> Binding:
132
    return Binding(
133
        name=binding.name,
134
        nctype=binding.nctype.remove_const_ref(),
135
        argument=binding.argument,
136
        default=binding.default,
137
    )
138

139

140
def returns_multi_tensor(fn: NativeFunction) -> bool:
141
    returns = fn.func.returns
142
    assert len(returns) == 1
143
    returns_list_like = returns[0].type.is_list_like() is not None
144
    returns_tensor_like = returns[0].type.is_tensor_like()
145
    return returns_list_like and returns_tensor_like
146

147

148
# Generates strings with logic for getting / setting state of a particular type.
149
#
150
# Args:
151
#   bindings (list): List of state bindings of interest (may be empty)
152
#   state_vec_type (NamedCType): Type of vector to either return or copy from
153
#
154
# Returns:
155
#   tuple: (list of getter logic strings, list of setter logic strings, string
156
#     with num items expression)
157
def generate_state_getter_setter(
158
    bindings: List[Binding],
159
    state_vec_type: NamedCType,
160
) -> Tuple[List[str], List[str], str]:
161
    getter_logic = []
162
    setter_logic = []
163

164
    state_vec = state_vec_type.name
165
    getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
166
    if len(bindings) > 0:
167
        setter_logic.append("auto i = 0;")
168

169
    num_exprs = []
170
    for i, b in enumerate(bindings):
171
        assert isinstance(b.argument, Argument)
172
        if b.argument.type.is_list_like():
173
            # Handle list-likes.
174
            num_expr = f"{b.name}.size()"
175
            num_exprs.append(num_expr)
176
            getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
177
            setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
178
        elif isinstance(b.argument.type, OptionalType):
179
            # Handle optionals.
180
            num_expr = f"({b.name}.has_value() ? 1 : 0)"
181
            num_exprs.append(num_expr)
182
            conditional = f"if({b.name}.has_value())"
183
            getter = (
184
                f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
185
            )
186
            setter = f"{conditional} {b.name} = {state_vec}[i];"
187
        else:
188
            num_expr = "1"
189
            num_exprs.append(num_expr)
190
            getter = f"{state_vec}.push_back({b.name});"
191
            setter = f"{b.name} = {state_vec}[i];"
192

193
        getter_logic.append(getter)
194
        setter_logic.append(setter)
195
        if i < len(bindings) - 1:
196
            setter_logic.append(f"i += {num_expr};")
197

198
    # Reserve / assert based on the total number of items expression.
199
    num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
200
    if len(bindings) > 0:
201
        getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
202

203
    getter_logic.append(f"return {state_vec};")
204

205
    return getter_logic, setter_logic, num_items
206

207

208
def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
209
    bindings = extract_bindings(fn)
210
    non_self_bindings = [b for b in bindings if b.name != "self"]
211

212
    non_self_args = fn.func.arguments.flat_all[1:]
213
    non_self_value_bindings = [
214
        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
215
    ]
216

217
    # Generate constructor / clone args for the generated struct.
218
    constructor_args = [b.defn() for b in non_self_bindings]
219
    clone_args = [b.name for b in non_self_bindings]
220

221
    # Generate state variable declarations for the generated struct.
222
    state_variables = [
223
        f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
224
    ]
225

226
    # Generate initializer list expressions for the generated struct.
227
    # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
228
    # vector<SymInt>s.
229
    init_exprs = translate(
230
        non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
231
    )
232
    initializers = []
233
    for b, init_expr in zip(non_self_bindings, init_exprs):
234
        name = b.nctype.name
235
        assert isinstance(name, str)
236
        initializers.append(f"{name}({init_expr.expr})")
237

238
    # Generate call to underlying view op
239
    call_input_name = "input_base"
240
    op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
241
    op_call = CALL_DISPATCH.substitute(
242
        unambiguous_name=fn.func.name.unambiguous_name(),
243
        unpacked_args=op_call_args,
244
    )
245

246
    # Multi-output views additionally require a view_idx for disambiguation.
247
    if returns_multi_tensor(fn):
248
        view_idx_name = "view_idx"
249
        view_idx_typename = "int64_t"
250
        view_idx_decl = f"{view_idx_typename} {view_idx_name}"
251
        constructor_args.append(view_idx_decl)
252
        clone_args.append(view_idx_name)
253
        state_variables.append(f"{view_idx_decl};")
254
        initializers.append(f"{view_idx_name}({view_idx_name})")
255
        op_call += f"[{view_idx_name}]"
256

257
    # Generate initializer list for the generated struct.
258
    initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
259

260
    # Generate getter / setter logic for any symints.
261
    symint_bindings = [
262
        b
263
        for b in non_self_bindings
264
        if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
265
    ]
266
    symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
267
    get_symints, set_symints, num_symints = generate_state_getter_setter(
268
        symint_bindings, symints_vec_type
269
    )
270

271
    # Generate getter / setter logic for any tensors.
272
    tensor_bindings = [
273
        b
274
        for b in non_self_bindings
275
        if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
276
    ]
277
    tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
278
    get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
279
        tensor_bindings, tensors_vec_type
280
    )
281

282
    return template.substitute(
283
        op=view_func_name(fn),
284
        uppercase_op=view_func_name(fn, camel_case=False).upper(),
285
        superclass="torch::autograd::ViewFunc",
286
        initializer_list=initializer_list,
287
        state=state_variables,
288
        constructor_args=constructor_args,
289
        clone_args=clone_args,
290
        symints_vec=symints_vec_type.name,
291
        get_symints=get_symints,
292
        set_symints=set_symints,
293
        num_symints=num_symints,
294
        tensors_vec=tensors_vec_type.name,
295
        get_tensors=get_tensors,
296
        set_tensors=set_tensors,
297
        num_tensors=num_tensors,
298
        call_input_name=call_input_name,
299
        op_call=op_call,
300
    )
301

302

303
def gen_view_funcs(
304
    out: str,
305
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
306
    template_path: str,
307
) -> None:
308
    # don't need the info parts, just the function
309
    fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
310
    # only want out-of-place views
311
    view_fns = [
312
        fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
313
    ]
314

315
    declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
316
    definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
317
    ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
318

319
    file_basename = "ViewFuncs"
320
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
321
    for suffix in [".h", ".cpp"]:
322
        fname = file_basename + suffix
323
        fm.write_with_template(
324
            fname,
325
            fname,
326
            lambda: {
327
                "generated_comment": "@"
328
                + f"generated from {fm.template_dir_for_comments()}/"
329
                + fname,
330
                "view_func_declarations": declarations,
331
                "view_func_definitions": definitions,
332
                "ops_headers": ops_headers,
333
            },
334
        )
335

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.