pytorch
520 строк · 20.0 Кб
1# mypy: allow-untyped-defs
2import contextlib
3import logging
4
5import torch
6import torch._subclasses.functional_tensor
7import torch.utils._pytree as pytree
8from torch._C import DispatchKey
9from torch._C._functorch import (
10_add_batch_dim,
11get_unwrapped,
12is_batchedtensor,
13maybe_get_bdim,
14)
15from torch._dispatch.python import suspend_functionalization
16from torch._functorch.utils import exposed_in
17from torch._guards import detect_fake_mode
18from torch._higher_order_ops.utils import (
19_has_potential_branch_input_alias,
20_has_potential_branch_input_mutation,
21_set_compilation_env,
22reenter_make_fx,
23unique_graph_id,
24UnsupportedAliasMutationException,
25)
26from torch._ops import HigherOrderOperator
27from torch._subclasses.fake_tensor import FakeTensorMode
28from torch._subclasses.functional_tensor import disable_functional_mode
29from torch.fx.experimental.proxy_tensor import (
30_temp_remove_pre_dispatch_torch_function_mode,
31disable_proxy_modes_tracing,
32ProxyTorchDispatchMode,
33track_tensor_tree,
34)
35from torch.fx.passes.shape_prop import _extract_tensor_metadata
36from torch.utils._python_dispatch import _get_current_dispatch_mode
37
38from .utils import _from_fun, create_fw_bw_graph
39
40
41log = logging.getLogger(__name__)
42
43"""
44We're going to define a `cond_op` operation.
45In order to do this, we need implementations for each of the dispatch keys.
46"""
47
48
49class CondOp(HigherOrderOperator):
50def __init__(self):
51super().__init__("cond")
52
53def __call__(self, pred, true_fn, false_fn, operands):
54return super().__call__(pred, true_fn, false_fn, operands)
55
56
57cond_op = CondOp()
58
59
60@exposed_in("torch")
61def cond(pred, true_fn, false_fn, operands):
62r"""
63Conditionally applies `true_fn` or `false_fn`.
64
65.. warning::
66`torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
67doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
68Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
69
70`cond` is structured control flow operator. That is, it is like a Python if-statement,
71but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
72capturable using torch.compile and torch.export.
73
74Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
75
76def cond(pred, true_branch, false_branch, operands):
77if pred:
78return true_branch(*operands)
79else:
80return false_branch(*operands)
81
82Args:
83pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
84indicating which branch function to apply.
85
86true_fn (Callable): A callable function (a -> b) that is within the
87scope that is being traced.
88
89false_fn (Callable): A callable function (a -> b) that is within the
90scope that is being traced. The true branch and false branch must
91have consistent input and outputs, meaning the inputs have to be
92the same, and the outputs have to be the same type and shape.
93
94operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
95
96Example::
97
98def true_fn(x: torch.Tensor):
99return x.cos()
100def false_fn(x: torch.Tensor):
101return x.sin()
102return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
103
104Restrictions:
105- The conditional statement (aka `pred`) must meet one of the following constraints:
106
107- It's a `torch.Tensor` with only one element, and torch.bool dtype
108
109- It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10`
110
111- The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints:
112
113- The function signature must match with operands.
114
115- The function must return a tensor with the same metadata, e.g. shape,
116dtype, etc.
117
118- The function cannot have in-place mutations on inputs or global variables.
119(Note: in-place tensor operations such as `add_` for intermediate results
120are allowed in a branch)
121
122.. warning::
123Temporal Limitations:
124
125- The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
126
127"""
128if torch.compiler.is_dynamo_compiling():
129return cond_op(pred, true_fn, false_fn, operands)
130
131if isinstance(pred, (bool, int, float)):
132log.warning(
133"Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
134" If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
135)
136if pred:
137return true_fn(*operands)
138else:
139return false_fn(*operands)
140
141def _validate_input(pred, true_fn, false_fn, operands):
142if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
143raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
144
145if isinstance(pred, torch.Tensor) and pred.numel() != 1:
146raise RuntimeError(
147f"Expected pred to be bool or single-element tensor, but got {pred}."
148)
149
150if not callable(true_fn) or not callable(false_fn):
151raise RuntimeError("Expect both branches to be callbale.")
152
153if not isinstance(operands, (tuple, list)) or pytree.tree_any(
154lambda t: not isinstance(t, torch.Tensor), operands
155):
156raise RuntimeError(
157"Expect operands to be a tuple of possibly nested dict/list/tuple that only"
158f"consists of tensor leaves, but got {operands}."
159)
160
161_validate_input(pred, true_fn, false_fn, operands)
162
163if not torch._dynamo.is_dynamo_supported():
164raise RuntimeError("torch.cond requires dynamo support.")
165
166# Dynamo is expecting a callable with "__code__" attribute.
167# We cannot directly pass cond_op to it. So we wrap it in a dummy function.
168def _cond_op_wrapper(*args, **kwargs):
169return cond_op(*args, **kwargs)
170
171with _set_compilation_env():
172with torch._dynamo.utils.disable_cache_limit():
173with _temp_remove_pre_dispatch_torch_function_mode():
174return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
175pred, true_fn, false_fn, operands
176)
177
178
179def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
180# See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
181
182with suspend_functionalization(), disable_functional_mode():
183with disable_proxy_modes_tracing():
184fw_inputs = pytree.tree_map(_from_fun, operands)
185
186fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs))
187if any(
188not isinstance(out, torch.Tensor)
189for out in fw_outputs_true
190if out is not None
191):
192raise RuntimeError(
193"Expect outputs of true_fn to only contains tensors or None. "
194f"Got types {[type(out) for out in fw_outputs_true]}."
195)
196fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs))
197if any(
198not isinstance(out, torch.Tensor)
199for out in fw_outputs_false
200if out is not None
201):
202raise RuntimeError(
203"Expect outputs of false_fn to only contains tensors or None. "
204f"Got types {[type(out) for out in fw_outputs_false]}."
205)
206
207# TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice:
208# Once in the forward path (as it should) and once in the backward path, where it shouldn't be called
209# If we can get rid of the second invokation, it would simplify this function
210fw_true_graph, joint_true_graph = create_fw_bw_graph(
211true_fn, False, fw_inputs, fw_outputs_true
212)
213fw_false_graph, joint_false_graph = create_fw_bw_graph(
214false_fn, False, fw_inputs, fw_outputs_false
215)
216
217return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
218
219
220def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
221assert isinstance(
222operands, (list, tuple)
223), "Cond operands must be a list or tuple of tensors"
224assert all(
225isinstance(o, torch.Tensor) for o in operands
226), "Cond operands must be a list of tensors"
227
228true_graph = reenter_make_fx(true_fn)(*operands)
229false_graph = reenter_make_fx(false_fn)(*operands)
230
231true_outs = []
232false_outs = []
233for node in true_graph.graph.nodes:
234if node.op == "output":
235true_outs.extend(node.args)
236
237for node in false_graph.graph.nodes:
238if node.op == "output":
239false_outs.extend(node.args)
240
241flat_true_outs = pytree.arg_tree_leaves(*true_outs)
242flat_false_outs = pytree.arg_tree_leaves(*false_outs)
243if len(flat_true_outs) != len(flat_false_outs):
244raise torch._dynamo.exc.CondOpArgsMismatchError(
245f"Expected to return same number of outputs but got:"
246f"\n true branch returns {len(flat_true_outs)} item(s)"
247f"\n false branch returns {len(flat_false_outs)} item(s)"
248)
249
250for i in range(0, len(flat_true_outs)):
251true_out = flat_true_outs[i]
252false_out = flat_false_outs[i]
253
254# Note that we need skip the check for requires_grad because we're after
255# after autograd key during tracing, so the rquires_grad attribute of the tensors
256# are no longer. See Note [invariants for node meta 'val']
257def _same_meta_except_requires_grad(true_out, false_out):
258if true_out is None and false_out is None:
259return True
260elif true_out is None or false_out is None:
261# Consider the following case:
262# def true_fn(x, y):
263# return x * y
264#
265# def false_fn(x, y):
266# return x.sin()
267#
268# We'll get the following graphs for backward:
269# def backward_true_fn(x, y, grad_out):
270# return grad_out * y, grad_out * x
271#
272# def backward_false_fn(x, y, grad_out):
273# retrun grad_out, None
274#
275# This suggests that when we make_fx into the backward graph,
276# the output graph would produce outputs with metadata, this is undesirable.
277#
278# Ideally, we should provide an optional type to indicate that one of the branches might
279# return None. But we'll just let it pass for now and let downstream/runtime handle.
280#
281# Note that this corner case should **only** happen when user want to trace backward graph because
282# if it's foward, dynamo will error.
283return True
284true_meta = true_out.meta.get("tensor_meta", None)
285false_meta = false_out.meta.get("tensor_meta", None)
286return (
287true_meta.shape == false_meta.shape
288and true_meta.dtype == false_meta.dtype
289and true_meta.stride == false_meta.stride
290)
291
292if not _same_meta_except_requires_grad(true_out, false_out):
293raise torch._dynamo.exc.CondOpArgsMismatchError(
294f"Expected each tensor to have same metadata but got:"
295f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
296f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
297)
298
299i, true_name = unique_graph_id(proxy_mode, prefix="true_graph")
300
301false_name = f"false_graph_{i}"
302assert not hasattr(proxy_mode.tracer.root, false_name)
303
304proxy_mode.tracer.root.register_module(true_name, true_graph)
305proxy_mode.tracer.root.register_module(false_name, false_graph)
306
307args = (pred, true_graph, false_graph, operands)
308
309proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
310
311out_proxy = proxy_mode.tracer.create_proxy(
312"call_function", func_overload, proxy_args, {}
313)
314
315# At this point, we're *guaranteed* that whether an output came from the
316# true or false branch is indistinguishable. So, as this is just for tracing
317# purposes, choose the true branch.
318
319# TODO: the unbacked symbol allocations MUST NOT leak out, if you want to
320# support this we need to arrange for the reenter_make_fx unbacked SymInts
321# to be used, AND we need to arrange for some sort of unification between
322# the two branches (but not really unification; e.g., if one branch
323# returns [u0] and the other returns [5] this is OK but you MUST NOT
324# conclude the result is 5. Also if one branch returns [3] and another
325# branch returns [5] you can make it work by immediately allocating a new
326# unbacked SymInt here).
327ignore_fresh_unbacked = contextlib.nullcontext()
328if (fake_mode := detect_fake_mode()) and fake_mode.shape_env:
329ignore_fresh_unbacked = fake_mode.shape_env.ignore_fresh_unbacked_symbols()
330
331# TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
332# a FakeTensorMode error :
333# `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
334# TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in
335# dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys.
336with ignore_fresh_unbacked:
337out = false_fn(*operands)
338
339return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
340
341
342@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
343def cond_op_dense(pred, true_fn, false_fn, operands):
344mode = _get_current_dispatch_mode()
345assert mode is None, "Mode should never be enabled for CPU/CUDA key"
346if pred:
347return true_fn(*operands)
348else:
349return false_fn(*operands)
350
351
352class CondAutogradOp(torch.autograd.Function):
353@staticmethod
354def forward(
355ctx,
356pred,
357fw_true_graph,
358fw_false_graph,
359joint_true_graph,
360joint_false_graph,
361*operands,
362):
363ctx._pred = pred
364ctx._joint_true_graph = joint_true_graph
365ctx._joint_false_graph = joint_false_graph
366ctx.save_for_backward(*operands)
367
368with torch._C._AutoDispatchBelowAutograd():
369return cond_op(pred, fw_true_graph, fw_false_graph, operands)
370
371@staticmethod
372def backward(ctx, *flat_grads):
373operands = ctx.saved_tensors
374
375grads = cond_op(
376ctx._pred,
377ctx._joint_true_graph,
378ctx._joint_false_graph,
379flat_grads + operands,
380)
381return None, None, None, None, None, *grads
382
383
384@cond_op.py_impl(DispatchKey.Autograd)
385def cond_autograd(pred, true_fn, false_fn, operands):
386# A shortcut for the case where all inputs don't require gradient,
387# we skip tracing the forward and backward graph.
388if pytree.tree_all_only(
389torch.Tensor,
390lambda t: not t.requires_grad, # type: ignore[union-attr]
391(pred, operands),
392):
393with torch._C._AutoDispatchBelowAutograd():
394return cond_op(pred, true_fn, false_fn, operands)
395
396(
397fw_true_graph,
398fw_false_graph,
399joint_true_graph,
400joint_false_graph,
401) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
402flat_out = CondAutogradOp.apply(
403pred,
404fw_true_graph,
405fw_false_graph,
406joint_true_graph,
407joint_false_graph,
408*operands,
409)
410return flat_out
411
412
413@cond_op.py_impl(ProxyTorchDispatchMode)
414def inner(mode, pred, true_fn, false_fn, operands):
415return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
416
417
418@cond_op.py_impl(FakeTensorMode)
419def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
420# Ignore here, because if you've gotten here but you're not manually
421# tracing the inner graphs, that means that you intend to reuse the graph
422# directly. Which means the old unbacked symbol bindings are appropriate.
423# This strategy will not work if unbacked symbols can escape.
424ignore_fresh_unbacked = contextlib.nullcontext()
425if mode.shape_env:
426ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
427
428with mode, ignore_fresh_unbacked:
429true_outs = true_fn(*operands)
430flat_true_outs = pytree.tree_leaves(true_outs)
431flat_false_outs = pytree.tree_leaves(false_fn(*operands))
432if len(flat_true_outs) != len(flat_false_outs):
433raise RuntimeError("Unmatched number of outputs from cond() branches.")
434
435for true_out, false_out in zip(flat_true_outs, flat_false_outs):
436true_meta = _extract_tensor_metadata(true_out)
437false_meta = _extract_tensor_metadata(false_out)
438if true_meta != false_meta:
439raise torch._dynamo.exc.CondOpArgsMismatchError(
440f"Expected each tensor to have same metadata but got:"
441f"\n {true_fn.__name__} returns {true_meta}"
442f"\n {false_fn.__name__} returns {false_meta}"
443)
444return true_outs
445
446
447@cond_op.py_functionalize_impl
448def cond_func(ctx, pred, true_fn, false_fn, inputs):
449unwrapped_inputs = ctx.unwrap_tensors(inputs)
450unwrapped_pred = ctx.unwrap_tensors(pred)
451with ctx.redispatch_to_next() as m:
452functional_true = ctx.functionalize(true_fn)
453functional_false = ctx.functionalize(false_fn)
454pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
455for branch in [functional_true, functional_false]:
456if _has_potential_branch_input_mutation(
457branch, unwrapped_inputs, pre_dispatch=pre_dispatch
458):
459raise UnsupportedAliasMutationException(
460"One of torch.cond branch might be modifying the input!"
461)
462for branch in [true_fn, false_fn]:
463if _has_potential_branch_input_alias(
464branch, unwrapped_inputs, pre_dispatch=pre_dispatch
465):
466raise UnsupportedAliasMutationException(
467"One of torch.cond branch might be aliasing the input!"
468)
469
470cond_return = cond_op(
471unwrapped_pred, functional_true, functional_false, unwrapped_inputs
472)
473return ctx.wrap_tensors(cond_return)
474
475
476@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
477def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
478assert isinstance(
479inputs, (list, tuple)
480), "Cond inputs must be a list or tuple of tensors"
481assert all(
482isinstance(i, torch.Tensor) for i in inputs
483), "Cond inputs must be a list of tensors"
484
485pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
486
487# unbatched tensors are not vmapped
488tensors, in_dims = zip(
489*[
490(get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
491for t in inputs
492]
493)
494
495if is_batchedtensor(pred):
496# prepend "pred" and vmap everything
497tensors = (pred_,) + tensors
498in_dims = (0,) + in_dims
499
500def fn(p, *args):
501t = true_fn(*args)
502f = false_fn(*args)
503return torch.where(p, t[0], f[0])
504
505with interpreter.lower():
506result = torch.vmap(fn, in_dims=in_dims)(*tensors)
507
508else:
509# predicate is known at this stage and it is a boolean expression or a
510# tensor with one element.
511true_fn = torch.vmap(true_fn, in_dims=in_dims)
512false_fn = torch.vmap(false_fn, in_dims=in_dims)
513
514with interpreter.lower():
515result = cond_op(pred, true_fn, false_fn, tensors)
516
517if not isinstance(result, tuple):
518result = (result,)
519lvl = interpreter.level()
520return tuple([_add_batch_dim(r, 0, lvl) for r in result])
521