pytorch
289 строк · 9.4 Кб
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3from enum import Enum
4from typing import Any, Dict, Optional, Tuple, Union
5from weakref import WeakKeyDictionary
6
7import torch
8import torch.utils._pytree as pytree
9from torch._C import DispatchKey
10from torch._higher_order_ops.torchbind import call_torchbind
11from torch._ops import HigherOrderOperator
12from torch._subclasses.fake_tensor import FakeTensorMode
13from torch.fx.experimental.proxy_tensor import (
14disable_proxy_modes_tracing,
15ProxyTorchDispatchMode,
16track_tensor_tree,
17)
18
19
20class _EffectType(Enum):
21ORDERED = "Ordered"
22
23
24OpType = Union[torch._ops.HigherOrderOperator, torch._ops.OpOverload]
25
26
27SIDE_EFFECTS: "WeakKeyDictionary[OpType, _EffectType]" = WeakKeyDictionary(
28{
29torch.ops.aten._print.default: _EffectType.ORDERED,
30call_torchbind: _EffectType.ORDERED,
31}
32)
33
34
35def _register_effectful_op(op: OpType, effect: _EffectType):
36assert isinstance(
37op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
38) and not has_aliasing(op)
39if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
40raise RuntimeError(
41f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
42f"trying to register a different effect type {effect}."
43)
44SIDE_EFFECTS[op] = effect
45
46
47def _deregister_effectful_op(op: OpType):
48if op not in SIDE_EFFECTS:
49raise RuntimeError(f"Op {op} is not registered as effectful")
50
51del SIDE_EFFECTS[op]
52
53
54class WithEffects(HigherOrderOperator):
55"""
56with_effects(token, op, args, kwargs) -> (new_token, op_results)
57
58This HOP helps ensure ordering between side effectful ops like prints or ops
59using torchbind objects. This is needed to ensure a traced graph from
60AOTAutograd is functional so that future optimization passes do not reorder
61these operators. This is done through threading "effect tokens" through the
62graph to enforce data dependence between side effectful ops.
63
64The tokens are basically dummy values (torch.tensor([])). We create a token
65per "effect type", which are enumerated in the _EffectType enum.
66"""
67
68def __init__(self) -> None:
69super().__init__("with_effects")
70
71def __call__(
72self,
73token,
74op: OpType,
75*args: Tuple[Any, ...],
76**kwargs: Dict[str, Any],
77) -> Tuple[Any, ...]:
78assert isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
79assert not has_aliasing(op), "Ops with aliasing is not supported"
80assert has_effects(op, args, kwargs)
81assert isinstance(kwargs, dict)
82return super().__call__(token, op, *args, **kwargs)
83
84
85with_effects = WithEffects()
86
87
88def has_aliasing(op: OpType):
89# NOT FOR PUBLIC USE
90if isinstance(op, torch._ops.HigherOrderOperator):
91return op not in SIDE_EFFECTS
92
93for arg in op._schema.arguments:
94if arg.alias_info is not None:
95return True
96for arg in op._schema.returns:
97if arg.alias_info is not None:
98return True
99return False
100
101
102def has_effects(op, args, kwargs) -> bool:
103# Skip over the profiler's RecordFunction as they should not show up in the graph
104_skip_ops = {torch.ops.profiler._record_function_exit._RecordFunction}
105if op in _skip_ops:
106return False
107
108return (
109isinstance(op, (torch._ops.HigherOrderOperator, torch._ops.OpOverload))
110and not has_aliasing(op)
111and get_effect_key(op, args, kwargs) is not None
112)
113
114
115def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
116if op in SIDE_EFFECTS:
117return SIDE_EFFECTS[op]
118
119for arg in args:
120if isinstance(arg, torch.ScriptObject):
121# Add it to the table so that next time we see the same op we don't
122# have to parse through the args again
123SIDE_EFFECTS[op] = _EffectType.ORDERED
124return _EffectType.ORDERED
125
126return None
127
128
129def new_token_tensor() -> torch.Tensor:
130# Use dtype bool to not affect Inductor dtype promotions
131return torch.tensor([], dtype=torch.bool)
132
133
134@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
135def with_effects_dense(
136token: torch.Tensor,
137op: torch._ops.OpOverload,
138*args: Tuple[Any, ...],
139**kwargs: Dict[str, Any],
140) -> Tuple[torch.Tensor, ...]:
141out = op(*args, **kwargs)
142new_token = new_token_tensor()
143if isinstance(out, tuple):
144return (new_token, *out)
145return (new_token, out)
146
147
148@with_effects.py_impl(FakeTensorMode)
149def with_effects_fake(
150mode,
151token: torch.Tensor,
152op: torch._ops.OpOverload,
153*args: Tuple[Any, ...],
154**kwargs: Dict[str, Any],
155) -> Tuple[torch.Tensor, ...]:
156with mode:
157result = with_effects_dense(token, op, *args, **kwargs)
158return result
159
160
161@with_effects.py_impl(ProxyTorchDispatchMode)
162def with_effects_proxy(
163mode,
164token: torch.Tensor,
165op: torch._ops.OpOverload,
166*args: Tuple[Any, ...],
167**kwargs: Dict[str, Any],
168) -> Tuple[torch.Tensor, ...]:
169with disable_proxy_modes_tracing():
170out = with_effects(token, op, *args, **kwargs)
171
172proxy_token = mode.tracer.unwrap_proxy(token)
173proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
174proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
175
176from torch.fx.node import has_side_effect
177
178# To avoid the being DCEed by graph.eliminate_dead_code if they.
179# don't have output or their outputs are not used.
180has_side_effect(op)
181
182out_proxy = mode.tracer.create_proxy(
183"call_function",
184with_effects,
185(proxy_token, op, *proxy_args),
186proxy_kwargs,
187)
188result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
189return result
190
191
192with_effects.fallthrough(DispatchKey.AutogradCPU)
193with_effects.fallthrough(DispatchKey.AutogradCUDA)
194
195
196def _get_schema(op, args) -> torch.FunctionSchema:
197if isinstance(op, torch._ops.OpOverload):
198return op._schema
199elif op == call_torchbind:
200return getattr(args[0], args[1]).schema
201else:
202raise RuntimeError(f"Unable to get schema for op {op}")
203
204
205def handle_effects(
206allow_token_discovery: bool,
207tokens: Dict[_EffectType, torch.Tensor],
208op: OpType,
209args: Tuple[Any, ...],
210kwargs: Dict[str, Any],
211) -> Any:
212"""
213Args:
214allow_token_discovery: Whether or not we are discovering tokens. If this
215is true, we will create a token for every side effect type seen that
216does not have a token assigned yet. If this is false, the tokens
217should've all been created ahead of time, so we will error if there is
218no token mapping to every effect type.
219
220tokens: Map of effect type to tokens. This is to chain operators of the
221same effects together so that they do not get reordered in later
222optimization passes.
223"""
224
225# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
226# this will create an empty tensor during proxy mode tracing if the token
227# doesn't exist. But the tokens should always exist during proxy mode tracing.
228key = get_effect_key(op, args, kwargs)
229assert key is not None
230if key not in tokens:
231assert (
232allow_token_discovery
233), f"Could not find a token for effect {key} which came from the function {op}"
234proxy_tensor_mode = torch._C._get_dispatch_mode(
235torch._C._TorchDispatchModeKey.PROXY
236)
237if proxy_tensor_mode is not None:
238# If we discovered a new token during tracing, we are in backward.
239# Then we patch the graph, adding additional tangents_token as input to the joint graph.
240tracer = proxy_tensor_mode.tracer
241
242from torch.fx.experimental.proxy_tensor import (
243disable_proxy_modes_tracing,
244track_tensor_tree,
245)
246
247with disable_proxy_modes_tracing():
248token_tensor = new_token_tensor()
249
250token_proxy = proxy_tensor_mode.tracer.create_proxy(
251"placeholder", "tangents_token", (), {}, name="tangents_token"
252)
253track_tensor_tree(token_tensor, token_proxy, constant=None, tracer=tracer)
254
255tokens[key] = token_tensor
256else:
257tokens[key] = new_token_tensor()
258
259token = tokens[key]
260
261from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
262
263ctx = PythonFunctionalizeAPI()
264
265unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
266unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
267unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
268with ctx.redispatch_to_next():
269(new_token, *unwrapped_outs) = with_effects(
270unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
271)
272
273schema = _get_schema(op, unwrapped_args)
274if len(schema.returns) == 0:
275assert unwrapped_outs[0] is None
276unwrapped_outs = None # type: ignore[assignment]
277elif len(schema.returns) == 1:
278assert len(unwrapped_outs) == 1
279unwrapped_outs = unwrapped_outs[0]
280else:
281assert len(unwrapped_outs) == len(schema.returns)
282
283# Add the newly created token into the tokens map for a following call to
284# use this token.
285wrapped_token = ctx.wrap_tensors(new_token)
286assert isinstance(wrapped_token, torch.Tensor)
287tokens[key] = wrapped_token
288
289return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]
290