pytorch

Форк
0
/
test_transformers.py 
3386 строк · 166.1 Кб
1
# Owner(s): ["module: nn"]
2

3
import contextlib
4
from functools import partial
5
from collections import namedtuple
6
import sys
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10
from torch.nn.functional import scaled_dot_product_attention
11
from torch.nn.attention import sdpa_kernel, SDPBackend
12
from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
13
from torch.nn.parameter import Parameter
14
import unittest
15
from unittest.mock import patch, MagicMock, ANY
16
import math
17
import torch.optim as optim
18
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
19
from typing import List, Tuple, Optional
20
from torch.testing._internal.common_nn import NNTestCase
21
from torch.testing._internal.common_utils import (
22
    TEST_WITH_ROCM,
23
    skipIfRocm,
24
    TEST_FAIRSEQ,
25
    run_tests,
26
    parametrize,
27
    freeze_rng_state,
28
    TEST_WITH_CROSSREF,
29
    slowTest,
30
    set_default_dtype,
31
    gradcheck,
32
    make_tensor,
33
    NOTEST_CPU,
34
    IS_WINDOWS,
35
    TEST_WITH_TORCHDYNAMO,
36
)
37
from torch._dynamo.testing import CompileCounterWithBackend
38

39

40
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
41
from torch.testing._internal.common_cuda import (
42
    SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
43
    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
44
    PLATFORM_SUPPORTS_FUSED_ATTENTION,
45
    PLATFORM_SUPPORTS_CUDNN_ATTENTION
46
)
47

48
if TEST_FAIRSEQ:
49
    import fairseq.models.transformer as fairseq_transformer
50

51
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
52
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
53

54
@contextlib.contextmanager
55
def use_deterministic_algorithims(mode: bool, warn_only: bool):
56
    r"""
57
    This context manager can be used to temporarily enable or disable deterministic algorithms.
58
    Upon exiting the context manager, the previous state of the flag will be restored.
59
    """
60
    previous_mode: bool = torch.are_deterministic_algorithms_enabled()
61
    previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
62
    try:
63
        torch.use_deterministic_algorithms(mode, warn_only=warn_only)
64
        yield {}
65
    finally:
66
        torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
67

68

69
# Found in torch/testing/_comparison.py
70
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
71
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
72

73
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)]
74
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
75
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
76
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
77

78
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
79
    deviation = true_value - computed_value
80
    deviation = torch.abs(deviation / true_value)
81
    # Fill in the nans with the default rtol
82
    torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
83
    return deviation.max().item()
84

85

86
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
87
    deviation = true_value - computed_value
88
    atol = torch.abs(deviation).max().item()
89
    return atol
90

91

92
def get_tolerances(
93
    true_value: torch.Tensor,
94
    computed_value: torch.Tensor,
95
    fudge_factor: Optional[float] = None,
96
) -> Tuple[float, float]:
97
    """Returns the absolute and relative tolerances for comparing two tensors."""
98
    fudge_factor = fudge_factor if fudge_factor is not None else 1.0
99
    atol = get_atol(true_value, computed_value)
100
    rtol = get_rtol(true_value, computed_value)
101

102
    atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
103
    rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
104
    # torch.isclose() has weird behavior around see:
105
    # https://github.com/pytorch/pytorch/issues/102400
106
    if rtol > 1e30:
107
        rtol = default_rtol[computed_value.dtype]
108
    return atol, rtol
109

110

111
def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
112
    """ Clones the query, key, and value tensors and moves them to the specified dtype. """
113
    if dtype is None:
114
        dtype = query.dtype
115
    query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
116
    key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
117
    value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
118
    return query_ref, key_ref, value_ref
119

120
def get_platform_specific_sdpa():
121
    ret = []
122
    if PLATFORM_SUPPORTS_FLASH_ATTENTION:
123
        ret.append(SDPBackend.FLASH_ATTENTION)
124
    if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
125
        ret.append(SDPBackend.EFFICIENT_ATTENTION)
126
    if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
127
        ret.append(SDPBackend.CUDNN_ATTENTION)
128
    if not ret:
129
        # Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize"
130
        ret.append(SDPBackend.EFFICIENT_ATTENTION)
131
    return ret
132

133
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
134

135
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
136
                     requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
137
    """Creates rand dense or nested tensor with given shape and type.
138

139
    Args:
140
        shape (Tuple[int]): Shape of Tensor to construct
141
        device (str): which device to create tensor on
142
        dtype (torch.dtype): Tensors' dtype
143
        type (str): Nested or Dense
144
        requires_grad (bool, optional): Tensors grad status. Defaults to False.
145
        packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
146

147
    Returns:
148
        torch.Tensor: A new tensor
149
    """
150
    batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim
151
    if type == "nested":
152
        if isinstance(seq_len, list):
153
            def _size(i):
154
                return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
155

156
            return torch.nested.nested_tensor([
157
                torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
158
                for i in range(batch)])
159
        else:
160
            size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
161
            return torch.nested.nested_tensor([
162
                torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
163
                for _ in range(batch)])
164
    else:
165
        assert (isinstance(seq_len, int))
166
        size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
167
        return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
168

169
def calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1):
170
    # TODO use NT ops when we have implemented Max for NestedTensor instead of unrolling
171
    ref_atol = default_atol[default_dtype]
172
    ref_rtol = default_rtol[default_dtype]
173
    for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()):
174
        ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol)
175
        ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol)
176
    return ref_atol, ref_rtol
177

178
class TestTransformers(NNTestCase):
179
    _do_cuda_memory_leak_check = True
180
    _do_cuda_non_default_stream = True
181

182
    @onlyCUDA
183
    @unittest.skip("4D mask not supported yet - activate when 4D mask supported")
184
    def test_self_attn_TxT_attn_mask(self, device):
185
        embed_dim = 16
186
        num_heads = 4
187
        batch_size = 10
188
        tgt_len = 16
189

190
        query = torch.rand(batch_size, tgt_len, embed_dim, device=device)  # [N, T, D]
191
        attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()  # [T, T]
192
        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0)
193

194
        attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
195

196
        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
197
        mta_model.eval()
198

199
        # Generate 3D results
200
        with torch.inference_mode():
201
            output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
202
            output_mask_4d = output_mask_4d.transpose(0, 1)  # [N, T, D]
203

204
            output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
205
            output_mask_TxT = output_mask_TxT.transpose(0, 1)  # [N, T, D]
206

207
            self.assertEqual(output_mask_4d, output_mask_TxT)
208

209
    @slowTest
210
    def test_train_with_pad_and_catch_error(self, device):
211
        iters = 100
212
        pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
213
        layer = nn.TransformerEncoderLayer(
214
            d_model=2,
215
            dim_feedforward=4,
216
            nhead=2,
217
            batch_first=True,
218
            activation="gelu",
219
            dropout=0,
220
        )
221
        criterion = nn.MSELoss()
222
        encoder = nn.TransformerEncoder(layer, 2).to(device)
223
        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
224
        encoder.train()
225
        for i in range(iters):
226
            encoder.train()
227
            optimizer.zero_grad()
228
            inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
229

230
            outputs = encoder(inputs, src_key_padding_mask=pad_mask)
231

232
            loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
233
            loss.backward()
234
            optimizer.step()
235

236
            with torch.no_grad():
237
                test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
238

239
                # Expect uint8 type not supported
240
                ex = None
241
                try:
242
                    test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
243
                except AssertionError as e:
244
                    continue
245
                self.assertFalse(e, "Failed to catch unsupported uint8 type exception")  # noqa: F821
246

247
                test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
248
                encoder.eval()
249

250
                # Expect long type not supported
251
                ex = None
252
                try:
253
                    test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
254
                except AssertionError as e:
255
                    continue
256
                self.assertFalse(e, "Failed to catch unsupported Long type exception")  # noqa: F821
257

258
                test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
259
                l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
260
                self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
261

262
    @parametrize("attn_mask_dim", [2, 3, None])
263
    @parametrize("key_padding_mask_dim", [2, None])
264
    @parametrize("mask_dtype", [torch.bool, torch.float32])
265
    def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
266
        with torch.no_grad():
267
            B = 2
268
            L = 4
269
            D = 8
270
            H = 4
271

272
            if attn_mask_dim == 2:
273
                attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
274
            elif attn_mask_dim == 3:
275
                attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
276
            elif attn_mask_dim is None:
277
                attn_mask = None
278

279
            if key_padding_mask_dim == 2:
280
                key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device)
281
            elif key_padding_mask_dim is None:
282
                key_padding_mask = None
283

284
            mha = nn.MultiheadAttention(D, H, batch_first=True, device=device)
285
            X = torch.randn(B, L, D, device=device)
286

287
            mha.train()  # disable fast path
288
            out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
289
            mha.eval()  # enable fast path
290
            out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
291
            self.assertEqual(out, out_fp)
292

293
    @parametrize("nhead", [1, 4, 8])
294
    def test_transformerencoderlayer_src_mask(self, device, nhead):
295
        batch_size = 2
296
        seqlen = 4
297
        d_model = 8
298
        dim_feedforward = 32
299

300
        model = torch.nn.TransformerEncoderLayer(
301
            d_model=d_model,
302
            nhead=nhead,
303
            dim_feedforward=dim_feedforward,
304
            batch_first=True).to(device)
305
        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
306
        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
307

308
        model(src, src_mask=src_mask)
309
        model.eval()
310
        with torch.no_grad():
311
            model(src, src_mask=src_mask)
312

313
    @parametrize("use_torchscript", [False])
314
    @parametrize("enable_nested_tensor", [True, False])
315
    @parametrize("use_autocast", [True, False])
316
    @parametrize("d_model", [12, 256])
317
    def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model):
318
        """
319
        Test TransformerEncoder fastpath output matches slowpath output
320
        """
321
        torch.manual_seed(1234)
322
        nhead = 4
323
        dim_feedforward = d_model
324
        batch_first = True
325

326
        model = torch.nn.TransformerEncoder(
327
            torch.nn.TransformerEncoderLayer(
328
                d_model=d_model,
329
                nhead=nhead,
330
                dim_feedforward=dim_feedforward,
331
                batch_first=batch_first),
332
            num_layers=2,
333
            enable_nested_tensor=enable_nested_tensor
334
        ).to(device).eval()
335

336
        if use_torchscript:
337
            model = torch.jit.script(model)
338

339
        # each input is (input, mask)
340
        input_mask_pairs = [
341
            (
342
                torch.rand(3, 2, d_model),
343
                [
344
                    [0, 1],
345
                    [0, 1],
346
                    [1, 1]
347
                ]
348
            ),
349
            (
350
                torch.rand(2, 100, d_model),
351
                [
352
                    [0] * 98 + [1] * 2,
353
                    [0] * 90 + [1] * 10
354
                ]
355
            ),
356
            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test 1024.
357
            (
358
                torch.rand(2, 1024, d_model),
359
                [
360
                    [0] * 1020 + [1] * 4,
361
                    [0] * 1024,
362
                ]
363
            ),
364
            (
365
                torch.rand(1, 1026, d_model),
366
                [[0] * 1024 + [1] * 2]
367
            ),
368
            # softmax.cu switches from fast->slowpath at masked seqlen 1024. test range of masks above 1024.
369
            (
370
                torch.rand(4, 1040, d_model),
371
                [
372
                    [0] * 1024 + [1] * 16,
373
                    [0] * 1025 + [1] * 15,
374
                    [0] * 1031 + [1] * 9,
375
                    [0] * 1040,
376
                ]
377
            )
378
        ]
379
        input_mask_pairs = [
380
            (
381
                torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()),  # float input
382
                torch.tensor(pair[1], device=device, dtype=torch.bool)  # bool mask
383
            ) for pair in input_mask_pairs
384
        ]
385

386
        maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext()
387
        with maybe_autocast:
388
            for input, src_key_padding_mask in input_mask_pairs:
389
                with torch.no_grad():
390
                    fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
391
                slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask)  # reference
392
                # Make sure fastpath_output is same shape as slowpath_output and mask.
393
                # When enable_nested_tensor=true, fastpath_output may be smaller than input tensor.
394
                # Eg if input bs=1, seqlen=6, and we mask out 2 tokens, fastpath_output will have bs=1, seqlen=4.
395
                # Expand back to old size to match.
396
                bs, true_seqlen, embed_dim = fastpath_output.shape
397
                expanded_seqlen = src_key_padding_mask.shape[1]
398
                fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device)
399
                fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output
400
                # no garauntees on output corresponding to masked tokens, so they may vary between slow/fast path. set all to 0.
401
                fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
402
                slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
403
                torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5)
404

405
    @parametrize("with_no_grad", [True, False])
406
    @parametrize("training", [True, False])
407
    @parametrize("enable_nested_tensor", [False])
408
    def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
409
        """
410
        Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
411
        batch size == sequence length
412
        """
413
        model = torch.nn.TransformerEncoder(
414
            torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
415
            num_layers=2,
416
            enable_nested_tensor=enable_nested_tensor
417
        ).to(device)
418

419
        with torch.no_grad():
420
            # set constant weights of the model
421
            for idx, p in enumerate(model.parameters()):
422
                x = p.data
423
                sz = x.view(-1).size(0)
424
                shape = x.shape
425
                x = torch.cos(torch.arange(0, sz).float().view(shape))
426
                p.data.copy_(x)
427

428
        if training:
429
            model = model.train()
430
        else:
431
            model = model.eval()
432
        x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device)
433
        src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
434

435
        if with_no_grad:
436
            cm = torch.no_grad()
437
        else:
438
            cm = contextlib.nullcontext()
439
        with cm:
440
            result = model(x, mask=src_mask)
441

442
        ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
443
                                    [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
444
                                   [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
445
                                    [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
446
                                  ).to(device)
447
        self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
448
        torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
449

450
    @parametrize("batch_first", [True, False])
451
    @parametrize("training", [True, False])
452
    @parametrize("enable_nested_tensor", [True, False])
453
    def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
454
        def get_a_test_layer(activation, batch_first=False):
455
            d_model = 4
456
            nhead = 2
457
            dim_feedforward = 16
458
            dropout = 0.0
459

460
            layer = nn.TransformerEncoderLayer(
461
                d_model,
462
                nhead,
463
                dim_feedforward=dim_feedforward,
464
                dropout=dropout,
465
                activation=activation,
466
                batch_first=batch_first,
467
            ).to(device)
468

469
            with torch.no_grad():
470
                # set constant weights of the model
471
                for idx, p in enumerate(layer.parameters()):
472
                    x = p.data
473
                    sz = x.view(-1).size(0)
474
                    shape = x.shape
475
                    x = torch.cos(torch.arange(0, sz).float().view(shape))
476
                    p.data.copy_(x)
477

478
            return layer
479

480
        # this is a deterministic test for TransformerEncoder
481
        activation = F.relu
482

483
        def _test(batch_first, training, enable_nested_tensor):
484
            def perm_fn(x):
485
                return x.transpose(1, 0) if batch_first else x
486

487
            encoder_layer = get_a_test_layer(activation=activation,
488
                                             batch_first=batch_first)
489

490
            model = nn.TransformerEncoder(
491
                encoder_layer, 1, enable_nested_tensor=enable_nested_tensor
492
            ).to(device)
493

494
            if not training:
495
                model = model.eval()
496

497
            # deterministic input
498
            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
499
                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
500
                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
501
                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
502
                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
503
                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
504
                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
505
                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
506
                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
507
                                                   [0.3718, 0.4945, 0.9511, 0.0864]]]
508
                                                 )).to(device)
509
            result = model(encoder_input)
510
            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
511
                                                [2.427987, 0.021213, -0.602496, -0.084103]],
512
                                               [[2.424689, 0.019155, -0.604793, -0.085672],
513
                                                [2.413863, 0.022211, -0.612486, -0.072490]],
514
                                               [[2.433774, 0.021598, -0.598343, -0.087548],
515
                                                [2.425104, 0.019748, -0.604515, -0.084839]],
516
                                               [[2.436185, 0.022682, -0.596625, -0.087261],
517
                                                [2.433556, 0.021891, -0.598509, -0.086832]],
518
                                               [[2.416246, 0.017512, -0.610712, -0.082961],
519
                                                [2.422901, 0.024187, -0.606178, -0.074929]]]
520
                                              )).to(device)
521
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
522
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
523

524
            # all 0 src_mask
525
            src_mask = torch.zeros([5, 5]).to(device) == 1
526
            result = model(encoder_input, mask=src_mask)
527
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
528
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
529

530
            # all 0
531
            mask = torch.zeros([2, 5]).to(device) == 1
532
            result = model(encoder_input, src_key_padding_mask=mask)
533
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
534
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
535

536
            mask[0, 1] = 1
537
            mask[1, 3] = 1
538
            mask[1, 4] = 1
539
            result = model(encoder_input, src_key_padding_mask=mask)
540
            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
541
                                                [2.428811, 0.021445, -0.601912, -0.084252]],
542
                                               [[2.425009, 0.019155, -0.604566, -0.085899],
543
                                                [2.415408, 0.02249, -0.611415, -0.073]],
544
                                               [[2.434199, 0.021682, -0.598039, -0.087699],
545
                                                [2.42598, 0.019941, -0.603896, -0.085091]],
546
                                               [[2.436457, 0.022736, -0.59643, -0.08736],
547
                                                [2.434021, 0.022093, -0.598179, -0.08679]],
548
                                               [[2.416531, 0.017498, -0.610513, -0.083181],
549
                                                [2.4242, 0.024653, -0.605266, -0.074959]]]
550
                                              )).to(device)
551
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
552
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
553

554
            # test case 2, multiple layers no norm
555
            model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
556
            if not training:
557
                model = model.eval()
558
            result = model(encoder_input, src_key_padding_mask=mask)
559
            ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
560
                                                [2.419102, 0.017452, -0.608703, -0.085026]],
561
                                               [[2.419043, 0.017445, -0.608744, -0.084999],
562
                                                [2.419052, 0.017446, -0.608738, -0.085004]],
563
                                               [[2.419067, 0.017448, -0.608727, -0.085010],
564
                                                [2.419098, 0.017452, -0.608706, -0.085024]],
565
                                               [[2.419072, 0.017449, -0.608724, -0.085012],
566
                                                [2.419119, 0.017455, -0.608691, -0.085034]],
567
                                               [[2.419019, 0.017442, -0.608761, -0.084989],
568
                                                [2.419075, 0.017449, -0.608722, -0.085014]]]
569
                                              )).to(device)
570
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
571
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
572

573
            model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
574
            if not training:
575
                model = model.eval()
576
            result = model(encoder_input, src_key_padding_mask=mask)
577
            ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
578
                                                [2.419101, 0.017453, -0.608704, -0.085025]],
579
                                               [[2.419101, 0.017453, -0.608703, -0.085025],
580
                                                [2.419101, 0.017453, -0.608704, -0.085025]],
581
                                               [[2.419101, 0.017453, -0.608703, -0.085025],
582
                                                [2.419101, 0.017453, -0.608704, -0.085025]],
583
                                               [[2.419101, 0.017453, -0.608703, -0.085025],
584
                                                [2.419101, 0.017453, -0.608704, -0.085025]],
585
                                               [[2.419101, 0.017453, -0.608703, -0.085025],
586
                                                [2.419101, 0.017453, -0.608704, -0.085025]]]
587
                                              )).to(device)
588
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
589
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
590

591
            # test case 3, multiple layers with norm
592
            # d_model = 4
593
            norm = nn.LayerNorm(4)
594
            model = nn.TransformerEncoder(encoder_layer, 2, norm=norm,
595
                                          enable_nested_tensor=enable_nested_tensor).to(device)
596
            if not training:
597
                model = model.eval()
598
            result = model(encoder_input, src_key_padding_mask=mask)
599
            ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
600
                                                [1.695955, -0.357639, -0.893050, -0.445266]],
601
                                               [[1.695948, -0.357634, -0.893082, -0.445233],
602
                                                [1.695950, -0.357635, -0.893077, -0.445238]],
603
                                               [[1.695951, -0.357636, -0.893069, -0.445246],
604
                                                [1.695955, -0.357639, -0.893052, -0.445264]],
605
                                               [[1.695952, -0.357636, -0.893066, -0.445249],
606
                                                [1.695957, -0.357641, -0.893041, -0.445276]],
607
                                               [[1.695946, -0.357632, -0.893095, -0.445220],
608
                                                [1.695952, -0.357637, -0.893065, -0.445251]]]
609
                                              )).to(device)
610
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
611
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
612

613
            model = nn.TransformerEncoder(encoder_layer, 6, norm=norm,
614
                                          enable_nested_tensor=enable_nested_tensor).to(device)
615
            if not training:
616
                model = model.eval()
617
            result = model(encoder_input, src_key_padding_mask=mask)
618
            ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
619
                                                [1.695955, -0.357639, -0.893051, -0.445265]],
620
                                               [[1.695955, -0.357639, -0.893051, -0.445265],
621
                                                [1.695955, -0.357639, -0.893051, -0.445265]],
622
                                               [[1.695955, -0.357639, -0.893051, -0.445265],
623
                                                [1.695955, -0.357639, -0.893051, -0.445265]],
624
                                               [[1.695955, -0.357639, -0.893051, -0.445265],
625
                                                [1.695955, -0.357639, -0.893051, -0.445265]],
626
                                               [[1.695955, -0.357639, -0.893051, -0.445265],
627
                                                [1.695955, -0.357639, -0.893051, -0.445265]]]
628
                                              )).to(device)
629
            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
630
            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
631

632
        # TODO: remove set default dtype to double by making ref_output more precise.
633
        # Added because this test was copied from test_nn.py, which has default
634
        # dtype double. If default dtype is float, tests will say tensors not close because
635
        # ref output precision too low
636
        with set_default_dtype(torch.double):
637
            if training:
638
                cm = contextlib.nullcontext()
639
            else:
640
                cm = torch.no_grad()  # transformer fast path requires no grad
641
            with cm:
642
                _test(batch_first, training, enable_nested_tensor)
643

644
    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
645
    def test_encoder_padding_and_src_mask_bool(self):
646
        encoder_layer = nn.TransformerEncoderLayer(
647
            d_model=16,
648
            nhead=2,
649
            dim_feedforward=32,
650
            dropout=0.1,
651
            activation='relu',
652
            batch_first=True,
653
        )
654
        encoder_norm = nn.LayerNorm(16)
655
        encoder = nn.TransformerEncoder(
656
            encoder_layer, 2, encoder_norm
657
        )
658

659
        inputs = torch.randn(2, 3, 16)
660

661
        src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1)
662
        input_seq_len = torch.tensor([3, 2])
663
        padding_mask = (
664
            torch.arange(3)[None, :].cpu() >= input_seq_len[:, None]
665
        )
666

667
        with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()):
668
            encoder(
669
                inputs,
670
                mask=src_mask,
671
                src_key_padding_mask=padding_mask,
672
            )
673

674
    @unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
675
    def test_decoder_padding_and_src_mask_bool(self):
676

677
        def transformer_decoder(inputs, input_seq_len, memory):
678
            decoder_layer = nn.TransformerDecoderLayer(
679
                d_model=16,
680
                nhead=2,
681
                dim_feedforward=32,
682
                dropout=0.1,
683
                activation='relu',
684
                batch_first=True,
685
            )
686
            decoder_norm = nn.LayerNorm(16)
687
            decoder = nn.TransformerDecoder(
688
                decoder_layer, 2, decoder_norm
689
            )
690

691
            src_mask = torch.ones(
692
                inputs.shape[1], inputs.shape[1], dtype=torch.bool
693
            ).triu_(diagonal=1)
694
            padding_mask = (
695
                torch.arange(inputs.shape[1])[None, :].cpu()
696
                >= input_seq_len[:, None]
697
            )
698

699
            return decoder(
700
                inputs,
701
                memory,
702
                tgt_mask=src_mask,
703
                tgt_key_padding_mask=padding_mask,
704
                memory_key_padding_mask=padding_mask,
705
            )
706

707
        inputs = torch.randn(2, 3, 16)
708
        memory = torch.randn(2, 3, 16)
709
        input_seq_len = torch.tensor([3, 2])
710

711
        with self.assertNoLogs(None):
712
            transformer_decoder(inputs, input_seq_len, memory)
713

714
    def test_encoder_is_causal(self):
715

716
        d_model = 3
717
        layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True)
718
        layer.eval()
719
        x = torch.randn(1, 5, d_model)
720
        unmasked_output = layer(x)
721
        mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1))
722
        is_causal_output = layer(x, src_mask=mask, is_causal=True)
723
        masked_output = layer(x, src_mask=mask)
724

725
        self.assertEqual(masked_output, is_causal_output)
726

727
    @onlyCUDA
728
    @parametrize("nb_heads", [1, 8])
729
    @parametrize("bias", [True, False])
730
    def test_mha_native_args(self, nb_heads, bias):
731

732
        B, L, F = 8, 100, 128
733
        batch_first = True
734
        fast_path = True
735
        use_pad_mask = (bias % 2) == 1
736

737
        mha = nn.MultiheadAttention(
738
            embed_dim=F,
739
            num_heads=nb_heads,
740
            batch_first=batch_first,
741
            bias=bias
742
        ).cuda()
743
        mha.eval()
744

745
        ctx = torch.no_grad if fast_path else contextlib.nullcontext
746
        with ctx():
747
            x = torch.randn(B, L, F).cuda()
748
            if not batch_first:
749
                x = x.transpose(0, 1)
750

751
            pad_mask = None
752
            if use_pad_mask:
753
                pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda()
754

755
            mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
756

757
    def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
758
        encoder_layer = nn.TransformerEncoderLayer(
759
            d_model=256,
760
            nhead=4,
761
            dim_feedforward=512,
762
            activation='gelu',
763
            norm_first=False,
764
            batch_first=False,
765
        )
766
        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
767

768
        x = torch.randn(10, 6, 256).to(device)
769
        mask = torch.ones(6, 10)
770
        mask[0, :] = 0  # here I masked 5 columns instead of just one
771
        mask = mask.bool().to(device)
772
        out = transformer_encoder(src=x, src_key_padding_mask=mask)
773
        self.assertEqual(out.shape[1], 6)
774

775
    # CPU unit test has_torch_functions in test environment,
776
    #   preventing successful completion
777
    @onlyCUDA
778
    def test_with_nested_tensor_input(self, device):
779
        encoder_layer = nn.TransformerEncoderLayer(
780
            d_model=256,
781
            nhead=4,
782
            dim_feedforward=512,
783
            activation='gelu',
784
            norm_first=False,
785
            batch_first=True,
786
        )
787
        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
788

789
        transformer_encoder.eval()
790
        with torch.no_grad():
791
            x = torch.randn(6, 10, 256).to(device)
792
            mask = torch.ones(6, 10)
793
            mask[0, 0:] = 0  # here I masked 5 columns instead of just one
794
            mask[2, 2:] = 0  # here I masked 5 columns instead of just one
795
            mask[4, 4:] = 0  # here I masked 5 columns instead of just one
796
            mask[5, 8:] = 0  # here I masked 5 columns instead of just one
797
            mask = mask.bool().to(device)
798
            x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False)
799
            out = transformer_encoder(src=x, src_key_padding_mask=None)
800

801
        self.assertEqual(out.is_nested, True)
802

803

804

805
    def test_script_encoder_subclass(self, device):
806
        class MyCustomLayer(nn.TransformerEncoderLayer):
807
            pass
808

809
        encoder = nn.TransformerEncoder(
810
            MyCustomLayer(d_model=256, nhead=8), num_layers=6
811
        ).to(device=device)
812
        torch.jit.script(encoder)
813

814
    # brazenly adapted from test_transformerencoderlayer_src_mask to test execution of
815
    # torchscripted transformerencoderlayer subclass
816
    def test_transformerencoderlayer_subclass(self, device):
817
        class MyCustomLayer(nn.TransformerEncoderLayer):
818
            pass
819

820
        nhead = 4
821
        batch_size = 2
822
        seqlen = 4
823
        d_model = 8
824
        dim_feedforward = 32
825

826
        model = MyCustomLayer(
827
            d_model=d_model,
828
            nhead=nhead,
829
            dim_feedforward=dim_feedforward,
830
            batch_first=True).to(device)
831
        script_model = torch.jit.script(model)
832

833
        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
834
        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
835

836
        torch.manual_seed(42)
837
        result = model(src, src_mask=src_mask)
838
        torch.manual_seed(42)
839
        scripted_result = script_model(src, src_mask=src_mask)
840
        self.assertEqual(result, scripted_result)
841

842
        model.eval()
843
        script_model = torch.jit.script(model)
844

845
        with torch.no_grad():
846
            result = model(src, src_mask=src_mask)
847
            scripted_result = script_model(src, src_mask=src_mask)
848
            self.assertEqual(result, scripted_result)
849

850

851
    def test_transformerencoderlayer_subclass_model(self, device):
852
        class MyCustomLayer(nn.TransformerEncoderLayer):
853
            pass
854

855
        nhead = 4
856
        batch_size = 2
857
        seqlen = 4
858
        d_model = 8
859
        dim_feedforward = 32
860

861
        layer = MyCustomLayer(
862
            d_model=d_model,
863
            nhead=nhead,
864
            dim_feedforward=dim_feedforward,
865
            batch_first=True)
866
        model = nn.TransformerEncoder(
867
            layer, num_layers=6
868
        ).to(device=device)
869
        script_model = torch.jit.script(model)
870

871
        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
872
        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
873

874
        torch.manual_seed(42)
875
        result = model(src, mask=src_mask)
876
        torch.manual_seed(42)
877
        scripted_result = script_model(src, mask=src_mask)
878
        self.assertEqual(result, scripted_result)
879

880
        model.eval()
881
        script_model = torch.jit.script(model)
882

883
        with torch.no_grad():
884
            result = model(src, mask=src_mask)
885
            scripted_result = script_model(src, mask=src_mask)
886
            self.assertEqual(result, scripted_result)
887

888

889
    @onlyCUDA
890
    @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
891
    def test_decoder_only_layer(self):
892
        DEFAULT_PADDING_IDX = 0
893

894
        class FairseqDecoder(torch.nn.Module):
895
            def __init__(
896
                self,
897
                embed_dim,
898
                attention_heads,
899
                ffn_embed_dim,
900
                num_layers,
901
                embedding_layer,  # torch.nn.Embedding. Must have a padding_idx field
902
                dropout=0,
903
                normalize_before=False,
904
                torch_encoder=None,  # torch encoder that you can map weights from
905
                activation="relu",
906
            ):
907
                super().__init__()
908

909
                cfg = fairseq_transformer.TransformerConfig()
910
                cfg.decoder.embed_dim = embed_dim
911
                cfg.decoder.output_dim = embed_dim
912
                cfg.decoder.attention_heads = attention_heads
913
                cfg.decoder.ffn_embed_dim = ffn_embed_dim
914
                cfg.dropout = dropout
915
                cfg.decoder.normalize_before = normalize_before
916
                cfg.decoder.layers = num_layers
917
                # make embedding behavior same as other encoders
918
                cfg.no_token_positional_embeddings = True
919
                cfg.no_scale_embedding = True
920
                cfg.activation_fn = activation
921

922
                dictionary = {}  # TODO: verify what this is
923

924
                self.decoder = fairseq_transformer.TransformerDecoder(
925
                    cfg,
926
                    dictionary,
927
                    embedding_layer,
928
                    no_encoder_attn=True,
929
                    output_projection=None,
930
                )
931

932
                if torch_encoder is not None:
933
                    self.decoder = torch_to_fairseq(torch_encoder, self.decoder)  # noqa: F821
934
                self.decoder = self.decoder.eval().cuda().half()
935

936
            def forward(
937
                self,
938
                tokens,
939
                src_lengths=None,
940
                with_triangle_mask=False,
941
                incremental_state=None,
942
            ):
943
                return self.decoder(
944
                    prev_output_tokens=tokens,
945
                    encoder_out=None,
946
                    incremental_state=incremental_state,
947
                    features_only=True,
948
                    full_context_alignment=not with_triangle_mask,
949
                    alignment_layer=None,
950
                    alignment_heads=None,
951
                    src_lengths=src_lengths,
952
                    return_all_hiddens=False,
953
                )[0]
954

955
    @parametrize("input_dim,attn_mask_dim,is_causal",
956
                 [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
957
                  (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
958
                 name_fn=lambda input_dim, attn_dim, is_causal: (
959
                     f"{input_dim}D_input_dim_" + (
960
                         f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
961
                         if attn_dim is not None else "no_attn_mask")))
962
    @parametrize("dropout_p", [0.0, 0.2, 0.5])
963
    @sdpa_kernel(backends=[SDPBackend.MATH])
964
    def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
965
        def sdp_ref(
966
                q,
967
                k,
968
                v,
969
                attn_mask=None,
970
                dropout_p=0.0):
971
            E = q.size(-1)
972
            q = q / math.sqrt(E)
973
            # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
974
            if attn_mask is not None:
975
                attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
976
            else:
977
                attn = torch.bmm(q, k.transpose(-2, -1))
978

979
            attn = torch.nn.functional.softmax(attn, dim=-1)
980
            if dropout_p > 0.0:
981
                attn = torch.nn.functional.dropout(attn, p=dropout_p)
982
            # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
983
            output = torch.bmm(attn, v)
984
            return output
985
        # TODO: Support cross-device / dtype testing properly when instantiate_device_type_tests() is used.
986
        dtypes = [torch.double, torch.float]
987
        for dtype in dtypes:
988

989
            def rand_tensor(*shape):
990
                return torch.randn(shape, device=device, dtype=dtype)
991

992
            # This test compares python and C++ implementations of SDP.
993
            N, N_prime, L, S, E = 5, 2, 4, 3, 6
994
            if input_dim == 3:
995
                query = rand_tensor(N, L, E)
996
                key = rand_tensor(N, S, E)
997
                value = rand_tensor(N, S, E)
998
            elif input_dim == 4:
999
                query = rand_tensor(N, N_prime, L, E)
1000
                key = rand_tensor(N, N_prime, S, E)
1001
                value = rand_tensor(N, N_prime, S, E)
1002
            else:
1003
                self.fail(f'Invalid input_dim {input_dim} encountered in SDP test')
1004

1005
            attn_mask = None
1006
            if attn_mask_dim is not None:
1007
                assert attn_mask_dim in [2, input_dim]
1008
                mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S))
1009
                attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
1010
                             else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
1011

1012
            with freeze_rng_state():
1013
                # Python impl only supports float mask and 3D inputs.
1014
                attn_mask_float = attn_mask
1015
                if attn_mask_float is not None:
1016
                    attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
1017
                    attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
1018
                q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E)
1019
                a = attn_mask_float
1020
                if a is not None and attn_mask_dim > 3:
1021
                    a = a.view(-1, L, S)
1022
                expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p)
1023
                if input_dim > 3:
1024
                    expected = expected.view(-1, N_prime, L, E)
1025

1026
            with freeze_rng_state():
1027
                if is_causal:
1028
                    # NB: Don't pass attn_mask here
1029
                    actual = torch.nn.functional.scaled_dot_product_attention(
1030
                        query, key, value, None, dropout_p, is_causal)
1031

1032
                    # Error case: both explicit attn_mask and is_causal are set
1033
                    with self.assertRaisesRegex(RuntimeError,
1034
                                                "Explicit attn_mask should not be set when is_causal=True"):
1035
                        torch.nn.functional.scaled_dot_product_attention(
1036
                            query, key, value, attn_mask, dropout_p, is_causal)
1037
                else:
1038
                    actual = torch.nn.functional.scaled_dot_product_attention(
1039
                        query, key, value, attn_mask, dropout_p, is_causal)
1040

1041
                self.assertEqual(actual, expected)
1042

1043
        if attn_mask_dim is None:
1044
            q = q.double().clone()
1045
            k = k.double().clone()
1046
            v = v.double().clone()
1047
            q.requires_grad_()
1048
            k.requires_grad_()
1049
            v.requires_grad_()
1050

1051
            assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs),
1052
                             (q, k, v, attn_mask, dropout_p))
1053
            assert gradcheck(lambda *args, **kwargs:
1054
                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1055
                             (q, k, v, attn_mask, dropout_p))
1056

1057
        def test_incompatible_mask(self, device):
1058
            def ones_tensor(*shape):
1059
                return torch.ones(shape, dtype=torch.float32)
1060
            S, L, E, H = 1, 2, 4, 1
1061
            qkv = ones_tensor(S, L, E)
1062

1063
            mha = nn.MultiheadAttention(E, H)
1064
            mha.in_proj_weight = Parameter(torch.ones((E * 3, E)))
1065
            mha.out_proj.weight = Parameter(torch.ones((E, E)))
1066
            qkv = qkv.to(float)
1067
            kpm = ones_tensor(S, L) * float("-inf")
1068
            am = ones_tensor(L, L).to(bool)
1069

1070
            def func():
1071
                return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am)
1072

1073
            self.assertRaises(RuntimeError, func)
1074

1075
    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1076
    @torch.no_grad()
1077
    def test_mask_check_fastpath(self):
1078
        """
1079
        Test that fastpath is executed independently of the masks that are passed.
1080
        If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used
1081
        (sparsity fastpath), otherwise use fastpath with traditional tensors.
1082
        Also test that fast path is executed with both key padding mask and attention mask passed at the same time.
1083
        """
1084

1085
        x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
1086

1087
        def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True):
1088
            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1089
                fastpath_mock.return_value = mock_return_value
1090
                model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask)
1091

1092
                # If mock was called, fastpath was taken
1093
                self.assertTrue(fastpath_mock.called)
1094

1095
                # If mock was called with nested tensors, sparsity fastpath was taken
1096
                for call_args, _ in fastpath_mock.call_args_list:
1097
                    self.assertEqual(call_args[0].is_nested, nested_tensors)
1098

1099
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1100

1101
        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1102
        model.eval()
1103

1104
        aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
1105
        not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
1106
        attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool)
1107
        nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
1108
        tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
1109

1110
        # Left aligned mask results in sparsity fastpath
1111
        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1112

1113
        # Not aligned mask results in fastpath
1114
        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1115

1116
        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
1117
        model.eval()
1118

1119
        # If nested tensor disabled, fastpath is always taken
1120
        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1121
        _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1122
        # Fast path is taken if both attention mask and key padding mask are present
1123
        _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False)
1124

1125
        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
1126
        model.eval()
1127

1128
        # Mask check disabled results in sparisty fastpath, independently of the mask
1129
        _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1130
        _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1131

1132
    # Test failing MHA when bias was NoneType
1133
    def test_bias_is_none(self):
1134
        x = torch.rand((1, 5, 10))
1135
        model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
1136
        model.eval()
1137
        model(x, x, x)
1138
        # completes without error
1139

1140
    def test_transformer_bias_is_none(self, device):
1141
        batch_size = 2
1142
        seqlen = 3
1143
        d_model = 8
1144
        nhead = 4
1145

1146
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device)
1147
        encoder_layer.eval()
1148
        x = torch.randn(batch_size, seqlen, d_model, device=device)
1149
        # runs without error
1150
        encoder_layer(x)
1151

1152
        with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"):
1153
            encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval()
1154
            encoder(x)
1155

1156
        with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"):
1157
            transformer = torch.nn.Transformer(
1158
                d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device
1159
            ).eval()
1160
            transformer(x, x)
1161

1162
    def test_train_with_is_causal(self, device):
1163
        # training with is_causal
1164
        S, L, E, H = 1, 2, 2, 1
1165
        layer = nn.TransformerEncoderLayer(
1166
            d_model=2,
1167
            dim_feedforward=4,
1168
            nhead=H,
1169
            batch_first=True,
1170
            activation="gelu",
1171
            dropout=0,
1172
        )
1173
        criterion = nn.MSELoss()
1174
        encoder = nn.TransformerEncoder(layer, 2).to(device)
1175
        optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
1176
        encoder.train()
1177

1178
        encoder.train()
1179
        optimizer.zero_grad()
1180
        inputs = torch.randn(S, L, E).to(device)
1181
        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1182
            inputs.size(1), device=device
1183
        )
1184

1185
        outputs = encoder(inputs, mask=mask, is_causal=True)
1186

1187
        loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
1188
        loss.backward()
1189
        optimizer.step()
1190

1191
        # inference with is_causal
1192
        t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32)
1193
        mha = nn.MultiheadAttention(E, H).to(device)
1194
        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1195
            S, device=device
1196
        )
1197

1198
        attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True)
1199

1200
        # Can't give only is_causal
1201
        attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool)
1202
        with self.assertRaises(RuntimeError):
1203
            _ = mha(t_qvk, t_qvk, t_qvk, is_causal=True)
1204

1205
        # # Passing a causal mask sets is_causal to 1
1206
        causal_mask = torch.triu(
1207
            torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1
1208
        ).to(torch.bool)
1209

1210
        mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs)
1211
        encoder.layers[1] = mock_layer
1212
        outputs = encoder(inputs, mask=causal_mask)
1213
        mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
1214

1215
        # check expected numerical values with all kernels
1216
        self.is_causal_kernels([SDPBackend.MATH], device)
1217

1218
    def is_causal_kernels(self, kernels, device):
1219
        def ones_tensor(*shape):
1220
            return torch.ones(shape, device=device, dtype=torch.float32).to(device)
1221
        S, L, E, H = 1, 2, 4, 1
1222
        qkv = ones_tensor(S, L, E)
1223

1224
        mha = nn.MultiheadAttention(E, H).to(device)
1225
        mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device))
1226
        mha.out_proj.weight = Parameter(torch.ones((E, E), device=device))
1227
        expected = torch.ones(size=(S, L, E)).to(device) * 16
1228
        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1229
            qkv.size(1), device=device
1230
        )
1231

1232
        for kernel in kernels:
1233
            with sdpa_kernel(backends=[kernel]):
1234
                actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
1235
                self.assertTrue(torch.equal(actual, expected))
1236

1237
                if kernel != SDPBackend.MATH:
1238
                    # fails with embedding size not multiple of 4
1239
                    with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1240
                        qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
1241
                        mask = torch.nn.Transformer.generate_square_subsequent_mask(
1242
                            qkv_f.size(1), device=device
1243
                        )
1244
                        _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
1245
                        torch.cuda.synchronize()
1246

1247
    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1248
    @unittest.skipIf(
1249
        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
1250
    )
1251
    def test_is_causal_gpu(self):
1252
        device = 'cuda'
1253
        self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
1254

1255
    def test_script_mha_in_proj_weight_none(self):
1256
        mha = torch.nn.MultiheadAttention(
1257
            embed_dim=128, num_heads=8, kdim=256, vdim=256
1258
        ).eval()
1259

1260
        torch.jit.script(mha)
1261

1262
    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1263
    @torch.no_grad()
1264
    def test_disable_fastpath(self, device):
1265
        def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1266
            if kwargs is None:
1267
                kwargs = {}
1268
            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1269
                fastpath_mock.return_value = return_value
1270
                output = model(*args, **kwargs)
1271
                self.assertTrue(fastpath_mock.called == is_called)
1272

1273
        def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1274
            if kwargs is None:
1275
                kwargs = {}
1276
            with patch('torch._native_multi_head_attention') as fastpath_mock:
1277
                fastpath_mock.return_value = return_value
1278
                output = model(*args, **kwargs)
1279
                self.assertTrue(fastpath_mock.called == is_called)
1280

1281
        inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device)
1282
        aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device)
1283
        src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device)
1284
        attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device)
1285
        te_return_value = torch.ones((1, 3, 2), dtype=torch.float32)
1286

1287
        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1288
        te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1289
        te = te.to(device).eval()
1290

1291
        t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval()
1292
        src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device)
1293
        tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device)
1294
        t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1295

1296
        mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval()
1297
        q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device)
1298
        mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1299

1300
        _test_te_fastpath_called(
1301
            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1302
            return_value=te_return_value, is_called=True
1303
        )
1304
        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1305
        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1306

1307
        torch.backends.mha.set_fastpath_enabled(False)
1308
        _test_te_fastpath_called(
1309
            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1310
            return_value=te_return_value, is_called=False
1311
        )
1312
        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False)
1313
        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False)
1314

1315
        torch.backends.mha.set_fastpath_enabled(True)
1316
        _test_te_fastpath_called(
1317
            te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1318
            return_value=te_return_value, is_called=True
1319
        )
1320
        _test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1321
        _test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1322

1323

1324
class TestSDPAFailureModes(NNTestCase):
1325
    """ Used to test the failure modes of scaled_dot_product_attention
1326
    """
1327
    _do_cuda_memory_leak_check = True
1328
    _do_cuda_non_default_stream = True
1329

1330
    @onlyCUDA
1331
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
1332
                     "Does not support fused SDPA or not SM86+ hardware")
1333
    @parametrize("head_dim", [193, 204, 256])
1334
    def test_flash_backward_failure_sm86plus(self, device, head_dim: int):
1335
        dtype = torch.float16
1336
        make_tensor = partial(torch.rand, device=device, dtype=dtype)
1337
        # See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
1338
        size = (2, 2, 4, head_dim)
1339
        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1340

1341
        with sdpa_kernel(backends=[SDPBackend.MATH]):
1342
            math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1343

1344
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1345
            # Should not fail because inputs don't require grad
1346
            flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1347

1348
            self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3)
1349

1350
            # Should fail because inputs require grad
1351
            q = make_tensor(size, requires_grad=True)
1352
            k = make_tensor(size, requires_grad=True)
1353
            v = make_tensor(size, requires_grad=True)
1354
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1355
                q, k, v, None, 0.0, False))
1356

1357
    @onlyCUDA
1358
    def test_dispatch_fails_no_backend(self, device):
1359
        dtype = torch.float16
1360
        with sdpa_kernel(backends=[SDPBackend.ERROR]):
1361
            size = (2, 3, 4)
1362
            q = torch.randn(size, device=device, dtype=dtype)
1363
            k = torch.randn(size, device=device, dtype=dtype)
1364
            v = torch.randn(size, device=device, dtype=dtype)
1365
            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1366
                                   lambda: torch._fused_sdp_choice(q, k, v))
1367
            self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1368
                                   lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
1369

1370
    @onlyCUDA
1371
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1372
    @parametrize(
1373
        "kernel",
1374
        PLATFORM_SPECIFIC_SDPA,
1375
    )
1376
    def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
1377
        with sdpa_kernel(backends=[kernel]):
1378
            # Dim is not 4
1379
            size = (2, 3, 8)
1380
            dtype = torch.float16
1381
            q = torch.randn(size, device=device, dtype=dtype)
1382
            k = torch.randn(size, device=device, dtype=dtype)
1383
            v = torch.randn(size, device=device, dtype=dtype)
1384
            with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"):
1385
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1386
                    q, k, v, None, 0.0, False))
1387

1388
    @onlyCUDA
1389
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1390
    @parametrize(
1391
        "kernel",
1392
        PLATFORM_SPECIFIC_SDPA,
1393
    )
1394
    def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
1395
        with sdpa_kernel(backends=[kernel]):
1396
            #  Fused Kernels don't support broadcasting for dense inputs
1397
            dtype = torch.float16
1398
            size = (2, 4, 3, 8)
1399
            size_broadcast = (1, 4, 3, 8)
1400
            q = torch.randn(size_broadcast, device=device, dtype=dtype)
1401
            k = torch.randn(size, device=device, dtype=dtype)
1402
            v = torch.randn(size, device=device, dtype=dtype)
1403
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1404
                q, k, v, None, 0.0, False))
1405

1406
    @onlyCUDA
1407
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1408
    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1409
    def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
1410
        with sdpa_kernel(backends=[kernel]):
1411
            # Passing in a q,k,v with 0 length sequences will error
1412
            dtype = torch.float16
1413
            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1414
            size = SdpaShape(2, 2, 0, 8)
1415
            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1416
            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."):
1417
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1418
                    q, k, v, None, 0.0, False))
1419

1420
    @onlyCUDA
1421
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1422
    @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1423
    def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
1424
        with sdpa_kernel(backends=[kernel]):
1425
            # Passing in a q,k,v with 0 length sequences will error
1426
            dtype = torch.float16
1427
            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1428
            size = SdpaShape(2, 2, 8, 8)
1429
            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1430
            q.as_strided_(size, [2, 2, 2, 2])
1431
            with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."):
1432
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1433
                    q, k, v, None, 0.0, False))
1434

1435
    @onlyCUDA
1436
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
1437
    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1438
    def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
1439
        with sdpa_kernel(backends=[kernel]):
1440
            # The embed dim per head is not divisible by 8 for flash attention
1441
            dtype = torch.float16
1442
            make_tensor = partial(torch.rand, device=device, dtype=dtype)
1443
            size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
1444
            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1445
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1446
                q, k, v, None, 0.0, False))
1447

1448
    @onlyCUDA
1449
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1450
    @parametrize(
1451
        "kernel",
1452
        PLATFORM_SPECIFIC_SDPA,
1453
    )
1454
    def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
1455
        with sdpa_kernel(backends=[kernel]):
1456
            # Invalid dtype for both Flash Attention and Mem Efficient Attention
1457
            size = SdpaShape(2, 2, 3, 16)
1458
            make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
1459
            q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1460
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1461
                q, k, v, None, 0.0, False))
1462

1463
    @onlyCUDA
1464
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1465
    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
1466
    def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
1467
        with sdpa_kernel(backends=[kernel]):
1468
            # Failures for unsupported SDP args
1469
            size = SdpaShape(2, 2, 3, 16)
1470
            make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
1471
            q, k, v = make_tensor(), make_tensor(), make_tensor()
1472
            # Non-None attention mask
1473
            mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype)
1474
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1475
                q, k, v, mask, 0.0, False))
1476

1477
    @onlyCUDA
1478
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1479
    def test_unaligned_tensors(self, device):
1480
        # The alignment is depdent on arch so we specifiy SM80OrLater
1481
        dtype = torch.float16
1482
        size = SdpaShape(2, 2, 8, 5)
1483
        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1484
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1485
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1486
            self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1487
                q, k, v, None, 0.0, False))
1488

1489
    @onlyCUDA
1490
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1491
    def test_flash_fail_fp32(self, device):
1492
        dtype = torch.float
1493
        size = SdpaShape(16, 16, 32, 32)
1494
        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1495
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1496
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1497
            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
1498
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1499
                    q, k, v, None, 0.0, False))
1500

1501
    @onlyCUDA
1502
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1503
    def test_flash_autocast_fp32_float16(self, device):
1504
        dtype = torch.float
1505
        size = SdpaShape(16, 16, 32, 32)
1506
        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1507
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1508
        with torch.autocast(device_type='cuda', dtype=torch.float16):
1509
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1510
                _ = torch.nn.functional.scaled_dot_product_attention(
1511
                    q, k, v, None, 0.0, False)
1512

1513
    @onlyCUDA
1514
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1515
    def test_flash_autocast_fp32_bfloat16(self, device):
1516
        dtype = torch.float
1517
        size = SdpaShape(16, 16, 32, 32)
1518
        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1519
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1520
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1521
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1522
                _ = torch.nn.functional.scaled_dot_product_attention(
1523
                    q, k, v, None, 0.0, False)
1524

1525
    # Note: do not truncate the list according to platforms. These tests should always raise errors.
1526
    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1527
    def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
1528
        with sdpa_kernel(backends=[kernel]):
1529
            # Different datatypes
1530
            shape = (1, 4, 8, 16)
1531
            query = torch.randn(shape, dtype=torch.float32, device=device)
1532
            key = torch.randn(shape, dtype=torch.float16, device=device)
1533
            value = torch.randn(shape, dtype=torch.float16, device=device)
1534
            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1535

1536
    @onlyCUDA
1537
    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1538
    def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend):
1539
        # Different devices
1540
        shape = (1, 4, 8, 16)
1541
        query = torch.randn(shape, dtype=torch.float32, device=device)
1542
        key = torch.randn(shape, dtype=torch.float16, device='cpu')
1543
        value = torch.randn(shape, dtype=torch.float16, device='cpu')
1544
        self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1545

1546
    @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1547
    def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
1548
        with sdpa_kernel(backends=[kernel]):
1549
            # 1 dimensional input
1550
            shape = (1, 4)
1551
            query = torch.randn(4, dtype=torch.float16, device=device)
1552
            key = torch.randn(shape, dtype=torch.float16, device=device)
1553
            value = torch.randn(shape, dtype=torch.float16, device=device)
1554
            self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1555

1556
    @onlyCUDA
1557
    @skipIfRocm  # Missing EFFICIENT_ATTENTION
1558
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
1559
    def test_fused_kernels_nested_broadcasting_error_cases(self, device):
1560
        # one of k,v needs to be broadcasted and other has non consistent seq_len dim
1561
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
1562
        batch, num_heads, head_dim = 32, 8, 64
1563
        seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist()
1564
        seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist()
1565

1566
        q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim)
1567
        k_shape = SdpaShape(1, num_heads, 1, head_dim)
1568
        v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim)
1569

1570
        query = rand_nested_tensor(q_shape).transpose(1, 2)
1571
        key = rand_nested_tensor(k_shape).transpose(1, 2)
1572
        value = rand_nested_tensor(v_shape).transpose(1, 2)
1573

1574
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1575
            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1576
                torch.nn.functional.scaled_dot_product_attention(
1577
                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1578

1579
    @onlyCUDA
1580
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1581
    def test_nested_fails_on_padding_head_dim(self, device):
1582
        dtype = torch.bfloat16
1583
        seq_len_list = [2, 4, 5, 6, 7]
1584
        shape = SdpaShape(5, 8, seq_len_list, 57)
1585
        make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
1586
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1587
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1588
            with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
1589
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1590
                    q, k, v, None, 0.0, False))
1591

1592

1593
    @onlyCUDA
1594
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
1595
                     "Current platform does not support fused SDPA or is an SM80+ device.")
1596
    def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
1597
        dtype = torch.bfloat16
1598
        size = SdpaShape(16, 16, 32, 32)
1599
        make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1600
        q, k, v = make_tensor(), make_tensor(), make_tensor()
1601
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1602
            with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
1603
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1604
                    q, k, v, None, 0.0, False))
1605

1606
    @onlyCUDA
1607
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
1608
    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
1609
                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
1610
    def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel):
1611
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
1612
        batch, num_heads, head_dim = 32, 16, 64
1613
        seq_lens = torch.randint(low=1, high=32, size=(batch,))
1614
        # make sure some seq_lens are 0
1615
        num_zeros = 10
1616
        indices = torch.randint(low=0, high=batch, size=(num_zeros,))
1617
        seq_lens.scatter_(0, indices, 0)
1618

1619
        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
1620
        query = rand_nested_tensor(shape)
1621
        key = rand_nested_tensor(shape)
1622
        value = rand_nested_tensor(shape)
1623

1624
        query = query.transpose(1, 2)
1625
        key = key.transpose(1, 2)
1626
        value = value.transpose(1, 2)
1627

1628
        with sdpa_kernel(backends=[fused_kernel]):
1629
            with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1630
                torch.nn.functional.scaled_dot_product_attention(
1631
                    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1632

1633
    @onlyCUDA
1634
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1635
    def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device):
1636
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True)
1637
        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64
1638
        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
1639
        q_shape = SdpaShape(1, num_heads, 1, head_dim)
1640
        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
1641
        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
1642

1643
        # create a dense query
1644
        query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True)
1645
        key = rand_nested_tensor(k_shape)
1646
        value = rand_nested_tensor(v_shape)
1647

1648
        query = query.transpose(1, 2)
1649
        key = key.transpose(1, 2)
1650
        value = value.transpose(1, 2)
1651

1652
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1653
            with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
1654
                with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1655
                    out = torch.nn.functional.scaled_dot_product_attention(
1656
                        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1657

1658
    @onlyCUDA
1659
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1660
    def test_flash_attention_fail_with_non_square_causal_attention(self, device):
1661
        dtype = torch.bfloat16
1662
        q_shape = SdpaShape(1, 1, 8, 16)
1663
        kv_shape = SdpaShape(1, 1, 12, 16)
1664
        make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
1665
        make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
1666
        q, k, v = make_q(), make_kv(), make_kv()
1667
        warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
1668
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1669
            with self.assertWarnsRegex(UserWarning, warning_str):
1670
                self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1671
                    q, k, v, None, 0.0, is_causal=True))
1672

1673
def _get_block_size(device, head_dim, is_causal):
1674
    # This should match the block sizes in the CUDA kernel
1675
    # Mask is only interesting when we are setting dropout
1676
    is_dropout = True
1677
    assert head_dim <= 256
1678
    major, minor = torch.cuda.get_device_capability(device)
1679
    is_sm8x = major == 8 and minor > 0  # Only include sm86 and sm89, exclude sm80 (A100)
1680
    is_sm80 = major == 8 and minor == 0
1681
    is_sm90 = major == 9 and minor == 0
1682
    if head_dim <= 32:
1683
        return 128, 128
1684
    if head_dim <= 64:
1685
        return (128, 128) if not is_dropout else (128, 64)
1686
    elif head_dim <= 96:
1687
        return (64, 64) if (is_sm8x and is_causal) else (128, 64)
1688
    elif head_dim <= 128:
1689
        if is_sm8x:
1690
            return (64, 64) if (not is_dropout and is_causal) else (128, 32)
1691
        else:
1692
            return 128, (64 if not is_dropout else 32)
1693
    elif head_dim <= 160:
1694
        if is_sm8x:
1695
            return (128, 64) if not is_causal else (64, 64)
1696
        else:
1697
            return 128, 32
1698
    elif head_dim <= 192:
1699
        return (128, 64) if not is_dropout else (64, 64)
1700
    elif head_dim <= 224:
1701
        return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
1702
    elif head_dim <= 256:
1703
        return (128, 64) if is_sm80 else (64, 64)
1704

1705

1706
def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
1707
    last_dim_size = input_tensor.size(-1)
1708
    if (last_dim_size % alignment_size == 0):
1709
        return input_tensor, last_dim_size
1710
    pad_count = alignment_size - (last_dim_size % alignment_size)
1711
    padded_tensor = F.pad(input_tensor, (0, pad_count))
1712
    if slice:
1713
        return padded_tensor[..., :last_dim_size], last_dim_size
1714
    return padded_tensor, last_dim_size
1715

1716

1717
class TestSDPA(NNTestCase):
1718
    """ Used to test generic functionality of scaled_dot_product_attention
1719
    Summary:
1720
        If you are adding a new test to this class, make sure that it runs
1721
        for both cpu and cuda. If you're test is only applicable to cuda,
1722
        add it to TestSDPACudaOnly.
1723
    """
1724
    @parametrize("contiguous_inputs", [True, False])
1725
    def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
1726

1727
        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
1728
        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
1729
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
1730
                              dtype=torch.float64, requires_grad=True, packed=True)
1731

1732
        qkv = make_tensor(shape)
1733
        query, key, value = qkv.chunk(3, dim=-1)
1734

1735
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1736
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1737
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1738

1739
        if contiguous_inputs:
1740
            query = query.contiguous()
1741
            key = key.contiguous()
1742
            value = value.contiguous()
1743

1744
        with sdpa_kernel(backends=[SDPBackend.MATH]):
1745
            assert gradcheck(lambda *args, **kwargs:
1746
                             wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1747
                             (query, key, value, None, 0.0, False)
1748
                             )
1749

1750
    @onlyCPU
1751
    @parametrize("type", ["dense", "nested"])
1752
    @parametrize("dropout", [0.0, 0.7])
1753
    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
1754
    def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype):
1755
        # Test that cpu and nestedtensor cpu return MATH backend
1756
        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
1757
        size = SdpaShape(2, 8, 128, 64)
1758
        q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1759
        if type == "nested" \
1760
                or dropout > 0.0 \
1761
                or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]:
1762
            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
1763
        else:
1764
            assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value
1765

1766
    @onlyCPU
1767
    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1768
    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1769
    @parametrize("batch_size", [2, 12])
1770
    @parametrize("seq_len", [267, 1030])
1771
    @parametrize("n_head", [1, 3])
1772
    @parametrize("head_dim", [8, 16])
1773
    @parametrize("causal", [True, False])
1774
    @parametrize("train", [True, False])
1775
    def test_scaled_dot_product_fused_attention_vs_math_cpu(
1776
        self,
1777
        device,
1778
        fused_kernel,
1779
        dtype,
1780
        batch_size,
1781
        seq_len,
1782
        n_head,
1783
        head_dim,
1784
        causal,
1785
        train,
1786
    ):
1787
        atol = 1e-5
1788
        rtol = 5e-6
1789
        if dtype is torch.bfloat16:
1790
            atol = 5e-2
1791
            rtol = 5e-2
1792
        if dtype is torch.float16:
1793
            atol = 1e-2
1794
            rtol = 1e-2
1795

1796
        n_embd = n_head * head_dim
1797
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False)
1798
        shape = SdpaShape(batch_size, n_head, seq_len, head_dim)
1799
        x = make_tensor(shape)
1800
        x2 = x.clone()
1801

1802
        if train:
1803
            x.requires_grad_(True)
1804
            x2.requires_grad_(True)
1805

1806
        q, k, v = x.split(n_embd, dim=2)
1807
        q2, k2, v2 = x2.split(n_embd, dim=2)
1808

1809
        if dtype in [torch.bfloat16, torch.float16]:
1810
            q2 = q2.float()
1811
            k2 = k2.float()
1812
            v2 = v2.float()
1813

1814
        # (B, nh, T, hs)
1815
        k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1816
        q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1817
        v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1818
        k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1819
        q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1820
        v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1821

1822
        with sdpa_kernel(backends=[fused_kernel]):
1823
            actual = torch.nn.functional.scaled_dot_product_attention(
1824
                q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
1825
        with sdpa_kernel(backends=[SDPBackend.MATH]):
1826
            math_ref = torch.nn.functional.scaled_dot_product_attention(
1827
                q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
1828

1829
        if dtype in [torch.bfloat16, torch.float16]:
1830
            math_ref = math_ref.to(dtype)
1831

1832
        self.assertEqual(actual, math_ref, atol=atol, rtol=rtol)
1833

1834
        if train:
1835
            actual.sum().backward()
1836
            math_ref.sum().backward()
1837

1838
            grad_x, grad_x2 = x.grad, x2.grad
1839
            grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2)
1840
            grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2)
1841

1842
            self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol)
1843
            self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol)
1844
            self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol)
1845

1846
    @onlyCPU
1847
    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1848
    @parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1849
    @parametrize("batch_size", [2, 12])
1850
    @parametrize("q_seq_len", [267, 1030])
1851
    @parametrize("kv_seq_len", [514, 1179])
1852
    @parametrize("n_head", [1, 3])
1853
    @parametrize("head_dim", [8, 16])
1854
    @parametrize("mask_dim", [2, 4])
1855
    @parametrize("bool_mask", [0, 1])
1856
    @parametrize("train", [True, False])
1857
    def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
1858
        self,
1859
        device,
1860
        fused_kernel,
1861
        dtype,
1862
        batch_size,
1863
        q_seq_len,
1864
        kv_seq_len,
1865
        n_head,
1866
        head_dim,
1867
        mask_dim,
1868
        bool_mask,
1869
        train,
1870
    ):
1871
        tol = Tolerances(1e-5, 5e-6)
1872
        if dtype is torch.bfloat16:
1873
            tol = Tolerances(5e-2, 5e-2)
1874
        if dtype is torch.float16:
1875
            tol = Tolerances(1e-2, 1e-2)
1876

1877
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
1878
        q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
1879
        kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
1880
        q = make_tensor(q_shape)
1881
        k = make_tensor(kv_shape)
1882
        v = make_tensor(kv_shape)
1883
        q2, k2, v2 = q.clone(), k.clone(), v.clone()
1884

1885
        if train:
1886
            q.requires_grad_(True)
1887
            k.requires_grad_(True)
1888
            v.requires_grad_(True)
1889
            q2.requires_grad_(True)
1890
            k2.requires_grad_(True)
1891
            v2.requires_grad_(True)
1892

1893
        if dtype in [torch.bfloat16, torch.float16]:
1894
            q2, k2, v2 = q2.float(), k2.float(), v2.float()
1895
        # (B, nh, T, hs)
1896
        q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1897
        k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1898
        v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1899
        if mask_dim == 4:
1900
            mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len)
1901
        else:
1902
            mask_shape = (q_seq_len, kv_seq_len)
1903
        if bool_mask:
1904
            attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
1905
        else:
1906
            attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
1907
        q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1908
        k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1909
        v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1910

1911
        with sdpa_kernel(backends=[fused_kernel]):
1912
            actual = torch.nn.functional.scaled_dot_product_attention(
1913
                q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1914
        with sdpa_kernel(backends=[SDPBackend.MATH]):
1915
            if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
1916
                attn_mask = attn_mask.float()
1917
            math_ref = torch.nn.functional.scaled_dot_product_attention(
1918
                q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1919

1920
        if dtype in [torch.bfloat16, torch.float16]:
1921
            math_ref = math_ref.to(dtype)
1922

1923
        self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
1924

1925
        if train:
1926
            actual.sum().backward()
1927
            math_ref.sum().backward()
1928

1929
            grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
1930
            grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
1931

1932
            self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
1933
            self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
1934
            self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
1935

1936
    @parametrize("kernel", [SDPBackend.MATH])
1937
    def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
1938
        # https://github.com/pytorch/pytorch/issues/105190.
1939
        def ref(x):
1940
            v1 = torch.matmul(x, x.transpose(-1, -2))
1941
            v2 = v1 / -0.0001
1942
            v3 = v2.softmax(dim=-1)
1943
            v4 = torch.matmul(v3, x)
1944
            return v4
1945

1946
        x = torch.randn(1, 3, 64, 64, device=device)
1947
        ref_result = ref(x)
1948
        with sdpa_kernel(backends=[kernel]):
1949
            sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
1950
        self.assertEqual(ref_result, sdp_math)
1951

1952
class TestSDPACudaOnly(NNTestCase):
1953
    """ Used to test CUDA only functionality of scaled_dot_product_attention
1954
    Quarks:
1955
        There is some trickiness with this function. It's runtime behavior
1956
        is dependent on the CUDA architecture you are testing it on. See
1957
        `PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file.
1958
        Summary:
1959
            Math: always supported
1960
            FlashAttention: Supported on sm80 or newer hardware
1961
            MemEfficientAttention: Supported on sm50 or newer hardware
1962
    """
1963
    _do_cuda_memory_leak_check = True
1964
    _do_cuda_non_default_stream = True
1965

1966
    def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False):
1967
        """FlashAttention stores the S matrix in a different way.
1968
        Arguments:
1969
            S: (batch_size, nheads, seqlen_q, seqlen_k)
1970
            query_padding_mask: (batch_size, seqlen_q)
1971
            key_padding_mask: (batch_size, seqlen_k)
1972
        """
1973
        if TEST_WITH_ROCM:
1974
            return S
1975

1976
        b, h, seqlen_q, seqlen_k = S.shape
1977
        warps_n = 4
1978
        blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
1979
        nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
1980
        nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
1981
        mmas_n = (blocksize_n + 16 - 1) // 16
1982

1983
        # Reshape S using PyTorch native functions
1984
        S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n)
1985
        S_flat = S_flat.permute(0, 1, 2, 4, 3, 5)
1986
        S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n))
1987
        S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2)
1988
        S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11)
1989
        S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) *
1990
                                          warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2))
1991

1992
        if causal:
1993
            causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
1994
            S_converted.masked_fill_(causal_mask, 0.0)
1995
        # Need to zero out things not in attention_mask in case S was initialized with random values
1996
        # and some of those values aren't overwritten.
1997
        seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
1998
        if query_padding_mask is not None:
1999
            if seqlen_q_og < seqlen_q:
2000
                query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
2001
            else:
2002
                query_padding_mask = query_padding_mask[:, :seqlen_q]
2003
            q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
2004
            S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
2005
        seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
2006
        if key_padding_mask is not None:
2007
            if seqlen_k_og < seqlen_k:
2008
                key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
2009
            else:
2010
                key_padding_mask = key_padding_mask[:, :seqlen_k]
2011
            k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1])
2012
            S_converted = S_converted.masked_fill(k_mask_fill, 0.0)
2013
        if seqlen_q_og < seqlen_q:
2014
            S_converted = S_converted[:, :, :seqlen_q_og, :]
2015
        else:
2016
            S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
2017
        if seqlen_k_og < seqlen_k:
2018
            S_converted = S_converted[:, :, :, :seqlen_k_og]
2019
        else:
2020
            S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
2021
        return S_converted
2022

2023
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2024
    @parametrize("mask_dim", [1, 2, 3, 4])
2025
    def test_mem_efficient_attetntion_mask_variants(self, device, mask_dim: List[int]):
2026
        dtype = torch.float16
2027
        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2028
        batch, num_heads, head_dim = 8, 8, 64
2029
        seq_len_q, seq_len_kv = 64, 32
2030
        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2031
        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2032
        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2033

2034
        if mask_dim == 1:
2035
            mask = torch.randn((seq_len_kv,), device=device, dtype=dtype)
2036
        elif mask_dim == 2:
2037
            mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype)
2038
        elif mask_dim == 3:
2039
            mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2040
        elif mask_dim == 4:
2041
            mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2042
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2043
            out = F.scaled_dot_product_attention(query, key, value, mask)
2044
        out.sum().backward()
2045

2046
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2047
    @parametrize("dtype", [torch.float, torch.float16])
2048
    def test_mem_eff_attention_pad_mask(self, device, dtype):
2049
        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2050
        batch, num_heads, head_dim = 8, 8, 64
2051
        seq_len_q, seq_len_kv = 64, 15
2052
        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2053
        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2054
        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2055
        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2056
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2057
            out = F.scaled_dot_product_attention(query, key, value, mask)
2058
        out.sum().backward()
2059

2060
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2061
    @parametrize("dtype", [torch.float, torch.float16])
2062
    def test_mem_eff_attention_non_contiguous_mask(self, device, dtype):
2063
        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2064
        batch, num_heads, head_dim = 8, 8, 64
2065
        seq_len_q, seq_len_kv = 64, 16
2066
        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2067
        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2068
        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2069
        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2070
        mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
2071
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2072
            out = F.scaled_dot_product_attention(query, key, value, mask)
2073
        out.sum().backward()
2074

2075
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2076
    @parametrize("dtype", [torch.float, torch.float16])
2077
    def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
2078
        if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30:
2079
            unittest.skip("This test requires substatnial GPU memory.")
2080
            return
2081
        make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2082
        batch, num_heads, head_dim = 1, 32, 64
2083
        seq_len_q, seq_len_kv = 8192, 8192
2084
        query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2085
        kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2086
        key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2087
        mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2088
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2089
            out = F.scaled_dot_product_attention(query, key, value, mask)
2090
        out.sum().backward()
2091

2092
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2093
    def test_mem_eff_attention_non_contig_mask_bug(self, device):
2094
        # Without the fix this produces `AssertionError: assert 0.07352933287620544 < 1e-07`
2095
        # Shapes taken from repro
2096
        query_size = (3, 16, 1, 128)
2097
        query_strides = (2304, 128, 2048, 1)
2098
        key_size = (3, 16, 14, 128)
2099
        key_strides = (3584, 0, 256, 1)
2100
        value_size = (3, 16, 14, 128)
2101
        value_strides = (3584, 0, 256, 1)
2102
        attention_mask_size = (3, 1, 1, 14)
2103
        attn_mask_strides = (14, 14, 14, 1)
2104

2105
        # Calculate the number of elements needed for each tensor
2106
        query_num_elements = max([size * stride for size, stride in zip(query_size, query_strides)])
2107
        key_num_elements = max([size * stride for size, stride in zip(key_size, key_strides)])
2108
        value_num_elements = max([size * stride for size, stride in zip(value_size, value_strides)])
2109
        attention_mask_num_elements = max([size * stride for size, stride in zip(attention_mask_size, attn_mask_strides)])
2110

2111
        # Create the tensors with the specified sizes and strides
2112
        query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides)
2113
        key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides)
2114
        value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
2115
        bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
2116

2117
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2118
            out = F.scaled_dot_product_attention(query, key, value, bias)
2119
            out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
2120

2121
        max_diff = (out - out_contig).abs().mean()
2122
        self.assertTrue(max_diff.item() < 1e-7)
2123

2124
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
2125
    def test_singelton_head_dim_stride_ne_1(self, device):
2126
        query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
2127
        query = query.transpose(-1, -2)
2128
        key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2129
        value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2130

2131
        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
2132
            scaled_dot_product_attention(query, key, value)
2133

2134
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2135
    @parametrize("type", ["dense", "nested"])
2136
    @parametrize("is_contiguous", [True, False])
2137
    def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
2138
        make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
2139

2140
        batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
2141
        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2142

2143
        # Test Packed
2144
        qkv = make_tensor(shape)
2145
        query, key, value = qkv.chunk(3, dim=-1)
2146

2147
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2148
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2149
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2150

2151
        if is_contiguous:
2152
            query = query.contiguous()
2153
            key = key.contiguous()
2154
            value = value.contiguous()
2155

2156
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2157
            actual = torch.nn.functional.scaled_dot_product_attention(
2158
                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2159
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2160
            math_ref = torch.nn.functional.scaled_dot_product_attention(
2161
                query.contiguous(), key.contiguous(), value.contiguous(),
2162
                attn_mask=None, dropout_p=0.0, is_causal=False)
2163

2164
        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
2165

2166
    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2167
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2168
    @parametrize("type", ["dense", "nested"])
2169
    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2170
                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2171
    def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str):
2172
        def rand_nt(shape):
2173
            batch, seq_len, num_heads, head_dim = shape
2174
            tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2175
                       for _ in range(batch)]
2176
            return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32),
2177
                    torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16))
2178

2179
        def rand_tensor(shape):
2180
            batch, seq_len, num_heads, head_dim = shape
2181
            tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2182
            return tensor, tensor.to(dtype=torch.float16)
2183

2184
        batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64
2185
        shape = (batch_size, seq_len, num_heads, head_dim)
2186

2187
        # Test Packed
2188
        qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape)
2189
        query, key, value = qkv.chunk(3, dim=-1)
2190
        query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1)
2191

2192
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2193
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2194
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2195

2196
        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2197
        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2198
        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2199

2200
        with sdpa_kernel(backends=[fused_kernel]):
2201
            actual = torch.nn.functional.scaled_dot_product_attention(
2202
                query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
2203

2204
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2205
            math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
2206
                query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
2207
                attn_mask=None, dropout_p=0.0, is_causal=False)
2208

2209
            math_query = query.contiguous()
2210
            math_key = key.contiguous()
2211
            math_value = value.contiguous()
2212

2213
            math_ref = torch.nn.functional.scaled_dot_product_attention(
2214
                math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False)
2215

2216
        actual_test = actual
2217
        math_ref_test = math_ref
2218
        math_ref_lp_test = math_ref_lp
2219

2220
        if actual_test.is_nested:
2221
            actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0)
2222
            math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0)
2223
            math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0)
2224

2225
        actual_test = actual_test.to(dtype=torch.float32).contiguous()
2226
        math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
2227
        math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()
2228

2229
        self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
2230
        self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
2231

2232
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Flash Attention was not built for this system")
2233
    @parametrize("contiguous_inputs", [True, False])
2234
    @parametrize("is_causal", [True, False])
2235
    def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool):
2236
        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2237
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2238
                              dtype=torch.float64, requires_grad=True, packed=True)
2239

2240
        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2241
        qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
2242

2243
        query, key, value = qkv.chunk(3, dim=-1)
2244
        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2245

2246
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2247
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2248
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2249

2250
        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2251
        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2252
        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2253

2254
        if contiguous_inputs:
2255
            query = query.contiguous()
2256
            key = key.contiguous()
2257
            value = value.contiguous()
2258

2259
            query_lp = query_lp.contiguous()
2260
            key_lp = key_lp.contiguous()
2261
            value_lp = value_lp.contiguous()
2262

2263
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2264
            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2265

2266
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2267
            out_lp = torch.nn.functional.scaled_dot_product_attention(
2268
                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2269

2270
        rand_upward = torch.rand_like(out)
2271
        rand_upward_lp = rand_upward.to(torch.float32)
2272

2273
        out.backward(rand_upward)
2274
        out_lp.backward(rand_upward_lp)
2275

2276
        # Cast up and compare
2277
        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
2278

2279
    @skipIfRocm  # Small matrices
2280
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
2281
    @parametrize("contiguous_inputs", [True, False])
2282
    @parametrize("is_causal", [True, False])
2283
    @parametrize("dtype", [torch.float16, torch.bfloat16])
2284
    def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype):
2285
        batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2286
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2287
                              dtype=torch.float64, requires_grad=True, packed=True)
2288

2289
        qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2290
        qkv_lp = qkv.detach().clone().to(dtype).requires_grad_()
2291

2292
        query, key, value = qkv.chunk(3, dim=-1)
2293
        query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2294

2295
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2296
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2297
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2298

2299
        query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2300
        key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2301
        value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2302

2303
        if contiguous_inputs:
2304
            query = query.contiguous()
2305
            key = key.contiguous()
2306
            value = value.contiguous()
2307

2308
            query_lp = query_lp.contiguous()
2309
            key_lp = key_lp.contiguous()
2310
            value_lp = value_lp.contiguous()
2311

2312
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2313
            out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2314

2315
        with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2316
            out_lp = torch.nn.functional.scaled_dot_product_attention(
2317
                query_lp, key_lp, value_lp, None, 0.0, is_causal)
2318

2319
        rand_upward = torch.rand_like(out)
2320
        rand_upward_lp = rand_upward.to(dtype)
2321

2322
        out.backward(rand_upward)
2323
        out_lp.backward(rand_upward_lp)
2324

2325
        # Cast up and compare
2326
        # Since we are doing the compute on fp16 we have to bump the tolerance
2327
        # Bump down the tolearnce for blfoat16
2328
        atol = 7e-4 if dtype == torch.float16 else 7e-3
2329
        rtol = 7e-4 if dtype == torch.float16 else 7e-3
2330
        self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
2331

2332
    @skipIfRocm  # Missing nested and EFFICIENT_ATTENTION
2333
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
2334
    @parametrize("type", ["dense", "nested"])
2335
    def test_fused_sdp_choice(self, device, type: str):
2336
        batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
2337
        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2338
        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
2339

2340
        qkv = make_tensor(shape, type=type)
2341
        query, key, value = qkv.chunk(3, dim=-1)
2342

2343
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2344
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2345
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2346

2347
        if PLATFORM_SUPPORTS_FLASH_ATTENTION:
2348
            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value
2349
        else:
2350
            assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2351

2352
        # Change dtype to float32 so that efficient attention should get chosen
2353
        make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
2354

2355
        qkv = make_tensor(shape, type=type)
2356
        query, key, value = qkv.chunk(3, dim=-1)
2357

2358
        query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2359
        value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2360
        key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2361

2362
        assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2363

2364
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2365
    @parametrize("warn_only", [True, False])
2366
    def test_sdp_choice_with_determinism(self, device, warn_only):
2367
        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2368
        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2369
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False)
2370
        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2371

2372
        with use_deterministic_algorithims(True, warn_only=warn_only):
2373
            # Note that this should swith to a testing version with we remove old context manager
2374
            with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
2375
                assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2376

2377
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2378
    @parametrize("warn_only", [True, False])
2379
    def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only):
2380
        batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2381
        shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2382
        make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True)
2383
        query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2384

2385
        warning_context = (
2386
            self.assertWarnsRegex(
2387
                UserWarning,
2388
                "Memory Efficient attention defaults to a non-deterministic algorithm.",
2389
            )
2390
            if warn_only
2391
            else contextlib.nullcontext()
2392
        )
2393
        with use_deterministic_algorithims(True, warn_only=warn_only):
2394
            with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2395
                with warning_context:
2396
                    torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
2397

2398
    @unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD")
2399
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA")
2400
    def test_mem_eff_backwards_determinism(self, device):
2401
        # Need big seq_len to ensure that num_splits > 1
2402
        dtype = torch.float32
2403
        batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64
2404
        query = torch.rand(batch_size, n_heads, seq_len, head_dim,
2405
                           device=device, dtype=dtype, requires_grad=True)
2406
        key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device,
2407
                         dtype=dtype, requires_grad=True)
2408
        value = torch.rand(batch_size, n_heads, seq_len, head_dim,
2409
                           device=device, dtype=dtype, requires_grad=True)
2410

2411
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2412
            # Run once to establish baseline
2413
            out = F.scaled_dot_product_attention(query, key, value)
2414
            upward_grad = torch.rand_like(out)
2415
            out.backward(upward_grad)
2416
            intial_query_grad = query.grad
2417

2418
            # Re-run the op with the same upward grad and check that the backward is
2419
            # not deterministic
2420
            diff_anwser_once = False
2421
            for _ in range(100):
2422
                query.grad = None
2423
                out = F.scaled_dot_product_attention(query, key, value)
2424
                out.backward(upward_grad)
2425
                if not torch.equal(intial_query_grad, query.grad):
2426
                    diff_anwser_once = True
2427
                    break
2428
            self.assertTrue(diff_anwser_once)
2429

2430
        with use_deterministic_algorithims(True, warn_only=False):
2431
            query.grad = None
2432
            out = F.scaled_dot_product_attention(query, key, value)
2433
            upward_grad = torch.rand_like(out)
2434
            out.backward(upward_grad)
2435
            intial_query_grad = query.grad
2436

2437
            # Re-run the op with the same upward grad and check that the backward is
2438
            # deterministic now that we have enforced it
2439
            diff_anwser_once = False
2440
            for _ in range(100):
2441
                query.grad = None
2442
                out = F.scaled_dot_product_attention(query, key, value)
2443
                out.backward(upward_grad)
2444
                if not torch.equal(intial_query_grad, query.grad):
2445
                    diff_anwser_once = True
2446
                    break
2447
            self.assertFalse(diff_anwser_once)
2448

2449
    # verified passing successfully on H100
2450
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2451
    @parametrize("batch_size", [1, 8])
2452
    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
2453
    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
2454
    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64])
2455
    @parametrize("is_causal", [False, True])
2456
    @parametrize("dropout_p", [0.0, 0.22])
2457
    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
2458
                 SM80OrLater else [torch.float16, torch.float32])
2459
    @parametrize("scale", [None, "l1"])
2460
    def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2461
                                                       head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2462
                                                       scale: str):
2463
        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2464
            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2465
            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2466
            mask = (rand_uniform > p).to(torch.float32)
2467
            return mask
2468
        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2469
            unittest.skip("Reference implementation OOM")
2470
            return
2471
        seed = 42
2472
        scale = scale if scale is None else (1 / head_dim)
2473
        n_heads = 4
2474
        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2475
                           device=device, dtype=dtype, requires_grad=True)
2476
        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2477
                         dtype=dtype, requires_grad=True)
2478
        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2479
                           device=device, dtype=dtype, requires_grad=True)
2480

2481
        # Run the math kernel on low precision references
2482
        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2483

2484
        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2485
        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2486

2487
        # Create real output
2488
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2489
            # Set the seed and run the kernel
2490
            torch.manual_seed(seed)
2491
            out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2492

2493
        if dropout_p == 0.0:
2494
            with sdpa_kernel(backends=[SDPBackend.MATH]):
2495
                # High Precision Math Reference
2496
                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2497
                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2498
                # Low Precision Math Reference
2499
                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2500
                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2501
        else:
2502
            if seq_len_q > 1024:
2503
                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2504
            # Create the dropout_mask
2505
            torch.manual_seed(seed)
2506
            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device)
2507
            # High Precision Math Reference
2508
            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2509
                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2510
            # Low Precision Math Reference
2511
            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2512
                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2513
                dropout_mask=dropout_mask)[0]
2514

2515
        upstream_grad = torch.rand_like(out, requires_grad=False)
2516

2517
        out.backward(upstream_grad)
2518
        out_ref.backward(upstream_grad.to(out_ref.dtype))
2519
        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2520

2521
        # [Note] Fused Tolerances
2522
        # Establish the numerical error between the "true" high precision math output
2523
        # and the low precision math reference. We use this reference for the atol
2524
        # And we use the default rtol for the low precision type.
2525
        # We then provide a fudge factor for gradients respectively to account
2526
        # for the use of the fused kernel rather than the eager implemntation.
2527
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2528

2529
        # Fudge Factor when dropout is enabled
2530
        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0
2531

2532
        query_fudge_factor = dropout_fudge_factor
2533
        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2534

2535
        # TODO: Investigate why grad_k needs larger tolerances
2536
        key_fudge_factor = 8 * dropout_fudge_factor
2537
        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2538

2539
        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2540
        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2541

2542
        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2543
        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2544
                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2545
        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2546
                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2547
        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2548
                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2549

2550

2551
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2552
    @parametrize("batch_size", [1, 8])
2553
    @parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512])
2554
    @parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512])
2555
    @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64])
2556
    @parametrize("is_causal", [False])
2557
    @parametrize("dropout_p", [0.0, 0.22])
2558
    @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
2559
                 SM80OrLater else [torch.float16, torch.float32])
2560
    @parametrize("scale", [None, "l1"])
2561
    def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
2562
                                                                 seq_len_k: int, head_dim: int, is_causal: bool,
2563
                                                                 dropout_p: float, dtype: torch.dtype,
2564
                                                                 scale: str):
2565
        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2566
            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2567
            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2568
            mask = (rand_uniform > p).to(torch.float32)
2569
            return mask
2570
        if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2571
            unittest.skip("Reference implementation OOM")
2572
            return
2573
        seed = 42
2574
        scale = scale if scale is None else (1 / head_dim)
2575
        n_heads = 4
2576
        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2577
                           device=device, dtype=dtype, requires_grad=True)
2578
        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2579
                         dtype=dtype, requires_grad=True)
2580
        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2581
                           device=device, dtype=dtype, requires_grad=True)
2582

2583
        attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
2584

2585
        # Run the math kernel on low precision references
2586
        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2587
        attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True)
2588

2589
        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2590
        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2591
        attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
2592

2593
        # Create real output
2594
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2595
            # Set the seed and run the kernel
2596
            torch.manual_seed(seed)
2597
            out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
2598
                                                 is_causal=is_causal, scale=scale)
2599

2600
        if dropout_p == 0.0:
2601
            with sdpa_kernel(backends=[SDPBackend.MATH]):
2602
                # High Precision Math Reference
2603
                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
2604
                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2605
                # Low Precision Math Reference
2606
                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2607
                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2608
        else:
2609
            if seq_len_q > 1024:
2610
                self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2611
            # Create the dropout_mask
2612
            torch.manual_seed(seed)
2613
            dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q,
2614
                                                  seq_len_k, dropout_p, seed, 0, device=device)
2615
            # High Precision Math Reference
2616
            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2617
                query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal,
2618
                scale=scale, dropout_mask=dropout_mask)[0]
2619
            # Low Precision Math Reference
2620
            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2621
                query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2622
                dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2623
                dropout_mask=dropout_mask)[0]
2624

2625
        upstream_grad = torch.rand_like(out, requires_grad=False)
2626

2627
        out.backward(upstream_grad)
2628
        out_ref.backward(upstream_grad.to(out_ref.dtype))
2629
        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2630

2631
        # [Note] Fused Tolerances
2632
        # Establish the numerical error between the "true" high precision math output
2633
        # and the low precision math reference. We use this reference for the atol
2634
        # And we use the default rtol for the low precision type.
2635
        # We then provide a fudge factor for gradients respectively to account
2636
        # for the use of the fused kernel rather than the eager implemntation.
2637
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2638

2639
        # Fudge Factor when dropout is enabled
2640
        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75
2641
        mask_fudge_factor = 1.0 if attn_mask is None else 1.5
2642

2643
        query_fudge_factor = dropout_fudge_factor
2644
        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2645

2646
        # TODO: Investigate why grad_k needs larger tolerances
2647
        key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor
2648
        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2649

2650
        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2651
        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2652

2653
        mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22
2654
        grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances(
2655
            attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor)
2656

2657
        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2658
        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2659
                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2660
        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2661
                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2662
        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2663
                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2664

2665
        self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype),
2666
                         atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
2667

2668
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2669
    @parametrize("batch_size", [1, 8])
2670
    @parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
2671
    @parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
2672
    @parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
2673
    @parametrize("is_causal", [True, False])
2674
    @parametrize("dropout_p", [0.0, 0.22, 0.48])
2675
    @parametrize("dtype", [torch.float16, torch.bfloat16])
2676
    @parametrize("scale", [None, "l1"])
2677
    def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2678
                                               head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2679
                                               scale: str):
2680
        if TEST_WITH_ROCM:
2681
            def is_power_of_2(n):
2682
                return n & (n - 1) == 0
2683
            if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
2684
                self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
2685
            if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
2686
                self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
2687
            if head_dim > 128:
2688
                self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
2689

2690
        if isSM8XDevice and head_dim in range(193, 256 + 1):
2691
            self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
2692
        if is_causal and seq_len_q != seq_len_k:
2693
            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2694

2695
        scale = scale if scale is None else (1 / head_dim)
2696
        n_heads = 4
2697
        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2698
                           device=device, dtype=dtype, requires_grad=True)
2699
        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2700
                         dtype=dtype, requires_grad=True)
2701
        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2702
                           device=device, dtype=dtype, requires_grad=True)
2703

2704
        # Run the math kernel on low precision references
2705
        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2706

2707
        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2708
        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2709

2710
        is_dropout = dropout_p > 0.0
2711

2712
        if not is_dropout:
2713
            # Problem: We pad sizes in the composite region of the top level SDPA. But we need the
2714
            # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
2715
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2716
                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2717
            with sdpa_kernel(backends=[SDPBackend.MATH]):
2718
                # High Precision Math Reference
2719
                out_ref = F.scaled_dot_product_attention(
2720
                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
2721
                # Low Precision Math Reference
2722
                out_lp_ref = F.scaled_dot_product_attention(
2723
                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
2724
        else:
2725
            q_padded, q_og_size = pad_last_dim(query, 8)
2726
            k_padded, k_og_size = pad_last_dim(key, 8)
2727
            v_padded, v_og_size = pad_last_dim(value, 8)
2728
            # scale needs to be calculated on the og head_size
2729
            if scale is None:
2730
                scale = 1 / math.sqrt(q_og_size)
2731
            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
2732
                q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
2733
            out = output_tuple[0]
2734
            out = out[..., :v_og_size]
2735
            # Build dropout_mask
2736
            dbug_mask = output_tuple[-1]
2737
            query_padding_mask = torch.ones(
2738
                batch_size, seq_len_q, device=device, dtype=torch.bool)
2739
            key_padding_mask = torch.ones(
2740
                batch_size, seq_len_k, device=device, dtype=torch.bool)
2741

2742
            softmax_mask = self.convert_flash_attn_S_to_softmax(
2743
                dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim,
2744
                causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
2745
            dropout_mask = softmax_mask >= 0
2746
            # High Precision Math Reference
2747
            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2748
                query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2749
            # Low Precision Math Reference
2750
            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2751
                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2752
                dropout_mask=dropout_mask)[0]
2753

2754
        upstream_grad = torch.rand_like(out, requires_grad=False)
2755

2756
        # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled
2757
        if isSM8XDevice and head_dim in range(193, 256):
2758
            self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
2759
            return
2760
        out.backward(upstream_grad)
2761
        out_ref.backward(upstream_grad.to(out_ref.dtype))
2762
        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2763

2764
        # See [Note] Fused Tolerances above
2765
        output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
2766
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
2767

2768
        # TODO: Investigate why grad_q needs larger tolerances
2769
        query_fudge_factor = 4
2770
        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2771

2772
        key_fudge_factor = 2
2773
        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2774

2775
        value_fudge_factor = 2
2776
        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2777

2778
        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2779
        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2780
                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2781
        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2782
                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2783
        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2784
                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2785

2786
    @skipIfRocm  # FIXME: "capturing stream has unjoined work"
2787
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2788
    @parametrize("batch_size", [1, 8])
2789
    @parametrize("seq_len_q", [256, 512, 1024])
2790
    @parametrize("seq_len_k", [256, 512, 1024])
2791
    @parametrize("head_dim", [32, 64])
2792
    @parametrize("is_causal", [True, False])
2793
    @parametrize("dropout_p", [0.0, 0.22])
2794
    @parametrize("dtype", [torch.float16,])
2795
    @parametrize("scale", [None, "l1"])
2796
    @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2797
    def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2798
                                                         head_dim: int,
2799
                                                         is_causal: bool,
2800
                                                         dropout_p: float,
2801
                                                         dtype: torch.dtype,
2802
                                                         scale: str,
2803
                                                         fused_kernel: SDPBackend):
2804
        def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device):
2805
            mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2806
            rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset)
2807
            mask = (rand_uniform > dropout_p).to(torch.float32)
2808
            return mask
2809

2810
        def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device):
2811
            if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
2812
                output_seed, output_offset = output_tuple[2], output_tuple[3]
2813
                output_seed = output_seed.item()
2814
                output_offset = output_offset.item()
2815
                return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len,
2816
                                              dropout_p, output_seed, output_offset, device=device)
2817
            else:
2818
                # Build dropout_mask
2819
                dbug_mask = output_tuple[-1]
2820
                query_padding_mask = torch.ones(
2821
                    batch_size, seq_len_q, device=device, dtype=torch.bool)
2822
                key_padding_mask = torch.ones(
2823
                    batch_size, seq_len_k, device=device, dtype=torch.bool)
2824

2825
                softmax_mask = self.convert_flash_attn_S_to_softmax(
2826
                    dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
2827
                dropout_mask = softmax_mask >= 0
2828
                return dropout_mask
2829

2830
        if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
2831
            self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2832

2833
        seed = 42
2834
        scale = scale if scale is None else (1 / head_dim)
2835
        n_heads = 4
2836
        query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2837
                           device=device, dtype=dtype, requires_grad=True)
2838
        key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2839
                         dtype=dtype, requires_grad=True)
2840
        value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2841
                           device=device, dtype=dtype, requires_grad=True)
2842

2843
        fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention
2844
                    if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention)
2845
        # Run the math kernel on low precision references
2846
        query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2847

2848
        higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2849
        query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2850

2851
        # warmup
2852
        s = torch.cuda.Stream()
2853
        s.wait_stream(torch.cuda.current_stream())
2854
        # Set the global seed before capture
2855
        torch.manual_seed(seed)
2856
        kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale}
2857
        if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
2858
            kwargs["compute_log_sumexp"] = True
2859
            kwargs["attn_bias"] = None
2860
        if fused_kernel == SDPBackend.FLASH_ATTENTION:
2861
            kwargs['return_debug_mask'] = dropout_p > 0.0
2862
        with torch.cuda.stream(s):
2863
            # Create real output
2864
            output_tuple = fused_op(query, key, value, **kwargs)
2865

2866
        torch.cuda.current_stream().wait_stream(s)
2867
        out = output_tuple[0]
2868
        upstream_grad = torch.rand_like(out, requires_grad=False)
2869
        s.wait_stream(torch.cuda.current_stream())
2870
        with torch.cuda.stream(s):
2871
            out.backward(upstream_grad)
2872
        for x in (query, key, value):
2873
            x.grad = None
2874
        g = torch.cuda.CUDAGraph()
2875
        # Create real output
2876
        with torch.cuda.graph(g):
2877
            tmp = torch.rand_like(query, device=query.device)  # test non-zero intragraph offset
2878
            # Create real output
2879
            output_tuple = fused_op(query, key, value, **kwargs)
2880
            assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
2881
        g.replay()
2882
        out_first = output_tuple[0].clone()
2883
        g.replay()
2884
        out = output_tuple[0]
2885
        if dropout_p == 0.0:
2886
            self.assertEqual(out_first, out, atol=0, rtol=0)
2887
        else:
2888
            # replays produce different results
2889
            self.assertNotEqual(out_first, out)
2890

2891
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2892
            if dropout_p == 0.0:
2893
                # High Precision Math Reference
2894
                out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2895
                                                         dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2896
                # Low Precision Math Reference
2897
                out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2898
                                                            dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2899
            else:
2900
                # Create the dropout_mask
2901
                dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size,
2902
                                                n_heads, seq_len_q, seq_len_k, dropout_p, device)
2903
                # High Precision Math Reference
2904
                out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2905
                    query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
2906
                    scale=scale, dropout_mask=dropout_mask)[0]
2907
                # Low Precision Math Reference
2908
                out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2909
                    query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2910
                    dropout_mask=dropout_mask)[0]
2911

2912

2913
        g1 = torch.cuda.CUDAGraph()
2914
        with torch.cuda.graph(g1):
2915
            out.backward(upstream_grad)
2916
        g1.replay()
2917
        out_ref.backward(upstream_grad.to(out_ref.dtype))
2918
        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2919

2920
        # [Note] Fused Tolerances
2921
        # Establish the numerical error between the "true" high precision math output
2922
        # and the low precision math reference. We use this reference for the atol
2923
        # And we use the default rtol for the low precision type.
2924
        # We then provide a fudge factor for gradients respectively to account
2925
        # for the use of the fused kernel rather than the eager implemntation.
2926
        output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2927

2928
        # Fudge Factor when dropout is enabled
2929
        dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5
2930

2931
        query_fudge_factor = dropout_fudge_factor
2932
        grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2933

2934
        # TODO: Investigate why grad_k needs larger tolerances
2935
        key_fudge_factor = 8 * dropout_fudge_factor
2936
        grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2937

2938
        value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2939
        grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2940

2941
        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2942
        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2943
                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2944
        self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2945
                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2946
        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2947
                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2948

2949
    @skipIfRocm  # Nested Tensor
2950
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2951
    @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2952
                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2953
    def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel):
2954
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
2955
        batch, num_heads, head_dim = 32, 16, 64
2956
        seq_lens = torch.randint(low=1, high=32, size=(batch,))
2957
        # make sure some seq_lens are 1
2958
        num_ones = 10
2959
        indices = torch.randint(low=0, high=batch, size=(num_ones,))
2960
        seq_lens.scatter_(0, indices, 1)
2961

2962
        shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
2963
        query = rand_nested_tensor(shape)
2964
        key = rand_nested_tensor(shape)
2965
        value = rand_nested_tensor(shape)
2966

2967
        query = query.transpose(1, 2)
2968
        key = key.transpose(1, 2)
2969
        value = value.transpose(1, 2)
2970

2971
        with sdpa_kernel(backends=[fused_kernel]):
2972
            actual = torch.nn.functional.scaled_dot_product_attention(
2973
                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2974
        with sdpa_kernel(backends=[SDPBackend.MATH]):
2975
            math_ref = torch.nn.functional.scaled_dot_product_attention(
2976
                query.contiguous().to(torch.float32),
2977
                key.contiguous().to(torch.float32),
2978
                value.contiguous().to(torch.float32),
2979
                attn_mask=None, dropout_p=0.0, is_causal=False)
2980

2981
        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
2982

2983

2984
    @skipIfRocm  # Nested tensor
2985
    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2986
    @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2987
                 PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2988
    @parametrize("expand_q_batch", [True, False])
2989
    @parametrize("expand_k_batch", [True, False])
2990
    @parametrize("expand_v_batch", [True, False])
2991
    @parametrize("expand_q_num_heads", [True, False])
2992
    @parametrize("expand_k_num_heads", [True, False])
2993
    @parametrize("expand_v_num_heads", [True, False])
2994
    def test_fused_kernels_nested_broadcasting(
2995
        self,
2996
        device,
2997
        kernel,
2998
        expand_q_batch,
2999
        expand_k_batch,
3000
        expand_v_batch,
3001
        expand_q_num_heads,
3002
        expand_k_num_heads,
3003
        expand_v_num_heads,
3004
    ):
3005
        is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION
3006
        dtype = torch.float32 if is_efficient else torch.float16
3007
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
3008
        batch, num_heads, head_dim = 32, 8, 64
3009
        head_dim_v = 32 if is_efficient else head_dim
3010
        seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
3011
                      if expand_q_batch
3012
                      else torch.randint(low=1, high=32, size=(batch,)).tolist())
3013
        seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item()
3014
                       if (expand_k_batch or expand_v_batch)
3015
                       else torch.randint(low=1, high=32, size=(batch,)).tolist())
3016

3017
        batch_q = 1 if expand_q_batch else batch
3018
        batch_k = 1 if expand_k_batch else batch
3019
        batch_v = 1 if expand_v_batch else batch
3020

3021
        # handle case where all batch_sizes are 1
3022
        batch = max(batch_q, batch_k, batch_v)
3023

3024
        num_heads_q = 1 if expand_q_num_heads else num_heads
3025
        num_heads_k = 1 if expand_k_num_heads else num_heads
3026
        num_heads_v = 1 if expand_v_num_heads else num_heads
3027

3028
        # handle case where all num_heads are 1
3029
        num_heads = max(num_heads_q, num_heads_k, num_heads_v)
3030

3031
        q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim)
3032
        k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim)
3033
        v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v)
3034

3035
        query = rand_nested_tensor(q_shape)
3036
        key = rand_nested_tensor(k_shape)
3037
        value = rand_nested_tensor(v_shape)
3038

3039
        def _broadcast(t, batch_broadcasted, num_heads_broadcasted):
3040
            if batch_broadcasted and num_heads_broadcasted:
3041
                # (1, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3042
                result = torch.nested.nested_tensor(
3043
                    [t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32)
3044
            elif batch_broadcasted:
3045
                # (1, seq_len, num_heads, head_dim) -> (batch, seq_len, num_heads, head_dim)
3046
                result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32)
3047
            elif num_heads_broadcasted:
3048
                # (batch, seq_len, 1, head_dim) -> (batch, seq_len, num_heads, head_dim)
3049
                result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1))
3050
                                                    for x in t.unbind()], dtype=torch.float32)
3051
            else:
3052
                result = t.to(torch.float32)
3053
            return result
3054

3055
        query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2)
3056
        key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2)
3057
        value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2)
3058

3059
        query = query.transpose(1, 2)
3060
        key = key.transpose(1, 2)
3061
        value = value.transpose(1, 2)
3062

3063
        with sdpa_kernel(backends=[kernel]):
3064
            actual = torch.nn.functional.scaled_dot_product_attention(
3065
                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3066
        with sdpa_kernel(backends=[SDPBackend.MATH]):
3067
            math_ref = torch.nn.functional.scaled_dot_product_attention(
3068
                query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
3069
                attn_mask=None, dropout_p=0.0, is_causal=False)
3070

3071
        self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3072

3073
    @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
3074
    def test_fused_kernels_nested_broadcasting_query_dense(self, device):
3075
        rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
3076
        batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96
3077
        seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
3078
        q_shape = (1, 1, num_heads, head_dim)
3079
        k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
3080
        v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
3081

3082
        # create a dense query
3083
        query = torch.randn(q_shape, device=device, dtype=torch.float32)
3084
        key = rand_nested_tensor(k_shape)
3085
        value = rand_nested_tensor(v_shape)
3086

3087
        # (1, 1, num_heads, head_dim) -> (batch, 1, num_heads, head_dim)
3088
        query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2)
3089
        # (batch, seq_lens, 1, head_dim) -> (batch, seq_lens, num_heads, head_dim)
3090
        value_expanded = torch.nested.nested_tensor(
3091
            [t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2)
3092

3093
        query = query.transpose(1, 2)
3094
        key = key.transpose(1, 2)
3095
        value = value.transpose(1, 2)
3096

3097
        with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
3098
            actual = torch.nn.functional.scaled_dot_product_attention(
3099
                query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3100
        with sdpa_kernel(backends=[SDPBackend.MATH]):
3101
            math_ref = torch.nn.functional.scaled_dot_product_attention(
3102
                query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
3103
                attn_mask=None, dropout_p=0.0, is_causal=False)
3104

3105
        self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
3106

3107
    @onlyCUDA
3108
    @skipIfRocm  # Nested tensor
3109
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3110
    @parametrize("batch_size", [8, 32])
3111
    @parametrize("max_seq_len_q", [32, 256])
3112
    @parametrize("max_seq_len_kv", [32, 256])
3113
    @parametrize("head_dim", [8, 64])
3114
    @parametrize("dropout_p", [0.0, 0.1])
3115
    @parametrize("dtype", [torch.float16])
3116
    @parametrize("scale", [None, "l1"])
3117
    @parametrize("is_causal", [True, False])
3118
    def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
3119
                                                            head_dim: int, dropout_p: float, dtype: torch.dtype,
3120
                                                            scale: str, is_causal: bool):
3121
        if is_causal:
3122
            # TODO we should support this
3123
            self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True")
3124
            return
3125
        scale = scale if scale is None else (1 / head_dim)
3126
        n_heads = 4
3127
        seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,))
3128
        # Set one entry to max length
3129
        seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q
3130
        seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,))
3131
        seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv
3132

3133
        def rand_nt(sequence_list, num_heads, head_dim):
3134
            tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list]
3135
            return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype)
3136

3137
        query = rand_nt(seq_lens_q, n_heads, head_dim)
3138
        key = rand_nt(seq_lens_kv, n_heads, head_dim)
3139
        value = rand_nt(seq_lens_kv, n_heads, head_dim)
3140

3141
        # Run the math kernel on low precision references
3142
        query_ref_lp = query.clone().detach().requires_grad_(True)
3143
        key_ref_lp = key.clone().detach().requires_grad_(True)
3144
        value_ref_lp = value.clone().detach().requires_grad_(True)
3145

3146
        query_ref = query.clone().detach().to(torch.float32).requires_grad_(True)
3147
        key_ref = key.clone().detach().to(torch.float32).requires_grad_(True)
3148
        value_ref = value.clone().detach().to(torch.float32).requires_grad_(True)
3149

3150
        is_dropout = dropout_p > 0.0
3151

3152
        if not is_dropout:
3153
            with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
3154
                out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3155
            with sdpa_kernel(backends=[SDPBackend.MATH]):
3156
                # High Precision Math Reference
3157
                out_ref = F.scaled_dot_product_attention(
3158
                    query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
3159
                # Low Precision Math Reference
3160
                out_lp_ref = F.scaled_dot_product_attention(
3161
                    query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
3162
        else:
3163
            # Create real output
3164
            output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
3165
                query, key, value, dropout_p=dropout_p, is_causal=is_causal,
3166
                scale=scale, return_debug_mask=is_dropout)
3167
            out = output_tuple[0]
3168
            dbug_mask = output_tuple[-1]
3169

3170
            query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand(
3171
                batch_size, max_seq_len_q
3172
            ) < seq_lens_q.unsqueeze(-1)
3173
            query_padding_mask = query_padding_mask.to("cuda")
3174

3175
            key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand(
3176
                batch_size, max_seq_len_kv
3177
            ) < seq_lens_kv.unsqueeze(-1)
3178
            key_padding_mask = key_padding_mask.to("cuda")
3179

3180
            softmax_mask = self.convert_flash_attn_S_to_softmax(
3181
                dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
3182
            dropout_mask = softmax_mask >= 0
3183
            nt_stack = []
3184
            for tensor_component in range(batch_size):
3185
                batch_stack = []
3186
                for head in range(n_heads):
3187
                    batch_stack.append(dropout_mask[tensor_component, head,
3188
                                                    0:seq_lens_q[tensor_component],
3189
                                                    0:seq_lens_kv[tensor_component]].unsqueeze(0))
3190
                nt_stack.append(torch.cat(batch_stack))
3191
            nested_dropout_mask = torch.nested.nested_tensor(nt_stack)
3192
            # High Precision Math Reference
3193
            out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3194
                query_ref, key_ref, value_ref, dropout_p=dropout_p,
3195
                is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0]
3196
            # Low Precision Math Reference
3197
            out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3198
                query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3199
                dropout_mask=nested_dropout_mask)[0]
3200

3201
        upstream_grad = out.detach().clone().contiguous()
3202

3203
        out.backward(upstream_grad)
3204
        out_ref.backward(upstream_grad.to(out_ref.dtype))
3205
        out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3206

3207
        # See [Note] Fused Tolerances above
3208
        output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype)
3209
        grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad,
3210
                                                                   query.grad.dtype, fudge_factor=4)
3211
        grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype)
3212
        grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype)
3213

3214
        self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3215
        self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3216
                         atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3217
        self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype),
3218
                         atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3219
        self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3220
                         atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3221

3222
class TestAttnBias(NNTestCase):
3223

3224
    def run_test(
3225
        self,
3226
        device,
3227
        make_q,
3228
        make_kv,
3229
        attn_bias=None,
3230
        forw_tolerances: Optional[Tolerances] = None,
3231
        grad_tolerances: Optional[Tolerances] = None,
3232
        backend=None,
3233
    ):
3234
        if backend is not None:
3235
            torch._dynamo.reset()
3236

3237
        query, key, value = make_q(), make_kv(), make_kv()
3238
        query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value)
3239

3240
        realized = attn_bias._materialize(device) if attn_bias is not None else None
3241
        pytorch_output = scaled_dot_product_attention(
3242
            query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False
3243
        )
3244

3245
        sdpa_op = (
3246
            torch.compile(scaled_dot_product_attention, backend=backend)
3247
            if backend is not None
3248
            else scaled_dot_product_attention
3249
        )
3250
        sdpa_output = sdpa_op(
3251
            query_prototype,
3252
            key_prototype,
3253
            value_prototype,
3254
            attn_mask=attn_bias,
3255
            dropout_p=0.0,
3256
            is_causal=False,
3257
            scale=None,
3258
        )
3259

3260
        dOut = torch.randn_like(pytorch_output)
3261
        pytorch_output.backward(dOut)
3262
        sdpa_output.backward(dOut)
3263

3264
        # Use default assert_close tolerances for dtypes
3265
        if forw_tolerances is None:
3266
            forw_tolerances = Tolerances(atol=None, rtol=None)
3267
        if grad_tolerances is None:
3268
            grad_tolerances = Tolerances(atol=None, rtol=None)
3269

3270
        torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol)
3271
        torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3272
        torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3273
        torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3274

3275
    @skipIfRocm  # No support for the second variant for now
3276
    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3277
    @parametrize(
3278
        "shape",
3279
        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3280
    )
3281
    def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3282
        make_tensor = partial(
3283
            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3284
        )
3285

3286
        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3287
        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3288
        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3289
        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3290
            self.skipTest(
3291
                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3292
            )
3293

3294
        forw_tol = Tolerances(1e-3, 1e-3)
3295
        grad_tol = Tolerances(5e-3, 5e-3)
3296

3297
        if causal_variant == CausalVariant.UPPER_LEFT:
3298
            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3299
        else:
3300
            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3301

3302
        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
3303

3304
    @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3305
    @parametrize(
3306
        "shape",
3307
        [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3308
    )
3309
    @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
3310
    @unittest.skipIf(
3311
        sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
3312
    )
3313
    def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3314
        cnts = CompileCounterWithBackend("aot_eager")
3315
        make_tensor = partial(
3316
            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3317
        )
3318

3319
        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3320
        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3321
        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3322
        if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3323
            self.skipTest(
3324
                "Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3325
            )
3326
        forw_tol = Tolerances(1e-3, 1e-3)
3327
        grad_tol = Tolerances(5e-3, 5e-3)
3328

3329
        if causal_variant == CausalVariant.UPPER_LEFT:
3330
            attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3331
        else:
3332
            attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3333

3334
        self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
3335
        self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
3336

3337
    @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
3338
    def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
3339
        make_tensor = partial(
3340
            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3341
        )
3342

3343
        bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3344
        make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3345
        make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3346

3347
        forw_tol = Tolerances(1e-3, 1e-3)
3348
        grad_tol = Tolerances(5e-3, 5e-3)
3349

3350
        query = make_q_tensor()
3351
        key = make_kv_tensor()
3352
        value = make_kv_tensor()
3353
        attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3354

3355
        out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0)
3356
        out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0)
3357
        torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol)
3358

3359
    def test_is_causal_and_mask_fails(self, device):
3360
        make_tensor = partial(
3361
            torch.rand, device=device, dtype=torch.float16, requires_grad=True
3362
        )
3363
        make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3364
        make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3365

3366
        query = make_q_tensor()
3367
        key = make_kv_tensor()
3368
        value = make_kv_tensor()
3369
        attn_bias = causal_upper_left(128, 128)
3370

3371
        with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
3372
            scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
3373

3374
if NOTEST_CPU:
3375
    device_types = ("cuda", )
3376
else:
3377
    device_types = ("cpu", "cuda")
3378

3379
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
3380
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
3381
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
3382
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
3383
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
3384

3385
if __name__ == '__main__':
3386
    run_tests()
3387

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.