pytorch

Форк
0
/
auto_functionalize.py 
298 строк · 10.4 Кб
1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
3
import warnings
4
from typing import Any, Dict, List, Optional, Tuple, Union
5

6
import torch
7
import torch.utils._pytree as pytree
8
from torch import Tensor
9
from torch._C import DispatchKey
10
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
11
from torch._prims_common import clone_preserve_strides
12
from torch._subclasses.fake_tensor import FakeTensorMode
13
from torch.fx.experimental.proxy_tensor import (
14
    disable_proxy_modes_tracing,
15
    ProxyTorchDispatchMode,
16
    track_tensor_tree,
17
)
18

19

20
# NOTE: [auto-functionalizing custom ops]
21
# Users may wish to torch.compile custom ops that mutate their inputs.
22
# torch.compile will automatically support this op without anyone needing
23
# to provide a functionalization kernel for it. Here's how.
24
#
25
# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
26
# op. First, when FakeTensor sees this op:
27
# - If the schema says it returns nothing, we can generate a trivial
28
#   FakeTensor rule for it (that returns nothing).
29
# - Otherwise, the user needs to provide a FakeTensor impl (fake impl)
30
#
31
# Next, when Python FunctionalTensor sees the op, it will functionalize
32
# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
33
# HOP and replacing the mutated inputs with corresponding outputs of this HOP.
34
# This HOP effectively runs the functional version of the op when
35
# called: it clones inputs that will be mutated, runs the op, and
36
# then returns (output, Tensors with the new values)
37

38

39
class AutoFunctionalized(HigherOrderOperator):
40
    """auto_functionalized(_mutable_op, **kwargs)
41

42
    This HOP runs a "functional" version of _mutable_op.
43

44
    Concretely, it looks at all the arguments that are mutable through
45
    _mutable_op's operator schema, clones those kwargs, runs
46
    `out = _mutable_op(**kwargs)` with the cloned values, and then returns the
47
    operator output concatenated with the cloned values that were mutated.
48

49
    We have some restrictions on `_mutable_op`.
50
    See `can_auto_functionalize` for the restrictions. We can likely lift
51
    many of these if users request it.
52

53
    The reason why _mutable_op is prefixed with an
54
    underscore is to prevent collisions with kwarg names in **kwargs.
55
    """
56

57
    def __init__(self) -> None:
58
        super().__init__("auto_functionalized")
59

60
    def __call__(
61
        self,
62
        /,
63
        _mutable_op: OpOverload,
64
        **kwargs: Any,
65
    ) -> Tuple[Any, Tuple[Tensor, ...]]:
66
        assert can_auto_functionalize(_mutable_op)
67
        assert isinstance(kwargs, dict)
68
        return super().__call__(_mutable_op, **kwargs)
69

70

71
auto_functionalized = AutoFunctionalized()
72
auto_functionalized.__module__ = "torch.ops.higher_order"
73

74

75
def can_auto_functionalize(op: OperatorBase) -> bool:
76
    if not isinstance(op, OpOverload):
77
        return False
78

79
    if torch._library.utils.is_builtin(op):
80
        # We control the built-ins. These may (in rare cases)
81
        # do input metadata mutation (which we have banned on custom ops)
82
        return False
83
    schema = op._schema
84
    if not schema.is_mutable:
85
        return False
86
    schema = op._schema
87

88
    for arg in schema.arguments:
89
        if arg.alias_info is None:
90
            continue
91
        if not arg.alias_info.is_write:
92
            continue
93
        if type(arg.type) is torch.TensorType:
94
            continue
95
        if (
96
            type(arg.type) is torch.OptionalType
97
            and type(arg.type.getElementType()) is torch.TensorType
98
        ):
99
            continue
100
        if (
101
            type(arg.type) is torch.ListType
102
            and type(arg.type.getElementType()) is torch.TensorType
103
        ):
104
            continue
105
        # Not yet supported: other Tensor types. This includes things like
106
        # Tensor?[], Tensor[]?.
107
        return False
108

109
    if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
110
        # Skip schema returns -> None
111
        return True
112
    # The returns must not alias anything
113
    for ret in schema.returns:
114
        if ret.alias_info is None and type(ret.type) is torch.TensorType:
115
            continue
116
        # Not yet supported: List[Tensor] return.
117
        return False
118
    if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"):
119
        return False
120
    return True
121

122

123
@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
124
def auto_functionalized_dense(
125
    _mutable_op: OpOverload,
126
    _only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
127
    **kwargs: Any,
128
) -> Tuple[Any, Tuple[Tensor, ...]]:
129
    new_kwargs = dict(**kwargs)
130
    result = []
131

132
    _mutable_args_names = get_mutable_arg_names(_mutable_op)
133
    for name in _mutable_args_names:
134
        if (
135
            _only_clone_these_tensors is not None
136
            and name not in _only_clone_these_tensors
137
        ):
138
            new_kwargs[name] = kwargs[name]
139
        else:
140
            new_kwargs[name] = (
141
                [clone_preserve_strides(x) for x in kwargs[name]]
142
                if kwargs[name] is not None and isinstance(kwargs[name], list)
143
                else clone_preserve_strides(kwargs[name])
144
                if kwargs[name] is not None
145
                else None
146
            )
147
        result.append(new_kwargs[name])
148
    out = _mutable_op(**new_kwargs)
149

150
    if isinstance(out, tuple):
151
        return (*out, *result)  # type: ignore[return-value]
152
    else:
153
        return (out, *result)  # type: ignore[return-value]
154

155

156
@auto_functionalized.py_impl(FakeTensorMode)
157
def auto_functionalized_fake(
158
    mode,
159
    _mutable_op: OpOverload,
160
    **kwargs: Any,
161
) -> Tuple[Any, Tuple[Tensor, ...]]:
162
    with mode:
163
        result = auto_functionalized_dense(_mutable_op, **kwargs)
164
        return result
165

166

167
@auto_functionalized.py_impl(ProxyTorchDispatchMode)
168
def auto_functionalized_proxy(
169
    mode,
170
    _mutable_op: OpOverload,
171
    **kwargs: Any,
172
) -> Tuple[Any, Tuple[Tensor, ...]]:
173
    with disable_proxy_modes_tracing():
174
        out = auto_functionalized(_mutable_op, **kwargs)
175

176
    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
177
    out_proxy = mode.tracer.create_proxy(
178
        "call_function",
179
        auto_functionalized,
180
        (_mutable_op,),
181
        proxy_kwargs,
182
    )
183
    result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
184
    return result
185

186

187
auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
188
auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
189

190

191
def get_mutable_arg_names(op: OpOverload) -> List[str]:
192
    """
193
    Returns the list of argument names that get mutated according to the
194
    schema.
195
    """
196
    mutable_args_names = [
197
        arg.name
198
        for arg in op._schema.arguments
199
        if arg.alias_info is not None and arg.alias_info.is_write
200
    ]
201
    return mutable_args_names
202

203

204
def do_auto_functionalize(
205
    op: OpOverload,
206
    args: Tuple[Any, ...],
207
    kwargs: Dict[str, Any],
208
) -> Any:
209
    """Functionalizes a call to op(*args, **kwargs) by emitting a call to
210
    `outs = auto_functionalized(op, normalized_kwargs)`
211
    and replacing the mutated (args, kwargs) with the corresponding outputs.
212

213
    The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
214
    This makes handling easier for the auto_functionalized HOP.
215
    """
216
    from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
217

218
    ctx = PythonFunctionalizeAPI()
219

220
    # All of the (args, kwargs), but all as kwargs. The names for the
221
    # args come from the schema. This makes it easier for us to work with them.
222
    normalized_kwargs = {}
223
    schema = op._schema
224
    for idx, arg in enumerate(schema.arguments):
225
        # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
226
        if arg.name in kwargs:
227
            normalized_kwargs[arg.name] = kwargs[arg.name]
228
        elif idx < len(args):
229
            # if its out of bounds we don't need to do anything
230
            # as it means the the optional arg was passed with its default
231
            # value
232
            normalized_kwargs[arg.name] = args[idx]
233
        else:
234
            normalized_kwargs[arg.name] = arg.default_value
235

236
    unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs)  # type: ignore[arg-type]
237
    if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
238
        warnings.warn(
239
            "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
240
            "Please consider using a different name for this argument to avoid potential issues."
241
        )
242
    with ctx.redispatch_to_next():
243
        unwrapped_outs = auto_functionalized(
244
            op, **unwrapped_kwargs  # type: ignore[arg-type]
245
        )
246

247
    # List of the name of args that get mutated (according to the schema)
248
    mutable_args_names = get_mutable_arg_names(op)
249

250
    unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
251
        : -len(mutable_args_names)
252
    ]
253
    unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
254

255
    if len(op._schema.returns) == 0:
256
        assert unwrapped_actual_out[0] is None
257
        unwrapped_actual_out = None
258
    elif len(op._schema.returns) == 1:
259
        assert len(unwrapped_actual_out) == 1
260
        unwrapped_actual_out = unwrapped_actual_out[0]
261
    else:
262
        assert len(unwrapped_actual_out) == len(op._schema.returns)
263

264
    for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
265
        # Can be None if input was `Tensor(a!)?`
266
        if unwrapped_out is None:
267
            continue
268

269
        # We only handle Tensor or List[Tensor] here for now.
270
        def sync_update(o, orig_arg):
271
            ctx.replace(orig_arg, o)
272
            ctx.commit_update(orig_arg)
273
            ctx.sync(orig_arg)
274

275
        orig_arg = normalized_kwargs[name]
276

277
        if isinstance(unwrapped_out, torch.Tensor):
278
            sync_update(unwrapped_out, orig_arg)
279
        elif isinstance(unwrapped_out, list) and all(
280
            isinstance(o, torch.Tensor) for o in unwrapped_out
281
        ):
282
            assert len(orig_arg) == len(unwrapped_out)
283
            for orig_a, o in zip(orig_arg, unwrapped_out):
284
                sync_update(o, orig_a)
285
        else:
286
            raise RuntimeError(
287
                f"unsupported type for auto-functionalization: {unwrapped_out}"
288
            )
289

290
    return ctx.wrap_tensors(unwrapped_actual_out)  # type: ignore[arg-type]
291

292

293
@auto_functionalized.py_functionalize_impl
294
def auto_functionalized_func(ctx, _mutable_op, **kwargs):
295
    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
296
    with ctx.redispatch_to_next():
297
        result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
298
    return ctx.wrap_tensors(result)
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.