pytorch

Форк
0
/
associative_scan.py 
362 строки · 12.8 Кб
1
# mypy: allow-untyped-defs
2
import functools
3
import itertools
4
from typing import Callable, List
5

6
import torch
7
import torch._prims_common as utils
8
import torch._subclasses.functional_tensor
9
import torch.utils._pytree as pytree
10
from torch._C import DispatchKey
11
from torch._higher_order_ops.utils import (
12
    _set_compilation_env,
13
    autograd_not_implemented,
14
    reenter_make_fx,
15
    unique_graph_id,
16
)
17
from torch._inductor.utils import is_pointwise_use
18
from torch._ops import HigherOrderOperator
19
from torch._subclasses.fake_tensor import FakeTensorMode
20
from torch.fx.experimental.proxy_tensor import (
21
    disable_proxy_modes_tracing,
22
    ProxyTorchDispatchMode,
23
    track_tensor_tree,
24
)
25

26

27
aten = torch._ops.ops.aten
28

29

30
def wrap_combine_fn_flat(*args, combine_fn, spec, num_leaves):
31
    assert len(args) == 2 * num_leaves
32
    lhs = pytree.tree_unflatten(args[:num_leaves], spec)
33
    rhs = pytree.tree_unflatten(args[num_leaves:], spec)
34
    combined = combine_fn(lhs, rhs)
35
    combined_leaves = pytree.tree_leaves(combined)
36
    assert num_leaves == len(combined_leaves)
37
    return combined_leaves
38

39

40
def _interleave(a, b, dim):
41
    # https://stackoverflow.com/questions/60869537/how-can-i-interleave-5-pytorch-tensors
42
    if b_trunc := (a.shape[dim] == b.shape[dim] + 1):
43
        pad = (
44
            [0] * ((b.ndim - dim - 1) * 2 + 1)
45
            + [1]
46
            + [0] * (b.ndim * 2 - ((b.ndim - dim - 1) * 2 + 2))
47
        )
48
        b = torch.nn.functional.pad(b, pad)
49

50
    stacked = torch.stack([a, b], dim=dim + 1)
51
    interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
52
    if b_trunc:
53
        # TODO: find torch alternative for slice_along dim for torch.jit.script to work
54
        interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
55
    return interleaved
56

57

58
def safe_map(f, *args):
59
    args = list(map(list, args))
60
    n = len(args[0])
61
    for arg in args[1:]:
62
        if len(arg) != n:
63
            raise ValueError("length mismatch: {list(map(len, args))}")
64

65
    def nf(a):
66
        return f(*a)
67

68
    return list(map(nf, zip(*args)))
69

70

71
class AssociativeScanOp(HigherOrderOperator):
72
    def __init__(self):
73
        super().__init__("associative_scan")
74

75
    def __call__(self, combine_fn, input, dim):
76
        return super().__call__(combine_fn, input, dim)
77

78

79
associative_scan_op = AssociativeScanOp()
80

81

82
def associative_scan(
83
    combine_fn: Callable[[pytree.PyTree, pytree.PyTree], pytree.PyTree],
84
    input: pytree.PyTree,
85
    dim: int,
86
    reverse: bool = False,
87
    combine_mode: str = "pointwise",
88
) -> torch.Tensor:
89
    r"""
90
    Performs an inclusive scan with an associative pointwise combine function.
91

92
    .. warning::
93
        `torch.associative_scan` is a prototype feature in PyTorch. It currently
94
        does not support autograd and you may run into miscompiles.
95
        Read more about feature classification at:
96
        https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
97

98
    This operator requires runtime code generation and so requires support for
99
    ``torch.compile``. Further, only CUDA device codegen is supported at the moment.
100

101
    Args:
102
        combine_fn (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
103
            or if input is a pytree ``(pytree, pytree) -> pytree``.
104
            This function must be pure, pointwise, and satisfy the associative property.
105
        input (torch.Tensor): The input tensor, or nested pytree of tensors.
106
            All inputs are expected to have the same shape.
107
        dim (int): the dimension to scan over
108
        reverse (bool): A boolean stating if the scan should be reversed with respect to the dimension.
109
        combine_mode (str): A string indicating whether the ``combine_fn`` is ``pointwise`` or ``generic``.
110
            If ``combine_mode=pointwise``, ``combine_fn`` must be pure, may only contain pointwise operations
111
            and ``input`` must be CUDA tensors.
112
            In all other cases ``combine_mode=generic`` should be used.
113
            Note: ``combine_mode=pointwise`` is more efficient than ``combine_mode=generic``.
114

115

116
    Example::
117

118
        def add(x: torch.Tensor, y: torch.Tensor):
119
            return x + y
120

121
        cumsum = associative_scan(add, x, dim)
122

123
    """
124
    assert callable(combine_fn), "combine_fn must be a callable, but got {combine_fn}"
125
    assert isinstance(dim, int), "dim must be an int, but got {type(dim)}"
126
    assert combine_mode in ["pointwise", "generic"]
127

128
    if not torch._dynamo.is_compiling():
129
        with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
130
            return torch.compile(associative_scan, fullgraph=True)(
131
                combine_fn, input, dim, reverse=reverse, combine_mode=combine_mode
132
            )
133

134
    leaves, spec = pytree.tree_flatten(input)
135

136
    if combine_mode == "pointwise" and not all(l.device.type == "cuda" for l in leaves):
137
        raise ValueError(
138
            "For combine_mode='pointwise', all input tensors need to be on CUDA"
139
        )
140

141
    assert len(leaves) >= 1, "expected at least 1 input leaf"
142
    assert all(
143
        isinstance(x, torch.Tensor) for x in leaves
144
    ), "input leaves must be a Tensor"
145

146
    if reverse:
147
        leaves = [torch.flip(elem, [dim]) for elem in leaves]
148

149
    shape = leaves[0].shape
150
    ndim = len(shape)
151
    dim = utils.canonicalize_dim(ndim, dim)
152

153
    for x in leaves[1:]:
154
        assert x.shape == shape, "All input tensors must have the same shape"
155

156
    out = combine_fn(
157
        pytree.tree_unflatten(leaves, spec),
158
        pytree.tree_unflatten(leaves, spec),
159
    )
160
    out_leaves, tree_out = pytree.tree_flatten(out)
161
    assert len(leaves) == len(
162
        out_leaves
163
    ), "The pytree of the output of the operator needs to match the input pytree"
164
    for x in out_leaves:
165
        assert (
166
            x.shape == shape
167
        ), "The pytree of the output of the operator needs to match the input pytree"
168

169
    combine_fn = functools.partial(
170
        wrap_combine_fn_flat, combine_fn=combine_fn, spec=spec, num_leaves=len(leaves)
171
    )
172

173
    if combine_mode == "generic":
174
        result_flat = generic_associative_scan(combine_fn, leaves, dim)
175
    else:
176
        result_flat = associative_scan_op(combine_fn, leaves, dim)
177

178
    if reverse:
179
        result_flat = [torch.flip(elem, [dim]) for elem in result_flat]
180

181
    return pytree.tree_unflatten(result_flat, spec)
182

183

184
def generic_associative_scan(operator, elems_flat, dim=0):
185
    r"""
186
    This function performs the associative_scan operation.
187
    The algorithm works by recursively collecting neighbours of ``elems_flat`` and subsequently
188
    applying the ``operator`` on all pairs in parallel along ``dim``.
189
    The results of the recursive calls are later combined.
190

191
    Args:
192
        operator (Callable): A binary callable with type ``(Tensor, Tensor) -> Tensor``,
193
            or if input is a pytree ``(pytree, pytree) -> pytree``.
194
            This function must be pure, pointwise, and satisfy the associative property.
195
        elems_flat (torch.Tensor): A list of torch.Tensors converted from the pytree of
196
            ``input`` provided to ``associative_scan``.
197
            All inputs are expected to have the same shape.
198
        dim (int): the dimension to scan over
199

200

201
    Example::
202

203
        def add(x: torch.Tensor, y: torch.Tensor):
204
            return x + y
205

206
        elems_flat = torch.tensor([0.0, 1.0, 2.0, 3.0])
207

208
        First iteration of _scan ->
209
            # odd_elems -> apply operator on all neighbours
210
            # odd_elems = operator([torch.tensor([0.0, 2.0])],
211
            #                      [torch.tensor([1.0, 3.0])])
212
            odd_elems = torch.tensor([1.0, 5.0])
213
            Second iteration of _scan ->
214
                # odd_elems = operator([torch.tensor([1.0])],
215
                #                      [torch.tensor([5.0])])
216
                odd_elems = torch.tensor([6.0])
217
                # even_elems -> apply operator on all odd_elems and
218
                # every second element of ``elems``, starting from the second element.
219
                # even_elems is expanded with the first element of ``elems``
220
                even_elems = [1.0]
221
                # Merges odd_elems and even_elems
222
                res = torch.tensor([1.0, 6.0])
223
            # even_elems -> apply operator on all odd_elems and
224
            # every second element of ``elems``, starting from the second element.
225
            # even_elems is expanded with the first element of ``elems``
226
            even_elems = [0.0, 3.0]
227
            # Merges odd_elems and even_elems
228
            res = torch.tensor([0.0, 1.0, 3.0, 6.0])
229

230
    """
231

232
    def _scan(elems):
233
        """Perform the actual recursive scan on ``elems``."""
234
        num_elems = elems[0].shape[dim]
235

236
        if num_elems < 2:
237
            return elems
238

239
        reduced_elems = operator(
240
            *[aten.slice(elem, dim, 0, -1, 2) for elem in elems],
241
            *[aten.slice(elem, dim, 1, None, 2) for elem in elems],
242
        )
243

244
        # Recursively compute scan for partially reduced tensors.
245
        odd_elems = _scan(reduced_elems)
246

247
        if num_elems % 2 == 0:
248
            even_elems = operator(
249
                *[aten.slice(e, dim, 0, -1) for e in odd_elems],
250
                *[aten.slice(e, dim, 2, None, 2) for e in elems],
251
            )
252
        else:
253
            even_elems = operator(
254
                *odd_elems,
255
                *[aten.slice(e, dim, 2, None, 2) for e in elems],
256
            )
257

258
        # The first element of a scan is the same as the first element
259
        # of the original `elems`.
260
        even_elems = [
261
            torch.cat([aten.slice(elem, dim, 0, 1), result], dim=dim)
262
            if result.shape.numel() > 0 and elem.shape[dim] > 0
263
            else result
264
            if result.shape.numel() > 0
265
            else aten.slice(
266
                elem, dim, 0, 1
267
            )  # Jax allows/ignores concat with 0-dim, Pytorch does not
268
            for (elem, result) in zip(elems, even_elems)
269
        ]
270

271
        return list(
272
            safe_map(functools.partial(_interleave, dim=dim), even_elems, odd_elems)
273
        )
274

275
    scans = _scan(elems_flat)
276

277
    return scans
278

279

280
def trace_associative_scan(
281
    proxy_mode, func_overload, combine_fn: Callable, input: List[torch.Tensor], dim: int
282
):
283
    with disable_proxy_modes_tracing():
284
        sample_inputs = [
285
            torch.empty_like(
286
                x,
287
                dtype=x.dtype,
288
                device=x.device,
289
                requires_grad=x.requires_grad,
290
            )
291
            for x in itertools.chain(input, input)
292
        ]
293
        combine_graph = reenter_make_fx(combine_fn)(*sample_inputs)
294

295
    outputs = None
296
    for node in combine_graph.graph.nodes:
297
        if node.op == "output":
298
            assert outputs is None
299
            assert len(node.args) == 1
300
            outputs = node.args[0]
301

302
        if not all(is_pointwise_use(use) or use.op == "output" for use in node.users):
303
            raise ValueError(
304
                "For combine_mode='pointwise', the combine_fn needs to be pointwise"
305
            )
306

307
    assert outputs is not None
308
    assert len(outputs) == len(
309
        input
310
    ), f"expected combine_fn to return {len(input)} results but got {len(outputs)}"
311

312
    for i, o in zip(input, outputs):
313
        o_meta = o.meta["tensor_meta"]
314
        assert o_meta.dtype == i.dtype, (
315
            f"combine_fn output type mismatch, expected {i.dtype} "
316
            + f"but got {o_meta.dtype}"
317
        )
318

319
    _, combine_graph_name = unique_graph_id(proxy_mode, prefix="scan_combine_graph")
320

321
    proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph)
322

323
    args = (combine_graph, input, dim)
324
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
325
    out_proxy = proxy_mode.tracer.create_proxy(
326
        "call_function", func_overload, proxy_args, {}, name="associative_scan"
327
    )
328

329
    with disable_proxy_modes_tracing():
330
        out = [aten.clone(x) for x in input]
331

332
    return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
333

334

335
@associative_scan_op.py_impl(DispatchKey.CompositeExplicitAutograd)
336
def associative_scan_op_dense(combine_fn, input, dim):
337
    raise NotImplementedError("associative_scan is not implemented for eager")
338

339

340
associative_scan_op.py_impl(DispatchKey.Autograd)(
341
    autograd_not_implemented(associative_scan_op, deferred_error=True)
342
)
343

344

345
@associative_scan_op.py_impl(ProxyTorchDispatchMode)
346
def associative_scan_proxy_mode(mode, combine_fn, input, dim):
347
    return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
348

349

350
@associative_scan_op.py_impl(FakeTensorMode)
351
def assoiciative_scan_fake_tensor_mode(mode, combine_fn, input, dim):
352
    with mode:
353
        return [x.clone() for x in input]
354

355

356
@associative_scan_op.py_functionalize_impl
357
def associative_scan_functionalize(ctx, combine_fn, input, dim):
358
    unwrapped_input = ctx.unwrap_tensors(input)
359
    with ctx.redispatch_to_next() as m:
360
        functional_combine_fn = ctx.functionalize(combine_fn)
361
        ret = associative_scan_op(functional_combine_fn, unwrapped_input, dim)
362
    return ctx.wrap_tensors(ret)
363

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.