pytorch
659 строк · 25.3 Кб
1import torch2from torch._ops import HigherOrderOperator3from torch._C._functorch import TransformType4from torch._functorch.utils import enable_single_level_autograd_function5import torch.utils._pytree as pytree6from torch._C._functorch import (7_wrap_for_grad,8_unwrap_for_grad,9current_level,10)
11from torch._functorch.vmap import (12wrap_batched,13unwrap_batched,14restore_vmap,15_add_batch_dim,16)
17from torch._functorch.apis import vmap18from torch._functorch.vmap import _broadcast_to_and_flatten19from torch.autograd.forward_ad import _set_fwd_grad_enabled20from typing import Any, NamedTuple, Tuple21
22# autograd.Function technically runs before the regular PyTorch dispatcher.
23# This is how features like autocast and torch_dispatch (e.g. PythonTLSSnapshot)
24# work with it. One day we might decide to change this, but until then,
25# we need to give the illusion that autograd.Function runs before those things.
26#
27# We do this by using creating a custom HigherOrderOperator that only functorch
28# dispatches specially.
29class CustomFunctionHigherOrderOperator(HigherOrderOperator):30def __init__(self):31super().__init__('custom_function_call')32
33def __call__(self, autograd_function, *args, **kwargs):34# When custom_function_call is done dispatching through functorch,35# it should just invoke the autograd.Function. This is consistent36# with the autograd.Function behavior of being invoked before the37# PyTorch dispatcher.38#39# This will lead us into trouble later down the line, but this is40# pre-existing. There is an invariant that a function traced by41# make_fx should have the same behavior when provided the same42# Tensor. However, make_fx sees autograd.Function as a composite43# (because autograd.Function happens before the Python dispatch key)44# and only traces the forward pass.45if torch._C._are_functorch_transforms_active():46return super().__call__(autograd_function, *args, **kwargs)47return autograd_function.apply(*args, **kwargs)48
49
50# "custom_function_call"
51# This is the mechanism for an autograd.Function that works with functorch transforms.
52# It wraps an autograd.Function; interactions with functorch transforms are defined
53# via PyDispatcher and HigherOrderOperator rather than through the traditional PyTorch
54# dispatcher.
55custom_function_call = CustomFunctionHigherOrderOperator()56
57
58# The grad rule for custom_function_call is to construct a new _SingleLevelFunction
59# (autograd.Function that only works with a single layer (level) of functorch) that:
60# - unwraps the inputs
61# - redispatches to custom_function_call
62# - wraps the outputs
63# and whose backward pass calls the original autograd.Function's backward.
64#
65# Why do we need to redispatch to custom_function_call?
66# -----------------------------------------------------
67# This is consistent with how ATen operators work with functorch's grad transform:
68# they always redispatch to the original operator.
69# Consider torch.sin, and let's say we do grad0(grad1(torch.sin))(x)
70#
71# grad1 will:
72# - set up the autograd graph
73# - unwrap the inputs
74# - redispatch to at::sin (*)
75# - rewrap the outputs on the return
76#
77# On the redispatch in (*), grad0 will:
78# - set up the autograd graph
79# - unwrap the inputs
80# - redispatch to at::sin
81# - rewrap the outputs on the return
82#
83# To "set up the autograd graph", we generate a _SingleLevelFunction
84# and apply it.
85@custom_function_call.py_impl(TransformType.Grad)86@custom_function_call.py_impl(TransformType.Jvp)87def custom_function_call_grad(interpreter, autograd_function, *operands):88Generated = generate_single_level_function(interpreter, autograd_function)89with enable_single_level_autograd_function():90flat_out = Generated.apply(*operands)91return flat_out92
93
94def generate_single_level_function(interpreter, autograd_function):95level = interpreter.level()96
97def forward(*operands):98unwrapped_operands = pytree.tree_map_only(99torch.Tensor,100lambda x: _unwrap_for_grad(x, level),101operands)102# Both enable_grad() and _set_fwd_grad_enabled() are necessary no matter103# the transform. _SingleLevelFunction will turn off both fwd and bwd104# gradient computation and we need to turn it back on here.105with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():106unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)107
108# See NOTE [mark_dirty object identity check]109def wrap_fn(output):110return _wrap_for_grad(output, level)111
112return wrap_outputs_maintaining_identity(113unwrapped_output,114unwrapped_operands,115operands,116wrap_fn)117
118def setup_context(ctx, inputs, output):119return autograd_function.setup_context(ctx, inputs, output)120
121# backward is only used if the transform is TransformType.Grad122def backward(ctx, *grads):123result = autograd_function.backward(ctx, *grads)124return result125
126# jvp is only used if the transform is TransformType.Jvp127def jvp(ctx, *tangents):128result = autograd_function.jvp(ctx, *tangents)129return result130
131# This is the sequence of magic words to dynamically generate a Subclass with132# a given name. A Tensor's .grad_fn field has a class name that is the original133# autograd.Function's name + Backward, so we do this to generate some134# meaningful name.135name = f'{autograd_function.__name__}Generated'136Generated = type(137name,138(torch.autograd.function._SingleLevelFunction,),139{140'forward': staticmethod(forward),141'backward': staticmethod(backward),142'jvp': staticmethod(jvp),143'setup_context': staticmethod(setup_context),144},145)146return Generated147
148# wrap_outputs_maintaining_identity handles outputs from the vmap,
149# backward (vjp), and jvp staticmethod. The way it distinguishes
150# between the vmap case and the {backward, jvp} case is if the out_dims
151# are specified or not.
152#
153# NB: we cannot use out_dims=None as the deciding factor. This because
154# out_dims=None can still happen in the vmap staticmethod! What the
155# user is saying in that case is that their output does not have a
156# dimension that is being vmapped over, which is valid.
157NO_OUT_DIMS = "not specified"158
159# NOTE [mark_dirty object identity check]
160# autograd.Function's ctx.mark_dirty expect a returned input
161# to have the same object identity as the input.
162# Mode-only functorch will greatly simplify this logic.
163def wrap_outputs_maintaining_identity(164outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS):165flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)166flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)167
168unwrapped_input_to_orig_input = {169id(unwrapped): orig170for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)171}172
173flat_outputs, spec = pytree.tree_flatten(outputs)174result = []175
176out_dims_specified = out_dims != NO_OUT_DIMS177
178if out_dims_specified:179flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)180# _broadcast_to_and_flatten returns None if it is unable to broadcast.181# TODO: update following link from master to stable once that's out182if flat_out_dims is None:183raise RuntimeError(184f"The autograd.Function's vmap staticmethod returned an "185f"incompatible (output, out_dims) tuple. "186f"Expected out_dims={out_dims} "187f"to be compatible with the structure of `output`. "188f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "189f"but output has structure {spec}. "190f"For more details, please see "191f"https://pytorch.org/docs/master/notes/extending.func.html"192)193
194for i, output in enumerate(flat_outputs):195if not isinstance(output, torch.Tensor):196result.append(output)197continue198if id(output) in unwrapped_input_to_orig_input:199result.append(unwrapped_input_to_orig_input[id(output)])200continue201if out_dims_specified:202result.append(wrap_fn(output, flat_out_dims[i])) # type: ignore[possibly-undefined, index]203else:204result.append(wrap_fn(output))205
206return pytree.tree_unflatten(result, spec)207
208
209# NOTE: [functorch vjp and autograd interaction]
210# There's an edge case with the functorch vjp and autograd interaction
211# that will eventually be fixed by mode-only functorch.
212# The TL;DR is that there's no way to unwrap a dead GradTensorWrapper,
213# so we (the framework) need to do it manually. Regular PyTorch operators
214# automatically do so this is consistent.
215#
216# class MyExp(torch.autograd.Function):
217# @staticmethod
218# def forward(x):
219# return x.exp()
220#
221# @staticmethod
222# def setup_context(ctx, inputs, output):
223# y = output
224# ctx.save_for_backward(y)
225#
226# @staticmethod
227# def backward(gy):
228# y, = ctx.saved_tensors()
229# return MyMul.apply(gy, y)
230#
231# x = torch.randn([], requires_grad=True)
232# gy = torch.randn([], requires_grad=True)
233# _, vjp_fn = vjp(MySin.apply, x)
234# result = vjp_fn(gy)
235#
236# MyMul is an autograd.Function that is not shown here.
237# It saves a `y` for backward (since gy requires grad).
238#
239# in vjp_fn(gy), we get:
240# > MyMul.apply(gy, GradTensorWrapper(y, level=dead))
241# Because the y that is saved for backward by MyExp is a GradTensorWrapper
242# but is now dead since we are outside the vjp context.
243#
244# PyTorch dispatcher operations, upon seeing a dead GradTensorWrapper,
245# will automatically unwrap the GradTensorWrapper when applied.
246# But since autograd.Function technically sits above the regular PyTorch
247# dispatcher, it doesn't get this treatment. So we manually do
248# the unwrapping to be consistent with regular PyTorch dispatcher operations.
249
250
251class VmapInfo(NamedTuple):252batch_size: int253randomness: str254
255
256def has_overriden_vmap_rule(autograd_function):257return autograd_function.vmap is not torch.autograd.Function.vmap258
259
260def validate_vmap_returns_tuple_of_two_elements(result):261base_error_msg = (262"Expected the vmap staticmethod to have two returns, an output "263"and out_dims with pytree structure compatible with the output. "264)265if not isinstance(result, tuple):266raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")267if not len(result) == 2:268raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")269
270@custom_function_call.py_impl(TransformType.Vmap)271def custom_function_call_vmap(interpreter, autograd_function, *operands):272if autograd_function.generate_vmap_rule:273if has_overriden_vmap_rule(autograd_function):274# TODO: Update link to stable once that's out275# https://github.com/pytorch/pytorch/issues/92029276raise RuntimeError(277f"You tried to vmap over {autograd_function.__name__}, but "278f"it has both generate_vmap_rule=True and an overriden vmap "279f"staticmethod. Please set generate_vmap_rule=False or delete "280f"the overriden vmap staticmethod to avoid ambiguity. "281f"For more details, please see "282f"https://pytorch.org/docs/master/notes/extending.func.html")283return custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands)284
285if not has_overriden_vmap_rule(autograd_function):286# TODO: Update link to stable once that's out287# https://github.com/pytorch/pytorch/issues/92029288raise RuntimeError(289f"You tried to vmap over {autograd_function.__name__}, but "290f"it does not have vmap support. Please override and implement the "291f"vmap staticmethod or set generate_vmap_rule=True. "292f"For more details, please see "293f"https://pytorch.org/docs/master/notes/extending.func.html")294
295current_level = interpreter.level()296info = VmapInfo(297batch_size=interpreter.batch_size(),298randomness=interpreter.randomness(),299)300unwrapped_operands, in_dims = unwrap_batched(operands, current_level)301
302# If none of the tensors are batched at the current level, then we skip the303# current level. This saves the user from needing to handle this case in304# their vmap staticmethod (and is consistent with our C++ batching rule API)305if pytree.tree_all(lambda dim: dim is None, in_dims):306with interpreter.lower():307return custom_function_call(autograd_function, *operands)308
309with interpreter.lower():310result = autograd_function.vmap(info, in_dims, *unwrapped_operands)311validate_vmap_returns_tuple_of_two_elements(result)312unwrapped_output, out_dims = result313
314# See NOTE [mark_dirty object identity check]315def wrap_fn(output, out_dim):316return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)317
318return wrap_outputs_maintaining_identity(319unwrapped_output,320unwrapped_operands,321operands,322wrap_fn,323out_dims=out_dims)324
325
326def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):327unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())328vmapped_function, get_out_dims = vmapify_autograd_function(329autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness())330
331with interpreter.lower():332output = custom_function_call(vmapped_function, *unwrapped_operands)333
334out_dims = get_out_dims()335return wrap_batched(output, out_dims, interpreter.level())336
337
338@custom_function_call.py_impl(TransformType.Functionalize)339def custom_function_call_functionalize(interpreter, autograd_function, generate_vmap_rule, *operands):340raise RuntimeError("NYI: Functionalize rule for custom_function_call")341
342
343def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):344# The following values are saved from the forward() and setup_context()345# and used in backward().346# Why do we save the values out here instead of on the ctx object?347# - out_dims: There's no way to retrieve this from forward()348# - input_shapes, saved_tensors_bdims: I'm a bit scared of nesting349# vmap(vmap( but not completely sure if it is a problem. If we350# assigned those fields to the ctx object, the worry is that they351# get overwritten.352init_val = "not populated"353out_dims = init_val354input_shapes: Any = init_val355saved_tensors_bdims: Any = init_val356
357def forward(*operands):358nonlocal out_dims359outputs, out_dims = restore_vmap(360autograd_function.forward, in_dims, batch_size, randomness)(*operands)361return outputs362
363def setup_context(ctx, inputs, outputs):364input_shapes_ = None365saved_tensors_bdims_ = None366
367def inner(inputs, outputs):368# wrapped_ctx.save_for_backward will:369# - unwrap batchedtensors into (tensor, bdim)370# - save_for_backward(*unwrapped_tensors)371# - assign the bdims to wrapped_ctx._pt_saved_tensors_bdims372wrapped_ctx = CtxCustomSave(ctx, current_level())373autograd_function.setup_context(wrapped_ctx, inputs, outputs)374
375# input_shapes are used for reductify later to reduce expanded gradients376# to the correct shape.377# See NOTE: [Why can't we rely on autograd to reduce expanded gradients?]378# for more details379nonlocal input_shapes_380input_shapes_ = tuple(inp.shape if isinstance(inp, torch.Tensor) else None381for inp in inputs)382nonlocal saved_tensors_bdims_383saved_tensors_bdims_ = wrapped_ctx._pt_saved_tensors_bdims384
385# See NOTE: [Why do we need to run setup_context under a vmap?]386restore_vmap(387inner,388(in_dims, out_dims),389batch_size,390randomness,391)(inputs, outputs)392
393nonlocal input_shapes394input_shapes = input_shapes_395nonlocal saved_tensors_bdims396saved_tensors_bdims = saved_tensors_bdims_397
398def jvp(ctx, *tangents):399assert out_dims != init_val400assert saved_tensors_bdims != init_val401
402def jvp_no_context(saved_tensors, tangents):403wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)404return autograd_function.jvp(wrapped_ctx, *tangents)405
406tangent_in_dims = get_tangents_in_dims(in_dims, tangents)407out_tangents, out_tangents_dims = restore_vmap(408jvp_no_context, (saved_tensors_bdims, tangent_in_dims), batch_size, randomness)(409ctx.saved_tensors, tangents)410
411result = reductify(out_tangents, out_tangents_dims, out_dims, batch_size)412return result413
414def backward(ctx, *grad_outputs):415assert out_dims != init_val416assert input_shapes != init_val417assert saved_tensors_bdims != init_val418
419def backward_no_context(inputs):420saved_tensors, grad_outputs = inputs421wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)422return autograd_function.backward(wrapped_ctx, *grad_outputs)423
424grad_ins, grad_ins_dims = restore_vmap(425backward_no_context, ((saved_tensors_bdims, out_dims),), batch_size, randomness)(426(ctx.saved_tensors, grad_outputs))427result = reductify(grad_ins, grad_ins_dims, in_dims, batch_size, input_shapes)428return result429
430name = f'Vmapped{autograd_function.__name__}'431Generated = type(432name,433(torch.autograd.Function,),434{435'forward': staticmethod(forward),436'backward': staticmethod(backward),437'jvp': staticmethod(jvp),438'setup_context': staticmethod(setup_context),439'generate_vmap_rule': True440}441)442
443def get_out_dims():444assert out_dims != init_val445return out_dims446
447return Generated, get_out_dims448
449
450# tangents might be None, so we need to replace
451# the corresponding in_dims with None.
452def get_tangents_in_dims(input_dims, tangents):453flat_in_dims, spec = pytree.tree_flatten(input_dims)454flat_tangents = pytree.arg_tree_leaves(*tangents)455result = [None if tangent is None else in_dim456for in_dim, tangent in zip(flat_in_dims, flat_tangents)]457return pytree.tree_unflatten(result, spec)458
459
460# NOTE: [Why do we need to run setup_context under a vmap?]
461# Consider the following autograd.Function
462#
463# class Sum(torch.autograd.Function):
464# @staticmethod
465# def forward(x):
466# return x.sum()
467# @staticmethod
468# def setup_context(ctx, inputs, outputs):
469# ctx.x_shape = inputs[0]
470# @staticmethod
471# def backward(ctx, gy):
472# return gy.expand(ctx.x_shape)
473#
474# x = torch.randn(B, 4)
475# in_dims = 0
476# vmap(Sum.apply, in_dims)(x)
477#
478# Let’s assume for a moment that we didn’t vmap setup_context in VmappedSum:
479#
480# class VmappedSum(torch.autograd.Function):
481# @staticmethod
482# def forward(x):
483# return vmap(Sum.forward, in_dims)(x)
484#
485# @staticmethod
486# def setup_context(ctx, inputs, outputs):
487# Sum.setup_context(ctx, inputs, outputs)
488#
489# @staticmethod
490# def backward(ctx, gy):
491# def backward_no_context(gy):
492# return gy.expand(ctx.x_shape)
493#
494# dims = (0,)
495# gx = vmap(backward_no_context, dims)(gy)
496# return gx
497#
498# We end up saving [B, 4] as x_shape. In the backward, gy has shape [B],
499# and we’re doing:
500#
501# def backward_no_context(gy):
502# return gy.expand([B, 4])
503#
504# gx = vmap(backward_no_context, dims)(gy: "Tensor[B]")
505#
506# This gives us the wrong result (gx has shape [B, B, 4], but it should
507# have shape [4]). Performing vmap over setup_context means the shape
508# saved has shape [4] and leads to a correct result shape for gx.
509
510# Wraps a ctx object. Forwards all attr accesses to the underlying object
511# except for the attrs in _pt_attrs
512class WrappedCtx:513_pt_reserved_attrs: Tuple[str, ...] = ('_pt_reserved_attrs', '_pt_inner_ctx')514
515def __init__(self, ctx):516if not isinstance(ctx, WrappedCtx):517reserved_attrs = type(self)._pt_reserved_attrs518for name in reserved_attrs:519if not hasattr(ctx, name):520continue521raise RuntimeError(522f'PyTorch reserves the {reserved_attrs} field on ctx. '523'Please name your fields on ctx something else to avoid name '524'collision.')525self._pt_inner_ctx = ctx526
527def __getattr__(self, name):528return getattr(self._pt_inner_ctx, name)529
530def __setattr__(self, name, value):531if name in type(self)._pt_reserved_attrs:532self.__dict__[name] = value533return534return setattr(self._pt_inner_ctx, name, value)535
536# Wraps ctx to create a new ctx object that overrides saved_tensors.
537class CtxWithSavedTensors(WrappedCtx):538_pt_reserved_attrs = ('_pt_new_saved_tensors', *WrappedCtx._pt_reserved_attrs)539
540def __init__(self, ctx, new_saved_tensors):541super().__init__(ctx)542self._pt_new_saved_tensors = new_saved_tensors543
544@property545def saved_tensors(self):546return self._pt_new_saved_tensors547
548class CtxCustomSave(WrappedCtx):549_pt_reserved_attrs = ('_pt_saved_tensors_bdims', '_pt_current_level',550*WrappedCtx._pt_reserved_attrs)551
552def __init__(self, ctx, current_level):553super().__init__(ctx)554self._pt_saved_tensors_bdims = ()555self._pt_current_level = current_level556
557def save_for_backward(self, *tensors):558unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)559self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)560self._pt_saved_tensors_bdims = bdims561
562def save_for_forward(self, *tensors):563unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)564self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)565self._pt_saved_tensors_bdims = bdims566
567
568def reductify(grad_input, grad_input_bdim, input_bdim, batch_size,569target_shape_without_bdim_to_reduce_to=None):570if not isinstance(grad_input, tuple):571grad_input = (grad_input,)572if not isinstance(grad_input_bdim, tuple):573grad_input_bdim = (grad_input_bdim,)574if not isinstance(input_bdim, tuple):575input_bdim = (input_bdim,)576
577if target_shape_without_bdim_to_reduce_to is None:578target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)579result = tuple(580reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)581for gi, gi_bdim, i_bdim, maybe_ishape in582zip(grad_input, grad_input_bdim, input_bdim, target_shape_without_bdim_to_reduce_to)583)584return result585
586
587def reductify_leaf(grad_input, grad_input_bdim, input_bdim, batch_size,588target_shape_without_bdim_to_reduce_to=None):589if grad_input is None:590return None591
592if grad_input_bdim is None and input_bdim is None:593return grad_input594
595if grad_input_bdim is not None and input_bdim is None:596return grad_input.sum(grad_input_bdim)597
598# NOTE: [Why can't we rely on autograd to reduce expanded gradients?]599# For reverse-mode AD,600# given a grad_input and input, it is valid for the user to return a601# grad_input that has a broadcasted shape when compared to the input.602# In this situation, autograd automatically reduces the grad_input to603# the shape of the input.604#605# However, when input_bdim is not None, we have problems.606#607# [example 1]608# grad_input: Tensor[3, 4], input: Tensor[B, 4]609# We can expand grad_input to Tensor[B, 3, 4], but that isn't broadcastable610# from [B, 4].611#612# [example 2]613# grad_input: Tensor[3, B, 4], input: Tensor[B, 4]614# We can swizzle grad_input to Tensor[B, 3, 4], but that isn't broadcastable615# from [B, 4].616#617# This means that we need to also reduce the grad_input to the shape of the618# input. This behavior is controlled by the `target_shape_without_bdim_to_reduce_to` flag;619# if not-None then we do the reducing manually, otherwise, we do not do a reduction.620assert input_bdim is not None621
622if grad_input_bdim is None:623grad_input = grad_input.unsqueeze(input_bdim)624new_shape = list(grad_input.shape)625new_shape[input_bdim] = batch_size626grad_input = grad_input.expand(new_shape)627grad_input_bdim = input_bdim628
629if target_shape_without_bdim_to_reduce_to is not None:630return vmap(torch.Tensor.sum_to_size, in_dims=(grad_input_bdim, None), out_dims=input_bdim)(631grad_input, target_shape_without_bdim_to_reduce_to)632
633if input_bdim != grad_input_bdim:634grad_input = grad_input.movedim(grad_input_bdim, input_bdim)635return grad_input636
637
638class AutogradFunctionApply(HigherOrderOperator):639def __init__(self):640super().__init__("autograd_function_apply")641
642def __call__(self, fwd, bwd, *fwd_args):643saved_values = None644
645class ApplyTemplate(torch.autograd.Function):646@staticmethod647def forward(ctx, *args):648nonlocal saved_values649output, saved_values = fwd(None, *args)650return output651
652@staticmethod653def backward(ctx, *grad):654return bwd(None, *grad, *saved_values)655
656return ApplyTemplate.apply(*fwd_args)657
658
659autograd_function_apply = AutogradFunctionApply()660