pytorch

Форк
0
379 строк · 13.6 Кб
1
# mypy: allow-untyped-defs
2
import functools
3
from contextlib import contextmanager
4
from dataclasses import dataclass
5
from typing import Any, Callable
6

7
import torch
8
import torch.fx.traceback as fx_traceback
9
import torch.utils._pytree as pytree
10
from torch._ops import OperatorBase
11
from torch.fx.experimental.proxy_tensor import make_fx
12
from torch.multiprocessing.reductions import StorageWeakRef
13

14

15
@dataclass
16
class UnsupportedAliasMutationException(RuntimeError):
17
    reason: str
18

19

20
def autograd_not_implemented_inner(
21
    operator: OperatorBase, delayed_error: bool, *args: Any, **kwargs: Any
22
) -> Any:
23
    """If autograd is enabled and any of the arguments require grad this will either
24
    raise an error or return a DelayedError depending on the value of delayed.
25

26
    Args:
27
        operator: The Operator to call with the *args and **kwargs with
28
        op_name: The name of the Operator
29
        delayed_error: If True, return a DelayedError instead of raising an error
30
        args: The flattened operands to the Operator
31
        kwargs: The keyword arguments to the Operator
32

33
    Raises:
34
        RuntimeError: If autograd is enabled and any of the arguments to the Operator
35
    """
36
    with torch._C._AutoDispatchBelowAutograd():
37
        result = operator(*args, **kwargs)
38
        flat_operands = pytree.arg_tree_leaves(*args)
39
        if torch.is_grad_enabled() and any(
40
            f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
41
        ):
42
            if delayed_error:
43
                err_fn = torch._C._functions.DelayedError(
44
                    f"Autograd not implemented for {str(operator)}",
45
                    1,
46
                )
47

48
                def fake_requires_grad(tensor):
49
                    if torch.is_floating_point(tensor) or torch.is_complex(tensor):
50
                        tensor = tensor.detach()
51
                        tensor.requires_grad = True
52
                    return tensor
53

54
                return pytree.tree_map_only(
55
                    torch.Tensor, lambda x: err_fn(fake_requires_grad(x)), result
56
                )
57
            else:
58
                raise RuntimeError(f"Autograd not implemented for {str(operator)}")
59
        return result
60

61

62
def autograd_not_implemented(op: OperatorBase, deferred_error: bool) -> Callable:
63
    def inner(*args, **kwargs):
64
        return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
65

66
    return inner
67

68

69
def _maybe_run_with_interpreter(fn):
70
    maybe_interpreted_fn = fn
71
    if isinstance(fn, torch.fx.GraphModule) and fx_traceback.has_preserved_node_meta():
72
        # Running graph with interpreter is needed for propagating the stack_trace
73
        def graph_with_interpreter(*args):
74
            with fx_traceback.preserve_node_meta():
75
                return torch.fx.Interpreter(fn).run(*args)
76

77
        maybe_interpreted_fn = graph_with_interpreter
78
    return maybe_interpreted_fn
79

80

81
def reenter_make_fx(fn):
82
    from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
83

84
    @functools.wraps(fn)
85
    def wrapped(*args):
86
        assert (
87
            _CURRENT_MAKE_FX_TRACER is not None
88
        ), "Cannot reenter make_fx when we're not under a make_fx tracing session"
89
        return _CURRENT_MAKE_FX_TRACER.trace_subgraph(
90
            _maybe_run_with_interpreter(fn), *args
91
        )
92

93
    return wrapped
94

95

96
def _maybe_reenter_make_fx(fn):
97
    from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
98

99
    if _CURRENT_MAKE_FX_TRACER is not None:
100
        return reenter_make_fx(fn)
101
    else:
102
        return make_fx(fn)
103

104

105
@contextmanager
106
def _set_compilation_env():
107
    _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag
108
    try:
109
        # We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
110
        # once we are confident fx tracing works with dynamo.
111
        torch.fx._symbolic_trace._is_fx_tracing_flag = False
112
        yield
113
    finally:
114
        torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing
115

116

117
def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False):
118
    """
119
    Dispatch-trace the branch with inputs and check if
120
    producing graph has mutable op on the input. This is
121
    bit restrictive as the branch must be traceable.
122
    """
123
    try:
124
        gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
125
    except UnsupportedAliasMutationException:
126
        # this can happen when nested cond_op is
127
        # functionalized
128
        return True
129
    except Exception as e:
130
        raise e
131

132
    def _detect_input_mutation(gm):
133
        input_nodes = set()
134
        for node in gm.graph.nodes:
135
            if node.op == "placeholder":
136
                input_nodes.add(node)
137
            if node.op == "call_function":
138
                target = node.target
139
                if (
140
                    isinstance(target, torch._ops.OpOverload)
141
                    and target._schema.is_mutable
142
                ):
143
                    for arg in node.args:
144
                        if arg in input_nodes:
145
                            return True
146

147
        for _, module in gm.named_children():
148
            if isinstance(module, torch.fx.GraphModule):
149
                if _detect_input_mutation(module):
150
                    return True
151

152
        return False
153

154
    return _detect_input_mutation(gm)
155

156

157
def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False):
158
    """
159
    Dispatch-trace the branch with inputs and check if
160
    producing graph has output aliasing the branch input. This is
161
    bit restrictive as the branch must be traceable.
162
    """
163
    try:
164
        gm = make_fx(branch, pre_dispatch=pre_dispatch)(*inputs)
165
    except UnsupportedAliasMutationException:
166
        # this can happen when nested cond_op is
167
        # functionalized
168
        return True
169
    except Exception as e:
170
        raise e
171

172
    def _detect_input_alias(gm):
173
        input_storages = set()
174
        for node in gm.graph.nodes:
175
            # We need to check existence of "val" because we reuse the logic here
176
            # for map operator, where num_mapped_args is a scalar
177
            # and doesn't have a "val" meta.
178
            if node.op == "placeholder" and "val" in node.meta:
179
                input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage()))
180
            if node.op == "output":
181

182
                def check_alias(out):
183
                    if out is not None and "val" in out.meta:
184
                        out_storage = StorageWeakRef(out.meta["val"]._typed_storage())
185
                        return out_storage in input_storages
186
                    return False
187

188
                if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))):
189
                    return True
190

191
        for _, module in gm.named_children():
192
            if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module):
193
                return True
194

195
        return False
196

197
    return _detect_input_alias(gm)
198

199

200
def unique_graph_id(proxy_mode, prefix):
201
    """Returns a unique name and id for a graph to be added to a proxy_mode tracer"""
202
    # There are probably better ways - I know that create_arg has some self incrementing name
203
    # magic to it, but since we explicitly have to get the name for register_module,
204
    # I was not sure how to do that. This kinda simulates it.
205
    next_name = None
206
    i = 0
207
    while not next_name:
208
        candidate = f"{prefix}_{i}"
209
        if hasattr(proxy_mode.tracer.root, candidate):
210
            i += 1
211
        else:
212
            next_name = candidate
213
    return i, next_name
214

215

216
def _from_fun(t):
217
    from torch._functorch.aot_autograd import from_fun
218
    from torch._subclasses.functional_tensor import FunctionalTensor
219

220
    if isinstance(t, torch.Tensor):
221
        if t.dtype != torch.bool:
222
            return torch.empty_strided(
223
                t.size(),
224
                t.stride(),
225
                dtype=t.dtype,
226
                requires_grad=t.requires_grad,
227
            )
228
        else:
229
            # clone of a functional tensor produces a functional tensor
230
            # but we want to avoid it so we clone a non-functional version
231
            maybe_unfunc_t = t
232
            if isinstance(t, FunctionalTensor):
233
                torch._sync(t)
234
                maybe_unfunc_t = from_fun(t)
235
            elif torch._is_functional_tensor(t):
236
                # need to handle both types of functionalization here:
237
                # these are the tensors that came from the user,
238
                # which could be either FunctionalTensorWrapper or FunctionalTensor
239
                torch._sync(t)
240
                maybe_unfunc_t = torch._from_functional_tensor(t)
241
            return maybe_unfunc_t.clone()
242
    return t
243

244

245
def clone_outputs_aliasing_inputs(args):
246
    input_storage = {
247
        StorageWeakRef(arg._typed_storage())
248
        for arg in args
249
        if isinstance(arg, torch.Tensor)
250
    }
251

252
    def maybe_clone(t):
253
        if (
254
            isinstance(t, torch.Tensor)
255
            and StorageWeakRef(t._typed_storage()) in input_storage
256
        ):
257
            return t.clone()
258
        return t
259

260
    return maybe_clone
261

262

263
def prepare_fw_with_masks(fn):
264
    def fw_with_masks(*args):
265
        fw_out = fn(*args)
266
        return fw_out, [
267
            True if isinstance(ret, torch.Tensor) and ret.requires_grad else False
268
            for ret in fw_out
269
        ]
270

271
    return fw_with_masks
272

273

274
# TODO: The parameter use_output_and_grad_bw is required because some operations
275
# that utilize this function, such as the while_loop, may require (grad, fwd_outputs)
276
def create_fw_bw_graph(fn, use_output_and_grad_bw, fw_inputs, fw_outputs):
277
    from torch._functorch.aot_autograd import AOTConfig, create_joint
278

279
    # Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
280
    # between Autograd and Python key. Currently, we only suspend functionalization but more can be
281
    # added when required. Will encounter two problems if we don't suspend functionalization:
282
    #
283
    # 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
284
    # but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
285
    # However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
286
    # fetch the proxy for the inputs and fail to capture any operations on them.
287
    #
288
    # 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
289
    # wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
290
    # only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
291
    # when creating the output node, it fails to associate the wrapped tensor with its proxy.
292
    # Instead, it will create _tensor_constant as output.
293

294
    dummy_aot_config = AOTConfig(
295
        fw_compiler=None,  # type: ignore[arg-type]
296
        bw_compiler=None,  # type: ignore[arg-type]
297
        partition_fn=None,  # type: ignore[arg-type]
298
        decompositions={},
299
        num_params_buffers=0,
300
        aot_id=0,
301
        keep_inference_input_mutations=False,
302
    )
303

304
    example_grad = [_from_fun(out) for out in fw_outputs]
305
    num_grads = len(example_grad)
306
    fw_graph = _maybe_reenter_make_fx(fn)(*fw_inputs)
307

308
    def joint_fn(*joint_operands_grads):
309
        if use_output_and_grad_bw:
310
            grads = joint_operands_grads[0]
311
            inputs = joint_operands_grads[1][-1:]
312
        else:
313
            grads = joint_operands_grads[:num_grads]
314
            inputs = joint_operands_grads[num_grads:]
315

316
        joint = create_joint(prepare_fw_with_masks(fn), aot_config=dummy_aot_config)
317
        _, grads = joint(
318
            list(inputs),
319
            [grad for grad in grads if grad is not None and grad.requires_grad],
320
        )
321

322
        # In order to keep map functional for backward graph,
323
        # we clone outputs that are aliasing inputs
324
        maybe_clone = clone_outputs_aliasing_inputs(joint_operands_grads)
325

326
        return pytree.tree_map(maybe_clone, grads)
327

328
    if use_output_and_grad_bw:
329
        example_xs_out = list(fw_inputs) + list(fw_outputs)
330
        joint_graph = _maybe_reenter_make_fx(joint_fn)(
331
            (list(example_grad), list(example_xs_out))
332
        )
333
    else:
334
        example_xs_out = list(fw_inputs)
335
        joint_graph = _maybe_reenter_make_fx(joint_fn)(
336
            *(list(example_grad) + list(example_xs_out))
337
        )
338

339
    return fw_graph, joint_graph
340

341

342
def _unstack_pytree(xs):
343
    flat_xs, inspec = pytree.tree_flatten(xs)
344
    if not all(isinstance(xs, torch.Tensor) for xs in flat_xs):
345
        raise RuntimeError(f"Leaves of xs must be Tensor {flat_xs}")
346

347
    if not all(xs.shape[0] == flat_xs[0].shape[0] for xs in flat_xs):
348
        raise RuntimeError(
349
            f"Leaves of xs must have same leading dimension size {[xs.shape for xs in flat_xs]}"
350
        )
351

352
    a = zip(*flat_xs)
353

354
    pytrees = []
355
    for tuple in a:
356
        pytrees.append(pytree.tree_unflatten(tuple, inspec))
357
    return pytrees
358

359

360
def _stack_pytree(pytrees):
361
    flat_out = []
362
    out_spec = None
363
    for pt in pytrees:
364
        flat_pt, out_spec = pytree.tree_flatten(pt)
365
        flat_out.append(flat_pt)
366
    assert out_spec is not None
367
    b = zip(*flat_out)
368
    stacked_out = []
369
    for leaves in b:
370
        if all(isinstance(leaf, torch.Tensor) for leaf in leaves):
371
            stacked_out.append(torch.stack(leaves))
372
        elif all(leaf is None for leaf in leaves):
373
            # Backward graph can return None output when forward inputs doesn't require grad.
374
            # When we eagerly execute backward graph, we need to call _stack_pytree on its output,
375
            # therefore we need to deal with None output.
376
            stacked_out.append(None)  # type: ignore[arg-type]
377
        else:
378
            raise RuntimeError(f"Cannot stack {leaves}.")
379
    return pytree.tree_unflatten(stacked_out, out_spec)
380

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.