pytorch

Форк
0
749 строк · 26.1 Кб
1
import abc
2
import collections
3
import contextlib
4
import functools
5
import logging
6
import threading
7
import weakref
8
from collections import defaultdict, namedtuple
9
from typing import (
10
    Any,
11
    Callable,
12
    cast,
13
    Deque,
14
    Dict,
15
    List,
16
    Optional,
17
    Sequence,
18
    Set,
19
    Tuple,
20
    Union,
21
)
22

23
import torch
24
from torch.autograd.variable import Variable
25
from torch.utils._python_dispatch import TorchDispatchMode
26
from torch.utils.hooks import RemovableHandle
27

28
log = logging.getLogger(__name__)
29

30

31
__all__ = [
32
    "saved_tensors_hooks",
33
    "save_on_cpu",
34
    "disable_saved_tensors_hooks",
35
    "register_multi_grad_hook",
36
    "allow_mutation_on_saved_tensors",
37
    "Node",
38
    "GradientEdge",
39
    "get_gradient_edge",
40
    "increment_version",
41
]
42

43

44
class Node(abc.ABC):
45
    @abc.abstractmethod
46
    def name(self) -> str:
47
        r"""Return the name.
48

49
        Example::
50

51
            >>> import torch
52
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
53
            >>> b = a.clone()
54
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
55
            >>> print(b.grad_fn.name())
56
            CloneBackward0
57
        """
58
        ...
59

60
    @property
61
    @abc.abstractmethod
62
    def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
63
        ...
64

65
    @abc.abstractmethod
66
    def metadata(self) -> dict:
67
        r"""Return the metadata."""
68
        ...
69

70
    @abc.abstractmethod
71
    def _register_hook_dict(self, tensor: torch.Tensor) -> None:
72
        ...
73

74
    @abc.abstractmethod
75
    def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
76
        r"""Register a backward hook.
77

78
        The hook will be called every time a gradient with respect to the
79
        Node is computed. The hook should have the following signature::
80

81
            hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
82

83

84
        The hook should not modify its argument, but it can optionally return
85
        a new gradient which will be used in place of :attr:`grad_inputs`.
86

87
        This function returns a handle with a method ``handle.remove()``
88
        that removes the hook from the module.
89

90
        .. note::
91
            See :ref:`backward-hooks-execution` for more information on how when this hook
92
            is executed, and how its execution is ordered relative to other hooks.
93

94
        Example::
95

96
            >>> import torch
97
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
98
            >>> b = a.clone()
99
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
100
            >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
101
            >>> b.sum().backward(retain_graph=True)
102
            >>> print(a.grad)
103
            tensor([2., 2., 2.])
104
            >>> handle.remove() # Removes the hook
105
            >>> a.grad = None
106
            >>> b.sum().backward(retain_graph=True)
107
            >>> print(a.grad)
108
            tensor([1., 1., 1.])
109
        """
110
        ...
111

112
    @abc.abstractmethod
113
    def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
114
        r"""Register a backward pre-hook.
115

116
        The hook will be called every time a gradient with respect to the
117
        Node is computed. The hook should have the following signature::
118

119
            hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
120

121
        The hook should not modify its argument, but it can optionally return
122
        a new gradient which will be used in place of :attr:`grad_outputs`.
123

124
        This function returns a handle with a method ``handle.remove()``
125
        that removes the hook from the module.
126

127
        .. note::
128
            See :ref:`backward-hooks-execution` for more information on how when this hook
129
            is executed, and how its execution is ordered relative to other hooks.
130

131
        Example::
132

133
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
134
            >>> b = a.clone()
135
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
136
            >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
137
            >>> b.sum().backward(retain_graph=True)
138
            >>> print(a.grad)
139
            tensor([2., 2., 2.])
140
            >>> handle.remove()
141
            >>> a.grad = None
142
            >>> b.sum().backward(retain_graph=True)
143
            >>> print(a.grad)
144
            tensor([1., 1., 1.])
145
        """
146
        ...
147

148
    @classmethod
149
    def __subclasshook__(cls, C):
150
        if cls is Node:
151
            if (
152
                C is not None and C is getattr(torch._C._functions, C.__name__, None)
153
            ) or issubclass(C, torch.autograd.function.BackwardCFunction):
154
                return True
155
        return NotImplemented
156

157

158
def _get_grad_fn_or_grad_acc(t):
159
    if t.requires_grad and t.grad_fn is None:
160
        return t.view_as(t).grad_fn.next_functions[0][0]
161
    else:
162
        return t.grad_fn
163

164

165
GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
166
GradientEdge.__doc__ = """\
167
Object representing a given gradient edge within the autograd graph.
168
To get the gradient edge where a given Tensor gradient will be computed,
169
you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
170
"""
171

172

173
def get_gradient_edge(tensor):
174
    """Get the gradient edge for computing the gradient of the given Tensor.
175

176
    In particular, it is equivalent to call
177
    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
178
    """
179
    if not tensor.requires_grad:
180
        raise RuntimeError(
181
            "It is not possible to get the gradient edge for a Tensor that does not require gradients"
182
        )
183
    grad_fn = _get_grad_fn_or_grad_acc(tensor)
184

185
    # Note that output_nr default to 0 which is the right value
186
    # for the AccumulateGrad node.
187
    return GradientEdge(grad_fn, tensor.output_nr)
188

189

190
def increment_version(tensor):
191
    """Update autograd metadata tracking whether the given Tensor was modified in place.
192

193
    This is to enable more accurate error checking within the autograd engine.
194
    It is already done automatically by PyTorch functions and within custom Function
195
    when mark_dirty() is called appropriately so you only need to call this explicitly
196
    if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
197
    know about. For example a custom kernel that reads the Tensor data_ptr and modifies
198
    the memory inplace based on this pointer.
199

200
    Note that incrementing the version counter multiple times for a single inplace operation
201
    is not problematic.
202
    """
203
    torch._C._increment_version(tensor)
204

205

206
class saved_tensors_hooks:
207
    """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
208

209
    Use this context-manager to define how intermediary results of an operation
210
    should be packed before saving, and unpacked on retrieval.
211

212
    In that context, the ``pack_hook`` function will be called everytime an
213
    operation saves a tensor for backward (this includes intermediary results
214
    saved using
215
    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
216
    also those recorded by a PyTorch-defined operation). The output of
217
    ``pack_hook`` is then stored in the computation graph instead of the
218
    original tensor.
219

220
    The ``unpack_hook`` is called when the saved tensor needs to be accessed,
221
    namely when executing :func:`torch.Tensor.backward()` or
222
    :func:`torch.autograd.grad()`. It takes as argument the *packed* object
223
    returned by ``pack_hook`` and should return a tensor which has the same
224
    content as the original tensor (passed as input to the corresponding
225
    ``pack_hook``).
226

227
    The hooks should have the following signatures:
228

229
        pack_hook(tensor: Tensor) -> Any
230

231
        unpack_hook(Any) -> Tensor
232

233
    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
234

235
    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
236
    of value, size, dtype and device.
237

238
    Example::
239

240
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
241
        >>> def pack_hook(x):
242
        ...     print("Packing", x)
243
        ...     return x
244
        >>>
245
        >>> def unpack_hook(x):
246
        ...     print("Unpacking", x)
247
        ...     return x
248
        >>>
249
        >>> a = torch.ones(5, requires_grad=True)
250
        >>> b = torch.ones(5, requires_grad=True) * 2
251
        >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
252
        ...     y = a * b
253
        Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
254
        Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
255
        >>> y.sum().backward()
256
        Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
257
        Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
258

259
    .. warning ::
260
        Performing an inplace operation on the input to either hooks may lead
261
        to undefined behavior.
262

263
    .. warning ::
264
        Only one pair of hooks is allowed at a time. When recursively nesting this
265
        context-manager, only the inner-most pair of hooks will be applied.
266
    """
267

268
    def __init__(
269
        self,
270
        pack_hook: Callable[[torch.Tensor], Any],
271
        unpack_hook: Callable[[Any], torch.Tensor],
272
    ):
273
        self.pack_hook = pack_hook
274
        self.unpack_hook = unpack_hook
275

276
    def __enter__(self):
277
        torch._C._autograd._push_saved_tensors_default_hooks(
278
            self.pack_hook, self.unpack_hook
279
        )
280

281
    def __exit__(self, *args: object):
282
        torch._C._autograd._pop_saved_tensors_default_hooks()
283

284

285
class save_on_cpu(saved_tensors_hooks):
286
    """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
287

288
    When performing operations within this context manager, intermediary
289
    results saved in the graph during the forward pass will be moved to CPU,
290
    then copied back to the original device when needed for the backward pass.
291
    If the graph was already on CPU, no tensor copy is performed.
292

293
    Use this context-manager to trade compute for GPU memory usage (e.g.
294
    when your model doesn't fit in GPU memory during training).
295

296
    Args:
297
        pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
298
                           during packing and copied to GPU asynchronously during unpacking.
299
                           Defaults to ``False``.
300
                           Also see :ref:`cuda-memory-pinning`.
301

302

303
    Example::
304

305
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
306
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
307
        >>> a = torch.randn(5, requires_grad=True, device="cuda")
308
        >>> b = torch.randn(5, requires_grad=True, device="cuda")
309
        >>> c = torch.randn(5, requires_grad=True, device="cuda")
310
        >>>
311
        >>> def f(a, b, c):
312
        ...     prod_1 = a * b           # a and b are saved on GPU
313
        ...     with torch.autograd.graph.save_on_cpu():
314
        ...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
315
        ...     y = prod_2 * a           # prod_2 and a are saved on GPU
316
        ...     return y
317
        >>>
318
        >>> y = f(a, b, c)
319
        >>> del a, b, c  # for illustration only
320
        >>> # the content of a, b, and prod_2 are still alive on GPU
321
        >>> # the content of prod_1 and c only live on CPU
322
        >>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
323
        >>> # all intermediary tensors are released (deleted) after the call to backward
324

325
    """
326

327
    def __init__(self, pin_memory=False, device_type="cuda"):
328
        device_module = getattr(torch, device_type, torch.cuda)
329

330
        def pack_to_cpu(tensor):
331
            if not pin_memory:
332
                return (tensor.device, tensor.cpu())
333
            packed = torch.empty(
334
                tensor.size(),
335
                dtype=tensor.dtype,
336
                layout=tensor.layout,
337
                pin_memory=(device_module.is_available() and not tensor.is_sparse),
338
            )
339
            packed.copy_(tensor)
340
            return (tensor.device, packed)
341

342
        def unpack_from_cpu(packed):
343
            device, tensor = packed
344
            return tensor.to(device, non_blocking=pin_memory)
345

346
        super().__init__(pack_to_cpu, unpack_from_cpu)
347

348

349
@contextlib.contextmanager
350
def disable_saved_tensors_hooks(error_message):
351
    """Context-manager that disables the saved tensors default hooks feature.
352

353
    Useful for if you are creating a feature that does not work with saved
354
    tensors default hooks.
355

356
    Args:
357
        error_message (str): When saved tensors default hooks are used when they
358
                             have been are disabled, a RuntimeError with this
359
                             error message gets raised.
360

361
    Example::
362

363
        >>> # xdoctest: +SKIP(failing)
364
        >>> message = "saved tensors default hooks are disabled"
365
        >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
366
        ...     # Raises RuntimeError: saved tensors default hooks are disabled
367
        ...     with torch.autograd.graph.save_on_cpu():
368
        ...         pass
369

370
    """
371
    try:
372
        maybe_prev_message = (
373
            torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
374
        )
375
        torch._C._autograd._saved_tensors_hooks_disable(error_message)
376
        yield
377
    finally:
378
        # See NOTE: [disabled_error_message invariant]
379
        if maybe_prev_message is None:
380
            torch._C._autograd._saved_tensors_hooks_enable()
381
        else:
382
            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
383

384

385
def register_multi_grad_hook(
386
    tensors: Sequence[torch.Tensor],
387
    fn: Union[
388
        Callable[[Sequence[Optional[torch.Tensor]]], None],
389
        Callable[[torch.Tensor], None],
390
    ],
391
    *,
392
    mode: str = "all",
393
):
394
    r"""Register a multi-grad backward hook.
395

396
    There are two supported modes: ``"all"`` and ``"any"``.
397

398
    Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
399
    :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
400
    is not part of the graph, or if a tensor is not needed to compute the gradients
401
    for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
402
    this tensor will be ignored and the hook will not wait for its gradient to be
403
    computed.
404

405
    After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
406
    called with those gradients. ``None`` will be passed for tensors that did not
407
    have their gradients computed.
408

409
    Under the ``"any"`` mode, the hook will be called after the first gradient
410
    with respect to a tensor in :attr:`tensors` has been computed. The hook
411
    will be called with that gradient as its argument.
412

413
    The hook should not modify its arguments.
414

415
    This function returns a handle with a method ``handle.remove()`` that removes the hook.
416

417
    .. note::
418
        See :ref:`backward-hooks-execution` for more information on how when this hook
419
        is executed, and how its execution is ordered relative to other hooks.
420

421
    Example::
422

423
        >>> import torch
424
        >>>
425
        >>> a = torch.rand(2, 3, requires_grad=True)
426
        >>> b = torch.rand(2, 3, requires_grad=True)
427
        >>> c = a * b
428
        >>> d = a * b
429
        >>>
430
        >>> def fn(grads):
431
        ...     print([g is not None for g in grads])
432
        ...
433
        >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
434
        >>>
435
        >>> c.sum().backward(retain_graph=True)
436
        [True, True, True, False]
437
        >>> c.sum().backward(inputs=(a,), retain_graph=True)
438
        [True, False, True, False]
439
        >>>
440
    """
441
    supported_modes = ("all", "any")
442
    if mode not in supported_modes:
443
        raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
444

445
    class Handle(RemovableHandle):
446
        handles: Tuple[RemovableHandle, ...]
447

448
        def __init__(self, handles: Tuple[RemovableHandle, ...]):
449
            self.handles = handles
450

451
        def remove(self):
452
            for handle in self.handles:
453
                handle.remove()
454

455
        def __getstate__(self):
456
            return self.handles
457

458
        def __setstate__(self, state):
459
            self.handles = state
460

461
    if mode == "all":
462
        count: Dict[int, int] = dict()
463
        nb_calls = None
464
        buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
465

466
        grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
467
        len_tensors = len(tensors)
468

469
        def get_inner_hook(idx):
470
            def inner_hook(grad: torch.Tensor):
471
                nonlocal count, nb_calls, buffer, fn
472
                id = torch._C._current_graph_task_id()
473
                assert (
474
                    id != -1
475
                ), "expected this hook to be called inside a backward call"
476
                count[id] = count.get(id, 0)
477
                buffer[id] = buffer.get(id, [None] * len_tensors)
478

479
                if count[id] == 0:
480
                    # On the first call, compute the actual nb_calls and buffer
481
                    nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns)  # type: ignore[attr-defined]
482

483
                buffer[id][idx] = grad
484
                count[id] += 1
485

486
                if count[id] == nb_calls:
487
                    fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
488
                    fn(buffer[id])
489
                    del count[id]
490
                    del buffer[id]
491

492
            return inner_hook
493

494
        handles: Tuple[RemovableHandle] = tuple(
495
            t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
496
        )
497
    elif mode == "any":
498
        fn = cast(Callable[[torch.Tensor], None], fn)
499
        lock = threading.Lock()
500
        ran_hook: Dict[int, bool] = defaultdict(bool)
501

502
        @functools.wraps(fn)
503
        def wrapped_fn(grad: torch.Tensor):
504
            nonlocal ran_hook
505
            id = torch._C._current_graph_task_id()
506
            assert id != -1, "expected this hook to be called inside a backward call"
507
            with lock:
508
                prev, ran_hook[id] = ran_hook[id], True
509
            if prev:
510
                return
511
            fn(grad)
512

513
        handles = tuple(
514
            tensor.register_hook(wrapped_fn)
515
            for tensor in tensors
516
            if tensor.requires_grad
517
        )
518

519
    return Handle(handles)  # type: ignore[possibly-undefined]
520

521

522
# NOTE [Allow mutation on tensors saved for backward]
523
#
524
# 1. Tensor gets saved for backward
525
#    - remember the python object id and the version of the tensor
526
#    - remember aliasing information (data_ptr of base + version)
527
#    - save the original so we control its lifetime
528
# 2. Any time a tensor gets in-placed
529
#    - for each tensor aliased to it:
530
#      - check using its object id and version to see if it has been saved
531
#      - if it has been saved, clone it
532
#      - delete the reference to the original
533
# 3. during backward
534
#    - if the clone exists, the tensor must've been modified in-place
535
_allow_mutation_on_saved_tensors_enabled = False
536

537

538
def _get_tid(t) -> Tuple[int, int, int]:
539
    return (id(t), t.data_ptr(), t._version)
540

541

542
def _get_sid(t) -> Tuple[int, int]:
543
    return (t.data_ptr(), t._version)
544

545

546
class _Handle:
547
    pass
548

549

550
class _swap_with_cloned(saved_tensors_hooks):
551
    def __init__(self, ctx):
552
        def pack_hook(t):
553
            tid = _get_tid(t)
554
            sid = _get_sid(t)
555
            # Tensors saved for backward have an entry in _tid_to_weakhandle
556
            handle: Optional[_Handle] = None
557

558
            # Save aliasing information
559
            ctx.sid_to_tid[sid].add(tid)
560

561
            # NB: The same tensor (of the same version) can be saved multiple times
562
            if tid not in ctx.tid_to_weakhandle:
563
                handle = _Handle()
564
                ctx.tid_to_weakhandle[tid] = handle
565
                ctx.original[handle] = t
566
            else:
567
                # Store an additional strong reference to the handle
568
                handle = ctx.tid_to_weakhandle[tid]
569
            return handle
570

571
        def unpack_hook(tup):
572
            handle = tup
573
            error_msg = (
574
                "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
575
                "in which the graph was originally recorded."
576
            )
577
            assert _allow_mutation_on_saved_tensors_enabled, error_msg
578
            if handle in ctx.cloned:
579
                res = ctx.cloned[handle]
580
            else:
581
                assert handle in ctx.original, error_msg
582
                res = ctx.original[handle]
583
            return res
584

585
        super().__init__(pack_hook, unpack_hook)
586

587

588
class _CloneArgBeforeMutateMode(TorchDispatchMode):
589
    def __init__(self, ctx):
590
        self.ctx = ctx
591

592
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
593
        kwargs = kwargs or {}
594

595
        for idx, arg in enumerate(func._schema.arguments):
596
            if arg.alias_info is not None and arg.alias_info.is_write:
597
                t = kwargs["out"] if arg.is_out else args[idx]
598
                tid = _get_tid(t)
599
                sid = _get_sid(t)
600
                ctx = self.ctx
601
                if sid in ctx.sid_to_tid:
602
                    for tid in ctx.sid_to_tid[sid]:
603
                        if tid not in ctx.tid_to_weakhandle:
604
                            # We know that if tid is in sid_to_tid, then it must also be in
605
                            # tid_to_weakhandle. However, it is possible for the tensor to be
606
                            # saved at one point, but cleared by backward before it is modified
607
                            # in-place. Consider the following example:
608
                            #
609
                            # >>> a = torch.randn(2, 3, requires_grad=True).clone()
610
                            # >>> out = (a**2).sum()
611
                            # >>> out.backward()
612
                            # >>> a.sin_()
613
                            continue
614
                        handle = ctx.tid_to_weakhandle[tid]
615
                        if handle in ctx.cloned:
616
                            # The same exact tensor has been cloned already
617
                            continue
618
                        ctx.cloned[handle] = ctx.original[handle].clone()
619
                        del ctx.original[handle]
620

621
        rs = func(*args, **kwargs)
622
        return rs
623

624

625
class _AllowMutationOnSavedContext:
626
    def __init__(self):
627
        self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
628
        self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
629
        self.tid_to_weakhandle: weakref.WeakValueDictionary = (
630
            weakref.WeakValueDictionary()
631
        )
632
        self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(
633
            set
634
        )
635

636
    def clear(self):
637
        self.cloned.clear()
638
        self.original.clear()
639
        self.tid_to_weakhandle.clear()
640
        self.sid_to_tid.clear()
641

642

643
@contextlib.contextmanager
644
def allow_mutation_on_saved_tensors():
645
    """Context manager under which mutating tensors saved for backward is allowed.
646

647
    Under this context manager, tensors saved for backward are cloned on mutation,
648
    so the original version can still be used during backward. Normally, mutating a tensor
649
    saved for backward will result in an error raised when it's used during backward.
650

651
    To ensure the correct behavior, both the forward and backward should be run under
652
    the same context manager.
653

654
    returns:
655
        An _AllowMutationOnSavedContext object storing the state managed by this
656
        context manager. This object can be useful for debugging purposes. The state
657
        managed by the context manager is automatically cleared upon exiting.
658

659
    Example::
660

661
        >>> import torch
662
        >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
663
        ...     # forward
664
        ...     a = torch.ones(2, 3, requires_grad=True)
665
        ...     b = a.clone()
666
        ...     out = (b**2).sum()
667
        ...     b.sin_()
668
        ...     # backward
669
        ...     out.sum().backward()
670
        ...
671
        tensor([[0.8415, 0.8415, 0.8415],
672
                [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
673
    """
674
    global _allow_mutation_on_saved_tensors_enabled
675

676
    ctx = _AllowMutationOnSavedContext()
677

678
    with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
679
        try:
680
            if _allow_mutation_on_saved_tensors_enabled:
681
                raise RuntimeError(
682
                    "allow_mutation_on_saved_tensors contexts cannot be nested"
683
                )
684
            _allow_mutation_on_saved_tensors_enabled = True
685
            yield ctx
686
        finally:
687
            ctx.clear()
688
            _allow_mutation_on_saved_tensors_enabled = False
689

690

691
def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
692
    grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
693

694
    def iter_graph(roots):
695
        if not roots:
696
            return
697
        seen = set()
698
        q: Deque = collections.deque()
699
        for node in roots:
700
            if node is not None:
701
                seen.add(node)
702
                q.append(node)
703

704
        while q:
705
            node = q.popleft()
706
            for fn, _idx in node.next_functions:
707
                if fn in seen or fn is None:
708
                    continue
709
                seen.add(fn)
710
                q.append(fn)
711

712
            yield node
713

714
    def fmt(t):
715
        # Avoid circular import
716
        from torch.testing._internal.common_utils import dtype_abbrs
717

718
        if t is None:
719
            return "None"
720
        return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
721

722
    def prehook(grad_outputs):
723
        node = torch._C._current_autograd_node()
724
        grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
725
        log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
726
        log.debug(log_str)
727

728
    handles = []
729
    for node in iter_graph(grad_fns):
730
        handles.append(node.register_prehook(prehook))
731

732
    def unregister_hooks():
733
        for handle in handles:
734
            handle.remove()
735

736
    return unregister_hooks
737

738

739
def _engine_run_backward(t_outputs, *args, **kwargs):
740
    attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
741
    if attach_logging_hooks:
742
        unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
743
    try:
744
        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
745
            t_outputs, *args, **kwargs
746
        )  # Calls into the C++ engine to run the backward pass
747
    finally:
748
        if attach_logging_hooks:
749
            unregister_hooks()  # type: ignore[possibly-undefined]
750

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.