pytorch

Форк
0
/
flex_attention.py 
996 строк · 33.5 Кб
1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
3
import math
4
from typing import Any, Callable, Dict, Tuple, Union
5

6
import torch
7
import torch.utils._pytree as pytree
8
from torch._C import DispatchKey
9
from torch._higher_order_ops.utils import (
10
    _has_potential_branch_input_mutation,
11
    autograd_not_implemented,
12
    reenter_make_fx,
13
    UnsupportedAliasMutationException,
14
)
15
from torch._ops import HigherOrderOperator
16
from torch._subclasses import FakeTensorMode
17
from torch.fx.experimental.proxy_tensor import (
18
    make_fx,
19
    ProxyTorchDispatchMode,
20
    track_tensor_tree,
21
)
22
from torch.fx.graph_module import GraphModule
23
from torch.overrides import TorchFunctionMode
24

25

26
class TransformGetItemToIndex(TorchFunctionMode):
27
    # This is needed since we want to support calling
28
    # A[q_idx], where q_idx is a scalar tensor in score_mod.
29
    # Today, when q_idx is a scalar tensor, we implicitly convert it to a python
30
    # scalar and create a view. We do not want that behavior in this case, so we
31
    # use this torchfunctionmode to override that behavior for score_mod
32
    # wherever we're running it.
33
    def __torch_function__(self, func, types, args, kwargs=None):
34
        if func == torch.Tensor.__getitem__:
35
            index_args = pytree.tree_leaves(args[1])
36
            if all(isinstance(x, torch.Tensor) for x in index_args):
37
                return torch.ops.aten.index(args[0], index_args)
38
        return func(*args, **(kwargs or {}))
39

40

41
class FlexAttentionHOP(HigherOrderOperator):
42
    def __init__(self) -> None:
43
        super().__init__("flex_attention")
44

45
    def __call__(
46
        self,
47
        query: torch.Tensor,
48
        key: torch.Tensor,
49
        value: torch.Tensor,
50
        score_mod: Callable,
51
        block_mask: Tuple,
52
        scale: float,
53
        kernel_options: Dict[str, Any],
54
        score_mod_other_buffers: Tuple = (),
55
        mask_mod_other_buffers: Tuple = (),
56
    ) -> Tuple[torch.Tensor, torch.Tensor]:
57
        if not all(
58
            isinstance(buf, torch.Tensor)
59
            for buf in score_mod_other_buffers + mask_mod_other_buffers
60
        ):
61
            raise RuntimeError("Other buffers must be tensors.")
62
        return super().__call__(
63
            query,
64
            key,
65
            value,
66
            score_mod,
67
            block_mask,
68
            scale,
69
            kernel_options,
70
            score_mod_other_buffers,
71
            mask_mod_other_buffers,
72
        )
73

74

75
flex_attention = FlexAttentionHOP()
76

77

78
class FlexAttentionBackwardHOP(HigherOrderOperator):
79
    def __init__(self) -> None:
80
        super().__init__("flex_attention_backward")
81

82
    def __call__(
83
        self,
84
        query: torch.Tensor,
85
        key: torch.Tensor,
86
        value: torch.Tensor,
87
        out: torch.Tensor,
88
        logsumexp: torch.Tensor,
89
        grad_out: torch.Tensor,
90
        grad_logsumexp: torch.Tensor,
91
        fw_graph: Union[Callable, GraphModule],
92
        joint_graph: GraphModule,
93
        block_mask: Tuple,
94
        scale: float,
95
        kernel_options: Dict[str, Any],
96
        score_mod_other_buffers: Tuple = (),
97
        mask_mod_other_buffers: Tuple = (),
98
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
99
        if not all(
100
            isinstance(buf, torch.Tensor)
101
            for buf in score_mod_other_buffers + mask_mod_other_buffers
102
        ):
103
            raise RuntimeError("Other buffers must be tensors.")
104
        return super().__call__(
105
            query,
106
            key,
107
            value,
108
            out,
109
            logsumexp,
110
            grad_out,
111
            grad_logsumexp,
112
            fw_graph,
113
            joint_graph,
114
            block_mask,
115
            scale,
116
            kernel_options,
117
            score_mod_other_buffers,
118
            mask_mod_other_buffers,
119
        )
120

121

122
flex_attention_backward = FlexAttentionBackwardHOP()
123

124

125
def _math_attention_inner(
126
    query: torch.Tensor,
127
    key: torch.Tensor,
128
    value: torch.Tensor,
129
    score_mod: Callable,
130
    block_mask: Tuple,
131
    scale: float,
132
    kernel_options: Dict[str, Any],
133
    score_mod_other_buffers: Tuple = (),
134
    mask_mod_other_buffers: Tuple = (),
135
) -> Tuple[torch.Tensor, torch.Tensor]:
136
    working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32
137

138
    scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)
139

140
    b = torch.arange(0, scores.size(0), device=scores.device)
141
    h = torch.arange(0, scores.size(1), device=scores.device)
142
    m = torch.arange(0, scores.size(2), device=scores.device)
143
    n = torch.arange(0, scores.size(3), device=scores.device)
144

145
    captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
146
    from torch.nn.attention.flex_attention import _vmap_for_bhqkv
147

148
    # first input is score
149
    score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
150

151
    mask_mod = block_mask[-1]
152
    mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)
153
    mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)
154

155
    with TransformGetItemToIndex():
156
        scores = (scores * scale).to(working_precision)
157
        post_mod_scores = torch.where(
158
            mask_mod(b, h, m, n, *mask_mod_other_buffers),
159
            score_mod(scores, b, h, m, n, *score_mod_other_buffers),
160
            torch.tensor(-float("inf"), dtype=working_precision, device=scores.device),
161
        )
162

163
    return scores, post_mod_scores
164

165

166
def math_attention(
167
    query: torch.Tensor,
168
    key: torch.Tensor,
169
    value: torch.Tensor,
170
    score_mod: Callable,
171
    block_mask: Tuple,
172
    scale: float,
173
    kernel_options: Dict[str, Any],
174
    score_mod_other_buffers: Tuple = (),
175
    mask_mod_other_buffers: Tuple = (),
176
) -> Tuple[torch.Tensor, torch.Tensor]:
177
    """Eager implementation
178

179
    This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
180
    We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the
181
    batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions.
182

183
    Args:
184
        query: The query tensor
185
        key: The key tensor
186
        value: The value tensor
187
        score_mod: The score_mod function
188
        other_buffers: Other buffers that are passed to the score_mod function
189
    """
190
    # broadcast query & key along head dim for GQA
191
    G = query.size(1) // key.size(1)
192
    value = torch.repeat_interleave(value, G, dim=1)
193
    key = torch.repeat_interleave(key, G, dim=1)
194

195
    _, post_mod_scores = _math_attention_inner(
196
        query,
197
        key,
198
        value,
199
        score_mod,
200
        block_mask,
201
        scale,
202
        kernel_options,
203
        score_mod_other_buffers,
204
        mask_mod_other_buffers,
205
    )
206

207
    # Set fully masked rows' sumexp to 0.0
208
    logsumexp = post_mod_scores.logsumexp(dim=-1)
209
    masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)
210
    logsumexp = torch.where(masked_rows, -float("inf"), logsumexp)
211

212
    post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)
213

214
    return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)
215

216

217
@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd)
218
def sdpa_dense(
219
    query: torch.Tensor,
220
    key: torch.Tensor,
221
    value: torch.Tensor,
222
    score_mod: Callable,
223
    block_mask: Tuple,
224
    scale: float,
225
    kernel_options: Dict[str, Any],
226
    score_mod_other_buffers: Tuple = (),
227
    mask_mod_other_buffers: Tuple = (),
228
) -> Tuple[torch.Tensor, torch.Tensor]:
229
    out, lse = math_attention(
230
        query,
231
        key,
232
        value,
233
        score_mod,
234
        block_mask,
235
        scale,
236
        kernel_options,
237
        score_mod_other_buffers,
238
        mask_mod_other_buffers,
239
    )
240
    out = out.contiguous()
241
    return out, lse
242

243

244
def trace_flex_attention(
245
    proxy_mode: ProxyTorchDispatchMode,
246
    query: torch.Tensor,
247
    key: torch.Tensor,
248
    value: torch.Tensor,
249
    score_mod: Callable,
250
    block_mask: Tuple,
251
    scale: float,
252
    kernel_options: Dict[str, Any],
253
    score_mod_other_buffers: Tuple = (),
254
    mask_mod_other_buffers: Tuple = (),
255
) -> Tuple[torch.Tensor, torch.Tensor]:
256
    """Traces the flex_attention operator with the given score_mod function and other_buffers.
257

258
    Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
259
    This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
260
    access this graph module in inductor to inline the score_mod function to the triton template.
261
    """
262
    example_out = flex_attention(
263
        query,
264
        key,
265
        value,
266
        score_mod,
267
        block_mask,
268
        scale,
269
        kernel_options,
270
        score_mod_other_buffers,
271
        mask_mod_other_buffers,
272
    )
273
    example_vals = [
274
        torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
275
    ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
276
    mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
277
    mask_mod = block_mask[-1]
278
    with TransformGetItemToIndex():
279
        score_graph = reenter_make_fx(score_mod)(
280
            *example_vals, *score_mod_other_buffers
281
        )
282
        mask_graph = reenter_make_fx(mask_mod)(
283
            *mask_example_vals, *mask_mod_other_buffers
284
        )
285
    assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
286
    block_mask = block_mask[:-1] + (mask_graph,)
287
    qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score")
288
    proxy_mode.tracer.root.register_module(qualname, score_graph)
289
    mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask")
290
    proxy_mode.tracer.root.register_module(mask_qualname, mask_graph)
291
    node_args = (
292
        query,
293
        key,
294
        value,
295
        score_graph,
296
        block_mask,
297
        scale,
298
        kernel_options,
299
        score_mod_other_buffers,
300
        mask_mod_other_buffers,
301
    )
302
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
303
    out_proxy = proxy_mode.tracer.create_proxy(
304
        "call_function", flex_attention, proxy_args, {}
305
    )
306
    return track_tensor_tree(
307
        example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
308
    )
309

310

311
@flex_attention.py_impl(ProxyTorchDispatchMode)
312
def flex_attention_proxy_torch_dispatch_mode(
313
    mode: ProxyTorchDispatchMode,
314
    query: torch.Tensor,
315
    key: torch.Tensor,
316
    value: torch.Tensor,
317
    score_mod: Callable,
318
    block_mask: Tuple,
319
    scale: float,
320
    kernel_options: Dict[str, Any],
321
    score_mod_other_buffers: Tuple = (),
322
    mask_mod_other_buffers: Tuple = (),
323
) -> Tuple[torch.Tensor, torch.Tensor]:
324
    assert mode is not None, "Mode should always be enabled for python fallback key"
325
    return trace_flex_attention(
326
        mode,
327
        query,
328
        key,
329
        value,
330
        score_mod,
331
        block_mask,
332
        scale,
333
        kernel_options,
334
        score_mod_other_buffers,
335
        mask_mod_other_buffers,
336
    )
337

338

339
@flex_attention.py_functionalize_impl
340
def flex_attention_functionalize(
341
    ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
342
    query: torch.Tensor,
343
    key: torch.Tensor,
344
    value: torch.Tensor,
345
    score_mod: Callable,
346
    block_mask: Tuple,
347
    scale: float,
348
    kernel_options: Dict[str, Any],
349
    score_mod_other_buffers: Tuple = (),
350
    mask_mod_other_buffers: Tuple = (),
351
) -> Tuple[torch.Tensor, torch.Tensor]:
352
    """Defines the functionalization rules for the flex_attention operator.
353

354
    Write now we are unwrapping each tensor and then redispatching to the next, however we want to
355
    guard against any mutations in the score_mod function, to the other_buffers since those
356
    are free variables.
357
    """
358
    query_unwrapped = ctx.unwrap_tensors(query)
359
    key_unwrapped = ctx.unwrap_tensors(key)
360
    value_unwrapped = ctx.unwrap_tensors(value)
361
    block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
362
    score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
363
    mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
364

365
    # Appease the mypy overlords
366
    assert isinstance(query_unwrapped, torch.Tensor)
367
    assert isinstance(key_unwrapped, torch.Tensor)
368
    assert isinstance(value_unwrapped, torch.Tensor)
369
    assert isinstance(block_mask_unwrapped, tuple)
370
    assert isinstance(score_mod_other_buffers_unwrapped, tuple)
371
    assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
372
    assert all(
373
        isinstance(item, torch.Tensor)
374
        for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
375
    )
376

377
    example_vals = (
378
        [torch.zeros((), dtype=query.dtype)]
379
        + [torch.zeros((), dtype=torch.int) for _ in range(4)]
380
        + list(score_mod_other_buffers_unwrapped)
381
    )
382
    with ctx.redispatch_to_next() as m:
383
        functional_score_mod = ctx.functionalize(score_mod)
384
        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
385
        with TransformGetItemToIndex():
386
            mutates = _has_potential_branch_input_mutation(
387
                functional_score_mod, example_vals, pre_dispatch
388
            )
389
        # The only care about mutations of existing buffers since we can't replay these.
390
        # However, we can just error if anything is detected
391
        if mutates:
392
            raise UnsupportedAliasMutationException("Mutations detected in score_mod")
393

394
        out = flex_attention(
395
            query_unwrapped,
396
            key_unwrapped,
397
            value_unwrapped,
398
            functional_score_mod,
399
            block_mask_unwrapped,
400
            scale,
401
            kernel_options,
402
            score_mod_other_buffers_unwrapped,
403
            mask_mod_other_buffers_unwrapped,
404
        )
405
    return ctx.wrap_tensors(out)  # type: ignore[return-value, arg-type]
406

407

408
@flex_attention.py_impl(FakeTensorMode)
409
def flex_attention_fake_tensor_mode(
410
    mode: FakeTensorMode,
411
    query: torch.Tensor,
412
    key: torch.Tensor,
413
    value: torch.Tensor,
414
    score_mod: Callable,
415
    block_mask: Tuple,
416
    scale: float,
417
    kernel_options: Dict[str, Any],
418
    score_mod_other_buffers: Tuple = (),
419
    mask_mod_other_buffers: Tuple = (),
420
) -> Tuple[torch.Tensor, torch.Tensor]:
421
    with mode:
422
        v_head_dim = value.size(-1)
423
        batch_size, num_heads, seq_len_q, q_head_dim = query.shape
424
        logsumexp = query.new_empty(
425
            batch_size, num_heads, seq_len_q, dtype=torch.float32
426
        )
427
        out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)
428
        return query.new_empty(out_shape), logsumexp
429

430

431
# ---------------------------- Autograd Implementation ----------------------------
432
def create_fw_bw_graph(score_mod, index_values, other_buffers):
433
    # See Note:[HOP create fw_bw graph]
434

435
    # All of these imports need to be here in order to avoid circular dependencies
436
    from torch._dispatch.python import suspend_functionalization
437
    from torch._functorch.aot_autograd import AOTConfig, create_joint
438
    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
439
    from torch._subclasses.functional_tensor import disable_functional_mode
440
    from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
441

442
    dummy_aot_config = AOTConfig(
443
        fw_compiler=None,  # type: ignore[arg-type]
444
        bw_compiler=None,  # type: ignore[arg-type]
445
        partition_fn=None,  # type: ignore[arg-type]
446
        decompositions={},
447
        num_params_buffers=0,
448
        aot_id=0,
449
        keep_inference_input_mutations=False,
450
    )
451

452
    with suspend_functionalization(), disable_functional_mode():
453
        with disable_proxy_modes_tracing():
454

455
            def _from_fun(t):
456
                return torch.empty_strided(
457
                    t.size(),
458
                    t.stride(),
459
                    device=t.device,
460
                    dtype=t.dtype,
461
                    requires_grad=t.requires_grad,
462
                )
463

464
            # If someone runs this hop under the default compiler backend ("eager")
465
            # Then this path will be run with the actual user inputs. We convert them
466
            # to fake tensors in order to not perform any actual compute.
467
            from torch._guards import detect_fake_mode
468

469
            fake_mode = detect_fake_mode(index_values)
470
            if fake_mode is None:
471
                fake_mode = FakeTensorMode(allow_non_fake_inputs=True)
472

473
            with fake_mode:
474
                unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
475
                unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
476

477
            assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)
478
            assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)
479

480
            example_flat_out = pytree.tree_map(
481
                _from_fun,
482
                score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers),
483
            )
484
            if not isinstance(example_flat_out, torch.Tensor):
485
                raise RuntimeError(
486
                    "Expected output of score_mod to be a tensor."
487
                    f"Got type {type(example_flat_out)}."
488
                )
489
            example_grad = _from_fun(example_flat_out)
490

491
        def joint_f(score, b, h, m, n, example_grad, *other_buffers):
492
            def fw_with_masks(*args):
493
                fw_out = score_mod(*args)
494
                out_requires_grad = fw_out.requires_grad
495
                return ((fw_out,), (out_requires_grad,))
496

497
            joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
498
            args = [score, b, h, m, n] + list(other_buffers)
499
            optional_grad = [example_grad] if example_grad.requires_grad else []
500
            _, grads = joint(args, optional_grad)
501

502
            return grads
503

504
        joint_graph = make_fx(joint_f)(
505
            *unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers
506
        )
507
        return score_mod, joint_graph
508

509

510
class FlexAttentionAutogradOp(torch.autograd.Function):
511
    @staticmethod
512
    def forward(
513
        ctx,
514
        query,
515
        key,
516
        value,
517
        fw_graph,
518
        joint_graph,
519
        block_mask,
520
        scale,
521
        kernel_options,
522
        score_mod_other_buffers,
523
        mask_mod_other_buffers,
524
    ) -> Tuple[torch.Tensor, torch.Tensor]:
525
        any_buffer_requires_grad = any(
526
            buffer.requires_grad
527
            for buffer in score_mod_other_buffers + mask_mod_other_buffers
528
        )
529
        assert (
530
            not any_buffer_requires_grad
531
        ), "Captured buffers that require grad are not yet supported."
532
        ctx._fw_graph = fw_graph
533
        ctx._joint_graph = joint_graph
534
        ctx._mask_graph = block_mask[-1]
535
        # KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward
536
        ctx._KV_BLOCK_SIZE = block_mask[8]
537
        ctx._Q_BLOCK_SIZE = block_mask[9]
538
        ctx.scale = scale
539
        ctx.kernel_options = kernel_options
540
        ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)
541
        with torch._C._AutoDispatchBelowAutograd():
542
            out, logsumexp = flex_attention(
543
                query,
544
                key,
545
                value,
546
                fw_graph,
547
                block_mask,
548
                scale,
549
                kernel_options,
550
                score_mod_other_buffers,
551
                mask_mod_other_buffers,
552
            )
553

554
        ctx.save_for_backward(
555
            query,
556
            key,
557
            value,
558
            out,
559
            logsumexp,
560
            *block_mask[:8],
561
            *score_mod_other_buffers,
562
            *mask_mod_other_buffers,
563
        )
564
        return out, logsumexp
565

566
    @staticmethod
567
    def backward(ctx, grad_out, grad_logsumexp):
568
        fw_args = ctx.saved_tensors
569
        (
570
            query,
571
            key,
572
            value,
573
            out,
574
            logsumexp,
575
            kv_num_blocks,
576
            kv_indices,
577
            full_kv_num_blocks,
578
            full_kv_indices,
579
            q_num_blocks,
580
            q_indices,
581
            full_q_num_blocks,
582
            full_q_indices,
583
            *other_buffers,
584
        ) = fw_args
585
        fw_graph = ctx._fw_graph
586
        joint_graph = ctx._joint_graph
587
        mask_graph = ctx._mask_graph
588
        KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE
589
        Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE
590
        scale = ctx.scale
591
        kernel_options = ctx.kernel_options
592
        score_mod_other_buffers = tuple(
593
            other_buffers[: ctx._score_mod_other_buffers_len]
594
        )
595
        mask_mod_other_buffers = tuple(
596
            other_buffers[ctx._score_mod_other_buffers_len :]
597
        )
598
        # We have asserted that other_buffers do not require grad in the forward
599
        none_grads = [None] * 7
600
        grad_query, grad_key, grad_value = flex_attention_backward(
601
            query,
602
            key,
603
            value,
604
            out,
605
            logsumexp,
606
            grad_out,
607
            grad_logsumexp,
608
            fw_graph,
609
            joint_graph,
610
            (
611
                kv_num_blocks,
612
                kv_indices,
613
                full_kv_num_blocks,
614
                full_kv_indices,
615
                q_num_blocks,
616
                q_indices,
617
                full_q_num_blocks,
618
                full_q_indices,
619
                KV_BLOCK_SIZE,
620
                Q_BLOCK_SIZE,
621
                mask_graph,
622
            ),
623
            scale,
624
            kernel_options,
625
            score_mod_other_buffers,
626
            mask_mod_other_buffers,
627
        )
628
        return grad_query, grad_key, grad_value, *none_grads
629

630

631
@flex_attention.py_impl(DispatchKey.Autograd)
632
def flex_attention_autograd(
633
    query: torch.Tensor,
634
    key: torch.Tensor,
635
    value: torch.Tensor,
636
    score_mod: Callable,
637
    block_mask: Tuple,
638
    scale: float,
639
    kernel_options: Dict[str, Any],
640
    score_mod_other_buffers: Tuple = (),
641
    mask_mod_other_buffers: Tuple = (),
642
) -> Tuple[torch.Tensor, torch.Tensor]:
643
    with TransformGetItemToIndex():
644
        input_requires_grad = any(t.requires_grad for t in (query, key, value))
645
        if torch.is_grad_enabled() and input_requires_grad:
646
            example_vals = [
647
                torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
648
            ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
649
            fw_graph, bw_graph = create_fw_bw_graph(
650
                score_mod, example_vals, score_mod_other_buffers
651
            )
652
        else:
653
            fw_graph, bw_graph = score_mod, None
654
        out, logsumexp = FlexAttentionAutogradOp.apply(
655
            query,
656
            key,
657
            value,
658
            fw_graph,
659
            bw_graph,
660
            block_mask,
661
            scale,
662
            kernel_options,
663
            score_mod_other_buffers,
664
            mask_mod_other_buffers,
665
        )
666
    return out, logsumexp
667

668

669
# ---------------------------- Backward HOP Implementation ----------------------------
670

671

672
@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd)
673
def sdpa_dense_backward(
674
    query: torch.Tensor,
675
    key: torch.Tensor,
676
    value: torch.Tensor,
677
    out: torch.Tensor,
678
    logsumexp: torch.Tensor,
679
    grad_out: torch.Tensor,
680
    grad_logsumexp: torch.Tensor,
681
    fw_graph: Callable,  # GraphModule type hint?
682
    joint_graph: Callable,
683
    block_mask: Tuple,
684
    scale: float,
685
    kernel_options: Dict[str, Any],
686
    score_mod_other_buffers: Tuple,
687
    mask_mod_other_buffers: Tuple,
688
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
689
    G = query.size(1) // key.size(1)
690
    key = torch.repeat_interleave(key, G, dim=1)
691
    value = torch.repeat_interleave(value, G, dim=1)
692

693
    # We're undoing the log -> log2 change of base in the forwards
694
    logsumexp = logsumexp * math.log(2)
695
    # The backwards formula for the log -> log2 change of base in the forwards
696
    grad_logsumexp = grad_logsumexp / math.log(2)
697
    scores, post_mod_scores = _math_attention_inner(
698
        query,
699
        key,
700
        value,
701
        fw_graph,
702
        block_mask,
703
        scale,
704
        kernel_options,
705
        score_mod_other_buffers,
706
        mask_mod_other_buffers,
707
    )
708
    masked_out_rows = logsumexp == -float("inf")
709
    softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1))
710
    softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores)
711

712
    grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out
713

714
    grad_softmax_scores = grad_out @ value.transpose(-2, -1)
715

716
    sum_scores = torch.sum(out * grad_out, -1, keepdim=True)
717
    grad_score_mod = softmax_scores * (
718
        grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)
719
    )
720

721
    b = torch.arange(0, scores.size(0), device=scores.device)
722
    h = torch.arange(0, scores.size(1), device=scores.device)
723
    m = torch.arange(0, scores.size(2), device=scores.device)
724
    n = torch.arange(0, scores.size(3), device=scores.device)
725

726
    mask_graph = block_mask[-1]
727
    # Gradient of the inline score_mod function, with respect to the scores
728
    captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
729
    out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
730
    from torch.nn.attention.flex_attention import _vmap_for_bhqkv
731

732
    # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
733
    # score and gradOut are "fully" batched
734
    joint_score_mod = _vmap_for_bhqkv(
735
        joint_graph,
736
        prefix=(0,),
737
        suffix=(0,) + captured_buffers_in_dim,
738
        out_dims=out_dims,
739
    )
740
    with TransformGetItemToIndex():
741
        grad_scores, *_ = joint_score_mod(
742
            scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers
743
        )
744
    grad_scores = grad_scores * scale
745
    grad_scores = grad_scores.to(query.dtype)
746

747
    mask_mod = _vmap_for_bhqkv(
748
        mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers)
749
    )
750
    with TransformGetItemToIndex():
751
        mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers)
752
        grad_scores = torch.where(
753
            mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype)
754
        )
755

756
    grad_query = grad_scores @ key
757
    grad_key = grad_scores.transpose(-2, -1) @ query
758

759
    # Reduce DK, DV along broadcasted heads.
760
    grad_key = grad_key.view(
761
        grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)
762
    )
763
    grad_value = grad_value.view(
764
        grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)
765
    )
766

767
    grad_key = torch.sum(grad_key, 2, keepdim=False)
768
    grad_value = torch.sum(grad_value, 2, keepdim=False)
769

770
    return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous()
771

772

773
def trace_flex_attention_backward(
774
    proxy_mode: ProxyTorchDispatchMode,
775
    query: torch.Tensor,
776
    key: torch.Tensor,
777
    value: torch.Tensor,
778
    out: torch.Tensor,
779
    logsumexp: torch.Tensor,
780
    grad_out: torch.Tensor,
781
    grad_logsumexp: torch.Tensor,
782
    fw_graph: Union[Callable, GraphModule],
783
    joint_graph: GraphModule,
784
    block_mask: Tuple,
785
    scale: float,
786
    kernel_options: Dict[str, Any],
787
    score_mod_other_buffers: Tuple = (),
788
    mask_mod_other_buffers: Tuple = (),
789
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
790
    """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""
791
    example_out = flex_attention_backward(
792
        query,
793
        key,
794
        value,
795
        out,
796
        logsumexp,
797
        grad_out,
798
        grad_logsumexp,
799
        fw_graph,
800
        joint_graph,
801
        block_mask,
802
        scale,
803
        kernel_options,
804
        score_mod_other_buffers,
805
        mask_mod_other_buffers,
806
    )
807

808
    fw_example_vals = [
809
        torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
810
    ] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
811
    bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]
812
    mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]
813
    mask_graph = block_mask[-1]
814
    with TransformGetItemToIndex():
815
        fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)
816
        joint_graph = reenter_make_fx(joint_graph)(
817
            *bw_example_vals, *score_mod_other_buffers
818
        )
819
        mask_graph = reenter_make_fx(mask_graph)(
820
            *mask_example_vals, *mask_mod_other_buffers
821
        )
822
    assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
823
    block_mask = block_mask[:-1] + (mask_graph,)
824
    proxy_mode.tracer.root.register_module("fw_graph", fw_graph)  # type: ignore[arg-type]
825
    proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
826
    proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
827
    node_args = (
828
        query,
829
        key,
830
        value,
831
        out,
832
        logsumexp,
833
        grad_out,
834
        grad_logsumexp,
835
        fw_graph,
836
        joint_graph,
837
        block_mask,
838
        scale,
839
        kernel_options,
840
        score_mod_other_buffers,
841
        mask_mod_other_buffers,
842
    )
843
    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
844
    out_proxy = proxy_mode.tracer.create_proxy(
845
        "call_function",
846
        flex_attention_backward,
847
        proxy_args,
848
        {},
849
        name="flex_attention_backward",
850
    )
851
    return track_tensor_tree(
852
        example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
853
    )
854

855

856
@flex_attention_backward.py_impl(ProxyTorchDispatchMode)
857
def flex_attention_backward_proxy_torch_dispatch_mode(
858
    mode: ProxyTorchDispatchMode,
859
    query: torch.Tensor,
860
    key: torch.Tensor,
861
    value: torch.Tensor,
862
    out: torch.Tensor,
863
    logsumexp: torch.Tensor,
864
    grad_out: torch.Tensor,
865
    grad_logsumexp: torch.Tensor,
866
    fw_graph: Union[Callable, GraphModule],
867
    joint_graph: GraphModule,
868
    block_mask: Tuple,
869
    scale: float,
870
    kernel_options: Dict[str, Any],
871
    score_mod_other_buffers: Tuple = (),
872
    mask_mod_other_buffers: Tuple = (),
873
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
874
    assert mode is not None, "Mode should always be enabled for python fallback key"
875
    return trace_flex_attention_backward(
876
        mode,
877
        query,
878
        key,
879
        value,
880
        out,
881
        logsumexp,
882
        grad_out,
883
        grad_logsumexp,
884
        fw_graph,
885
        joint_graph,
886
        block_mask,
887
        scale,
888
        kernel_options,
889
        score_mod_other_buffers,
890
        mask_mod_other_buffers,
891
    )
892

893

894
@flex_attention_backward.py_functionalize_impl
895
def flex_attention_backward_functionalize(
896
    ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,
897
    query: torch.Tensor,
898
    key: torch.Tensor,
899
    value: torch.Tensor,
900
    out: torch.Tensor,
901
    logsumexp: torch.Tensor,
902
    grad_out: torch.Tensor,
903
    grad_logsumexp: torch.Tensor,
904
    fw_graph: Union[Callable, GraphModule],
905
    joint_graph: GraphModule,
906
    block_mask: Tuple,
907
    scale: float,
908
    kernel_options: Dict[str, Any],
909
    score_mod_other_buffers: Tuple = (),
910
    mask_mod_other_buffers: Tuple = (),
911
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
912
    """Defines the functionalization rules for the flex_attention operator.
913

914
    Write now we are unwrapping each tensor and then redispatching to the next,
915
    since we know that the forward score mod function is assured to be free of mutations
916
    to the other_buffers, we skip that mutate check and go straight to redispatching.
917
    """
918
    query_unwrapped = ctx.unwrap_tensors(query)
919
    key_unwrapped = ctx.unwrap_tensors(key)
920
    value_unwrapped = ctx.unwrap_tensors(value)
921
    out_unwrapped = ctx.unwrap_tensors(out)
922
    logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp)
923
    grad_out_unwrapped = ctx.unwrap_tensors(grad_out)
924
    grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp)
925
    block_mask_unwrapped = ctx.unwrap_tensors(block_mask)
926
    score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)
927
    mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)
928

929
    # Appease the mypy overlords
930
    assert isinstance(query_unwrapped, torch.Tensor)
931
    assert isinstance(key_unwrapped, torch.Tensor)
932
    assert isinstance(value_unwrapped, torch.Tensor)
933
    assert isinstance(out_unwrapped, torch.Tensor)
934
    assert isinstance(logsumexp_unwrapped, torch.Tensor)
935
    assert isinstance(grad_out_unwrapped, torch.Tensor)
936
    assert isinstance(grad_logsumexp_unwrapped, torch.Tensor)
937
    assert isinstance(block_mask_unwrapped, tuple)
938
    assert isinstance(score_mod_other_buffers_unwrapped, tuple)
939
    assert isinstance(mask_mod_other_buffers_unwrapped, tuple)
940
    assert all(
941
        isinstance(item, torch.Tensor)
942
        for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped
943
    )
944

945
    with ctx.redispatch_to_next() as m:
946
        functional_fw_graph = ctx.functionalize(fw_graph)
947
        functional_joint_graph = ctx.functionalize(joint_graph)
948

949
        grad_query, grad_key, grad_value = flex_attention_backward(
950
            query_unwrapped,
951
            key_unwrapped,
952
            value_unwrapped,
953
            out_unwrapped,
954
            logsumexp_unwrapped,
955
            grad_out_unwrapped,
956
            grad_logsumexp_unwrapped,
957
            functional_fw_graph,  # type: ignore[arg-type]
958
            functional_joint_graph,  # type: ignore[arg-type]
959
            block_mask_unwrapped,
960
            scale,
961
            kernel_options,
962
            score_mod_other_buffers_unwrapped,
963
            mask_mod_other_buffers_unwrapped,
964
        )
965

966
    return ctx.wrap_tensors((grad_query, grad_key, grad_value))  # type: ignore[return-value,arg-type]
967

968

969
@flex_attention_backward.py_impl(FakeTensorMode)
970
def flex_attention_backward_fake_tensor_mode(
971
    mode: FakeTensorMode,
972
    query: torch.Tensor,
973
    key: torch.Tensor,
974
    value: torch.Tensor,
975
    out: torch.Tensor,
976
    logsumexp: torch.Tensor,
977
    grad_out: torch.Tensor,
978
    grad_logsumexp: torch.Tensor,
979
    fw_graph: Union[Callable, GraphModule],
980
    joint_graph: GraphModule,
981
    block_mask: Tuple,
982
    scale: float,
983
    kernel_options: Dict[str, Any],
984
    score_mod_other_buffers: Tuple = (),
985
    mask_mod_other_buffers: Tuple = (),
986
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
987
    with mode:
988
        grad_query = torch.empty_like(query)
989
        grad_key = torch.empty_like(key)
990
        grad_value = torch.empty_like(value)
991
        return grad_query, grad_key, grad_value
992

993

994
flex_attention_backward.py_impl(DispatchKey.Autograd)(
995
    autograd_not_implemented(flex_attention_backward, deferred_error=True)
996
)
997

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.