pytorch

Форк
0
/
load_derivatives.py 
1013 строк · 39.4 Кб
1
# Parses derivatives.yaml into autograd functions
2
#
3
# Each autograd function is represented by `DifferentiabilityInfo` containing
4
# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
5
import re
6
from collections import defaultdict
7
from typing import Any, Counter, Dict, List, Match, Optional, Sequence, Set, Tuple
8

9
import yaml
10
from torchgen.api import cpp
11

12
from torchgen.api.autograd import (
13
    Derivative,
14
    DifferentiabilityInfo,
15
    ForwardDerivative,
16
    SavedAttribute,
17
)
18
from torchgen.api.types import (
19
    BaseCType,
20
    Binding,
21
    boolT,
22
    CppSignatureGroup,
23
    layoutT,
24
    longT,
25
    NamedCType,
26
    OptionalCType,
27
    scalarTypeT,
28
    SpecialArgName,
29
    stringT,
30
    symIntArrayRefT,
31
    SymIntT,
32
    tensorGeometryT,
33
    tensorOptionsT,
34
    typeAndSizeT,
35
    VectorCType,
36
)
37
from torchgen.context import with_native_function
38
from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
39
from torchgen.model import (
40
    AUTOGRAD_KEYS,
41
    FunctionSchema,
42
    NativeFunction,
43
    NativeFunctionsViewGroup,
44
    OperatorName,
45
    SchemaKind,
46
    Type,
47
    Variant,
48
)
49
from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
50
from torchgen.yaml_utils import YamlLoader
51

52
DerivativeRet = Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]
53

54
_GLOBAL_LOAD_DERIVATIVE_CACHE: Dict[Tuple[str, str], DerivativeRet] = {}
55

56
_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
57

58

59
# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
60
# Since every {view} and {view}_copy op shares the same derivative formula,
61
# we generate them here instead of duplicating them in the yaml.
62
# See Note [Codegen'd {view}_copy Operators]
63
def add_view_copy_derivatives(
64
    infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
65
    view_groups: List[NativeFunctionsViewGroup],
66
) -> None:
67
    # Get the map from each view op's name to its corresponding view group
68
    view_name_to_group: Dict[OperatorName, NativeFunctionsViewGroup] = {
69
        g.view.func.name: g for g in view_groups
70
    }
71

72
    view_infos = {}
73

74
    for info_dispatch_dict in infos.values():
75
        # maybe_view_group only needs to be calculated once per info_dispatch_dict
76
        maybe_view_group = None
77
        view_copy_differentiability_infos = {}
78
        for dispatch_key, info in info_dispatch_dict.items():
79
            maybe_view_group = view_name_to_group.get(info.func.func.name, None)
80
            if maybe_view_group is not None and maybe_view_group.view_copy is not None:
81
                view_copy_info = info.create_view_copy_from_view_derivative(
82
                    maybe_view_group
83
                )
84
                if view_copy_info is not None:
85
                    fn_schema = view_copy_info.func.func
86
                    view_copy_differentiability_infos[dispatch_key] = view_copy_info
87
            else:
88
                break
89
        # prefer manually-defined derivatives if any
90
        if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
91
            assert fn_schema is not None
92
            view_infos[fn_schema] = view_copy_differentiability_infos
93

94
    infos.update(view_infos)
95

96

97
def load_derivatives(
98
    derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
99
) -> DerivativeRet:
100
    # Do some caching as this is a deterministic function
101
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
102
    key = (derivatives_yaml_path, native_yaml_path)
103
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
104
        with open(derivatives_yaml_path) as f:
105
            definitions = yaml.load(f, Loader=YamlLoader)
106

107
        funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
108
        # From the parsed native functions, separate out the (generated) view_copy functions,
109
        # so we can generate derivatives for them separately.
110
        native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
111
        native_functions = concatMap(
112
            lambda g: [g]
113
            if isinstance(g, NativeFunction)
114
            else list(g.functions(include_copy=True)),
115
            native_functions_with_view_groups,
116
        )
117
        view_groups = [
118
            g
119
            for g in native_functions_with_view_groups
120
            if isinstance(g, NativeFunctionsViewGroup)
121
        ]
122

123
        # What's the difference between function schema v.s. signature?
124
        # function schema is the complete declaration including mutability annotation / default value and etc.
125
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
126
        # that are semantically related.
127
        functions_by_signature: Dict[
128
            FunctionSchema, List[NativeFunction]
129
        ] = defaultdict(list)
130
        functions_by_schema: Dict[str, NativeFunction] = {}
131
        for function in native_functions:
132
            functions_by_signature[function.func.signature()].append(function)
133
            assert str(function.func) not in functions_by_schema
134
            functions_by_schema[str(function.func)] = function
135

136
        # Keep track of how many of which ops we've seen so we can
137
        # disambiguate them with a numeric suffix.
138
        op_counter = Counter[str]()
139

140
        # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
141
        # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
142
        # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
143
        infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = {}
144
        used_dispatch_keys: Set[str] = set()
145
        for defn_dict in definitions:
146
            # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
147
            if "dispatch" not in defn_dict:
148
                specification = defn_dict.pop("name")
149
                output_differentiability = defn_dict.pop(
150
                    "output_differentiability", None
151
                )
152
                defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
153
                if output_differentiability:
154
                    defn_dict["output_differentiability"] = output_differentiability
155
            name, per_dispatch_diffinfos = create_differentiability_info(
156
                defn_dict,
157
                functions_by_signature,
158
                functions_by_schema,
159
                op_counter,
160
                used_dispatch_keys,
161
            )
162
            infos[name] = per_dispatch_diffinfos
163

164
        add_view_copy_derivatives(infos, view_groups)
165

166
        # cache both loaded infos as well a a set of all the dispatch_keys/aliases
167
        # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
168
        # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
169
        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
170

171
    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
172

173

174
# TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
175
@with_native_function
176
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
177
    sigs = CppSignatureGroup.from_native_function(f, method=False)
178
    if sigs.symint_signature is not None:
179
        return sigs.symint_signature.arguments()
180
    else:
181
        return sigs.signature.arguments()
182

183

184
def create_derivative(
185
    f: NativeFunction,
186
    formula: str,
187
    var_names: Tuple[str, ...],
188
    available_named_gradients: Sequence[str],
189
) -> Derivative:
190
    original_formula = formula
191
    arguments: List[NamedCType] = [
192
        a.nctype.remove_const_ref() for a in cpp_arguments(f)
193
    ]
194

195
    return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
196
    return_types = tuple(
197
        cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
198
    )
199

200
    named_returns = [
201
        NamedCType(name, type) for name, type in zip(return_names, return_types)
202
    ]
203

204
    formula, saved_inputs = saved_variables(formula, arguments, var_names)
205
    formula, saved_outputs = saved_variables(formula, named_returns, var_names)
206

207
    used_named_gradients = {
208
        name
209
        for name in available_named_gradients
210
        if re.search(IDENT_REGEX.format(name), formula)
211
    }
212

213
    # Check that the referenced derivatives in the formula are in bounds
214
    for i in used_gradient_indices(formula):
215
        if i >= len(f.func.returns):
216
            raise RuntimeError(
217
                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
218
                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
219
            )
220

221
    return Derivative(
222
        formula=formula,
223
        original_formula=original_formula,
224
        var_names=var_names,
225
        saved_inputs=saved_inputs,
226
        saved_outputs=saved_outputs,
227
        named_gradients=used_named_gradients,
228
    )
229

230

231
def create_forward_derivative(
232
    f: NativeFunction, formula: str, names: Tuple[str, ...]
233
) -> ForwardDerivative:
234
    var_names = names
235
    var_types: Optional[Tuple[Type, ...]] = None
236
    for r in f.func.returns:
237
        if r.name in var_names:
238
            if var_types is None:
239
                var_types = tuple()
240
            var_types = var_types + (r.type,)
241

242
    # Handle default return names
243
    if var_types is None:
244
        if var_names == ("result",):
245
            assert len(f.func.returns) == 1
246
            var_types = (f.func.returns[0].type,)
247
        else:
248
            for var_name in var_names:
249
                res = re.findall(r"^result(\d+)$", var_name)
250
                if len(res) == 1:
251
                    if var_types is None:
252
                        var_types = tuple()
253
                    arg_idx = int(res[0])
254
                    var_types = var_types + (f.func.returns[arg_idx].type,)
255

256
    assert var_types is not None, "No matching output for forward derivative definition"
257
    return ForwardDerivative(
258
        formula=formula,
259
        var_names=var_names,
260
        var_types=var_types,
261
        required_inputs_fw_grad=None,
262
        required_inputs_primal=None,
263
        required_original_self_value=False,
264
        is_reusing_outplace_formula=False,
265
    )
266

267

268
def postprocess_forward_derivatives(
269
    f: NativeFunction,
270
    defn_name: str,
271
    all_arg_names: List[str],
272
    derivatives: List[Derivative],
273
    forward_derivatives: List[ForwardDerivative],
274
    args_with_derivatives: Sequence[Binding],
275
) -> List[ForwardDerivative]:
276
    def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]:
277
        is_foreach = f.func.name.name.base.startswith("_foreach_")
278
        required_inputs = set()
279
        for arg in args_with_derivatives:
280
            if (
281
                arg.type in ("at::TensorList", "const at::ITensorListRef &")
282
                and not is_foreach
283
            ):
284
                # The functions taking TensorList handle everything internally
285
                continue
286
            arg_name = arg.name
287

288
            found = re.search(IDENT_REGEX.format(arg_name), formula)
289
            if found:
290
                raise RuntimeError(
291
                    f"The forward formula for {defn_name} is using the base name of the {arg_name} "
292
                    f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
293
                    f"value and {arg_name}_t to access the tangent."
294
                )
295

296
            found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
297
            if found:
298
                required_inputs.add(arg_name)
299

300
        return tuple(required_inputs)
301

302
    updated_derivatives: List[ForwardDerivative] = []
303

304
    for defn in forward_derivatives:
305
        formula = defn.formula
306
        required_inputs_tangent = find_required_inputs(formula, "_t")
307
        if formula == "auto_element_wise":
308
            assert (
309
                f.func.kind() != SchemaKind.inplace
310
            ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
311
            if (
312
                (not len(args_with_derivatives) == 1)
313
                or len(forward_derivatives) > 1
314
                or len(forward_derivatives[0].var_names) > 1
315
            ):
316
                raise RuntimeError(
317
                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
318
                    "forward definition of gradient as element_wise but this only "
319
                    "works for functions with a single differentiable input and a "
320
                    "single differentiable output."
321
                )
322
            if not len(derivatives) == 1:
323
                raise RuntimeError(
324
                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
325
                    "forward definition of gradient as element_wise but it does not "
326
                    "defines the gradient formula for its argument which is required."
327
                )
328
            # This transformation is based on the observation that for element-wise functions, the Jacobian
329
            # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
330
            # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
331
            # So here we are going to re-use the backward formula and replace two things:
332
            # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
333
            # 2) all usage of an original input "foo" with its primal value "foo_p".
334
            # 3) conjugate the final result
335
            # For example, for abs, the backward formula is:
336
            #   grad * self.sgn()
337
            # And this function generates a forward formula that is:
338
            #   (self_t.conj() * self_p.sgn()).conj()
339

340
            backward_formula = derivatives[0].original_formula
341
            input_name = args_with_derivatives[0].name
342

343
            # Do replacement 1) of the grad
344
            def repl(m: Any) -> str:
345
                return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
346

347
            fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
348

349
            # Do replacement 2) of the input variables
350
            for arg in args_with_derivatives:
351
                arg_name = arg.name
352

353
                def repl(m: Any) -> str:
354
                    return f"{m.group(1)}{arg_name}_p{m.group(2)}"
355

356
                fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
357

358
            # Do the final conjugate 3)
359
            fw_formula = f"({fw_formula}).conj()"
360

361
            # Since there is a single differentiable inputs and we necessarily need its tangent we can
362
            # simply require all differentiable input's tangent.
363
            required_inputs_tangent = tuple(all_arg_names)
364
            formula = fw_formula
365
        elif formula == "auto_linear":
366
            if (
367
                len(forward_derivatives) > 1
368
                or len(forward_derivatives[0].var_names) > 1
369
            ):
370
                raise RuntimeError(
371
                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
372
                    "forward definition of gradient as linear but this only works "
373
                    "for functions with a single differentiable output."
374
                )
375
            # This transformation is based on the observation that linear functions can be written as:
376
            #   y = f(x) = A * x
377
            # For some matrix A and the Jacobian of the function f is also A.
378
            # So doing J * v = A * v = f(v).
379
            # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
380
            # We do this by calling the forward again by replacing any occurrence of the differentiable
381
            # input "foo" by it's tangent "foo_t".
382
            # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
383
            # the vector where all the differentiable inputs are stacked.
384

385
            diff_arg_names = [arg.name for arg in args_with_derivatives]
386
            assert len(diff_arg_names) > 0
387

388
            # Do replacement of input variables
389
            new_args = []
390
            for arg_name in all_arg_names:
391
                if arg_name in diff_arg_names:
392
                    arg_name = arg_name + "_t"
393
                new_args.append(arg_name)
394

395
            # TODO we are trolling
396
            if f.func.has_symint():
397
                defn_name += "_symint"
398

399
            # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
400
            if Variant.function in f.variants:
401
                fw_formula = f"at::{defn_name}({', '.join(new_args)})"
402
            else:
403
                assert Variant.method in f.variants
404
                fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
405

406
            # All of the input tangents are always used so all of them are required here.
407
            required_inputs_tangent = tuple(diff_arg_names)
408
            formula = fw_formula
409

410
        # At this point, the formula is final and is not modified anymore.
411

412
        # During forward formula, we use the primal instead of the input Tensors.
413
        # This call inspects the formula to find for which input's primal are used.
414
        required_inputs_primal = find_required_inputs(formula, "_p")
415

416
        updated_derivatives.append(
417
            ForwardDerivative(
418
                formula=formula,
419
                var_names=defn.var_names,
420
                var_types=defn.var_types,
421
                required_inputs_fw_grad=required_inputs_tangent,
422
                required_inputs_primal=required_inputs_primal,
423
                required_original_self_value=False,
424
                is_reusing_outplace_formula=False,
425
            )
426
        )
427

428
    return updated_derivatives
429

430

431
def is_forward_derivative_definition(
432
    all_arg_names: List[str], names: Tuple[str, ...]
433
) -> bool:
434
    for name in names:
435
        if name not in all_arg_names:
436
            return True
437
        else:
438
            return False
439
    raise RuntimeError("Expected `names` to be non-empty")
440

441

442
def create_differentiability_info(
443
    defn_dict: Dict[Any, Any],
444
    functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
445
    functions_by_schema: Dict[str, NativeFunction],
446
    op_counter: Counter[str],
447
    used_dispatch_keys: Set[str],
448
) -> Tuple[FunctionSchema, Dict[str, DifferentiabilityInfo]]:
449
    """Processes a single entry `defn` in derivatives.yaml"""
450

451
    def canonical_function(
452
        functions: Sequence[NativeFunction], name: str
453
    ) -> NativeFunction:
454
        for f in functions:
455
            if (
456
                not f.func.is_functional_fn()
457
                and not f.func.is_out_fn()
458
                and name == str(f.func.name.name)
459
            ):
460
                return f
461
        # some functions only have in-place variants
462
        assert name + "_" == cpp.name(functions[0].func)
463
        return functions[0]
464

465
    def split_names(raw_names: str) -> Tuple[str, ...]:
466
        """Given "foo, bar", return ["foo", "bar"]."""
467
        return tuple(x.strip() for x in raw_names.split(","))
468

469
    def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
470
        """
471
        Check for some subtle mistakes one might make when writing derivatives.
472
        These mistakes will compile, but will be latent until a function is
473
        used with double backwards.
474
        """
475

476
        uses_grad = False  # true if any derivative uses "grad"
477
        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
478
        uses_named_grads = False  # true if any derivative uses "grad_{name}"
479
        used_grads_indices: List[int] = []  # which indices of grads are used
480
        for d in derivatives:
481
            formula = d.formula
482
            uses_grad = uses_grad or bool(
483
                re.findall(IDENT_REGEX.format("grad"), formula)
484
            )
485
            num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
486
            uses_named_grads = uses_named_grads or bool(d.named_gradients)
487
            used_grads_indices.extend(used_gradient_indices(formula))
488
        # This is a basic sanity check: the number of places we see
489
        # "grads" should be no fewer than the number of indices we see
490
        # inside "grads". They may not be equal because we may use
491
        # "grads" without an index.
492
        assert num_grads_uses >= len(used_grads_indices)
493
        # Thus if the number is equal, every use of grads is also
494
        # indexed.
495
        only_used_grads_indices = num_grads_uses == len(used_grads_indices)
496

497
        if uses_grad and num_grads_uses > 0:
498
            raise RuntimeError(
499
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
500
                "mixes use of 'grad' and 'grads'. Consider replacing "
501
                "occurrences of 'grad' with 'grads[0]'"
502
            )
503

504
        if only_used_grads_indices and set(used_grads_indices) == {0}:
505
            raise RuntimeError(
506
                f"Derivative definition of {defn_name} in derivatives.yaml solely "
507
                "refers to 'grads[0]'.  If the first output is indeed the "
508
                "only differentiable output, replace 'grads[0]' with 'grad'; "
509
                "otherwise, there is a likely error in your derivatives "
510
                "declaration."
511
            )
512

513
        if uses_named_grads and (uses_grad or num_grads_uses > 0):
514
            raise RuntimeError(
515
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
516
                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
517
                "only one method for identifying gradients."
518
            )
519

520
    @with_native_function
521
    def set_up_derivatives(
522
        f: NativeFunction,
523
    ) -> Tuple[
524
        Sequence[Derivative],
525
        Sequence[ForwardDerivative],
526
        Sequence[Binding],
527
        Sequence[str],
528
        Sequence[str],
529
    ]:
530
        # Set up the derivative information
531
        derivatives: List[Derivative] = []
532
        forward_derivatives: List[ForwardDerivative] = []
533
        non_differentiable_arg_names: List[str] = []
534
        args_with_derivatives_set: Set[str] = set()
535

536
        all_arg_names = [a.name for a in cpp_arguments(f)]
537
        all_ret_names = [
538
            r.name for r in f.func.returns
539
        ]  # only used for the assert below
540
        # output_differentiability is captured from the enclosed
541
        # scope. Don't modify it.
542
        #
543
        # If it is not present, then no output is explicitly
544
        # undifferentiable.
545
        #
546
        # It may be present and shorter than the length of return
547
        # values. If that's the case, any return value that does not
548
        # have a corresponding entry is considered not differentiable.
549
        differentiability = output_differentiability or [True] * len(f.func.returns)
550
        # A return is available as a named gradient ...
551
        available_named_gradients = [
552
            f"grad_{ret.name}"
553
            for ret, differentiable in zip(f.func.returns, differentiability)
554
            # if it has not been explicitly made undifferentiable
555
            if differentiable
556
            # and if it has a name
557
            and ret.name is not None
558
            # and if its type is differentiable
559
            and ret.type.is_tensor_like()
560
        ]
561

562
        for raw_names in sorted(defn.keys()):
563
            formula = defn[raw_names]
564
            names = split_names(raw_names)
565

566
            for name in names:
567
                assert not (name in all_arg_names and name in all_ret_names), (
568
                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
569
                    f"expected '{name}' to not be both an input arg and named return. "
570
                )
571

572
            if is_forward_derivative_definition(all_arg_names, names):
573
                forward_derivatives.append(create_forward_derivative(f, formula, names))
574
            else:
575
                if formula.lower().strip() == "non_differentiable":
576
                    non_differentiable_arg_names += names
577
                else:
578
                    derivative = create_derivative(
579
                        f, formula, names, available_named_gradients
580
                    )
581
                    derivatives.append(derivative)
582
                    args_with_derivatives_set |= set(names)
583

584
        overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
585
        if overlap:
586
            raise RuntimeError(
587
                f"derivatives definition for {defn} have overlapped non_differentiable "
588
                f"and differentiable variables: {overlap}"
589
            )
590

591
        # Next, let us determine the list of inputs in order.
592
        # TODO: do we need eagerly calculate and save it here? Can it be derived
593
        # from NativeFunction and `derivatives` on callsites instead?
594
        args_with_derivatives = [
595
            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
596
        ]
597

598
        # Postprocess forward derivatives definitions now that we know the differentiable arguments
599
        forward_derivatives = postprocess_forward_derivatives(
600
            f,
601
            defn_name,
602
            all_arg_names,
603
            derivatives,
604
            forward_derivatives,
605
            args_with_derivatives,
606
        )
607

608
        # Test to see if the use of 'grads' makes sense.
609
        check_grad_usage(defn_name, derivatives)
610

611
        return (
612
            derivatives,
613
            forward_derivatives,
614
            args_with_derivatives,
615
            non_differentiable_arg_names,
616
            available_named_gradients,
617
        )
618

619
    # NB: Removes 'name' from defn dictionary
620
    specification = defn_dict.pop("name")
621
    defn_name, _ = split_name_params(specification)
622
    # NB: Removes 'output_differentiability' from defn dictionary
623
    #     `None` means all differentiable.
624
    output_differentiability = defn_dict.pop("output_differentiability", None)
625
    output_differentiability_conditions = None
626
    if output_differentiability and any(
627
        isinstance(diff, str) for diff in output_differentiability
628
    ):
629
        if len(output_differentiability) != 1:
630
            raise RuntimeError(
631
                f"Not supported: for {specification},"
632
                f"output_differentiability must either be "
633
                f"List[bool] or a List[str] where each str is a "
634
                f"condition. In the case where it is a condition, "
635
                f"we only support single-output functions. "
636
                f"Please file us an issue. "
637
            )
638
        output_differentiability_conditions = output_differentiability
639
        output_differentiability = [True]
640

641
    schema_function = functions_by_schema.get(specification)
642
    if not schema_function:
643
        avail = "\n".join(
644
            k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
645
        )
646
        raise RuntimeError(
647
            f"could not find ATen function for schema: {specification} "
648
            f".  Available signatures:\n{avail}"
649
        )
650

651
    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
652
    # to map in-place schemas to the out-of-place variants.
653
    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
654
    signature = schema_function.func.signature()
655
    functions = functions_by_signature[signature]
656
    if len(functions) == 0:
657
        avail = "\n".join(
658
            str(k)
659
            for k, v in functions_by_signature.items()
660
            if cpp.name(k) == defn_name
661
        )
662
        raise RuntimeError(
663
            f"could not find ATen function for legacy signature: {signature} "
664
            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
665
            f"Available signatures:\n{avail}"
666
        )
667

668
    canonical = canonical_function(functions, defn_name)
669
    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
670
        raise RuntimeError(
671
            f"Schema for {defn_name} has an argument named grad_input_mask, "
672
            "but this name would be shadowed by our codegen. "
673
            "Please use a different name in native_functions.yaml."
674
        )
675

676
    if "result" in (a.name for a in cpp_arguments(canonical)):
677
        raise RuntimeError(
678
            f"Schema for {defn_name} has an argument named result, "
679
            "but this is only allowed for outputs."
680
            "Please use a different name in native_functions.yaml."
681
        )
682

683
    diffinfo_dict = {}
684
    for key, defn in defn_dict["dispatch"].items():
685
        if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
686
            raise RuntimeError(
687
                f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
688
                f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
689
            )
690
        if key not in used_dispatch_keys:
691
            used_dispatch_keys.add(key)
692

693
        (
694
            derivatives,
695
            forward_derivatives,
696
            args_with_derivatives,
697
            non_differentiable_arg_names,
698
            available_named_gradients,
699
        ) = set_up_derivatives(canonical)
700

701
        used_named_gradients: Set[str] = set()
702
        for d in derivatives:
703
            used_named_gradients |= d.named_gradients
704

705
        # only assign an op name if we are actually going to calculate a derivative
706
        op = None
707
        if args_with_derivatives:
708
            op_prefix = _create_op_prefix(defn_name)
709
            if key != "Default":
710
                op_prefix = op_prefix + key
711
            op = f"{op_prefix}{op_counter[op_prefix]}"
712
            op_counter[op_prefix] += 1
713

714
        diffinfo_dict[key] = DifferentiabilityInfo(
715
            name=defn_name,
716
            func=canonical,
717
            op=op,
718
            derivatives=derivatives,
719
            forward_derivatives=forward_derivatives,
720
            all_saved_inputs=dedup_vars(
721
                [v for d in derivatives for v in d.saved_inputs]
722
            ),
723
            all_saved_outputs=dedup_vars(
724
                [v for d in derivatives for v in d.saved_outputs]
725
            ),
726
            available_named_gradients=available_named_gradients,
727
            used_named_gradients=used_named_gradients,
728
            args_with_derivatives=args_with_derivatives,
729
            non_differentiable_arg_names=non_differentiable_arg_names,
730
            output_differentiability=output_differentiability,
731
            output_differentiability_conditions=output_differentiability_conditions,
732
        )
733

734
    return canonical.func, diffinfo_dict
735

736

737
GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
738

739

740
def used_gradient_indices(formula: str) -> List[int]:
741
    """Determine a list of gradient indices (the i in grads[i]) that
742
    are used by the formula.
743

744
    >>> used_gradient_indices("foo(grads[0], grads[1])")
745
    [0, 1]
746
    """
747
    return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
748

749

750
def saved_variables(
751
    formula: str,
752
    nctypes: List[NamedCType],
753
    var_names: Tuple[str, ...],
754
) -> Tuple[str, Tuple[SavedAttribute, ...]]:
755
    def stride_expr(name: str) -> str:
756
        assert var_names == (name,), (
757
            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
758
            'that ".strides()" is being called on.'
759
        )
760
        return f'strides_or_error({name}, "{name}")'
761

762
    REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [
763
        # replace self.sym_sizes() with self_sym_sizes
764
        (
765
            r"{}.sym_sizes\(\)",
766
            {
767
                "suffix": "_sym_sizes",
768
                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
769
            },
770
        ),
771
        # replace self->sym_sizes() with self_sym_sizes_opt
772
        (
773
            r"{}->sym_sizes\(\)",
774
            {
775
                "suffix": "_sym_sizes_opt",
776
                "nctype": lambda name: NamedCType(
777
                    name, OptionalCType(BaseCType(symIntArrayRefT))
778
                ),
779
                "expr": lambda name: f"{name}.has_value() ? c10::optional<c10::SymIntArrayRef>({name}->sym_sizes()) : c10::nullopt",
780
            },
781
        ),
782
        # replace self.sym_blocksize() with self_sym_blocksize_opt
783
        (
784
            r"{}.sym_blocksize\(\)",
785
            {
786
                "suffix": "_self_sym_blocksize_opt",
787
                "nctype": lambda name: NamedCType(
788
                    name, OptionalCType(BaseCType(symIntArrayRefT))
789
                ),
790
                "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
791
            },
792
        ),
793
        # replace self.options() with self_options
794
        (
795
            r"{}.options\(\)",
796
            {
797
                "suffix": "_options",
798
                "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
799
            },
800
        ),
801
        # replace zeros_like(self) with self_info
802
        (
803
            r"zeros_like\({}\)",
804
            {
805
                "suffix": "_info",
806
                "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
807
                "expr": lambda name: name,  # at save-time
808
                "res": lambda name: name + "_info.zeros()",  # at eval-time
809
            },
810
        ),
811
        # replace self.sym_size(2) with self_sym_size_2
812
        (
813
            r"{}.sym_size\((-?\w+)\)",
814
            {
815
                "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
816
                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
817
            },
818
        ),
819
        # replace self.numel() with self_numel
820
        (
821
            r"{}.numel\(\)",
822
            {
823
                "suffix": "_numel",
824
                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
825
            },
826
        ),
827
        # replace self.sym_numel() with self_sym_numel
828
        (
829
            r"{}.sym_numel\(\)",
830
            {
831
                "suffix": "_sym_numel",
832
                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
833
            },
834
        ),
835
        # replace to_args_sizes(self) with self_args_sizes
836
        (
837
            r"to_args_sizes\({}\)",
838
            {
839
                "suffix": "_args_sizes",
840
                "nctype": lambda name: NamedCType(
841
                    name, VectorCType(VectorCType(BaseCType(longT)))
842
                ),
843
            },
844
        ),
845
        # replace to_args_sizes_symint(self) with self_args_sizes
846
        (
847
            r"to_args_sizes_symint\({}\)",
848
            {
849
                "suffix": "_args_sizes_symint",
850
                "nctype": lambda name: NamedCType(
851
                    name, VectorCType(VectorCType(BaseCType(SymIntT)))
852
                ),
853
            },
854
        ),
855
        # replace to_args_scalartypes(self) with self_args_scalartypes
856
        (
857
            r"to_args_scalartypes\({}\)",
858
            {
859
                "suffix": "_args_scalartypes",
860
                "nctype": lambda name: NamedCType(
861
                    name, VectorCType(BaseCType(scalarTypeT))
862
                ),
863
            },
864
        ),
865
        # replace TensorGeometry(self) with self_geometry
866
        (
867
            r"TensorGeometry\({}\)",
868
            {
869
                "suffix": "_geometry",
870
                "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
871
            },
872
        ),
873
        (
874
            r"{}.scalar_type\(\)",
875
            {
876
                "suffix": "_scalar_type",
877
                "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
878
            },
879
        ),
880
        # replace self.dim() with self_dim
881
        (
882
            r"{}.dim\(\)",
883
            {
884
                "suffix": "_dim",
885
                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
886
            },
887
        ),
888
        # replace self.sym_strides() with self_sym_strides
889
        (
890
            r"{}.sym_strides\(\)",
891
            {
892
                "suffix": "_sym_strides",
893
                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
894
                "expr": stride_expr,
895
            },
896
        ),
897
        # replace self.layout() with self_layout
898
        (
899
            r"{}.layout\(\)",
900
            {
901
                "suffix": "_layout",
902
                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
903
            },
904
        ),
905
        # replace self.is_conj() with self_conjugate
906
        (
907
            r"{}.is_conj\(\)",
908
            {
909
                "suffix": "_conjugate",
910
                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
911
            },
912
        ),
913
    ]
914

915
    # find which arguments need to be saved
916
    saved: List[SavedAttribute] = []
917

918
    if ".sizes()" in formula or "->sizes()" in formula:
919
        raise RuntimeError(
920
            ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
921
            + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
922
        )
923
    if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
924
        r"->size\([-]?\d+\)", formula
925
    ):
926
        raise RuntimeError(
927
            ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
928
            + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
929
        )
930
    if ".strides()" in formula or "->strides()" in formula:
931
        raise RuntimeError(
932
            ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
933
            + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
934
        )
935
    for nctype in nctypes:
936
        name = (
937
            nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
938
        )
939
        # First search the formula for expressions which can be evaluated
940
        # when the autograd Function is created to avoid saving variables
941
        for regex, info in REPLACEMENTS:
942

943
            def repl(m: Match[str]) -> str:
944
                suffix: str = (
945
                    info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
946
                )
947
                expr: str = info["expr"](name) if "expr" in info else m.group(0)
948
                saved.append(
949
                    SavedAttribute(
950
                        nctype=info["nctype"](name + suffix),
951
                        expr=expr,
952
                    )
953
                )
954
                if "res" in info:
955
                    replacement: str = info["res"](name)
956
                    return replacement
957
                return name + suffix
958

959
            formula = re.sub(regex.format(name), repl, formula)
960

961
        # c10::optional<std::string> types stored in Backward nodes must be
962
        # converted to c10::optional<c10::string_view> before being passed into
963
        # the backward function
964
        if nctype.type == OptionalCType(BaseCType(stringT)):
965
            formula = re.sub(
966
                rf"\b{name}\b",
967
                f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt",
968
                formula,
969
            )
970

971
        # Find any variables which remain in the formula and save them
972
        if re.search(IDENT_REGEX.format(name), formula):
973
            saved.append(
974
                SavedAttribute(
975
                    nctype=nctype,
976
                    expr=name,
977
                )
978
            )
979

980
    return formula, tuple(saved)
981

982

983
def _create_op_prefix(name: str) -> str:
984
    """Takes a native function name converts to a op prefix name.
985

986
    Note that the "name" parameter must be the native function name
987
    without the optional variant suffix, so "add" instead of
988
    "add.out".
989

990
    OP names correspond to classes, hence the change to title case.
991

992
    Example::
993
    >>> _create_op_prefix('add')
994
    'AddBackward'
995
    """
996
    camel_case = "".join([p.title() for p in name.split("_")])
997
    return (camel_case + "Backward").replace("ForwardBackward", "Backward")
998

999

1000
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
1001
    seen: Set[str] = set()
1002
    saved: List[SavedAttribute] = []
1003
    for var in vars:
1004
        name = (
1005
            var.nctype.name.name
1006
            if isinstance(var.nctype.name, SpecialArgName)
1007
            else var.nctype.name
1008
        )
1009
        if name in seen:
1010
            continue
1011
        seen.add(name)
1012
        saved.append(var)
1013
    return saved
1014

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.