8
from collections import defaultdict, namedtuple
24
from torch.autograd.variable import Variable
25
from torch.utils._python_dispatch import TorchDispatchMode
26
from torch.utils.hooks import RemovableHandle
28
log = logging.getLogger(__name__)
32
"saved_tensors_hooks",
34
"disable_saved_tensors_hooks",
35
"register_multi_grad_hook",
36
"allow_mutation_on_saved_tensors",
46
def name(self) -> str:
52
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
54
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
55
>>> print(b.grad_fn.name())
62
def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
66
def metadata(self) -> dict:
67
r"""Return the metadata."""
71
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
75
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
76
r"""Register a backward hook.
78
The hook will be called every time a gradient with respect to the
79
Node is computed. The hook should have the following signature::
81
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
84
The hook should not modify its argument, but it can optionally return
85
a new gradient which will be used in place of :attr:`grad_inputs`.
87
This function returns a handle with a method ``handle.remove()``
88
that removes the hook from the module.
91
See :ref:`backward-hooks-execution` for more information on how when this hook
92
is executed, and how its execution is ordered relative to other hooks.
97
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
99
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
100
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
101
>>> b.sum().backward(retain_graph=True)
104
>>> handle.remove() # Removes the hook
106
>>> b.sum().backward(retain_graph=True)
113
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
114
r"""Register a backward pre-hook.
116
The hook will be called every time a gradient with respect to the
117
Node is computed. The hook should have the following signature::
119
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
121
The hook should not modify its argument, but it can optionally return
122
a new gradient which will be used in place of :attr:`grad_outputs`.
124
This function returns a handle with a method ``handle.remove()``
125
that removes the hook from the module.
128
See :ref:`backward-hooks-execution` for more information on how when this hook
129
is executed, and how its execution is ordered relative to other hooks.
133
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
135
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
136
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
137
>>> b.sum().backward(retain_graph=True)
142
>>> b.sum().backward(retain_graph=True)
149
def __subclasshook__(cls, C):
152
C is not None and C is getattr(torch._C._functions, C.__name__, None)
153
) or issubclass(C, torch.autograd.function.BackwardCFunction):
155
return NotImplemented
158
def _get_grad_fn_or_grad_acc(t):
159
if t.requires_grad and t.grad_fn is None:
160
return t.view_as(t).grad_fn.next_functions[0][0]
165
GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
166
GradientEdge.__doc__ = """\
167
Object representing a given gradient edge within the autograd graph.
168
To get the gradient edge where a given Tensor gradient will be computed,
169
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
173
def get_gradient_edge(tensor):
174
"""Get the gradient edge for computing the gradient of the given Tensor.
176
In particular, it is equivalent to call
177
``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
179
if not tensor.requires_grad:
181
"It is not possible to get the gradient edge for a Tensor that does not require gradients"
183
grad_fn = _get_grad_fn_or_grad_acc(tensor)
187
return GradientEdge(grad_fn, tensor.output_nr)
190
def increment_version(tensor):
191
"""Update autograd metadata tracking whether the given Tensor was modified in place.
193
This is to enable more accurate error checking within the autograd engine.
194
It is already done automatically by PyTorch functions and within custom Function
195
when mark_dirty() is called appropriately so you only need to call this explicitly
196
if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
197
know about. For example a custom kernel that reads the Tensor data_ptr and modifies
198
the memory inplace based on this pointer.
200
Note that incrementing the version counter multiple times for a single inplace operation
203
torch._C._increment_version(tensor)
206
class saved_tensors_hooks:
207
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
209
Use this context-manager to define how intermediary results of an operation
210
should be packed before saving, and unpacked on retrieval.
212
In that context, the ``pack_hook`` function will be called everytime an
213
operation saves a tensor for backward (this includes intermediary results
215
:func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
216
also those recorded by a PyTorch-defined operation). The output of
217
``pack_hook`` is then stored in the computation graph instead of the
220
The ``unpack_hook`` is called when the saved tensor needs to be accessed,
221
namely when executing :func:`torch.Tensor.backward()` or
222
:func:`torch.autograd.grad()`. It takes as argument the *packed* object
223
returned by ``pack_hook`` and should return a tensor which has the same
224
content as the original tensor (passed as input to the corresponding
227
The hooks should have the following signatures:
229
pack_hook(tensor: Tensor) -> Any
231
unpack_hook(Any) -> Tensor
233
where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
235
In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
236
of value, size, dtype and device.
240
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
241
>>> def pack_hook(x):
242
... print("Packing", x)
245
>>> def unpack_hook(x):
246
... print("Unpacking", x)
249
>>> a = torch.ones(5, requires_grad=True)
250
>>> b = torch.ones(5, requires_grad=True) * 2
251
>>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
253
Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
254
Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
255
>>> y.sum().backward()
256
Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
257
Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
260
Performing an inplace operation on the input to either hooks may lead
261
to undefined behavior.
264
Only one pair of hooks is allowed at a time. When recursively nesting this
265
context-manager, only the inner-most pair of hooks will be applied.
270
pack_hook: Callable[[torch.Tensor], Any],
271
unpack_hook: Callable[[Any], torch.Tensor],
273
self.pack_hook = pack_hook
274
self.unpack_hook = unpack_hook
277
torch._C._autograd._push_saved_tensors_default_hooks(
278
self.pack_hook, self.unpack_hook
281
def __exit__(self, *args: object):
282
torch._C._autograd._pop_saved_tensors_default_hooks()
285
class save_on_cpu(saved_tensors_hooks):
286
"""Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
288
When performing operations within this context manager, intermediary
289
results saved in the graph during the forward pass will be moved to CPU,
290
then copied back to the original device when needed for the backward pass.
291
If the graph was already on CPU, no tensor copy is performed.
293
Use this context-manager to trade compute for GPU memory usage (e.g.
294
when your model doesn't fit in GPU memory during training).
297
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
298
during packing and copied to GPU asynchronously during unpacking.
299
Defaults to ``False``.
300
Also see :ref:`cuda-memory-pinning`.
305
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
306
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
307
>>> a = torch.randn(5, requires_grad=True, device="cuda")
308
>>> b = torch.randn(5, requires_grad=True, device="cuda")
309
>>> c = torch.randn(5, requires_grad=True, device="cuda")
312
... prod_1 = a * b # a and b are saved on GPU
313
... with torch.autograd.graph.save_on_cpu():
314
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
315
... y = prod_2 * a # prod_2 and a are saved on GPU
319
>>> del a, b, c # for illustration only
320
>>> # the content of a, b, and prod_2 are still alive on GPU
321
>>> # the content of prod_1 and c only live on CPU
322
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
323
>>> # all intermediary tensors are released (deleted) after the call to backward
327
def __init__(self, pin_memory=False, device_type="cuda"):
328
device_module = getattr(torch, device_type, torch.cuda)
330
def pack_to_cpu(tensor):
332
return (tensor.device, tensor.cpu())
333
packed = torch.empty(
336
layout=tensor.layout,
337
pin_memory=(device_module.is_available() and not tensor.is_sparse),
340
return (tensor.device, packed)
342
def unpack_from_cpu(packed):
343
device, tensor = packed
344
return tensor.to(device, non_blocking=pin_memory)
346
super().__init__(pack_to_cpu, unpack_from_cpu)
349
@contextlib.contextmanager
350
def disable_saved_tensors_hooks(error_message):
351
"""Context-manager that disables the saved tensors default hooks feature.
353
Useful for if you are creating a feature that does not work with saved
354
tensors default hooks.
357
error_message (str): When saved tensors default hooks are used when they
358
have been are disabled, a RuntimeError with this
359
error message gets raised.
363
>>> # xdoctest: +SKIP(failing)
364
>>> message = "saved tensors default hooks are disabled"
365
>>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
366
... # Raises RuntimeError: saved tensors default hooks are disabled
367
... with torch.autograd.graph.save_on_cpu():
372
maybe_prev_message = (
373
torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
375
torch._C._autograd._saved_tensors_hooks_disable(error_message)
379
if maybe_prev_message is None:
380
torch._C._autograd._saved_tensors_hooks_enable()
382
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
385
def register_multi_grad_hook(
386
tensors: Sequence[torch.Tensor],
388
Callable[[Sequence[Optional[torch.Tensor]]], None],
389
Callable[[torch.Tensor], None],
394
r"""Register a multi-grad backward hook.
396
There are two supported modes: ``"all"`` and ``"any"``.
398
Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
399
:attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
400
is not part of the graph, or if a tensor is not needed to compute the gradients
401
for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
402
this tensor will be ignored and the hook will not wait for its gradient to be
405
After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
406
called with those gradients. ``None`` will be passed for tensors that did not
407
have their gradients computed.
409
Under the ``"any"`` mode, the hook will be called after the first gradient
410
with respect to a tensor in :attr:`tensors` has been computed. The hook
411
will be called with that gradient as its argument.
413
The hook should not modify its arguments.
415
This function returns a handle with a method ``handle.remove()`` that removes the hook.
418
See :ref:`backward-hooks-execution` for more information on how when this hook
419
is executed, and how its execution is ordered relative to other hooks.
425
>>> a = torch.rand(2, 3, requires_grad=True)
426
>>> b = torch.rand(2, 3, requires_grad=True)
431
... print([g is not None for g in grads])
433
>>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
435
>>> c.sum().backward(retain_graph=True)
436
[True, True, True, False]
437
>>> c.sum().backward(inputs=(a,), retain_graph=True)
438
[True, False, True, False]
441
supported_modes = ("all", "any")
442
if mode not in supported_modes:
443
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
445
class Handle(RemovableHandle):
446
handles: Tuple[RemovableHandle, ...]
448
def __init__(self, handles: Tuple[RemovableHandle, ...]):
449
self.handles = handles
452
for handle in self.handles:
455
def __getstate__(self):
458
def __setstate__(self, state):
462
count: Dict[int, int] = dict()
464
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
466
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
467
len_tensors = len(tensors)
469
def get_inner_hook(idx):
470
def inner_hook(grad: torch.Tensor):
471
nonlocal count, nb_calls, buffer, fn
472
id = torch._C._current_graph_task_id()
475
), "expected this hook to be called inside a backward call"
476
count[id] = count.get(id, 0)
477
buffer[id] = buffer.get(id, [None] * len_tensors)
481
nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns)
483
buffer[id][idx] = grad
486
if count[id] == nb_calls:
487
fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
494
handles: Tuple[RemovableHandle] = tuple(
495
t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
498
fn = cast(Callable[[torch.Tensor], None], fn)
499
lock = threading.Lock()
500
ran_hook: Dict[int, bool] = defaultdict(bool)
503
def wrapped_fn(grad: torch.Tensor):
505
id = torch._C._current_graph_task_id()
506
assert id != -1, "expected this hook to be called inside a backward call"
508
prev, ran_hook[id] = ran_hook[id], True
514
tensor.register_hook(wrapped_fn)
515
for tensor in tensors
516
if tensor.requires_grad
519
return Handle(handles)
535
_allow_mutation_on_saved_tensors_enabled = False
538
def _get_tid(t) -> Tuple[int, int, int]:
539
return (id(t), t.data_ptr(), t._version)
542
def _get_sid(t) -> Tuple[int, int]:
543
return (t.data_ptr(), t._version)
550
class _swap_with_cloned(saved_tensors_hooks):
551
def __init__(self, ctx):
556
handle: Optional[_Handle] = None
559
ctx.sid_to_tid[sid].add(tid)
562
if tid not in ctx.tid_to_weakhandle:
564
ctx.tid_to_weakhandle[tid] = handle
565
ctx.original[handle] = t
568
handle = ctx.tid_to_weakhandle[tid]
571
def unpack_hook(tup):
574
"Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
575
"in which the graph was originally recorded."
577
assert _allow_mutation_on_saved_tensors_enabled, error_msg
578
if handle in ctx.cloned:
579
res = ctx.cloned[handle]
581
assert handle in ctx.original, error_msg
582
res = ctx.original[handle]
585
super().__init__(pack_hook, unpack_hook)
588
class _CloneArgBeforeMutateMode(TorchDispatchMode):
589
def __init__(self, ctx):
592
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
593
kwargs = kwargs or {}
595
for idx, arg in enumerate(func._schema.arguments):
596
if arg.alias_info is not None and arg.alias_info.is_write:
597
t = kwargs["out"] if arg.is_out else args[idx]
601
if sid in ctx.sid_to_tid:
602
for tid in ctx.sid_to_tid[sid]:
603
if tid not in ctx.tid_to_weakhandle:
614
handle = ctx.tid_to_weakhandle[tid]
615
if handle in ctx.cloned:
618
ctx.cloned[handle] = ctx.original[handle].clone()
619
del ctx.original[handle]
621
rs = func(*args, **kwargs)
625
class _AllowMutationOnSavedContext:
627
self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
628
self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
629
self.tid_to_weakhandle: weakref.WeakValueDictionary = (
630
weakref.WeakValueDictionary()
632
self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
638
self.original.clear()
639
self.tid_to_weakhandle.clear()
640
self.sid_to_tid.clear()
643
@contextlib.contextmanager
644
def allow_mutation_on_saved_tensors():
645
"""Context manager under which mutating tensors saved for backward is allowed.
647
Under this context manager, tensors saved for backward are cloned on mutation,
648
so the original version can still be used during backward. Normally, mutating a tensor
649
saved for backward will result in an error raised when it's used during backward.
651
To ensure the correct behavior, both the forward and backward should be run under
652
the same context manager.
655
An _AllowMutationOnSavedContext object storing the state managed by this
656
context manager. This object can be useful for debugging purposes. The state
657
managed by the context manager is automatically cleared upon exiting.
662
>>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
664
... a = torch.ones(2, 3, requires_grad=True)
666
... out = (b**2).sum()
669
... out.sum().backward()
671
tensor([[0.8415, 0.8415, 0.8415],
672
[0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
674
global _allow_mutation_on_saved_tensors_enabled
676
ctx = _AllowMutationOnSavedContext()
678
with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
680
if _allow_mutation_on_saved_tensors_enabled:
682
"allow_mutation_on_saved_tensors contexts cannot be nested"
684
_allow_mutation_on_saved_tensors_enabled = True
688
_allow_mutation_on_saved_tensors_enabled = False
691
def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
692
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
694
def iter_graph(roots):
698
q: Deque = collections.deque()
706
for fn, _idx in node.next_functions:
707
if fn in seen or fn is None:
716
from torch.testing._internal.common_utils import dtype_abbrs
720
return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
722
def prehook(grad_outputs):
723
node = torch._C._current_autograd_node()
724
grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
725
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
729
for node in iter_graph(grad_fns):
730
handles.append(node.register_prehook(prehook))
732
def unregister_hooks():
733
for handle in handles:
736
return unregister_hooks
739
def _engine_run_backward(t_outputs, *args, **kwargs):
740
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
741
if attach_logging_hooks:
742
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
744
return Variable._execution_engine.run_backward(
745
t_outputs, *args, **kwargs
748
if attach_logging_hooks: