pytorch
752 строки · 27.3 Кб
1# mypy: allow-untyped-defs
2from typing import Any, NamedTuple, Tuple
3
4import torch
5import torch.utils._pytree as pytree
6from torch._C._functorch import (
7_unwrap_for_grad,
8_wrap_for_grad,
9current_level,
10TransformType,
11)
12from torch._functorch.apis import vmap
13from torch._functorch.utils import enable_single_level_autograd_function
14from torch._functorch.vmap import (
15_add_batch_dim,
16_broadcast_to_and_flatten,
17restore_vmap,
18unwrap_batched,
19wrap_batched,
20)
21from torch._ops import HigherOrderOperator
22from torch.autograd.forward_ad import _set_fwd_grad_enabled
23
24
25# autograd.Function technically runs before the regular PyTorch dispatcher.
26# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
27# work with it. One day we might decide to change this, but until then,
28# we need to give the illusion that autograd.Function runs before those things.
29#
30# We do this by using creating a custom HigherOrderOperator that only functorch
31# dispatches specially.
32class CustomFunctionHigherOrderOperator(HigherOrderOperator):
33def __init__(self) -> None:
34super().__init__("custom_function_call")
35
36def __call__(self, autograd_function, *args, **kwargs):
37# When custom_function_call is done dispatching through functorch,
38# it should just invoke the autograd.Function. This is consistent
39# with the autograd.Function behavior of being invoked before the
40# PyTorch dispatcher.
41#
42# This will lead us into trouble later down the line, but this is
43# pre-existing. There is an invariant that a function traced by
44# make_fx should have the same behavior when provided the same
45# Tensor. However, make_fx sees autograd.Function as a composite
46# (because autograd.Function happens before the Python dispatch key)
47# and only traces the forward pass.
48if torch._C._are_functorch_transforms_active():
49return super().__call__(autograd_function, *args, **kwargs)
50return autograd_function.apply(*args, **kwargs)
51
52
53# "custom_function_call"
54# This is the mechanism for an autograd.Function that works with functorch transforms.
55# It wraps an autograd.Function; interactions with functorch transforms are defined
56# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
57# dispatcher.
58custom_function_call = CustomFunctionHigherOrderOperator()
59
60
61# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
62# (autograd.Function that only works with a single layer (level) of functorch) that:
63# - unwraps the inputs
64# - redispatches to custom_function_call
65# - wraps the outputs
66# and whose backward pass calls the original autograd.Function's backward.
67#
68# Why do we need to redispatch to custom_function_call?
69# -----------------------------------------------------
70# This is consistent with how ATen operators work with functorch's grad transform:
71# they always redispatch to the original operator.
72# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
73#
74# grad1 will:
75# - set up the autograd graph
76# - unwrap the inputs
77# - redispatch to at::sin (*)
78# - rewrap the outputs on the return
79#
80# On the redispatch in (*), grad0 will:
81# - set up the autograd graph
82# - unwrap the inputs
83# - redispatch to at::sin
84# - rewrap the outputs on the return
85#
86# To "set up the autograd graph", we generate a _SingleLevelFunction
87# and apply it.
88@custom_function_call.py_impl(TransformType.Grad)
89@custom_function_call.py_impl(TransformType.Jvp)
90def custom_function_call_grad(interpreter, autograd_function, *operands):
91Generated = generate_single_level_function(interpreter, autograd_function)
92with enable_single_level_autograd_function():
93flat_out = Generated.apply(*operands)
94return flat_out
95
96
97def generate_single_level_function(interpreter, autograd_function):
98level = interpreter.level()
99
100def forward(*operands):
101unwrapped_operands = pytree.tree_map_only(
102torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
103)
104# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter
105# the transform. _SingleLevelFunction will turn off both fwd and bwd
106# gradient computation and we need to turn it back on here.
107with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
108unwrapped_output = custom_function_call(
109autograd_function, *unwrapped_operands
110)
111
112# See NOTE [mark_dirty object identity check]
113def wrap_fn(output):
114return _wrap_for_grad(output, level)
115
116return wrap_outputs_maintaining_identity(
117unwrapped_output, unwrapped_operands, operands, wrap_fn
118)
119
120def setup_context(ctx, inputs, output):
121return autograd_function.setup_context(ctx, inputs, output)
122
123# backward is only used if the transform is TransformType.Grad
124def backward(ctx, *grads):
125result = autograd_function.backward(ctx, *grads)
126return result
127
128# jvp is only used if the transform is TransformType.Jvp
129def jvp(ctx, *tangents):
130result = autograd_function.jvp(ctx, *tangents)
131return result
132
133# This is the sequence of magic words to dynamically generate a Subclass with
134# a given name. A Tensor's .grad_fn field has a class name that is the original
135# autograd.Function's name + Backward, so we do this to generate some
136# meaningful name.
137name = f"{autograd_function.__name__}Generated"
138Generated = type(
139name,
140(torch.autograd.function._SingleLevelFunction,),
141{
142"forward": staticmethod(forward),
143"backward": staticmethod(backward),
144"jvp": staticmethod(jvp),
145"setup_context": staticmethod(setup_context),
146},
147)
148return Generated
149
150
151# wrap_outputs_maintaining_identity handles outputs from the vmap,
152# backward (vjp), and jvp staticmethod. The way it distinguishes
153# between the vmap case and the {backward, jvp} case is if the out_dims
154# are specified or not.
155#
156# NB: we cannot use out_dims=None as the deciding factor. This because
157# out_dims=None can still happen in the vmap staticmethod! What the
158# user is saying in that case is that their output does not have a
159# dimension that is being vmapped over, which is valid.
160NO_OUT_DIMS = "not specified"
161
162
163# NOTE [mark_dirty object identity check]
164# autograd.Function's ctx.mark_dirty expect a returned input
165# to have the same object identity as the input.
166# Mode-only functorch will greatly simplify this logic.
167def wrap_outputs_maintaining_identity(
168outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
169):
170flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
171flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
172
173unwrapped_input_to_orig_input = {
174id(unwrapped): orig
175for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
176}
177
178flat_outputs, spec = pytree.tree_flatten(outputs)
179result = []
180
181out_dims_specified = out_dims != NO_OUT_DIMS
182
183if out_dims_specified:
184flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
185# _broadcast_to_and_flatten returns None if it is unable to broadcast.
186# TODO: update following link from master to stable once that's out
187if flat_out_dims is None:
188raise RuntimeError(
189f"The autograd.Function's vmap staticmethod returned an "
190f"incompatible (output, out_dims) tuple. "
191f"Expected out_dims={out_dims} "
192f"to be compatible with the structure of `output`. "
193f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
194f"but output has structure {spec}. "
195f"For more details, please see "
196f"https://pytorch.org/docs/main/notes/extending.func.html"
197)
198
199for i, output in enumerate(flat_outputs):
200if not isinstance(output, torch.Tensor):
201result.append(output)
202continue
203if id(output) in unwrapped_input_to_orig_input:
204result.append(unwrapped_input_to_orig_input[id(output)])
205continue
206if out_dims_specified:
207result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]
208else:
209result.append(wrap_fn(output))
210
211return pytree.tree_unflatten(result, spec)
212
213
214# NOTE: [functorch vjp and autograd interaction]
215# There's an edge case with the functorch vjp and autograd interaction
216# that will eventually be fixed by mode-only functorch.
217# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
218# so we (the framework) need to do it manually. Regular PyTorch operators
219# automatically do so this is consistent.
220#
221# class MyExp(torch.autograd.Function):
222# @staticmethod
223# def forward(x):
224# return x.exp()
225#
226# @staticmethod
227# def setup_context(ctx, inputs, output):
228# y = output
229# ctx.save_for_backward(y)
230#
231# @staticmethod
232# def backward(gy):
233# y, = ctx.saved_tensors()
234# return MyMul.apply(gy, y)
235#
236# x = torch.randn([], requires_grad=True)
237# gy = torch.randn([], requires_grad=True)
238# _, vjp_fn = vjp(MySin.apply, x)
239# result = vjp_fn(gy)
240#
241# MyMul is an autograd.Function that is not shown here.
242# It saves a `y` for backward (since gy requires grad).
243#
244# in vjp_fn(gy), we get:
245# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
246# Because the y that is saved for backward by MyExp is a GradTensorWrapper
247# but is now dead since we are outside the vjp context.
248#
249# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
250# will automatically unwrap the GradTensorWrapper when applied.
251# But since autograd.Function technically sits above the regular PyTorch
252# dispatcher, it doesn't get this treatment. So we manually do
253# the unwrapping to be consistent with regular PyTorch dispatcher operations.
254
255
256class VmapInfo(NamedTuple):
257batch_size: int
258randomness: str
259
260
261def has_overriden_vmap_rule(autograd_function):
262return autograd_function.vmap is not torch.autograd.Function.vmap
263
264
265def validate_vmap_returns_tuple_of_two_elements(result):
266base_error_msg = (
267"Expected the vmap staticmethod to have two returns, an output "
268"and out_dims with pytree structure compatible with the output. "
269)
270if not isinstance(result, tuple):
271raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
272if not len(result) == 2:
273raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
274
275
276@custom_function_call.py_impl(TransformType.Vmap)
277def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
278if any(
279isinstance(val, torch.Tensor)
280for val in torch.utils._pytree.tree_flatten(kwargs)[0]
281):
282raise NotImplementedError(
283f"Run vmap on autograd.Function with kwarg-only Tensor args. "
284f"Please do not pass kwarg-only Tensors to autograd.Function. "
285f"Got: {kwargs}"
286)
287
288if autograd_function.generate_vmap_rule:
289if has_overriden_vmap_rule(autograd_function):
290# TODO: Update link to stable once that's out
291# https://github.com/pytorch/pytorch/issues/92029
292raise RuntimeError(
293f"You tried to vmap over {autograd_function.__name__}, but "
294f"it has both generate_vmap_rule=True and an overriden vmap "
295f"staticmethod. Please set generate_vmap_rule=False or delete "
296f"the overriden vmap staticmethod to avoid ambiguity. "
297f"For more details, please see "
298f"https://pytorch.org/docs/main/notes/extending.func.html"
299)
300return custom_function_call_vmap_generate_rule(
301interpreter, autograd_function, *operands
302)
303
304if not has_overriden_vmap_rule(autograd_function):
305# TODO: Update link to stable once that's out
306# https://github.com/pytorch/pytorch/issues/92029
307raise RuntimeError(
308f"You tried to vmap over {autograd_function.__name__}, but "
309f"it does not have vmap support. Please override and implement the "
310f"vmap staticmethod or set generate_vmap_rule=True. "
311f"For more details, please see "
312f"https://pytorch.org/docs/main/notes/extending.func.html"
313)
314
315return custom_function_call_vmap_helper(
316interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
317)
318
319
320def custom_function_call_vmap_helper(
321interpreter, vmap_function, op, *operands, **kwargs
322):
323current_level = interpreter.level()
324info = VmapInfo(
325batch_size=interpreter.batch_size(),
326randomness=interpreter.randomness(),
327)
328unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
329# If none of the tensors are batched at the current level, then we skip the
330# current level. This saves the user from needing to handle this case in
331# their vmap staticmethod (and is consistent with our C++ batching rule API)
332if pytree.tree_all(lambda dim: dim is None, in_dims):
333with interpreter.lower():
334if isinstance(op, torch.autograd.function.FunctionMeta):
335return custom_function_call(op, *operands)
336else:
337return op(*operands, **kwargs)
338
339with interpreter.lower():
340result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
341validate_vmap_returns_tuple_of_two_elements(result)
342unwrapped_output, out_dims = result
343
344# See NOTE [mark_dirty object identity check]
345def wrap_fn(output, out_dim):
346return (
347output
348if out_dim is None
349else _add_batch_dim(output, out_dim, current_level)
350)
351
352return wrap_outputs_maintaining_identity(
353unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
354)
355
356
357def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
358unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
359vmapped_function, get_out_dims = vmapify_autograd_function(
360autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
361)
362
363with interpreter.lower():
364output = custom_function_call(vmapped_function, *unwrapped_operands)
365
366out_dims = get_out_dims()
367return wrap_batched(output, out_dims, interpreter.level())
368
369
370@custom_function_call.py_impl(TransformType.Functionalize)
371def custom_function_call_functionalize(
372interpreter, autograd_function, generate_vmap_rule, *operands
373):
374raise RuntimeError("NYI: Functionalize rule for custom_function_call")
375
376
377def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
378# The following values are saved from the forward() and setup_context()
379# and used in backward().
380# Why do we save the values out here instead of on the ctx object?
381# - out_dims: There's no way to retrieve this from forward()
382# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting
383# vmap(vmap( but not completely sure if it is a problem. If we
384# assigned those fields to the ctx object, the worry is that they
385# get overwritten.
386init_val = "not populated"
387out_dims = init_val
388input_shapes: Any = init_val
389saved_tensors_bdims: Any = init_val
390
391def forward(*operands):
392nonlocal out_dims
393outputs, out_dims = restore_vmap(
394autograd_function.forward, in_dims, batch_size, randomness
395)(*operands)
396return outputs
397
398def setup_context(ctx, inputs, outputs):
399input_shapes_ = None
400saved_tensors_bdims_ = None
401
402def inner(inputs, outputs):
403# wrapped_ctx.save_for_backward will:
404# - unwrap batchedtensors into (tensor, bdim)
405# - save_for_backward(*unwrapped_tensors)
406# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims
407wrapped_ctx = CtxCustomSave(ctx, current_level())
408autograd_function.setup_context(wrapped_ctx, inputs, outputs)
409
410# input_shapes are used for reductify later to reduce expanded gradients
411# to the correct shape.
412# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
413# for more details
414nonlocal input_shapes_
415input_shapes_ = tuple(
416inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
417)
418nonlocal saved_tensors_bdims_
419saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims
420
421# See NOTE: [Why do we need to run setup_context under a vmap?]
422restore_vmap(
423inner,
424(in_dims, out_dims),
425batch_size,
426randomness,
427)(inputs, outputs)
428
429nonlocal input_shapes
430input_shapes = input_shapes_
431nonlocal saved_tensors_bdims
432saved_tensors_bdims = saved_tensors_bdims_
433
434def jvp(ctx, *tangents):
435assert out_dims != init_val
436assert saved_tensors_bdims != init_val
437
438def jvp_no_context(saved_tensors, tangents):
439wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
440return autograd_function.jvp(wrapped_ctx, *tangents)
441
442tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
443out_tangents, out_tangents_dims = restore_vmap(
444jvp_no_context,
445(saved_tensors_bdims, tangent_in_dims),
446batch_size,
447randomness,
448)(ctx.saved_tensors, tangents)
449
450result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)
451return result
452
453def backward(ctx, *grad_outputs):
454assert out_dims != init_val
455assert input_shapes != init_val
456assert saved_tensors_bdims != init_val
457
458def backward_no_context(inputs):
459saved_tensors, grad_outputs = inputs
460wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
461return autograd_function.backward(wrapped_ctx, *grad_outputs)
462
463grad_ins, grad_ins_dims = restore_vmap(
464backward_no_context,
465((saved_tensors_bdims, out_dims),),
466batch_size,
467randomness,
468)((ctx.saved_tensors, grad_outputs))
469result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)
470return result
471
472name = f"Vmapped{autograd_function.__name__}"
473Generated = type(
474name,
475(torch.autograd.Function,),
476{
477"forward": staticmethod(forward),
478"backward": staticmethod(backward),
479"jvp": staticmethod(jvp),
480"setup_context": staticmethod(setup_context),
481"generate_vmap_rule": True,
482},
483)
484
485def get_out_dims():
486assert out_dims != init_val
487return out_dims
488
489return Generated, get_out_dims
490
491
492# tangents might be None, so we need to replace
493# the corresponding in_dims with None.
494def get_tangents_in_dims(input_dims, tangents):
495flat_in_dims, spec = pytree.tree_flatten(input_dims)
496flat_tangents = pytree.arg_tree_leaves(*tangents)
497result = [
498None if tangent is None else in_dim
499for in_dim, tangent in zip(flat_in_dims, flat_tangents)
500]
501return pytree.tree_unflatten(result, spec)
502
503
504# NOTE: [Why do we need to run setup_context under a vmap?]
505# Consider the following autograd.Function
506#
507# class Sum(torch.autograd.Function):
508# @staticmethod
509# def forward(x):
510# return x.sum()
511# @staticmethod
512# def setup_context(ctx, inputs, outputs):
513# ctx.x_shape = inputs[0]
514# @staticmethod
515# def backward(ctx, gy):
516# return gy.expand(ctx.x_shape)
517#
518# x = torch.randn(B, 4)
519# in_dims = 0
520# vmap(Sum.apply, in_dims)(x)
521#
522# Let's assume for a moment that we didn't vmap setup_context in VmappedSum:
523#
524# class VmappedSum(torch.autograd.Function):
525# @staticmethod
526# def forward(x):
527# return vmap(Sum.forward, in_dims)(x)
528#
529# @staticmethod
530# def setup_context(ctx, inputs, outputs):
531# Sum.setup_context(ctx, inputs, outputs)
532#
533# @staticmethod
534# def backward(ctx, gy):
535# def backward_no_context(gy):
536# return gy.expand(ctx.x_shape)
537#
538# dims = (0,)
539# gx = vmap(backward_no_context, dims)(gy)
540# return gx
541#
542# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
543# and we're doing:
544#
545# def backward_no_context(gy):
546# return gy.expand([B, 4])
547#
548# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
549#
550# This gives us the wrong result (gx has shape [B, B, 4], but it should
551# have shape [4]). Performing vmap over setup_context means the shape
552# saved has shape [4] and leads to a correct result shape for gx.
553
554
555# Wraps a ctx object. Forwards all attr accesses to the underlying object
556# except for the attrs in _pt_attrs
557class WrappedCtx:
558_pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
559
560def __init__(self, ctx):
561if not isinstance(ctx, WrappedCtx):
562reserved_attrs = type(self)._pt_reserved_attrs
563for name in reserved_attrs:
564if not hasattr(ctx, name):
565continue
566raise RuntimeError(
567f"PyTorch reserves the {reserved_attrs} field on ctx. "
568"Please name your fields on ctx something else to avoid name "
569"collision."
570)
571self._pt_inner_ctx = ctx
572
573def __getattr__(self, name):
574return getattr(self._pt_inner_ctx, name)
575
576def __setattr__(self, name, value):
577if name in type(self)._pt_reserved_attrs:
578self.__dict__[name] = value
579return
580return setattr(self._pt_inner_ctx, name, value)
581
582
583# Wraps ctx to create a new ctx object that overrides saved_tensors.
584class CtxWithSavedTensors(WrappedCtx):
585_pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
586
587def __init__(self, ctx, new_saved_tensors):
588super().__init__(ctx)
589self._pt_new_saved_tensors = new_saved_tensors
590
591@property
592def saved_tensors(self):
593return self._pt_new_saved_tensors
594
595
596class CtxCustomSave(WrappedCtx):
597_pt_reserved_attrs = (
598"_pt_saved_tensors_bdims",
599"_pt_current_level",
600*WrappedCtx._pt_reserved_attrs,
601)
602
603def __init__(self, ctx, current_level):
604super().__init__(ctx)
605self._pt_saved_tensors_bdims = ()
606self._pt_current_level = current_level
607
608def save_for_backward(self, *tensors):
609unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
610self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
611self._pt_saved_tensors_bdims = bdims
612
613def save_for_forward(self, *tensors):
614unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
615self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
616self._pt_saved_tensors_bdims = bdims
617
618
619def reductify(
620grad_input,
621grad_input_bdim,
622input_bdim,
623batch_size,
624target_shape_without_bdim_to_reduce_to=None,
625):
626if not isinstance(grad_input, tuple):
627grad_input = (grad_input,)
628if not isinstance(grad_input_bdim, tuple):
629grad_input_bdim = (grad_input_bdim,)
630if not isinstance(input_bdim, tuple):
631input_bdim = (input_bdim,)
632
633if target_shape_without_bdim_to_reduce_to is None:
634target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
635result = tuple(
636reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
637for gi, gi_bdim, i_bdim, maybe_ishape in zip(
638grad_input,
639grad_input_bdim,
640input_bdim,
641target_shape_without_bdim_to_reduce_to,
642)
643)
644return result
645
646
647def reductify_leaf(
648grad_input,
649grad_input_bdim,
650input_bdim,
651batch_size,
652target_shape_without_bdim_to_reduce_to=None,
653):
654if grad_input is None:
655return None
656
657if grad_input_bdim is None and input_bdim is None:
658return grad_input
659
660if grad_input_bdim is not None and input_bdim is None:
661return grad_input.sum(grad_input_bdim)
662
663# NOTE: [Why can't we rely on autograd to reduce expanded gradients?]
664# For reverse-mode AD,
665# given a grad_input and input, it is valid for the user to return a
666# grad_input that has a broadcasted shape when compared to the input.
667# In this situation, autograd automatically reduces the grad_input to
668# the shape of the input.
669#
670# However, when input_bdim is not None, we have problems.
671#
672# [example 1]
673# grad_input: Tensor[3, 4], input: Tensor[B, 4]
674# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable
675# from [B, 4].
676#
677# [example 2]
678# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]
679# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable
680# from [B, 4].
681#
682# This means that we need to also reduce the grad_input to the shape of the
683# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;
684# if not-None then we do the reducing manually, otherwise, we do not do a reduction.
685assert input_bdim is not None
686
687if grad_input_bdim is None:
688grad_input = grad_input.unsqueeze(input_bdim)
689new_shape = list(grad_input.shape)
690new_shape[input_bdim] = batch_size
691grad_input = grad_input.expand(new_shape)
692grad_input_bdim = input_bdim
693
694if target_shape_without_bdim_to_reduce_to is not None:
695return vmap(
696torch.Tensor.sum_to_size,
697in_dims=(grad_input_bdim, None),
698out_dims=input_bdim,
699)(grad_input, target_shape_without_bdim_to_reduce_to)
700
701if input_bdim != grad_input_bdim:
702grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
703return grad_input
704
705
706def autograd_function_forward_rewritten(original_forward, original_setup_context):
707def new_forward(ctx, *args, **kwargs):
708output = original_forward(*args, **kwargs)
709original_setup_context(ctx, args, output)
710return output
711
712return new_forward
713
714
715class AutogradFunctionApply(HigherOrderOperator):
716def __init__(self) -> None:
717super().__init__("autograd_function_apply")
718
719def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
720saved_values = None
721args_tensor_mask = fwd_kwargs["args_tensor_mask"]
722non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
723length_of_tensor_args = sum(args_tensor_mask)
724# Filter out the original tensor args from fwd_args,
725# lifted freevars should not be args of ApplyTemplate.apply
726# since we don't need to calculate the gradients of them.
727new_fwd_args = fwd_args[:length_of_tensor_args]
728
729class ApplyTemplate(torch.autograd.Function):
730@staticmethod
731def forward(ctx, *args):
732nonlocal saved_values
733output, saved_values = fwd(None, *fwd_args)
734
735# If users call ctx.mark_non_differentiable() in the original fwd function.
736if len(non_differentiable_idx) > 0:
737non_differentiable_output = []
738for i, x in enumerate(output):
739if i in non_differentiable_idx:
740non_differentiable_output.append(x)
741ctx.mark_non_differentiable(*non_differentiable_output)
742
743return output
744
745@staticmethod
746def backward(ctx, *grad):
747return bwd(None, *grad, *saved_values)
748
749return ApplyTemplate.apply(*new_fwd_args)
750
751
752autograd_function_apply = AutogradFunctionApply()
753