6
from collections import defaultdict, deque
27
from typing_extensions import TypeAlias
28
from weakref import WeakKeyDictionary, WeakValueDictionary
31
from torch.autograd.variable import Variable
32
from torch.utils._python_dispatch import TorchDispatchMode
33
from torch.utils.hooks import RemovableHandle
37
from torch._ops import OpOverload
41
"saved_tensors_hooks",
43
"disable_saved_tensors_hooks",
44
"register_multi_grad_hook",
45
"allow_mutation_on_saved_tensors",
53
log = logging.getLogger(__name__)
58
def name(self) -> str:
64
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
66
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
67
>>> print(b.grad_fn.name())
70
raise NotImplementedError
74
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
75
raise NotImplementedError
78
def metadata(self) -> dict:
79
r"""Return the metadata."""
80
raise NotImplementedError
84
def _input_metadata(self) -> List[Any]:
85
raise NotImplementedError
88
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
89
raise NotImplementedError
92
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
93
r"""Register a backward hook.
95
The hook will be called every time a gradient with respect to the
96
Node is computed. The hook should have the following signature::
98
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
101
The hook should not modify its argument, but it can optionally return
102
a new gradient which will be used in place of :attr:`grad_inputs`.
104
This function returns a handle with a method ``handle.remove()``
105
that removes the hook from the module.
108
See :ref:`backward-hooks-execution` for more information on how when this hook
109
is executed, and how its execution is ordered relative to other hooks.
114
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
116
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
117
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
118
>>> b.sum().backward(retain_graph=True)
121
>>> handle.remove() # Removes the hook
123
>>> b.sum().backward(retain_graph=True)
127
raise NotImplementedError
130
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
131
r"""Register a backward pre-hook.
133
The hook will be called every time a gradient with respect to the
134
Node is computed. The hook should have the following signature::
136
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
138
The hook should not modify its argument, but it can optionally return
139
a new gradient which will be used in place of :attr:`grad_outputs`.
141
This function returns a handle with a method ``handle.remove()``
142
that removes the hook from the module.
145
See :ref:`backward-hooks-execution` for more information on how when this hook
146
is executed, and how its execution is ordered relative to other hooks.
150
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
152
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
153
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
154
>>> b.sum().backward(retain_graph=True)
159
>>> b.sum().backward(retain_graph=True)
163
raise NotImplementedError
166
def __subclasshook__(cls, subclass: type) -> bool:
170
and subclass is getattr(torch._C._functions, subclass.__name__, None)
172
or issubclass(subclass, torch.autograd.function.BackwardCFunction)
175
return NotImplemented
178
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
179
if isinstance(t, GradientEdge):
181
if t.requires_grad and t.grad_fn is None:
182
with torch.enable_grad():
183
node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
186
assert node is not None
190
class GradientEdge(NamedTuple):
191
"""Object representing a given gradient edge within the autograd graph.
193
To get the gradient edge where a given Tensor gradient will be computed,
194
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
201
def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
202
"""Get the gradient edge for computing the gradient of the given Tensor.
204
In particular, it is equivalent to call
205
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
207
if not tensor.requires_grad:
209
"It is not possible to get the gradient edge for a Tensor "
210
"that does not require gradients",
212
grad_fn = _get_grad_fn_or_grad_acc(tensor)
214
# Note that output_nr default to 0 which is the right value
215
# for the AccumulateGrad node.
216
return GradientEdge(grad_fn, tensor.output_nr)
219
def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
220
"""Update autograd metadata tracking whether the given Tensor was modified in place.
222
This is to enable more accurate error checking within the autograd engine.
223
It is already done automatically by PyTorch functions and within custom Function
224
when mark_dirty() is called appropriately so you only need to call this explicitly
225
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
226
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
227
the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
229
Note that incrementing the version counter multiple times for a single inplace operation
232
Note that if you pass in tensor constructed under torch.inference_mode(),
233
we will not bump its version counter (because your tensor does not have one).
235
if isinstance(tensor, torch.Tensor):
237
torch._C._increment_version(tensor)
240
class saved_tensors_hooks:
241
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
243
Use this context-manager to define how intermediary results of an operation
244
should be packed before saving, and unpacked on retrieval.
246
In that context, the ``pack_hook`` function will be called everytime an
247
operation saves a tensor for backward (this includes intermediary results
249
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
250
also those recorded by a PyTorch-defined operation). The output of
251
``pack_hook`` is then stored in the computation graph instead of the
254
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
255
namely when executing :func:`torch.Tensor.backward()` or
256
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
257
returned by ``pack_hook`` and should return a tensor which has the same
258
content as the original tensor (passed as input to the corresponding
261
The hooks should have the following signatures:
263
pack_hook(tensor: Tensor) -> Any
265
unpack_hook(Any) -> Tensor
267
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
269
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
270
of value, size, dtype and device.
274
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
275
>>> def pack_hook(x):
276
... print("Packing", x)
279
>>> def unpack_hook(x):
280
... print("Unpacking", x)
283
>>> a = torch.ones(5, requires_grad=True)
284
>>> b = torch.ones(5, requires_grad=True) * 2
285
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
287
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
288
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
289
>>> y.sum().backward()
290
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
291
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
294
Performing an inplace operation on the input to either hooks may lead
295
to undefined behavior.
298
Only one pair of hooks is allowed at a time. When recursively nesting this
299
context-manager, only the inner-most pair of hooks will be applied.
304
pack_hook: Callable[[torch.Tensor], Any],
305
unpack_hook: Callable[[Any], torch.Tensor],
307
self.pack_hook = pack_hook
308
self.unpack_hook = unpack_hook
310
def __enter__(self) -> None:
311
torch._C._autograd._push_saved_tensors_default_hooks(
312
self.pack_hook, self.unpack_hook
315
def __exit__(self, *args: object) -> None:
316
torch._C._autograd._pop_saved_tensors_default_hooks()
319
class save_on_cpu(saved_tensors_hooks):
320
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
322
When performing operations within this context manager, intermediary
323
results saved in the graph during the forward pass will be moved to CPU,
324
then copied back to the original device when needed for the backward pass.
325
If the graph was already on CPU, no tensor copy is performed.
327
Use this context-manager to trade compute for GPU memory usage (e.g.
328
when your model doesn't fit in GPU memory during training).
331
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
332
during packing and copied to GPU asynchronously during unpacking.
333
Defaults to ``False``.
334
Also see :ref:`cuda-memory-pinning`.
339
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
340
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
341
>>> a = torch.randn(5, requires_grad=True, device="cuda")
342
>>> b = torch.randn(5, requires_grad=True, device="cuda")
343
>>> c = torch.randn(5, requires_grad=True, device="cuda")
346
... prod_1 = a * b # a and b are saved on GPU
347
... with torch.autograd.graph.save_on_cpu():
348
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
349
... y = prod_2 * a # prod_2 and a are saved on GPU
353
>>> del a, b, c # for illustration only
354
>>> # the content of a, b, and prod_2 are still alive on GPU
355
>>> # the content of prod_1 and c only live on CPU
356
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
357
>>> # all intermediary tensors are released (deleted) after the call to backward
360
def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
361
device_module = getattr(torch, device_type, torch.cuda)
363
def pack_to_cpu(tensor: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
365
return (tensor.device, tensor.cpu())
366
packed = torch.empty(
369
layout=tensor.layout,
370
pin_memory=(device_module.is_available() and not tensor.is_sparse),
373
return (tensor.device, packed)
375
def unpack_from_cpu(packed: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
376
device, tensor = packed
377
return tensor.to(device, non_blocking=pin_memory)
379
super().__init__(pack_to_cpu, unpack_from_cpu)
382
@contextlib.contextmanager
383
def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]:
384
"""Context-manager that disables the saved tensors default hooks feature.
386
Useful for if you are creating a feature that does not work with saved
387
tensors default hooks.
390
error_message (str): When saved tensors default hooks are used when they
391
have been are disabled, a RuntimeError with this
392
error message gets raised.
396
>>> # xdoctest: +SKIP(failing)
397
>>> message = "saved tensors default hooks are disabled"
398
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
399
... # Raises RuntimeError: saved tensors default hooks are disabled
400
... with torch.autograd.graph.save_on_cpu():
403
maybe_prev_message = None
405
maybe_prev_message = (
406
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
408
torch._C._autograd._saved_tensors_hooks_disable(error_message)
411
# See NOTE: [disabled_error_message invariant]
412
if maybe_prev_message is None:
413
torch._C._autograd._saved_tensors_hooks_enable()
415
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
418
class _MultiHandle(RemovableHandle):
419
handles: Tuple[RemovableHandle, ...]
421
def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
422
self.handles = handles
424
def remove(self) -> None:
425
for handle in self.handles:
428
def __getstate__(self) -> Tuple[RemovableHandle, ...]:
431
def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
435
def register_multi_grad_hook(
436
tensors: Sequence[torch.Tensor],
438
Callable[[Sequence[Optional[torch.Tensor]]], None],
439
Callable[[torch.Tensor], None],
442
mode: Literal["all", "any"] = "all",
444
r"""Register a multi-grad backward hook.
446
There are two supported modes: ``"all"`` and ``"any"``.
448
Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
449
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
450
is not part of the graph, or if a tensor is not needed to compute the gradients
451
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
452
this tensor will be ignored and the hook will not wait for its gradient to be
455
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
456
called with those gradients. ``None`` will be passed for tensors that did not
457
have their gradients computed.
459
Under the ``"any"`` mode, the hook will be called after the first gradient
460
with respect to a tensor in :attr:`tensors` has been computed. The hook
461
will be called with that gradient as its argument.
463
The hook should not modify its arguments.
465
This function returns a handle with a method ``handle.remove()`` that removes the hook.
468
See :ref:`backward-hooks-execution` for more information on how when this hook
469
is executed, and how its execution is ordered relative to other hooks.
475
>>> a = torch.rand(2, 3, requires_grad=True)
476
>>> b = torch.rand(2, 3, requires_grad=True)
481
... print([g is not None for g in grads])
483
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
485
>>> c.sum().backward(retain_graph=True)
486
[True, True, True, False]
487
>>> c.sum().backward(inputs=(a,), retain_graph=True)
488
[True, False, True, False]
491
supported_modes = ("all", "any")
492
lock = threading.Lock()
494
if mode not in supported_modes:
495
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
498
count: Dict[int, int] = {}
500
buffer: Dict[int, List[Optional[torch.Tensor]]] = {}
502
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
503
len_tensors = len(tensors)
505
def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]:
506
def inner_hook(grad: torch.Tensor) -> None:
507
nonlocal count, nb_calls, buffer, fn
508
id = torch._C._current_graph_task_id()
511
), "expected this hook to be called inside a backward call"
512
count[id] = count.get(id, 0)
513
buffer[id] = buffer.get(id, [None] * len_tensors)
516
curr_count, count[id] = count[id], count[id] + 1
519
# On the first call, compute the actual nb_calls and buffer
521
map(torch._C._will_engine_execute_node, grad_fns)
524
buffer[id][idx] = grad
526
assert nb_calls is not None
527
if curr_count == nb_calls - 1:
528
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
536
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
539
fn = cast(Callable[[torch.Tensor], None], fn)
540
ran_hook: Dict[int, bool] = defaultdict(bool)
543
def wrapped_fn(grad: torch.Tensor) -> None:
545
id = torch._C._current_graph_task_id()
546
assert id != -1, "expected this hook to be called inside a backward call"
548
prev, ran_hook[id] = ran_hook[id], True
554
tensor.register_hook(wrapped_fn)
555
for tensor in tensors
556
if tensor.requires_grad
559
return _MultiHandle(handles) # type: ignore[possibly-undefined]
562
# NOTE [Allow mutation on tensors saved for backward]
564
# 1. Tensor gets saved for backward
565
# - remember the python object id and the version of the tensor
566
# - remember aliasing information (data_ptr of base + version)
567
# - save the original so we control its lifetime
568
# 2. Any time a tensor gets in-placed
569
# - for each tensor aliased to it:
570
# - check using its object id and version to see if it has been saved
571
# - if it has been saved, clone it
572
# - delete the reference to the original
574
# - if the clone exists, the tensor must've been modified in-place
575
_allow_mutation_on_saved_tensors_enabled: bool = False
578
_TID: TypeAlias = Tuple[int, int, int]
579
_SID: TypeAlias = Tuple[int, int]
582
def _get_tid(tensor: torch.Tensor) -> _TID:
583
# FIXME: This is almost definitely a bug.
587
torch._subclasses.fake_tensor.FakeTensor,
588
torch._subclasses.functional_tensor.FunctionalTensor,
593
data_ptr = tensor.data_ptr()
594
return (id(tensor), data_ptr, tensor._version)
597
def _get_sid(tensor: torch.Tensor) -> _SID:
598
# FIXME: This is almost definitely a bug.
602
torch._subclasses.fake_tensor.FakeTensor,
603
torch._subclasses.functional_tensor.FunctionalTensor,
608
data_ptr = tensor.data_ptr()
609
return (data_ptr, tensor._version)
616
class _swap_with_cloned(saved_tensors_hooks):
617
def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
618
def pack_hook(tensor: torch.Tensor) -> _Handle:
619
tid = _get_tid(tensor)
620
sid = _get_sid(tensor)
621
# Tensors saved for backward have an entry in _tid_to_weakhandle
622
handle: Optional[_Handle] = None
624
# Save aliasing information
625
ctx.sid_to_tid[sid].add(tid)
627
# NB: The same tensor (of the same version) can be saved multiple times
628
if tid not in ctx.tid_to_weakhandle:
630
ctx.tid_to_weakhandle[tid] = handle
631
ctx.original[handle] = tensor
633
# Store an additional strong reference to the handle
634
handle = ctx.tid_to_weakhandle[tid]
637
def unpack_hook(handle: _Handle) -> torch.Tensor:
639
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
640
"in which the graph was originally recorded."
642
assert _allow_mutation_on_saved_tensors_enabled, error_msg
643
if handle in ctx.cloned:
644
res = ctx.cloned[handle]
646
assert handle in ctx.original, error_msg
647
res = ctx.original[handle]
650
super().__init__(pack_hook, unpack_hook)
653
class _CloneArgBeforeMutateMode(TorchDispatchMode):
654
def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
657
def __torch_dispatch__(
660
types: Iterable[type],
661
args: Tuple[Any, ...] = (),
662
kwargs: Optional[Dict[Any, Any]] = None,
664
kwargs = kwargs or {}
666
for idx, arg in enumerate(func._schema.arguments):
667
if arg.alias_info is not None and arg.alias_info.is_write:
668
t = kwargs["out"] if arg.is_out else args[idx]
672
if sid in ctx.sid_to_tid:
673
for tid in ctx.sid_to_tid[sid]:
674
if tid not in ctx.tid_to_weakhandle:
675
# We know that if tid is in sid_to_tid, then it must also be in
676
# tid_to_weakhandle. However, it is possible for the tensor to be
677
# saved at one point, but cleared by backward before it is modified
678
# in-place. Consider the following example:
680
# >>> a = torch.randn(2, 3, requires_grad=True).clone()
681
# >>> out = (a**2).sum()
685
handle = ctx.tid_to_weakhandle[tid]
686
if handle in ctx.cloned:
687
# The same exact tensor has been cloned already
689
ctx.cloned[handle] = ctx.original[handle].clone()
690
del ctx.original[handle]
692
return func(*args, **kwargs)
695
class _AllowMutationOnSavedContext:
696
def __init__(self) -> None:
697
self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
698
self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
699
self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
700
self.sid_to_tid: Dict[_SID, Set[_TID]] = defaultdict(set)
702
def clear(self) -> None:
704
self.original.clear()
705
self.tid_to_weakhandle.clear()
706
self.sid_to_tid.clear()
709
@contextlib.contextmanager
710
def allow_mutation_on_saved_tensors() -> (
711
Generator[_AllowMutationOnSavedContext, None, None]
713
"""Context manager under which mutating tensors saved for backward is allowed.
715
Under this context manager, tensors saved for backward are cloned on mutation,
716
so the original version can still be used during backward. Normally, mutating a tensor
717
saved for backward will result in an error raised when it's used during backward.
719
To ensure the correct behavior, both the forward and backward should be run under
720
the same context manager.
723
An _AllowMutationOnSavedContext object storing the state managed by this
724
context manager. This object can be useful for debugging purposes. The state
725
managed by the context manager is automatically cleared upon exiting.
730
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
732
... a = torch.ones(2, 3, requires_grad=True)
734
... out = (b**2).sum()
737
... out.sum().backward()
739
tensor([[0.8415, 0.8415, 0.8415],
740
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
742
global _allow_mutation_on_saved_tensors_enabled
744
ctx = _AllowMutationOnSavedContext()
746
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
748
if _allow_mutation_on_saved_tensors_enabled:
750
"allow_mutation_on_saved_tensors contexts cannot be nested"
752
_allow_mutation_on_saved_tensors_enabled = True
756
_allow_mutation_on_saved_tensors_enabled = False
759
def _register_logging_hooks_on_whole_graph(
760
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
761
) -> Callable[[], None]:
762
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
764
def iter_graph(roots: List[Node]) -> Iterator[Node]:
767
seen: Set[Node] = set()
768
q: Deque[Node] = deque()
776
for fn, _ in node.next_functions:
777
if fn in seen or fn is None:
784
def fmt(t: Optional[torch.Tensor]) -> str:
785
# Avoid circular import
786
from torch.testing._internal.common_utils import dtype_abbrs
790
return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
792
def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None:
793
node = torch._C._current_autograd_node()
794
grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
795
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
799
for node in iter_graph(grad_fns):
800
handles.append(node.register_prehook(prehook))
802
def unregister_hooks() -> None:
803
for handle in handles:
806
return unregister_hooks
809
def _engine_run_backward(
810
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
813
) -> Tuple[torch.Tensor, ...]:
814
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
815
if attach_logging_hooks:
816
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
818
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
819
t_outputs, *args, **kwargs
820
) # Calls into the C++ engine to run the backward pass
822
if attach_logging_hooks:
823
unregister_hooks() # type: ignore[possibly-undefined]