pytorch

Форк
0
520 строк · 20.0 Кб
1
# mypy: allow-untyped-defs
2
import contextlib
3
import logging
4

5
import torch
6
import torch._subclasses.functional_tensor
7
import torch.utils._pytree as pytree
8
from torch._C import DispatchKey
9
from torch._C._functorch import (
10
    _add_batch_dim,
11
    get_unwrapped,
12
    is_batchedtensor,
13
    maybe_get_bdim,
14
)
15
from torch._dispatch.python import suspend_functionalization
16
from torch._functorch.utils import exposed_in
17
from torch._guards import detect_fake_mode
18
from torch._higher_order_ops.utils import (
19
    _has_potential_branch_input_alias,
20
    _has_potential_branch_input_mutation,
21
    _set_compilation_env,
22
    reenter_make_fx,
23
    unique_graph_id,
24
    UnsupportedAliasMutationException,
25
)
26
from torch._ops import HigherOrderOperator
27
from torch._subclasses.fake_tensor import FakeTensorMode
28
from torch._subclasses.functional_tensor import disable_functional_mode
29
from torch.fx.experimental.proxy_tensor import (
30
    _temp_remove_pre_dispatch_torch_function_mode,
31
    disable_proxy_modes_tracing,
32
    ProxyTorchDispatchMode,
33
    track_tensor_tree,
34
)
35
from torch.fx.passes.shape_prop import _extract_tensor_metadata
36
from torch.utils._python_dispatch import _get_current_dispatch_mode
37

38
from .utils import _from_fun, create_fw_bw_graph
39

40

41
log = logging.getLogger(__name__)
42

43
"""
44
We're going to define a `cond_op` operation.
45
In order to do this, we need implementations for each of the dispatch keys.
46
"""
47

48

49
class CondOp(HigherOrderOperator):
50
    def __init__(self):
51
        super().__init__("cond")
52

53
    def __call__(self, pred, true_fn, false_fn, operands):
54
        return super().__call__(pred, true_fn, false_fn, operands)
55

56

57
cond_op = CondOp()
58

59

60
@exposed_in("torch")
61
def cond(pred, true_fn, false_fn, operands):
62
    r"""
63
    Conditionally applies `true_fn` or `false_fn`.
64

65
    .. warning::
66
        `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
67
        doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
68
        Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
69

70
    `cond` is structured control flow operator. That is, it is like a Python if-statement,
71
    but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
72
    capturable using torch.compile and torch.export.
73

74
    Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
75

76
        def cond(pred, true_branch, false_branch, operands):
77
            if pred:
78
                return true_branch(*operands)
79
            else:
80
                return false_branch(*operands)
81

82
    Args:
83
        pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
84
          indicating which branch function to apply.
85

86
        true_fn (Callable): A callable function (a -> b) that is within the
87
          scope that is being traced.
88

89
        false_fn (Callable): A callable function (a -> b) that is within the
90
          scope that is being traced. The true branch and false branch must
91
          have consistent input and outputs, meaning the inputs have to be
92
          the same, and the outputs have to be the same type and shape.
93

94
        operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
95

96
    Example::
97

98
        def true_fn(x: torch.Tensor):
99
            return x.cos()
100
        def false_fn(x: torch.Tensor):
101
            return x.sin()
102
        return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
103

104
    Restrictions:
105
        - The conditional statement (aka `pred`) must meet one of the following constraints:
106

107
          - It's a `torch.Tensor` with only one element, and torch.bool dtype
108

109
          - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10`
110

111
        - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints:
112

113
          - The function signature must match with operands.
114

115
          - The function must return a tensor with the same metadata, e.g. shape,
116
            dtype, etc.
117

118
          - The function cannot have in-place mutations on inputs or global variables.
119
            (Note: in-place tensor operations such as `add_` for intermediate results
120
            are allowed in a branch)
121

122
    .. warning::
123
        Temporal Limitations:
124

125
        - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
126

127
    """
128
    if torch.compiler.is_dynamo_compiling():
129
        return cond_op(pred, true_fn, false_fn, operands)
130

131
    if isinstance(pred, (bool, int, float)):
132
        log.warning(
133
            "Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
134
            " If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
135
        )
136
        if pred:
137
            return true_fn(*operands)
138
        else:
139
            return false_fn(*operands)
140

141
    def _validate_input(pred, true_fn, false_fn, operands):
142
        if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
143
            raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
144

145
        if isinstance(pred, torch.Tensor) and pred.numel() != 1:
146
            raise RuntimeError(
147
                f"Expected pred to be bool or single-element tensor, but got {pred}."
148
            )
149

150
        if not callable(true_fn) or not callable(false_fn):
151
            raise RuntimeError("Expect both branches to be callbale.")
152

153
        if not isinstance(operands, (tuple, list)) or pytree.tree_any(
154
            lambda t: not isinstance(t, torch.Tensor), operands
155
        ):
156
            raise RuntimeError(
157
                "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
158
                f"consists of tensor leaves, but got {operands}."
159
            )
160

161
    _validate_input(pred, true_fn, false_fn, operands)
162

163
    if not torch._dynamo.is_dynamo_supported():
164
        raise RuntimeError("torch.cond requires dynamo support.")
165

166
    # Dynamo is expecting a callable with "__code__" attribute.
167
    # We cannot directly pass cond_op to it. So we wrap it in a dummy function.
168
    def _cond_op_wrapper(*args, **kwargs):
169
        return cond_op(*args, **kwargs)
170

171
    with _set_compilation_env():
172
        with torch._dynamo.utils.disable_cache_limit():
173
            with _temp_remove_pre_dispatch_torch_function_mode():
174
                return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
175
                    pred, true_fn, false_fn, operands
176
                )
177

178

179
def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
180
    # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
181

182
    with suspend_functionalization(), disable_functional_mode():
183
        with disable_proxy_modes_tracing():
184
            fw_inputs = pytree.tree_map(_from_fun, operands)
185

186
            fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs))
187
            if any(
188
                not isinstance(out, torch.Tensor)
189
                for out in fw_outputs_true
190
                if out is not None
191
            ):
192
                raise RuntimeError(
193
                    "Expect outputs of true_fn to only contains tensors or None. "
194
                    f"Got types {[type(out) for out in fw_outputs_true]}."
195
                )
196
            fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs))
197
            if any(
198
                not isinstance(out, torch.Tensor)
199
                for out in fw_outputs_false
200
                if out is not None
201
            ):
202
                raise RuntimeError(
203
                    "Expect outputs of false_fn to only contains tensors or None. "
204
                    f"Got types {[type(out) for out in fw_outputs_false]}."
205
                )
206

207
            # TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice:
208
            # Once in the forward path (as it should) and once in the backward path, where it shouldn't be called
209
            # If we can get rid of the second invokation, it would simplify this function
210
            fw_true_graph, joint_true_graph = create_fw_bw_graph(
211
                true_fn, False, fw_inputs, fw_outputs_true
212
            )
213
            fw_false_graph, joint_false_graph = create_fw_bw_graph(
214
                false_fn, False, fw_inputs, fw_outputs_false
215
            )
216

217
        return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
218

219

220
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
221
    assert isinstance(
222
        operands, (list, tuple)
223
    ), "Cond operands must be a list or tuple of tensors"
224
    assert all(
225
        isinstance(o, torch.Tensor) for o in operands
226
    ), "Cond operands must be a list of tensors"
227

228
    true_graph = reenter_make_fx(true_fn)(*operands)
229
    false_graph = reenter_make_fx(false_fn)(*operands)
230

231
    true_outs = []
232
    false_outs = []
233
    for node in true_graph.graph.nodes:
234
        if node.op == "output":
235
            true_outs.extend(node.args)
236

237
    for node in false_graph.graph.nodes:
238
        if node.op == "output":
239
            false_outs.extend(node.args)
240

241
    flat_true_outs = pytree.arg_tree_leaves(*true_outs)
242
    flat_false_outs = pytree.arg_tree_leaves(*false_outs)
243
    if len(flat_true_outs) != len(flat_false_outs):
244
        raise torch._dynamo.exc.CondOpArgsMismatchError(
245
            f"Expected to return same number of outputs but got:"
246
            f"\n  true branch returns {len(flat_true_outs)} item(s)"
247
            f"\n  false branch returns {len(flat_false_outs)} item(s)"
248
        )
249

250
    for i in range(0, len(flat_true_outs)):
251
        true_out = flat_true_outs[i]
252
        false_out = flat_false_outs[i]
253

254
        # Note that we need skip the check for requires_grad because we're after
255
        # after autograd key during tracing, so the rquires_grad attribute of the tensors
256
        # are no longer. See Note [invariants for node meta 'val']
257
        def _same_meta_except_requires_grad(true_out, false_out):
258
            if true_out is None and false_out is None:
259
                return True
260
            elif true_out is None or false_out is None:
261
                # Consider the following case:
262
                # def true_fn(x, y):
263
                #   return x * y
264
                #
265
                # def false_fn(x, y):
266
                #   return x.sin()
267
                #
268
                # We'll get the following graphs for backward:
269
                # def backward_true_fn(x, y, grad_out):
270
                #  return grad_out * y, grad_out * x
271
                #
272
                # def backward_false_fn(x, y, grad_out):
273
                #  retrun grad_out, None
274
                #
275
                # This suggests that when we make_fx into the backward graph,
276
                # the output graph would produce outputs with metadata, this is undesirable.
277
                #
278
                # Ideally, we should provide an optional type to indicate that one of the branches might
279
                # return None. But we'll just let it pass for now and let downstream/runtime handle.
280
                #
281
                # Note that this corner case should **only** happen when user want to trace backward graph because
282
                # if it's foward, dynamo will error.
283
                return True
284
            true_meta = true_out.meta.get("tensor_meta", None)
285
            false_meta = false_out.meta.get("tensor_meta", None)
286
            return (
287
                true_meta.shape == false_meta.shape
288
                and true_meta.dtype == false_meta.dtype
289
                and true_meta.stride == false_meta.stride
290
            )
291

292
        if not _same_meta_except_requires_grad(true_out, false_out):
293
            raise torch._dynamo.exc.CondOpArgsMismatchError(
294
                f"Expected each tensor to have same metadata but got:"
295
                f"\n  {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
296
                f"\n  {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
297
            )
298

299
    i, true_name = unique_graph_id(proxy_mode, prefix="true_graph")
300

301
    false_name = f"false_graph_{i}"
302
    assert not hasattr(proxy_mode.tracer.root, false_name)
303

304
    proxy_mode.tracer.root.register_module(true_name, true_graph)
305
    proxy_mode.tracer.root.register_module(false_name, false_graph)
306

307
    args = (pred, true_graph, false_graph, operands)
308

309
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
310

311
    out_proxy = proxy_mode.tracer.create_proxy(
312
        "call_function", func_overload, proxy_args, {}
313
    )
314

315
    # At this point, we're *guaranteed* that whether an output came from the
316
    # true or false branch is indistinguishable. So, as this is just for tracing
317
    # purposes, choose the true branch.
318

319
    # TODO: the unbacked symbol allocations MUST NOT leak out, if you want to
320
    # support this we need to arrange for the reenter_make_fx unbacked SymInts
321
    # to be used, AND we need to arrange for some sort of unification between
322
    # the two branches (but not really unification; e.g., if one branch
323
    # returns [u0] and the other returns [5] this is OK but you MUST NOT
324
    # conclude the result is 5.  Also if one branch returns [3] and another
325
    # branch returns [5] you can make it work by immediately allocating a new
326
    # unbacked SymInt here).
327
    ignore_fresh_unbacked = contextlib.nullcontext()
328
    if (fake_mode := detect_fake_mode()) and fake_mode.shape_env:
329
        ignore_fresh_unbacked = fake_mode.shape_env.ignore_fresh_unbacked_symbols()
330

331
    # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
332
    # a FakeTensorMode error :
333
    # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
334
    # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in
335
    # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys.
336
    with ignore_fresh_unbacked:
337
        out = false_fn(*operands)
338

339
    return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
340

341

342
@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
343
def cond_op_dense(pred, true_fn, false_fn, operands):
344
    mode = _get_current_dispatch_mode()
345
    assert mode is None, "Mode should never be enabled for CPU/CUDA key"
346
    if pred:
347
        return true_fn(*operands)
348
    else:
349
        return false_fn(*operands)
350

351

352
class CondAutogradOp(torch.autograd.Function):
353
    @staticmethod
354
    def forward(
355
        ctx,
356
        pred,
357
        fw_true_graph,
358
        fw_false_graph,
359
        joint_true_graph,
360
        joint_false_graph,
361
        *operands,
362
    ):
363
        ctx._pred = pred
364
        ctx._joint_true_graph = joint_true_graph
365
        ctx._joint_false_graph = joint_false_graph
366
        ctx.save_for_backward(*operands)
367

368
        with torch._C._AutoDispatchBelowAutograd():
369
            return cond_op(pred, fw_true_graph, fw_false_graph, operands)
370

371
    @staticmethod
372
    def backward(ctx, *flat_grads):
373
        operands = ctx.saved_tensors
374

375
        grads = cond_op(
376
            ctx._pred,
377
            ctx._joint_true_graph,
378
            ctx._joint_false_graph,
379
            flat_grads + operands,
380
        )
381
        return None, None, None, None, None, *grads
382

383

384
@cond_op.py_impl(DispatchKey.Autograd)
385
def cond_autograd(pred, true_fn, false_fn, operands):
386
    # A shortcut for the case where all inputs don't require gradient,
387
    # we skip tracing the forward and backward graph.
388
    if pytree.tree_all_only(
389
        torch.Tensor,
390
        lambda t: not t.requires_grad,  # type: ignore[union-attr]
391
        (pred, operands),
392
    ):
393
        with torch._C._AutoDispatchBelowAutograd():
394
            return cond_op(pred, true_fn, false_fn, operands)
395

396
    (
397
        fw_true_graph,
398
        fw_false_graph,
399
        joint_true_graph,
400
        joint_false_graph,
401
    ) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
402
    flat_out = CondAutogradOp.apply(
403
        pred,
404
        fw_true_graph,
405
        fw_false_graph,
406
        joint_true_graph,
407
        joint_false_graph,
408
        *operands,
409
    )
410
    return flat_out
411

412

413
@cond_op.py_impl(ProxyTorchDispatchMode)
414
def inner(mode, pred, true_fn, false_fn, operands):
415
    return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
416

417

418
@cond_op.py_impl(FakeTensorMode)
419
def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
420
    # Ignore here, because if you've gotten here but you're not manually
421
    # tracing the inner graphs, that means that you intend to reuse the graph
422
    # directly.  Which means the old unbacked symbol bindings are appropriate.
423
    # This strategy will not work if unbacked symbols can escape.
424
    ignore_fresh_unbacked = contextlib.nullcontext()
425
    if mode.shape_env:
426
        ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
427

428
    with mode, ignore_fresh_unbacked:
429
        true_outs = true_fn(*operands)
430
        flat_true_outs = pytree.tree_leaves(true_outs)
431
        flat_false_outs = pytree.tree_leaves(false_fn(*operands))
432
    if len(flat_true_outs) != len(flat_false_outs):
433
        raise RuntimeError("Unmatched number of outputs from cond() branches.")
434

435
    for true_out, false_out in zip(flat_true_outs, flat_false_outs):
436
        true_meta = _extract_tensor_metadata(true_out)
437
        false_meta = _extract_tensor_metadata(false_out)
438
        if true_meta != false_meta:
439
            raise torch._dynamo.exc.CondOpArgsMismatchError(
440
                f"Expected each tensor to have same metadata but got:"
441
                f"\n  {true_fn.__name__} returns {true_meta}"
442
                f"\n  {false_fn.__name__} returns {false_meta}"
443
            )
444
    return true_outs
445

446

447
@cond_op.py_functionalize_impl
448
def cond_func(ctx, pred, true_fn, false_fn, inputs):
449
    unwrapped_inputs = ctx.unwrap_tensors(inputs)
450
    unwrapped_pred = ctx.unwrap_tensors(pred)
451
    with ctx.redispatch_to_next() as m:
452
        functional_true = ctx.functionalize(true_fn)
453
        functional_false = ctx.functionalize(false_fn)
454
        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
455
        for branch in [functional_true, functional_false]:
456
            if _has_potential_branch_input_mutation(
457
                branch, unwrapped_inputs, pre_dispatch=pre_dispatch
458
            ):
459
                raise UnsupportedAliasMutationException(
460
                    "One of torch.cond branch might be modifying the input!"
461
                )
462
        for branch in [true_fn, false_fn]:
463
            if _has_potential_branch_input_alias(
464
                branch, unwrapped_inputs, pre_dispatch=pre_dispatch
465
            ):
466
                raise UnsupportedAliasMutationException(
467
                    "One of torch.cond branch might be aliasing the input!"
468
                )
469

470
        cond_return = cond_op(
471
            unwrapped_pred, functional_true, functional_false, unwrapped_inputs
472
        )
473
        return ctx.wrap_tensors(cond_return)
474

475

476
@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
477
def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
478
    assert isinstance(
479
        inputs, (list, tuple)
480
    ), "Cond inputs must be a list or tuple of tensors"
481
    assert all(
482
        isinstance(i, torch.Tensor) for i in inputs
483
    ), "Cond inputs must be a list of tensors"
484

485
    pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
486

487
    # unbatched tensors are not vmapped
488
    tensors, in_dims = zip(
489
        *[
490
            (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
491
            for t in inputs
492
        ]
493
    )
494

495
    if is_batchedtensor(pred):
496
        # prepend "pred" and vmap everything
497
        tensors = (pred_,) + tensors
498
        in_dims = (0,) + in_dims
499

500
        def fn(p, *args):
501
            t = true_fn(*args)
502
            f = false_fn(*args)
503
            return torch.where(p, t[0], f[0])
504

505
        with interpreter.lower():
506
            result = torch.vmap(fn, in_dims=in_dims)(*tensors)
507

508
    else:
509
        # predicate is known at this stage and it is a boolean expression or a
510
        # tensor with one element.
511
        true_fn = torch.vmap(true_fn, in_dims=in_dims)
512
        false_fn = torch.vmap(false_fn, in_dims=in_dims)
513

514
        with interpreter.lower():
515
            result = cond_op(pred, true_fn, false_fn, tensors)
516

517
    if not isinstance(result, tuple):
518
        result = (result,)
519
    lvl = interpreter.level()
520
    return tuple([_add_batch_dim(r, 0, lvl) for r in result])
521

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.