pytorch
298 строк · 10.4 Кб
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import warnings
4from typing import Any, Dict, List, Optional, Tuple, Union
5
6import torch
7import torch.utils._pytree as pytree
8from torch import Tensor
9from torch._C import DispatchKey
10from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
11from torch._prims_common import clone_preserve_strides
12from torch._subclasses.fake_tensor import FakeTensorMode
13from torch.fx.experimental.proxy_tensor import (
14disable_proxy_modes_tracing,
15ProxyTorchDispatchMode,
16track_tensor_tree,
17)
18
19
20# NOTE: [auto-functionalizing custom ops]
21# Users may wish to torch.compile custom ops that mutate their inputs.
22# torch.compile will automatically support this op without anyone needing
23# to provide a functionalization kernel for it. Here's how.
24#
25# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
26# op. First, when FakeTensor sees this op:
27# - If the schema says it returns nothing, we can generate a trivial
28# FakeTensor rule for it (that returns nothing).
29# - Otherwise, the user needs to provide a FakeTensor impl (fake impl)
30#
31# Next, when Python FunctionalTensor sees the op, it will functionalize
32# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
33# HOP and replacing the mutated inputs with corresponding outputs of this HOP.
34# This HOP effectively runs the functional version of the op when
35# called: it clones inputs that will be mutated, runs the op, and
36# then returns (output, Tensors with the new values)
37
38
39class AutoFunctionalized(HigherOrderOperator):
40"""auto_functionalized(_mutable_op, **kwargs)
41
42This HOP runs a "functional" version of _mutable_op.
43
44Concretely, it looks at all the arguments that are mutable through
45_mutable_op's operator schema, clones those kwargs, runs
46`out = _mutable_op(**kwargs)` with the cloned values, and then returns the
47operator output concatenated with the cloned values that were mutated.
48
49We have some restrictions on `_mutable_op`.
50See `can_auto_functionalize` for the restrictions. We can likely lift
51many of these if users request it.
52
53The reason why _mutable_op is prefixed with an
54underscore is to prevent collisions with kwarg names in **kwargs.
55"""
56
57def __init__(self) -> None:
58super().__init__("auto_functionalized")
59
60def __call__(
61self,
62/,
63_mutable_op: OpOverload,
64**kwargs: Any,
65) -> Tuple[Any, Tuple[Tensor, ...]]:
66assert can_auto_functionalize(_mutable_op)
67assert isinstance(kwargs, dict)
68return super().__call__(_mutable_op, **kwargs)
69
70
71auto_functionalized = AutoFunctionalized()
72auto_functionalized.__module__ = "torch.ops.higher_order"
73
74
75def can_auto_functionalize(op: OperatorBase) -> bool:
76if not isinstance(op, OpOverload):
77return False
78
79if torch._library.utils.is_builtin(op):
80# We control the built-ins. These may (in rare cases)
81# do input metadata mutation (which we have banned on custom ops)
82return False
83schema = op._schema
84if not schema.is_mutable:
85return False
86schema = op._schema
87
88for arg in schema.arguments:
89if arg.alias_info is None:
90continue
91if not arg.alias_info.is_write:
92continue
93if type(arg.type) is torch.TensorType:
94continue
95if (
96type(arg.type) is torch.OptionalType
97and type(arg.type.getElementType()) is torch.TensorType
98):
99continue
100if (
101type(arg.type) is torch.ListType
102and type(arg.type.getElementType()) is torch.TensorType
103):
104continue
105# Not yet supported: other Tensor types. This includes things like
106# Tensor?[], Tensor[]?.
107return False
108
109if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
110# Skip schema returns -> None
111return True
112# The returns must not alias anything
113for ret in schema.returns:
114if ret.alias_info is None and type(ret.type) is torch.TensorType:
115continue
116# Not yet supported: List[Tensor] return.
117return False
118if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"):
119return False
120return True
121
122
123@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
124def auto_functionalized_dense(
125_mutable_op: OpOverload,
126_only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
127**kwargs: Any,
128) -> Tuple[Any, Tuple[Tensor, ...]]:
129new_kwargs = dict(**kwargs)
130result = []
131
132_mutable_args_names = get_mutable_arg_names(_mutable_op)
133for name in _mutable_args_names:
134if (
135_only_clone_these_tensors is not None
136and name not in _only_clone_these_tensors
137):
138new_kwargs[name] = kwargs[name]
139else:
140new_kwargs[name] = (
141[clone_preserve_strides(x) for x in kwargs[name]]
142if kwargs[name] is not None and isinstance(kwargs[name], list)
143else clone_preserve_strides(kwargs[name])
144if kwargs[name] is not None
145else None
146)
147result.append(new_kwargs[name])
148out = _mutable_op(**new_kwargs)
149
150if isinstance(out, tuple):
151return (*out, *result) # type: ignore[return-value]
152else:
153return (out, *result) # type: ignore[return-value]
154
155
156@auto_functionalized.py_impl(FakeTensorMode)
157def auto_functionalized_fake(
158mode,
159_mutable_op: OpOverload,
160**kwargs: Any,
161) -> Tuple[Any, Tuple[Tensor, ...]]:
162with mode:
163result = auto_functionalized_dense(_mutable_op, **kwargs)
164return result
165
166
167@auto_functionalized.py_impl(ProxyTorchDispatchMode)
168def auto_functionalized_proxy(
169mode,
170_mutable_op: OpOverload,
171**kwargs: Any,
172) -> Tuple[Any, Tuple[Tensor, ...]]:
173with disable_proxy_modes_tracing():
174out = auto_functionalized(_mutable_op, **kwargs)
175
176proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
177out_proxy = mode.tracer.create_proxy(
178"call_function",
179auto_functionalized,
180(_mutable_op,),
181proxy_kwargs,
182)
183result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
184return result
185
186
187auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
188auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
189
190
191def get_mutable_arg_names(op: OpOverload) -> List[str]:
192"""
193Returns the list of argument names that get mutated according to the
194schema.
195"""
196mutable_args_names = [
197arg.name
198for arg in op._schema.arguments
199if arg.alias_info is not None and arg.alias_info.is_write
200]
201return mutable_args_names
202
203
204def do_auto_functionalize(
205op: OpOverload,
206args: Tuple[Any, ...],
207kwargs: Dict[str, Any],
208) -> Any:
209"""Functionalizes a call to op(*args, **kwargs) by emitting a call to
210`outs = auto_functionalized(op, normalized_kwargs)`
211and replacing the mutated (args, kwargs) with the corresponding outputs.
212
213The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
214This makes handling easier for the auto_functionalized HOP.
215"""
216from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
217
218ctx = PythonFunctionalizeAPI()
219
220# All of the (args, kwargs), but all as kwargs. The names for the
221# args come from the schema. This makes it easier for us to work with them.
222normalized_kwargs = {}
223schema = op._schema
224for idx, arg in enumerate(schema.arguments):
225# NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
226if arg.name in kwargs:
227normalized_kwargs[arg.name] = kwargs[arg.name]
228elif idx < len(args):
229# if its out of bounds we don't need to do anything
230# as it means the the optional arg was passed with its default
231# value
232normalized_kwargs[arg.name] = args[idx]
233else:
234normalized_kwargs[arg.name] = arg.default_value
235
236unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type]
237if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
238warnings.warn(
239"Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
240"Please consider using a different name for this argument to avoid potential issues."
241)
242with ctx.redispatch_to_next():
243unwrapped_outs = auto_functionalized(
244op, **unwrapped_kwargs # type: ignore[arg-type]
245)
246
247# List of the name of args that get mutated (according to the schema)
248mutable_args_names = get_mutable_arg_names(op)
249
250unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
251: -len(mutable_args_names)
252]
253unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
254
255if len(op._schema.returns) == 0:
256assert unwrapped_actual_out[0] is None
257unwrapped_actual_out = None
258elif len(op._schema.returns) == 1:
259assert len(unwrapped_actual_out) == 1
260unwrapped_actual_out = unwrapped_actual_out[0]
261else:
262assert len(unwrapped_actual_out) == len(op._schema.returns)
263
264for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
265# Can be None if input was `Tensor(a!)?`
266if unwrapped_out is None:
267continue
268
269# We only handle Tensor or List[Tensor] here for now.
270def sync_update(o, orig_arg):
271ctx.replace(orig_arg, o)
272ctx.commit_update(orig_arg)
273ctx.sync(orig_arg)
274
275orig_arg = normalized_kwargs[name]
276
277if isinstance(unwrapped_out, torch.Tensor):
278sync_update(unwrapped_out, orig_arg)
279elif isinstance(unwrapped_out, list) and all(
280isinstance(o, torch.Tensor) for o in unwrapped_out
281):
282assert len(orig_arg) == len(unwrapped_out)
283for orig_a, o in zip(orig_arg, unwrapped_out):
284sync_update(o, orig_a)
285else:
286raise RuntimeError(
287f"unsupported type for auto-functionalization: {unwrapped_out}"
288)
289
290return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
291
292
293@auto_functionalized.py_functionalize_impl
294def auto_functionalized_func(ctx, _mutable_op, **kwargs):
295unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
296with ctx.redispatch_to_next():
297result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
298return ctx.wrap_tensors(result)
299