pytorch

Форк
0
/
autograd_function.py 
659 строк · 25.3 Кб
1
import torch
2
from torch._ops import HigherOrderOperator
3
from torch._C._functorch import TransformType
4
from torch._functorch.utils import enable_single_level_autograd_function
5
import torch.utils._pytree as pytree
6
from torch._C._functorch import (
7
    _wrap_for_grad,
8
    _unwrap_for_grad,
9
    current_level,
10
)
11
from torch._functorch.vmap import (
12
    wrap_batched,
13
    unwrap_batched,
14
    restore_vmap,
15
    _add_batch_dim,
16
)
17
from torch._functorch.apis import vmap
18
from torch._functorch.vmap import _broadcast_to_and_flatten
19
from torch.autograd.forward_ad import _set_fwd_grad_enabled
20
from typing import Any, NamedTuple, Tuple
21

22
# autograd.Function technically runs before the regular PyTorch dispatcher.
23
# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
24
# work with it. One day we might decide to change this, but until then,
25
# we need to give the illusion that autograd.Function runs before those things.
26
#
27
# We do this by using creating a custom HigherOrderOperator that only functorch
28
# dispatches specially.
29
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
30
    def __init__(self):
31
        super().__init__('custom_function_call')
32

33
    def __call__(self, autograd_function, *args, **kwargs):
34
        # When custom_function_call is done dispatching through functorch,
35
        # it should just invoke the autograd.Function. This is consistent
36
        # with the autograd.Function behavior of being invoked before the
37
        # PyTorch dispatcher.
38
        #
39
        # This will lead us into trouble later down the line, but this is
40
        # pre-existing. There is an invariant that a function traced by
41
        # make_fx should have the same behavior when provided the same
42
        # Tensor. However, make_fx sees autograd.Function as a composite
43
        # (because autograd.Function happens before the Python dispatch key)
44
        # and only traces the forward pass.
45
        if torch._C._are_functorch_transforms_active():
46
            return super().__call__(autograd_function, *args, **kwargs)
47
        return autograd_function.apply(*args, **kwargs)
48

49

50
# "custom_function_call"
51
# This is the mechanism for an autograd.Function that works with functorch transforms.
52
# It wraps an autograd.Function; interactions with functorch transforms are defined
53
# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
54
# dispatcher.
55
custom_function_call = CustomFunctionHigherOrderOperator()
56

57

58
# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
59
# (autograd.Function that only works with a single layer (level) of functorch) that:
60
# - unwraps the inputs
61
# - redispatches to custom_function_call
62
# - wraps the outputs
63
# and whose backward pass calls the original autograd.Function's backward.
64
#
65
# Why do we need to redispatch to custom_function_call?
66
# -----------------------------------------------------
67
# This is consistent with how ATen operators work with functorch's grad transform:
68
# they always redispatch to the original operator.
69
# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
70
#
71
# grad1 will:
72
# - set up the autograd graph
73
# - unwrap the inputs
74
# - redispatch to at::sin (*)
75
# - rewrap the outputs on the return
76
#
77
# On the redispatch in (*), grad0 will:
78
# - set up the autograd graph
79
# - unwrap the inputs
80
# - redispatch to at::sin
81
# - rewrap the outputs on the return
82
#
83
# To "set up the autograd graph", we generate a _SingleLevelFunction
84
# and apply it.
85
@custom_function_call.py_impl(TransformType.Grad)
86
@custom_function_call.py_impl(TransformType.Jvp)
87
def custom_function_call_grad(interpreter, autograd_function, *operands):
88
    Generated = generate_single_level_function(interpreter, autograd_function)
89
    with enable_single_level_autograd_function():
90
        flat_out = Generated.apply(*operands)
91
    return flat_out
92

93

94
def generate_single_level_function(interpreter, autograd_function):
95
    level = interpreter.level()
96

97
    def forward(*operands):
98
        unwrapped_operands = pytree.tree_map_only(
99
            torch.Tensor,
100
            lambda x: _unwrap_for_grad(x, level),
101
            operands)
102
        # Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
103
        # the transform. _SingleLevelFunction will turn off both fwd and bwd
104
        # gradient computation and we need to turn it back on here.
105
        with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
106
            unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)
107

108
        # See NOTE [mark_dirty object identity check]
109
        def wrap_fn(output):
110
            return _wrap_for_grad(output, level)
111

112
        return wrap_outputs_maintaining_identity(
113
            unwrapped_output,
114
            unwrapped_operands,
115
            operands,
116
            wrap_fn)
117

118
    def setup_context(ctx, inputs, output):
119
        return autograd_function.setup_context(ctx, inputs, output)
120

121
    # backward is only used if the transform is TransformType.Grad
122
    def backward(ctx, *grads):
123
        result = autograd_function.backward(ctx, *grads)
124
        return result
125

126
    # jvp is only used if the transform is TransformType.Jvp
127
    def jvp(ctx, *tangents):
128
        result = autograd_function.jvp(ctx, *tangents)
129
        return result
130

131
    # This is the sequence of magic words to dynamically generate a Subclass with
132
    # a given name. A Tensor's .grad_fn field has a class name that is the original
133
    # autograd.Function's name + Backward, so we do this to generate some
134
    # meaningful name.
135
    name = f'{autograd_function.__name__}Generated'
136
    Generated = type(
137
        name,
138
        (torch.autograd.function._SingleLevelFunction,),
139
        {
140
            'forward': staticmethod(forward),
141
            'backward': staticmethod(backward),
142
            'jvp': staticmethod(jvp),
143
            'setup_context': staticmethod(setup_context),
144
        },
145
    )
146
    return Generated
147

148
# wrap_outputs_maintaining_identity handles outputs from the vmap,
149
# backward (vjp), and jvp staticmethod. The way it distinguishes
150
# between the vmap case and the {backward, jvp} case is if the out_dims
151
# are specified or not.
152
#
153
# NB: we cannot use out_dims=None as the deciding factor. This because
154
# out_dims=None can still happen in the vmap staticmethod! What the
155
# user is saying in that case is that their output does not have a
156
# dimension that is being vmapped over, which is valid.
157
NO_OUT_DIMS = "not specified"
158

159
# NOTE [mark_dirty object identity check]
160
# autograd.Function's ctx.mark_dirty expect a returned input
161
# to have the same object identity as the input.
162
# Mode-only functorch will greatly simplify this logic.
163
def wrap_outputs_maintaining_identity(
164
        outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):
165
    flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
166
    flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
167

168
    unwrapped_input_to_orig_input = {
169
        id(unwrapped): orig
170
        for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
171
    }
172

173
    flat_outputs, spec = pytree.tree_flatten(outputs)
174
    result = []
175

176
    out_dims_specified = out_dims != NO_OUT_DIMS
177

178
    if out_dims_specified:
179
        flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
180
        # _broadcast_to_and_flatten returns None if it is unable to broadcast.
181
        # TODO: update following link from master to stable once that's out
182
        if flat_out_dims is None:
183
            raise RuntimeError(
184
                f"The autograd.Function's vmap staticmethod returned an "
185
                f"incompatible (output, out_dims) tuple. "
186
                f"Expected out_dims={out_dims} "
187
                f"to be compatible with the structure of `output`. "
188
                f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
189
                f"but output has structure {spec}. "
190
                f"For more details, please see "
191
                f"https://pytorch.org/docs/master/notes/extending.func.html"
192
            )
193

194
    for i, output in enumerate(flat_outputs):
195
        if not isinstance(output, torch.Tensor):
196
            result.append(output)
197
            continue
198
        if id(output) in unwrapped_input_to_orig_input:
199
            result.append(unwrapped_input_to_orig_input[id(output)])
200
            continue
201
        if out_dims_specified:
202
            result.append(wrap_fn(output, flat_out_dims[i]))  # type: ignore[possibly-undefined, index]
203
        else:
204
            result.append(wrap_fn(output))
205

206
    return pytree.tree_unflatten(result, spec)
207

208

209
# NOTE: [functorch vjp and autograd interaction]
210
# There's an edge case with the functorch vjp and autograd interaction
211
# that will eventually be fixed by mode-only functorch.
212
# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
213
# so we (the framework) need to do it manually. Regular PyTorch operators
214
# automatically do so this is consistent.
215
#
216
# class MyExp(torch.autograd.Function):
217
#     @staticmethod
218
#     def forward(x):
219
#         return x.exp()
220
#
221
#     @staticmethod
222
#     def setup_context(ctx, inputs, output):
223
#         y = output
224
#         ctx.save_for_backward(y)
225
#
226
#     @staticmethod
227
#     def backward(gy):
228
#         y, = ctx.saved_tensors()
229
#         return MyMul.apply(gy, y)
230
#
231
# x = torch.randn([], requires_grad=True)
232
# gy = torch.randn([], requires_grad=True)
233
# _, vjp_fn = vjp(MySin.apply, x)
234
# result = vjp_fn(gy)
235
#
236
# MyMul is an autograd.Function that is not shown here.
237
# It saves a `y` for backward (since gy requires grad).
238
#
239
# in vjp_fn(gy), we get:
240
# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
241
# Because the y that is saved for backward by MyExp is a GradTensorWrapper
242
# but is now dead since we are outside the vjp context.
243
#
244
# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
245
# will automatically unwrap the GradTensorWrapper when applied.
246
# But since autograd.Function technically sits above the regular PyTorch
247
# dispatcher, it doesn't get this treatment. So we manually do
248
# the unwrapping to be consistent with regular PyTorch dispatcher operations.
249

250

251
class VmapInfo(NamedTuple):
252
    batch_size: int
253
    randomness: str
254

255

256
def has_overriden_vmap_rule(autograd_function):
257
    return autograd_function.vmap is not torch.autograd.Function.vmap
258

259

260
def validate_vmap_returns_tuple_of_two_elements(result):
261
    base_error_msg = (
262
        "Expected the vmap staticmethod to have two returns, an output "
263
        "and out_dims with pytree structure compatible with the output. "
264
    )
265
    if not isinstance(result, tuple):
266
        raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
267
    if not len(result) == 2:
268
        raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
269

270
@custom_function_call.py_impl(TransformType.Vmap)
271
def custom_function_call_vmap(interpreter, autograd_function, *operands):
272
    if autograd_function.generate_vmap_rule:
273
        if has_overriden_vmap_rule(autograd_function):
274
            # TODO: Update link to stable once that's out
275
            # https://github.com/pytorch/pytorch/issues/92029
276
            raise RuntimeError(
277
                f"You tried to vmap over {autograd_function.__name__}, but "
278
                f"it has both generate_vmap_rule=True and an overriden vmap "
279
                f"staticmethod. Please set generate_vmap_rule=False or delete "
280
                f"the overriden vmap staticmethod to avoid ambiguity. "
281
                f"For more details, please see "
282
                f"https://pytorch.org/docs/master/notes/extending.func.html")
283
        return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands)
284

285
    if not has_overriden_vmap_rule(autograd_function):
286
        # TODO: Update link to stable once that's out
287
        # https://github.com/pytorch/pytorch/issues/92029
288
        raise RuntimeError(
289
            f"You tried to vmap over {autograd_function.__name__}, but "
290
            f"it does not have vmap support. Please override and implement the "
291
            f"vmap staticmethod or set generate_vmap_rule=True. "
292
            f"For more details, please see "
293
            f"https://pytorch.org/docs/master/notes/extending.func.html")
294

295
    current_level = interpreter.level()
296
    info = VmapInfo(
297
        batch_size=interpreter.batch_size(),
298
        randomness=interpreter.randomness(),
299
    )
300
    unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
301

302
    # If none of the tensors are batched at the current level, then we skip the
303
    # current level. This saves the user from needing to handle this case in
304
    # their vmap staticmethod (and is consistent with our C++ batching rule API)
305
    if pytree.tree_all(lambda dim: dim is None, in_dims):
306
        with interpreter.lower():
307
            return custom_function_call(autograd_function, *operands)
308

309
    with interpreter.lower():
310
        result = autograd_function.vmap(info, in_dims, *unwrapped_operands)
311
    validate_vmap_returns_tuple_of_two_elements(result)
312
    unwrapped_output, out_dims = result
313

314
    # See NOTE [mark_dirty object identity check]
315
    def wrap_fn(output, out_dim):
316
        return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)
317

318
    return wrap_outputs_maintaining_identity(
319
        unwrapped_output,
320
        unwrapped_operands,
321
        operands,
322
        wrap_fn,
323
        out_dims=out_dims)
324

325

326
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
327
    unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
328
    vmapped_function, get_out_dims = vmapify_autograd_function(
329
        autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness())
330

331
    with interpreter.lower():
332
        output = custom_function_call(vmapped_function, *unwrapped_operands)
333

334
    out_dims = get_out_dims()
335
    return wrap_batched(output, out_dims, interpreter.level())
336

337

338
@custom_function_call.py_impl(TransformType.Functionalize)
339
def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands):
340
    raise RuntimeError("NYI: Functionalize rule for custom_function_call")
341

342

343
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
344
    # The following values are saved from the forward() and setup_context()
345
    # and used in backward().
346
    # Why do we save the values out here instead of on the ctx object?
347
    # - out_dims: There's no way to retrieve this from forward()
348
    # - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
349
    #   vmap(vmap( but not completely sure if it is a problem. If we
350
    #   assigned those fields to the ctx object, the worry is that they
351
    #   get overwritten.
352
    init_val = "not populated"
353
    out_dims = init_val
354
    input_shapes: Any = init_val
355
    saved_tensors_bdims: Any = init_val
356

357
    def forward(*operands):
358
        nonlocal out_dims
359
        outputs, out_dims = restore_vmap(
360
            autograd_function.forward, in_dims, batch_size, randomness)(*operands)
361
        return outputs
362

363
    def setup_context(ctx, inputs, outputs):
364
        input_shapes_ = None
365
        saved_tensors_bdims_ = None
366

367
        def inner(inputs, outputs):
368
            # wrapped_ctx.save_for_backward will:
369
            # - unwrap batchedtensors into (tensor, bdim)
370
            # - save_for_backward(*unwrapped_tensors)
371
            # - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
372
            wrapped_ctx = CtxCustomSave(ctx, current_level())
373
            autograd_function.setup_context(wrapped_ctx, inputs, outputs)
374

375
            # input_shapes are used for reductify later to reduce expanded gradients
376
            # to the correct shape.
377
            # See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
378
            # for more details
379
            nonlocal input_shapes_
380
            input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None
381
                                  for inp in inputs)
382
            nonlocal saved_tensors_bdims_
383
            saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
384

385
        # See NOTE: [Why do we need to run setup_context under a vmap?]
386
        restore_vmap(
387
            inner,
388
            (in_dims, out_dims),
389
            batch_size,
390
            randomness,
391
        )(inputs, outputs)
392

393
        nonlocal input_shapes
394
        input_shapes = input_shapes_
395
        nonlocal saved_tensors_bdims
396
        saved_tensors_bdims = saved_tensors_bdims_
397

398
    def jvp(ctx, *tangents):
399
        assert out_dims != init_val
400
        assert saved_tensors_bdims != init_val
401

402
        def jvp_no_context(saved_tensors, tangents):
403
            wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
404
            return autograd_function.jvp(wrapped_ctx, *tangents)
405

406
        tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
407
        out_tangents, out_tangents_dims = restore_vmap(
408
            jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)(
409
                ctx.saved_tensors, tangents)
410

411
        result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
412
        return result
413

414
    def backward(ctx, *grad_outputs):
415
        assert out_dims != init_val
416
        assert input_shapes != init_val
417
        assert saved_tensors_bdims != init_val
418

419
        def backward_no_context(inputs):
420
            saved_tensors, grad_outputs = inputs
421
            wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
422
            return autograd_function.backward(wrapped_ctx, *grad_outputs)
423

424
        grad_ins, grad_ins_dims = restore_vmap(
425
            backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)(
426
                (ctx.saved_tensors, grad_outputs))
427
        result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
428
        return result
429

430
    name = f'Vmapped{autograd_function.__name__}'
431
    Generated = type(
432
        name,
433
        (torch.autograd.Function,),
434
        {
435
            'forward': staticmethod(forward),
436
            'backward': staticmethod(backward),
437
            'jvp': staticmethod(jvp),
438
            'setup_context': staticmethod(setup_context),
439
            'generate_vmap_rule': True
440
        }
441
    )
442

443
    def get_out_dims():
444
        assert out_dims != init_val
445
        return out_dims
446

447
    return Generated, get_out_dims
448

449

450
# tangents might be None, so we need to replace
451
# the corresponding in_dims with None.
452
def get_tangents_in_dims(input_dims, tangents):
453
    flat_in_dims, spec = pytree.tree_flatten(input_dims)
454
    flat_tangents = pytree.arg_tree_leaves(*tangents)
455
    result = [None if tangent is None else in_dim
456
              for in_dim, tangent in zip(flat_in_dims, flat_tangents)]
457
    return pytree.tree_unflatten(result, spec)
458

459

460
# NOTE: [Why do we need to run setup_context under a vmap?]
461
# Consider the following autograd.Function
462
#
463
# class Sum(torch.autograd.Function):
464
#    @staticmethod
465
#    def forward(x):
466
#        return x.sum()
467
#    @staticmethod
468
#    def setup_context(ctx, inputs, outputs):
469
#        ctx.x_shape = inputs[0]
470
#    @staticmethod
471
#    def backward(ctx, gy):
472
#        return gy.expand(ctx.x_shape)
473
#
474
# x = torch.randn(B, 4)
475
# in_dims = 0
476
# vmap(Sum.apply, in_dims)(x)
477
#
478
# Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum:
479
#
480
# class VmappedSum(torch.autograd.Function):
481
#    @staticmethod
482
#    def forward(x):
483
#        return vmap(Sum.forward, in_dims)(x)
484
#
485
#    @staticmethod
486
#    def setup_context(ctx, inputs, outputs):
487
#        Sum.setup_context(ctx, inputs, outputs)
488
#
489
#    @staticmethod
490
#    def backward(ctx, gy):
491
#        def backward_no_context(gy):
492
#            return gy.expand(ctx.x_shape)
493
#
494
#        dims = (0,)
495
#        gx = vmap(backward_no_context, dims)(gy)
496
#        return gx
497
#
498
# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
499
# and we’re doing:
500
#
501
# def backward_no_context(gy):
502
#     return gy.expand([B, 4])
503
#
504
# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
505
#
506
# This gives us the wrong result (gx has shape [B, B, 4], but it should
507
# have shape [4]). Performing vmap over setup_context means the shape
508
# saved has shape [4] and leads to a correct result shape for gx.
509

510
# Wraps a ctx object. Forwards all attr accesses to the underlying object
511
# except for the attrs in _pt_attrs
512
class WrappedCtx:
513
    _pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')
514

515
    def __init__(self, ctx):
516
        if not isinstance(ctx, WrappedCtx):
517
            reserved_attrs = type(self)._pt_reserved_attrs
518
            for name in reserved_attrs:
519
                if not hasattr(ctx, name):
520
                    continue
521
                raise RuntimeError(
522
                    f'PyTorch reserves the {reserved_attrs} field on ctx. '
523
                    'Please name your fields on ctx something else to avoid name '
524
                    'collision.')
525
        self._pt_inner_ctx = ctx
526

527
    def __getattr__(self, name):
528
        return getattr(self._pt_inner_ctx, name)
529

530
    def __setattr__(self, name, value):
531
        if name in type(self)._pt_reserved_attrs:
532
            self.__dict__[name] = value
533
            return
534
        return setattr(self._pt_inner_ctx, name, value)
535

536
# Wraps ctx to create a new ctx object that overrides saved_tensors.
537
class CtxWithSavedTensors(WrappedCtx):
538
    _pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)
539

540
    def __init__(self, ctx, new_saved_tensors):
541
        super().__init__(ctx)
542
        self._pt_new_saved_tensors = new_saved_tensors
543

544
    @property
545
    def saved_tensors(self):
546
        return self._pt_new_saved_tensors
547

548
class CtxCustomSave(WrappedCtx):
549
    _pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level',
550
                          *WrappedCtx._pt_reserved_attrs)
551

552
    def __init__(self, ctx, current_level):
553
        super().__init__(ctx)
554
        self._pt_saved_tensors_bdims = ()
555
        self._pt_current_level = current_level
556

557
    def save_for_backward(self, *tensors):
558
        unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
559
        self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
560
        self._pt_saved_tensors_bdims = bdims
561

562
    def save_for_forward(self, *tensors):
563
        unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
564
        self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
565
        self._pt_saved_tensors_bdims = bdims
566

567

568
def reductify(grad_input, grad_input_bdim, input_bdim, batch_size,
569
              target_shape_without_bdim_to_reduce_to=None):
570
    if not isinstance(grad_input, tuple):
571
        grad_input = (grad_input,)
572
    if not isinstance(grad_input_bdim, tuple):
573
        grad_input_bdim = (grad_input_bdim,)
574
    if not isinstance(input_bdim, tuple):
575
        input_bdim = (input_bdim,)
576

577
    if target_shape_without_bdim_to_reduce_to is None:
578
        target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
579
    result = tuple(
580
        reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
581
        for gi, gi_bdim, i_bdim, maybe_ishape in
582
        zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to)
583
    )
584
    return result
585

586

587
def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size,
588
                   target_shape_without_bdim_to_reduce_to=None):
589
    if grad_input is None:
590
        return None
591

592
    if grad_input_bdim is None and input_bdim is None:
593
        return grad_input
594

595
    if grad_input_bdim is not None and input_bdim is None:
596
        return grad_input.sum(grad_input_bdim)
597

598
    # NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
599
    # For reverse-mode AD,
600
    # given a grad_input and input, it is valid for the user to return a
601
    # grad_input that has a broadcasted shape when compared to the input.
602
    # In this situation, autograd automatically reduces the grad_input to
603
    # the shape of the input.
604
    #
605
    # However, when input_bdim is not None, we have problems.
606
    #
607
    # [example 1]
608
    # grad_input: Tensor[3, 4], input: Tensor[B, 4]
609
    # We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
610
    # from [B, 4].
611
    #
612
    # [example 2]
613
    # grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
614
    # We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
615
    # from [B, 4].
616
    #
617
    # This means that we need to also reduce the grad_input to the shape of the
618
    # input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
619
    # if not-None then we do the reducing manually, otherwise, we do not do a reduction.
620
    assert input_bdim is not None
621

622
    if grad_input_bdim is None:
623
        grad_input = grad_input.unsqueeze(input_bdim)
624
        new_shape = list(grad_input.shape)
625
        new_shape[input_bdim] = batch_size
626
        grad_input = grad_input.expand(new_shape)
627
        grad_input_bdim = input_bdim
628

629
    if target_shape_without_bdim_to_reduce_to is not None:
630
        return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(
631
            grad_input, target_shape_without_bdim_to_reduce_to)
632

633
    if input_bdim != grad_input_bdim:
634
        grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
635
    return grad_input
636

637

638
class AutogradFunctionApply(HigherOrderOperator):
639
    def __init__(self):
640
        super().__init__("autograd_function_apply")
641

642
    def __call__(self, fwd, bwd, *fwd_args):
643
        saved_values = None
644

645
        class ApplyTemplate(torch.autograd.Function):
646
            @staticmethod
647
            def forward(ctx, *args):
648
                nonlocal saved_values
649
                output, saved_values = fwd(None, *args)
650
                return output
651

652
            @staticmethod
653
            def backward(ctx, *grad):
654
                return bwd(None, *grad, *saved_values)
655

656
        return ApplyTemplate.apply(*fwd_args)
657

658

659
autograd_function_apply = AutogradFunctionApply()
660

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.