pytorch

Форк
0
240 строк · 9.5 Кб
1
# mypy: allow-untyped-defs
2
import inspect
3
import itertools
4
import logging
5
from typing import Optional
6

7
from torch._logging import warning_once
8
from torch._ops import HigherOrderOperator
9
from torch.types import _dtype
10
from torch.utils.checkpoint import checkpoint, CheckpointPolicy
11

12

13
log = logging.getLogger(__name__)
14

15
uid = itertools.count(1)
16

17

18
# Used for testing the HigherOrderOperator mechanism
19
class Wrap(HigherOrderOperator):
20
    def __init__(self) -> None:
21
        super().__init__("wrap")
22

23
    def __call__(self, func, *args, **kwargs):
24
        # Dynamo already traces the body of HigherOrderOp beforehand when it
25
        # so no need to trace into it.
26
        import torch._dynamo  # noqa: F401
27
        from torch._dynamo import disable
28

29
        @disable
30
        def wrapper():
31
            result = func(*args, **kwargs)
32
            return result
33

34
        return wrapper()
35

36

37
wrap = Wrap()
38

39

40
class WrapWithSetGradEnabled(HigherOrderOperator):
41
    def __init__(self) -> None:
42
        super().__init__("wrap_with_set_grad_enabled")
43

44
    def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
45
        # Dynamo already traces the body of HigherOrderOp beforehand when it
46
        # so no need to trace into it.
47
        import torch._dynamo  # noqa: F401
48
        from torch._dynamo import disable
49

50
        @disable
51
        def wrapper():
52
            with torch.set_grad_enabled(enable_grad):
53
                return wrapped_func(*args, **kwargs)
54

55
        return wrapper()
56

57

58
wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
59

60

61
class WrapWithAutocast(HigherOrderOperator):
62
    def __init__(self):
63
        super().__init__("wrap_with_autocast")
64

65
    def __call__(
66
        self,
67
        device_type: str,
68
        dtype: Optional[_dtype],
69
        enabled: bool,
70
        cache_enabled: Optional[bool],
71
        wrapped_func,
72
        *args,
73
        **kwargs,
74
    ):
75
        # Dynamo already traces the body of HigherOrderOp beforehand when it
76
        # so no need to trace into it.
77
        import torch._dynamo  # noqa: F401
78
        from torch._dynamo import disable
79

80
        @disable
81
        def wrapper():
82
            with torch.autocast(device_type, dtype, enabled, cache_enabled):
83
                return wrapped_func(*args, **kwargs)
84

85
        return wrapper()
86

87

88
wrap_with_autocast = WrapWithAutocast()
89

90

91
class WrapActivationCheckpoint(HigherOrderOperator):
92
    """
93
    This operator is used to wrap torch.utils.checkpoint. This avoids
94
    TorchDynamo to look into saved tensor hooks and directly passes the control
95
    to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
96
    AOT tracing torch.utils.checkpoint code, we have a backward graph with
97
    recomputed forward nodes.
98

99
    However, we might deprecate this operator soon. The difficulty arises in the
100
    functionalization of rng ops. Today, there are two different
101
    functionalization of rng ops - one at AOT autograd and other at Inductor.
102
    And they are difficult to map to each other. The rng states also complicate
103
    pattern matching in Inductor. Due to the ease of implementation, we are
104
    currently inclined towards functionalization at Inductor level, which means
105
    that duplication/recomputation is done as a compiler pass in the
106
    partitioners. See TagActivationCheckpoint for more information.
107
    """
108

109
    def __init__(self) -> None:
110
        super().__init__("wrap_activation_checkpoint")
111

112
    def __call__(self, function, *args, **kwargs):
113
        # use_reentrant is set to False because this op is going to be traced.
114
        # And we ensure that AOT Autograd traces through the non reentrant
115
        # version of checkpointing.
116
        import torch.fx.traceback as fx_traceback
117
        from torch.fx import Interpreter
118

119
        kwargs["use_reentrant"] = False
120
        kwargs["preserve_rng_state"] = False
121
        # Using interpreter allows preservation of metadata through torch.compile stack.
122
        with fx_traceback.preserve_node_meta():
123
            return checkpoint(Interpreter(function).run, *args, **kwargs)
124

125

126
wrap_activation_checkpoint = WrapActivationCheckpoint()
127

128

129
class TagActivationCheckpoint(HigherOrderOperator):
130
    """
131
    This operator is supposed to be used only with torch.compile stack. This
132
    accepts a Fx graph module which needs to be checkpointed. This operator adds
133
    "recomputable" tag to the nodes of the Fx graph that should be recomputed.
134

135
    The goal is to:
136
    1. Avoid using Dynamo to trace through saved tensor hooks.
137
    2. For selective checkpointing case, let AOTAutograd trace through
138
       saved tensor hooks but has special logic with TorchDispatchMode to override
139
       the usual saved_tensor_hooks fn logic in order to tag the nodes.
140
    3. Rely on the partitioners to actually duplicate the nodes.
141
    This sits well in the torch.compile stack, because by the time graph
142
    reaches partitioner, inductor has already run its functionalization of rng
143
    ops (by setting fixed seed for each random op, see `replace_random_passes`).
144
    Therefore, the duplication of nodes, by design, respects the rng states in
145
    the forward and recomputed forward in backward.
146
    """
147

148
    def __init__(self) -> None:
149
        super().__init__("tag_activation_checkpoint")
150

151
    @staticmethod
152
    def divide_kwargs(kwargs):
153
        """
154
        checkpoint fn can have mixed kwargs between checkpointed fn and
155
        checkpoint fn itself. For example
156
        >> def gn(x, y, z=None):
157
        >>     a = torch.matmul(x, y)
158
        >>     if z is not None:
159
        >>         return torch.matmul(a, z)
160
        >>     return a
161
        >> def fn(x, y, z):
162
        >>     return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
163
        In the above case, z belongs to checkpointed function gn, but
164
        use_reentrant belongs to the checkpoint function. This function splits
165
        the kwargs into checkpoint_kwargs and gmod_kwargs (or
166
        checkpointed_fn_kwargs).
167
        We do sorting to ensure same graph from run to run for better
168
        debuggability. It is not required for correctness.
169
        """
170
        ckpt_signature = inspect.signature(checkpoint)
171
        checkpoint_keys = set()
172
        for name in ckpt_signature.parameters:
173
            if name in ("function", "args", "kwargs"):
174
                continue
175
            checkpoint_keys.add(name)
176

177
        # `preserve_rng_state` is not a regular kwarg
178
        checkpoint_keys.add("preserve_rng_state")
179

180
        checkpoint_kwargs = {
181
            name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
182
        }
183
        gmod_kwargs = {
184
            name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
185
        }
186
        return checkpoint_kwargs, gmod_kwargs
187

188
    def tag_nodes(self, gmod, is_sac):
189
        unique_graph_id = next(uid)
190
        for node in gmod.graph.nodes:
191
            if node.op in ("call_function", "call_method", "call_module"):
192
                node.meta["ac_graph_id"] = unique_graph_id
193
                if is_sac:
194
                    # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.
195
                    node.meta["recompute"] = None
196
                else:
197
                    # Under vanilla activation checkpointing, all nodes should be recomputed.
198
                    node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
199
        return gmod
200

201
    def __call__(self, gmod, *args, **kwargs):
202
        import torch.fx.traceback as fx_traceback
203
        from torch.fx import Interpreter
204

205
        if "_checkpoint_context_fn" in gmod.meta:
206
            warning_once(
207
                log,
208
                """
209
Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
210
Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
211
""",
212
            )
213
            # use_reentrant is set to False because this op is going to be traced.
214
            # And we ensure that AOT Autograd traces through the non reentrant
215
            # version of checkpointing.
216
            kwargs["use_reentrant"] = False
217
            # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through
218
            # `torch.random.fork_rng` op (which is not supported yet under CUDA).
219
            # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state
220
            # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor
221
            # instead of in AOTAutograd).
222
            kwargs["preserve_rng_state"] = False
223
            kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]
224
            # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag
225
            # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.
226
            gmod = self.tag_nodes(gmod, is_sac=True)
227
            # Using interpreter allows preservation of metadata through torch.compile stack.
228
            with fx_traceback.preserve_node_meta():
229
                return checkpoint(Interpreter(gmod).run, *args, **kwargs)
230
        else:
231
            gmod = self.tag_nodes(gmod, is_sac=False)
232
            # Using interpreter allows preservation of metadata through torch.compile stack.
233
            # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here
234
            # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.
235
            # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)
236
            with fx_traceback.preserve_node_meta():
237
                return Interpreter(gmod).run(*args)
238

239

240
tag_activation_checkpoint = TagActivationCheckpoint()
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.