1
from __future__ import annotations
4
from dataclasses import dataclass
5
from typing import cast, Sequence
7
from torchgen import local
8
from torchgen.api import cpp
9
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
10
from torchgen.model import (
16
NativeFunctionsViewGroup,
20
from torchgen.utils import IDENT_REGEX
23
# Represents a saved attribute involved in backward calculation.
24
# Note that it can be a derived property of an input argument, e.g.:
25
# we could save `other.scalar_type()` instead of the entire `other` tensor.
26
@dataclass(frozen=True)
28
# The NamedCType holds the updated name and cpp type of the attribute
29
# for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
32
# The expression to read the derived property at save time, e.g.:
33
# `other.scalar_type()`.
37
# Represents a backward formula that calculates derivatives for one
39
@dataclass(frozen=True)
41
# The formula string (legit C++ expression).
42
# Note that expressions against input arguments have been replaced with the
43
# corresponding saved attributes.
45
# raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
46
# here: `mul_tensor_backward(grad, self, other_scalar_type)`
49
# The formula string before input argument replacement
52
# Names of the arguments for which this formula calculates derivatives.
53
var_names: tuple[str, ...]
55
# Saved inputs that are referenced by the formula.
56
saved_inputs: tuple[SavedAttribute, ...]
58
# Saved outputs that are referenced by the formula.
59
saved_outputs: tuple[SavedAttribute, ...]
61
# Gradients that are referenced by name in the formula.
62
named_gradients: set[str]
65
# Represents a forward formula that calculates forward derivatives
67
@dataclass(frozen=True)
68
class ForwardDerivative:
69
# The formula string (legit C++ expression).
70
# Note that special keywords such as "linear" or "element_wise" have been
71
# replaced by the automatically generated formula.
74
# Name of the output arguments for which this formula calculates forward
76
var_names: tuple[str, ...]
78
# Type of the output arguments for which this formula calculates forward
80
var_types: tuple[Type, ...]
82
# Inputs for which the forward derivatives are required for this formula
83
required_inputs_fw_grad: tuple[str, ...] | None
85
# Inputs for which the primal is required for this formula
86
required_inputs_primal: tuple[str, ...] | None
88
# Flag to specify if this formula requires the original value of self
89
# This is only used by inplace operations
90
required_original_self_value: bool
92
# If this formula is specified in derivatives.yaml or if we are re-using the
93
# out of place formula for inplace
94
is_reusing_outplace_formula: bool
97
# Represents differentiability info for a NativeFunction.
98
@dataclass(frozen=True)
99
class DifferentiabilityInfo:
100
# The base name read from derivatives.yaml.
103
# The matching native function.
105
# There can be multiple NativeFunction having the same base name:
106
# - different overloads with different types of input arguments;
107
# - in-place/out/functional variants of the same function;
109
# We first use the schema string (under the 'name' key) in derivatives.yaml
110
# to find the NativeFunction having the same schema string.
111
# Then we find the in-place/out/functional variants of the matching function.
112
# Among these variants, we choose the one having the same name as the
113
# derivatives.yaml entry. If there is no exact match, then we choose the
115
# TODO: maybe the logic to search for all variants is no longer necessary?
118
# The name of the generated autograd function.
119
# It's set only if we will calculate a derivative, i.e.
120
# 'args_with_derivatives' is not empty.
123
# The derivatives formulae for this function.
124
# Note that the length of this sequence is the number of differentiable inputs
125
derivatives: Sequence[Derivative]
127
# The forward derivatives formulae for this function.
128
# Note that the length of this sequence is the number of differentiable outputs
129
forward_derivatives: Sequence[ForwardDerivative]
131
# The union of 'saved_inputs' of all 'derivatives'.
132
all_saved_inputs: Sequence[SavedAttribute]
134
# The union of 'saved_outputs' of all 'derivatives'.
135
all_saved_outputs: Sequence[SavedAttribute]
137
# All named gradients that are available for use, in the same
138
# order as in the grads vector.
139
available_named_gradients: Sequence[str]
141
# The named gradients that are used in any of the derivatives.
142
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
143
used_named_gradients: set[str]
145
# The function's input arguments for which it calculates derivatives.
146
# It's the union of 'var_names' of all 'derivatives', sorted by the
147
# argument order in the function schema.
148
args_with_derivatives: Sequence[Binding]
150
# Names of arguments whose derivative formula is 'non_differentiable'.
151
non_differentiable_arg_names: Sequence[str]
153
# Raw data read from derivatives.yaml.
154
output_differentiability: list[bool] | None
156
# output_differentiability in derivatives.yaml can be a list of
157
# conditions that express if the output is differentiable. In this case,
158
# the number of conditions must match the number of outputs
159
# (NB: we only support one condition right now).
160
# output_differentiability gets populated with True for each condition,
161
# while output_differentiability_conditions gets populated with the conditions
162
output_differentiability_conditions: list[str] | None
165
def has_derivatives(self) -> bool:
166
return len(self.args_with_derivatives) > 0
168
# Generates a new DifferentiabilityInfo using the exact same set of derivative information,
169
# but with a new operator name.
170
# This is used when generating "copy" variants of view ops,
171
# which are able to use the exact same derivative formula as the original view op
172
# See Note [Codegen'd {view}_copy Operators]
173
def create_view_copy_from_view_derivative(
174
self, g: NativeFunctionsViewGroup
175
) -> DifferentiabilityInfo | None:
176
if g.view_copy is None:
180
name_split_by_period = self.name.split(".", maxsplit=2)
181
# Append a "_copy" to the base name of the operator (but keep the overload name the same)
182
view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
183
name_split_by_period[1:]
185
view_copy_op_name = None if self.op is None else f"{self.op}_copy"
187
return DifferentiabilityInfo(
188
# Use the "_copy" version of name/func/op
191
op=view_copy_op_name,
192
# But keep all derivative info the same
193
derivatives=self.derivatives,
194
forward_derivatives=self.forward_derivatives,
195
all_saved_inputs=self.all_saved_inputs,
196
all_saved_outputs=self.all_saved_outputs,
197
available_named_gradients=self.available_named_gradients,
198
used_named_gradients=self.used_named_gradients,
199
args_with_derivatives=self.args_with_derivatives,
200
non_differentiable_arg_names=self.non_differentiable_arg_names,
201
output_differentiability=self.output_differentiability,
202
output_differentiability_conditions=self.output_differentiability_conditions,
206
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
209
for derivative in info.derivatives:
210
formula = derivative.formula
211
if re.search(IDENT_REGEX.format(ident), formula):
216
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
217
return uses_ident(info, "retain_variables")
220
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
221
return uses_ident(info, "grad")
224
# Represents a differentiable `Argument`.
225
# How is it different from the `Argument` type?
226
# - It's processed Arguments which are differentiable and only used in the
227
# context of the autograd codegen;
228
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
229
@dataclass(frozen=True)
230
class DifferentiableInput:
234
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
238
# Represents a differentiable `Return`.
239
# How it it different from the `Return` type?
240
# - The name in `Return` is optional. Here it is always populated using the same
241
# `cpp.return_names()` method.
242
# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
243
# - It's processed Returns which are differentiable, in compliance with the
244
# `output_differentiability` field defined in derivatives.yaml (if specified),
245
# and are only used in the context of the autograd codegen;
246
@dataclass(frozen=True)
247
class DifferentiableOutput:
251
# TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
255
@dataclass(frozen=True)
256
class NativeFunctionWithDifferentiabilityInfo:
258
info: dict[str, DifferentiabilityInfo] | None
259
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
262
# TODO: Update comment below since it is out of date.
263
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
264
"""How are we going to call the underlying implementation of a
265
declaration? There are two strategies:
266
- use_derived: we want to call the implementation on CPUDoubleType
267
(or a similar, derived Type instance). Because these derived
268
instances deal in Tensors, not Variables (it's a completely different
269
object, so it doesn't dispatch back to VariableType), code on
270
this dispatch path needs to wrap/unwrap tensors. If the
271
derived implementation takes and returns tensors, the
272
implementation is usually differentiable (although we also use
273
the derived dispatch path for non-differentiable functions
274
that we still want to dispatch on the derived Type instance;
276
- use_type: we want to call the implementation on Type, because
277
it is implemented concretely, and the functions it invokes will
278
get dispatched back to VariableType (which will ensure that they
281
# fn is derived as long as any of its per-key differentiability infos
282
# has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
283
# and ADInplaceOrViewType. We want to generate these functions as long as a
284
# derivative is defined for ANY dispatch key.
285
if fn.func.is_abstract or (
286
fn.info is not None and any(info.has_derivatives for info in fn.info.values())
288
# If the function is abstract (not implemented on at::Type), we must
289
# call the implementation on the derived type with unpacked tensors.
291
# If the function has a derivative specified and is concrete, we could
292
# call either implementation. We prefer the calling the derived
293
# type's implementation with unpacked tensors because it is more
294
# performant in some cases: any internal calls to other ATen functions
295
# won't have the history tracked.
297
# If the function has a type dispatched argument (i.e. is a factory),
298
# we prefer calling the derived type's implementation both because it is
299
# more performant and to ensure factory functions return tensors with _version
300
# of 0 (probably not strictly necessary, but nice to have to keeps versions simple
305
# If the function is concrete (we don't have to override it) and we
306
# didn't declare it in derivatives.yaml, we'll assume that it is
307
# actually implemented out of differentiable functions. (This
308
# assumption might not hold, but then you'll see gradcheck fail.)
312
def is_foreach_func(f: NativeFunction) -> bool:
313
return f.func.name.name.base.startswith("_foreach_")
316
# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
317
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
318
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
319
_foreach_with_inplace_ref = {"_foreach_zero_"}
320
_foreach_with_tensor_overload = {
321
"_foreach_add.Tensor",
322
"_foreach_mul.Tensor",
323
"_foreach_div.Tensor",
325
# The following do not support the alpha kwarg, which the nonforeach versions support.
326
_skip_argument_len_check = {
327
"_foreach_add.Scalar",
328
"_foreach_add_.Scalar",
329
"_foreach_add.ScalarList",
330
"_foreach_add_.ScalarList",
331
"_foreach_sub.Scalar",
332
"_foreach_sub_.Scalar",
333
"_foreach_sub.ScalarList",
334
"_foreach_sub_.ScalarList",
338
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
339
# reference to generate derivatives.
340
def is_reference_for_foreach(
342
function_schema: FunctionSchema,
345
f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
347
not function_schema.name.name.inplace
348
or str(f.func.name) in _foreach_with_inplace_ref
351
str(f.func.name) in _skip_argument_len_check
352
or len(f.func.arguments.flat_non_out)
353
== len(function_schema.arguments.flat_non_out)
356
ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
357
for arg, ref_arg in zip(
358
f.func.arguments.flat_non_out,
359
function_schema.arguments.flat_non_out,
365
# TODO(crcrpar): Avoid hard coding "Default" ideally.
366
def gen_foreach_derivativeinfo(
367
foreach_function: NativeFunction,
368
functional_info_by_signature: dict[
369
FunctionSchema, dict[str, DifferentiabilityInfo]
371
non_functional_info_by_signature: dict[
372
FunctionSchema, dict[str, DifferentiabilityInfo]
374
dispatch_key: str = "Default",
375
) -> tuple[DifferentiabilityInfo | None, bool]:
376
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
378
The second return value indicates whether the info is generated in this function.
380
ref_diff_info: DifferentiabilityInfo | None = None
382
for function_schema, diff_info in functional_info_by_signature.items():
383
if not is_reference_for_foreach(foreach_function, function_schema):
385
ref_diff_info = diff_info[dispatch_key]
386
if ref_diff_info is not None:
388
# note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
389
# while the info of `zero_` is in non_functional_info_by_signature
391
ref_diff_info is None
392
and foreach_function.func.kind() == SchemaKind.inplace
393
and str(foreach_function.func.name) in _foreach_with_inplace_ref
395
for function_schema, diff_info in non_functional_info_by_signature.items():
396
if not is_reference_for_foreach(foreach_function, function_schema):
398
ref_diff_info = diff_info[dispatch_key]
399
if ref_diff_info is not None:
401
if ref_diff_info is None:
404
# non out-place uses the existing Derivative.
405
if foreach_function.func.kind() == SchemaKind.inplace:
406
return ref_diff_info, False
408
map_refarg2foreacharg, map_name2arg = {}, {}
409
for i, (arg, ref_arg) in enumerate(
411
foreach_function.func.arguments.flat_non_out,
412
function_schema.arguments.flat_non_out,
415
map_refarg2foreacharg[ref_arg.name] = arg.name
416
map_name2arg[arg.name] = arg
418
all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
419
modified_derivative_formulas = []
420
for i, derivative in enumerate(ref_diff_info.derivatives):
421
modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
422
"result", "result[i]"
424
saved_inputs, saved_outputs = [], []
425
# note(crcrpar): This context seems necessary to call `cpp.argument_type`
426
with local.parametrize(
427
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
428
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
430
for ref_input in derivative.saved_inputs:
431
ref_input_jit_name = ref_input.expr.split(".")[0]
432
mapped_name = map_refarg2foreacharg[ref_input_jit_name]
433
if isinstance(map_name2arg[mapped_name].type, ListType):
434
mapped_expr = mapped_name + "[i]"
436
mapped_expr = mapped_name
437
new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
438
modified_formula = modified_formula.replace(
439
cast(str, ref_input.nctype.name), new_expr
442
nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
443
canonical_nctype = NamedCType(
444
nctype.name, nctype.type.remove_const_ref()
447
SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
449
for ref_output in derivative.saved_outputs:
450
if ref_output.nctype.name == "result":
451
saved_outputs.append(
454
name="result", type=BaseCType(tensorListT)
460
raise RuntimeError("")
461
var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
462
all_var_names.extend(var_names)
463
all_saved_inputs.extend(saved_inputs)
464
all_saved_outputs.extend(saved_outputs)
465
modified_derivative = Derivative(
466
formula=modified_formula,
467
original_formula=derivative.formula,
468
var_names=tuple(var_names),
469
saved_inputs=tuple(saved_inputs),
470
saved_outputs=tuple(saved_outputs),
471
named_gradients=set(),
473
modified_derivative_formulas.append(modified_derivative)
475
with local.parametrize(
476
use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
477
use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
479
args_with_derivatives = [
482
nctype=cpp.argument_type(arg, binds=arg.name),
486
for arg in foreach_function.func.arguments.flat_non_out
487
if arg.name in all_var_names
490
forward_derivatives: list[ForwardDerivative] = []
491
fw_derivative: ForwardDerivative
492
for fw_derivative in ref_diff_info.forward_derivatives:
493
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
494
var_types: list[Type] = list(fw_derivative.var_types)
495
required_inputs_fw_grad: list[str] = []
496
required_inputs_primal: list[str] = []
497
if fw_derivative.required_inputs_fw_grad is not None:
498
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499
if fw_derivative.required_inputs_primal:
500
required_inputs_primal = list(fw_derivative.required_inputs_primal)
501
modified_formula = fw_derivative.formula
503
# Foreach's result is TensorList
504
if "result" in modified_formula:
505
modified_formula = fw_derivative.formula.replace("result", "result[i]")
507
for foreach_arg, ref_arg in zip(
508
foreach_function.func.arguments.flat_non_out,
509
ref_diff_info.func.func.arguments.flat_non_out,
511
# Modify reference forward formula
513
isinstance(foreach_arg.type, ListType)
514
and not foreach_arg.type.is_tensor_like()
516
# Assuming ScalarList
517
modified_formula = modified_formula.replace(
518
ref_arg.name, foreach_arg.name + "[i]"
520
elif foreach_arg.type.is_tensor_like():
521
# Assuming TensorList / Tensor
522
# assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
523
assert isinstance(foreach_arg.type, ListType) or (
524
foreach_arg.type == BaseType(BaseTy.Tensor)
525
and str(foreach_function.func.name) in _foreach_with_tensor_overload
526
), f"{foreach_function.func.name}, {foreach_arg.type}"
527
for suffix in ("_p", "_t"):
528
curr_expr = ref_arg.name + suffix
529
if curr_expr in modified_formula:
530
new_expr = foreach_arg.name + suffix
531
modified_formula = modified_formula.replace(curr_expr, new_expr)
534
if foreach_arg.name != ref_arg.name:
535
modified_formula = modified_formula.replace(
536
ref_arg.name, foreach_arg.name
539
# note(crcrpar): there should exist a cooler way...
540
for i, name in enumerate(var_names):
541
if name == ref_arg.name:
542
var_names[i] = foreach_arg.name
543
var_types[i] = foreach_arg.type
544
for i, name in enumerate(required_inputs_fw_grad):
545
if name == ref_arg.name:
546
required_inputs_fw_grad[i] = foreach_arg.name
547
for i, name in enumerate(required_inputs_primal):
548
if name == ref_arg.name:
549
required_inputs_primal[i] = foreach_arg.name
550
forward_derivatives.append(
552
formula=modified_formula,
553
var_names=tuple(var_names),
554
var_types=tuple(var_types),
555
required_inputs_fw_grad=tuple(required_inputs_fw_grad),
556
required_inputs_primal=tuple(required_inputs_primal),
557
required_original_self_value=fw_derivative.required_original_self_value,
558
is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
563
DifferentiabilityInfo(
564
name=foreach_function.func.name.name.base,
565
func=foreach_function,
566
op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
567
derivatives=modified_derivative_formulas,
568
forward_derivatives=forward_derivatives,
569
all_saved_inputs=tuple(set(all_saved_inputs)),
570
all_saved_outputs=tuple(set(all_saved_outputs)),
571
available_named_gradients=(),
572
used_named_gradients=set(),
573
args_with_derivatives=args_with_derivatives,
574
non_differentiable_arg_names=[],
575
output_differentiability=None,
576
output_differentiability_conditions=None,
582
def match_differentiability_info(
583
native_functions: list[NativeFunction],
584
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585
) -> list[NativeFunctionWithDifferentiabilityInfo]:
586
"""Sets the "derivative" key on declarations to matching autograd function
587
In-place functions will use the out-of-place derivative definition if there
588
is no in-place specific derivative.
591
functional_info_by_signature = {
592
schema.signature(strip_default=True): info_dict
593
for schema, info_dict in differentiability_infos.items()
594
if schema.kind() == SchemaKind.functional
596
non_functional_info_by_signature = {
597
schema.signature(strip_default=True): info_dict
598
for schema, info_dict in differentiability_infos.items()
599
if schema.kind() != SchemaKind.functional
604
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
605
# Don't bother matching info to generated out= variants
606
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
609
# (1) Check for an exact match
610
if f.func in differentiability_infos:
611
return differentiability_infos[f.func], True
613
# (2) If no exact match, check if the out-of-place variant
614
# of this operator has a match.
615
# i.e mul() for mul_() or mul_out()
616
# note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
617
# native functions instead of the out-place counterparts.
618
f_sig = f.func.signature(strip_default=True)
619
if f_sig in functional_info_by_signature and not is_foreach_func(f):
620
return functional_info_by_signature[f_sig], False
622
# (3) Some operators have a derivative explicitly defined for the mutable
623
# variant, but get a code-generated out-of-place variant which does *not*
624
# come with a derivative formula.
625
# For the generated out-of-place variant, use the mutable variant's formula
627
if "generated" in f.tags and f_sig in non_functional_info_by_signature:
628
info_dict = non_functional_info_by_signature[f_sig]
629
# See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
631
any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
632
for info in info_dict.values()
634
Attempted to convert a derivative formula for a mutable operator
635
to be used by automatically by its functional variant ("{str(f.func)}").
636
this is not currently supported (we'd need to fix up the formula in the codegen)."""
637
return info_dict, False
639
# (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
640
if is_foreach_func(f):
641
assert f.func not in differentiability_infos
642
diff_info, is_generated = gen_foreach_derivativeinfo(
644
functional_info_by_signature,
645
non_functional_info_by_signature,
647
if diff_info is None:
649
# TODO(crcrpar): Avoid hard coding "Default" ideally.
650
diff_info_dict = {"Default": diff_info}
652
differentiability_infos[f.func] = diff_info_dict
653
functional_info_by_signature[f.func] = diff_info_dict
654
return diff_info_dict, is_generated
658
result: list[NativeFunctionWithDifferentiabilityInfo] = []
659
for f in native_functions:
660
info_dict, is_exact_match = find_info(f)
662
# Currently, the '.strides()' to 'strides_or_error' replacement does not support
663
# 'self' derivatives of an inplace function, so we must check for this case.
664
if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
665
for info in info_dict.values():
666
for derivative in info.derivatives:
667
if "self" in derivative.var_names:
668
for saved_input in derivative.saved_inputs:
669
assert "strides_or_error" not in saved_input.expr, (
670
"Calling '.strides()' in the 'self' derivative formula of an "
671
f"in-place function is not supported: {f.func}"
676
NativeFunctionWithDifferentiabilityInfo(
677
func=f, info=None, fw_derivatives=None
682
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
683
for key, info in info_dict.items():
684
if not info.forward_derivatives:
685
fw_derivative_dict[key] = []
688
forward_derivatives = info.forward_derivatives
690
# For functions that have a single def for out-of-place and inplace (like abs())
691
if f.func.kind() == SchemaKind.inplace:
692
# For inplace functions there is a little bit of work to do:
693
# 1) Validate the formula and make sure the input that is modified in not used:
694
# - If there is a formula for the inplace variant of the function (is_exact_match == True) then
695
# we make sure that the original value of the input that is being modified inplace (self_p) is
696
# not used in the formula. Note that the formula can use "original_self_p" here and that would
697
# trigger a clone of the original input.
698
# - If we are re-using the out of place formula (is_exact_match == False) then we replace every
699
# occurrence of self_p and self_t by original_self_p and original_self_t. These will be
700
# populated by cloned version of the original input (either the clone done by the backward AD
701
# logic if self is also used in a backward formula or a special clone that we add).
702
# 2) At this point, there cannot be a self_p in the formula.
703
# 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
704
# simply called self (as it is modified inplace).
705
# 4) Update the required primals data in case it used to contain "result" but should now contain
707
# 5) If it is not an exact match, the user formula is not modifying the existing forward grad
708
# inplace as it should. So add some code that makes sure that we do so if the forward grad
712
len(info.forward_derivatives) == 1
713
) # Only single output inplace should exist
714
fw_info = info.forward_derivatives[0]
715
formula = fw_info.formula
717
def replace_self_with_original_self(formula: str, postfix: str) -> str:
718
def repl(m: re.Match[str]) -> str:
719
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
721
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
723
if re.search(IDENT_REGEX.format("self_p"), formula):
725
# For manually defined formulas, don't allow the original value to be used
727
f'The formula for "{f.func.name}" is using the original value of self '
728
"that is being modified inplace. This would lead to wrong forward gradients. "
729
'Please use "result" in the formula only.'
732
# When the original formula is out of place, we save a clone of the primal
733
# value to be able to access this value if needed
734
# replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
735
formula = replace_self_with_original_self(formula, "_p")
736
formula = replace_self_with_original_self(formula, "_t")
738
# replace "result" from the formula by "self_p"
739
def repl(m: re.Match[str]) -> str:
740
return f"{m.group(1)}self_p{m.group(2)}"
742
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
744
required_primals = fw_info.required_inputs_primal
745
if re.search(IDENT_REGEX.format("self_p"), formula):
747
required_primals + ("self",) if required_primals else ("self",)
750
if not is_exact_match:
751
# NOTE [In-place forward AD formula Optimization]
753
# This optimization transforms the formula to directly do inplace, i.e.
754
# instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
756
# 1) the formula satisfies the pattern: "self_t.op(*args)"
757
# 2) "op" in (1) needs to be the same as the op the derivative is for
759
# (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
760
# If there is a need, we can relax (2) to allow any op that has an in-place variant
761
is_single_method_on_self_t = False
762
directly_do_inplace = False
763
op_name: str | None = None
764
between_parens: str | None = None
765
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
767
op_name, between_parens = match.group(1), match.group(2)
770
# Match: self_t.op1(other_p.op2(arg))
771
# Avoid: self_t.op1(args) + self_t.op2(args)
772
# Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
773
def check_parens_nest_level_gt_zero(s: str) -> bool:
784
is_single_method_on_self_t = check_parens_nest_level_gt_zero(
787
directly_do_inplace = (
788
is_single_method_on_self_t and op_name == info.name
791
if directly_do_inplace:
792
assert op_name is not None
793
assert between_parens is not None
794
formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
796
# Make sure that the forward grad is modified inplace when the original formula
798
formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
800
required_original_self_value = bool(
801
re.search(IDENT_REGEX.format("original_self_p"), formula)
802
) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
804
forward_derivatives = [
808
var_types=fw_info.var_types,
809
required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
810
required_inputs_primal=required_primals,
811
required_original_self_value=required_original_self_value,
812
is_reusing_outplace_formula=not is_exact_match,
816
fw_derivative_dict[key] = forward_derivatives
819
NativeFunctionWithDifferentiabilityInfo(
820
func=f, info=info_dict, fw_derivatives=fw_derivative_dict
827
def is_differentiable(
828
name: str, type: Type, info: DifferentiabilityInfo | None
830
return type.is_tensor_like() and (
831
info is None or name not in info.non_differentiable_arg_names
835
def gen_differentiable_outputs(
836
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837
) -> list[DifferentiableOutput]:
839
info = fn.info[key] if fn.info else None
840
outputs: list[DifferentiableOutput] = [
841
DifferentiableOutput(
844
cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
846
for name, ret in zip(cpp.return_names(f), f.func.returns)
848
output_differentiability = info.output_differentiability if info else None
849
if output_differentiability is not None:
850
if len(output_differentiability) != len(outputs):
852
f"The length of output_differentiability ({len(output_differentiability)}), "
853
f"does not match the number of outputs ({len(outputs)})."
855
differentiable_outputs: list[DifferentiableOutput] = []
856
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
858
"output_differentiability=False for inplace operation (version_counter won't get updated)"
860
for differentiable, output in zip(output_differentiability, outputs):
862
differentiable_outputs.append(output)
863
return differentiable_outputs
864
candidate_differentiable_outputs = list(
865
filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
867
if uses_single_grad(info):
868
return candidate_differentiable_outputs[:1]
870
return candidate_differentiable_outputs