6
from collections import defaultdict
7
from itertools import count
19
from weakref import ReferenceType
22
import torch.fx.traceback as fx_traceback
23
from torch._functorch._aot_autograd.functional_utils import is_fun
24
from torch.utils._pytree import tree_map
25
from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
26
from torch.utils._python_dispatch import TorchDispatchMode
30
"checkpoint_sequential",
33
"check_backward_validity",
38
"set_checkpoint_early_stop",
40
"set_checkpoint_debug_enabled",
43
_DEFAULT_DETERMINISM_MODE = "default"
45
_checkpoint_debug_enabled: Optional[bool] = None
48
@contextlib.contextmanager
49
def set_checkpoint_debug_enabled(enabled: Optional[bool]):
51
Context manager that sets whether checkpoint should print additional debug
52
information when running. See the ``debug`` flag for
53
:func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
54
when set, this context manager overrides the value of ``debug`` passed to
55
checkpoint. To defer to the local setting, pass ``None`` to this context.
58
enabled (bool): Whether checkpoint should print debug information.
61
global _checkpoint_debug_enabled
63
prev = _checkpoint_debug_enabled
64
_checkpoint_debug_enabled = enabled
67
_checkpoint_debug_enabled = prev
70
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
71
if isinstance(inputs, tuple):
74
if not isinstance(inp, torch.Tensor):
79
x.requires_grad = inp.requires_grad
84
"Only tuple of tensors is supported. Got Unsupported input type: ",
85
type(inputs).__name__,
89
def check_backward_validity(inputs: Iterable[Any]) -> None:
90
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
92
"None of the inputs have requires_grad=True. Gradients will be None"
96
def _get_device_module(device="cuda"):
97
device_module = getattr(torch, device)
101
class DefaultDeviceType:
103
A class that manages the default device type for checkpointing.
105
If no non-CPU tensors are present, the default device type will
106
be used. The default value is 'cuda'. The device type is used in
107
the checkpointing process when determining which device states
108
to save and restore for recomputation.
111
_default_device_type = "cuda"
114
def set_device_type(device: str = "cuda"):
116
Set the default device type for checkpointing.
119
device (str): The device type to be set as default. Default is 'cuda'.
121
DefaultDeviceType._default_device_type = device
124
def get_device_type() -> str:
126
Get the current default device type for checkpointing.
129
str: The current default device type.
131
return DefaultDeviceType._default_device_type
134
def _infer_device_type(*args):
139
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
142
if len(device_types) > 1:
144
"Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
145
"Device state will only be saved for devices of a single device type, and the remaining "
146
"devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
147
"this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
148
"detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
150
if len(device_types) == 0:
151
return DefaultDeviceType.get_device_type()
152
elif "cuda" in device_types:
155
return device_types[0]
165
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
168
fwd_device_ids = list(
172
if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu"
176
fwd_device_states = []
177
device_module = _get_device_module(_infer_device_type(*args))
179
for device_id in fwd_device_ids:
180
with device_module.device(device_id):
181
fwd_device_states.append(device_module.get_rng_state())
183
return fwd_device_ids, fwd_device_states
186
def set_device_states(devices, states) -> None:
187
device_module = _get_device_module(_infer_device_type(*states))
188
for device, state in zip(devices, states):
189
with device_module.device(device):
190
device_module.set_rng_state(state)
193
def _get_autocast_kwargs(device="cuda"):
195
device_autocast_kwargs = {
196
"enabled": torch.is_autocast_enabled(),
197
"dtype": torch.get_autocast_gpu_dtype(),
198
"cache_enabled": torch.is_autocast_cache_enabled(),
200
elif _supports_autocast(device):
201
device_module = _get_device_module(device)
202
device_autocast_kwargs = {
203
"enabled": device_module.is_autocast_enabled(),
204
"dtype": device_module.get_autocast_dtype(),
205
"cache_enabled": torch.is_autocast_cache_enabled(),
208
device_autocast_kwargs = None
210
cpu_autocast_kwargs = {
211
"enabled": torch.is_autocast_cpu_enabled(),
212
"dtype": torch.get_autocast_cpu_dtype(),
213
"cache_enabled": torch.is_autocast_cache_enabled(),
216
return device_autocast_kwargs, cpu_autocast_kwargs
218
def _supports_autocast(device):
219
device_module = _get_device_module(device)
220
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
221
and hasattr(device_module, "get_autocast_dtype"))
223
class CheckpointFunction(torch.autograd.Function):
225
def forward(ctx, run_function, preserve_rng_state, *args):
226
check_backward_validity(args)
227
ctx.run_function = run_function
228
ctx.preserve_rng_state = preserve_rng_state
230
ctx.device = _infer_device_type(*args)
231
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
234
if preserve_rng_state:
235
ctx.fwd_cpu_state = torch.get_rng_state()
240
ctx.had_device_in_fwd = False
241
device_module = _get_device_module(ctx.device)
242
if getattr(device_module, "_initialized", False):
243
ctx.had_device_in_fwd = True
244
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
249
ctx.tensor_indices = []
251
for i, arg in enumerate(args):
252
if torch.is_tensor(arg):
253
tensor_inputs.append(arg)
254
ctx.tensor_indices.append(i)
255
ctx.inputs.append(None)
257
ctx.inputs.append(arg)
259
ctx.save_for_backward(*tensor_inputs)
261
with torch.no_grad():
262
outputs = run_function(*args)
266
def backward(ctx, *args):
267
if not torch.autograd._is_checkpoint_valid():
269
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
270
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
274
inputs = list(ctx.inputs)
275
tensor_indices = ctx.tensor_indices
276
tensors = ctx.saved_tensors
277
device_module = _get_device_module(ctx.device)
280
for i, idx in enumerate(tensor_indices):
281
inputs[idx] = tensors[i]
287
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
288
rng_devices = ctx.fwd_devices
289
with torch.random.fork_rng(
290
devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device
292
if ctx.preserve_rng_state:
293
torch.set_rng_state(ctx.fwd_cpu_state)
294
if ctx.had_device_in_fwd:
295
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
296
detached_inputs = detach_variable(tuple(inputs))
298
device_autocast_ctx = device_module.amp.autocast(
299
**ctx.device_autocast_kwargs
300
) if _supports_autocast(ctx.device) else contextlib.nullcontext()
301
with torch.enable_grad(), device_autocast_ctx, \
302
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
303
outputs = ctx.run_function(*detached_inputs)
305
if isinstance(outputs, torch.Tensor):
309
outputs_with_grad = []
311
for i in range(len(outputs)):
312
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
313
outputs_with_grad.append(outputs[i])
314
args_with_grad.append(args[i])
315
if len(outputs_with_grad) == 0:
317
"none of output has requires_grad=True,"
318
" this checkpoint() is not necessary"
320
torch.autograd.backward(outputs_with_grad, args_with_grad)
322
inp.grad if isinstance(inp, torch.Tensor) else None
323
for inp in detached_inputs
326
return (None, None) + grads
329
def noop_context_fn():
330
return contextlib.nullcontext(), contextlib.nullcontext()
342
@torch._disable_dynamo
346
use_reentrant: Optional[bool] = None,
347
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
348
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
352
r"""Checkpoint a model or part of the model.
354
Activation checkpointing is a technique that trades compute for memory.
355
Instead of keeping tensors needed for backward alive until they are used in
356
gradient computation during backward, forward computation in checkpointed
357
regions omits saving tensors for backward and recomputes them during the
358
backward pass. Activation checkpointing can be applied to any part of a
361
There are currently two checkpointing implementations available, determined
362
by the :attr:`use_reentrant` parameter. It is recommended that you use
363
``use_reentrant=False``. Please refer the note below for a discussion of
368
If the :attr:`function` invocation during the backward pass differs
369
from the forward pass, e.g., due to a global variable, the checkpointed
370
version may not be equivalent, potentially causing an
371
error being raised or leading to silently incorrect gradients.
375
The ``use_reentrant`` parameter should be passed explicitly. In version
376
2.4 we will raise an exception if ``use_reentrant`` is not passed.
377
If you are using the ``use_reentrant=True`` variant, please refer to the
378
note below for important considerations and potential limitations.
382
The reentrant variant of checkpoint (``use_reentrant=True``) and
383
the non-reentrant variant of checkpoint (``use_reentrant=False``)
384
differ in the following ways:
386
* Non-reentrant checkpoint stops recomputation as soon as all needed
387
intermediate activations have been recomputed. This feature is enabled
388
by default, but can be disabled with :func:`set_checkpoint_early_stop`.
389
Reentrant checkpoint always recomputes :attr:`function` in its
390
entirety during the backward pass.
392
* The reentrant variant does not record the autograd graph during the
393
forward pass, as it runs with the forward pass under
394
:func:`torch.no_grad`. The non-reentrant version does record the
395
autograd graph, allowing one to perform backward on the graph within
396
checkpointed regions.
398
* The reentrant checkpoint only supports the
399
:func:`torch.autograd.backward` API for the backward pass without its
400
`inputs` argument, while the non-reentrant version supports all ways
401
of performing the backward pass.
403
* At least one input and output must have ``requires_grad=True`` for the
404
reentrant variant. If this condition is unmet, the checkpointed part
405
of the model will not have gradients. The non-reentrant version does
406
not have this requirement.
408
* The reentrant version does not consider tensors in nested structures
409
(e.g., custom objects, lists, dicts, etc) as participating in
410
autograd, while the non-reentrant version does.
412
* The reentrant checkpoint does not support checkpointed regions with
413
detached tensors from the computational graph, whereas the
414
non-reentrant version does. For the reentrant variant, if the
415
checkpointed segment contains tensors detached using ``detach()`` or
416
with :func:`torch.no_grad`, the backward pass will raise an error.
417
This is because ``checkpoint`` makes all the outputs require gradients
418
and this causes issues when a tensor is defined to have no gradient in
419
the model. To avoid this, detach the tensors outside of the
420
``checkpoint`` function.
423
function: describes what to run in the forward pass of the model or
424
part of the model. It should also know how to handle the inputs
425
passed as the tuple. For example, in LSTM, if user passes
426
``(activation, hidden)``, :attr:`function` should correctly use the
427
first input as ``activation`` and the second input as ``hidden``
428
preserve_rng_state(bool, optional): Omit stashing and restoring
429
the RNG state during each checkpoint. Note that under torch.compile,
430
this flag doesn't take effect and we always preserve RNG state.
433
specify whether to use the activation checkpoint variant that
434
requires reentrant autograd. This parameter should be passed
435
explicitly. In version 2.4 we will raise an exception if
436
``use_reentrant`` is not passed. If ``use_reentrant=False``,
437
``checkpoint`` will use an implementation that does not require
438
reentrant autograd. This allows ``checkpoint`` to support additional
439
functionality, such as working as expected with
440
``torch.autograd.grad`` and support for keyword arguments input into
441
the checkpointed function.
442
context_fn(Callable, optional): A callable returning a tuple of two
443
context managers. The function and its recomputation will be run
444
under the first and second context managers respectively.
445
This argument is only supported if ``use_reentrant=False``.
446
determinism_check(str, optional): A string specifying the determinism
447
check to perform. By default it is set to ``"default"`` which
448
compares the shapes, dtypes, and devices of the recomputed tensors
449
against those the saved tensors. To turn off this check, specify
450
``"none"``. Currently these are the only two supported values.
451
Please open an issue if you would like to see more determinism
452
checks. This argument is only supported if ``use_reentrant=False``,
453
if ``use_reentrant=True``, the determinism check is always disabled.
454
debug(bool, optional): If ``True``, error messages will also include
455
a trace of the operators ran during the original forward computation
456
as well as the recomputation. This argument is only supported if
457
``use_reentrant=False``.
458
args: tuple containing inputs to the :attr:`function`
461
Output of running :attr:`function` on :attr:`*args`
463
if use_reentrant is None:
465
"torch.utils.checkpoint: the use_reentrant parameter should be "
466
"passed explicitly. In version 2.4 we will raise an exception "
467
"if use_reentrant is not passed. use_reentrant=False is "
468
"recommended, but if you need to preserve the current default "
469
"behavior, you can pass use_reentrant=True. Refer to docs for more "
470
"details on the differences between the two variants."
475
preserve = kwargs.pop("preserve_rng_state", True)
476
if kwargs and use_reentrant:
478
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
482
if context_fn is not noop_context_fn or debug is not False:
484
"Passing `context_fn` or `debug` is only supported when "
485
"use_reentrant=False."
487
return CheckpointFunction.apply(function, preserve, *args)
489
gen = _checkpoint_without_reentrant_generator(
490
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
494
ret = function(*args, **kwargs)
498
except StopIteration:
502
def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
503
r"""Checkpoint a sequential model to save memory.
505
Sequential models execute a list of modules/functions in order
506
(sequentially). Therefore, we can divide such a model in various segments
507
and checkpoint each segment. All segments except the last will not store
508
the intermediate activations. The inputs of each checkpointed segment will
509
be saved for re-running the segment in the backward pass.
512
The ``use_reentrant`` parameter should be passed explicitly. In version
513
2.4 we will raise an exception if ``use_reentrant`` is not passed.
514
If you are using the ``use_reentrant=True` variant, please see
515
:func:`~torch.utils.checkpoint.checkpoint` for
516
the important considerations and limitations of this variant. It is
517
recommended that you use ``use_reentrant=False``.
520
Since PyTorch 1.4, it allows only one Tensor as the input and
521
intermediate outputs, just like :class:`torch.nn.Sequential`.
524
functions: A :class:`torch.nn.Sequential` or the list of modules or
525
functions (comprising the model) to run sequentially.
526
segments: Number of chunks to create in the model
527
input: A Tensor that is input to :attr:`functions`
528
preserve_rng_state(bool, optional): Omit stashing and restoring
529
the RNG state during each checkpoint.
532
specify whether to use the activation checkpoint variant that
533
requires reentrant autograd. This parameter should be passed
534
explicitly. In version 2.4 we will raise an exception if
535
``use_reentrant`` is not passed. If ``use_reentrant=False``,
536
``checkpoint`` will use an implementation that does not require
537
reentrant autograd. This allows ``checkpoint`` to support additional
538
functionality, such as working as expected with
539
``torch.autograd.grad`` and support for keyword arguments input into
540
the checkpointed function.
543
Output of running :attr:`functions` sequentially on :attr:`*inputs`
546
>>> # xdoctest: +SKIP("stub")
547
>>> model = nn.Sequential(...)
548
>>> input_var = checkpoint_sequential(model, chunks, input_var)
550
if use_reentrant is None:
552
"torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
553
"parameter should be passed explicitly. "
554
"In version 2.4 we will raise an exception if use_reentrant "
555
"is not passed. use_reentrant=False is "
556
"recommended, but if you need to preserve the current default "
557
"behavior, you can pass use_reentrant=True. Refer to docs for more "
558
"details on the differences between the two variants."
563
preserve = kwargs.pop("preserve_rng_state", True)
566
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
569
def run_function(start, end, functions):
571
for j in range(start, end + 1):
572
input = functions[j](input)
577
if isinstance(functions, torch.nn.Sequential):
578
functions = list(functions.children())
580
segment_size = len(functions) // segments
583
for start in range(0, segment_size * (segments - 1), segment_size):
584
end = start + segment_size - 1
586
run_function(start, end, functions),
588
use_reentrant=use_reentrant,
589
preserve_rng_state=preserve,
591
return run_function(end + 1, len(functions) - 1, functions)(input)
594
def _internal_assert(cond):
596
raise AssertionError(
597
"Something went unexpectedly wrong in activation checkpoint. "
598
"Please report this bug by filing an issue to PyTorch."
733
_enable_checkpoint_early_stop = True
736
@contextlib.contextmanager
737
def set_checkpoint_early_stop(enable: bool):
738
"""Context manager that sets whether checkpoint should stop recomputation early.
740
By default, non-reentrant checkpoint stops recomputation as soon as it
741
has computed all needed Tensors. This context manager can be used to disable
742
that feature if it is problematic for your specific application.
744
This context manager only needs to be active when forward is run. It does
745
not need to be active during backward.
749
>>> # xdoctest: +SKIP(failing)
750
>>> message = "saved tensors default hooks are disabled"
751
>>> with set_checkpoint_early_stop(False):
752
... # Any checkpoint under this context manager will respect this
753
... # context manager, even if its backward is performed outside.
754
... out = checkpoint(fn, inputs)
758
global _enable_checkpoint_early_stop
760
prev = _enable_checkpoint_early_stop
761
_enable_checkpoint_early_stop = enable
764
_enable_checkpoint_early_stop = prev
773
self.handles: Dict[int, Optional[_Handle]] = dict()
776
class _NoopSaveInputs(torch.autograd.Function):
779
return torch.empty((0,))
782
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
785
tensor_indices, tensors = zip(
786
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
788
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
790
args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
792
def get_args(saved_tensors):
798
saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
799
for i, o in enumerate(args)
805
ctx.get_args = get_args
806
ctx.save_for_backward(*tensors)
809
def backward(ctx, *grad_outputs):
810
raise AssertionError("Did not expect to backward on this graph")
813
class _CheckpointFrame:
814
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
815
self.recompute_fn = recompute_fn
816
self.input_saver = None
817
self.weak_holders: List[ReferenceType] = []
821
self.recomputed: DefaultDict[
822
int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
823
] = defaultdict(weakref.WeakKeyDictionary)
826
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
827
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
830
self.early_stop = early_stop
833
self.metadata_fn = metadata_fn
834
self.unpack_error_cb = unpack_error_cb
835
self.x_metadatas = []
836
self.forward_completed = False
837
self.ignore_saved_mismatch = False
839
def check_recomputed_tensors_match(self, gid):
840
if self.ignore_saved_mismatch:
857
if not len(self.weak_holders) == self.recomp_counter[gid]:
863
raise CheckpointError(
864
"torch.utils.checkpoint: A different number of tensors was saved "
865
"during the original forward and recomputation.\n"
866
f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
867
f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
872
nb_meta_different = []
873
for idx, weak_holder in enumerate(self.weak_holders):
874
holder = weak_holder()
881
_internal_assert(gid in holder.handles)
884
_internal_assert(holder.handles[gid] is not None)
886
_internal_assert(holder.handles[gid] in self.recomputed[gid])
888
x_meta = self.x_metadatas[idx]
889
recomputed_x = self.recomputed[gid][holder.handles[gid]]
890
if x_meta != self.metadata_fn(recomputed_x):
891
nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
893
if len(nb_meta_different) > 0:
894
mismatched_tensors = ""
895
for idx, x_meta, recomputed_meta in nb_meta_different:
896
mismatched_tensors += (
897
f"tensor at position {idx}:\n"
898
f"saved metadata: {x_meta}\n"
899
f"recomputed metadata: {recomputed_meta}\n"
901
raise CheckpointError(
902
"torch.utils.checkpoint: Recomputed values for the following tensors "
903
"have different metadata than during the forward pass.\n"
904
f"{mismatched_tensors}"
908
_checkpoint_error_template = """ \
909
An error happened while unpacking tensors; dumping logs of latest computation
910
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
911
Scroll all the way down for guidance on how to navigate these logs.
913
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
914
| 1. Stack traces of the operators that ran in the original forward |
915
+------------------------------------------------------------------------------+
918
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
919
| 2. Stack traces of the operators that ran during recomputation |
920
+------------------------------------------------------------------------------+
923
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
924
| 3. Log of operators in the original forward and recomputation |
925
+------------------------------------------------------------------------------+
926
(Scroll up to correlate stack traces with each operation listed below. This
927
helps identify their source in the code.)
929
IMPORTANT: Differences in "detach" calls between the original forward and the
930
recomputation are expected. They are introduced by the checkpointing
931
mechanism and can be ignored.
933
Operations executed during the original forward:
937
Operations executed during recomputation:
941
+------------------------------------------------------------------------------+
942
ERROR: Detected non-determinism while running activation checkpointing
944
You are seeing this error because you passed `debug=True` to checkpoint and
945
tensors to be saved during the original forward and differ between those saved
946
during recomputation. This can happen if different operators were ran in the
947
original forward and in the recomputation.
949
To identify where the mismatch may be coming from, you can do the following:
951
1) Compare the operators ran during original forward and recomputation to
952
see where they differ. These operators are printed above in the order they
955
2) Review the stack trace for each operator to locate its invocation source.
956
Each operator's stack trace is printed in their execution order.
958
Note that the logs can be quite long. Here's how they are structured:
959
(Tip: you can Ctrl-f for these headers)
961
1. Stack traces of the operators that ran in the original forward
962
2. Stack traces of the operators that ran during recomputation
963
3. Log of operators in the original forward and recomputation
964
4. Error message <--- You are here
965
--------------------------------------------------------------------------------
968
class CheckpointError(RuntimeError):
972
def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
978
cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
985
def get_context_manager(self):
986
@contextlib.contextmanager
988
with LoggingTensorMode(), \
989
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
990
self.logs, self.tbs = logs_and_tb
992
return logging_mode()
994
capture_logs_fwd = CaptureLogs()
995
capture_logs_recompute = CaptureLogs()
997
def unpack_error_cb(e: CheckpointError):
998
def get_str_tb(label, capture_logs):
1000
total_len = len(capture_logs.logs)
1001
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
1002
out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
1003
found_torch_dispatch = False
1006
is_torch_dispatch = line['name'] == '__torch_dispatch__'
1007
if not found_torch_dispatch and not is_torch_dispatch:
1009
elif is_torch_dispatch:
1010
found_torch_dispatch = True
1012
out += f"{line['filename']}:{line['line']}:{line['name']}\n"
1015
assert capture_logs_fwd.logs is not None
1016
assert capture_logs_recompute.logs is not None
1017
raise CheckpointError(
1018
_checkpoint_error_template.format(
1019
forward_traces=get_str_tb("original", capture_logs_fwd),
1020
recompute_traces=get_str_tb("recompute", capture_logs_recompute),
1021
forward_ops="\n".join(capture_logs_fwd.logs),
1022
recompute_ops="\n".join(capture_logs_recompute.logs)
1027
return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
1029
return context_fn, unpack_error_cb
1031
def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
1039
_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
1040
_DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
1041
"none": lambda _: None,
1045
class _StopRecomputationError(Exception):
1049
class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
1050
def __init__(self, target_frame_ref: ReferenceType, gid: int):
1052
target_frame = target_frame_ref()
1053
assert target_frame is not None
1054
recomp_idx = target_frame.recomp_counter[gid]
1055
target_frame.recomp_counter[gid] += 1
1057
if recomp_idx >= len(target_frame.weak_holders):
1058
assert not target_frame.early_stop
1059
if not target_frame.forward_completed:
1065
target_frame.ignore_saved_mismatch = True
1067
raise CheckpointError(
1068
"torch.utils.checkpoint: trying to save more tensors during "
1069
"recomputation than during the original forward pass."
1072
holder = target_frame.weak_holders[recomp_idx]()
1076
if holder is not None:
1077
_internal_assert(holder.handles.get(gid, None) is None)
1078
holder.handles[gid] = _Handle()
1079
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()
1081
if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
1082
target_frame.weak_holders
1084
raise _StopRecomputationError()
1093
super().__init__(pack_hook, unpack_hook)
1096
class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
1097
def __init__(self, frame):
1101
frame.weak_holders.append(weakref.ref(holder))
1103
if frame.metadata_fn is not None:
1104
with torch.no_grad():
1105
frame.x_metadatas.append(frame.metadata_fn(x))
1108
def unpack_hook(holder):
1109
gid = torch._C._current_graph_task_id()
1112
gid = int(uuid.uuid4())
1114
if not frame.is_recomputed[gid]:
1115
ctx = frame.input_saver.grad_fn
1116
args = ctx.get_args(ctx.saved_tensors)
1119
with _recomputation_hook(
1120
weakref.ref(frame), gid
1121
), torch.autograd.enable_grad():
1122
frame.recompute_fn(*args)
1123
except _StopRecomputationError:
1125
frame.is_recomputed[gid] = True
1126
frame.check_recomputed_tensors_match(gid)
1128
_internal_assert(gid in holder.handles)
1130
if holder.handles[gid] is None:
1131
raise CheckpointError(
1132
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
1133
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
1134
"so only once. Otherwise please open an issue with details on your use case."
1136
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
1137
ret = frame.recomputed[gid][holder.handles[gid]]
1138
holder.handles[gid] = None
1141
if frame.unpack_error_cb is not None:
1142
def unpack_hook_with_error_cb(holder):
1144
return unpack_hook(holder)
1145
except CheckpointError as e:
1146
frame.unpack_error_cb(e)
1147
super().__init__(pack_hook, unpack_hook_with_error_cb)
1149
super().__init__(pack_hook, unpack_hook)
1152
def _is_compiling(func, args, kwargs):
1157
if isinstance(arg, torch.Tensor) and is_fun(arg):
1163
if isinstance(x, torch.Tensor):
1175
torch.ops.prim.device.default,
1176
torch.ops.aten.detach.default,
1177
} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
1180
class _CachingTorchDispatchMode(TorchDispatchMode):
1182
A :class:`TorchDispatchMode` to implement selective activation checkpointing
1183
that's compatible with torch.compile. Used together with _CachedTorchDispatchMode.
1185
def __init__(self, policy_fn, storage):
1186
self.policy_fn = policy_fn
1187
self.storage = storage
1189
def push_into_storage(self, out, func, args, kwargs):
1190
out_detached = tree_map(_detach, out)
1191
self.storage[func].append(out_detached)
1193
def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs):
1194
if func in _ignored_ops:
1195
return func(*args, **kwargs)
1196
if should_not_recompute:
1197
fx_traceback.current_meta["recompute"] = 0
1200
out = func(*args, **kwargs)
1201
self.push_into_storage(out, func, args, kwargs)
1204
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1207
should_not_recompute = self.policy_fn("forward", func, *args, **kwargs)
1208
if _is_compiling(func, args, kwargs):
1209
return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs)
1211
if should_not_recompute:
1212
out = func(*args, **kwargs)
1213
self.push_into_storage(out, func, args, kwargs)
1215
out = func(*args, **kwargs)
1219
class _CachedTorchDispatchMode(TorchDispatchMode):
1221
A :class:`TorchDispatchMode` to implement selective activation checkpointing
1222
that's compatible with torch.compile. Used together with _CachingTorchDispatchMode.
1224
def __init__(self, policy_fn, storage):
1225
self.policy_fn = policy_fn
1226
self.storage = storage
1228
def pop_from_storage(self, func, args, kwargs):
1229
assert func in self.storage
1230
out = self.storage[func].pop(0)
1233
def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs):
1234
if func in _ignored_ops:
1235
return func(*args, **kwargs)
1236
out = self.pop_from_storage(func, args, kwargs)
1239
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1242
should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs)
1243
if _is_compiling(func, args, kwargs):
1244
return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs)
1246
if should_not_recompute:
1247
out = self.pop_from_storage(func, args, kwargs)
1249
out = func(*args, **kwargs)
1253
def _pt2_selective_checkpoint_context_fn_gen(policy_fn):
1255
A helper function that generates a pair of contexts to be later passed into
1256
`torch.utils.checkpoint` API to implment selective checkpointing.
1259
This is context_fn is intended for use with torch.compile only.
1262
policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function
1263
to decide whether a particular op should be recomputed in backward pass or not.
1265
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
1266
If policy_fn(...) returns False, the op is guaranteed to be recomputed.
1267
In torch.compile mode:
1268
If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed.
1269
If policy_fn(...) returns False, the op may or may not be recomputed
1270
(it's up to the partitioner to decide).
1273
A pair of generated contexts.
1276
>>> # xdoctest: +REQUIRES(LINUX)
1278
>>> def get_custom_policy():
1279
>>> no_recompute_list = [
1280
>>> torch.ops.aten.mm.default,
1282
>>> def custom_policy(mode, func, *args, **kwargs):
1283
>>> return func in no_recompute_list
1284
>>> return custom_policy
1286
>>> def selective_checkpointing_context_fn():
1287
>>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy())
1290
>>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
1293
>>> return torch.utils.checkpoint.checkpoint(
1295
>>> use_reentrant=False,
1296
>>> context_fn=selective_checkpointing_context_fn,
1299
>>> x = torch.randn(4, 4, requires_grad=True)
1300
>>> y = torch.randn(4, 4, requires_grad=True)
1302
>>> compiled_fn = torch.compile(fn)
1304
storage: Dict[Any, List[Any]] = defaultdict(list)
1305
return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage)
1311
def _checkpoint_without_reentrant_generator(
1313
preserve_rng_state=True,
1314
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
1315
determinism_check: str = _DEFAULT_DETERMINISM_MODE,
1316
debug: bool = False,
1320
"""Checkpointing without reentrant autograd.
1323
function: describes what to run in the forward pass of the model or
1324
part of the model. It should also know how to handle the inputs
1325
passed as the tuple. For example, in LSTM, if user passes
1326
``(activation, hidden)``, :attr:`function` should correctly use the
1327
first input as ``activation`` and the second input as ``hidden``
1328
preserve_rng_state(bool, optional): Omit stashing and restoring
1329
the RNG state during each checkpoint.
1331
context_fn(Callable, optional): A callable returning a tuple of two
1332
context managers. The function and its recomputation will be run
1333
under the first and second context managers respectively.
1334
determinism_check(str, optional): A string specifying the determinism
1335
check to perform. By default it is set to ``"default"`` which
1336
compares the shapes, dtypes, and devices of the recomputed tensors
1337
against those the saved tensors. To turn off this check, specify
1338
``"none"``. Currently these are the only two supported values.
1339
Please open an issue if you would like to see more determinism
1341
debug(bool, optional): If ``True``, error messages will also include
1342
a trace of the operators ran during the original forward computation
1343
as well as the recomputation.
1344
*args: Arguments to pass in to the given ``function``.
1345
**kwargs: Keyword arguments to pass into the given ``function``.
1347
unpack_error_cb = None
1349
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
1350
if context_fn != noop_context_fn:
1352
"debug=True is incompatible with non-default context_fn"
1354
context_fn, unpack_error_cb = _get_debug_context_and_cb()
1356
if determinism_check in _allowed_determinism_checks_to_fns:
1357
metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
1360
f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
1361
f"but got {determinism_check}"
1364
device = _infer_device_type(*args)
1365
device_module = _get_device_module(device)
1366
forward_context, recompute_context = context_fn()
1367
if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
1369
isinstance(forward_context, TorchDispatchMode) and
1370
isinstance(recompute_context, TorchDispatchMode)
1372
"In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
1373
"must generate a tuple of two `TorchDispatchMode`s."
1375
device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device)
1377
if preserve_rng_state:
1378
fwd_cpu_state = torch.get_rng_state()
1384
had_device_in_fwd = False
1385
if getattr(device_module, "_initialized", False):
1386
had_device_in_fwd = True
1387
fwd_devices, fwd_device_states = get_device_states(*args)
1389
def recompute_fn(*inputs):
1390
kwargs, *args = inputs
1394
if preserve_rng_state and had_device_in_fwd:
1395
rng_devices = fwd_devices
1396
with torch.random.fork_rng(
1397
devices=rng_devices, enabled=preserve_rng_state, device_type=device
1399
if preserve_rng_state:
1400
torch.set_rng_state(fwd_cpu_state)
1401
if had_device_in_fwd:
1402
set_device_states(fwd_devices, fwd_device_states)
1404
device_autocast_ctx = device_module.amp.autocast(
1405
**device_autocast_kwargs
1406
) if _supports_autocast(device) else contextlib.nullcontext()
1407
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
1411
new_frame = _CheckpointFrame(
1413
_enable_checkpoint_early_stop,
1417
dummy = torch.empty((0,), requires_grad=True)
1418
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
1421
if new_frame.input_saver.grad_fn is None:
1425
with _checkpoint_hook(new_frame), forward_context:
1427
new_frame.forward_completed = True
1429
if getattr(device_module, "_initialized", False) and \
1430
preserve_rng_state and not had_device_in_fwd:
1434
"PyTorch's device state was initialized in the forward pass "
1435
"of a Checkpoint, which is not allowed. Please open an issue "
1436
"if you need this feature."