pytorch

Форк
0
/
autograd_function.py 
752 строки · 27.3 Кб
1
# mypy: allow-untyped-defs
2
from typing import Any, NamedTuple, Tuple
3

4
import torch
5
import torch.utils._pytree as pytree
6
from torch._C._functorch import (
7
    _unwrap_for_grad,
8
    _wrap_for_grad,
9
    current_level,
10
    TransformType,
11
)
12
from torch._functorch.apis import vmap
13
from torch._functorch.utils import enable_single_level_autograd_function
14
from torch._functorch.vmap import (
15
    _add_batch_dim,
16
    _broadcast_to_and_flatten,
17
    restore_vmap,
18
    unwrap_batched,
19
    wrap_batched,
20
)
21
from torch._ops import HigherOrderOperator
22
from torch.autograd.forward_ad import _set_fwd_grad_enabled
23

24

25
# autograd.Function technically runs before the regular PyTorch dispatcher.
26
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
27
# work with it. One day we might decide to change this, but until then,
28
# we need to give the illusion that autograd.Function runs before those things.
29
#
30
# We do this by using creating a custom HigherOrderOperator that only functorch
31
# dispatches specially.
32
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
33
    def __init__(self) -> None:
34
        super().__init__("custom_function_call")
35

36
    def __call__(self, autograd_function, *args, **kwargs):
37
        # When custom_function_call is done dispatching through functorch,
38
        # it should just invoke the autograd.Function. This is consistent
39
        # with the autograd.Function behavior of being invoked before the
40
        # PyTorch dispatcher.
41
        #
42
        # This will lead us into trouble later down the line, but this is
43
        # pre-existing. There is an invariant that a function traced by
44
        # make_fx should have the same behavior when provided the same
45
        # Tensor. However, make_fx sees autograd.Function as a composite
46
        # (because autograd.Function happens before the Python dispatch key)
47
        # and only traces the forward pass.
48
        if torch._C._are_functorch_transforms_active():
49
            return super().__call__(autograd_function, *args, **kwargs)
50
        return autograd_function.apply(*args, **kwargs)
51

52

53
# "custom_function_call"
54
# This is the mechanism for an autograd.Function that works with functorch transforms.
55
# It wraps an autograd.Function; interactions with functorch transforms are defined
56
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
57
# dispatcher.
58
custom_function_call = CustomFunctionHigherOrderOperator()
59

60

61
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
62
# (autograd.Function that only works with a single layer (level) of functorch) that:
63
# - unwraps the inputs
64
# - redispatches to custom_function_call
65
# - wraps the outputs
66
# and whose backward pass calls the original autograd.Function's backward.
67
#
68
# Why do we need to redispatch to custom_function_call?
69
# -----------------------------------------------------
70
# This is consistent with how ATen operators work with functorch's grad transform:
71
# they always redispatch to the original operator.
72
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
73
#
74
# grad1 will:
75
# - set up the autograd graph
76
# - unwrap the inputs
77
# - redispatch to at::sin (*)
78
# - rewrap the outputs on the return
79
#
80
# On the redispatch in (*), grad0 will:
81
# - set up the autograd graph
82
# - unwrap the inputs
83
# - redispatch to at::sin
84
# - rewrap the outputs on the return
85
#
86
# To "set up the autograd graph", we generate a _SingleLevelFunction
87
# and apply it.
88
@custom_function_call.py_impl(TransformType.Grad)
89
@custom_function_call.py_impl(TransformType.Jvp)
90
def custom_function_call_grad(interpreter, autograd_function, *operands):
91
    Generated = generate_single_level_function(interpreter, autograd_function)
92
    with enable_single_level_autograd_function():
93
        flat_out = Generated.apply(*operands)
94
    return flat_out
95

96

97
def generate_single_level_function(interpreter, autograd_function):
98
    level = interpreter.level()
99

100
    def forward(*operands):
101
        unwrapped_operands = pytree.tree_map_only(
102
            torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
103
        )
104
        # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
105
        # the transform. _SingleLevelFunction will turn off both fwd and bwd
106
        # gradient computation and we need to turn it back on here.
107
        with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
108
            unwrapped_output = custom_function_call(
109
                autograd_function, *unwrapped_operands
110
            )
111

112
        # See NOTE [mark_dirty object identity check]
113
        def wrap_fn(output):
114
            return _wrap_for_grad(output, level)
115

116
        return wrap_outputs_maintaining_identity(
117
            unwrapped_output, unwrapped_operands, operands, wrap_fn
118
        )
119

120
    def setup_context(ctx, inputs, output):
121
        return autograd_function.setup_context(ctx, inputs, output)
122

123
    # backward is only used if the transform is TransformType.Grad
124
    def backward(ctx, *grads):
125
        result = autograd_function.backward(ctx, *grads)
126
        return result
127

128
    # jvp is only used if the transform is TransformType.Jvp
129
    def jvp(ctx, *tangents):
130
        result = autograd_function.jvp(ctx, *tangents)
131
        return result
132

133
    # This is the sequence of magic words to dynamically generate a Subclass with
134
    # a given name. A Tensor's .grad_fn field has a class name that is the original
135
    # autograd.Function's name + Backward, so we do this to generate some
136
    # meaningful name.
137
    name = f"{autograd_function.__name__}Generated"
138
    Generated = type(
139
        name,
140
        (torch.autograd.function._SingleLevelFunction,),
141
        {
142
            "forward": staticmethod(forward),
143
            "backward": staticmethod(backward),
144
            "jvp": staticmethod(jvp),
145
            "setup_context": staticmethod(setup_context),
146
        },
147
    )
148
    return Generated
149

150

151
# wrap_outputs_maintaining_identity handles outputs from the vmap,
152
# backward (vjp), and jvp staticmethod. The way it distinguishes
153
# between the vmap case and the {backward, jvp} case is if the out_dims
154
# are specified or not.
155
#
156
# NB: we cannot use out_dims=None as the deciding factor. This because
157
# out_dims=None can still happen in the vmap staticmethod! What the
158
# user is saying in that case is that their output does not have a
159
# dimension that is being vmapped over, which is valid.
160
NO_OUT_DIMS = "not specified"
161

162

163
# NOTE [mark_dirty object identity check]
164
# autograd.Function's ctx.mark_dirty expect a returned input
165
# to have the same object identity as the input.
166
# Mode-only functorch will greatly simplify this logic.
167
def wrap_outputs_maintaining_identity(
168
    outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
169
):
170
    flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
171
    flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
172

173
    unwrapped_input_to_orig_input = {
174
        id(unwrapped): orig
175
        for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
176
    }
177

178
    flat_outputs, spec = pytree.tree_flatten(outputs)
179
    result = []
180

181
    out_dims_specified = out_dims != NO_OUT_DIMS
182

183
    if out_dims_specified:
184
        flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
185
        # _broadcast_to_and_flatten returns None if it is unable to broadcast.
186
        # TODO: update following link from master to stable once that's out
187
        if flat_out_dims is None:
188
            raise RuntimeError(
189
                f"The autograd.Function's vmap staticmethod returned an "
190
                f"incompatible (output, out_dims) tuple. "
191
                f"Expected out_dims={out_dims} "
192
                f"to be compatible with the structure of `output`. "
193
                f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
194
                f"but output has structure {spec}. "
195
                f"For more details, please see "
196
                f"https://pytorch.org/docs/main/notes/extending.func.html"
197
            )
198

199
    for i, output in enumerate(flat_outputs):
200
        if not isinstance(output, torch.Tensor):
201
            result.append(output)
202
            continue
203
        if id(output) in unwrapped_input_to_orig_input:
204
            result.append(unwrapped_input_to_orig_input[id(output)])
205
            continue
206
        if out_dims_specified:
207
            result.append(wrap_fn(output, flat_out_dims[i]))  # type: ignore[possibly-undefined, index]
208
        else:
209
            result.append(wrap_fn(output))
210

211
    return pytree.tree_unflatten(result, spec)
212

213

214
# NOTE: [functorch vjp and autograd interaction]
215
# There's an edge case with the functorch vjp and autograd interaction
216
# that will eventually be fixed by mode-only functorch.
217
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
218
# so we (the framework) need to do it manually. Regular PyTorch operators
219
# automatically do so this is consistent.
220
#
221
# class MyExp(torch.autograd.Function):
222
#     @staticmethod
223
#     def forward(x):
224
#         return x.exp()
225
#
226
#     @staticmethod
227
#     def setup_context(ctx, inputs, output):
228
#         y = output
229
#         ctx.save_for_backward(y)
230
#
231
#     @staticmethod
232
#     def backward(gy):
233
#         y, = ctx.saved_tensors()
234
#         return MyMul.apply(gy, y)
235
#
236
# x = torch.randn([], requires_grad=True)
237
# gy = torch.randn([], requires_grad=True)
238
# _, vjp_fn = vjp(MySin.apply, x)
239
# result = vjp_fn(gy)
240
#
241
# MyMul is an autograd.Function that is not shown here.
242
# It saves a `y` for backward (since gy requires grad).
243
#
244
# in vjp_fn(gy), we get:
245
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
246
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
247
# but is now dead since we are outside the vjp context.
248
#
249
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
250
# will automatically unwrap the GradTensorWrapper when applied.
251
# But since autograd.Function technically sits above the regular PyTorch
252
# dispatcher, it doesn't get this treatment. So we manually do
253
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
254

255

256
class VmapInfo(NamedTuple):
257
    batch_size: int
258
    randomness: str
259

260

261
def has_overriden_vmap_rule(autograd_function):
262
    return autograd_function.vmap is not torch.autograd.Function.vmap
263

264

265
def validate_vmap_returns_tuple_of_two_elements(result):
266
    base_error_msg = (
267
        "Expected the vmap staticmethod to have two returns, an output "
268
        "and out_dims with pytree structure compatible with the output. "
269
    )
270
    if not isinstance(result, tuple):
271
        raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
272
    if not len(result) == 2:
273
        raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
274

275

276
@custom_function_call.py_impl(TransformType.Vmap)
277
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
278
    if any(
279
        isinstance(val, torch.Tensor)
280
        for val in torch.utils._pytree.tree_flatten(kwargs)[0]
281
    ):
282
        raise NotImplementedError(
283
            f"Run vmap on autograd.Function with kwarg-only Tensor args. "
284
            f"Please do not pass kwarg-only Tensors to autograd.Function. "
285
            f"Got: {kwargs}"
286
        )
287

288
    if autograd_function.generate_vmap_rule:
289
        if has_overriden_vmap_rule(autograd_function):
290
            # TODO: Update link to stable once that's out
291
            # https://github.com/pytorch/pytorch/issues/92029
292
            raise RuntimeError(
293
                f"You tried to vmap over {autograd_function.__name__}, but "
294
                f"it has both generate_vmap_rule=True and an overriden vmap "
295
                f"staticmethod. Please set generate_vmap_rule=False or delete "
296
                f"the overriden vmap staticmethod to avoid ambiguity. "
297
                f"For more details, please see "
298
                f"https://pytorch.org/docs/main/notes/extending.func.html"
299
            )
300
        return custom_function_call_vmap_generate_rule(
301
            interpreter, autograd_function, *operands
302
        )
303

304
    if not has_overriden_vmap_rule(autograd_function):
305
        # TODO: Update link to stable once that's out
306
        # https://github.com/pytorch/pytorch/issues/92029
307
        raise RuntimeError(
308
            f"You tried to vmap over {autograd_function.__name__}, but "
309
            f"it does not have vmap support. Please override and implement the "
310
            f"vmap staticmethod or set generate_vmap_rule=True. "
311
            f"For more details, please see "
312
            f"https://pytorch.org/docs/main/notes/extending.func.html"
313
        )
314

315
    return custom_function_call_vmap_helper(
316
        interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
317
    )
318

319

320
def custom_function_call_vmap_helper(
321
    interpreter, vmap_function, op, *operands, **kwargs
322
):
323
    current_level = interpreter.level()
324
    info = VmapInfo(
325
        batch_size=interpreter.batch_size(),
326
        randomness=interpreter.randomness(),
327
    )
328
    unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
329
    # If none of the tensors are batched at the current level, then we skip the
330
    # current level. This saves the user from needing to handle this case in
331
    # their vmap staticmethod (and is consistent with our C++ batching rule API)
332
    if pytree.tree_all(lambda dim: dim is None, in_dims):
333
        with interpreter.lower():
334
            if isinstance(op, torch.autograd.function.FunctionMeta):
335
                return custom_function_call(op, *operands)
336
            else:
337
                return op(*operands, **kwargs)
338

339
    with interpreter.lower():
340
        result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
341
    validate_vmap_returns_tuple_of_two_elements(result)
342
    unwrapped_output, out_dims = result
343

344
    # See NOTE [mark_dirty object identity check]
345
    def wrap_fn(output, out_dim):
346
        return (
347
            output
348
            if out_dim is None
349
            else _add_batch_dim(output, out_dim, current_level)
350
        )
351

352
    return wrap_outputs_maintaining_identity(
353
        unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
354
    )
355

356

357
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
358
    unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
359
    vmapped_function, get_out_dims = vmapify_autograd_function(
360
        autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
361
    )
362

363
    with interpreter.lower():
364
        output = custom_function_call(vmapped_function, *unwrapped_operands)
365

366
    out_dims = get_out_dims()
367
    return wrap_batched(output, out_dims, interpreter.level())
368

369

370
@custom_function_call.py_impl(TransformType.Functionalize)
371
def custom_function_call_functionalize(
372
    interpreter, autograd_function, generate_vmap_rule, *operands
373
):
374
    raise RuntimeError("NYI: Functionalize rule for custom_function_call")
375

376

377
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
378
    # The following values are saved from the forward() and setup_context()
379
    # and used in backward().
380
    # Why do we save the values out here instead of on the ctx object?
381
    # - out_dims: There's no way to retrieve this from forward()
382
    # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
383
    #   vmap(vmap( but not completely sure if it is a problem. If we
384
    #   assigned those fields to the ctx object, the worry is that they
385
    #   get overwritten.
386
    init_val = "not populated"
387
    out_dims = init_val
388
    input_shapes: Any = init_val
389
    saved_tensors_bdims: Any = init_val
390

391
    def forward(*operands):
392
        nonlocal out_dims
393
        outputs, out_dims = restore_vmap(
394
            autograd_function.forward, in_dims, batch_size, randomness
395
        )(*operands)
396
        return outputs
397

398
    def setup_context(ctx, inputs, outputs):
399
        input_shapes_ = None
400
        saved_tensors_bdims_ = None
401

402
        def inner(inputs, outputs):
403
            # wrapped_ctx.save_for_backward will:
404
            # - unwrap batchedtensors into (tensor, bdim)
405
            # - save_for_backward(*unwrapped_tensors)
406
            # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
407
            wrapped_ctx = CtxCustomSave(ctx, current_level())
408
            autograd_function.setup_context(wrapped_ctx, inputs, outputs)
409

410
            # input_shapes are used for reductify later to reduce expanded gradients
411
            # to the correct shape.
412
            # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
413
            # for more details
414
            nonlocal input_shapes_
415
            input_shapes_ = tuple(
416
                inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
417
            )
418
            nonlocal saved_tensors_bdims_
419
            saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
420

421
        # See NOTE: [Why do we need to run setup_context under a vmap?]
422
        restore_vmap(
423
            inner,
424
            (in_dims, out_dims),
425
            batch_size,
426
            randomness,
427
        )(inputs, outputs)
428

429
        nonlocal input_shapes
430
        input_shapes = input_shapes_
431
        nonlocal saved_tensors_bdims
432
        saved_tensors_bdims = saved_tensors_bdims_
433

434
    def jvp(ctx, *tangents):
435
        assert out_dims != init_val
436
        assert saved_tensors_bdims != init_val
437

438
        def jvp_no_context(saved_tensors, tangents):
439
            wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
440
            return autograd_function.jvp(wrapped_ctx, *tangents)
441

442
        tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
443
        out_tangents, out_tangents_dims = restore_vmap(
444
            jvp_no_context,
445
            (saved_tensors_bdims, tangent_in_dims),
446
            batch_size,
447
            randomness,
448
        )(ctx.saved_tensors, tangents)
449

450
        result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
451
        return result
452

453
    def backward(ctx, *grad_outputs):
454
        assert out_dims != init_val
455
        assert input_shapes != init_val
456
        assert saved_tensors_bdims != init_val
457

458
        def backward_no_context(inputs):
459
            saved_tensors, grad_outputs = inputs
460
            wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
461
            return autograd_function.backward(wrapped_ctx, *grad_outputs)
462

463
        grad_ins, grad_ins_dims = restore_vmap(
464
            backward_no_context,
465
            ((saved_tensors_bdims, out_dims),),
466
            batch_size,
467
            randomness,
468
        )((ctx.saved_tensors, grad_outputs))
469
        result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
470
        return result
471

472
    name = f"Vmapped{autograd_function.__name__}"
473
    Generated = type(
474
        name,
475
        (torch.autograd.Function,),
476
        {
477
            "forward": staticmethod(forward),
478
            "backward": staticmethod(backward),
479
            "jvp": staticmethod(jvp),
480
            "setup_context": staticmethod(setup_context),
481
            "generate_vmap_rule": True,
482
        },
483
    )
484

485
    def get_out_dims():
486
        assert out_dims != init_val
487
        return out_dims
488

489
    return Generated, get_out_dims
490

491

492
# tangents might be None, so we need to replace
493
# the corresponding in_dims with None.
494
def get_tangents_in_dims(input_dims, tangents):
495
    flat_in_dims, spec = pytree.tree_flatten(input_dims)
496
    flat_tangents = pytree.arg_tree_leaves(*tangents)
497
    result = [
498
        None if tangent is None else in_dim
499
        for in_dim, tangent in zip(flat_in_dims, flat_tangents)
500
    ]
501
    return pytree.tree_unflatten(result, spec)
502

503

504
# NOTE: [Why do we need to run setup_context under a vmap?]
505
# Consider the following autograd.Function
506
#
507
# class Sum(torch.autograd.Function):
508
#    @staticmethod
509
#    def forward(x):
510
#        return x.sum()
511
#    @staticmethod
512
#    def setup_context(ctx, inputs, outputs):
513
#        ctx.x_shape = inputs[0]
514
#    @staticmethod
515
#    def backward(ctx, gy):
516
#        return gy.expand(ctx.x_shape)
517
#
518
# x = torch.randn(B, 4)
519
# in_dims = 0
520
# vmap(Sum.apply, in_dims)(x)
521
#
522
# Let's assume for a moment that we didn't vmap setup_context in VmappedSum:
523
#
524
# class VmappedSum(torch.autograd.Function):
525
#    @staticmethod
526
#    def forward(x):
527
#        return vmap(Sum.forward, in_dims)(x)
528
#
529
#    @staticmethod
530
#    def setup_context(ctx, inputs, outputs):
531
#        Sum.setup_context(ctx, inputs, outputs)
532
#
533
#    @staticmethod
534
#    def backward(ctx, gy):
535
#        def backward_no_context(gy):
536
#            return gy.expand(ctx.x_shape)
537
#
538
#        dims = (0,)
539
#        gx = vmap(backward_no_context, dims)(gy)
540
#        return gx
541
#
542
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
543
# and we're doing:
544
#
545
# def backward_no_context(gy):
546
#     return gy.expand([B, 4])
547
#
548
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
549
#
550
# This gives us the wrong result (gx has shape [B, B, 4], but it should
551
# have shape [4]). Performing vmap over setup_context means the shape
552
# saved has shape [4] and leads to a correct result shape for gx.
553

554

555
# Wraps a ctx object. Forwards all attr accesses to the underlying object
556
# except for the attrs in _pt_attrs
557
class WrappedCtx:
558
    _pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
559

560
    def __init__(self, ctx):
561
        if not isinstance(ctx, WrappedCtx):
562
            reserved_attrs = type(self)._pt_reserved_attrs
563
            for name in reserved_attrs:
564
                if not hasattr(ctx, name):
565
                    continue
566
                raise RuntimeError(
567
                    f"PyTorch reserves the {reserved_attrs} field on ctx. "
568
                    "Please name your fields on ctx something else to avoid name "
569
                    "collision."
570
                )
571
        self._pt_inner_ctx = ctx
572

573
    def __getattr__(self, name):
574
        return getattr(self._pt_inner_ctx, name)
575

576
    def __setattr__(self, name, value):
577
        if name in type(self)._pt_reserved_attrs:
578
            self.__dict__[name] = value
579
            return
580
        return setattr(self._pt_inner_ctx, name, value)
581

582

583
# Wraps ctx to create a new ctx object that overrides saved_tensors.
584
class CtxWithSavedTensors(WrappedCtx):
585
    _pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
586

587
    def __init__(self, ctx, new_saved_tensors):
588
        super().__init__(ctx)
589
        self._pt_new_saved_tensors = new_saved_tensors
590

591
    @property
592
    def saved_tensors(self):
593
        return self._pt_new_saved_tensors
594

595

596
class CtxCustomSave(WrappedCtx):
597
    _pt_reserved_attrs = (
598
        "_pt_saved_tensors_bdims",
599
        "_pt_current_level",
600
        *WrappedCtx._pt_reserved_attrs,
601
    )
602

603
    def __init__(self, ctx, current_level):
604
        super().__init__(ctx)
605
        self._pt_saved_tensors_bdims = ()
606
        self._pt_current_level = current_level
607

608
    def save_for_backward(self, *tensors):
609
        unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
610
        self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
611
        self._pt_saved_tensors_bdims = bdims
612

613
    def save_for_forward(self, *tensors):
614
        unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
615
        self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
616
        self._pt_saved_tensors_bdims = bdims
617

618

619
def reductify(
620
    grad_input,
621
    grad_input_bdim,
622
    input_bdim,
623
    batch_size,
624
    target_shape_without_bdim_to_reduce_to=None,
625
):
626
    if not isinstance(grad_input, tuple):
627
        grad_input = (grad_input,)
628
    if not isinstance(grad_input_bdim, tuple):
629
        grad_input_bdim = (grad_input_bdim,)
630
    if not isinstance(input_bdim, tuple):
631
        input_bdim = (input_bdim,)
632

633
    if target_shape_without_bdim_to_reduce_to is None:
634
        target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
635
    result = tuple(
636
        reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
637
        for gi, gi_bdim, i_bdim, maybe_ishape in zip(
638
            grad_input,
639
            grad_input_bdim,
640
            input_bdim,
641
            target_shape_without_bdim_to_reduce_to,
642
        )
643
    )
644
    return result
645

646

647
def reductify_leaf(
648
    grad_input,
649
    grad_input_bdim,
650
    input_bdim,
651
    batch_size,
652
    target_shape_without_bdim_to_reduce_to=None,
653
):
654
    if grad_input is None:
655
        return None
656

657
    if grad_input_bdim is None and input_bdim is None:
658
        return grad_input
659

660
    if grad_input_bdim is not None and input_bdim is None:
661
        return grad_input.sum(grad_input_bdim)
662

663
    # NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
664
    # For reverse-mode AD,
665
    # given a grad_input and input, it is valid for the user to return a
666
    # grad_input that has a broadcasted shape when compared to the input.
667
    # In this situation, autograd automatically reduces the grad_input to
668
    # the shape of the input.
669
    #
670
    # However, when input_bdim is not None, we have problems.
671
    #
672
    # [example 1]
673
    # grad_input: Tensor[3, 4], input: Tensor[B, 4]
674
    # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
675
    # from [B, 4].
676
    #
677
    # [example 2]
678
    # grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
679
    # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
680
    # from [B, 4].
681
    #
682
    # This means that we need to also reduce the grad_input to the shape of the
683
    # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
684
    # if not-None then we do the reducing manually, otherwise, we do not do a reduction.
685
    assert input_bdim is not None
686

687
    if grad_input_bdim is None:
688
        grad_input = grad_input.unsqueeze(input_bdim)
689
        new_shape = list(grad_input.shape)
690
        new_shape[input_bdim] = batch_size
691
        grad_input = grad_input.expand(new_shape)
692
        grad_input_bdim = input_bdim
693

694
    if target_shape_without_bdim_to_reduce_to is not None:
695
        return vmap(
696
            torch.Tensor.sum_to_size,
697
            in_dims=(grad_input_bdim, None),
698
            out_dims=input_bdim,
699
        )(grad_input, target_shape_without_bdim_to_reduce_to)
700

701
    if input_bdim != grad_input_bdim:
702
        grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
703
    return grad_input
704

705

706
def autograd_function_forward_rewritten(original_forward, original_setup_context):
707
    def new_forward(ctx, *args, **kwargs):
708
        output = original_forward(*args, **kwargs)
709
        original_setup_context(ctx, args, output)
710
        return output
711

712
    return new_forward
713

714

715
class AutogradFunctionApply(HigherOrderOperator):
716
    def __init__(self) -> None:
717
        super().__init__("autograd_function_apply")
718

719
    def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
720
        saved_values = None
721
        args_tensor_mask = fwd_kwargs["args_tensor_mask"]
722
        non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
723
        length_of_tensor_args = sum(args_tensor_mask)
724
        # Filter out the original tensor args from fwd_args,
725
        # lifted freevars should not be args of ApplyTemplate.apply
726
        # since we don't need to calculate the gradients of them.
727
        new_fwd_args = fwd_args[:length_of_tensor_args]
728

729
        class ApplyTemplate(torch.autograd.Function):
730
            @staticmethod
731
            def forward(ctx, *args):
732
                nonlocal saved_values
733
                output, saved_values = fwd(None, *fwd_args)
734

735
                # If users call ctx.mark_non_differentiable() in the original fwd function.
736
                if len(non_differentiable_idx) > 0:
737
                    non_differentiable_output = []
738
                    for i, x in enumerate(output):
739
                        if i in non_differentiable_idx:
740
                            non_differentiable_output.append(x)
741
                    ctx.mark_non_differentiable(*non_differentiable_output)
742

743
                return output
744

745
            @staticmethod
746
            def backward(ctx, *grad):
747
                return bwd(None, *grad, *saved_values)
748

749
        return ApplyTemplate.apply(*new_fwd_args)
750

751

752
autograd_function_apply = AutogradFunctionApply()
753

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.