7
from typing import Dict, List, Sequence, Tuple
9
from torchgen.api.autograd import (
11
DifferentiabilityInfo,
13
uses_retain_variables,
16
from torchgen.api.types import (
30
optionalSymIntArrayRefT,
35
TENSOR_LIST_LIKE_CTYPES,
40
from torchgen.code_template import CodeTemplate
41
from torchgen.model import Argument, FunctionSchema
42
from torchgen.utils import FileManager
44
from .gen_inplace_or_view_type import VIEW_FUNCTIONS
46
FUNCTION_DECLARATION = CodeTemplate(
49
struct ${op} : public ${superclass} {
50
TORCH_API ${op}() = default;
52
struct TORCH_API ${op} : public ${superclass} {
54
using ${superclass}::${superclass};
55
variable_list apply(variable_list&& grads) override;
56
std::string name() const override { return "${op}"; }
57
void release_variables() override {
61
${will_release_variables}
62
void compiled_args(CompiledNodeArgs& args) override;
63
variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
70
WILL_RELEASE_VARIABLES = CodeTemplate(
72
bool retain_variables = true;
73
void will_release_variables() override {
74
retain_variables = false;
79
FUNCTION_DEFINITION = CodeTemplate(
81
variable_list ${op}::apply(variable_list&& grads) {
84
IndexRangeGenerator gen;
85
${compute_index_ranges}
86
variable_list grad_inputs(gen.size());
90
void ${op}::compiled_args(CompiledNodeArgs& args) {
93
variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
94
${apply_with_saved_before}
95
variable_list result = apply(variable_list(grads));
96
${apply_with_saved_after}
102
GRAD_INPUT_MASK = CodeTemplate(
104
auto grad_input_mask = std::array<bool, ${n}>{
110
DERIVATIVE_SINGLE = CodeTemplate(
112
if (task_should_compute_output({ ${name}_ix })) {
113
auto grad_result = ${derivative};
114
copy_range(grad_inputs, ${name}_ix, grad_result);
123
DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
125
if (task_should_compute_output({ ${name}_ix })) {
126
std::vector<Tensor> grad_result;
127
grad_result.reserve(grads.size());
128
for (const auto & i : c10::irange(grads.size())) {
129
if (grads[i].defined()) {
130
grad_result.emplace_back(${derivative});
132
grad_result.emplace_back(Tensor());
135
copy_range(grad_inputs, ${name}_ix, grad_result);
140
DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
142
if (task_should_compute_output({ ${name}_ix })) {
143
copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
148
DERIVATIVE_MULTI = CodeTemplate(
150
if (task_should_compute_output({ ${idx_ranges} })) {
152
auto grad_result = ${derivative};
167
PY_FUNCTION_DEFINITION = CodeTemplate(
169
static PyTypeObject ${op}Class;
170
addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties);
174
PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
176
${all_getter_definitions}
178
static struct PyGetSetDef ${op}_properties[] = {
179
THP_FUNCTION_DEFAULT_PROPERTIES,
180
${all_getsetdef_structs}
181
{nullptr} /* sentinel */
187
PY_GETSETDEF_STRUCT = CodeTemplate(
189
{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
192
PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
194
{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
198
GETTER_DEFINITION = CodeTemplate(
200
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
202
auto prop = static_cast<${op}*>(self->cdata.get())->${name};
209
GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
211
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
213
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
220
GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
222
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
224
const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
231
GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
233
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
235
const auto *node = static_cast<${op}*>(self->cdata.get());
236
const auto& prop = node->${name}_;
237
if (node->${name}_released_) {
238
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
247
GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
249
PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
251
const auto *node = static_cast<${op}*>(self->cdata.get());
252
const auto& prop = node->${name}_;
253
if (node->${name}_released_) {
254
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
263
GETTER_DEFINITION_OPT = CodeTemplate(
265
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
267
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
268
if (!opt_prop.has_value()) {
271
auto prop = opt_prop.value();
278
GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
280
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
282
auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
283
if (!opt_prop.list.has_value()) {
286
auto prop = opt_prop.list.value();
294
GETTER_BODY_SAVEDVAR = """\
295
return THPVariable_Wrap(prop.unpack(self->cdata));
298
GETTER_BODY_RAW_SAVEDVAR = """\
299
pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference);
300
return obj.release().ptr();
303
GETTER_BODY_VEC_SAVEDVAR = """\
304
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
305
for (auto i: c10::irange(prop.size())) {
306
PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
311
GETTER_BODY_RAW_VEC_SAVEDVAR = """\
312
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
313
for (auto i : c10::irange(prop.size())) {
314
pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference);
315
PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr());
320
GETTER_BODY_ARRAYREF_LONG = """\
321
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
322
for (auto i : c10::irange(prop.size())) {
323
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
328
GETTER_BODY_ARRAYREF_SYMINT = """\
329
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
330
for (auto i : c10::irange(prop.size())) {
332
if (auto m = si.maybe_as_int()) {
333
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m));
335
auto py_symint = py::cast(si).release().ptr();
336
PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
342
GETTER_BODY_ARRAYREF_DOUBLE = """\
343
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
344
for (auto i : c10::irange(prop.size())) {
345
PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
350
GETTER_BODY_INT64_T = """\
351
return PyLong_FromUnsignedLong((int64_t) prop);
354
GETTER_BODY_SYMINT = """\
355
if (auto m = prop.maybe_as_int()) {
356
return PyLong_FromUnsignedLong(*m);
358
return py::cast(prop).release().ptr();
362
GETTER_BODY_DOUBLE = """\
363
return PyFloat_FromDouble((double) prop);
366
GETTER_BODY_BOOL = """\
374
GETTER_BODY_STRING = """\
375
return PyUnicode_FromStringAndSize(prop.data(), prop.size());
378
GETTER_BODY_SCALAR = """\
379
if (prop.isComplex()) {
380
auto cprop = prop.to<c10::complex<double>>();
381
return PyComplex_FromDoubles(cprop.real(), cprop.imag());
382
} else if (prop.isFloatingPoint()) {
383
return PyFloat_FromDouble(prop.to<double>());
384
} else if (prop.isIntegral(/*includeBool=*/false)) {
385
return PyLong_FromLong(prop.to<int64_t>());
386
} else if (prop.isBoolean()) {
387
if (prop.to<bool>()) {
393
PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
399
GETTER_BODY_VEC_SCALAR = """\
400
PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
401
for (auto i: c10::irange(prop.size())) {
402
if (prop[i].isComplex()) {
403
auto cprop = prop[i].to<c10::complex<double>>();
404
PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag()));
405
} else if (prop[i].isFloatingPoint()) {
406
auto double_prop = prop[i].to<double>();
407
PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop));
408
} else if (prop[i].isIntegral(/*includeBool=*/false)) {
409
auto long_prop = prop[i].to<int64_t>();
410
PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop));
411
} else if (prop[i].isBoolean()) {
412
if (prop[i].to<bool>()) {
413
PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True);
415
PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False);
418
PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
427
OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
428
OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT),
429
BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
430
OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
431
BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
432
BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
433
OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
442
UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
445
def get_infos_with_derivatives_list(
446
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]]
447
) -> List[DifferentiabilityInfo]:
450
for diffinfo_dict in differentiability_infos.values()
451
for info in diffinfo_dict.values()
454
return list(filter(lambda info: info.args_with_derivatives, diff_info_list))
457
def gen_autograd_functions_lib(
459
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
462
"""Functions.h and Functions.cpp body
464
These contain the auto-generated subclasses of torch::autograd::Node
465
for each every differentiable torch function.
470
infos = get_infos_with_derivatives_list(differentiability_infos)
471
declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
472
definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
474
file_basename = "Functions"
475
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
476
for suffix in [".h", ".cpp"]:
477
fname = file_basename + suffix
478
fm.write_with_template(
482
"generated_comment": "@"
483
+ f"generated from {fm.template_dir_for_comments()}/"
485
"autograd_function_declarations": declarations,
486
"autograd_function_definitions": definitions,
491
def gen_autograd_functions_python(
493
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
496
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
499
"python_functions.h",
501
"generated_comment": "@"
502
+ f"generated from {fm.template_dir_for_comments()}/python_functions.h",
503
"shard_forward_declare": [
504
f"void initialize_autogenerated_functions_{i}(PyObject* module);"
505
for i in range(num_shards)
508
f"initialize_autogenerated_functions_{i}(module);"
509
for i in range(num_shards)
516
infos = get_infos_with_derivatives_list(differentiability_infos)
518
"python_functions.cpp",
520
key_fn=lambda info: info.name,
522
"generated_comment": "@"
523
+ f"generated from {fm.template_dir_for_comments()}/python_functions.cpp",
525
env_callable=lambda info: {
526
"py_function_initializers": [
527
process_function(info, PY_FUNCTION_DEFINITION)
529
"py_function_props_and_getters": [
530
process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
533
num_shards=num_shards,
534
sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
538
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
539
saved_variables: List[str] = []
540
release_variables: List[str] = []
541
saved_list_sizes: List[str] = []
542
unpack: List[str] = []
543
asserts: List[str] = []
544
compute_index_ranges: List[str] = []
545
getter_definitions: List[str] = []
546
py_getsetdef_structs: List[str] = []
547
compiled_args: List[str] = []
548
apply_with_saved_before: List[str] = []
549
apply_with_saved_after: List[str] = []
551
for arg in info.args_with_derivatives:
552
if arg.type in TENSOR_LIST_LIKE_CTYPES:
553
size = f"{arg.name}_size_"
554
saved_list_sizes.append(f"size_t {arg.name}_size_;")
557
compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
559
def save_var(var: SavedAttribute, is_output: bool) -> None:
560
name = var.nctype.name
561
type = var.nctype.type
562
should_append_getsetdef = True
563
should_append_raw_getsetdef = False
567
type == BaseCType(tensorT)
568
or type == OptionalCType(BaseCType(tensorT))
569
or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
570
or (type == BaseCType(scalarT) and is_output)
572
saved_variables.append(f"SavedVariable {name}_;")
573
release_variables.append(f"{name}_.reset_data();")
574
ptr = "shared_from_this()" if is_output else ""
575
unpack.append(f"auto {name} = {name}_.unpack({ptr});")
576
getter_definitions.append(
577
GETTER_DEFINITION_SAVEDVAR.substitute(
578
op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
581
getter_definitions.append(
582
GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
583
op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
586
should_append_raw_getsetdef = True
587
visit_name = f"{name}_"
589
type == BaseCType(tensorListT)
590
or type == BaseCType(iTensorListRefT)
591
or type == VectorCType(BaseCType(tensorT))
601
if type == VectorCType(BaseCType(tensorT)):
603
info.func.func.name.name.base.startswith("_foreach") and is_output
605
saved_variables.append(f"std::vector<SavedVariable> {name}_;")
606
saved_variables.append(f"bool {name}_released_ = false;")
609
release_variables.append(f"{name}_.clear();")
610
release_variables.append(f"{name}_released_ = true;")
611
ptr = "shared_from_this()" if is_output else "nullptr"
612
unpack.append(f"auto {name} = unpack_list({name}_, {ptr});")
613
asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
614
getter_definitions.append(
615
GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
616
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
619
getter_definitions.append(
620
GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
621
op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
624
should_append_raw_getsetdef = True
625
visit_name = f"{name}_"
626
elif type == ListCType(OptionalCType(BaseCType(tensorT))):
627
saved_variables.append(f"std::vector<SavedVariable> {name}_;")
628
saved_variables.append(f"bool {name}_released_ = false;")
631
release_variables.append(f"{name}_.clear();")
632
release_variables.append(f"{name}_released_ = true;")
633
unpack.append(f"auto {name} = unpack_opt_list({name}_);")
634
asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
635
getter_definitions.append(
636
GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
637
op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
640
getter_definitions.append(
641
GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
642
op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
645
should_append_raw_getsetdef = True
646
visit_name = f"{name}_"
647
elif type == BaseCType(intArrayRefT):
648
saved_variables.append(f"std::vector<int64_t> {name};")
649
getter_definitions.append(
650
GETTER_DEFINITION.substitute(
651
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
654
elif type == BaseCType(symIntArrayRefT):
655
saved_variables.append(f"std::vector<c10::SymInt> {name};")
656
getter_definitions.append(
657
GETTER_DEFINITION.substitute(
658
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
661
elif type == BaseCType(optionalIntArrayRefT):
662
saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
663
getter_definitions.append(
664
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
665
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
668
elif type == BaseCType(optionalSymIntArrayRefT):
669
saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
670
getter_definitions.append(
671
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
672
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
675
elif type == OptionalCType(BaseCType(intArrayRefT)):
676
saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
677
getter_definitions.append(
678
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
679
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
682
elif type == OptionalCType(BaseCType(symIntArrayRefT)):
683
saved_variables.append(f"c10::OptionalArray<c10::SymInt> {name};")
684
getter_definitions.append(
685
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
686
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
689
elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
690
saved_variables.append(f"c10::OptionalArray<double> {name};")
691
getter_definitions.append(
692
GETTER_DEFINITION_OPT_ARRAYREF.substitute(
693
op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
696
elif type == BaseCType(longT):
697
saved_variables.append(f"{type.cpp_type()} {name} = 0;")
698
getter_definitions.append(
699
GETTER_DEFINITION.substitute(
700
op=info.op, name=name, body=GETTER_BODY_INT64_T
703
elif type == BaseCType(SymIntT):
704
saved_variables.append(f"c10::SymInt {name};")
705
getter_definitions.append(
706
GETTER_DEFINITION.substitute(
707
op=info.op, name=name, body=GETTER_BODY_SYMINT
710
elif type == BaseCType(stringT):
711
saved_variables.append(f"std::string {name};")
712
getter_definitions.append(
713
GETTER_DEFINITION.substitute(
714
op=info.op, name=name, body=GETTER_BODY_STRING
717
elif type == OptionalCType(BaseCType(stringT)):
718
saved_variables.append(f"c10::optional<std::string> {name};")
719
getter_definitions.append(
720
GETTER_DEFINITION_OPT.substitute(
721
op=info.op, name=name, body=GETTER_BODY_STRING
724
elif type == ArrayRefCType(
725
elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
727
saved_variables.append(f"std::vector<at::Scalar> {name};")
728
saved_variables.append(f"bool {name}_released_ = false;")
731
release_variables.append(f"{name}.clear();")
735
getter_definitions.append(
738
PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
740
const auto *node = static_cast<${op}*>(self->cdata.get());
741
const auto& prop = node->${name};
742
if (node->${name}_released_) {
743
PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
753
body=GETTER_BODY_VEC_SCALAR,
762
"ref" not in type.cpp_type().lower()
763
and "view" not in type.cpp_type().lower()
764
and "*" not in type.cpp_type()
765
and "&" not in type.cpp_type()
766
), f"{type.cpp_type()} looks like it contains a non-owning reference"
767
saved_variables.append(f"{type.cpp_type()} {name};")
769
if type in MISC_GETTER_DEFS:
770
getter_def, body = MISC_GETTER_DEFS[type]
771
getter_definitions.append(
772
getter_def.substitute(op=info.op, name=name, body=body)
778
should_append_getsetdef = False
780
if should_append_getsetdef:
781
py_getsetdef_structs.append(
782
PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
784
if should_append_raw_getsetdef:
785
py_getsetdef_structs.append(
786
PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
789
compiled_args.append(f"args.collect({visit_name});")
790
apply_with_saved_before.append(f"saved.before({visit_name});")
791
apply_with_saved_after.append(f"saved.after({visit_name});")
793
for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
794
save_var(var, is_output=False)
795
for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
796
save_var(var, is_output=True)
800
if len(release_variables) > 0:
801
thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
805
if uses_retain_variables(info):
806
will_release_variables = WILL_RELEASE_VARIABLES.substitute()
808
will_release_variables = ""
812
if uses_single_grad(info):
813
body.append("const auto& grad = grads[0];")
817
f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
818
for name in sorted(info.used_named_gradients)
822
derivative: Derivative,
823
args_with_derivatives: Sequence[Binding],
824
) -> Tuple[bool, str]:
825
formula = derivative.formula
826
var_names = derivative.var_names
827
if len(var_names) == 1:
828
checks_any_grad_defined = False
829
if "not_implemented" not in formula:
831
arg for arg in args_with_derivatives if arg.name == var_names[0]
833
if len(matching_args) == 1:
835
arg = matching_args[0]
836
if isinstance(arg.argument, Argument) and str(
838
) in ("Tensor", "Tensor?"):
839
formula = "any_grad_defined ? (" + formula + ") : Tensor()"
840
checks_any_grad_defined = True
841
if info.name.startswith("_foreach_"):
842
derivative_template = DERIVATIVE_SINGLE_FOREACH
844
derivative_template = DERIVATIVE_SINGLE
846
checks_any_grad_defined,
847
derivative_template.substitute(name=var_names[0], derivative=formula),
850
if "grad_input_mask" in formula:
852
f"task_should_compute_output({{ {n}_ix }})," for n in var_names
854
grad_input_mask = GRAD_INPUT_MASK.substitute(
855
masks=masks, n=len(var_names)
859
idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
860
copy_ranges: List[str] = []
861
for i, n in enumerate(var_names):
862
copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
863
return False, DERIVATIVE_MULTI.substitute(
864
idx_ranges=idx_ranges,
865
copy_ranges=copy_ranges,
867
grad_input_mask=grad_input_mask,
871
need_any_grad_defined_var = False
872
for derivative in info.derivatives:
873
checks_any_grad_defined, derivative_text = emit_derivative(
874
derivative, info.args_with_derivatives
876
body.append(derivative_text)
877
need_any_grad_defined_var |= checks_any_grad_defined
880
if need_any_grad_defined_var:
882
-len(info.derivatives),
883
"bool any_grad_defined = any_variable_defined(grads);",
886
if info.name in UNTRACEABLE_FUNCTIONS:
889
superclass = "TraceableFunction"
891
all_getsetdef_structs = (
892
",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
894
all_getter_definitions = "\n".join(getter_definitions)
896
return template.substitute(
898
compute_index_ranges=compute_index_ranges,
899
saved_variables=saved_variables,
900
release_variables=release_variables,
901
saved_list_sizes=saved_list_sizes,
903
thread_lock=thread_lock,
904
will_release_variables=will_release_variables,
906
superclass=superclass,
907
all_getter_definitions=all_getter_definitions,
908
all_getsetdef_structs=all_getsetdef_structs,
909
compiled_args=compiled_args,
910
apply_with_saved_before=apply_with_saved_before,
911
apply_with_saved_after=apply_with_saved_after,