pytorch

Форк
0
823 строки · 28.9 Кб
1
import abc
2
import contextlib
3
import functools
4
import logging
5
import threading
6
from collections import defaultdict, deque
7
from typing import (
8
    Any,
9
    Callable,
10
    cast,
11
    Deque,
12
    Dict,
13
    Generator,
14
    Iterable,
15
    Iterator,
16
    List,
17
    Literal,
18
    MutableMapping,
19
    NamedTuple,
20
    Optional,
21
    Sequence,
22
    Set,
23
    Tuple,
24
    TYPE_CHECKING,
25
    Union,
26
)
27
from typing_extensions import TypeAlias
28
from weakref import WeakKeyDictionary, WeakValueDictionary
29

30
import torch
31
from torch.autograd.variable import Variable
32
from torch.utils._python_dispatch import TorchDispatchMode
33
from torch.utils.hooks import RemovableHandle
34

35

36
if TYPE_CHECKING:
37
    from torch._ops import OpOverload
38

39

40
__all__ = [
41
    "saved_tensors_hooks",
42
    "save_on_cpu",
43
    "disable_saved_tensors_hooks",
44
    "register_multi_grad_hook",
45
    "allow_mutation_on_saved_tensors",
46
    "Node",
47
    "GradientEdge",
48
    "get_gradient_edge",
49
    "increment_version",
50
]
51

52

53
log = logging.getLogger(__name__)
54

55

56
class Node(abc.ABC):
57
    @abc.abstractmethod
58
    def name(self) -> str:
59
        r"""Return the name.
60

61
        Example::
62

63
            >>> import torch
64
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
65
            >>> b = a.clone()
66
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
67
            >>> print(b.grad_fn.name())
68
            CloneBackward0
69
        """
70
        raise NotImplementedError
71

72
    @property
73
    @abc.abstractmethod
74
    def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]:
75
        raise NotImplementedError
76

77
    @abc.abstractmethod
78
    def metadata(self) -> dict:
79
        r"""Return the metadata."""
80
        raise NotImplementedError
81

82
    @property
83
    @abc.abstractmethod
84
    def _input_metadata(self) -> List[Any]:
85
        raise NotImplementedError
86

87
    @abc.abstractmethod
88
    def _register_hook_dict(self, tensor: torch.Tensor) -> None:
89
        raise NotImplementedError
90

91
    @abc.abstractmethod
92
    def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
93
        r"""Register a backward hook.
94

95
        The hook will be called every time a gradient with respect to the
96
        Node is computed. The hook should have the following signature::
97

98
            hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
99

100

101
        The hook should not modify its argument, but it can optionally return
102
        a new gradient which will be used in place of :attr:`grad_inputs`.
103

104
        This function returns a handle with a method ``handle.remove()``
105
        that removes the hook from the module.
106

107
        .. note::
108
            See :ref:`backward-hooks-execution` for more information on how when this hook
109
            is executed, and how its execution is ordered relative to other hooks.
110

111
        Example::
112

113
            >>> import torch
114
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
115
            >>> b = a.clone()
116
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
117
            >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
118
            >>> b.sum().backward(retain_graph=True)
119
            >>> print(a.grad)
120
            tensor([2., 2., 2.])
121
            >>> handle.remove() # Removes the hook
122
            >>> a.grad = None
123
            >>> b.sum().backward(retain_graph=True)
124
            >>> print(a.grad)
125
            tensor([1., 1., 1.])
126
        """
127
        raise NotImplementedError
128

129
    @abc.abstractmethod
130
    def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
131
        r"""Register a backward pre-hook.
132

133
        The hook will be called every time a gradient with respect to the
134
        Node is computed. The hook should have the following signature::
135

136
            hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
137

138
        The hook should not modify its argument, but it can optionally return
139
        a new gradient which will be used in place of :attr:`grad_outputs`.
140

141
        This function returns a handle with a method ``handle.remove()``
142
        that removes the hook from the module.
143

144
        .. note::
145
            See :ref:`backward-hooks-execution` for more information on how when this hook
146
            is executed, and how its execution is ordered relative to other hooks.
147

148
        Example::
149

150
            >>> a = torch.tensor([0., 0., 0.], requires_grad=True)
151
            >>> b = a.clone()
152
            >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
153
            >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
154
            >>> b.sum().backward(retain_graph=True)
155
            >>> print(a.grad)
156
            tensor([2., 2., 2.])
157
            >>> handle.remove()
158
            >>> a.grad = None
159
            >>> b.sum().backward(retain_graph=True)
160
            >>> print(a.grad)
161
            tensor([1., 1., 1.])
162
        """
163
        raise NotImplementedError
164

165
    @classmethod
166
    def __subclasshook__(cls, subclass: type) -> bool:
167
        if cls is Node and (
168
            (
169
                subclass is not None
170
                and subclass is getattr(torch._C._functions, subclass.__name__, None)
171
            )
172
            or issubclass(subclass, torch.autograd.function.BackwardCFunction)
173
        ):
174
            return True
175
        return NotImplemented
176

177

178
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
179
    if isinstance(t, GradientEdge):
180
        return t.node
181
    if t.requires_grad and t.grad_fn is None:
182
        with torch.enable_grad():
183
            node = t.view_as(t).grad_fn.next_functions[0][0]  # type: ignore[union-attr]
184
    else:
185
        node = t.grad_fn
186
    assert node is not None
187
    return node
188

189

190
class GradientEdge(NamedTuple):
191
    """Object representing a given gradient edge within the autograd graph.
192

193
    To get the gradient edge where a given Tensor gradient will be computed,
194
    you can do ``edge = autograd.graph.get_gradient_edge(tensor)``.
195
    """
196

197
    node: Node
198
    output_nr: int
199

200

201
def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
202
    """Get the gradient edge for computing the gradient of the given Tensor.
203

204
    In particular, it is equivalent to call
205
    ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``.
206
    """
207
    if not tensor.requires_grad:
208
        raise RuntimeError(
209
            "It is not possible to get the gradient edge for a Tensor "
210
            "that does not require gradients",
211
        )
212
    grad_fn = _get_grad_fn_or_grad_acc(tensor)
213

214
    # Note that output_nr default to 0 which is the right value
215
    # for the AccumulateGrad node.
216
    return GradientEdge(grad_fn, tensor.output_nr)
217

218

219
def increment_version(tensor: Union[torch.Tensor, Iterable[torch.Tensor]]) -> None:
220
    """Update autograd metadata tracking whether the given Tensor was modified in place.
221

222
    This is to enable more accurate error checking within the autograd engine.
223
    It is already done automatically by PyTorch functions and within custom Function
224
    when mark_dirty() is called appropriately so you only need to call this explicitly
225
    if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't
226
    know about. For example a custom kernel that reads the Tensor data_ptr and modifies
227
    the memory inplace based on this pointer. Can accept either a tensor, or a list of tensors.
228

229
    Note that incrementing the version counter multiple times for a single inplace operation
230
    is not problematic.
231

232
    Note that if you pass in tensor constructed under torch.inference_mode(),
233
    we will not bump its version counter (because your tensor does not have one).
234
    """
235
    if isinstance(tensor, torch.Tensor):
236
        tensor = (tensor,)
237
    torch._C._increment_version(tensor)
238

239

240
class saved_tensors_hooks:
241
    """Context-manager that sets a pair of pack / unpack hooks for saved tensors.
242

243
    Use this context-manager to define how intermediary results of an operation
244
    should be packed before saving, and unpacked on retrieval.
245

246
    In that context, the ``pack_hook`` function will be called everytime an
247
    operation saves a tensor for backward (this includes intermediary results
248
    saved using
249
    :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but
250
    also those recorded by a PyTorch-defined operation). The output of
251
    ``pack_hook`` is then stored in the computation graph instead of the
252
    original tensor.
253

254
    The ``unpack_hook`` is called when the saved tensor needs to be accessed,
255
    namely when executing :func:`torch.Tensor.backward()` or
256
    :func:`torch.autograd.grad()`. It takes as argument the *packed* object
257
    returned by ``pack_hook`` and should return a tensor which has the same
258
    content as the original tensor (passed as input to the corresponding
259
    ``pack_hook``).
260

261
    The hooks should have the following signatures:
262

263
        pack_hook(tensor: Tensor) -> Any
264

265
        unpack_hook(Any) -> Tensor
266

267
    where the return value of ``pack_hook`` is a valid input to ``unpack_hook``.
268

269
    In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms
270
    of value, size, dtype and device.
271

272
    Example::
273

274
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
275
        >>> def pack_hook(x):
276
        ...     print("Packing", x)
277
        ...     return x
278
        >>>
279
        >>> def unpack_hook(x):
280
        ...     print("Unpacking", x)
281
        ...     return x
282
        >>>
283
        >>> a = torch.ones(5, requires_grad=True)
284
        >>> b = torch.ones(5, requires_grad=True) * 2
285
        >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
286
        ...     y = a * b
287
        Packing tensor([1., 1., 1., 1., 1.], requires_grad=True)
288
        Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
289
        >>> y.sum().backward()
290
        Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True)
291
        Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
292

293
    .. warning ::
294
        Performing an inplace operation on the input to either hooks may lead
295
        to undefined behavior.
296

297
    .. warning ::
298
        Only one pair of hooks is allowed at a time. When recursively nesting this
299
        context-manager, only the inner-most pair of hooks will be applied.
300
    """
301

302
    def __init__(
303
        self,
304
        pack_hook: Callable[[torch.Tensor], Any],
305
        unpack_hook: Callable[[Any], torch.Tensor],
306
    ) -> None:
307
        self.pack_hook = pack_hook
308
        self.unpack_hook = unpack_hook
309

310
    def __enter__(self) -> None:
311
        torch._C._autograd._push_saved_tensors_default_hooks(
312
            self.pack_hook, self.unpack_hook
313
        )
314

315
    def __exit__(self, *args: object) -> None:
316
        torch._C._autograd._pop_saved_tensors_default_hooks()
317

318

319
class save_on_cpu(saved_tensors_hooks):
320
    """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward.
321

322
    When performing operations within this context manager, intermediary
323
    results saved in the graph during the forward pass will be moved to CPU,
324
    then copied back to the original device when needed for the backward pass.
325
    If the graph was already on CPU, no tensor copy is performed.
326

327
    Use this context-manager to trade compute for GPU memory usage (e.g.
328
    when your model doesn't fit in GPU memory during training).
329

330
    Args:
331
        pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
332
                           during packing and copied to GPU asynchronously during unpacking.
333
                           Defaults to ``False``.
334
                           Also see :ref:`cuda-memory-pinning`.
335

336

337
    Example::
338

339
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
340
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
341
        >>> a = torch.randn(5, requires_grad=True, device="cuda")
342
        >>> b = torch.randn(5, requires_grad=True, device="cuda")
343
        >>> c = torch.randn(5, requires_grad=True, device="cuda")
344
        >>>
345
        >>> def f(a, b, c):
346
        ...     prod_1 = a * b           # a and b are saved on GPU
347
        ...     with torch.autograd.graph.save_on_cpu():
348
        ...         prod_2 = prod_1 * c  # prod_1 and c are saved on CPU
349
        ...     y = prod_2 * a           # prod_2 and a are saved on GPU
350
        ...     return y
351
        >>>
352
        >>> y = f(a, b, c)
353
        >>> del a, b, c  # for illustration only
354
        >>> # the content of a, b, and prod_2 are still alive on GPU
355
        >>> # the content of prod_1 and c only live on CPU
356
        >>> y.sum().backward()  # all CPU tensors are moved back to GPU, for backward
357
        >>> # all intermediary tensors are released (deleted) after the call to backward
358
    """
359

360
    def __init__(self, pin_memory: bool = False, device_type: str = "cuda") -> None:
361
        device_module = getattr(torch, device_type, torch.cuda)
362

363
        def pack_to_cpu(tensor: torch.Tensor) -> Tuple[torch.device, torch.Tensor]:
364
            if not pin_memory:
365
                return (tensor.device, tensor.cpu())
366
            packed = torch.empty(
367
                tensor.size(),
368
                dtype=tensor.dtype,
369
                layout=tensor.layout,
370
                pin_memory=(device_module.is_available() and not tensor.is_sparse),
371
            )
372
            packed.copy_(tensor)
373
            return (tensor.device, packed)
374

375
        def unpack_from_cpu(packed: Tuple[torch.device, torch.Tensor]) -> torch.Tensor:
376
            device, tensor = packed
377
            return tensor.to(device, non_blocking=pin_memory)
378

379
        super().__init__(pack_to_cpu, unpack_from_cpu)
380

381

382
@contextlib.contextmanager
383
def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, None]:
384
    """Context-manager that disables the saved tensors default hooks feature.
385

386
    Useful for if you are creating a feature that does not work with saved
387
    tensors default hooks.
388

389
    Args:
390
        error_message (str): When saved tensors default hooks are used when they
391
                             have been are disabled, a RuntimeError with this
392
                             error message gets raised.
393

394
    Example::
395

396
        >>> # xdoctest: +SKIP(failing)
397
        >>> message = "saved tensors default hooks are disabled"
398
        >>> with torch.autograd.graph.disable_saved_tensors_hooks(message):
399
        ...     # Raises RuntimeError: saved tensors default hooks are disabled
400
        ...     with torch.autograd.graph.save_on_cpu():
401
        ...         pass
402
    """
403
    maybe_prev_message = None
404
    try:
405
        maybe_prev_message = (
406
            torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
407
        )
408
        torch._C._autograd._saved_tensors_hooks_disable(error_message)
409
        yield
410
    finally:
411
        # See NOTE: [disabled_error_message invariant]
412
        if maybe_prev_message is None:
413
            torch._C._autograd._saved_tensors_hooks_enable()
414
        else:
415
            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
416

417

418
class _MultiHandle(RemovableHandle):
419
    handles: Tuple[RemovableHandle, ...]
420

421
    def __init__(self, handles: Tuple[RemovableHandle, ...]) -> None:
422
        self.handles = handles
423

424
    def remove(self) -> None:
425
        for handle in self.handles:
426
            handle.remove()
427

428
    def __getstate__(self) -> Tuple[RemovableHandle, ...]:
429
        return self.handles
430

431
    def __setstate__(self, state: Tuple[RemovableHandle, ...]) -> None:
432
        self.handles = state
433

434

435
def register_multi_grad_hook(
436
    tensors: Sequence[torch.Tensor],
437
    fn: Union[
438
        Callable[[Sequence[Optional[torch.Tensor]]], None],
439
        Callable[[torch.Tensor], None],
440
    ],
441
    *,
442
    mode: Literal["all", "any"] = "all",
443
) -> RemovableHandle:
444
    r"""Register a multi-grad backward hook.
445

446
    There are two supported modes: ``"all"`` and ``"any"``.
447

448
    Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in
449
    :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but
450
    is not part of the graph, or if a tensor is not needed to compute the gradients
451
    for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call,
452
    this tensor will be ignored and the hook will not wait for its gradient to be
453
    computed.
454

455
    After every non-ignored tensor's gradient has been computed, :attr:`fn` will be
456
    called with those gradients. ``None`` will be passed for tensors that did not
457
    have their gradients computed.
458

459
    Under the ``"any"`` mode, the hook will be called after the first gradient
460
    with respect to a tensor in :attr:`tensors` has been computed. The hook
461
    will be called with that gradient as its argument.
462

463
    The hook should not modify its arguments.
464

465
    This function returns a handle with a method ``handle.remove()`` that removes the hook.
466

467
    .. note::
468
        See :ref:`backward-hooks-execution` for more information on how when this hook
469
        is executed, and how its execution is ordered relative to other hooks.
470

471
    Example::
472

473
        >>> import torch
474
        >>>
475
        >>> a = torch.rand(2, 3, requires_grad=True)
476
        >>> b = torch.rand(2, 3, requires_grad=True)
477
        >>> c = a * b
478
        >>> d = a * b
479
        >>>
480
        >>> def fn(grads):
481
        ...     print([g is not None for g in grads])
482
        ...
483
        >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn)
484
        >>>
485
        >>> c.sum().backward(retain_graph=True)
486
        [True, True, True, False]
487
        >>> c.sum().backward(inputs=(a,), retain_graph=True)
488
        [True, False, True, False]
489
        >>>
490
    """
491
    supported_modes = ("all", "any")
492
    lock = threading.Lock()
493

494
    if mode not in supported_modes:
495
        raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
496

497
    if mode == "all":
498
        count: Dict[int, int] = {}
499
        nb_calls = None
500
        buffer: Dict[int, List[Optional[torch.Tensor]]] = {}
501

502
        grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
503
        len_tensors = len(tensors)
504

505
        def get_inner_hook(idx: int) -> Callable[[torch.Tensor], None]:
506
            def inner_hook(grad: torch.Tensor) -> None:
507
                nonlocal count, nb_calls, buffer, fn
508
                id = torch._C._current_graph_task_id()
509
                assert (
510
                    id != -1
511
                ), "expected this hook to be called inside a backward call"
512
                count[id] = count.get(id, 0)
513
                buffer[id] = buffer.get(id, [None] * len_tensors)
514

515
                with lock:
516
                    curr_count, count[id] = count[id], count[id] + 1
517

518
                    if curr_count == 0:
519
                        # On the first call, compute the actual nb_calls and buffer
520
                        nb_calls = sum(
521
                            map(torch._C._will_engine_execute_node, grad_fns)
522
                        )
523

524
                buffer[id][idx] = grad
525

526
                assert nb_calls is not None
527
                if curr_count == nb_calls - 1:
528
                    fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn)
529
                    fn(buffer[id])
530
                    del count[id]
531
                    del buffer[id]
532

533
            return inner_hook
534

535
        handles = tuple(
536
            t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors)
537
        )
538
    elif mode == "any":
539
        fn = cast(Callable[[torch.Tensor], None], fn)
540
        ran_hook: Dict[int, bool] = defaultdict(bool)
541

542
        @functools.wraps(fn)
543
        def wrapped_fn(grad: torch.Tensor) -> None:
544
            nonlocal ran_hook
545
            id = torch._C._current_graph_task_id()
546
            assert id != -1, "expected this hook to be called inside a backward call"
547
            with lock:
548
                prev, ran_hook[id] = ran_hook[id], True
549
            if prev:
550
                return
551
            fn(grad)
552

553
        handles = tuple(
554
            tensor.register_hook(wrapped_fn)
555
            for tensor in tensors
556
            if tensor.requires_grad
557
        )
558

559
    return _MultiHandle(handles)  # type: ignore[possibly-undefined]
560

561

562
# NOTE [Allow mutation on tensors saved for backward]
563
#
564
# 1. Tensor gets saved for backward
565
#    - remember the python object id and the version of the tensor
566
#    - remember aliasing information (data_ptr of base + version)
567
#    - save the original so we control its lifetime
568
# 2. Any time a tensor gets in-placed
569
#    - for each tensor aliased to it:
570
#      - check using its object id and version to see if it has been saved
571
#      - if it has been saved, clone it
572
#      - delete the reference to the original
573
# 3. during backward
574
#    - if the clone exists, the tensor must've been modified in-place
575
_allow_mutation_on_saved_tensors_enabled: bool = False
576

577

578
_TID: TypeAlias = Tuple[int, int, int]
579
_SID: TypeAlias = Tuple[int, int]
580

581

582
def _get_tid(tensor: torch.Tensor) -> _TID:
583
    # FIXME: This is almost definitely a bug.
584
    if isinstance(
585
        tensor,
586
        (
587
            torch._subclasses.fake_tensor.FakeTensor,
588
            torch._subclasses.functional_tensor.FunctionalTensor,
589
        ),
590
    ):
591
        data_ptr = 0
592
    else:
593
        data_ptr = tensor.data_ptr()
594
    return (id(tensor), data_ptr, tensor._version)
595

596

597
def _get_sid(tensor: torch.Tensor) -> _SID:
598
    # FIXME: This is almost definitely a bug.
599
    if isinstance(
600
        tensor,
601
        (
602
            torch._subclasses.fake_tensor.FakeTensor,
603
            torch._subclasses.functional_tensor.FunctionalTensor,
604
        ),
605
    ):
606
        data_ptr = 0
607
    else:
608
        data_ptr = tensor.data_ptr()
609
    return (data_ptr, tensor._version)
610

611

612
class _Handle:
613
    pass
614

615

616
class _swap_with_cloned(saved_tensors_hooks):
617
    def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
618
        def pack_hook(tensor: torch.Tensor) -> _Handle:
619
            tid = _get_tid(tensor)
620
            sid = _get_sid(tensor)
621
            # Tensors saved for backward have an entry in _tid_to_weakhandle
622
            handle: Optional[_Handle] = None
623

624
            # Save aliasing information
625
            ctx.sid_to_tid[sid].add(tid)
626

627
            # NB: The same tensor (of the same version) can be saved multiple times
628
            if tid not in ctx.tid_to_weakhandle:
629
                handle = _Handle()
630
                ctx.tid_to_weakhandle[tid] = handle
631
                ctx.original[handle] = tensor
632
            else:
633
                # Store an additional strong reference to the handle
634
                handle = ctx.tid_to_weakhandle[tid]
635
            return handle
636

637
        def unpack_hook(handle: _Handle) -> torch.Tensor:
638
            error_msg = (
639
                "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
640
                "in which the graph was originally recorded."
641
            )
642
            assert _allow_mutation_on_saved_tensors_enabled, error_msg
643
            if handle in ctx.cloned:
644
                res = ctx.cloned[handle]
645
            else:
646
                assert handle in ctx.original, error_msg
647
                res = ctx.original[handle]
648
            return res
649

650
        super().__init__(pack_hook, unpack_hook)
651

652

653
class _CloneArgBeforeMutateMode(TorchDispatchMode):
654
    def __init__(self, ctx: "_AllowMutationOnSavedContext") -> None:
655
        self.ctx = ctx
656

657
    def __torch_dispatch__(
658
        self,
659
        func: "OpOverload",
660
        types: Iterable[type],
661
        args: Tuple[Any, ...] = (),
662
        kwargs: Optional[Dict[Any, Any]] = None,
663
    ) -> Any:
664
        kwargs = kwargs or {}
665

666
        for idx, arg in enumerate(func._schema.arguments):
667
            if arg.alias_info is not None and arg.alias_info.is_write:
668
                t = kwargs["out"] if arg.is_out else args[idx]
669
                tid = _get_tid(t)
670
                sid = _get_sid(t)
671
                ctx = self.ctx
672
                if sid in ctx.sid_to_tid:
673
                    for tid in ctx.sid_to_tid[sid]:
674
                        if tid not in ctx.tid_to_weakhandle:
675
                            # We know that if tid is in sid_to_tid, then it must also be in
676
                            # tid_to_weakhandle. However, it is possible for the tensor to be
677
                            # saved at one point, but cleared by backward before it is modified
678
                            # in-place. Consider the following example:
679
                            #
680
                            # >>> a = torch.randn(2, 3, requires_grad=True).clone()
681
                            # >>> out = (a**2).sum()
682
                            # >>> out.backward()
683
                            # >>> a.sin_()
684
                            continue
685
                        handle = ctx.tid_to_weakhandle[tid]
686
                        if handle in ctx.cloned:
687
                            # The same exact tensor has been cloned already
688
                            continue
689
                        ctx.cloned[handle] = ctx.original[handle].clone()
690
                        del ctx.original[handle]
691

692
        return func(*args, **kwargs)
693

694

695
class _AllowMutationOnSavedContext:
696
    def __init__(self) -> None:
697
        self.cloned: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
698
        self.original: MutableMapping[_Handle, torch.Tensor] = WeakKeyDictionary()
699
        self.tid_to_weakhandle: MutableMapping[_TID, _Handle] = WeakValueDictionary()
700
        self.sid_to_tid: Dict[_SID, Set[_TID]] = defaultdict(set)
701

702
    def clear(self) -> None:
703
        self.cloned.clear()
704
        self.original.clear()
705
        self.tid_to_weakhandle.clear()
706
        self.sid_to_tid.clear()
707

708

709
@contextlib.contextmanager
710
def allow_mutation_on_saved_tensors() -> (
711
    Generator[_AllowMutationOnSavedContext, None, None]
712
):
713
    """Context manager under which mutating tensors saved for backward is allowed.
714

715
    Under this context manager, tensors saved for backward are cloned on mutation,
716
    so the original version can still be used during backward. Normally, mutating a tensor
717
    saved for backward will result in an error raised when it's used during backward.
718

719
    To ensure the correct behavior, both the forward and backward should be run under
720
    the same context manager.
721

722
    Returns:
723
        An _AllowMutationOnSavedContext object storing the state managed by this
724
        context manager. This object can be useful for debugging purposes. The state
725
        managed by the context manager is automatically cleared upon exiting.
726

727
    Example::
728

729
        >>> import torch
730
        >>> with torch.autograd.graph.allow_mutation_on_saved_tensors():
731
        ...     # forward
732
        ...     a = torch.ones(2, 3, requires_grad=True)
733
        ...     b = a.clone()
734
        ...     out = (b**2).sum()
735
        ...     b.sin_()
736
        ...     # backward
737
        ...     out.sum().backward()
738
        ...
739
        tensor([[0.8415, 0.8415, 0.8415],
740
                [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>)
741
    """
742
    global _allow_mutation_on_saved_tensors_enabled
743

744
    ctx = _AllowMutationOnSavedContext()
745

746
    with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx):
747
        try:
748
            if _allow_mutation_on_saved_tensors_enabled:
749
                raise RuntimeError(
750
                    "allow_mutation_on_saved_tensors contexts cannot be nested"
751
                )
752
            _allow_mutation_on_saved_tensors_enabled = True
753
            yield ctx
754
        finally:
755
            ctx.clear()
756
            _allow_mutation_on_saved_tensors_enabled = False
757

758

759
def _register_logging_hooks_on_whole_graph(
760
    t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
761
) -> Callable[[], None]:
762
    grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
763

764
    def iter_graph(roots: List[Node]) -> Iterator[Node]:
765
        if not roots:
766
            return
767
        seen: Set[Node] = set()
768
        q: Deque[Node] = deque()
769
        for node in roots:
770
            if node is not None:
771
                seen.add(node)
772
                q.append(node)
773

774
        while q:
775
            node = q.popleft()
776
            for fn, _ in node.next_functions:
777
                if fn in seen or fn is None:
778
                    continue
779
                seen.add(fn)
780
                q.append(fn)
781

782
            yield node
783

784
    def fmt(t: Optional[torch.Tensor]) -> str:
785
        # Avoid circular import
786
        from torch.testing._internal.common_utils import dtype_abbrs
787

788
        if t is None:
789
            return "None"
790
        return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]"
791

792
    def prehook(grad_outputs: Sequence[Optional[torch.Tensor]]) -> None:
793
        node = torch._C._current_autograd_node()
794
        grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]"
795
        log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
796
        log.debug(log_str)
797

798
    handles = []
799
    for node in iter_graph(grad_fns):
800
        handles.append(node.register_prehook(prehook))
801

802
    def unregister_hooks() -> None:
803
        for handle in handles:
804
            handle.remove()
805

806
    return unregister_hooks
807

808

809
def _engine_run_backward(
810
    t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
811
    *args: Any,
812
    **kwargs: Any,
813
) -> Tuple[torch.Tensor, ...]:
814
    attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
815
    if attach_logging_hooks:
816
        unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
817
    try:
818
        return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
819
            t_outputs, *args, **kwargs
820
        )  # Calls into the C++ engine to run the backward pass
821
    finally:
822
        if attach_logging_hooks:
823
            unregister_hooks()  # type: ignore[possibly-undefined]
824

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.