pytorch

Форк
0
/
gen_autograd_functions.py 
912 строк · 31.3 Кб
1
# Generates C++ autograd functions for the derivatives of ATen operations
2
#
3
# This writes two files:
4
#  Functions.h/cpp: subclasses of autograd::Node
5
#  python_functions.h/cpp: Python bindings for the above classes
6
#
7
from typing import Dict, List, Sequence, Tuple
8

9
from torchgen.api.autograd import (
10
    Derivative,
11
    DifferentiabilityInfo,
12
    SavedAttribute,
13
    uses_retain_variables,
14
    uses_single_grad,
15
)
16
from torchgen.api.types import (
17
    ArrayRefCType,
18
    BaseCppType,
19
    BaseCType,
20
    Binding,
21
    boolT,
22
    doubleT,
23
    intArrayRefT,
24
    iTensorListRefT,
25
    ListCType,
26
    longT,
27
    MutRefCType,
28
    OptionalCType,
29
    optionalIntArrayRefT,
30
    optionalSymIntArrayRefT,
31
    scalarT,
32
    stringT,
33
    symIntArrayRefT,
34
    SymIntT,
35
    TENSOR_LIST_LIKE_CTYPES,
36
    tensorListT,
37
    tensorT,
38
    VectorCType,
39
)
40
from torchgen.code_template import CodeTemplate
41
from torchgen.model import Argument, FunctionSchema
42
from torchgen.utils import FileManager
43

44
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
45

46
FUNCTION_DECLARATION = CodeTemplate(
47
    """\
48
#ifdef _WIN32
49
struct ${op} : public ${superclass} {
50
  TORCH_API ${op}() = default;
51
#else
52
struct TORCH_API ${op} : public ${superclass} {
53
#endif
54
  using ${superclass}::${superclass};
55
  variable_list apply(variable_list&& grads) override;
56
  std::string name() const override { return "${op}"; }
57
  void release_variables() override {
58
    ${thread_lock}
59
    ${release_variables}
60
  }
61
  ${will_release_variables}
62
  void compiled_args(CompiledNodeArgs& args) override;
63
  variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
64
  ${saved_variables}
65
  ${saved_list_sizes}
66
};
67
"""
68
)
69

70
WILL_RELEASE_VARIABLES = CodeTemplate(
71
    """\
72
bool retain_variables = true;
73
void will_release_variables() override {
74
  retain_variables = false;
75
}
76
"""
77
)
78

79
FUNCTION_DEFINITION = CodeTemplate(
80
    """\
81
variable_list ${op}::apply(variable_list&& grads) {
82
  ${thread_lock}
83
  ${asserts}
84
  IndexRangeGenerator gen;
85
  ${compute_index_ranges}
86
  variable_list grad_inputs(gen.size());
87
  ${body}
88
  return grad_inputs;
89
}
90
void ${op}::compiled_args(CompiledNodeArgs& args) {
91
    ${compiled_args}
92
}
93
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
94
    ${apply_with_saved_before}
95
    variable_list result = apply(variable_list(grads));
96
    ${apply_with_saved_after}
97
    return result;
98
}
99
"""
100
)
101

102
GRAD_INPUT_MASK = CodeTemplate(
103
    """\
104
  auto grad_input_mask = std::array<bool, ${n}>{
105
    ${masks}
106
  };\
107
"""
108
)
109

110
DERIVATIVE_SINGLE = CodeTemplate(
111
    """\
112
if (task_should_compute_output({ ${name}_ix })) {
113
  auto grad_result = ${derivative};
114
  copy_range(grad_inputs, ${name}_ix, grad_result);
115
}
116
"""
117
)
118

119
# note(crcrpar): `self` argument and other optional positional argument
120
# of foreach functions are basically a list of n `Tensor`s thus iterating over
121
# `grads` in order to utilize and apply the existing derivative definitions
122
# to each `Tensor`(s) of `self`, and the others.
123
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
124
    """\
125
if (task_should_compute_output({ ${name}_ix })) {
126
  std::vector<Tensor> grad_result;
127
  grad_result.reserve(grads.size());
128
  for (const auto & i : c10::irange(grads.size())) {
129
    if (grads[i].defined()) {
130
      grad_result.emplace_back(${derivative});
131
    } else {
132
      grad_result.emplace_back(Tensor());
133
    }
134
  }
135
  copy_range(grad_inputs, ${name}_ix, grad_result);
136
}
137
"""
138
)
139

140
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
141
    """\
142
  if (task_should_compute_output({ ${name}_ix })) {
143
    copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
144
  }
145
"""
146
)
147

148
DERIVATIVE_MULTI = CodeTemplate(
149
    """\
150
if (task_should_compute_output({ ${idx_ranges} })) {
151
  ${grad_input_mask}
152
  auto grad_result = ${derivative};
153
  ${copy_ranges}
154
}
155
"""
156
)
157

158
# Generates python bindings
159
#
160
# This generates the definitions for:
161
#   (1) The PyTypeObject for each backward grad_fn subclassing Node
162
#   (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs)
163
#       We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs
164
#       Each PyGetSetDef has a function ptr to a getter, also defined here (3).
165
#   (3) Getters for each of grad_fn's saved inputs and outputs.
166
#
167
PY_FUNCTION_DEFINITION = CodeTemplate(
168
    """\
169
static PyTypeObject ${op}Class;
170
addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties);
171
"""
172
)
173

174
PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
175
    """\
176
${all_getter_definitions}
177

178
static struct PyGetSetDef ${op}_properties[] = {
179
  THP_FUNCTION_DEFAULT_PROPERTIES,
180
  ${all_getsetdef_structs}
181
  {nullptr} /* sentinel */
182
};
183

184
"""
185
)
186

187
PY_GETSETDEF_STRUCT = CodeTemplate(
188
    """\
189
{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
190
)
191

192
PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
193
    """\
194
{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
195
)
196

197
# Getter templates
198
GETTER_DEFINITION = CodeTemplate(
199
    """\
200
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
201
  HANDLE_TH_ERRORS
202
  auto prop = static_cast<${op}*>(self->cdata.get())->${name};
203
  ${body}
204
  END_HANDLE_TH_ERRORS
205
}
206
"""
207
)
208

209
GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
210
    """\
211
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
212
  HANDLE_TH_ERRORS
213
  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
214
  ${body}
215
  END_HANDLE_TH_ERRORS
216
}
217
"""
218
)
219

220
GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
221
    """\
222
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
223
  HANDLE_TH_ERRORS
224
  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
225
  ${body}
226
  END_HANDLE_TH_ERRORS
227
}
228
"""
229
)
230

231
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
232
    """\
233
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
234
  HANDLE_TH_ERRORS
235
  const auto *node = static_cast<${op}*>(self->cdata.get());
236
  const auto& prop = node->${name}_;
237
  if (node->${name}_released_) {
238
    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
239
    return nullptr;
240
  }
241
  ${body}
242
  END_HANDLE_TH_ERRORS
243
}
244
"""
245
)
246

247
GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
248
    """\
249
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
250
  HANDLE_TH_ERRORS
251
  const auto *node = static_cast<${op}*>(self->cdata.get());
252
  const auto& prop = node->${name}_;
253
  if (node->${name}_released_) {
254
    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
255
    return nullptr;
256
  }
257
  ${body}
258
  END_HANDLE_TH_ERRORS
259
}
260
"""
261
)
262

263
GETTER_DEFINITION_OPT = CodeTemplate(
264
    """\
265
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
266
  HANDLE_TH_ERRORS
267
  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
268
  if (!opt_prop.has_value()) {
269
    Py_RETURN_NONE;
270
  }
271
  auto prop = opt_prop.value();
272
  ${body}
273
  END_HANDLE_TH_ERRORS
274
}
275
"""
276
)
277

278
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
279
    """\
280
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
281
  HANDLE_TH_ERRORS
282
  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
283
  if (!opt_prop.list.has_value()) {
284
    Py_RETURN_NONE;
285
  }
286
  auto prop = opt_prop.list.value();
287
  ${body}
288
  END_HANDLE_TH_ERRORS
289
}
290
"""
291
)
292

293
# Getter body
294
GETTER_BODY_SAVEDVAR = """\
295
return THPVariable_Wrap(prop.unpack(self->cdata));
296
"""
297

298
GETTER_BODY_RAW_SAVEDVAR = """\
299
pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference);
300
return obj.release().ptr();
301
"""
302

303
GETTER_BODY_VEC_SAVEDVAR = """\
304
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
305
for (auto i: c10::irange(prop.size())) {
306
  PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
307
}
308
return tup;
309
"""
310

311
GETTER_BODY_RAW_VEC_SAVEDVAR = """\
312
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
313
for (auto i : c10::irange(prop.size())) {
314
  pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference);
315
  PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr());
316
}
317
return tup;
318
"""
319

320
GETTER_BODY_ARRAYREF_LONG = """\
321
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
322
for (auto i : c10::irange(prop.size())) {
323
  PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
324
}
325
return tup;
326
"""
327

328
GETTER_BODY_ARRAYREF_SYMINT = """\
329
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
330
for (auto i : c10::irange(prop.size())) {
331
    auto si = prop[i];
332
    if (auto m = si.maybe_as_int()) {
333
      PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m));
334
    } else {
335
      auto py_symint = py::cast(si).release().ptr();
336
      PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
337
    }
338
}
339
return tup;
340
"""
341

342
GETTER_BODY_ARRAYREF_DOUBLE = """\
343
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
344
for (auto i : c10::irange(prop.size())) {
345
  PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
346
}
347
return tup;
348
"""
349

350
GETTER_BODY_INT64_T = """\
351
return PyLong_FromUnsignedLong((int64_t) prop);
352
"""
353

354
GETTER_BODY_SYMINT = """\
355
if (auto m = prop.maybe_as_int()) {
356
  return PyLong_FromUnsignedLong(*m);
357
} else {
358
  return py::cast(prop).release().ptr();
359
}
360
"""
361

362
GETTER_BODY_DOUBLE = """\
363
return PyFloat_FromDouble((double) prop);
364
"""
365

366
GETTER_BODY_BOOL = """\
367
if (prop) {
368
  Py_RETURN_TRUE;
369
} else {
370
  Py_RETURN_FALSE;
371
}
372
"""
373

374
GETTER_BODY_STRING = """\
375
return PyUnicode_FromStringAndSize(prop.data(), prop.size());
376
"""
377

378
GETTER_BODY_SCALAR = """\
379
if (prop.isComplex()) {
380
  auto cprop = prop.to<c10::complex<double>>();
381
  return PyComplex_FromDoubles(cprop.real(), cprop.imag());
382
} else if (prop.isFloatingPoint()) {
383
  return PyFloat_FromDouble(prop.to<double>());
384
} else if (prop.isIntegral(/*includeBool=*/false)) {
385
  return PyLong_FromLong(prop.to<int64_t>());
386
} else if (prop.isBoolean()) {
387
  if (prop.to<bool>()) {
388
    Py_RETURN_TRUE;
389
  } else {
390
    Py_RETURN_FALSE;
391
  }
392
} else {
393
  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
394
  return nullptr;
395
}
396
"""
397

398

399
GETTER_BODY_VEC_SCALAR = """\
400
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
401
for (auto i: c10::irange(prop.size())) {
402
  if (prop[i].isComplex()) {
403
    auto cprop = prop[i].to<c10::complex<double>>();
404
    PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag()));
405
  } else if (prop[i].isFloatingPoint()) {
406
    auto double_prop = prop[i].to<double>();
407
    PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop));
408
  } else if (prop[i].isIntegral(/*includeBool=*/false)) {
409
    auto long_prop = prop[i].to<int64_t>();
410
    PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop));
411
  } else if (prop[i].isBoolean()) {
412
    if (prop[i].to<bool>()) {
413
      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True);
414
    } else {
415
      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False);
416
    }
417
  } else {
418
    PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
419
    return nullptr;
420
  }
421
}
422
return tup;
423
"""
424

425

426
MISC_GETTER_DEFS = {
427
    OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
428
    OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT),
429
    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
430
    OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
431
    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
432
    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
433
    OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
434
}
435

436
# These functions have backwards which cannot be traced, and so must have
437
# their backward functions traced opaquely.
438
# VIEW_FUNCTIONS are not traceable because they use as_strided, which
439
# has an untraceable backwards, see
440
# https://github.com/pytorch/pytorch/issues/4250
441
# TODO: This is probably not exhaustive, but it's a start
442
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
443

444

445
def get_infos_with_derivatives_list(
446
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
447
) -> List[DifferentiabilityInfo]:
448
    diff_info_list = [
449
        info
450
        for diffinfo_dict in differentiability_infos.values()
451
        for info in diffinfo_dict.values()
452
    ]
453

454
    return list(filter(lambda info: info.args_with_derivatives, diff_info_list))
455

456

457
def gen_autograd_functions_lib(
458
    out: str,
459
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
460
    template_path: str,
461
) -> None:
462
    """Functions.h and Functions.cpp body
463

464
    These contain the auto-generated subclasses of torch::autograd::Node
465
    for each every differentiable torch function.
466
    """
467

468
    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
469
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
470
    infos = get_infos_with_derivatives_list(differentiability_infos)
471
    declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
472
    definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
473

474
    file_basename = "Functions"
475
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
476
    for suffix in [".h", ".cpp"]:
477
        fname = file_basename + suffix
478
        fm.write_with_template(
479
            fname,
480
            fname,
481
            lambda: {
482
                "generated_comment": "@"
483
                + f"generated from {fm.template_dir_for_comments()}/"
484
                + fname,
485
                "autograd_function_declarations": declarations,
486
                "autograd_function_definitions": definitions,
487
            },
488
        )
489

490

491
def gen_autograd_functions_python(
492
    out: str,
493
    differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
494
    template_path: str,
495
) -> None:
496
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
497
    num_shards = 5
498
    fm.write(
499
        "python_functions.h",
500
        lambda: {
501
            "generated_comment": "@"
502
            + f"generated from {fm.template_dir_for_comments()}/python_functions.h",
503
            "shard_forward_declare": [
504
                f"void initialize_autogenerated_functions_{i}(PyObject* module);"
505
                for i in range(num_shards)
506
            ],
507
            "shard_call": [
508
                f"initialize_autogenerated_functions_{i}(module);"
509
                for i in range(num_shards)
510
            ],
511
        },
512
    )
513

514
    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
515
    # infos with the diff dispatchkeys but the same name will still be in the same shard.
516
    infos = get_infos_with_derivatives_list(differentiability_infos)
517
    fm.write_sharded(
518
        "python_functions.cpp",
519
        infos,
520
        key_fn=lambda info: info.name,
521
        base_env={
522
            "generated_comment": "@"
523
            + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp",
524
        },
525
        env_callable=lambda info: {
526
            "py_function_initializers": [
527
                process_function(info, PY_FUNCTION_DEFINITION)
528
            ],
529
            "py_function_props_and_getters": [
530
                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
531
            ],
532
        },
533
        num_shards=num_shards,
534
        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
535
    )
536

537

538
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
539
    saved_variables: List[str] = []
540
    release_variables: List[str] = []
541
    saved_list_sizes: List[str] = []
542
    unpack: List[str] = []
543
    asserts: List[str] = []
544
    compute_index_ranges: List[str] = []
545
    getter_definitions: List[str] = []
546
    py_getsetdef_structs: List[str] = []
547
    compiled_args: List[str] = []
548
    apply_with_saved_before: List[str] = []
549
    apply_with_saved_after: List[str] = []
550

551
    for arg in info.args_with_derivatives:
552
        if arg.type in TENSOR_LIST_LIKE_CTYPES:
553
            size = f"{arg.name}_size_"
554
            saved_list_sizes.append(f"size_t {arg.name}_size_;")
555
        else:
556
            size = "1"
557
        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
558

559
    def save_var(var: SavedAttribute, is_output: bool) -> None:
560
        name = var.nctype.name
561
        type = var.nctype.type
562
        should_append_getsetdef = True
563
        should_append_raw_getsetdef = False
564
        visit_name = name
565

566
        if (
567
            type == BaseCType(tensorT)
568
            or type == OptionalCType(BaseCType(tensorT))
569
            or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
570
            or (type == BaseCType(scalarT) and is_output)
571
        ):
572
            saved_variables.append(f"SavedVariable {name}_;")
573
            release_variables.append(f"{name}_.reset_data();")
574
            ptr = "shared_from_this()" if is_output else ""
575
            unpack.append(f"auto {name} = {name}_.unpack({ptr});")
576
            getter_definitions.append(
577
                GETTER_DEFINITION_SAVEDVAR.substitute(
578
                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
579
                )
580
            )
581
            getter_definitions.append(
582
                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
583
                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
584
                )
585
            )
586
            should_append_raw_getsetdef = True
587
            visit_name = f"{name}_"
588
        elif (
589
            type == BaseCType(tensorListT)
590
            or type == BaseCType(iTensorListRefT)
591
            or type == VectorCType(BaseCType(tensorT))
592
        ):
593
            # note(crcrpar): [nuanced return type of out-of-place foreach functions]
594
            # When an out-of-place foreach function whose return signature is `Tensor[]`
595
            # spells out its backward definitions in `derivatives.yaml`, and some of them depend on
596
            # `result`, `result`'s type is interpreted and treated as `std::vector<Tensor>`.
597
            # An out-of-place foreach whose backwards rely on their output doesn't suffer from this
598
            # difference if the definitions are codegen'ed.
599
            # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor`
600
            # as of https://github.com/pytorch/pytorch/pull/105504.
601
            if type == VectorCType(BaseCType(tensorT)):
602
                assert (
603
                    info.func.func.name.name.base.startswith("_foreach") and is_output
604
                )
605
            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
606
            saved_variables.append(f"bool {name}_released_ = false;")
607
            # Just clear() is sufficient, we don't need to loop and clear each variable.
608
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
609
            release_variables.append(f"{name}_.clear();")
610
            release_variables.append(f"{name}_released_ = true;")
611
            ptr = "shared_from_this()" if is_output else "nullptr"
612
            unpack.append(f"auto {name} = unpack_list({name}_, {ptr});")
613
            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
614
            getter_definitions.append(
615
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
616
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
617
                )
618
            )
619
            getter_definitions.append(
620
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
621
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
622
                )
623
            )
624
            should_append_raw_getsetdef = True
625
            visit_name = f"{name}_"
626
        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
627
            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
628
            saved_variables.append(f"bool {name}_released_ = false;")
629
            # Just clear() is sufficient, we don't need to loop and clear each variable.
630
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
631
            release_variables.append(f"{name}_.clear();")
632
            release_variables.append(f"{name}_released_ = true;")
633
            unpack.append(f"auto {name} = unpack_opt_list({name}_);")
634
            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
635
            getter_definitions.append(
636
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
637
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
638
                )
639
            )
640
            getter_definitions.append(
641
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
642
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
643
                )
644
            )
645
            should_append_raw_getsetdef = True
646
            visit_name = f"{name}_"
647
        elif type == BaseCType(intArrayRefT):
648
            saved_variables.append(f"std::vector<int64_t> {name};")
649
            getter_definitions.append(
650
                GETTER_DEFINITION.substitute(
651
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
652
                )
653
            )
654
        elif type == BaseCType(symIntArrayRefT):
655
            saved_variables.append(f"std::vector<c10::SymInt> {name};")
656
            getter_definitions.append(
657
                GETTER_DEFINITION.substitute(
658
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
659
                )
660
            )
661
        elif type == BaseCType(optionalIntArrayRefT):
662
            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
663
            getter_definitions.append(
664
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
665
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
666
                )
667
            )
668
        elif type == BaseCType(optionalSymIntArrayRefT):
669
            saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
670
            getter_definitions.append(
671
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
672
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
673
                )
674
            )
675
        elif type == OptionalCType(BaseCType(intArrayRefT)):
676
            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
677
            getter_definitions.append(
678
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
679
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
680
                )
681
            )
682
        elif type == OptionalCType(BaseCType(symIntArrayRefT)):
683
            saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
684
            getter_definitions.append(
685
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
686
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
687
                )
688
            )
689
        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
690
            saved_variables.append(f"c10::OptionalArray<double> {name};")
691
            getter_definitions.append(
692
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
693
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
694
                )
695
            )
696
        elif type == BaseCType(longT):
697
            saved_variables.append(f"{type.cpp_type()} {name} = 0;")
698
            getter_definitions.append(
699
                GETTER_DEFINITION.substitute(
700
                    op=info.op, name=name, body=GETTER_BODY_INT64_T
701
                )
702
            )
703
        elif type == BaseCType(SymIntT):
704
            saved_variables.append(f"c10::SymInt {name};")
705
            getter_definitions.append(
706
                GETTER_DEFINITION.substitute(
707
                    op=info.op, name=name, body=GETTER_BODY_SYMINT
708
                )
709
            )
710
        elif type == BaseCType(stringT):
711
            saved_variables.append(f"std::string {name};")
712
            getter_definitions.append(
713
                GETTER_DEFINITION.substitute(
714
                    op=info.op, name=name, body=GETTER_BODY_STRING
715
                )
716
            )
717
        elif type == OptionalCType(BaseCType(stringT)):
718
            saved_variables.append(f"c10::optional<std::string> {name};")
719
            getter_definitions.append(
720
                GETTER_DEFINITION_OPT.substitute(
721
                    op=info.op, name=name, body=GETTER_BODY_STRING
722
                )
723
            )
724
        elif type == ArrayRefCType(
725
            elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
726
        ):
727
            saved_variables.append(f"std::vector<at::Scalar> {name};")
728
            saved_variables.append(f"bool {name}_released_ = false;")
729
            # Just clear() is sufficient, we don't need to loop and clear each variable.
730
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
731
            release_variables.append(f"{name}.clear();")
732
            # release_variables.append(f"{name}_released_ = true;")
733
            # unpack.append(f"auto {name} = unpack_list({name}_);")
734
            # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
735
            getter_definitions.append(
736
                CodeTemplate(
737
                    """\
738
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
739
  HANDLE_TH_ERRORS
740
  const auto *node = static_cast<${op}*>(self->cdata.get());
741
  const auto& prop = node->${name};
742
  if (node->${name}_released_) {
743
    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
744
    return nullptr;
745
  }
746
  ${body}
747
  END_HANDLE_TH_ERRORS
748
}
749
                            """
750
                ).substitute(
751
                    op=info.op,
752
                    name=name,
753
                    body=GETTER_BODY_VEC_SCALAR,
754
                )
755
            )
756
        else:
757
            # Check for indicators that you're putting a non-owning reference
758
            # into the saved variable field.  If this is spuriously firing,
759
            # edit this field.  Otherwise, you probably need to add a case
760
            # above.
761
            assert (
762
                "ref" not in type.cpp_type().lower()
763
                and "view" not in type.cpp_type().lower()
764
                and "*" not in type.cpp_type()
765
                and "&" not in type.cpp_type()
766
            ), f"{type.cpp_type()} looks like it contains a non-owning reference"
767
            saved_variables.append(f"{type.cpp_type()} {name};")
768

769
            if type in MISC_GETTER_DEFS:
770
                getter_def, body = MISC_GETTER_DEFS[type]
771
                getter_definitions.append(
772
                    getter_def.substitute(op=info.op, name=name, body=body)
773
                )
774
            else:
775
                # Types we don't expose python bindings to yet:
776
                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
777
                #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
778
                should_append_getsetdef = False
779

780
        if should_append_getsetdef:
781
            py_getsetdef_structs.append(
782
                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
783
            )
784
        if should_append_raw_getsetdef:
785
            py_getsetdef_structs.append(
786
                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
787
            )
788

789
        compiled_args.append(f"args.collect({visit_name});")
790
        apply_with_saved_before.append(f"saved.before({visit_name});")
791
        apply_with_saved_after.append(f"saved.after({visit_name});")
792

793
    for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
794
        save_var(var, is_output=False)
795
    for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
796
        save_var(var, is_output=True)
797

798
    # lock the mutex when we release variables and in Node::apply to protect thread safety
799
    # see Note [Thread Safety on Autograd Node]
800
    if len(release_variables) > 0:
801
        thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
802
    else:
803
        thread_lock = ""
804

805
    if uses_retain_variables(info):
806
        will_release_variables = WILL_RELEASE_VARIABLES.substitute()
807
    else:
808
        will_release_variables = ""
809

810
    body: List[str] = []
811

812
    if uses_single_grad(info):
813
        body.append("const auto& grad = grads[0];")
814
    else:
815
        # Generate aliases for gradients named for returned values.
816
        body.extend(
817
            f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
818
            for name in sorted(info.used_named_gradients)
819
        )
820

821
    def emit_derivative(
822
        derivative: Derivative,
823
        args_with_derivatives: Sequence[Binding],
824
    ) -> Tuple[bool, str]:
825
        formula = derivative.formula
826
        var_names = derivative.var_names
827
        if len(var_names) == 1:
828
            checks_any_grad_defined = False
829
            if "not_implemented" not in formula:
830
                matching_args = [
831
                    arg for arg in args_with_derivatives if arg.name == var_names[0]
832
                ]
833
                if len(matching_args) == 1:
834
                    # We can add undefined grad support if the input variable is a Tensor
835
                    arg = matching_args[0]
836
                    if isinstance(arg.argument, Argument) and str(
837
                        arg.argument.type
838
                    ) in ("Tensor", "Tensor?"):
839
                        formula = "any_grad_defined ? (" + formula + ") : Tensor()"
840
                        checks_any_grad_defined = True
841
            if info.name.startswith("_foreach_"):
842
                derivative_template = DERIVATIVE_SINGLE_FOREACH
843
            else:
844
                derivative_template = DERIVATIVE_SINGLE
845
            return (
846
                checks_any_grad_defined,
847
                derivative_template.substitute(name=var_names[0], derivative=formula),
848
            )
849
        else:
850
            if "grad_input_mask" in formula:
851
                masks = [
852
                    f"task_should_compute_output({{ {n}_ix }})," for n in var_names
853
                ]
854
                grad_input_mask = GRAD_INPUT_MASK.substitute(
855
                    masks=masks, n=len(var_names)
856
                )
857
            else:
858
                grad_input_mask = ""
859
            idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
860
            copy_ranges: List[str] = []
861
            for i, n in enumerate(var_names):
862
                copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
863
            return False, DERIVATIVE_MULTI.substitute(
864
                idx_ranges=idx_ranges,
865
                copy_ranges=copy_ranges,
866
                derivative=formula,
867
                grad_input_mask=grad_input_mask,
868
            )
869

870
    body.extend(unpack)
871
    need_any_grad_defined_var = False
872
    for derivative in info.derivatives:
873
        checks_any_grad_defined, derivative_text = emit_derivative(
874
            derivative, info.args_with_derivatives
875
        )
876
        body.append(derivative_text)
877
        need_any_grad_defined_var |= checks_any_grad_defined
878
    # Since single-output derivative formulas need to check if grads are
879
    # defined, only perform the check once, before all the formulas
880
    if need_any_grad_defined_var:
881
        body.insert(
882
            -len(info.derivatives),
883
            "bool any_grad_defined = any_variable_defined(grads);",
884
        )
885

886
    if info.name in UNTRACEABLE_FUNCTIONS:
887
        superclass = "Node"
888
    else:
889
        superclass = "TraceableFunction"
890

891
    all_getsetdef_structs = (
892
        ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
893
    )
894
    all_getter_definitions = "\n".join(getter_definitions)
895

896
    return template.substitute(
897
        op=info.op,
898
        compute_index_ranges=compute_index_ranges,
899
        saved_variables=saved_variables,
900
        release_variables=release_variables,
901
        saved_list_sizes=saved_list_sizes,
902
        asserts=asserts,
903
        thread_lock=thread_lock,
904
        will_release_variables=will_release_variables,
905
        body=body,
906
        superclass=superclass,
907
        all_getter_definitions=all_getter_definitions,
908
        all_getsetdef_structs=all_getsetdef_structs,
909
        compiled_args=compiled_args,
910
        apply_with_saved_before=apply_with_saved_before,
911
        apply_with_saved_after=apply_with_saved_after,
912
    )
913

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.