pytorch
362 строки · 12.8 Кб
1# mypy: allow-untyped-defs
2import functools
3import itertools
4from typing import Callable, List
5
6import torch
7import torch._prims_common as utils
8import torch._subclasses.functional_tensor
9import torch.utils._pytree as pytree
10from torch._C import DispatchKey
11from torch._higher_order_ops.utils import (
12_set_compilation_env,
13autograd_not_implemented,
14reenter_make_fx,
15unique_graph_id,
16)
17from torch._inductor.utils import is_pointwise_use
18from torch._ops import HigherOrderOperator
19from torch._subclasses.fake_tensor import FakeTensorMode
20from torch.fx.experimental.proxy_tensor import (
21disable_proxy_modes_tracing,
22ProxyTorchDispatchMode,
23track_tensor_tree,
24)
25
26
27aten = torch._ops.ops.aten
28
29
30def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
31assert len(args) == 2 * num_leaves
32lhs = pytree.tree_unflatten(args[:num_leaves], spec)
33rhs = pytree.tree_unflatten(args[num_leaves:], spec)
34combined = combine_fn(lhs, rhs)
35combined_leaves = pytree.tree_leaves(combined)
36assert num_leaves == len(combined_leaves)
37return combined_leaves
38
39
40def _interleave(a, b, dim):
41# https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
42if b_trunc := (a.shape[dim] == b.shape[dim] + 1):
43pad = (
44[0] * ((b.ndim - dim - 1) * 2 + 1)
45+ [1]
46+ [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2))
47)
48b = torch.nn.functional.pad(b, pad)
49
50stacked = torch.stack([a, b], dim=dim + 1)
51interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
52if b_trunc:
53# TODO: find torch alternative for slice_along dim for torch.jit.script to work
54interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
55return interleaved
56
57
58def safe_map(f, *args):
59args = list(map(list, args))
60n = len(args[0])
61for arg in args[1:]:
62if len(arg) != n:
63raise ValueError("length mismatch: {list(map(len, args))}")
64
65def nf(a):
66return f(*a)
67
68return list(map(nf, zip(*args)))
69
70
71class AssociativeScanOp(HigherOrderOperator):
72def __init__(self):
73super().__init__("associative_scan")
74
75def __call__(self, combine_fn, input, dim):
76return super().__call__(combine_fn, input, dim)
77
78
79associative_scan_op = AssociativeScanOp()
80
81
82def associative_scan(
83combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
84input: pytree.PyTree,
85dim: int,
86reverse: bool = False,
87combine_mode: str = "pointwise",
88) -> torch.Tensor:
89r"""
90Performs an inclusive scan with an associative pointwise combine function.
91
92.. warning::
93`torch.associative_scan` is a prototype feature in PyTorch. It currently
94does not support autograd and you may run into miscompiles.
95Read more about feature classification at:
96https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
97
98This operator requires runtime code generation and so requires support for
99``torch.compile``. Further, only CUDA device codegen is supported at the moment.
100
101Args:
102combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
103or if input is a pytree ``(pytree, pytree) -> pytree``.
104This function must be pure, pointwise, and satisfy the associative property.
105input (torch.Tensor): The input tensor, or nested pytree of tensors.
106All inputs are expected to have the same shape.
107dim (int): the dimension to scan over
108reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension.
109combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``.
110If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations
111and ``input`` must be CUDA tensors.
112In all other cases ``combine_mode=generic`` should be used.
113Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``.
114
115
116Example::
117
118def add(x: torch.Tensor, y: torch.Tensor):
119return x + y
120
121cumsum = associative_scan(add, x, dim)
122
123"""
124assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
125assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
126assert combine_mode in ["pointwise", "generic"]
127
128if not torch._dynamo.is_compiling():
129with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
130return torch.compile(associative_scan, fullgraph=True)(
131combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode
132)
133
134leaves, spec = pytree.tree_flatten(input)
135
136if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves):
137raise ValueError(
138"For combine_mode='pointwise', all input tensors need to be on CUDA"
139)
140
141assert len(leaves) >= 1, "expected at least 1 input leaf"
142assert all(
143isinstance(x, torch.Tensor) for x in leaves
144), "input leaves must be a Tensor"
145
146if reverse:
147leaves = [torch.flip(elem, [dim]) for elem in leaves]
148
149shape = leaves[0].shape
150ndim = len(shape)
151dim = utils.canonicalize_dim(ndim, dim)
152
153for x in leaves[1:]:
154assert x.shape == shape, "All input tensors must have the same shape"
155
156out = combine_fn(
157pytree.tree_unflatten(leaves, spec),
158pytree.tree_unflatten(leaves, spec),
159)
160out_leaves, tree_out = pytree.tree_flatten(out)
161assert len(leaves) == len(
162out_leaves
163), "The pytree of the output of the operator needs to match the input pytree"
164for x in out_leaves:
165assert (
166x.shape == shape
167), "The pytree of the output of the operator needs to match the input pytree"
168
169combine_fn = functools.partial(
170wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves)
171)
172
173if combine_mode == "generic":
174result_flat = generic_associative_scan(combine_fn, leaves, dim)
175else:
176result_flat = associative_scan_op(combine_fn, leaves, dim)
177
178if reverse:
179result_flat = [torch.flip(elem, [dim]) for elem in result_flat]
180
181return pytree.tree_unflatten(result_flat, spec)
182
183
184def generic_associative_scan(operator, elems_flat, dim=0):
185r"""
186This function performs the associative_scan operation.
187The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently
188applying the ``operator`` on all pairs in parallel along ``dim``.
189The results of the recursive calls are later combined.
190
191Args:
192operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
193or if input is a pytree ``(pytree, pytree) -> pytree``.
194This function must be pure, pointwise, and satisfy the associative property.
195elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of
196``input`` provided to ``associative_scan``.
197All inputs are expected to have the same shape.
198dim (int): the dimension to scan over
199
200
201Example::
202
203def add(x: torch.Tensor, y: torch.Tensor):
204return x + y
205
206elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0])
207
208First iteration of _scan ->
209# odd_elems -> apply operator on all neighbours
210# odd_elems = operator([torch.tensor([0.0, 2.0])],
211# [torch.tensor([1.0, 3.0])])
212odd_elems = torch.tensor([1.0, 5.0])
213Second iteration of _scan ->
214# odd_elems = operator([torch.tensor([1.0])],
215# [torch.tensor([5.0])])
216odd_elems = torch.tensor([6.0])
217# even_elems -> apply operator on all odd_elems and
218# every second element of ``elems``, starting from the second element.
219# even_elems is expanded with the first element of ``elems``
220even_elems = [1.0]
221# Merges odd_elems and even_elems
222res = torch.tensor([1.0, 6.0])
223# even_elems -> apply operator on all odd_elems and
224# every second element of ``elems``, starting from the second element.
225# even_elems is expanded with the first element of ``elems``
226even_elems = [0.0, 3.0]
227# Merges odd_elems and even_elems
228res = torch.tensor([0.0, 1.0, 3.0, 6.0])
229
230"""
231
232def _scan(elems):
233"""Perform the actual recursive scan on ``elems``."""
234num_elems = elems[0].shape[dim]
235
236if num_elems < 2:
237return elems
238
239reduced_elems = operator(
240*[aten.slice(elem, dim, 0, -1, 2) for elem in elems],
241*[aten.slice(elem, dim, 1, None, 2) for elem in elems],
242)
243
244# Recursively compute scan for partially reduced tensors.
245odd_elems = _scan(reduced_elems)
246
247if num_elems % 2 == 0:
248even_elems = operator(
249*[aten.slice(e, dim, 0, -1) for e in odd_elems],
250*[aten.slice(e, dim, 2, None, 2) for e in elems],
251)
252else:
253even_elems = operator(
254*odd_elems,
255*[aten.slice(e, dim, 2, None, 2) for e in elems],
256)
257
258# The first element of a scan is the same as the first element
259# of the original `elems`.
260even_elems = [
261torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim)
262if result.shape.numel() > 0 and elem.shape[dim] > 0
263else result
264if result.shape.numel() > 0
265else aten.slice(
266elem, dim, 0, 1
267) # Jax allows/ignores concat with 0-dim, Pytorch does not
268for (elem, result) in zip(elems, even_elems)
269]
270
271return list(
272safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems)
273)
274
275scans = _scan(elems_flat)
276
277return scans
278
279
280def trace_associative_scan(
281proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
282):
283with disable_proxy_modes_tracing():
284sample_inputs = [
285torch.empty_like(
286x,
287dtype=x.dtype,
288device=x.device,
289requires_grad=x.requires_grad,
290)
291for x in itertools.chain(input, input)
292]
293combine_graph = reenter_make_fx(combine_fn)(*sample_inputs)
294
295outputs = None
296for node in combine_graph.graph.nodes:
297if node.op == "output":
298assert outputs is None
299assert len(node.args) == 1
300outputs = node.args[0]
301
302if not all(is_pointwise_use(use) or use.op == "output" for use in node.users):
303raise ValueError(
304"For combine_mode='pointwise', the combine_fn needs to be pointwise"
305)
306
307assert outputs is not None
308assert len(outputs) == len(
309input
310), f"expected combine_fn to return {len(input)} results but got {len(outputs)}"
311
312for i, o in zip(input, outputs):
313o_meta = o.meta["tensor_meta"]
314assert o_meta.dtype == i.dtype, (
315f"combine_fn output type mismatch, expected {i.dtype} "
316+ f"but got {o_meta.dtype}"
317)
318
319_, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
320
321proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
322
323args = (combine_graph, input, dim)
324proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
325out_proxy = proxy_mode.tracer.create_proxy(
326"call_function", func_overload, proxy_args, {}, name="associative_scan"
327)
328
329with disable_proxy_modes_tracing():
330out = [aten.clone(x) for x in input]
331
332return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
333
334
335@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
336def associative_scan_op_dense(combine_fn, input, dim):
337raise NotImplementedError("associative_scan is not implemented for eager")
338
339
340associative_scan_op.py_impl(DispatchKey.Autograd)(
341autograd_not_implemented(associative_scan_op, deferred_error=True)
342)
343
344
345@associative_scan_op.py_impl(ProxyTorchDispatchMode)
346def associative_scan_proxy_mode(mode, combine_fn, input, dim):
347return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
348
349
350@associative_scan_op.py_impl(FakeTensorMode)
351def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim):
352with mode:
353return [x.clone() for x in input]
354
355
356@associative_scan_op.py_functionalize_impl
357def associative_scan_functionalize(ctx, combine_fn, input, dim):
358unwrapped_input = ctx.unwrap_tensors(input)
359with ctx.redispatch_to_next() as m:
360functional_combine_fn = ctx.functionalize(combine_fn)
361ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim)
362return ctx.wrap_tensors(ret)
363