pytorch
240 строк · 9.5 Кб
1# mypy: allow-untyped-defs
2import inspect3import itertools4import logging5from typing import Optional6
7from torch._logging import warning_once8from torch._ops import HigherOrderOperator9from torch.types import _dtype10from torch.utils.checkpoint import checkpoint, CheckpointPolicy11
12
13log = logging.getLogger(__name__)14
15uid = itertools.count(1)16
17
18# Used for testing the HigherOrderOperator mechanism
19class Wrap(HigherOrderOperator):20def __init__(self) -> None:21super().__init__("wrap")22
23def __call__(self, func, *args, **kwargs):24# Dynamo already traces the body of HigherOrderOp beforehand when it25# so no need to trace into it.26import torch._dynamo # noqa: F40127from torch._dynamo import disable28
29@disable30def wrapper():31result = func(*args, **kwargs)32return result33
34return wrapper()35
36
37wrap = Wrap()38
39
40class WrapWithSetGradEnabled(HigherOrderOperator):41def __init__(self) -> None:42super().__init__("wrap_with_set_grad_enabled")43
44def __call__(self, enable_grad, wrapped_func, *args, **kwargs):45# Dynamo already traces the body of HigherOrderOp beforehand when it46# so no need to trace into it.47import torch._dynamo # noqa: F40148from torch._dynamo import disable49
50@disable51def wrapper():52with torch.set_grad_enabled(enable_grad):53return wrapped_func(*args, **kwargs)54
55return wrapper()56
57
58wrap_with_set_grad_enabled = WrapWithSetGradEnabled()59
60
61class WrapWithAutocast(HigherOrderOperator):62def __init__(self):63super().__init__("wrap_with_autocast")64
65def __call__(66self,67device_type: str,68dtype: Optional[_dtype],69enabled: bool,70cache_enabled: Optional[bool],71wrapped_func,72*args,73**kwargs,74):75# Dynamo already traces the body of HigherOrderOp beforehand when it76# so no need to trace into it.77import torch._dynamo # noqa: F40178from torch._dynamo import disable79
80@disable81def wrapper():82with torch.autocast(device_type, dtype, enabled, cache_enabled):83return wrapped_func(*args, **kwargs)84
85return wrapper()86
87
88wrap_with_autocast = WrapWithAutocast()89
90
91class WrapActivationCheckpoint(HigherOrderOperator):92"""93This operator is used to wrap torch.utils.checkpoint. This avoids
94TorchDynamo to look into saved tensor hooks and directly passes the control
95to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
96AOT tracing torch.utils.checkpoint code, we have a backward graph with
97recomputed forward nodes.
98
99However, we might deprecate this operator soon. The difficulty arises in the
100functionalization of rng ops. Today, there are two different
101functionalization of rng ops - one at AOT autograd and other at Inductor.
102And they are difficult to map to each other. The rng states also complicate
103pattern matching in Inductor. Due to the ease of implementation, we are
104currently inclined towards functionalization at Inductor level, which means
105that duplication/recomputation is done as a compiler pass in the
106partitioners. See TagActivationCheckpoint for more information.
107"""
108
109def __init__(self) -> None:110super().__init__("wrap_activation_checkpoint")111
112def __call__(self, function, *args, **kwargs):113# use_reentrant is set to False because this op is going to be traced.114# And we ensure that AOT Autograd traces through the non reentrant115# version of checkpointing.116import torch.fx.traceback as fx_traceback117from torch.fx import Interpreter118
119kwargs["use_reentrant"] = False120kwargs["preserve_rng_state"] = False121# Using interpreter allows preservation of metadata through torch.compile stack.122with fx_traceback.preserve_node_meta():123return checkpoint(Interpreter(function).run, *args, **kwargs)124
125
126wrap_activation_checkpoint = WrapActivationCheckpoint()127
128
129class TagActivationCheckpoint(HigherOrderOperator):130"""131This operator is supposed to be used only with torch.compile stack. This
132accepts a Fx graph module which needs to be checkpointed. This operator adds
133"recomputable" tag to the nodes of the Fx graph that should be recomputed.
134
135The goal is to:
1361. Avoid using Dynamo to trace through saved tensor hooks.
1372. For selective checkpointing case, let AOTAutograd trace through
138saved tensor hooks but has special logic with TorchDispatchMode to override
139the usual saved_tensor_hooks fn logic in order to tag the nodes.
1403. Rely on the partitioners to actually duplicate the nodes.
141This sits well in the torch.compile stack, because by the time graph
142reaches partitioner, inductor has already run its functionalization of rng
143ops (by setting fixed seed for each random op, see `replace_random_passes`).
144Therefore, the duplication of nodes, by design, respects the rng states in
145the forward and recomputed forward in backward.
146"""
147
148def __init__(self) -> None:149super().__init__("tag_activation_checkpoint")150
151@staticmethod152def divide_kwargs(kwargs):153"""154checkpoint fn can have mixed kwargs between checkpointed fn and
155checkpoint fn itself. For example
156>> def gn(x, y, z=None):
157>> a = torch.matmul(x, y)
158>> if z is not None:
159>> return torch.matmul(a, z)
160>> return a
161>> def fn(x, y, z):
162>> return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
163In the above case, z belongs to checkpointed function gn, but
164use_reentrant belongs to the checkpoint function. This function splits
165the kwargs into checkpoint_kwargs and gmod_kwargs (or
166checkpointed_fn_kwargs).
167We do sorting to ensure same graph from run to run for better
168debuggability. It is not required for correctness.
169"""
170ckpt_signature = inspect.signature(checkpoint)171checkpoint_keys = set()172for name in ckpt_signature.parameters:173if name in ("function", "args", "kwargs"):174continue175checkpoint_keys.add(name)176
177# `preserve_rng_state` is not a regular kwarg178checkpoint_keys.add("preserve_rng_state")179
180checkpoint_kwargs = {181name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys182}183gmod_kwargs = {184name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys185}186return checkpoint_kwargs, gmod_kwargs187
188def tag_nodes(self, gmod, is_sac):189unique_graph_id = next(uid)190for node in gmod.graph.nodes:191if node.op in ("call_function", "call_method", "call_module"):192node.meta["ac_graph_id"] = unique_graph_id193if is_sac:194# For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.195node.meta["recompute"] = None196else:197# Under vanilla activation checkpointing, all nodes should be recomputed.198node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE199return gmod200
201def __call__(self, gmod, *args, **kwargs):202import torch.fx.traceback as fx_traceback203from torch.fx import Interpreter204
205if "_checkpoint_context_fn" in gmod.meta:206warning_once(207log,208"""209Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
210Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
211""",212)213# use_reentrant is set to False because this op is going to be traced.214# And we ensure that AOT Autograd traces through the non reentrant215# version of checkpointing.216kwargs["use_reentrant"] = False217# preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through218# `torch.random.fork_rng` op (which is not supported yet under CUDA).219# This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state220# regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor221# instead of in AOTAutograd).222kwargs["preserve_rng_state"] = False223kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]224# We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag225# for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.226gmod = self.tag_nodes(gmod, is_sac=True)227# Using interpreter allows preservation of metadata through torch.compile stack.228with fx_traceback.preserve_node_meta():229return checkpoint(Interpreter(gmod).run, *args, **kwargs)230else:231gmod = self.tag_nodes(gmod, is_sac=False)232# Using interpreter allows preservation of metadata through torch.compile stack.233# TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here234# as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.235# (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)236with fx_traceback.preserve_node_meta():237return Interpreter(gmod).run(*args)238
239
240tag_activation_checkpoint = TagActivationCheckpoint()241