pytorch

Форк
0
289 строк · 9.4 Кб
1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
3
from enum import Enum
4
from typing import Any, Dict, Optional, Tuple, Union
5
from weakref import WeakKeyDictionary
6

7
import torch
8
import torch.utils._pytree as pytree
9
from torch._C import DispatchKey
10
from torch._higher_order_ops.torchbind import call_torchbind
11
from torch._ops import HigherOrderOperator
12
from torch._subclasses.fake_tensor import FakeTensorMode
13
from torch.fx.experimental.proxy_tensor import (
14
    disable_proxy_modes_tracing,
15
    ProxyTorchDispatchMode,
16
    track_tensor_tree,
17
)
18

19

20
class _EffectType(Enum):
21
    ORDERED = "Ordered"
22

23

24
OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
25

26

27
SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
28
    {
29
        torch.ops.aten._print.default: _EffectType.ORDERED,
30
        call_torchbind: _EffectType.ORDERED,
31
    }
32
)
33

34

35
def _register_effectful_op(op: OpType, effect: _EffectType):
36
    assert isinstance(
37
        op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
38
    ) and not has_aliasing(op)
39
    if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
40
        raise RuntimeError(
41
            f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
42
            f"trying to register a different effect type {effect}."
43
        )
44
    SIDE_EFFECTS[op] = effect
45

46

47
def _deregister_effectful_op(op: OpType):
48
    if op not in SIDE_EFFECTS:
49
        raise RuntimeError(f"Op {op} is not registered as effectful")
50

51
    del SIDE_EFFECTS[op]
52

53

54
class WithEffects(HigherOrderOperator):
55
    """
56
    with_effects(token, op, args, kwargs) -> (new_token, op_results)
57

58
    This HOP helps ensure ordering between side effectful ops like prints or ops
59
    using torchbind objects. This is needed to ensure a traced graph from
60
    AOTAutograd is functional so that future optimization passes do not reorder
61
    these operators. This is done through threading "effect tokens" through the
62
    graph to enforce data dependence between side effectful ops.
63

64
    The tokens are basically dummy values (torch.tensor([])). We create a token
65
    per "effect type", which are enumerated in the _EffectType enum.
66
    """
67

68
    def __init__(self) -> None:
69
        super().__init__("with_effects")
70

71
    def __call__(
72
        self,
73
        token,
74
        op: OpType,
75
        *args: Tuple[Any, ...],
76
        **kwargs: Dict[str, Any],
77
    ) -> Tuple[Any, ...]:
78
        assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
79
        assert not has_aliasing(op), "Ops with aliasing is not supported"
80
        assert has_effects(op, args, kwargs)
81
        assert isinstance(kwargs, dict)
82
        return super().__call__(token, op, *args, **kwargs)
83

84

85
with_effects = WithEffects()
86

87

88
def has_aliasing(op: OpType):
89
    # NOT FOR PUBLIC USE
90
    if isinstance(op, torch._ops.HigherOrderOperator):
91
        return op not in SIDE_EFFECTS
92

93
    for arg in op._schema.arguments:
94
        if arg.alias_info is not None:
95
            return True
96
    for arg in op._schema.returns:
97
        if arg.alias_info is not None:
98
            return True
99
    return False
100

101

102
def has_effects(op, args, kwargs) -> bool:
103
    # Skip over the profiler's RecordFunction as they should not show up in the graph
104
    _skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
105
    if op in _skip_ops:
106
        return False
107

108
    return (
109
        isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
110
        and not has_aliasing(op)
111
        and get_effect_key(op, args, kwargs) is not None
112
    )
113

114

115
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
116
    if op in SIDE_EFFECTS:
117
        return SIDE_EFFECTS[op]
118

119
    for arg in args:
120
        if isinstance(arg, torch.ScriptObject):
121
            # Add it to the table so that next time we see the same op we don't
122
            # have to parse through the args again
123
            SIDE_EFFECTS[op] = _EffectType.ORDERED
124
            return _EffectType.ORDERED
125

126
    return None
127

128

129
def new_token_tensor() -> torch.Tensor:
130
    # Use dtype bool to not affect Inductor dtype promotions
131
    return torch.tensor([], dtype=torch.bool)
132

133

134
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
135
def with_effects_dense(
136
    token: torch.Tensor,
137
    op: torch._ops.OpOverload,
138
    *args: Tuple[Any, ...],
139
    **kwargs: Dict[str, Any],
140
) -> Tuple[torch.Tensor, ...]:
141
    out = op(*args, **kwargs)
142
    new_token = new_token_tensor()
143
    if isinstance(out, tuple):
144
        return (new_token, *out)
145
    return (new_token, out)
146

147

148
@with_effects.py_impl(FakeTensorMode)
149
def with_effects_fake(
150
    mode,
151
    token: torch.Tensor,
152
    op: torch._ops.OpOverload,
153
    *args: Tuple[Any, ...],
154
    **kwargs: Dict[str, Any],
155
) -> Tuple[torch.Tensor, ...]:
156
    with mode:
157
        result = with_effects_dense(token, op, *args, **kwargs)
158
        return result
159

160

161
@with_effects.py_impl(ProxyTorchDispatchMode)
162
def with_effects_proxy(
163
    mode,
164
    token: torch.Tensor,
165
    op: torch._ops.OpOverload,
166
    *args: Tuple[Any, ...],
167
    **kwargs: Dict[str, Any],
168
) -> Tuple[torch.Tensor, ...]:
169
    with disable_proxy_modes_tracing():
170
        out = with_effects(token, op, *args, **kwargs)
171

172
    proxy_token = mode.tracer.unwrap_proxy(token)
173
    proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
174
    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
175

176
    from torch.fx.node import has_side_effect
177

178
    # To avoid the being DCEed by graph.eliminate_dead_code if they.
179
    # don't have output or their outputs are not used.
180
    has_side_effect(op)
181

182
    out_proxy = mode.tracer.create_proxy(
183
        "call_function",
184
        with_effects,
185
        (proxy_token, op, *proxy_args),
186
        proxy_kwargs,
187
    )
188
    result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
189
    return result
190

191

192
with_effects.fallthrough(DispatchKey.AutogradCPU)
193
with_effects.fallthrough(DispatchKey.AutogradCUDA)
194

195

196
def _get_schema(op, args) -> torch.FunctionSchema:
197
    if isinstance(op, torch._ops.OpOverload):
198
        return op._schema
199
    elif op == call_torchbind:
200
        return getattr(args[0], args[1]).schema
201
    else:
202
        raise RuntimeError(f"Unable to get schema for op {op}")
203

204

205
def handle_effects(
206
    allow_token_discovery: bool,
207
    tokens: Dict[_EffectType, torch.Tensor],
208
    op: OpType,
209
    args: Tuple[Any, ...],
210
    kwargs: Dict[str, Any],
211
) -> Any:
212
    """
213
    Args:
214
        allow_token_discovery: Whether or not we are discovering tokens. If this
215
        is true, we will create a token for every side effect type seen that
216
        does not have a token assigned yet.  If this is false, the tokens
217
        should've all been created ahead of time, so we will error if there is
218
        no token mapping to every effect type.
219

220
        tokens: Map of effect type to tokens. This is to chain operators of the
221
        same effects together so that they do not get reordered in later
222
        optimization passes.
223
    """
224

225
    # Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
226
    # this will create an empty tensor during proxy mode tracing if the token
227
    # doesn't exist. But the tokens should always exist during proxy mode tracing.
228
    key = get_effect_key(op, args, kwargs)
229
    assert key is not None
230
    if key not in tokens:
231
        assert (
232
            allow_token_discovery
233
        ), f"Could not find a token for effect {key} which came from the function {op}"
234
        proxy_tensor_mode = torch._C._get_dispatch_mode(
235
            torch._C._TorchDispatchModeKey.PROXY
236
        )
237
        if proxy_tensor_mode is not None:
238
            # If we discovered a new token during tracing, we are in backward.
239
            # Then we patch the graph, adding additional tangents_token as input to the joint graph.
240
            tracer = proxy_tensor_mode.tracer
241

242
            from torch.fx.experimental.proxy_tensor import (
243
                disable_proxy_modes_tracing,
244
                track_tensor_tree,
245
            )
246

247
            with disable_proxy_modes_tracing():
248
                token_tensor = new_token_tensor()
249

250
            token_proxy = proxy_tensor_mode.tracer.create_proxy(
251
                "placeholder", "tangents_token", (), {}, name="tangents_token"
252
            )
253
            track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer)
254

255
            tokens[key] = token_tensor
256
        else:
257
            tokens[key] = new_token_tensor()
258

259
    token = tokens[key]
260

261
    from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
262

263
    ctx = PythonFunctionalizeAPI()
264

265
    unwrapped_token = ctx.unwrap_tensors([token])[0]  # type: ignore[arg-type]
266
    unwrapped_args = ctx.unwrap_tensors(args)  # type: ignore[arg-type]
267
    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)  # type: ignore[arg-type]
268
    with ctx.redispatch_to_next():
269
        (new_token, *unwrapped_outs) = with_effects(
270
            unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs  # type: ignore[arg-type]
271
        )
272

273
    schema = _get_schema(op, unwrapped_args)
274
    if len(schema.returns) == 0:
275
        assert unwrapped_outs[0] is None
276
        unwrapped_outs = None  # type: ignore[assignment]
277
    elif len(schema.returns) == 1:
278
        assert len(unwrapped_outs) == 1
279
        unwrapped_outs = unwrapped_outs[0]
280
    else:
281
        assert len(unwrapped_outs) == len(schema.returns)
282

283
    # Add the newly created token into the tokens map for a following call to
284
    # use this token.
285
    wrapped_token = ctx.wrap_tensors(new_token)
286
    assert isinstance(wrapped_token, torch.Tensor)
287
    tokens[key] = wrapped_token
288

289
    return ctx.wrap_tensors(unwrapped_outs)  # type: ignore[arg-type]
290

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.