pytorch

Форк
0
/
autograd.py 
870 строк · 38.0 Кб
1
from __future__ import annotations
2

3
import re
4
from dataclasses import dataclass
5
from typing import cast, Sequence
6

7
from torchgen import local
8
from torchgen.api import cpp
9
from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
10
from torchgen.model import (
11
    BaseTy,
12
    BaseType,
13
    FunctionSchema,
14
    ListType,
15
    NativeFunction,
16
    NativeFunctionsViewGroup,
17
    SchemaKind,
18
    Type,
19
)
20
from torchgen.utils import IDENT_REGEX
21

22

23
# Represents a saved attribute involved in backward calculation.
24
# Note that it can be a derived property of an input argument, e.g.:
25
# we could save `other.scalar_type()` instead of the entire `other` tensor.
26
@dataclass(frozen=True)
27
class SavedAttribute:
28
    # The NamedCType holds the updated name and cpp type of the attribute
29
    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
30
    nctype: NamedCType
31

32
    # The expression to read the derived property at save time, e.g.:
33
    # `other.scalar_type()`.
34
    expr: str
35

36

37
# Represents a backward formula that calculates derivatives for one
38
# or more tensors.
39
@dataclass(frozen=True)
40
class Derivative:
41
    # The formula string (legit C++ expression).
42
    # Note that expressions against input arguments have been replaced with the
43
    # corresponding saved attributes.
44
    # E.g.:
45
    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
46
    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
47
    formula: str
48

49
    # The formula string before input argument replacement
50
    original_formula: str
51

52
    # Names of the arguments for which this formula calculates derivatives.
53
    var_names: tuple[str, ...]
54

55
    # Saved inputs that are referenced by the formula.
56
    saved_inputs: tuple[SavedAttribute, ...]
57

58
    # Saved outputs that are referenced by the formula.
59
    saved_outputs: tuple[SavedAttribute, ...]
60

61
    # Gradients that are referenced by name in the formula.
62
    named_gradients: set[str]
63

64

65
# Represents a forward formula that calculates forward derivatives
66
# for one tensor.
67
@dataclass(frozen=True)
68
class ForwardDerivative:
69
    # The formula string (legit C++ expression).
70
    # Note that special keywords such as "linear" or "element_wise" have been
71
    # replaced by the automatically generated formula.
72
    formula: str
73

74
    # Name of the output arguments for which this formula calculates forward
75
    # derivatives
76
    var_names: tuple[str, ...]
77

78
    # Type of the output arguments for which this formula calculates forward
79
    # derivatives
80
    var_types: tuple[Type, ...]
81

82
    # Inputs for which the forward derivatives are required for this formula
83
    required_inputs_fw_grad: tuple[str, ...] | None
84

85
    # Inputs for which the primal is required for this formula
86
    required_inputs_primal: tuple[str, ...] | None
87

88
    # Flag to specify if this formula requires the original value of self
89
    # This is only used by inplace operations
90
    required_original_self_value: bool
91

92
    # If this formula is specified in derivatives.yaml or if we are re-using the
93
    # out of place formula for inplace
94
    is_reusing_outplace_formula: bool
95

96

97
# Represents differentiability info for a NativeFunction.
98
@dataclass(frozen=True)
99
class DifferentiabilityInfo:
100
    # The base name read from derivatives.yaml.
101
    name: str
102

103
    # The matching native function.
104
    #
105
    # There can be multiple NativeFunction having the same base name:
106
    #  - different overloads with different types of input arguments;
107
    #  - in-place/out/functional variants of the same function;
108
    #
109
    # We first use the schema string (under the 'name' key) in derivatives.yaml
110
    # to find the NativeFunction having the same schema string.
111
    # Then we find the in-place/out/functional variants of the matching function.
112
    # Among these variants, we choose the one having the same name as the
113
    # derivatives.yaml entry. If there is no exact match, then we choose the
114
    # in-place variant.
115
    # TODO: maybe the logic to search for all variants is no longer necessary?
116
    func: NativeFunction
117

118
    # The name of the generated autograd function.
119
    # It's set only if we will calculate a derivative, i.e.
120
    # 'args_with_derivatives' is not empty.
121
    op: str | None
122

123
    # The derivatives formulae for this function.
124
    # Note that the length of this sequence is the number of differentiable inputs
125
    derivatives: Sequence[Derivative]
126

127
    # The forward derivatives formulae for this function.
128
    # Note that the length of this sequence is the number of differentiable outputs
129
    forward_derivatives: Sequence[ForwardDerivative]
130

131
    # The union of 'saved_inputs' of all 'derivatives'.
132
    all_saved_inputs: Sequence[SavedAttribute]
133

134
    # The union of 'saved_outputs' of all 'derivatives'.
135
    all_saved_outputs: Sequence[SavedAttribute]
136

137
    # All named gradients that are available for use, in the same
138
    # order as in the grads vector.
139
    available_named_gradients: Sequence[str]
140

141
    # The named gradients that are used in any of the derivatives.
142
    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
143
    used_named_gradients: set[str]
144

145
    # The function's input arguments for which it calculates derivatives.
146
    # It's the union of 'var_names' of all 'derivatives', sorted by the
147
    # argument order in the function schema.
148
    args_with_derivatives: Sequence[Binding]
149

150
    # Names of arguments whose derivative formula is 'non_differentiable'.
151
    non_differentiable_arg_names: Sequence[str]
152

153
    # Raw data read from derivatives.yaml.
154
    output_differentiability: list[bool] | None
155

156
    # output_differentiability in derivatives.yaml can be a list of
157
    # conditions that express if the output is differentiable. In this case,
158
    # the number of conditions must match the number of outputs
159
    # (NB: we only support one condition right now).
160
    # output_differentiability gets populated with True for each condition,
161
    # while output_differentiability_conditions gets populated with the conditions
162
    output_differentiability_conditions: list[str] | None
163

164
    @property
165
    def has_derivatives(self) -> bool:
166
        return len(self.args_with_derivatives) > 0
167

168
    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
169
    # but with a new operator name.
170
    # This is used when generating "copy" variants of view ops,
171
    # which are able to use the exact same derivative formula as the original view op
172
    # See Note [Codegen'd {view}_copy Operators]
173
    def create_view_copy_from_view_derivative(
174
        self, g: NativeFunctionsViewGroup
175
    ) -> DifferentiabilityInfo | None:
176
        if g.view_copy is None:
177
            return None
178
        f = g.view_copy
179

180
        name_split_by_period = self.name.split(".", maxsplit=2)
181
        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
182
        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
183
            name_split_by_period[1:]
184
        )
185
        view_copy_op_name = None if self.op is None else f"{self.op}_copy"
186

187
        return DifferentiabilityInfo(
188
            # Use the "_copy" version of name/func/op
189
            name=view_copy_name,
190
            func=f,
191
            op=view_copy_op_name,
192
            # But keep all derivative info the same
193
            derivatives=self.derivatives,
194
            forward_derivatives=self.forward_derivatives,
195
            all_saved_inputs=self.all_saved_inputs,
196
            all_saved_outputs=self.all_saved_outputs,
197
            available_named_gradients=self.available_named_gradients,
198
            used_named_gradients=self.used_named_gradients,
199
            args_with_derivatives=self.args_with_derivatives,
200
            non_differentiable_arg_names=self.non_differentiable_arg_names,
201
            output_differentiability=self.output_differentiability,
202
            output_differentiability_conditions=self.output_differentiability_conditions,
203
        )
204

205

206
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
207
    if info is None:
208
        return False
209
    for derivative in info.derivatives:
210
        formula = derivative.formula
211
        if re.search(IDENT_REGEX.format(ident), formula):
212
            return True
213
    return False
214

215

216
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
217
    return uses_ident(info, "retain_variables")
218

219

220
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
221
    return uses_ident(info, "grad")
222

223

224
# Represents a differentiable `Argument`.
225
# How is it different from the `Argument` type?
226
# - It's processed Arguments which are differentiable and only used in the
227
#   context of the autograd codegen;
228
# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
229
@dataclass(frozen=True)
230
class DifferentiableInput:
231
    name: str
232
    type: Type
233

234
    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
235
    cpp_type: str
236

237

238
# Represents a differentiable `Return`.
239
# How it it different from the `Return` type?
240
# - The name in `Return` is optional. Here it is always populated using the same
241
#   `cpp.return_names()` method.
242
#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
243
# - It's processed Returns which are differentiable, in compliance with the
244
#   `output_differentiability` field defined in derivatives.yaml (if specified),
245
#   and are only used in the context of the autograd codegen;
246
@dataclass(frozen=True)
247
class DifferentiableOutput:
248
    name: str
249
    type: Type
250

251
    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
252
    cpp_type: str
253

254

255
@dataclass(frozen=True)
256
class NativeFunctionWithDifferentiabilityInfo:
257
    func: NativeFunction
258
    info: dict[str, DifferentiabilityInfo] | None
259
    fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
260

261

262
# TODO: Update comment below since it is out of date.
263
def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
264
    """How are we going to call the underlying implementation of a
265
    declaration?  There are two strategies:
266
        - use_derived: we want to call the implementation on CPUDoubleType
267
          (or a similar, derived Type instance).  Because these derived
268
          instances deal in Tensors, not Variables (it's a completely different
269
          object, so it doesn't dispatch back to VariableType), code on
270
          this dispatch path needs to wrap/unwrap tensors.  If the
271
          derived implementation takes and returns tensors, the
272
          implementation is usually differentiable (although we also use
273
          the derived dispatch path for non-differentiable functions
274
          that we still want to dispatch on the derived Type instance;
275
          e.g., size())
276
        - use_type: we want to call the implementation on Type, because
277
          it is implemented concretely, and the functions it invokes will
278
          get dispatched back to VariableType (which will ensure that they
279
          are differentiable.)
280
    """
281
    # fn is derived as long as any of its per-key differentiability infos
282
    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
283
    # and ADInplaceOrViewType. We want to generate these functions as long as a
284
    # derivative is defined for ANY dispatch key.
285
    if fn.func.is_abstract or (
286
        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
287
    ):
288
        # If the function is abstract (not implemented on at::Type), we must
289
        # call the implementation on the derived type with unpacked tensors.
290

291
        # If the function has a derivative specified and is concrete, we could
292
        # call either implementation. We prefer the calling the derived
293
        # type's implementation with unpacked tensors because it is more
294
        # performant in some cases: any internal calls to other ATen functions
295
        # won't have the history tracked.
296

297
        # If the function has a type dispatched argument (i.e. is a factory),
298
        # we prefer calling the derived type's implementation both because it is
299
        # more performant and to ensure factory functions return tensors with _version
300
        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
301
        # to understand.
302

303
        return "use_derived"
304
    else:
305
        # If the function is concrete (we don't have to override it) and we
306
        # didn't declare it in derivatives.yaml, we'll assume that it is
307
        # actually implemented out of differentiable functions. (This
308
        # assumption might not hold, but then you'll see gradcheck fail.)
309
        return "use_type"
310

311

312
def is_foreach_func(f: NativeFunction) -> bool:
313
    return f.func.name.name.base.startswith("_foreach_")
314

315

316
# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
317
# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
318
# they would find such one in `functional_info_by_signature`. There however are some exceptions:
319
_foreach_with_inplace_ref = {"_foreach_zero_"}
320
_foreach_with_tensor_overload = {
321
    "_foreach_add.Tensor",
322
    "_foreach_mul.Tensor",
323
    "_foreach_div.Tensor",
324
}
325
# The following do not support the alpha kwarg, which the nonforeach versions support.
326
_skip_argument_len_check = {
327
    "_foreach_add.Scalar",
328
    "_foreach_add_.Scalar",
329
    "_foreach_add.ScalarList",
330
    "_foreach_add_.ScalarList",
331
    "_foreach_sub.Scalar",
332
    "_foreach_sub_.Scalar",
333
    "_foreach_sub.ScalarList",
334
    "_foreach_sub_.ScalarList",
335
}
336

337

338
# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
339
# reference to generate derivatives.
340
def is_reference_for_foreach(
341
    f: NativeFunction,
342
    function_schema: FunctionSchema,
343
) -> bool:
344
    return (
345
        f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
346
        and (
347
            not function_schema.name.name.inplace
348
            or str(f.func.name) in _foreach_with_inplace_ref
349
        )
350
        and (
351
            str(f.func.name) in _skip_argument_len_check
352
            or len(f.func.arguments.flat_non_out)
353
            == len(function_schema.arguments.flat_non_out)
354
        )
355
        and all(
356
            ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
357
            for arg, ref_arg in zip(
358
                f.func.arguments.flat_non_out,
359
                function_schema.arguments.flat_non_out,
360
            )
361
        )
362
    )
363

364

365
# TODO(crcrpar): Avoid hard coding "Default" ideally.
366
def gen_foreach_derivativeinfo(
367
    foreach_function: NativeFunction,
368
    functional_info_by_signature: dict[
369
        FunctionSchema, dict[str, DifferentiabilityInfo]
370
    ],
371
    non_functional_info_by_signature: dict[
372
        FunctionSchema, dict[str, DifferentiabilityInfo]
373
    ],
374
    dispatch_key: str = "Default",
375
) -> tuple[DifferentiabilityInfo | None, bool]:
376
    """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
377

378
    The second return value indicates whether the info is generated in this function.
379
    """
380
    ref_diff_info: DifferentiabilityInfo | None = None
381

382
    for function_schema, diff_info in functional_info_by_signature.items():
383
        if not is_reference_for_foreach(foreach_function, function_schema):
384
            continue
385
        ref_diff_info = diff_info[dispatch_key]
386
        if ref_diff_info is not None:
387
            break
388
    # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
389
    # while the info of `zero_` is in non_functional_info_by_signature
390
    if (
391
        ref_diff_info is None
392
        and foreach_function.func.kind() == SchemaKind.inplace
393
        and str(foreach_function.func.name) in _foreach_with_inplace_ref
394
    ):
395
        for function_schema, diff_info in non_functional_info_by_signature.items():
396
            if not is_reference_for_foreach(foreach_function, function_schema):
397
                continue
398
            ref_diff_info = diff_info[dispatch_key]
399
            if ref_diff_info is not None:
400
                break
401
    if ref_diff_info is None:
402
        return None, False
403

404
    # non out-place uses the existing Derivative.
405
    if foreach_function.func.kind() == SchemaKind.inplace:
406
        return ref_diff_info, False
407

408
    map_refarg2foreacharg, map_name2arg = {}, {}
409
    for i, (arg, ref_arg) in enumerate(
410
        zip(
411
            foreach_function.func.arguments.flat_non_out,
412
            function_schema.arguments.flat_non_out,
413
        )
414
    ):
415
        map_refarg2foreacharg[ref_arg.name] = arg.name
416
        map_name2arg[arg.name] = arg
417

418
    all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
419
    modified_derivative_formulas = []
420
    for i, derivative in enumerate(ref_diff_info.derivatives):
421
        modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
422
            "result", "result[i]"
423
        )
424
        saved_inputs, saved_outputs = [], []
425
        # note(crcrpar): This context seems necessary to call `cpp.argument_type`
426
        with local.parametrize(
427
            use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
428
            use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
429
        ):
430
            for ref_input in derivative.saved_inputs:
431
                ref_input_jit_name = ref_input.expr.split(".")[0]
432
                mapped_name = map_refarg2foreacharg[ref_input_jit_name]
433
                if isinstance(map_name2arg[mapped_name].type, ListType):
434
                    mapped_expr = mapped_name + "[i]"
435
                else:
436
                    mapped_expr = mapped_name
437
                new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
438
                modified_formula = modified_formula.replace(
439
                    cast(str, ref_input.nctype.name), new_expr
440
                )
441

442
                nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
443
                canonical_nctype = NamedCType(
444
                    nctype.name, nctype.type.remove_const_ref()
445
                )
446
                saved_inputs.append(
447
                    SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
448
                )
449
            for ref_output in derivative.saved_outputs:
450
                if ref_output.nctype.name == "result":
451
                    saved_outputs.append(
452
                        SavedAttribute(
453
                            nctype=NamedCType(
454
                                name="result", type=BaseCType(tensorListT)
455
                            ),
456
                            expr="result",
457
                        )
458
                    )
459
                else:
460
                    raise RuntimeError("")
461
        var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
462
        all_var_names.extend(var_names)
463
        all_saved_inputs.extend(saved_inputs)
464
        all_saved_outputs.extend(saved_outputs)
465
        modified_derivative = Derivative(
466
            formula=modified_formula,
467
            original_formula=derivative.formula,
468
            var_names=tuple(var_names),
469
            saved_inputs=tuple(saved_inputs),
470
            saved_outputs=tuple(saved_outputs),
471
            named_gradients=set(),
472
        )
473
        modified_derivative_formulas.append(modified_derivative)
474

475
    with local.parametrize(
476
        use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
477
        use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
478
    ):
479
        args_with_derivatives = [
480
            Binding(
481
                name=arg.name,
482
                nctype=cpp.argument_type(arg, binds=arg.name),
483
                argument=arg,
484
                default=None,
485
            )
486
            for arg in foreach_function.func.arguments.flat_non_out
487
            if arg.name in all_var_names
488
        ]
489

490
    forward_derivatives: list[ForwardDerivative] = []
491
    fw_derivative: ForwardDerivative
492
    for fw_derivative in ref_diff_info.forward_derivatives:
493
        var_names: list[str] = list(fw_derivative.var_names)  # type: ignore[no-redef]
494
        var_types: list[Type] = list(fw_derivative.var_types)
495
        required_inputs_fw_grad: list[str] = []
496
        required_inputs_primal: list[str] = []
497
        if fw_derivative.required_inputs_fw_grad is not None:
498
            required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
499
        if fw_derivative.required_inputs_primal:
500
            required_inputs_primal = list(fw_derivative.required_inputs_primal)
501
        modified_formula = fw_derivative.formula
502

503
        # Foreach's result is TensorList
504
        if "result" in modified_formula:
505
            modified_formula = fw_derivative.formula.replace("result", "result[i]")
506

507
        for foreach_arg, ref_arg in zip(
508
            foreach_function.func.arguments.flat_non_out,
509
            ref_diff_info.func.func.arguments.flat_non_out,
510
        ):
511
            # Modify reference forward formula
512
            if (
513
                isinstance(foreach_arg.type, ListType)
514
                and not foreach_arg.type.is_tensor_like()
515
            ):
516
                # Assuming ScalarList
517
                modified_formula = modified_formula.replace(
518
                    ref_arg.name, foreach_arg.name + "[i]"
519
                )
520
            elif foreach_arg.type.is_tensor_like():
521
                # Assuming TensorList / Tensor
522
                # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
523
                assert isinstance(foreach_arg.type, ListType) or (
524
                    foreach_arg.type == BaseType(BaseTy.Tensor)
525
                    and str(foreach_function.func.name) in _foreach_with_tensor_overload
526
                ), f"{foreach_function.func.name}, {foreach_arg.type}"
527
                for suffix in ("_p", "_t"):
528
                    curr_expr = ref_arg.name + suffix
529
                    if curr_expr in modified_formula:
530
                        new_expr = foreach_arg.name + suffix
531
                        modified_formula = modified_formula.replace(curr_expr, new_expr)
532
            else:
533
                # Assuming Scalar
534
                if foreach_arg.name != ref_arg.name:
535
                    modified_formula = modified_formula.replace(
536
                        ref_arg.name, foreach_arg.name
537
                    )
538

539
            # note(crcrpar): there should exist a cooler way...
540
            for i, name in enumerate(var_names):
541
                if name == ref_arg.name:
542
                    var_names[i] = foreach_arg.name
543
                    var_types[i] = foreach_arg.type
544
            for i, name in enumerate(required_inputs_fw_grad):
545
                if name == ref_arg.name:
546
                    required_inputs_fw_grad[i] = foreach_arg.name
547
            for i, name in enumerate(required_inputs_primal):
548
                if name == ref_arg.name:
549
                    required_inputs_primal[i] = foreach_arg.name
550
        forward_derivatives.append(
551
            ForwardDerivative(
552
                formula=modified_formula,
553
                var_names=tuple(var_names),
554
                var_types=tuple(var_types),
555
                required_inputs_fw_grad=tuple(required_inputs_fw_grad),
556
                required_inputs_primal=tuple(required_inputs_primal),
557
                required_original_self_value=fw_derivative.required_original_self_value,
558
                is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
559
            )
560
        )
561

562
    return (
563
        DifferentiabilityInfo(
564
            name=foreach_function.func.name.name.base,
565
            func=foreach_function,
566
            op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
567
            derivatives=modified_derivative_formulas,
568
            forward_derivatives=forward_derivatives,
569
            all_saved_inputs=tuple(set(all_saved_inputs)),
570
            all_saved_outputs=tuple(set(all_saved_outputs)),
571
            available_named_gradients=(),
572
            used_named_gradients=set(),
573
            args_with_derivatives=args_with_derivatives,
574
            non_differentiable_arg_names=[],
575
            output_differentiability=None,
576
            output_differentiability_conditions=None,
577
        ),
578
        True,
579
    )
580

581

582
def match_differentiability_info(
583
    native_functions: list[NativeFunction],
584
    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
585
) -> list[NativeFunctionWithDifferentiabilityInfo]:
586
    """Sets the "derivative" key on declarations to matching autograd function
587
    In-place functions will use the out-of-place derivative definition if there
588
    is no in-place specific derivative.
589
    """
590

591
    functional_info_by_signature = {
592
        schema.signature(strip_default=True): info_dict
593
        for schema, info_dict in differentiability_infos.items()
594
        if schema.kind() == SchemaKind.functional
595
    }
596
    non_functional_info_by_signature = {
597
        schema.signature(strip_default=True): info_dict
598
        for schema, info_dict in differentiability_infos.items()
599
        if schema.kind() != SchemaKind.functional
600
    }
601

602
    def find_info(
603
        f: NativeFunction,
604
    ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
605
        # Don't bother matching info to generated out= variants
606
        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
607
            return None, False
608

609
        # (1) Check for an exact match
610
        if f.func in differentiability_infos:
611
            return differentiability_infos[f.func], True
612

613
        # (2) If no exact match, check if the out-of-place variant
614
        # of this operator has a match.
615
        # i.e mul() for mul_() or mul_out()
616
        # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
617
        # native functions instead of the out-place counterparts.
618
        f_sig = f.func.signature(strip_default=True)
619
        if f_sig in functional_info_by_signature and not is_foreach_func(f):
620
            return functional_info_by_signature[f_sig], False
621

622
        # (3) Some operators have a derivative explicitly defined for the mutable
623
        # variant, but get a code-generated out-of-place variant which does *not*
624
        # come with a derivative formula.
625
        # For the generated out-of-place variant, use the mutable variant's formula
626
        # if it exists.
627
        if "generated" in f.tags and f_sig in non_functional_info_by_signature:
628
            info_dict = non_functional_info_by_signature[f_sig]
629
            # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
630
            assert not any(
631
                any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs)
632
                for info in info_dict.values()
633
            ), f"""\
634
Attempted to convert a derivative formula for a mutable operator
635
 to be used by automatically by its functional variant ("{str(f.func)}").
636
 this is not currently supported (we'd need to fix up the formula in the codegen)."""
637
            return info_dict, False
638

639
        # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
640
        if is_foreach_func(f):
641
            assert f.func not in differentiability_infos
642
            diff_info, is_generated = gen_foreach_derivativeinfo(
643
                f,
644
                functional_info_by_signature,
645
                non_functional_info_by_signature,
646
            )
647
            if diff_info is None:
648
                return None, False
649
            # TODO(crcrpar): Avoid hard coding "Default" ideally.
650
            diff_info_dict = {"Default": diff_info}
651
            if is_generated:
652
                differentiability_infos[f.func] = diff_info_dict
653
                functional_info_by_signature[f.func] = diff_info_dict
654
            return diff_info_dict, is_generated
655

656
        return None, False
657

658
    result: list[NativeFunctionWithDifferentiabilityInfo] = []
659
    for f in native_functions:
660
        info_dict, is_exact_match = find_info(f)
661

662
        # Currently, the '.strides()' to 'strides_or_error' replacement does not support
663
        # 'self' derivatives of an inplace function, so we must check for this case.
664
        if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
665
            for info in info_dict.values():
666
                for derivative in info.derivatives:
667
                    if "self" in derivative.var_names:
668
                        for saved_input in derivative.saved_inputs:
669
                            assert "strides_or_error" not in saved_input.expr, (
670
                                "Calling '.strides()' in the 'self' derivative formula of an "
671
                                f"in-place function is not supported: {f.func}"
672
                            )
673

674
        if not info_dict:
675
            result.append(
676
                NativeFunctionWithDifferentiabilityInfo(
677
                    func=f, info=None, fw_derivatives=None
678
                )
679
            )
680
            continue
681

682
        fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
683
        for key, info in info_dict.items():
684
            if not info.forward_derivatives:
685
                fw_derivative_dict[key] = []
686
                continue
687

688
            forward_derivatives = info.forward_derivatives
689

690
            # For functions that have a single def for out-of-place and inplace (like abs())
691
            if f.func.kind() == SchemaKind.inplace:
692
                # For inplace functions there is a little bit of work to do:
693
                #  1) Validate the formula and make sure the input that is modified in not used:
694
                #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then
695
                #      we make sure that the original value of the input that is being modified inplace (self_p) is
696
                #      not used in the formula. Note that the formula can use "original_self_p" here and that would
697
                #      trigger a clone of the original input.
698
                #    - If we are re-using the out of place formula (is_exact_match == False) then we replace every
699
                #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be
700
                #      populated by cloned version of the original input (either the clone done by the backward AD
701
                #      logic if self is also used in a backward formula or a special clone that we add).
702
                #  2) At this point, there cannot be a self_p in the formula.
703
                #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
704
                #     simply called self (as it is modified inplace).
705
                #  4) Update the required primals data in case it used to contain "result" but should now contain
706
                #     "self"
707
                #  5) If it is not an exact match, the user formula is not modifying the existing forward grad
708
                #     inplace as it should. So add some code that makes sure that we do so if the forward grad
709
                #     already exists.
710

711
                assert (
712
                    len(info.forward_derivatives) == 1
713
                )  # Only single output inplace should exist
714
                fw_info = info.forward_derivatives[0]
715
                formula = fw_info.formula
716

717
                def replace_self_with_original_self(formula: str, postfix: str) -> str:
718
                    def repl(m: re.Match[str]) -> str:
719
                        return f"{m.group(1)}original_self{postfix}{m.group(2)}"
720

721
                    return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
722

723
                if re.search(IDENT_REGEX.format("self_p"), formula):
724
                    if is_exact_match:
725
                        # For manually defined formulas, don't allow the original value to be used
726
                        raise RuntimeError(
727
                            f'The formula for "{f.func.name}" is using the original value of self '
728
                            "that is being modified inplace. This would lead to wrong forward gradients. "
729
                            'Please use "result" in the formula only.'
730
                        )
731
                    else:
732
                        # When the original formula is out of place, we save a clone of the primal
733
                        # value to be able to access this value if needed
734
                        # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
735
                        formula = replace_self_with_original_self(formula, "_p")
736
                        formula = replace_self_with_original_self(formula, "_t")
737

738
                # replace "result" from the formula by "self_p"
739
                def repl(m: re.Match[str]) -> str:
740
                    return f"{m.group(1)}self_p{m.group(2)}"
741

742
                formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
743

744
                required_primals = fw_info.required_inputs_primal
745
                if re.search(IDENT_REGEX.format("self_p"), formula):
746
                    required_primals = (
747
                        required_primals + ("self",) if required_primals else ("self",)
748
                    )
749

750
                if not is_exact_match:
751
                    # NOTE [In-place forward AD formula Optimization]
752
                    #
753
                    # This optimization transforms the formula to directly do inplace, i.e.
754
                    # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
755
                    #
756
                    # 1) the formula satisfies the pattern: "self_t.op(*args)"
757
                    # 2) "op" in (1) needs to be the same as the op the derivative is for
758
                    #
759
                    # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
760
                    # If there is a need, we can relax (2) to allow any op that has an in-place variant
761
                    is_single_method_on_self_t = False
762
                    directly_do_inplace = False
763
                    op_name: str | None = None
764
                    between_parens: str | None = None
765
                    match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
766
                    if match:
767
                        op_name, between_parens = match.group(1), match.group(2)
768

769
                        # We want to...
770
                        #   Match: self_t.op1(other_p.op2(arg))
771
                        #   Avoid: self_t.op1(args) + self_t.op2(args)
772
                        #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
773
                        def check_parens_nest_level_gt_zero(s: str) -> bool:
774
                            level = 1
775
                            for ch in s:
776
                                if ch == ")":
777
                                    level -= 1
778
                                    if level == 0:
779
                                        return False
780
                                if ch == "(":
781
                                    level += 1
782
                            return True
783

784
                        is_single_method_on_self_t = check_parens_nest_level_gt_zero(
785
                            between_parens
786
                        )
787
                        directly_do_inplace = (
788
                            is_single_method_on_self_t and op_name == info.name
789
                        )
790

791
                    if directly_do_inplace:
792
                        assert op_name is not None
793
                        assert between_parens is not None
794
                        formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
795
                    else:
796
                        # Make sure that the forward grad is modified inplace when the original formula
797
                        # is out of place
798
                        formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
799

800
                required_original_self_value = bool(
801
                    re.search(IDENT_REGEX.format("original_self_p"), formula)
802
                ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
803

804
                forward_derivatives = [
805
                    ForwardDerivative(
806
                        formula=formula,
807
                        var_names=("self",),
808
                        var_types=fw_info.var_types,
809
                        required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
810
                        required_inputs_primal=required_primals,
811
                        required_original_self_value=required_original_self_value,
812
                        is_reusing_outplace_formula=not is_exact_match,
813
                    ),
814
                ]
815

816
            fw_derivative_dict[key] = forward_derivatives
817

818
        result.append(
819
            NativeFunctionWithDifferentiabilityInfo(
820
                func=f, info=info_dict, fw_derivatives=fw_derivative_dict
821
            )
822
        )
823

824
    return result
825

826

827
def is_differentiable(
828
    name: str, type: Type, info: DifferentiabilityInfo | None
829
) -> bool:
830
    return type.is_tensor_like() and (
831
        info is None or name not in info.non_differentiable_arg_names
832
    )
833

834

835
def gen_differentiable_outputs(
836
    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
837
) -> list[DifferentiableOutput]:
838
    f = fn.func
839
    info = fn.info[key] if fn.info else None
840
    outputs: list[DifferentiableOutput] = [
841
        DifferentiableOutput(
842
            name=name,
843
            type=ret.type,
844
            cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
845
        )
846
        for name, ret in zip(cpp.return_names(f), f.func.returns)
847
    ]
848
    output_differentiability = info.output_differentiability if info else None
849
    if output_differentiability is not None:
850
        if len(output_differentiability) != len(outputs):
851
            raise RuntimeError(
852
                f"The length of output_differentiability ({len(output_differentiability)}), "
853
                f"does not match the number of outputs ({len(outputs)})."
854
            )
855
        differentiable_outputs: list[DifferentiableOutput] = []
856
        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
857
            raise RuntimeError(
858
                "output_differentiability=False for inplace operation (version_counter won't get updated)"
859
            )
860
        for differentiable, output in zip(output_differentiability, outputs):
861
            if differentiable:
862
                differentiable_outputs.append(output)
863
        return differentiable_outputs
864
    candidate_differentiable_outputs = list(
865
        filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
866
    )
867
    if uses_single_grad(info):
868
        return candidate_differentiable_outputs[:1]
869
    else:
870
        return candidate_differentiable_outputs
871

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.