1
# Parses derivatives.yaml into autograd functions
3
# Each autograd function is represented by `DifferentiabilityInfo` containing
4
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
6
from collections import defaultdict
7
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple
10
from torchgen.api import cpp
12
from torchgen.api.autograd import (
14
DifferentiabilityInfo,
18
from torchgen.api.types import (
37
from torchgen.context import with_native_function
38
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
39
from torchgen.model import (
43
NativeFunctionsViewGroup,
49
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
50
from torchgen.yaml_utils import YamlLoader
52
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
54
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {}
56
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
59
# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
60
# Since every {view} and {view}_copy op shares the same derivative formula,
61
# we generate them here instead of duplicating them in the yaml.
62
# See Note [Codegen'd {view}_copy Operators]
63
def add_view_copy_derivatives(
64
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
65
view_groups: List[NativeFunctionsViewGroup],
67
# Get the map from each view op's name to its corresponding view group
68
view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
69
g.view.func.name: g for g in view_groups
74
for info_dispatch_dict in infos.values():
75
# maybe_view_group only needs to be calculated once per info_dispatch_dict
76
maybe_view_group = None
77
view_copy_differentiability_infos = {}
78
for dispatch_key, info in info_dispatch_dict.items():
79
maybe_view_group = view_name_to_group.get(info.func.func.name, None)
80
if maybe_view_group is not None and maybe_view_group.view_copy is not None:
81
view_copy_info = info.create_view_copy_from_view_derivative(
84
if view_copy_info is not None:
85
fn_schema = view_copy_info.func.func
86
view_copy_differentiability_infos[dispatch_key] = view_copy_info
89
# prefer manually-defined derivatives if any
90
if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
91
assert fn_schema is not None
92
view_infos[fn_schema] = view_copy_differentiability_infos
94
infos.update(view_infos)
98
derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
100
# Do some caching as this is a deterministic function
101
global _GLOBAL_LOAD_DERIVATIVE_CACHE
102
key = (derivatives_yaml_path, native_yaml_path)
103
if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
104
with open(derivatives_yaml_path) as f:
105
definitions = yaml.load(f, Loader=YamlLoader)
107
funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
108
# From the parsed native functions, separate out the (generated) view_copy functions,
109
# so we can generate derivatives for them separately.
110
native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
111
native_functions = concatMap(
113
if isinstance(g, NativeFunction)
114
else list(g.functions(include_copy=True)),
115
native_functions_with_view_groups,
119
for g in native_functions_with_view_groups
120
if isinstance(g, NativeFunctionsViewGroup)
123
# What's the difference between function schema v.s. signature?
124
# function schema is the complete declaration including mutability annotation / default value and etc.
125
# signature is the canonical schema for a group of functions (in-place/out/functional variants)
126
# that are semantically related.
127
functions_by_signature: Dict[
128
FunctionSchema, List[NativeFunction]
129
] = defaultdict(list)
130
functions_by_schema: Dict[str, NativeFunction] = {}
131
for function in native_functions:
132
functions_by_signature[function.func.signature()].append(function)
133
assert str(function.func) not in functions_by_schema
134
functions_by_schema[str(function.func)] = function
136
# Keep track of how many of which ops we've seen so we can
137
# disambiguate them with a numeric suffix.
138
op_counter = Counter[str]()
140
# infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
141
# this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
142
# we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
143
infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {}
144
used_dispatch_keys: Set[str] = set()
145
for defn_dict in definitions:
146
# Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
147
if "dispatch" not in defn_dict:
148
specification = defn_dict.pop("name")
149
output_differentiability = defn_dict.pop(
150
"output_differentiability", None
152
defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
153
if output_differentiability:
154
defn_dict["output_differentiability"] = output_differentiability
155
name, per_dispatch_diffinfos = create_differentiability_info(
157
functions_by_signature,
162
infos[name] = per_dispatch_diffinfos
164
add_view_copy_derivatives(infos, view_groups)
166
# cache both loaded infos as well a a set of all the dispatch_keys/aliases
167
# that appear in derivatives.yaml. used_dispatch_keys is useful for generating
168
# VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
169
_GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
171
return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
174
# TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
176
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
177
sigs = CppSignatureGroup.from_native_function(f, method=False)
178
if sigs.symint_signature is not None:
179
return sigs.symint_signature.arguments()
181
return sigs.signature.arguments()
184
def create_derivative(
187
var_names: Tuple[str, ...],
188
available_named_gradients: Sequence[str],
190
original_formula = formula
191
arguments: List[NamedCType] = [
192
a.nctype.remove_const_ref() for a in cpp_arguments(f)
195
return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
196
return_types = tuple(
197
cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
201
NamedCType(name, type) for name, type in zip(return_names, return_types)
204
formula, saved_inputs = saved_variables(formula, arguments, var_names)
205
formula, saved_outputs = saved_variables(formula, named_returns, var_names)
207
used_named_gradients = {
209
for name in available_named_gradients
210
if re.search(IDENT_REGEX.format(name), formula)
213
# Check that the referenced derivatives in the formula are in bounds
214
for i in used_gradient_indices(formula):
215
if i >= len(f.func.returns):
217
f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
218
f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
223
original_formula=original_formula,
225
saved_inputs=saved_inputs,
226
saved_outputs=saved_outputs,
227
named_gradients=used_named_gradients,
231
def create_forward_derivative(
232
f: NativeFunction, formula: str, names: Tuple[str, ...]
233
) -> ForwardDerivative:
235
var_types: Optional[Tuple[Type, ...]] = None
236
for r in f.func.returns:
237
if r.name in var_names:
238
if var_types is None:
240
var_types = var_types + (r.type,)
242
# Handle default return names
243
if var_types is None:
244
if var_names == ("result",):
245
assert len(f.func.returns) == 1
246
var_types = (f.func.returns[0].type,)
248
for var_name in var_names:
249
res = re.findall(r"^result(\d+)$", var_name)
251
if var_types is None:
253
arg_idx = int(res[0])
254
var_types = var_types + (f.func.returns[arg_idx].type,)
256
assert var_types is not None, "No matching output for forward derivative definition"
257
return ForwardDerivative(
261
required_inputs_fw_grad=None,
262
required_inputs_primal=None,
263
required_original_self_value=False,
264
is_reusing_outplace_formula=False,
268
def postprocess_forward_derivatives(
271
all_arg_names: List[str],
272
derivatives: List[Derivative],
273
forward_derivatives: List[ForwardDerivative],
274
args_with_derivatives: Sequence[Binding],
275
) -> List[ForwardDerivative]:
276
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
277
is_foreach = f.func.name.name.base.startswith("_foreach_")
278
required_inputs = set()
279
for arg in args_with_derivatives:
281
arg.type in ("at::TensorList", "const at::ITensorListRef &")
284
# The functions taking TensorList handle everything internally
288
found = re.search(IDENT_REGEX.format(arg_name), formula)
291
f"The forward formula for {defn_name} is using the base name of the {arg_name} "
292
f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
293
f"value and {arg_name}_t to access the tangent."
296
found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
298
required_inputs.add(arg_name)
300
return tuple(required_inputs)
302
updated_derivatives: List[ForwardDerivative] = []
304
for defn in forward_derivatives:
305
formula = defn.formula
306
required_inputs_tangent = find_required_inputs(formula, "_t")
307
if formula == "auto_element_wise":
309
f.func.kind() != SchemaKind.inplace
310
), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
312
(not len(args_with_derivatives) == 1)
313
or len(forward_derivatives) > 1
314
or len(forward_derivatives[0].var_names) > 1
317
f"Derivative definition of {defn_name} in derivatives.yaml defines the "
318
"forward definition of gradient as element_wise but this only "
319
"works for functions with a single differentiable input and a "
320
"single differentiable output."
322
if not len(derivatives) == 1:
324
f"Derivative definition of {defn_name} in derivatives.yaml defines the "
325
"forward definition of gradient as element_wise but it does not "
326
"defines the gradient formula for its argument which is required."
328
# This transformation is based on the observation that for element-wise functions, the Jacobian
329
# matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
330
# For the complex case, we use hermitian transpose and get (v.conj() J).conj()
331
# So here we are going to re-use the backward formula and replace two things:
332
# 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
333
# 2) all usage of an original input "foo" with its primal value "foo_p".
334
# 3) conjugate the final result
335
# For example, for abs, the backward formula is:
337
# And this function generates a forward formula that is:
338
# (self_t.conj() * self_p.sgn()).conj()
340
backward_formula = derivatives[0].original_formula
341
input_name = args_with_derivatives[0].name
343
# Do replacement 1) of the grad
344
def repl(m: Any) -> str:
345
return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
347
fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
349
# Do replacement 2) of the input variables
350
for arg in args_with_derivatives:
353
def repl(m: Any) -> str:
354
return f"{m.group(1)}{arg_name}_p{m.group(2)}"
356
fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
358
# Do the final conjugate 3)
359
fw_formula = f"({fw_formula}).conj()"
361
# Since there is a single differentiable inputs and we necessarily need its tangent we can
362
# simply require all differentiable input's tangent.
363
required_inputs_tangent = tuple(all_arg_names)
365
elif formula == "auto_linear":
367
len(forward_derivatives) > 1
368
or len(forward_derivatives[0].var_names) > 1
371
f"Derivative definition of {defn_name} in derivatives.yaml defines the "
372
"forward definition of gradient as linear but this only works "
373
"for functions with a single differentiable output."
375
# This transformation is based on the observation that linear functions can be written as:
377
# For some matrix A and the Jacobian of the function f is also A.
378
# So doing J * v = A * v = f(v).
379
# Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
380
# We do this by calling the forward again by replacing any occurrence of the differentiable
381
# input "foo" by it's tangent "foo_t".
382
# Note that multiple inputs are not a problem as long as the function is truly linear wrt to
383
# the vector where all the differentiable inputs are stacked.
385
diff_arg_names = [arg.name for arg in args_with_derivatives]
386
assert len(diff_arg_names) > 0
388
# Do replacement of input variables
390
for arg_name in all_arg_names:
391
if arg_name in diff_arg_names:
392
arg_name = arg_name + "_t"
393
new_args.append(arg_name)
395
# TODO we are trolling
396
if f.func.has_symint():
397
defn_name += "_symint"
399
# Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
400
if Variant.function in f.variants:
401
fw_formula = f"at::{defn_name}({', '.join(new_args)})"
403
assert Variant.method in f.variants
404
fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
406
# All of the input tangents are always used so all of them are required here.
407
required_inputs_tangent = tuple(diff_arg_names)
410
# At this point, the formula is final and is not modified anymore.
412
# During forward formula, we use the primal instead of the input Tensors.
413
# This call inspects the formula to find for which input's primal are used.
414
required_inputs_primal = find_required_inputs(formula, "_p")
416
updated_derivatives.append(
419
var_names=defn.var_names,
420
var_types=defn.var_types,
421
required_inputs_fw_grad=required_inputs_tangent,
422
required_inputs_primal=required_inputs_primal,
423
required_original_self_value=False,
424
is_reusing_outplace_formula=False,
428
return updated_derivatives
431
def is_forward_derivative_definition(
432
all_arg_names: List[str], names: Tuple[str, ...]
435
if name not in all_arg_names:
439
raise RuntimeError("Expected `names` to be non-empty")
442
def create_differentiability_info(
443
defn_dict: Dict[Any, Any],
444
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
445
functions_by_schema: Dict[str, NativeFunction],
446
op_counter: Counter[str],
447
used_dispatch_keys: Set[str],
448
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]:
449
"""Processes a single entry `defn` in derivatives.yaml"""
451
def canonical_function(
452
functions: Sequence[NativeFunction], name: str
456
not f.func.is_functional_fn()
457
and not f.func.is_out_fn()
458
and name == str(f.func.name.name)
461
# some functions only have in-place variants
462
assert name + "_" == cpp.name(functions[0].func)
465
def split_names(raw_names: str) -> Tuple[str, ...]:
466
"""Given "foo, bar", return ["foo", "bar"]."""
467
return tuple(x.strip() for x in raw_names.split(","))
469
def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
471
Check for some subtle mistakes one might make when writing derivatives.
472
These mistakes will compile, but will be latent until a function is
473
used with double backwards.
476
uses_grad = False # true if any derivative uses "grad"
477
num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]"
478
uses_named_grads = False # true if any derivative uses "grad_{name}"
479
used_grads_indices: List[int] = [] # which indices of grads are used
480
for d in derivatives:
482
uses_grad = uses_grad or bool(
483
re.findall(IDENT_REGEX.format("grad"), formula)
485
num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
486
uses_named_grads = uses_named_grads or bool(d.named_gradients)
487
used_grads_indices.extend(used_gradient_indices(formula))
488
# This is a basic sanity check: the number of places we see
489
# "grads" should be no fewer than the number of indices we see
490
# inside "grads". They may not be equal because we may use
491
# "grads" without an index.
492
assert num_grads_uses >= len(used_grads_indices)
493
# Thus if the number is equal, every use of grads is also
495
only_used_grads_indices = num_grads_uses == len(used_grads_indices)
497
if uses_grad and num_grads_uses > 0:
499
f"Derivative definition of {defn_name} in derivatives.yaml illegally "
500
"mixes use of 'grad' and 'grads'. Consider replacing "
501
"occurrences of 'grad' with 'grads[0]'"
504
if only_used_grads_indices and set(used_grads_indices) == {0}:
506
f"Derivative definition of {defn_name} in derivatives.yaml solely "
507
"refers to 'grads[0]'. If the first output is indeed the "
508
"only differentiable output, replace 'grads[0]' with 'grad'; "
509
"otherwise, there is a likely error in your derivatives "
513
if uses_named_grads and (uses_grad or num_grads_uses > 0):
515
f"Derivative definition of {defn_name} in derivatives.yaml illegally "
516
'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
517
"only one method for identifying gradients."
520
@with_native_function
521
def set_up_derivatives(
524
Sequence[Derivative],
525
Sequence[ForwardDerivative],
530
# Set up the derivative information
531
derivatives: List[Derivative] = []
532
forward_derivatives: List[ForwardDerivative] = []
533
non_differentiable_arg_names: List[str] = []
534
args_with_derivatives_set: Set[str] = set()
536
all_arg_names = [a.name for a in cpp_arguments(f)]
538
r.name for r in f.func.returns
539
] # only used for the assert below
540
# output_differentiability is captured from the enclosed
541
# scope. Don't modify it.
543
# If it is not present, then no output is explicitly
546
# It may be present and shorter than the length of return
547
# values. If that's the case, any return value that does not
548
# have a corresponding entry is considered not differentiable.
549
differentiability = output_differentiability or [True] * len(f.func.returns)
550
# A return is available as a named gradient ...
551
available_named_gradients = [
553
for ret, differentiable in zip(f.func.returns, differentiability)
554
# if it has not been explicitly made undifferentiable
556
# and if it has a name
557
and ret.name is not None
558
# and if its type is differentiable
559
and ret.type.is_tensor_like()
562
for raw_names in sorted(defn.keys()):
563
formula = defn[raw_names]
564
names = split_names(raw_names)
567
assert not (name in all_arg_names and name in all_ret_names), (
568
f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
569
f"expected '{name}' to not be both an input arg and named return. "
572
if is_forward_derivative_definition(all_arg_names, names):
573
forward_derivatives.append(create_forward_derivative(f, formula, names))
575
if formula.lower().strip() == "non_differentiable":
576
non_differentiable_arg_names += names
578
derivative = create_derivative(
579
f, formula, names, available_named_gradients
581
derivatives.append(derivative)
582
args_with_derivatives_set |= set(names)
584
overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
587
f"derivatives definition for {defn} have overlapped non_differentiable "
588
f"and differentiable variables: {overlap}"
591
# Next, let us determine the list of inputs in order.
592
# TODO: do we need eagerly calculate and save it here? Can it be derived
593
# from NativeFunction and `derivatives` on callsites instead?
594
args_with_derivatives = [
595
a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
598
# Postprocess forward derivatives definitions now that we know the differentiable arguments
599
forward_derivatives = postprocess_forward_derivatives(
605
args_with_derivatives,
608
# Test to see if the use of 'grads' makes sense.
609
check_grad_usage(defn_name, derivatives)
614
args_with_derivatives,
615
non_differentiable_arg_names,
616
available_named_gradients,
619
# NB: Removes 'name' from defn dictionary
620
specification = defn_dict.pop("name")
621
defn_name, _ = split_name_params(specification)
622
# NB: Removes 'output_differentiability' from defn dictionary
623
# `None` means all differentiable.
624
output_differentiability = defn_dict.pop("output_differentiability", None)
625
output_differentiability_conditions = None
626
if output_differentiability and any(
627
isinstance(diff, str) for diff in output_differentiability
629
if len(output_differentiability) != 1:
631
f"Not supported: for {specification},"
632
f"output_differentiability must either be "
633
f"List[bool] or a List[str] where each str is a "
634
f"condition. In the case where it is a condition, "
635
f"we only support single-output functions. "
636
f"Please file us an issue. "
638
output_differentiability_conditions = output_differentiability
639
output_differentiability = [True]
641
schema_function = functions_by_schema.get(specification)
642
if not schema_function:
644
k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
647
f"could not find ATen function for schema: {specification} "
648
f". Available signatures:\n{avail}"
651
# now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
652
# to map in-place schemas to the out-of-place variants.
653
# TODO: maybe the logic to handle the legacy schema is no longer necessary?
654
signature = schema_function.func.signature()
655
functions = functions_by_signature[signature]
656
if len(functions) == 0:
659
for k, v in functions_by_signature.items()
660
if cpp.name(k) == defn_name
663
f"could not find ATen function for legacy signature: {signature} "
664
f"corresponding to schema {specification}. Please report a bug to PyTorch. "
665
f"Available signatures:\n{avail}"
668
canonical = canonical_function(functions, defn_name)
669
if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
671
f"Schema for {defn_name} has an argument named grad_input_mask, "
672
"but this name would be shadowed by our codegen. "
673
"Please use a different name in native_functions.yaml."
676
if "result" in (a.name for a in cpp_arguments(canonical)):
678
f"Schema for {defn_name} has an argument named result, "
679
"but this is only allowed for outputs."
680
"Please use a different name in native_functions.yaml."
684
for key, defn in defn_dict["dispatch"].items():
685
if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
687
f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
688
f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
690
if key not in used_dispatch_keys:
691
used_dispatch_keys.add(key)
696
args_with_derivatives,
697
non_differentiable_arg_names,
698
available_named_gradients,
699
) = set_up_derivatives(canonical)
701
used_named_gradients: Set[str] = set()
702
for d in derivatives:
703
used_named_gradients |= d.named_gradients
705
# only assign an op name if we are actually going to calculate a derivative
707
if args_with_derivatives:
708
op_prefix = _create_op_prefix(defn_name)
710
op_prefix = op_prefix + key
711
op = f"{op_prefix}{op_counter[op_prefix]}"
712
op_counter[op_prefix] += 1
714
diffinfo_dict[key] = DifferentiabilityInfo(
718
derivatives=derivatives,
719
forward_derivatives=forward_derivatives,
720
all_saved_inputs=dedup_vars(
721
[v for d in derivatives for v in d.saved_inputs]
723
all_saved_outputs=dedup_vars(
724
[v for d in derivatives for v in d.saved_outputs]
726
available_named_gradients=available_named_gradients,
727
used_named_gradients=used_named_gradients,
728
args_with_derivatives=args_with_derivatives,
729
non_differentiable_arg_names=non_differentiable_arg_names,
730
output_differentiability=output_differentiability,
731
output_differentiability_conditions=output_differentiability_conditions,
734
return canonical.func, diffinfo_dict
737
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
740
def used_gradient_indices(formula: str) -> List[int]:
741
"""Determine a list of gradient indices (the i in grads[i]) that
742
are used by the formula.
744
>>> used_gradient_indices("foo(grads[0], grads[1])")
747
return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
752
nctypes: List[NamedCType],
753
var_names: Tuple[str, ...],
754
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
755
def stride_expr(name: str) -> str:
756
assert var_names == (name,), (
757
'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
758
'that ".strides()" is being called on.'
760
return f'strides_or_error({name}, "{name}")'
762
REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
763
# replace self.sym_sizes() with self_sym_sizes
767
"suffix": "_sym_sizes",
768
"nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
771
# replace self->sym_sizes() with self_sym_sizes_opt
773
r"{}->sym_sizes\(\)",
775
"suffix": "_sym_sizes_opt",
776
"nctype": lambda name: NamedCType(
777
name, OptionalCType(BaseCType(symIntArrayRefT))
779
"expr": lambda name: f"{name}.has_value() ? c10::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : c10::nullopt",
782
# replace self.sym_blocksize() with self_sym_blocksize_opt
784
r"{}.sym_blocksize\(\)",
786
"suffix": "_self_sym_blocksize_opt",
787
"nctype": lambda name: NamedCType(
788
name, OptionalCType(BaseCType(symIntArrayRefT))
790
"expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
793
# replace self.options() with self_options
797
"suffix": "_options",
798
"nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
801
# replace zeros_like(self) with self_info
806
"nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
807
"expr": lambda name: name, # at save-time
808
"res": lambda name: name + "_info.zeros()", # at eval-time
811
# replace self.sym_size(2) with self_sym_size_2
813
r"{}.sym_size\((-?\w+)\)",
815
"suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
816
"nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
819
# replace self.numel() with self_numel
824
"nctype": lambda name: NamedCType(name, BaseCType(longT)),
827
# replace self.sym_numel() with self_sym_numel
831
"suffix": "_sym_numel",
832
"nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
835
# replace to_args_sizes(self) with self_args_sizes
837
r"to_args_sizes\({}\)",
839
"suffix": "_args_sizes",
840
"nctype": lambda name: NamedCType(
841
name, VectorCType(VectorCType(BaseCType(longT)))
845
# replace to_args_sizes_symint(self) with self_args_sizes
847
r"to_args_sizes_symint\({}\)",
849
"suffix": "_args_sizes_symint",
850
"nctype": lambda name: NamedCType(
851
name, VectorCType(VectorCType(BaseCType(SymIntT)))
855
# replace to_args_scalartypes(self) with self_args_scalartypes
857
r"to_args_scalartypes\({}\)",
859
"suffix": "_args_scalartypes",
860
"nctype": lambda name: NamedCType(
861
name, VectorCType(BaseCType(scalarTypeT))
865
# replace TensorGeometry(self) with self_geometry
867
r"TensorGeometry\({}\)",
869
"suffix": "_geometry",
870
"nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
874
r"{}.scalar_type\(\)",
876
"suffix": "_scalar_type",
877
"nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
880
# replace self.dim() with self_dim
885
"nctype": lambda name: NamedCType(name, BaseCType(longT)),
888
# replace self.sym_strides() with self_sym_strides
890
r"{}.sym_strides\(\)",
892
"suffix": "_sym_strides",
893
"nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
897
# replace self.layout() with self_layout
902
"nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
905
# replace self.is_conj() with self_conjugate
909
"suffix": "_conjugate",
910
"nctype": lambda name: NamedCType(name, BaseCType(boolT)),
915
# find which arguments need to be saved
916
saved: List[SavedAttribute] = []
918
if ".sizes()" in formula or "->sizes()" in formula:
920
".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
921
+ f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
923
if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
924
r"->size\([-]?\d+\)", formula
927
".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
928
+ f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
930
if ".strides()" in formula or "->strides()" in formula:
932
".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
933
+ f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
935
for nctype in nctypes:
937
nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
939
# First search the formula for expressions which can be evaluated
940
# when the autograd Function is created to avoid saving variables
941
for regex, info in REPLACEMENTS:
943
def repl(m: Match[str]) -> str:
945
info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
947
expr: str = info["expr"](name) if "expr" in info else m.group(0)
950
nctype=info["nctype"](name + suffix),
955
replacement: str = info["res"](name)
959
formula = re.sub(regex.format(name), repl, formula)
961
# c10::optional<std::string> types stored in Backward nodes must be
962
# converted to c10::optional<c10::string_view> before being passed into
963
# the backward function
964
if nctype.type == OptionalCType(BaseCType(stringT)):
967
f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt",
971
# Find any variables which remain in the formula and save them
972
if re.search(IDENT_REGEX.format(name), formula):
980
return formula, tuple(saved)
983
def _create_op_prefix(name: str) -> str:
984
"""Takes a native function name converts to a op prefix name.
986
Note that the "name" parameter must be the native function name
987
without the optional variant suffix, so "add" instead of
990
OP names correspond to classes, hence the change to title case.
993
>>> _create_op_prefix('add')
996
camel_case = "".join([p.title() for p in name.split("_")])
997
return (camel_case + "Backward").replace("ForwardBackward", "Backward")
1000
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
1001
seen: Set[str] = set()
1002
saved: List[SavedAttribute] = []
1005
var.nctype.name.name
1006
if isinstance(var.nctype.name, SpecialArgName)
1007
else var.nctype.name