pytorch

Форк
0
/
checkpoint.py 
1439 строк · 58.9 Кб
1
import contextlib
2
import platform
3
import uuid
4
import warnings
5
import weakref
6
from collections import defaultdict
7
from itertools import count
8
from typing import (
9
    Any,
10
    Callable,
11
    ContextManager,
12
    DefaultDict,
13
    Dict,
14
    Iterable,
15
    List,
16
    Optional,
17
    Tuple,
18
)
19
from weakref import ReferenceType
20

21
import torch
22
import torch.fx.traceback as fx_traceback
23
from torch._functorch._aot_autograd.functional_utils import is_fun
24
from torch.utils._pytree import tree_map
25
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
26
from torch.utils._python_dispatch import TorchDispatchMode
27

28
__all__ = [
29
    "checkpoint",
30
    "checkpoint_sequential",
31
    "CheckpointError",
32
    "CheckpointFunction",
33
    "check_backward_validity",
34
    "detach_variable",
35
    "get_device_states",
36
    "set_device_states",
37
    "noop_context_fn",
38
    "set_checkpoint_early_stop",
39
    "DefaultDeviceType",
40
    "set_checkpoint_debug_enabled",
41
]
42

43
_DEFAULT_DETERMINISM_MODE = "default"
44

45
_checkpoint_debug_enabled: Optional[bool] = None
46

47

48
@contextlib.contextmanager
49
def set_checkpoint_debug_enabled(enabled: Optional[bool]):
50
    """
51
    Context manager that sets whether checkpoint should print additional debug
52
    information when running. See the ``debug`` flag for
53
    :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
54
    when set, this context manager overrides the value of ``debug`` passed to
55
    checkpoint. To defer to the local setting, pass ``None`` to this context.
56

57
    Args:
58
        enabled (bool): Whether checkpoint should print debug information.
59
            Default is 'None'.
60
    """
61
    global _checkpoint_debug_enabled
62
    try:
63
        prev = _checkpoint_debug_enabled
64
        _checkpoint_debug_enabled = enabled
65
        yield
66
    finally:
67
        _checkpoint_debug_enabled = prev
68

69

70
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
71
    if isinstance(inputs, tuple):
72
        out = []
73
        for inp in inputs:
74
            if not isinstance(inp, torch.Tensor):
75
                out.append(inp)
76
                continue
77

78
            x = inp.detach()
79
            x.requires_grad = inp.requires_grad
80
            out.append(x)
81
        return tuple(out)
82
    else:
83
        raise RuntimeError(
84
            "Only tuple of tensors is supported. Got Unsupported input type: ",
85
            type(inputs).__name__,
86
        )
87

88

89
def check_backward_validity(inputs: Iterable[Any]) -> None:
90
    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
91
        warnings.warn(
92
            "None of the inputs have requires_grad=True. Gradients will be None"
93
        )
94

95

96
def _get_device_module(device="cuda"):
97
    device_module = getattr(torch, device)
98
    return device_module
99

100

101
class DefaultDeviceType:
102
    r"""
103
    A class that manages the default device type for checkpointing.
104

105
    If no non-CPU tensors are present, the default device type will
106
    be used. The default value is 'cuda'. The device type is used in
107
    the checkpointing process when determining which device states
108
    to save and restore for recomputation.
109
    """
110

111
    _default_device_type = "cuda"
112

113
    @staticmethod
114
    def set_device_type(device: str = "cuda"):
115
        """
116
        Set the default device type for checkpointing.
117

118
        Args:
119
            device (str): The device type to be set as default. Default is 'cuda'.
120
        """
121
        DefaultDeviceType._default_device_type = device
122

123
    @staticmethod
124
    def get_device_type() -> str:
125
        """
126
        Get the current default device type for checkpointing.
127

128
        Returns:
129
            str: The current default device type.
130
        """
131
        return DefaultDeviceType._default_device_type
132

133

134
def _infer_device_type(*args):
135
    device_types = list(
136
        {
137
            arg.device.type
138
            for arg in args
139
            if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
140
        }
141
    )
142
    if len(device_types) > 1:
143
        warnings.warn(
144
            "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
145
            "Device state will only be saved for devices of a single device type, and the remaining "
146
            "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
147
            "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
148
            "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
149
        )
150
    if len(device_types) == 0:
151
        return DefaultDeviceType.get_device_type()
152
    elif "cuda" in device_types:
153
        return "cuda"
154
    else:
155
        return device_types[0]
156

157

158
# We can't know if the run_fn will internally move some args to different devices,
159
# which would require logic to preserve rng states for those devices as well.
160
# We could paranoically stash and restore ALL the rng states for all visible devices,
161
# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
162
# the device of all Tensor args.
163
#
164
# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
165
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
166
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
167
    # the conditionals short-circuit.
168
    fwd_device_ids = list(
169
        {
170
            arg.get_device()
171
            for arg in args
172
            if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
173
        }
174
    )
175

176
    fwd_device_states = []
177
    device_module = _get_device_module(_infer_device_type(*args))
178

179
    for device_id in fwd_device_ids:
180
        with device_module.device(device_id):
181
            fwd_device_states.append(device_module.get_rng_state())
182

183
    return fwd_device_ids, fwd_device_states
184

185

186
def set_device_states(devices, states) -> None:
187
    device_module = _get_device_module(_infer_device_type(*states))
188
    for device, state in zip(devices, states):
189
        with device_module.device(device):
190
            device_module.set_rng_state(state)
191

192

193
def _get_autocast_kwargs(device="cuda"):
194
    if device == "cuda":
195
        device_autocast_kwargs = {
196
            "enabled": torch.is_autocast_enabled(),
197
            "dtype": torch.get_autocast_gpu_dtype(),
198
            "cache_enabled": torch.is_autocast_cache_enabled(),
199
        }
200
    elif _supports_autocast(device):
201
        device_module = _get_device_module(device)
202
        device_autocast_kwargs = {
203
            "enabled": device_module.is_autocast_enabled(),
204
            "dtype": device_module.get_autocast_dtype(),
205
            "cache_enabled": torch.is_autocast_cache_enabled(),
206
        }
207
    else:
208
        device_autocast_kwargs = None
209

210
    cpu_autocast_kwargs = {
211
        "enabled": torch.is_autocast_cpu_enabled(),
212
        "dtype": torch.get_autocast_cpu_dtype(),
213
        "cache_enabled": torch.is_autocast_cache_enabled(),
214
    }
215

216
    return device_autocast_kwargs, cpu_autocast_kwargs
217

218
def _supports_autocast(device):
219
    device_module = _get_device_module(device)
220
    return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
221
                                and hasattr(device_module, "get_autocast_dtype"))
222

223
class CheckpointFunction(torch.autograd.Function):
224
    @staticmethod
225
    def forward(ctx, run_function, preserve_rng_state, *args):
226
        check_backward_validity(args)
227
        ctx.run_function = run_function
228
        ctx.preserve_rng_state = preserve_rng_state
229
        # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
230
        ctx.device = _infer_device_type(*args)
231
        ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
232
            ctx.device
233
        )
234
        if preserve_rng_state:
235
            ctx.fwd_cpu_state = torch.get_rng_state()
236
            # Don't eagerly initialize the cuda context by accident.
237
            # (If the user intends that the context is initialized later, within their
238
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
239
            # we have no way to anticipate this will happen before we run the function.)
240
            ctx.had_device_in_fwd = False
241
            device_module = _get_device_module(ctx.device)
242
            if getattr(device_module, "_initialized", False):
243
                ctx.had_device_in_fwd = True
244
                ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
245

246
        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
247
        # to be filled out during the backward.
248
        ctx.inputs = []
249
        ctx.tensor_indices = []
250
        tensor_inputs = []
251
        for i, arg in enumerate(args):
252
            if torch.is_tensor(arg):
253
                tensor_inputs.append(arg)
254
                ctx.tensor_indices.append(i)
255
                ctx.inputs.append(None)
256
            else:
257
                ctx.inputs.append(arg)
258

259
        ctx.save_for_backward(*tensor_inputs)
260

261
        with torch.no_grad():
262
            outputs = run_function(*args)
263
        return outputs
264

265
    @staticmethod
266
    def backward(ctx, *args):
267
        if not torch.autograd._is_checkpoint_valid():
268
            raise RuntimeError(
269
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
270
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
271
                " argument."
272
            )
273
        # Copy the list to avoid modifying original list.
274
        inputs = list(ctx.inputs)
275
        tensor_indices = ctx.tensor_indices
276
        tensors = ctx.saved_tensors
277
        device_module = _get_device_module(ctx.device)
278

279
        # Fill in inputs with appropriate saved tensors.
280
        for i, idx in enumerate(tensor_indices):
281
            inputs[idx] = tensors[i]
282

283
        # Stash the surrounding rng state, and mimic the state that was
284
        # present at this time during forward.  Restore the surrounding state
285
        # when we're done.
286
        rng_devices = []
287
        if ctx.preserve_rng_state and ctx.had_device_in_fwd:
288
            rng_devices = ctx.fwd_devices
289
        with torch.random.fork_rng(
290
            devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
291
        ):
292
            if ctx.preserve_rng_state:
293
                torch.set_rng_state(ctx.fwd_cpu_state)
294
                if ctx.had_device_in_fwd:
295
                    set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
296
            detached_inputs = detach_variable(tuple(inputs))
297

298
            device_autocast_ctx = device_module.amp.autocast(
299
                **ctx.device_autocast_kwargs
300
            ) if _supports_autocast(ctx.device) else contextlib.nullcontext()
301
            with torch.enable_grad(), device_autocast_ctx, \
302
                 torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
303
                outputs = ctx.run_function(*detached_inputs)
304

305
        if isinstance(outputs, torch.Tensor):
306
            outputs = (outputs,)
307

308
        # run backward() with only tensor that requires grad
309
        outputs_with_grad = []
310
        args_with_grad = []
311
        for i in range(len(outputs)):
312
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
313
                outputs_with_grad.append(outputs[i])
314
                args_with_grad.append(args[i])
315
        if len(outputs_with_grad) == 0:
316
            raise RuntimeError(
317
                "none of output has requires_grad=True,"
318
                " this checkpoint() is not necessary"
319
            )
320
        torch.autograd.backward(outputs_with_grad, args_with_grad)
321
        grads = tuple(
322
            inp.grad if isinstance(inp, torch.Tensor) else None
323
            for inp in detached_inputs
324
        )
325

326
        return (None, None) + grads
327

328

329
def noop_context_fn():
330
    return contextlib.nullcontext(), contextlib.nullcontext()
331

332
# TorchDynamo does not step inside utils.checkpoint function.  The flow
333
# looks likes this
334
#  1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
335
#     speculatively checking if the forward function is safe to trace.
336
#  2) If yes, then Dynamo-generated Fx graph has the wrapped higher
337
#     order op. As a result, TorchDynamo does not look inside utils.checkpoint.
338
#  3) If not, then TorchDynamo falls back to eager by performing a graph
339
#     break. And here, the following disable wrapper ensures that
340
#     TorchDynamo does not trigger again on the frames created by
341
#     utils.checkpoint innards.
342
@torch._disable_dynamo
343
def checkpoint(
344
    function,
345
    *args,
346
    use_reentrant: Optional[bool] = None,
347
    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
348
    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
349
    debug: bool = False,
350
    **kwargs
351
):
352
    r"""Checkpoint a model or part of the model.
353

354
    Activation checkpointing is a technique that trades compute for memory.
355
    Instead of keeping tensors needed for backward alive until they are used in
356
    gradient computation during backward, forward computation in checkpointed
357
    regions omits saving tensors for backward and recomputes them during the
358
    backward pass. Activation checkpointing can be applied to any part of a
359
    model.
360

361
    There are currently two checkpointing implementations available, determined
362
    by the :attr:`use_reentrant` parameter. It is recommended that you use
363
    ``use_reentrant=False``. Please refer the note below for a discussion of
364
    their differences.
365

366
    .. warning::
367

368
        If the :attr:`function` invocation during the backward pass differs
369
        from the forward pass, e.g., due to a global variable, the checkpointed
370
        version may not be equivalent, potentially causing an
371
        error being raised or leading to silently incorrect gradients.
372

373
    .. warning::
374

375
        The ``use_reentrant`` parameter should be passed explicitly. In version
376
        2.4 we will raise an exception if ``use_reentrant`` is not passed.
377
        If you are using the ``use_reentrant=True`` variant, please refer to the
378
        note below for important considerations and potential limitations.
379

380
    .. note::
381

382
        The reentrant variant of checkpoint (``use_reentrant=True``) and
383
        the non-reentrant variant of checkpoint (``use_reentrant=False``)
384
        differ in the following ways:
385

386
        * Non-reentrant checkpoint stops recomputation as soon as all needed
387
          intermediate activations have been recomputed. This feature is enabled
388
          by default, but can be disabled with :func:`set_checkpoint_early_stop`.
389
          Reentrant checkpoint always recomputes :attr:`function` in its
390
          entirety during the backward pass.
391

392
        * The reentrant variant does not record the autograd graph during the
393
          forward pass, as it runs with the forward pass under
394
          :func:`torch.no_grad`. The non-reentrant version does record the
395
          autograd graph, allowing one to perform backward on the graph within
396
          checkpointed regions.
397

398
        * The reentrant checkpoint only supports the
399
          :func:`torch.autograd.backward` API for the backward pass without its
400
          `inputs` argument, while the non-reentrant version supports all ways
401
          of performing the backward pass.
402

403
        * At least one input and output must have ``requires_grad=True`` for the
404
          reentrant variant. If this condition is unmet, the checkpointed part
405
          of the model will not have gradients. The non-reentrant version does
406
          not have this requirement.
407

408
        * The reentrant version does not consider tensors in nested structures
409
          (e.g., custom objects, lists, dicts, etc) as participating in
410
          autograd, while the non-reentrant version does.
411

412
        * The reentrant checkpoint does not support checkpointed regions with
413
          detached tensors from the computational graph, whereas the
414
          non-reentrant version does. For the reentrant variant, if the
415
          checkpointed segment contains tensors detached using ``detach()`` or
416
          with :func:`torch.no_grad`, the backward pass will raise an error.
417
          This is because ``checkpoint`` makes all the outputs require gradients
418
          and this causes issues when a tensor is defined to have no gradient in
419
          the model. To avoid this, detach the tensors outside of the
420
          ``checkpoint`` function.
421

422
    Args:
423
        function: describes what to run in the forward pass of the model or
424
            part of the model. It should also know how to handle the inputs
425
            passed as the tuple. For example, in LSTM, if user passes
426
            ``(activation, hidden)``, :attr:`function` should correctly use the
427
            first input as ``activation`` and the second input as ``hidden``
428
        preserve_rng_state(bool, optional):  Omit stashing and restoring
429
            the RNG state during each checkpoint. Note that under torch.compile,
430
            this flag doesn't take effect and we always preserve RNG state.
431
            Default: ``True``
432
        use_reentrant(bool):
433
            specify whether to use the activation checkpoint variant that
434
            requires reentrant autograd. This parameter should be passed
435
            explicitly. In version 2.4 we will raise an exception if
436
            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
437
            ``checkpoint`` will use an implementation that does not require
438
            reentrant autograd. This allows ``checkpoint`` to support additional
439
            functionality, such as working as expected with
440
            ``torch.autograd.grad`` and support for keyword arguments input into
441
            the checkpointed function.
442
        context_fn(Callable, optional): A callable returning a tuple of two
443
            context managers. The function and its recomputation will be run
444
            under the first and second context managers respectively.
445
            This argument is only supported if ``use_reentrant=False``.
446
        determinism_check(str, optional): A string specifying the determinism
447
            check to perform. By default it is set to ``"default"`` which
448
            compares the shapes, dtypes, and devices of the recomputed tensors
449
            against those the saved tensors. To turn off this check, specify
450
            ``"none"``. Currently these are the only two supported values.
451
            Please open an issue if you would like to see more determinism
452
            checks. This argument is only supported if ``use_reentrant=False``,
453
            if ``use_reentrant=True``, the determinism check is always disabled.
454
        debug(bool, optional): If ``True``, error messages will also include
455
            a trace of the operators ran during the original forward computation
456
            as well as the recomputation. This argument is only supported if
457
            ``use_reentrant=False``.
458
        args: tuple containing inputs to the :attr:`function`
459

460
    Returns:
461
        Output of running :attr:`function` on :attr:`*args`
462
    """
463
    if use_reentrant is None:
464
        warnings.warn(
465
            "torch.utils.checkpoint: the use_reentrant parameter should be "
466
            "passed explicitly. In version 2.4 we will raise an exception "
467
            "if use_reentrant is not passed. use_reentrant=False is "
468
            "recommended, but if you need to preserve the current default "
469
            "behavior, you can pass use_reentrant=True. Refer to docs for more "
470
            "details on the differences between the two variants."
471
        )
472
        use_reentrant = True
473

474
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
475
    preserve = kwargs.pop("preserve_rng_state", True)
476
    if kwargs and use_reentrant:
477
        raise ValueError(
478
            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
479
        )
480

481
    if use_reentrant:
482
        if context_fn is not noop_context_fn or debug is not False:
483
            raise ValueError(
484
                "Passing `context_fn` or `debug` is only supported when "
485
                "use_reentrant=False."
486
            )
487
        return CheckpointFunction.apply(function, preserve, *args)
488
    else:
489
        gen = _checkpoint_without_reentrant_generator(
490
            function, preserve, context_fn, determinism_check, debug, *args, **kwargs
491
        )
492
        # Runs pre-forward logic
493
        next(gen)
494
        ret = function(*args, **kwargs)
495
        # Runs post-forward logic
496
        try:
497
            next(gen)
498
        except StopIteration:
499
            return ret
500

501

502
def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
503
    r"""Checkpoint a sequential model to save memory.
504

505
    Sequential models execute a list of modules/functions in order
506
    (sequentially). Therefore, we can divide such a model in various segments
507
    and checkpoint each segment. All segments except the last will not store
508
    the intermediate activations. The inputs of each checkpointed segment will
509
    be saved for re-running the segment in the backward pass.
510

511
    .. warning::
512
        The ``use_reentrant`` parameter should be passed explicitly. In version
513
        2.4 we will raise an exception if ``use_reentrant`` is not passed.
514
        If you are using the ``use_reentrant=True` variant, please see
515
        :func:`~torch.utils.checkpoint.checkpoint` for
516
        the important considerations and limitations of this variant. It is
517
        recommended that you use ``use_reentrant=False``.
518

519
    .. warning:
520
        Since PyTorch 1.4, it allows only one Tensor as the input and
521
        intermediate outputs, just like :class:`torch.nn.Sequential`.
522

523
    Args:
524
        functions: A :class:`torch.nn.Sequential` or the list of modules or
525
            functions (comprising the model) to run sequentially.
526
        segments: Number of chunks to create in the model
527
        input: A Tensor that is input to :attr:`functions`
528
        preserve_rng_state(bool, optional):  Omit stashing and restoring
529
            the RNG state during each checkpoint.
530
            Default: ``True``
531
        use_reentrant(bool):
532
            specify whether to use the activation checkpoint variant that
533
            requires reentrant autograd. This parameter should be passed
534
            explicitly. In version 2.4 we will raise an exception if
535
            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
536
            ``checkpoint`` will use an implementation that does not require
537
            reentrant autograd. This allows ``checkpoint`` to support additional
538
            functionality, such as working as expected with
539
            ``torch.autograd.grad`` and support for keyword arguments input into
540
            the checkpointed function.
541

542
    Returns:
543
        Output of running :attr:`functions` sequentially on :attr:`*inputs`
544

545
    Example:
546
        >>> # xdoctest: +SKIP("stub")
547
        >>> model = nn.Sequential(...)
548
        >>> input_var = checkpoint_sequential(model, chunks, input_var)
549
    """
550
    if use_reentrant is None:
551
        warnings.warn(
552
            "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
553
            "parameter should be passed explicitly. "
554
            "In version 2.4 we will raise an exception if use_reentrant "
555
            "is not passed. use_reentrant=False is "
556
            "recommended, but if you need to preserve the current default "
557
            "behavior, you can pass use_reentrant=True. Refer to docs for more "
558
            "details on the differences between the two variants."
559
        )
560
        use_reentrant = True
561

562
    # Hack for keyword-only parameter in a python 2.7-compliant way
563
    preserve = kwargs.pop("preserve_rng_state", True)
564
    if kwargs:
565
        raise ValueError(
566
            "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
567
        )
568

569
    def run_function(start, end, functions):
570
        def forward(input):
571
            for j in range(start, end + 1):
572
                input = functions[j](input)
573
            return input
574

575
        return forward
576

577
    if isinstance(functions, torch.nn.Sequential):
578
        functions = list(functions.children())
579

580
    segment_size = len(functions) // segments
581
    # the last chunk has to be non-volatile
582
    end = -1
583
    for start in range(0, segment_size * (segments - 1), segment_size):
584
        end = start + segment_size - 1
585
        input = checkpoint(
586
            run_function(start, end, functions),
587
            input,
588
            use_reentrant=use_reentrant,
589
            preserve_rng_state=preserve,
590
        )
591
    return run_function(end + 1, len(functions) - 1, functions)(input)
592

593

594
def _internal_assert(cond):
595
    if not cond:
596
        raise AssertionError(
597
            "Something went unexpectedly wrong in activation checkpoint. "
598
            "Please report this bug by filing an issue to PyTorch."
599
        )
600

601

602
# NOTE [ Nestable Checkpoint ]
603
#
604
# The semantics of nested checkpoint can be defined by two basic rules.
605
# Following the two rules leads to an important implication that is central
606
# to motivating the design.
607
#
608
# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
609
#         from any outer layers of checkpoint.
610
#
611
# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
612
#         parent checkpoint.
613
#
614
# Implication: To recompute any given saved tensor, we need to recompute all of
615
#              the checkpoints wrapping it.
616
#
617
# Why is this implied? To unpack a saved tensor X during backward we need to
618
# recompute the inner-most checkpoint (#1), and in order to recompute that
619
# checkpoint I need to have its inputs, which are managed by that checkpoint's
620
# parent (#2), which thus also needs to be recomputed first. Continue this line
621
# of reasoning and we realize that in order to unpack X, all checkpoints that
622
# were active at the time X was saved need to be recomputed. (unless we have
623
# already done so in that backward for some other saved tensor).
624
#
625
# In practice, we use a noop autograd Function to save inputs as saved tensors.
626
# During unpack calling ctx.saved_tensor triggers the parent checkpoint to
627
# recompute.
628
#
629
# Rule 3. We should start recomputation as if there are no checkpoints currently
630
#         active. Checkpoints encountered during recomputation are still
631
#         respected.
632
#
633
# When we start recomputation, we push the saved variable hook meant for
634
# recomputation on the stack. See examples in Rule 6 for more context.
635
#
636
#                                  * * * *
637
#
638
# Beyond the basic semantics specific to nested checkpoint, we impose several
639
# more constraints that may apply to checkpointing in general.
640
#
641
# Rule 4. Lifetime of recomputed tensors
642
#
643
#         Recomputed tensors are considered specific to particular invocations
644
#         of backward and are always cleared immediately as they are unpacked
645
#         Particularly, we require this to happen even if retain_graph=True.
646
#
647
# [ Implementation details of Rule 4 ]
648
#
649
# If we were okay with recomputed tensors staying alive after backward is run
650
# with retain_graph=True, we would store recomputed variables as the values of a
651
# WeakKeyDictionary and pack strong references to the keys, so that as we
652
# backward, those packed keys would be cleared as long as retain_graph=False.
653
# Clearing the packed key clears the corresponding entry in the WKD.
654
#
655
# If we wish recomputed variables to be immediately cleared as we unpack them in
656
# the retain_graph=True case, we cannot rely on the packed keys to be cleared by
657
# backward automatically. Instead of packing the strong reference to the key
658
# directly, we pack a container object, which we manually clear as we unpack.
659
#
660
# An important detail is that if a second backward happens, the second
661
# recomputation needs to reset the container with a newly created key.
662
#
663
# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
664
#         know we need.
665
#
666
# [ Implementation details of Rule 5 ]
667
#
668
# During recomputation, raise an exception if the number of recomputed tensors
669
# matches the number of tensors that we expected to recompute. We wrap the
670
# recomputation call with a try-catch to catch this specific exception. See
671
# Rule #6 below for some examples.
672
#
673
# Rule 6. We support doing backward inside checkpoint context
674
#
675
# [ retain_graph is True]
676
#
677
# def fn(x):
678
#   y = x.sin()
679
#   z = y.cos()
680
#   gx, = torch.autograd.grad(z, x, retains_grad=True)
681
#   return gx, z
682
#
683
# out = checkpoint(fn)(inp)
684
# out.backward()
685
#
686
# Because z is saved by cos while checkpoint is enabled, it would not be
687
# actually saved, and so the .grad() call inside must trigger a recomputation.
688
#
689
# During recomputation the "inner pack hook" has two responsibilities:
690
#
691
# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
692
# 2) Pack the actual tensor (detached) so that one may perform backward on the
693
#    recomputed graph. The tensors saved to this graph will live until the end
694
#    of recomputation, or die earlier if someone performs backward with
695
#    retain_graph=False.
696
#
697
# More generally performing backward on the recomputed graph occurs in the
698
# following cases:
699
# - If backward is performed inside forward,
700
#   - During the original forward IF early-stop is disabled
701
#   - During the original backward
702
# - If there are multiple .grad()/.backward() calls, we would perform backward
703
#   on the recomputed graph even if early-stop is enabled (see the example below)
704
#
705
# [ retain_graph is False ]
706
#
707
# The example below shows what happens if during recomputation we find that some
708
# of the tensors we are trying to recompute have already been cleared.
709
#
710
# Spoiler: we don't do anything special, we just skip over them!
711
#
712
# def fn(x):
713
#   y = x.sin()                           # (1)
714
#   z = y.cos()                           # (2)
715
#   gx, = torch.autograd.grad(z, x)       # (3)
716
#   return x.cos() * gx                   # (4)
717
#
718
# out = checkpoint(fn)(inp)
719
# out.backward()                          # (5)
720
#
721
# 1, 2. Don't save x and y since we are inside a checkpoint.
722
# 3. Trigger a recompute of fn since x and y weren't saved.
723
#    And depending on whether early stop is enabled, either stop at (2) or
724
#    continue running the function.
725
#    Because we are running backward with retain_graph=False, we clear x and y's
726
#    holders.
727
# 4. Don't save x since we are inside a checkpoint.
728
# 5. Calling backward triggers another recompute of fn. During recompute, we see
729
#    that x and y have already been cleared in the original graph as indicated
730
#    by holder=None. We skip over them. We still save x at (4) (since its holder
731
#    is still alive.)
732

733
_enable_checkpoint_early_stop = True
734

735

736
@contextlib.contextmanager
737
def set_checkpoint_early_stop(enable: bool):
738
    """Context manager that sets whether checkpoint should stop recomputation early.
739

740
    By default, non-reentrant checkpoint stops recomputation as soon as it
741
    has computed all needed Tensors. This context manager can be used to disable
742
    that feature if it is problematic for your specific application.
743

744
    This context manager only needs to be active when forward is run. It does
745
    not need to be active during backward.
746

747
    Example::
748

749
    >>> # xdoctest: +SKIP(failing)
750
    >>> message = "saved tensors default hooks are disabled"
751
    >>> with set_checkpoint_early_stop(False):
752
    ...     # Any checkpoint under this context manager will respect this
753
    ...     # context manager, even if its backward is performed outside.
754
    ...     out = checkpoint(fn, inputs)
755
    ...
756
    >>> out.backward()
757
    """
758
    global _enable_checkpoint_early_stop
759
    try:
760
        prev = _enable_checkpoint_early_stop
761
        _enable_checkpoint_early_stop = enable
762
        yield
763
    finally:
764
        _enable_checkpoint_early_stop = prev
765

766

767
class _Handle:
768
    pass
769

770

771
class _Holder:
772
    def __init__(self):
773
        self.handles: Dict[int, Optional[_Handle]] = dict()
774

775

776
class _NoopSaveInputs(torch.autograd.Function):
777
    @staticmethod
778
    def forward(*args):
779
        return torch.empty((0,))
780

781
    @staticmethod
782
    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
783
        # Only tensors can be saved with ctx.save_for_backward, everything else
784
        # is captured by get_args, which is saved directly on ctx
785
        tensor_indices, tensors = zip(
786
            *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
787
        )
788
        idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
789
        # args but with tensors replaced with None as placeholders
790
        args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
791

792
        def get_args(saved_tensors):
793
            # restore the placeholders with the original tensors grabbed from
794
            # ctx.saved_tensors (which may be saved on a parent checkpoint if
795
            # this checkpoint is nested, and that would trigger a recursive
796
            # unpack!)
797
            ret = [
798
                saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
799
                for i, o in enumerate(args)
800
            ]
801
            # grab the tail since we also saved the dummy to avoid having to explicitly
802
            # handle the case where there are no tensor inputs
803
            return ret[1:]
804

805
        ctx.get_args = get_args
806
        ctx.save_for_backward(*tensors)
807

808
    @staticmethod
809
    def backward(ctx, *grad_outputs):
810
        raise AssertionError("Did not expect to backward on this graph")
811

812

813
class _CheckpointFrame:
814
    def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
815
        self.recompute_fn = recompute_fn
816
        self.input_saver = None
817
        self.weak_holders: List[ReferenceType] = []
818
        # We store this as a weakkeydictionary so that in the case of a partial
819
        # backward, the entries in the dict are cleared alongside the Holder
820
        # which will be removed when the SavedVariable is cleared.
821
        self.recomputed: DefaultDict[
822
            int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
823
        ] = defaultdict(weakref.WeakKeyDictionary)
824
        # We need both recomp_counter and recomputed since they can diverge
825
        # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
826
        self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
827
        self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
828

829
        # See Rule 5
830
        self.early_stop = early_stop
831

832
        # Debugging
833
        self.metadata_fn = metadata_fn
834
        self.unpack_error_cb = unpack_error_cb
835
        self.x_metadatas = []
836
        self.forward_completed = False
837
        self.ignore_saved_mismatch = False
838

839
    def check_recomputed_tensors_match(self, gid):
840
        if self.ignore_saved_mismatch:
841
            # TODO: we can probably make this check stricter by checking that
842
            #       the metadata of the first tensors still match.
843
            return
844
        # NOTE [ Error handling for checkpoint ]
845
        #
846
        # At a high level, we need to check that the tensors saved
847
        # during original forward matches tensors saved during recompute
848
        # This means handling 3 cases:
849
        #
850
        # 1. During recompute, more tensors were saved.
851
        #
852
        #    Usually this is hidden due to the StopRecomputationError
853
        #    but if early stop is not enabled, or we would have errored
854
        #    anyway because there aren't enough weak_holders. But we
855
        #    do want to have a nice error. See the _recomputation_hook
856
        #    for details.
857
        if not len(self.weak_holders) == self.recomp_counter[gid]:
858
            # 2. During recompute, fewer tensors were saved
859
            #
860
            # We know that everytime we save something do original forward
861
            # we append to weak_holder, and every time we save a tensor
862
            # during recompute we increment recompute_counter.
863
            raise CheckpointError(
864
                "torch.utils.checkpoint: A different number of tensors was saved "
865
                "during the original forward and recomputation.\n"
866
                f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
867
                f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
868
            )
869

870
        # 3. During recompute, the same tensors were saved, but they
871
        #    have different metadata
872
        nb_meta_different = []
873
        for idx, weak_holder in enumerate(self.weak_holders):
874
            holder = weak_holder()
875
            if holder is None:
876
                continue
877
            # We've seen all holders since we iterate over them in order
878
            # For every holder that is still alive now, it must've been
879
            # alive when we saw it during recompute, therefore, the
880
            # gid must be set.
881
            _internal_assert(gid in holder.handles)
882
            # We know this is the first unpack, so it couldn't have been set
883
            # to None yet.
884
            _internal_assert(holder.handles[gid] is not None)
885
            # We always set these together in the recomputation hook
886
            _internal_assert(holder.handles[gid] in self.recomputed[gid])
887
            # see pack hook, x_metadata is 1:1 with weak_holders.
888
            x_meta = self.x_metadatas[idx]
889
            recomputed_x = self.recomputed[gid][holder.handles[gid]]
890
            if x_meta != self.metadata_fn(recomputed_x):
891
                nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
892

893
        if len(nb_meta_different) > 0:
894
            mismatched_tensors = ""
895
            for idx, x_meta, recomputed_meta in nb_meta_different:
896
                mismatched_tensors += (
897
                    f"tensor at position {idx}:\n"
898
                    f"saved metadata: {x_meta}\n"
899
                    f"recomputed metadata: {recomputed_meta}\n"
900
                )
901
            raise CheckpointError(
902
                "torch.utils.checkpoint: Recomputed values for the following tensors "
903
                "have different metadata than during the forward pass.\n"
904
                f"{mismatched_tensors}"
905
            )
906

907

908
_checkpoint_error_template = """ \
909
An error happened while unpacking tensors; dumping logs of latest computation
910
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
911
Scroll all the way down for guidance on how to navigate these logs.
912

913
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
914
|        1. Stack traces of the operators that ran in the original forward     |
915
+------------------------------------------------------------------------------+
916

917
{forward_traces}
918
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
919
|        2. Stack traces of the operators that ran during recomputation        |
920
+------------------------------------------------------------------------------+
921

922
{recompute_traces}
923
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
924
|       3. Log of operators in the original forward and recomputation          |
925
+------------------------------------------------------------------------------+
926
(Scroll up to correlate stack traces with each operation listed below. This
927
 helps identify their source in the code.)
928

929
IMPORTANT: Differences in "detach" calls between the original forward and the
930
           recomputation are expected. They are introduced by the checkpointing
931
           mechanism and can be ignored.
932

933
Operations executed during the original forward:
934

935
{forward_ops}
936

937
Operations executed during recomputation:
938

939
{recompute_ops}
940

941
+------------------------------------------------------------------------------+
942
 ERROR: Detected non-determinism while running activation checkpointing
943

944
 You are seeing this error because you passed `debug=True` to checkpoint and
945
 tensors to be saved during the original forward and differ between those saved
946
 during recomputation. This can happen if different operators were ran in the
947
 original forward and in the recomputation.
948

949
 To identify where the mismatch may be coming from, you can do the following:
950

951
 1) Compare the operators ran during original forward and recomputation to
952
    see where they differ. These operators are printed above in the order they
953
    were executed.
954

955
 2) Review the stack trace for each operator to locate its invocation source.
956
    Each operator's stack trace is printed in their execution order.
957

958
 Note that the logs can be quite long. Here's how they are structured:
959
 (Tip: you can Ctrl-f for these headers)
960

961
 1. Stack traces of the operators that ran in the original forward
962
 2. Stack traces of the operators that ran during recomputation
963
 3. Log of operators in the original forward and recomputation
964
 4. Error message                                             <--- You are here
965
--------------------------------------------------------------------------------
966
"""
967

968
class CheckpointError(RuntimeError):
969
    pass
970

971

972
def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
973
    # This function returns the context_fn and error_cb to be used by the
974
    # checkpointing mechanism. error_cb is invoked when an error is detected
975
    # during unpack.
976

977
    # record_context_cpp is not support on non-linux non-x86_64 platforms
978
    cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
979

980
    class CaptureLogs:
981
        def __init__(self):
982
            self.logs = None
983
            self.tbs = None
984

985
        def get_context_manager(self):
986
            @contextlib.contextmanager
987
            def logging_mode():
988
                with LoggingTensorMode(), \
989
                     capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
990
                    self.logs, self.tbs = logs_and_tb
991
                    yield logs_and_tb
992
            return logging_mode()
993

994
    capture_logs_fwd = CaptureLogs()
995
    capture_logs_recompute = CaptureLogs()
996

997
    def unpack_error_cb(e: CheckpointError):
998
        def get_str_tb(label, capture_logs):
999
            out = ""
1000
            total_len = len(capture_logs.logs)
1001
            for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
1002
                out += f"{log}   ({i + 1} of {total_len} in {label})\n\n"
1003
                found_torch_dispatch = False
1004
                for line in tb:
1005
                    # Start printing stack trace only after __torch_dispatch__ is found
1006
                    is_torch_dispatch = line['name'] == '__torch_dispatch__'
1007
                    if not found_torch_dispatch and not is_torch_dispatch:
1008
                        continue
1009
                    elif is_torch_dispatch:
1010
                        found_torch_dispatch = True
1011
                        continue
1012
                    out += f"{line['filename']}:{line['line']}:{line['name']}\n"
1013
                out += "\n\n"
1014
            return out
1015
        assert capture_logs_fwd.logs is not None
1016
        assert capture_logs_recompute.logs is not None
1017
        raise CheckpointError(
1018
            _checkpoint_error_template.format(
1019
                forward_traces=get_str_tb("original", capture_logs_fwd),
1020
                recompute_traces=get_str_tb("recompute", capture_logs_recompute),
1021
                forward_ops="\n".join(capture_logs_fwd.logs),
1022
                recompute_ops="\n".join(capture_logs_recompute.logs)
1023
            )
1024
        ) from e
1025

1026
    def context_fn():
1027
        return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
1028

1029
    return context_fn, unpack_error_cb
1030

1031
def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
1032
    # These properties are fast to check, easy to understand
1033
    return {
1034
        "shape": x.shape,
1035
        "dtype": x.dtype,
1036
        "device": x.device
1037
    }
1038

1039
_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
1040
    _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
1041
    "none": lambda _: None,
1042
}
1043

1044
# See Rule 5
1045
class _StopRecomputationError(Exception):
1046
    pass
1047

1048

1049
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
1050
    def __init__(self, target_frame_ref: ReferenceType, gid: int):
1051
        def pack_hook(x):
1052
            target_frame = target_frame_ref()
1053
            assert target_frame is not None  # appease mypy
1054
            recomp_idx = target_frame.recomp_counter[gid]
1055
            target_frame.recomp_counter[gid] += 1
1056

1057
            if recomp_idx >= len(target_frame.weak_holders):
1058
                assert not target_frame.early_stop
1059
                if not target_frame.forward_completed:
1060
                    # We run into this case when early stop is not enabled and do
1061
                    # grad within checkpoint.
1062
                    # We need to set this flag, so we don't error out later when
1063
                    # we check if the number of tensors saved during forward and
1064
                    # recomputation match.
1065
                    target_frame.ignore_saved_mismatch = True
1066
                    return x.detach()
1067
                raise CheckpointError(
1068
                    "torch.utils.checkpoint: trying to save more tensors during "
1069
                    "recomputation than during the original forward pass."
1070
                )
1071

1072
            holder = target_frame.weak_holders[recomp_idx]()
1073

1074
            # This holder may have been cleared because someone may have called
1075
            # backward within forward. If so, we don't need to save.
1076
            if holder is not None:
1077
                _internal_assert(holder.handles.get(gid, None) is None)
1078
                holder.handles[gid] = _Handle()
1079
                target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
1080

1081
            if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
1082
                target_frame.weak_holders
1083
            ):
1084
                raise _StopRecomputationError()
1085
            # See Rule 6: [ retain_graph is True ] above
1086
            return x.detach()
1087

1088
        def unpack_hook(x):
1089
            # See Rule 6: [ retain_graph is True ] above for an example of when
1090
            # the graph created during recomputation could be backwarded.
1091
            return x
1092

1093
        super().__init__(pack_hook, unpack_hook)
1094

1095

1096
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
1097
    def __init__(self, frame):
1098
        def pack_hook(x):
1099
            # See Rule 4 above
1100
            holder = _Holder()
1101
            frame.weak_holders.append(weakref.ref(holder))
1102
            # Save metadata to detect non-determinism
1103
            if frame.metadata_fn is not None:
1104
                with torch.no_grad():
1105
                    frame.x_metadatas.append(frame.metadata_fn(x))
1106
            return holder
1107

1108
        def unpack_hook(holder):
1109
            gid = torch._C._current_graph_task_id()
1110
            if gid == -1:
1111
                # generate a temporary id if we trigger unpack outside of a backward call
1112
                gid = int(uuid.uuid4())
1113

1114
            if not frame.is_recomputed[gid]:
1115
                ctx = frame.input_saver.grad_fn
1116
                args = ctx.get_args(ctx.saved_tensors)
1117

1118
                try:
1119
                    with _recomputation_hook(
1120
                        weakref.ref(frame), gid
1121
                    ), torch.autograd.enable_grad():
1122
                        frame.recompute_fn(*args)
1123
                except _StopRecomputationError:
1124
                    pass
1125
                frame.is_recomputed[gid] = True
1126
                frame.check_recomputed_tensors_match(gid)
1127

1128
            _internal_assert(gid in holder.handles)
1129

1130
            if holder.handles[gid] is None:
1131
                raise CheckpointError(
1132
                    "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
1133
                    "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
1134
                    "so only once. Otherwise please open an issue with details on your use case."
1135
                )
1136
            _internal_assert(holder.handles[gid] in frame.recomputed[gid])
1137
            ret = frame.recomputed[gid][holder.handles[gid]]
1138
            holder.handles[gid] = None
1139
            return ret
1140

1141
        if frame.unpack_error_cb is not None:
1142
            def unpack_hook_with_error_cb(holder):
1143
                try:
1144
                    return unpack_hook(holder)
1145
                except CheckpointError as e:
1146
                    frame.unpack_error_cb(e)
1147
            super().__init__(pack_hook, unpack_hook_with_error_cb)
1148
        else:
1149
            super().__init__(pack_hook, unpack_hook)
1150

1151

1152
def _is_compiling(func, args, kwargs):
1153
    # Check if we are under AOTAutograd tracing
1154
    # There should probably be a better way to do this...
1155
    # TODO: unify _is_compiling across all compile stacks
1156
    for arg in args:
1157
        if isinstance(arg, torch.Tensor) and is_fun(arg):
1158
            return True
1159
    return False
1160

1161

1162
def _detach(x):
1163
    if isinstance(x, torch.Tensor):
1164
        return x.detach()
1165
    return x
1166

1167

1168
uid = count(1)
1169

1170

1171
# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times
1172
# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call),
1173
# so we ignore these ops and just always recompute them.
1174
_ignored_ops = {
1175
    torch.ops.prim.device.default,
1176
    torch.ops.aten.detach.default,
1177
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
1178

1179

1180
class _CachingTorchDispatchMode(TorchDispatchMode):
1181
    r"""
1182
    A :class:`TorchDispatchMode` to implement selective activation checkpointing
1183
    that's compatible with torch.compile. Used together with _CachedTorchDispatchMode.
1184
    """
1185
    def __init__(self, policy_fn, storage):
1186
        self.policy_fn = policy_fn
1187
        self.storage = storage
1188

1189
    def push_into_storage(self, out, func, args, kwargs):
1190
        out_detached = tree_map(_detach, out)
1191
        self.storage[func].append(out_detached)
1192

1193
    def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs):
1194
        if func in _ignored_ops:
1195
            return func(*args, **kwargs)
1196
        if should_not_recompute:
1197
            fx_traceback.current_meta["recompute"] = 0
1198
        # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode
1199
        # we decide and handle recomputation in the partitioner.
1200
        out = func(*args, **kwargs)
1201
        self.push_into_storage(out, func, args, kwargs)
1202
        return out
1203

1204
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1205
        if kwargs is None:
1206
            kwargs = {}
1207
        should_not_recompute = self.policy_fn("forward", func, *args, **kwargs)
1208
        if _is_compiling(func, args, kwargs):
1209
            return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs)
1210
        else:
1211
            if should_not_recompute:
1212
                out = func(*args, **kwargs)
1213
                self.push_into_storage(out, func, args, kwargs)
1214
            else:
1215
                out = func(*args, **kwargs)
1216
            return out
1217

1218

1219
class _CachedTorchDispatchMode(TorchDispatchMode):
1220
    r"""
1221
    A :class:`TorchDispatchMode` to implement selective activation checkpointing
1222
    that's compatible with torch.compile. Used together with _CachingTorchDispatchMode.
1223
    """
1224
    def __init__(self, policy_fn, storage):
1225
        self.policy_fn = policy_fn
1226
        self.storage = storage
1227

1228
    def pop_from_storage(self, func, args, kwargs):
1229
        assert func in self.storage
1230
        out = self.storage[func].pop(0)
1231
        return out
1232

1233
    def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs):
1234
        if func in _ignored_ops:
1235
            return func(*args, **kwargs)
1236
        out = self.pop_from_storage(func, args, kwargs)
1237
        return out
1238

1239
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1240
        if kwargs is None:
1241
            kwargs = {}
1242
        should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs)
1243
        if _is_compiling(func, args, kwargs):
1244
            return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs)
1245
        else:
1246
            if should_not_recompute:
1247
                out = self.pop_from_storage(func, args, kwargs)
1248
            else:
1249
                out = func(*args, **kwargs)
1250
            return out
1251

1252

1253
def _pt2_selective_checkpoint_context_fn_gen(policy_fn):
1254
    """
1255
    A helper function that generates a pair of contexts to be later passed into
1256
    `torch.utils.checkpoint` API to implment selective checkpointing.
1257

1258
    .. warning::
1259
        This is context_fn is intended for use with torch.compile only.
1260

1261
    Args:
1262
        policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function
1263
            to decide whether a particular op should be recomputed in backward pass or not.
1264
            In eager mode:
1265
                If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
1266
                If policy_fn(...) returns False, the op is guaranteed to be recomputed.
1267
            In torch.compile mode:
1268
                If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
1269
                If policy_fn(...) returns False, the op may or may not be recomputed
1270
                (it's up to the partitioner to decide).
1271

1272
    Returns:
1273
        A pair of generated contexts.
1274

1275
    Example:
1276
        >>> # xdoctest: +REQUIRES(LINUX)
1277
        >>>
1278
        >>> def get_custom_policy():
1279
        >>>     no_recompute_list = [
1280
        >>>         torch.ops.aten.mm.default,
1281
        >>>     ]
1282
        >>>     def custom_policy(mode, func, *args, **kwargs):
1283
        >>>         return func in no_recompute_list
1284
        >>>     return custom_policy
1285
        >>>
1286
        >>> def selective_checkpointing_context_fn():
1287
        >>>     return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy())
1288
        >>>
1289
        >>> def gn(x, y):
1290
        >>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
1291
        >>>
1292
        >>> def fn(x, y):
1293
        >>>     return torch.utils.checkpoint.checkpoint(
1294
        >>>         gn, x, y,
1295
        >>>         use_reentrant=False,
1296
        >>>         context_fn=selective_checkpointing_context_fn,
1297
        >>>     )
1298
        >>>
1299
        >>> x = torch.randn(4, 4, requires_grad=True)
1300
        >>> y = torch.randn(4, 4, requires_grad=True)
1301
        >>>
1302
        >>> compiled_fn = torch.compile(fn)
1303
    """
1304
    storage: Dict[Any, List[Any]] = defaultdict(list)
1305
    return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage)
1306

1307

1308
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
1309
#     saving/restoring of global state is handled here.
1310

1311
def _checkpoint_without_reentrant_generator(
1312
    fn,
1313
    preserve_rng_state=True,
1314
    context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
1315
    determinism_check: str = _DEFAULT_DETERMINISM_MODE,
1316
    debug: bool = False,
1317
    *args,
1318
    **kwargs
1319
):
1320
    """Checkpointing without reentrant autograd.
1321

1322
    Args:
1323
        function: describes what to run in the forward pass of the model or
1324
            part of the model. It should also know how to handle the inputs
1325
            passed as the tuple. For example, in LSTM, if user passes
1326
            ``(activation, hidden)``, :attr:`function` should correctly use the
1327
            first input as ``activation`` and the second input as ``hidden``
1328
        preserve_rng_state(bool, optional):  Omit stashing and restoring
1329
            the RNG state during each checkpoint.
1330
            Default: ``True``
1331
        context_fn(Callable, optional): A callable returning a tuple of two
1332
            context managers. The function and its recomputation will be run
1333
            under the first and second context managers respectively.
1334
        determinism_check(str, optional): A string specifying the determinism
1335
            check to perform. By default it is set to ``"default"`` which
1336
            compares the shapes, dtypes, and devices of the recomputed tensors
1337
            against those the saved tensors. To turn off this check, specify
1338
            ``"none"``. Currently these are the only two supported values.
1339
            Please open an issue if you would like to see more determinism
1340
            checks.
1341
        debug(bool, optional): If ``True``, error messages will also include
1342
            a trace of the operators ran during the original forward computation
1343
            as well as the recomputation.
1344
        *args: Arguments to pass in to the given ``function``.
1345
        **kwargs: Keyword arguments to pass into the given ``function``.
1346
    """
1347
    unpack_error_cb = None
1348

1349
    if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
1350
        if context_fn != noop_context_fn:
1351
            raise ValueError(
1352
                "debug=True is incompatible with non-default context_fn"
1353
            )
1354
        context_fn, unpack_error_cb = _get_debug_context_and_cb()
1355

1356
    if determinism_check in _allowed_determinism_checks_to_fns:
1357
        metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
1358
    else:
1359
        raise ValueError(
1360
            f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
1361
            f"but got {determinism_check}"
1362
        )
1363

1364
    device = _infer_device_type(*args)
1365
    device_module = _get_device_module(device)
1366
    forward_context, recompute_context = context_fn()
1367
    if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
1368
        assert (
1369
            isinstance(forward_context, TorchDispatchMode) and
1370
            isinstance(recompute_context, TorchDispatchMode)
1371
        ), \
1372
            "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
1373
            "must generate a tuple of two `TorchDispatchMode`s."
1374
    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
1375
    device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device)
1376

1377
    if preserve_rng_state:
1378
        fwd_cpu_state = torch.get_rng_state()
1379
        # Don't eagerly initialize the cuda context by accident.
1380
        # (If the user intends that the context is initialized later, within their
1381
        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
1382
        # we have no way to anticipate this will happen before we run the function.
1383
        # If they do so, we raise an error.)
1384
        had_device_in_fwd = False
1385
        if getattr(device_module, "_initialized", False):
1386
            had_device_in_fwd = True
1387
            fwd_devices, fwd_device_states = get_device_states(*args)
1388

1389
    def recompute_fn(*inputs):
1390
        kwargs, *args = inputs
1391
        # This will be called later during recomputation. This wrapping enables
1392
        # the necessary global state to be captured.
1393
        rng_devices = []
1394
        if preserve_rng_state and had_device_in_fwd:
1395
            rng_devices = fwd_devices
1396
        with torch.random.fork_rng(
1397
            devices=rng_devices, enabled=preserve_rng_state, device_type=device
1398
        ):
1399
            if preserve_rng_state:
1400
                torch.set_rng_state(fwd_cpu_state)
1401
                if had_device_in_fwd:
1402
                    set_device_states(fwd_devices, fwd_device_states)
1403

1404
            device_autocast_ctx = device_module.amp.autocast(
1405
                **device_autocast_kwargs
1406
            ) if _supports_autocast(device) else contextlib.nullcontext()
1407
            with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
1408
                 recompute_context:
1409
                fn(*args, **kwargs)
1410

1411
    new_frame = _CheckpointFrame(
1412
        recompute_fn,
1413
        _enable_checkpoint_early_stop,
1414
        unpack_error_cb,
1415
        metadata_fn
1416
    )
1417
    dummy = torch.empty((0,), requires_grad=True)
1418
    new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
1419

1420
    # When ambient grad_mode is False
1421
    if new_frame.input_saver.grad_fn is None:
1422
        yield
1423
        return
1424

1425
    with _checkpoint_hook(new_frame), forward_context:
1426
        yield
1427
    new_frame.forward_completed = True
1428

1429
    if getattr(device_module, "_initialized", False) and \
1430
       preserve_rng_state and not had_device_in_fwd:  # type: ignore[possibly-undefined]
1431
        # Device was not initialized before running the forward, so we didn't
1432
        # stash the device state.
1433
        raise RuntimeError(
1434
            "PyTorch's device state was initialized in the forward pass "
1435
            "of a Checkpoint, which is not allowed. Please open an issue "
1436
            "if you need this feature."
1437
        )
1438

1439
    return
1440

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.