4
from functools import partial
5
from collections import namedtuple
9
import torch.nn.functional as F
10
from torch.nn.functional import scaled_dot_product_attention
11
from torch.nn.attention import sdpa_kernel, SDPBackend
12
from torch.nn.attention.bias import CausalVariant, causal_lower_right, causal_upper_left
13
from torch.nn.parameter import Parameter
15
from unittest.mock import patch, MagicMock, ANY
17
import torch.optim as optim
18
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
19
from typing import List, Tuple, Optional
20
from torch.testing._internal.common_nn import NNTestCase
21
from torch.testing._internal.common_utils import (
35
TEST_WITH_TORCHDYNAMO,
37
from torch._dynamo.testing import CompileCounterWithBackend
40
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
41
from torch.testing._internal.common_cuda import (
42
SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
43
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
44
PLATFORM_SUPPORTS_FUSED_ATTENTION,
45
PLATFORM_SUPPORTS_CUDNN_ATTENTION
49
import fairseq.models.transformer as fairseq_transformer
51
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
52
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
54
@contextlib.contextmanager
55
def use_deterministic_algorithims(mode: bool, warn_only: bool):
57
This context manager can be used to temporarily enable or disable deterministic algorithms.
58
Upon exiting the context manager, the previous state of the flag will be restored.
60
previous_mode: bool = torch.are_deterministic_algorithms_enabled()
61
previous_warn_only: bool = torch.is_deterministic_algorithms_warn_only_enabled()
63
torch.use_deterministic_algorithms(mode, warn_only=warn_only)
66
torch.use_deterministic_algorithms(previous_mode, warn_only=previous_warn_only)
70
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
71
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
73
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)]
74
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
75
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
76
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
78
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
79
deviation = true_value - computed_value
80
deviation = torch.abs(deviation / true_value)
82
torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
83
return deviation.max().item()
86
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
87
deviation = true_value - computed_value
88
atol = torch.abs(deviation).max().item()
93
true_value: torch.Tensor,
94
computed_value: torch.Tensor,
95
fudge_factor: Optional[float] = None,
96
) -> Tuple[float, float]:
97
"""Returns the absolute and relative tolerances for comparing two tensors."""
98
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
99
atol = get_atol(true_value, computed_value)
100
rtol = get_rtol(true_value, computed_value)
102
atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
103
rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
107
rtol = default_rtol[computed_value.dtype]
111
def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, dtype: torch.dtype = None):
112
""" Clones the query, key, and value tensors and moves them to the specified dtype. """
115
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
116
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
117
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
118
return query_ref, key_ref, value_ref
120
def get_platform_specific_sdpa():
122
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
123
ret.append(SDPBackend.FLASH_ATTENTION)
124
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
125
ret.append(SDPBackend.EFFICIENT_ATTENTION)
126
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
127
ret.append(SDPBackend.CUDNN_ATTENTION)
130
ret.append(SDPBackend.EFFICIENT_ATTENTION)
133
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
135
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
136
requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
137
"""Creates rand dense or nested tensor with given shape and type.
140
shape (Tuple[int]): Shape of Tensor to construct
141
device (str): which device to create tensor on
142
dtype (torch.dtype): Tensors' dtype
143
type (str): Nested or Dense
144
requires_grad (bool, optional): Tensors grad status. Defaults to False.
145
packed (bool, optional): Whether to create a single QKV packed or not. Defaults to False.
148
torch.Tensor: A new tensor
150
batch, num_heads, seq_len, head_dim = shape.batch, shape.num_heads, shape.seq_len, shape.head_dim
152
if isinstance(seq_len, list):
154
return (seq_len[i], num_heads, head_dim) if not packed else (seq_len[i], 3 * num_heads * head_dim)
156
return torch.nested.nested_tensor([
157
torch.randn(_size(i), device=device, dtype=dtype, requires_grad=requires_grad)
158
for i in range(batch)])
160
size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim)
161
return torch.nested.nested_tensor([
162
torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
163
for _ in range(batch)])
165
assert (isinstance(seq_len, int))
166
size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim)
167
return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad)
169
def calculate_nt_tolerances(nt_ref_hp, nt_ref_lp, default_dtype, fudge_factor=1):
171
ref_atol = default_atol[default_dtype]
172
ref_rtol = default_rtol[default_dtype]
173
for tensor_component_ref, tensor_component_ref_lp in zip(nt_ref_hp.unbind(), nt_ref_lp.unbind()):
174
ref_atol = max((fudge_factor * torch.abs(tensor_component_ref - tensor_component_ref_lp)).max().item(), ref_atol)
175
ref_rtol = max(get_rtol(tensor_component_ref, tensor_component_ref_lp), ref_rtol)
176
return ref_atol, ref_rtol
178
class TestTransformers(NNTestCase):
179
_do_cuda_memory_leak_check = True
180
_do_cuda_non_default_stream = True
183
@unittest.skip("4D mask not supported yet - activate when 4D mask supported")
184
def test_self_attn_TxT_attn_mask(self, device):
190
query = torch.rand(batch_size, tgt_len, embed_dim, device=device)
191
attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()
192
attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, 0.0)
194
attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
196
mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
200
with torch.inference_mode():
201
output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
202
output_mask_4d = output_mask_4d.transpose(0, 1)
204
output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
205
output_mask_TxT = output_mask_TxT.transpose(0, 1)
207
self.assertEqual(output_mask_4d, output_mask_TxT)
210
def test_train_with_pad_and_catch_error(self, device):
212
pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
213
layer = nn.TransformerEncoderLayer(
221
criterion = nn.MSELoss()
222
encoder = nn.TransformerEncoder(layer, 2).to(device)
223
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
225
for i in range(iters):
227
optimizer.zero_grad()
228
inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
230
outputs = encoder(inputs, src_key_padding_mask=pad_mask)
232
loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
236
with torch.no_grad():
237
test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)
242
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
243
except AssertionError as e:
245
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")
247
test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
253
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
254
except AssertionError as e:
256
self.assertFalse(e, "Failed to catch unsupported Long type exception")
258
test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
259
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
260
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
262
@parametrize("attn_mask_dim", [2, 3, None])
263
@parametrize("key_padding_mask_dim", [2, None])
264
@parametrize("mask_dtype", [torch.bool, torch.float32])
265
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
266
with torch.no_grad():
272
if attn_mask_dim == 2:
273
attn_mask = make_tensor((L, L), dtype=mask_dtype, device=device)
274
elif attn_mask_dim == 3:
275
attn_mask = make_tensor((B * H, L, L), dtype=mask_dtype, device=device)
276
elif attn_mask_dim is None:
279
if key_padding_mask_dim == 2:
280
key_padding_mask = make_tensor((B, L), dtype=mask_dtype, device=device)
281
elif key_padding_mask_dim is None:
282
key_padding_mask = None
284
mha = nn.MultiheadAttention(D, H, batch_first=True, device=device)
285
X = torch.randn(B, L, D, device=device)
288
out, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
290
out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
291
self.assertEqual(out, out_fp)
293
@parametrize("nhead", [1, 4, 8])
294
def test_transformerencoderlayer_src_mask(self, device, nhead):
300
model = torch.nn.TransformerEncoderLayer(
303
dim_feedforward=dim_feedforward,
304
batch_first=True).to(device)
305
src = torch.rand(batch_size, seqlen, d_model).to(device)
306
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
308
model(src, src_mask=src_mask)
310
with torch.no_grad():
311
model(src, src_mask=src_mask)
313
@parametrize("use_torchscript", [False])
314
@parametrize("enable_nested_tensor", [True, False])
315
@parametrize("use_autocast", [True, False])
316
@parametrize("d_model", [12, 256])
317
def test_transformerencoder_fastpath(self, device, use_torchscript, enable_nested_tensor, use_autocast, d_model):
319
Test TransformerEncoder fastpath output matches slowpath output
321
torch.manual_seed(1234)
323
dim_feedforward = d_model
326
model = torch.nn.TransformerEncoder(
327
torch.nn.TransformerEncoderLayer(
330
dim_feedforward=dim_feedforward,
331
batch_first=batch_first),
333
enable_nested_tensor=enable_nested_tensor
337
model = torch.jit.script(model)
342
torch.rand(3, 2, d_model),
350
torch.rand(2, 100, d_model),
358
torch.rand(2, 1024, d_model),
360
[0] * 1020 + [1] * 4,
365
torch.rand(1, 1026, d_model),
366
[[0] * 1024 + [1] * 2]
370
torch.rand(4, 1040, d_model),
372
[0] * 1024 + [1] * 16,
373
[0] * 1025 + [1] * 15,
374
[0] * 1031 + [1] * 9,
381
torch.tensor(pair[0], device=device, dtype=torch.get_default_dtype()),
382
torch.tensor(pair[1], device=device, dtype=torch.bool)
383
) for pair in input_mask_pairs
386
maybe_autocast = torch.autocast("cuda", dtype=torch.float16) if use_autocast else contextlib.nullcontext()
388
for input, src_key_padding_mask in input_mask_pairs:
389
with torch.no_grad():
390
fastpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
391
slowpath_output = model(input, src_key_padding_mask=src_key_padding_mask)
396
bs, true_seqlen, embed_dim = fastpath_output.shape
397
expanded_seqlen = src_key_padding_mask.shape[1]
398
fastpath_output_expanded = torch.zeros(bs, expanded_seqlen, embed_dim, device=device)
399
fastpath_output_expanded[:, :true_seqlen, :] = fastpath_output
401
fastpath_output_expanded = fastpath_output_expanded.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
402
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
403
torch.testing.assert_close(fastpath_output_expanded, slowpath_output, rtol=1e-7, atol=1e-5)
405
@parametrize("with_no_grad", [True, False])
406
@parametrize("training", [True, False])
407
@parametrize("enable_nested_tensor", [False])
408
def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
410
Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
411
batch size == sequence length
413
model = torch.nn.TransformerEncoder(
414
torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
416
enable_nested_tensor=enable_nested_tensor
419
with torch.no_grad():
421
for idx, p in enumerate(model.parameters()):
423
sz = x.view(-1).size(0)
425
x = torch.cos(torch.arange(0, sz).float().view(shape))
429
model = model.train()
432
x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.get_default_dtype()).to(device)
433
src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
438
cm = contextlib.nullcontext()
440
result = model(x, mask=src_mask)
442
ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
443
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
444
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
445
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
447
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
448
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
450
@parametrize("batch_first", [True, False])
451
@parametrize("training", [True, False])
452
@parametrize("enable_nested_tensor", [True, False])
453
def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
454
def get_a_test_layer(activation, batch_first=False):
460
layer = nn.TransformerEncoderLayer(
463
dim_feedforward=dim_feedforward,
465
activation=activation,
466
batch_first=batch_first,
469
with torch.no_grad():
471
for idx, p in enumerate(layer.parameters()):
473
sz = x.view(-1).size(0)
475
x = torch.cos(torch.arange(0, sz).float().view(shape))
483
def _test(batch_first, training, enable_nested_tensor):
485
return x.transpose(1, 0) if batch_first else x
487
encoder_layer = get_a_test_layer(activation=activation,
488
batch_first=batch_first)
490
model = nn.TransformerEncoder(
491
encoder_layer, 1, enable_nested_tensor=enable_nested_tensor
498
encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
499
[0.5387, 0.1655, 0.3565, 0.0471]],
500
[[0.8335, 0.2799, 0.5031, 0.2947],
501
[0.1402, 0.0318, 0.7636, 0.1346]],
502
[[0.6333, 0.9344, 0.1376, 0.9938],
503
[0.8924, 0.2872, 0.6692, 0.2944]],
504
[[0.9897, 0.6915, 0.3154, 0.1733],
505
[0.8645, 0.3513, 0.3064, 0.0767]],
506
[[0.8117, 0.2366, 0.4838, 0.7881],
507
[0.3718, 0.4945, 0.9511, 0.0864]]]
509
result = model(encoder_input)
510
ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
511
[2.427987, 0.021213, -0.602496, -0.084103]],
512
[[2.424689, 0.019155, -0.604793, -0.085672],
513
[2.413863, 0.022211, -0.612486, -0.072490]],
514
[[2.433774, 0.021598, -0.598343, -0.087548],
515
[2.425104, 0.019748, -0.604515, -0.084839]],
516
[[2.436185, 0.022682, -0.596625, -0.087261],
517
[2.433556, 0.021891, -0.598509, -0.086832]],
518
[[2.416246, 0.017512, -0.610712, -0.082961],
519
[2.422901, 0.024187, -0.606178, -0.074929]]]
521
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
522
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
525
src_mask = torch.zeros([5, 5]).to(device) == 1
526
result = model(encoder_input, mask=src_mask)
527
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
528
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
531
mask = torch.zeros([2, 5]).to(device) == 1
532
result = model(encoder_input, src_key_padding_mask=mask)
533
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
534
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
539
result = model(encoder_input, src_key_padding_mask=mask)
540
ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
541
[2.428811, 0.021445, -0.601912, -0.084252]],
542
[[2.425009, 0.019155, -0.604566, -0.085899],
543
[2.415408, 0.02249, -0.611415, -0.073]],
544
[[2.434199, 0.021682, -0.598039, -0.087699],
545
[2.42598, 0.019941, -0.603896, -0.085091]],
546
[[2.436457, 0.022736, -0.59643, -0.08736],
547
[2.434021, 0.022093, -0.598179, -0.08679]],
548
[[2.416531, 0.017498, -0.610513, -0.083181],
549
[2.4242, 0.024653, -0.605266, -0.074959]]]
551
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
552
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
555
model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
558
result = model(encoder_input, src_key_padding_mask=mask)
559
ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
560
[2.419102, 0.017452, -0.608703, -0.085026]],
561
[[2.419043, 0.017445, -0.608744, -0.084999],
562
[2.419052, 0.017446, -0.608738, -0.085004]],
563
[[2.419067, 0.017448, -0.608727, -0.085010],
564
[2.419098, 0.017452, -0.608706, -0.085024]],
565
[[2.419072, 0.017449, -0.608724, -0.085012],
566
[2.419119, 0.017455, -0.608691, -0.085034]],
567
[[2.419019, 0.017442, -0.608761, -0.084989],
568
[2.419075, 0.017449, -0.608722, -0.085014]]]
570
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
571
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
573
model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
576
result = model(encoder_input, src_key_padding_mask=mask)
577
ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
578
[2.419101, 0.017453, -0.608704, -0.085025]],
579
[[2.419101, 0.017453, -0.608703, -0.085025],
580
[2.419101, 0.017453, -0.608704, -0.085025]],
581
[[2.419101, 0.017453, -0.608703, -0.085025],
582
[2.419101, 0.017453, -0.608704, -0.085025]],
583
[[2.419101, 0.017453, -0.608703, -0.085025],
584
[2.419101, 0.017453, -0.608704, -0.085025]],
585
[[2.419101, 0.017453, -0.608703, -0.085025],
586
[2.419101, 0.017453, -0.608704, -0.085025]]]
588
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
589
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
593
norm = nn.LayerNorm(4)
594
model = nn.TransformerEncoder(encoder_layer, 2, norm=norm,
595
enable_nested_tensor=enable_nested_tensor).to(device)
598
result = model(encoder_input, src_key_padding_mask=mask)
599
ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
600
[1.695955, -0.357639, -0.893050, -0.445266]],
601
[[1.695948, -0.357634, -0.893082, -0.445233],
602
[1.695950, -0.357635, -0.893077, -0.445238]],
603
[[1.695951, -0.357636, -0.893069, -0.445246],
604
[1.695955, -0.357639, -0.893052, -0.445264]],
605
[[1.695952, -0.357636, -0.893066, -0.445249],
606
[1.695957, -0.357641, -0.893041, -0.445276]],
607
[[1.695946, -0.357632, -0.893095, -0.445220],
608
[1.695952, -0.357637, -0.893065, -0.445251]]]
610
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
611
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
613
model = nn.TransformerEncoder(encoder_layer, 6, norm=norm,
614
enable_nested_tensor=enable_nested_tensor).to(device)
617
result = model(encoder_input, src_key_padding_mask=mask)
618
ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
619
[1.695955, -0.357639, -0.893051, -0.445265]],
620
[[1.695955, -0.357639, -0.893051, -0.445265],
621
[1.695955, -0.357639, -0.893051, -0.445265]],
622
[[1.695955, -0.357639, -0.893051, -0.445265],
623
[1.695955, -0.357639, -0.893051, -0.445265]],
624
[[1.695955, -0.357639, -0.893051, -0.445265],
625
[1.695955, -0.357639, -0.893051, -0.445265]],
626
[[1.695955, -0.357639, -0.893051, -0.445265],
627
[1.695955, -0.357639, -0.893051, -0.445265]]]
629
self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
630
torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
636
with set_default_dtype(torch.double):
638
cm = contextlib.nullcontext()
642
_test(batch_first, training, enable_nested_tensor)
644
@unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
645
def test_encoder_padding_and_src_mask_bool(self):
646
encoder_layer = nn.TransformerEncoderLayer(
654
encoder_norm = nn.LayerNorm(16)
655
encoder = nn.TransformerEncoder(
656
encoder_layer, 2, encoder_norm
659
inputs = torch.randn(2, 3, 16)
661
src_mask = torch.ones(3, 3, dtype=torch.bool).triu_(diagonal=1)
662
input_seq_len = torch.tensor([3, 2])
664
torch.arange(3)[None, :].cpu() >= input_seq_len[:, None]
667
with (self.assertNoLogs(None) if not TEST_WITH_TORCHDYNAMO else contextlib.nullcontext()):
671
src_key_padding_mask=padding_mask,
674
@unittest.skipIf(sys.version_info < (3, 11), "not supported on pre-3.11 Python")
675
def test_decoder_padding_and_src_mask_bool(self):
677
def transformer_decoder(inputs, input_seq_len, memory):
678
decoder_layer = nn.TransformerDecoderLayer(
686
decoder_norm = nn.LayerNorm(16)
687
decoder = nn.TransformerDecoder(
688
decoder_layer, 2, decoder_norm
691
src_mask = torch.ones(
692
inputs.shape[1], inputs.shape[1], dtype=torch.bool
695
torch.arange(inputs.shape[1])[None, :].cpu()
696
>= input_seq_len[:, None]
703
tgt_key_padding_mask=padding_mask,
704
memory_key_padding_mask=padding_mask,
707
inputs = torch.randn(2, 3, 16)
708
memory = torch.randn(2, 3, 16)
709
input_seq_len = torch.tensor([3, 2])
711
with self.assertNoLogs(None):
712
transformer_decoder(inputs, input_seq_len, memory)
714
def test_encoder_is_causal(self):
717
layer = torch.nn.TransformerEncoderLayer(d_model, 1, 6, batch_first=True)
719
x = torch.randn(1, 5, d_model)
720
unmasked_output = layer(x)
721
mask = torch.nn.Transformer.generate_square_subsequent_mask(x.size(1))
722
is_causal_output = layer(x, src_mask=mask, is_causal=True)
723
masked_output = layer(x, src_mask=mask)
725
self.assertEqual(masked_output, is_causal_output)
728
@parametrize("nb_heads", [1, 8])
729
@parametrize("bias", [True, False])
730
def test_mha_native_args(self, nb_heads, bias):
732
B, L, F = 8, 100, 128
735
use_pad_mask = (bias % 2) == 1
737
mha = nn.MultiheadAttention(
740
batch_first=batch_first,
745
ctx = torch.no_grad if fast_path else contextlib.nullcontext
747
x = torch.randn(B, L, F).cuda()
749
x = x.transpose(0, 1)
753
pad_mask = torch.zeros((B, L), dtype=torch.bool).cuda()
755
mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
757
def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
758
encoder_layer = nn.TransformerEncoderLayer(
766
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
768
x = torch.randn(10, 6, 256).to(device)
769
mask = torch.ones(6, 10)
771
mask = mask.bool().to(device)
772
out = transformer_encoder(src=x, src_key_padding_mask=mask)
773
self.assertEqual(out.shape[1], 6)
778
def test_with_nested_tensor_input(self, device):
779
encoder_layer = nn.TransformerEncoderLayer(
787
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
789
transformer_encoder.eval()
790
with torch.no_grad():
791
x = torch.randn(6, 10, 256).to(device)
792
mask = torch.ones(6, 10)
797
mask = mask.bool().to(device)
798
x = torch._nested_tensor_from_mask(x, mask.logical_not(), mask_check=False)
799
out = transformer_encoder(src=x, src_key_padding_mask=None)
801
self.assertEqual(out.is_nested, True)
805
def test_script_encoder_subclass(self, device):
806
class MyCustomLayer(nn.TransformerEncoderLayer):
809
encoder = nn.TransformerEncoder(
810
MyCustomLayer(d_model=256, nhead=8), num_layers=6
812
torch.jit.script(encoder)
816
def test_transformerencoderlayer_subclass(self, device):
817
class MyCustomLayer(nn.TransformerEncoderLayer):
826
model = MyCustomLayer(
829
dim_feedforward=dim_feedforward,
830
batch_first=True).to(device)
831
script_model = torch.jit.script(model)
833
src = torch.rand(batch_size, seqlen, d_model).to(device)
834
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
836
torch.manual_seed(42)
837
result = model(src, src_mask=src_mask)
838
torch.manual_seed(42)
839
scripted_result = script_model(src, src_mask=src_mask)
840
self.assertEqual(result, scripted_result)
843
script_model = torch.jit.script(model)
845
with torch.no_grad():
846
result = model(src, src_mask=src_mask)
847
scripted_result = script_model(src, src_mask=src_mask)
848
self.assertEqual(result, scripted_result)
851
def test_transformerencoderlayer_subclass_model(self, device):
852
class MyCustomLayer(nn.TransformerEncoderLayer):
861
layer = MyCustomLayer(
864
dim_feedforward=dim_feedforward,
866
model = nn.TransformerEncoder(
869
script_model = torch.jit.script(model)
871
src = torch.rand(batch_size, seqlen, d_model).to(device)
872
src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
874
torch.manual_seed(42)
875
result = model(src, mask=src_mask)
876
torch.manual_seed(42)
877
scripted_result = script_model(src, mask=src_mask)
878
self.assertEqual(result, scripted_result)
881
script_model = torch.jit.script(model)
883
with torch.no_grad():
884
result = model(src, mask=src_mask)
885
scripted_result = script_model(src, mask=src_mask)
886
self.assertEqual(result, scripted_result)
890
@unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
891
def test_decoder_only_layer(self):
892
DEFAULT_PADDING_IDX = 0
894
class FairseqDecoder(torch.nn.Module):
903
normalize_before=False,
909
cfg = fairseq_transformer.TransformerConfig()
910
cfg.decoder.embed_dim = embed_dim
911
cfg.decoder.output_dim = embed_dim
912
cfg.decoder.attention_heads = attention_heads
913
cfg.decoder.ffn_embed_dim = ffn_embed_dim
914
cfg.dropout = dropout
915
cfg.decoder.normalize_before = normalize_before
916
cfg.decoder.layers = num_layers
918
cfg.no_token_positional_embeddings = True
919
cfg.no_scale_embedding = True
920
cfg.activation_fn = activation
924
self.decoder = fairseq_transformer.TransformerDecoder(
928
no_encoder_attn=True,
929
output_projection=None,
932
if torch_encoder is not None:
933
self.decoder = torch_to_fairseq(torch_encoder, self.decoder)
934
self.decoder = self.decoder.eval().cuda().half()
940
with_triangle_mask=False,
941
incremental_state=None,
944
prev_output_tokens=tokens,
946
incremental_state=incremental_state,
948
full_context_alignment=not with_triangle_mask,
949
alignment_layer=None,
950
alignment_heads=None,
951
src_lengths=src_lengths,
952
return_all_hiddens=False,
955
@parametrize("input_dim,attn_mask_dim,is_causal",
956
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
957
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
958
name_fn=lambda input_dim, attn_dim, is_causal: (
959
f"{input_dim}D_input_dim_" + (
960
f"{attn_dim}D_{'causal_' if is_causal else ''}attn_mask"
961
if attn_dim is not None else "no_attn_mask")))
962
@parametrize("dropout_p", [0.0, 0.2, 0.5])
963
@sdpa_kernel(backends=[SDPBackend.MATH])
964
def test_scaled_dot_product_attention(self, device, input_dim, attn_mask_dim, is_causal, dropout_p):
974
if attn_mask is not None:
975
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
977
attn = torch.bmm(q, k.transpose(-2, -1))
979
attn = torch.nn.functional.softmax(attn, dim=-1)
981
attn = torch.nn.functional.dropout(attn, p=dropout_p)
983
output = torch.bmm(attn, v)
986
dtypes = [torch.double, torch.float]
989
def rand_tensor(*shape):
990
return torch.randn(shape, device=device, dtype=dtype)
993
N, N_prime, L, S, E = 5, 2, 4, 3, 6
995
query = rand_tensor(N, L, E)
996
key = rand_tensor(N, S, E)
997
value = rand_tensor(N, S, E)
999
query = rand_tensor(N, N_prime, L, E)
1000
key = rand_tensor(N, N_prime, S, E)
1001
value = rand_tensor(N, N_prime, S, E)
1003
self.fail(f'Invalid input_dim {input_dim} encountered in SDP test')
1006
if attn_mask_dim is not None:
1007
assert attn_mask_dim in [2, input_dim]
1008
mask_size = (L, S) if attn_mask_dim == 2 else ((N, L, S) if input_dim == 3 else (N, N_prime, L, S))
1009
attn_mask = (torch.ones(mask_size, device=device, dtype=torch.bool).tril() if is_causal
1010
else torch.randint(0, 2, size=mask_size, device=device, dtype=torch.bool))
1012
with freeze_rng_state():
1014
attn_mask_float = attn_mask
1015
if attn_mask_float is not None:
1016
attn_mask_float = torch.zeros_like(attn_mask, dtype=query.dtype)
1017
attn_mask_float.masked_fill_(attn_mask.logical_not(), float("-inf"))
1018
q, k, v = query.view(-1, L, E), key.view(-1, S, E), value.view(-1, S, E)
1020
if a is not None and attn_mask_dim > 3:
1021
a = a.view(-1, L, S)
1022
expected = sdp_ref(q, k, v, attn_mask=a, dropout_p=dropout_p)
1024
expected = expected.view(-1, N_prime, L, E)
1026
with freeze_rng_state():
1029
actual = torch.nn.functional.scaled_dot_product_attention(
1030
query, key, value, None, dropout_p, is_causal)
1033
with self.assertRaisesRegex(RuntimeError,
1034
"Explicit attn_mask should not be set when is_causal=True"):
1035
torch.nn.functional.scaled_dot_product_attention(
1036
query, key, value, attn_mask, dropout_p, is_causal)
1038
actual = torch.nn.functional.scaled_dot_product_attention(
1039
query, key, value, attn_mask, dropout_p, is_causal)
1041
self.assertEqual(actual, expected)
1043
if attn_mask_dim is None:
1044
q = q.double().clone()
1045
k = k.double().clone()
1046
v = v.double().clone()
1051
assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs),
1052
(q, k, v, attn_mask, dropout_p))
1053
assert gradcheck(lambda *args, **kwargs:
1054
wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1055
(q, k, v, attn_mask, dropout_p))
1057
def test_incompatible_mask(self, device):
1058
def ones_tensor(*shape):
1059
return torch.ones(shape, dtype=torch.float32)
1060
S, L, E, H = 1, 2, 4, 1
1061
qkv = ones_tensor(S, L, E)
1063
mha = nn.MultiheadAttention(E, H)
1064
mha.in_proj_weight = Parameter(torch.ones((E * 3, E)))
1065
mha.out_proj.weight = Parameter(torch.ones((E, E)))
1067
kpm = ones_tensor(S, L) * float("-inf")
1068
am = ones_tensor(L, L).to(bool)
1071
return mha(qkv, qkv, qkv, need_weights=False, key_padding_mask=kpm, attn_mask=am)
1073
self.assertRaises(RuntimeError, func)
1075
@unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1077
def test_mask_check_fastpath(self):
1079
Test that fastpath is executed independently of the masks that are passed.
1080
If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used
1081
(sparsity fastpath), otherwise use fastpath with traditional tensors.
1082
Also test that fast path is executed with both key padding mask and attention mask passed at the same time.
1085
x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
1087
def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True):
1088
with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1089
fastpath_mock.return_value = mock_return_value
1090
model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask)
1093
self.assertTrue(fastpath_mock.called)
1096
for call_args, _ in fastpath_mock.call_args_list:
1097
self.assertEqual(call_args[0].is_nested, nested_tensors)
1099
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1101
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1104
aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
1105
not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
1106
attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool)
1107
nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
1108
tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
1111
_test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1114
_test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1116
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
1120
_test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1121
_test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False)
1123
_test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False)
1125
model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
1129
_test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1130
_test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True)
1133
def test_bias_is_none(self):
1134
x = torch.rand((1, 5, 10))
1135
model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True)
1140
def test_transformer_bias_is_none(self, device):
1146
encoder_layer = torch.nn.TransformerEncoderLayer(d_model, nhead, bias=False, batch_first=True, device=device)
1147
encoder_layer.eval()
1148
x = torch.randn(batch_size, seqlen, d_model, device=device)
1152
with self.assertWarnsRegex(UserWarning, "encoder_layer.self_attn was passed bias=False"):
1153
encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=1).eval()
1156
with self.assertWarnsRegex(UserWarning, "self_attn was passed bias=False"):
1157
transformer = torch.nn.Transformer(
1158
d_model=d_model, nhead=nhead, bias=False, batch_first=True, device=device
1162
def test_train_with_is_causal(self, device):
1164
S, L, E, H = 1, 2, 2, 1
1165
layer = nn.TransformerEncoderLayer(
1173
criterion = nn.MSELoss()
1174
encoder = nn.TransformerEncoder(layer, 2).to(device)
1175
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
1179
optimizer.zero_grad()
1180
inputs = torch.randn(S, L, E).to(device)
1181
mask = torch.nn.Transformer.generate_square_subsequent_mask(
1182
inputs.size(1), device=device
1185
outputs = encoder(inputs, mask=mask, is_causal=True)
1187
loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
1192
t_qvk = torch.randn((S, L, E), device=device, dtype=torch.float32)
1193
mha = nn.MultiheadAttention(E, H).to(device)
1194
mask = torch.nn.Transformer.generate_square_subsequent_mask(
1198
attn_out, _ = mha(t_qvk, t_qvk, t_qvk, attn_mask=mask, is_causal=True)
1201
attn_mask = torch.randint(0, 2, size=(L, L), device=device, dtype=torch.bool)
1202
with self.assertRaises(RuntimeError):
1203
_ = mha(t_qvk, t_qvk, t_qvk, is_causal=True)
1206
causal_mask = torch.triu(
1207
torch.ones(L, L, device=inputs.device) * float('-inf'), diagonal=1
1210
mock_layer = MagicMock(torch.nn.MultiheadAttention(E, H), return_value=inputs)
1211
encoder.layers[1] = mock_layer
1212
outputs = encoder(inputs, mask=causal_mask)
1213
mock_layer.assert_called_with(ANY, src_mask=ANY, is_causal=True, src_key_padding_mask=ANY)
1216
self.is_causal_kernels([SDPBackend.MATH], device)
1218
def is_causal_kernels(self, kernels, device):
1219
def ones_tensor(*shape):
1220
return torch.ones(shape, device=device, dtype=torch.float32).to(device)
1221
S, L, E, H = 1, 2, 4, 1
1222
qkv = ones_tensor(S, L, E)
1224
mha = nn.MultiheadAttention(E, H).to(device)
1225
mha.in_proj_weight = Parameter(torch.ones((E * 3, E), device=device))
1226
mha.out_proj.weight = Parameter(torch.ones((E, E), device=device))
1227
expected = torch.ones(size=(S, L, E)).to(device) * 16
1228
mask = torch.nn.Transformer.generate_square_subsequent_mask(
1229
qkv.size(1), device=device
1232
for kernel in kernels:
1233
with sdpa_kernel(backends=[kernel]):
1234
actual, _ = mha(qkv, qkv, qkv, attn_mask=mask, need_weights=False, is_causal=True)
1235
self.assertTrue(torch.equal(actual, expected))
1237
if kernel != SDPBackend.MATH:
1239
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1240
qkv_f, mha_f = ones_tensor(S, L, 2), nn.MultiheadAttention(2, H).to(device)
1241
mask = torch.nn.Transformer.generate_square_subsequent_mask(
1242
qkv_f.size(1), device=device
1244
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
1245
torch.cuda.synchronize()
1249
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
1251
def test_is_causal_gpu(self):
1253
self.is_causal_kernels([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION], device)
1255
def test_script_mha_in_proj_weight_none(self):
1256
mha = torch.nn.MultiheadAttention(
1257
embed_dim=128, num_heads=8, kdim=256, vdim=256
1260
torch.jit.script(mha)
1262
@unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
1264
def test_disable_fastpath(self, device):
1265
def _test_te_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1268
with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
1269
fastpath_mock.return_value = return_value
1270
output = model(*args, **kwargs)
1271
self.assertTrue(fastpath_mock.called == is_called)
1273
def _test_mha_fastpath_called(model, args, kwargs=None, return_value=None, is_called=True):
1276
with patch('torch._native_multi_head_attention') as fastpath_mock:
1277
fastpath_mock.return_value = return_value
1278
output = model(*args, **kwargs)
1279
self.assertTrue(fastpath_mock.called == is_called)
1281
inp = torch.tensor([[[1, 2], [3, 4], [5, 6]]], dtype=torch.float32, device=device)
1282
aligned_key_padding_mask = torch.tensor([[0, 0, 1]], dtype=torch.bool, device=device)
1283
src_key_padding_mask = torch.tensor([[1, 0, 1]], dtype=torch.bool, device=device)
1284
attn_mask = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]], dtype=torch.bool, device=device)
1285
te_return_value = torch.ones((1, 3, 2), dtype=torch.float32)
1287
encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
1288
te = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
1289
te = te.to(device).eval()
1291
t = torch.nn.Transformer(d_model=2, nhead=2, batch_first=True, device=device).eval()
1292
src = torch.tensor([[[0, 1], [2, 3], [4, 5]]], dtype=torch.float32, device=device)
1293
tgt = torch.tensor([[[0, 1], [2, 3], [4, 5], [6, 7]]], dtype=torch.float32, device=device)
1294
t_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1296
mha = nn.MultiheadAttention(2, 2, batch_first=True, device=device).eval()
1297
q = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.float32, device=device)
1298
mha_return_value = torch.ones((1, 3, 2), dtype=torch.float32, device=device)
1300
_test_te_fastpath_called(
1301
te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1302
return_value=te_return_value, is_called=True
1304
_test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1305
_test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1307
torch.backends.mha.set_fastpath_enabled(False)
1308
_test_te_fastpath_called(
1309
te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1310
return_value=te_return_value, is_called=False
1312
_test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=False)
1313
_test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=False)
1315
torch.backends.mha.set_fastpath_enabled(True)
1316
_test_te_fastpath_called(
1317
te, (inp,), kwargs={'src_key_padding_mask': src_key_padding_mask},
1318
return_value=te_return_value, is_called=True
1320
_test_te_fastpath_called(t, (src, tgt), return_value=t_return_value, is_called=True)
1321
_test_mha_fastpath_called(mha, (q, q, q,), return_value=mha_return_value, is_called=True)
1324
class TestSDPAFailureModes(NNTestCase):
1325
""" Used to test the failure modes of scaled_dot_product_attention
1327
_do_cuda_memory_leak_check = True
1328
_do_cuda_non_default_stream = True
1331
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
1332
"Does not support fused SDPA or not SM86+ hardware")
1333
@parametrize("head_dim", [193, 204, 256])
1334
def test_flash_backward_failure_sm86plus(self, device, head_dim: int):
1335
dtype = torch.float16
1336
make_tensor = partial(torch.rand, device=device, dtype=dtype)
1338
size = (2, 2, 4, head_dim)
1339
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1341
with sdpa_kernel(backends=[SDPBackend.MATH]):
1342
math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1344
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1346
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
1348
self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3)
1351
q = make_tensor(size, requires_grad=True)
1352
k = make_tensor(size, requires_grad=True)
1353
v = make_tensor(size, requires_grad=True)
1354
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1355
q, k, v, None, 0.0, False))
1358
def test_dispatch_fails_no_backend(self, device):
1359
dtype = torch.float16
1360
with sdpa_kernel(backends=[SDPBackend.ERROR]):
1362
q = torch.randn(size, device=device, dtype=dtype)
1363
k = torch.randn(size, device=device, dtype=dtype)
1364
v = torch.randn(size, device=device, dtype=dtype)
1365
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1366
lambda: torch._fused_sdp_choice(q, k, v))
1367
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
1368
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
1371
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1374
PLATFORM_SPECIFIC_SDPA,
1376
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
1377
with sdpa_kernel(backends=[kernel]):
1380
dtype = torch.float16
1381
q = torch.randn(size, device=device, dtype=dtype)
1382
k = torch.randn(size, device=device, dtype=dtype)
1383
v = torch.randn(size, device=device, dtype=dtype)
1384
with self.assertWarnsRegex(UserWarning, "Both fused kernels requires query, key and value to be 4 dimensional"):
1385
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1386
q, k, v, None, 0.0, False))
1389
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1392
PLATFORM_SPECIFIC_SDPA,
1394
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
1395
with sdpa_kernel(backends=[kernel]):
1397
dtype = torch.float16
1399
size_broadcast = (1, 4, 3, 8)
1400
q = torch.randn(size_broadcast, device=device, dtype=dtype)
1401
k = torch.randn(size, device=device, dtype=dtype)
1402
v = torch.randn(size, device=device, dtype=dtype)
1403
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1404
q, k, v, None, 0.0, False))
1407
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1408
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1409
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
1410
with sdpa_kernel(backends=[kernel]):
1412
dtype = torch.float16
1413
make_tensor = partial(torch.rand, device=device, dtype=dtype)
1414
size = SdpaShape(2, 2, 0, 8)
1415
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1416
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support zero seq_len_q or seq_len_kv."):
1417
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1418
q, k, v, None, 0.0, False))
1421
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1422
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
1423
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
1424
with sdpa_kernel(backends=[kernel]):
1426
dtype = torch.float16
1427
make_tensor = partial(torch.rand, device=device, dtype=dtype)
1428
size = SdpaShape(2, 2, 8, 8)
1429
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1430
q.as_strided_(size, [2, 2, 2, 2])
1431
with self.assertWarnsRegex(UserWarning, "Both fused kernels require the last dimension of the input to have stride 1."):
1432
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1433
q, k, v, None, 0.0, False))
1436
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
1437
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1438
def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
1439
with sdpa_kernel(backends=[kernel]):
1441
dtype = torch.float16
1442
make_tensor = partial(torch.rand, device=device, dtype=dtype)
1443
size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
1444
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1445
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1446
q, k, v, None, 0.0, False))
1449
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
1452
PLATFORM_SPECIFIC_SDPA,
1454
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
1455
with sdpa_kernel(backends=[kernel]):
1457
size = SdpaShape(2, 2, 3, 16)
1458
make_tensor = partial(torch.rand, device=device, dtype=torch.float64)
1459
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1460
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1461
q, k, v, None, 0.0, False))
1464
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1465
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
1466
def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
1467
with sdpa_kernel(backends=[kernel]):
1469
size = SdpaShape(2, 2, 3, 16)
1470
make_tensor = partial(torch.rand, size, device=device, dtype=torch.float16)
1471
q, k, v = make_tensor(), make_tensor(), make_tensor()
1473
mask = torch.ones((2, 2, 3, 3), device=device, dtype=q.dtype)
1474
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1475
q, k, v, mask, 0.0, False))
1478
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1479
def test_unaligned_tensors(self, device):
1481
dtype = torch.float16
1482
size = SdpaShape(2, 2, 8, 5)
1483
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1484
q, k, v = make_tensor(), make_tensor(), make_tensor()
1485
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1486
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1487
q, k, v, None, 0.0, False))
1490
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
1491
def test_flash_fail_fp32(self, device):
1493
size = SdpaShape(16, 16, 32, 32)
1494
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1495
q, k, v = make_tensor(), make_tensor(), make_tensor()
1496
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1497
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, BFloat16}"):
1498
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1499
q, k, v, None, 0.0, False))
1502
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1503
def test_flash_autocast_fp32_float16(self, device):
1505
size = SdpaShape(16, 16, 32, 32)
1506
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1507
q, k, v = make_tensor(), make_tensor(), make_tensor()
1508
with torch.autocast(device_type='cuda', dtype=torch.float16):
1509
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1510
_ = torch.nn.functional.scaled_dot_product_attention(
1511
q, k, v, None, 0.0, False)
1514
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
1515
def test_flash_autocast_fp32_bfloat16(self, device):
1517
size = SdpaShape(16, 16, 32, 32)
1518
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1519
q, k, v = make_tensor(), make_tensor(), make_tensor()
1520
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1521
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1522
_ = torch.nn.functional.scaled_dot_product_attention(
1523
q, k, v, None, 0.0, False)
1526
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1527
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
1528
with sdpa_kernel(backends=[kernel]):
1530
shape = (1, 4, 8, 16)
1531
query = torch.randn(shape, dtype=torch.float32, device=device)
1532
key = torch.randn(shape, dtype=torch.float16, device=device)
1533
value = torch.randn(shape, dtype=torch.float16, device=device)
1534
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1537
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1538
def test_invalid_inputs_different_devices(self, device, kernel: SDPBackend):
1540
shape = (1, 4, 8, 16)
1541
query = torch.randn(shape, dtype=torch.float32, device=device)
1542
key = torch.randn(shape, dtype=torch.float16, device='cpu')
1543
value = torch.randn(shape, dtype=torch.float16, device='cpu')
1544
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1546
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
1547
def test_invalid_inputs_1_dimensional_inputs(self, device, kernel: SDPBackend):
1548
with sdpa_kernel(backends=[kernel]):
1551
query = torch.randn(4, dtype=torch.float16, device=device)
1552
key = torch.randn(shape, dtype=torch.float16, device=device)
1553
value = torch.randn(shape, dtype=torch.float16, device=device)
1554
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
1558
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
1559
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
1561
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
1562
batch, num_heads, head_dim = 32, 8, 64
1563
seq_lens_q = torch.randint(low=1, high=32, size=(batch,)).tolist()
1564
seq_lens_v = torch.randint(low=1, high=32, size=(batch,)).tolist()
1566
q_shape = SdpaShape(batch, num_heads, seq_lens_q, head_dim)
1567
k_shape = SdpaShape(1, num_heads, 1, head_dim)
1568
v_shape = SdpaShape(batch, num_heads, seq_lens_v, head_dim)
1570
query = rand_nested_tensor(q_shape).transpose(1, 2)
1571
key = rand_nested_tensor(k_shape).transpose(1, 2)
1572
value = rand_nested_tensor(v_shape).transpose(1, 2)
1574
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1575
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1576
torch.nn.functional.scaled_dot_product_attention(
1577
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1580
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1581
def test_nested_fails_on_padding_head_dim(self, device):
1582
dtype = torch.bfloat16
1583
seq_len_list = [2, 4, 5, 6, 7]
1584
shape = SdpaShape(5, 8, seq_len_list, 57)
1585
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
1586
q, k, v = make_tensor(), make_tensor(), make_tensor()
1587
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1588
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
1589
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1590
q, k, v, None, 0.0, False))
1594
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
1595
"Current platform does not support fused SDPA or is an SM80+ device.")
1596
def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
1597
dtype = torch.bfloat16
1598
size = SdpaShape(16, 16, 32, 32)
1599
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
1600
q, k, v = make_tensor(), make_tensor(), make_tensor()
1601
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1602
with self.assertWarnsRegex(UserWarning, "Expected query, key and value to all be of dtype: {Half, Float}"):
1603
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1604
q, k, v, None, 0.0, False))
1607
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
1608
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
1609
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
1610
def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel):
1611
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
1612
batch, num_heads, head_dim = 32, 16, 64
1613
seq_lens = torch.randint(low=1, high=32, size=(batch,))
1616
indices = torch.randint(low=0, high=batch, size=(num_zeros,))
1617
seq_lens.scatter_(0, indices, 0)
1619
shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
1620
query = rand_nested_tensor(shape)
1621
key = rand_nested_tensor(shape)
1622
value = rand_nested_tensor(shape)
1624
query = query.transpose(1, 2)
1625
key = key.transpose(1, 2)
1626
value = value.transpose(1, 2)
1628
with sdpa_kernel(backends=[fused_kernel]):
1629
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1630
torch.nn.functional.scaled_dot_product_attention(
1631
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1634
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
1635
def test_fused_kernels_nested_broadcasting_requires_grad_failure(self, device):
1636
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16, requires_grad=True)
1637
batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 64
1638
seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
1639
q_shape = SdpaShape(1, num_heads, 1, head_dim)
1640
k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
1641
v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
1644
query = torch.randn(q_shape, device=device, dtype=torch.float16, requires_grad=True)
1645
key = rand_nested_tensor(k_shape)
1646
value = rand_nested_tensor(v_shape)
1648
query = query.transpose(1, 2)
1649
key = key.transpose(1, 2)
1650
value = value.transpose(1, 2)
1652
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1653
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
1654
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
1655
out = torch.nn.functional.scaled_dot_product_attention(
1656
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
1659
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
1660
def test_flash_attention_fail_with_non_square_causal_attention(self, device):
1661
dtype = torch.bfloat16
1662
q_shape = SdpaShape(1, 1, 8, 16)
1663
kv_shape = SdpaShape(1, 1, 12, 16)
1664
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
1665
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
1666
q, k, v = make_q(), make_kv(), make_kv()
1667
warning_str = "Flash attention does not support the is_causal flag when seqlen_q != seqlen_k."
1668
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
1669
with self.assertWarnsRegex(UserWarning, warning_str):
1670
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1671
q, k, v, None, 0.0, is_causal=True))
1673
def _get_block_size(device, head_dim, is_causal):
1677
assert head_dim <= 256
1678
major, minor = torch.cuda.get_device_capability(device)
1679
is_sm8x = major == 8 and minor > 0
1680
is_sm80 = major == 8 and minor == 0
1681
is_sm90 = major == 9 and minor == 0
1685
return (128, 128) if not is_dropout else (128, 64)
1686
elif head_dim <= 96:
1687
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
1688
elif head_dim <= 128:
1690
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
1692
return 128, (64 if not is_dropout else 32)
1693
elif head_dim <= 160:
1695
return (128, 64) if not is_causal else (64, 64)
1698
elif head_dim <= 192:
1699
return (128, 64) if not is_dropout else (64, 64)
1700
elif head_dim <= 224:
1701
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
1702
elif head_dim <= 256:
1703
return (128, 64) if is_sm80 else (64, 64)
1706
def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
1707
last_dim_size = input_tensor.size(-1)
1708
if (last_dim_size % alignment_size == 0):
1709
return input_tensor, last_dim_size
1710
pad_count = alignment_size - (last_dim_size % alignment_size)
1711
padded_tensor = F.pad(input_tensor, (0, pad_count))
1713
return padded_tensor[..., :last_dim_size], last_dim_size
1714
return padded_tensor, last_dim_size
1717
class TestSDPA(NNTestCase):
1718
""" Used to test generic functionality of scaled_dot_product_attention
1720
If you are adding a new test to this class, make sure that it runs
1721
for both cpu and cuda. If you're test is only applicable to cuda,
1722
add it to TestSDPACudaOnly.
1724
@parametrize("contiguous_inputs", [True, False])
1725
def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
1727
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
1728
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
1729
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
1730
dtype=torch.float64, requires_grad=True, packed=True)
1732
qkv = make_tensor(shape)
1733
query, key, value = qkv.chunk(3, dim=-1)
1735
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1736
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1737
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
1739
if contiguous_inputs:
1740
query = query.contiguous()
1741
key = key.contiguous()
1742
value = value.contiguous()
1744
with sdpa_kernel(backends=[SDPBackend.MATH]):
1745
assert gradcheck(lambda *args, **kwargs:
1746
wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),
1747
(query, key, value, None, 0.0, False)
1751
@parametrize("type", ["dense", "nested"])
1752
@parametrize("dropout", [0.0, 0.7])
1753
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.half])
1754
def test_fused_sdp_choice_cpu(self, device, type: str, dropout: float, dtype: torch.dtype):
1756
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype)
1757
size = SdpaShape(2, 8, 128, 64)
1758
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1759
if type == "nested" \
1761
or dtype not in [torch.float32, torch.float64, torch.bfloat16, torch.float16]:
1762
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.MATH.value
1764
assert torch._fused_sdp_choice(q, k, v, dropout_p=dropout) == SDPBackend.FLASH_ATTENTION.value
1767
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1768
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1769
@parametrize("batch_size", [2, 12])
1770
@parametrize("seq_len", [267, 1030])
1771
@parametrize("n_head", [1, 3])
1772
@parametrize("head_dim", [8, 16])
1773
@parametrize("causal", [True, False])
1774
@parametrize("train", [True, False])
1775
def test_scaled_dot_product_fused_attention_vs_math_cpu(
1789
if dtype is torch.bfloat16:
1792
if dtype is torch.float16:
1796
n_embd = n_head * head_dim
1797
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, packed=True, requires_grad=False)
1798
shape = SdpaShape(batch_size, n_head, seq_len, head_dim)
1799
x = make_tensor(shape)
1803
x.requires_grad_(True)
1804
x2.requires_grad_(True)
1806
q, k, v = x.split(n_embd, dim=2)
1807
q2, k2, v2 = x2.split(n_embd, dim=2)
1809
if dtype in [torch.bfloat16, torch.float16]:
1815
k = k.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1816
q = q.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1817
v = v.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1818
k2 = k2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1819
q2 = q2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1820
v2 = v2.view(batch_size, seq_len, n_head, head_dim).transpose(1, 2)
1822
with sdpa_kernel(backends=[fused_kernel]):
1823
actual = torch.nn.functional.scaled_dot_product_attention(
1824
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal)
1825
with sdpa_kernel(backends=[SDPBackend.MATH]):
1826
math_ref = torch.nn.functional.scaled_dot_product_attention(
1827
q2, k2, v2, attn_mask=None, dropout_p=0.0, is_causal=causal)
1829
if dtype in [torch.bfloat16, torch.float16]:
1830
math_ref = math_ref.to(dtype)
1832
self.assertEqual(actual, math_ref, atol=atol, rtol=rtol)
1835
actual.sum().backward()
1836
math_ref.sum().backward()
1838
grad_x, grad_x2 = x.grad, x2.grad
1839
grad_q_actual, grad_k_actual, grad_v_actual = grad_x.split(n_embd, dim=2)
1840
grad_q_ref, grad_k_ref, grad_v_ref = grad_x2.split(n_embd, dim=2)
1842
self.assertEqual(grad_q_actual, grad_q_ref, atol=atol, rtol=rtol)
1843
self.assertEqual(grad_k_actual, grad_k_ref, atol=atol, rtol=rtol)
1844
self.assertEqual(grad_v_actual, grad_v_ref, atol=atol, rtol=rtol)
1847
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION])
1848
@parametrize("dtype", [torch.float64, torch.float32, torch.bfloat16, torch.float16])
1849
@parametrize("batch_size", [2, 12])
1850
@parametrize("q_seq_len", [267, 1030])
1851
@parametrize("kv_seq_len", [514, 1179])
1852
@parametrize("n_head", [1, 3])
1853
@parametrize("head_dim", [8, 16])
1854
@parametrize("mask_dim", [2, 4])
1855
@parametrize("bool_mask", [0, 1])
1856
@parametrize("train", [True, False])
1857
def test_scaled_dot_product_fused_attention_mask_vs_math_cpu(
1871
tol = Tolerances(1e-5, 5e-6)
1872
if dtype is torch.bfloat16:
1873
tol = Tolerances(5e-2, 5e-2)
1874
if dtype is torch.float16:
1875
tol = Tolerances(1e-2, 1e-2)
1877
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype, requires_grad=False)
1878
q_shape = SdpaShape(batch_size, n_head, q_seq_len, head_dim)
1879
kv_shape = SdpaShape(batch_size, n_head, kv_seq_len, head_dim)
1880
q = make_tensor(q_shape)
1881
k = make_tensor(kv_shape)
1882
v = make_tensor(kv_shape)
1883
q2, k2, v2 = q.clone(), k.clone(), v.clone()
1886
q.requires_grad_(True)
1887
k.requires_grad_(True)
1888
v.requires_grad_(True)
1889
q2.requires_grad_(True)
1890
k2.requires_grad_(True)
1891
v2.requires_grad_(True)
1893
if dtype in [torch.bfloat16, torch.float16]:
1894
q2, k2, v2 = q2.float(), k2.float(), v2.float()
1896
q = q.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1897
k = k.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1898
v = v.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1900
mask_shape = (batch_size, n_head, q_seq_len, kv_seq_len)
1902
mask_shape = (q_seq_len, kv_seq_len)
1904
attn_mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=device)
1906
attn_mask = torch.randn(mask_shape, dtype=dtype, device=device)
1907
q2 = q2.view(batch_size, q_seq_len, n_head, head_dim).transpose(1, 2)
1908
k2 = k2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1909
v2 = v2.view(batch_size, kv_seq_len, n_head, head_dim).transpose(1, 2)
1911
with sdpa_kernel(backends=[fused_kernel]):
1912
actual = torch.nn.functional.scaled_dot_product_attention(
1913
q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1914
with sdpa_kernel(backends=[SDPBackend.MATH]):
1915
if not bool_mask and dtype in [torch.bfloat16, torch.float16]:
1916
attn_mask = attn_mask.float()
1917
math_ref = torch.nn.functional.scaled_dot_product_attention(
1918
q2, k2, v2, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
1920
if dtype in [torch.bfloat16, torch.float16]:
1921
math_ref = math_ref.to(dtype)
1923
self.assertEqual(actual, math_ref, atol=tol.atol, rtol=tol.rtol)
1926
actual.sum().backward()
1927
math_ref.sum().backward()
1929
grad_q_actual, grad_k_actual, grad_v_actual = q.grad, k.grad, v.grad
1930
grad_q_ref, grad_k_ref, grad_v_ref = q2.grad, k2.grad, v2.grad
1932
self.assertEqual(grad_q_actual, grad_q_ref, atol=tol.atol, rtol=tol.rtol)
1933
self.assertEqual(grad_k_actual, grad_k_ref, atol=tol.atol, rtol=tol.rtol)
1934
self.assertEqual(grad_v_actual, grad_v_ref, atol=tol.atol, rtol=tol.rtol)
1936
@parametrize("kernel", [SDPBackend.MATH])
1937
def test_scaled_dot_product_attention_math_with_negative_scale(self, device, kernel: SDPBackend):
1940
v1 = torch.matmul(x, x.transpose(-1, -2))
1942
v3 = v2.softmax(dim=-1)
1943
v4 = torch.matmul(v3, x)
1946
x = torch.randn(1, 3, 64, 64, device=device)
1948
with sdpa_kernel(backends=[kernel]):
1949
sdp_math = torch.nn.functional.scaled_dot_product_attention(x, x, x, scale=-1.0 / 0.0001)
1950
self.assertEqual(ref_result, sdp_math)
1952
class TestSDPACudaOnly(NNTestCase):
1953
""" Used to test CUDA only functionality of scaled_dot_product_attention
1955
There is some trickiness with this function. It's runtime behavior
1956
is dependent on the CUDA architecture you are testing it on. See
1957
`PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file.
1959
Math: always supported
1960
FlashAttention: Supported on sm80 or newer hardware
1961
MemEfficientAttention: Supported on sm50 or newer hardware
1963
_do_cuda_memory_leak_check = True
1964
_do_cuda_non_default_stream = True
1966
def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False):
1967
"""FlashAttention stores the S matrix in a different way.
1969
S: (batch_size, nheads, seqlen_q, seqlen_k)
1970
query_padding_mask: (batch_size, seqlen_q)
1971
key_padding_mask: (batch_size, seqlen_k)
1976
b, h, seqlen_q, seqlen_k = S.shape
1978
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
1979
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
1980
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
1981
mmas_n = (blocksize_n + 16 - 1) // 16
1984
S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n)
1985
S_flat = S_flat.permute(0, 1, 2, 4, 3, 5)
1986
S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n))
1987
S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2)
1988
S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11)
1989
S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) *
1990
warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2))
1993
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
1994
S_converted.masked_fill_(causal_mask, 0.0)
1997
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
1998
if query_padding_mask is not None:
1999
if seqlen_q_og < seqlen_q:
2000
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
2002
query_padding_mask = query_padding_mask[:, :seqlen_q]
2003
q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
2004
S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
2005
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
2006
if key_padding_mask is not None:
2007
if seqlen_k_og < seqlen_k:
2008
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
2010
key_padding_mask = key_padding_mask[:, :seqlen_k]
2011
k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1])
2012
S_converted = S_converted.masked_fill(k_mask_fill, 0.0)
2013
if seqlen_q_og < seqlen_q:
2014
S_converted = S_converted[:, :, :seqlen_q_og, :]
2016
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
2017
if seqlen_k_og < seqlen_k:
2018
S_converted = S_converted[:, :, :, :seqlen_k_og]
2020
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
2023
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2024
@parametrize("mask_dim", [1, 2, 3, 4])
2025
def test_mem_efficient_attetntion_mask_variants(self, device, mask_dim: List[int]):
2026
dtype = torch.float16
2027
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2028
batch, num_heads, head_dim = 8, 8, 64
2029
seq_len_q, seq_len_kv = 64, 32
2030
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2031
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2032
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2035
mask = torch.randn((seq_len_kv,), device=device, dtype=dtype)
2037
mask = torch.randn((seq_len_q, seq_len_kv), device=device, dtype=dtype)
2039
mask = torch.randn((num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2041
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2042
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2043
out = F.scaled_dot_product_attention(query, key, value, mask)
2044
out.sum().backward()
2046
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2047
@parametrize("dtype", [torch.float, torch.float16])
2048
def test_mem_eff_attention_pad_mask(self, device, dtype):
2049
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2050
batch, num_heads, head_dim = 8, 8, 64
2051
seq_len_q, seq_len_kv = 64, 15
2052
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2053
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2054
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2055
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2056
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2057
out = F.scaled_dot_product_attention(query, key, value, mask)
2058
out.sum().backward()
2060
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2061
@parametrize("dtype", [torch.float, torch.float16])
2062
def test_mem_eff_attention_non_contiguous_mask(self, device, dtype):
2063
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2064
batch, num_heads, head_dim = 8, 8, 64
2065
seq_len_q, seq_len_kv = 64, 16
2066
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2067
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2068
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2069
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2070
mask = torch.as_strided(mask, (batch, num_heads, seq_len_q, seq_len_kv), (0, 0, 0, 1))
2071
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2072
out = F.scaled_dot_product_attention(query, key, value, mask)
2073
out.sum().backward()
2075
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2076
@parametrize("dtype", [torch.float, torch.float16])
2077
def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
2078
if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30:
2079
unittest.skip("This test requires substatnial GPU memory.")
2081
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
2082
batch, num_heads, head_dim = 1, 32, 64
2083
seq_len_q, seq_len_kv = 8192, 8192
2084
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
2085
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
2086
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
2087
mask = torch.randn((batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype)
2088
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2089
out = F.scaled_dot_product_attention(query, key, value, mask)
2090
out.sum().backward()
2092
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2093
def test_mem_eff_attention_non_contig_mask_bug(self, device):
2096
query_size = (3, 16, 1, 128)
2097
query_strides = (2304, 128, 2048, 1)
2098
key_size = (3, 16, 14, 128)
2099
key_strides = (3584, 0, 256, 1)
2100
value_size = (3, 16, 14, 128)
2101
value_strides = (3584, 0, 256, 1)
2102
attention_mask_size = (3, 1, 1, 14)
2103
attn_mask_strides = (14, 14, 14, 1)
2106
query_num_elements = max([size * stride for size, stride in zip(query_size, query_strides)])
2107
key_num_elements = max([size * stride for size, stride in zip(key_size, key_strides)])
2108
value_num_elements = max([size * stride for size, stride in zip(value_size, value_strides)])
2109
attention_mask_num_elements = max([size * stride for size, stride in zip(attention_mask_size, attn_mask_strides)])
2112
query = torch.randn(query_num_elements, device=device).as_strided(query_size, query_strides)
2113
key = torch.randn(key_num_elements, device=device).as_strided(key_size, key_strides)
2114
value = torch.randn(value_num_elements, device=device).as_strided(value_size, value_strides)
2115
bias = torch.randn(attention_mask_num_elements, device=device).as_strided(attention_mask_size, attn_mask_strides)
2117
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2118
out = F.scaled_dot_product_attention(query, key, value, bias)
2119
out_contig = F.scaled_dot_product_attention(query, key, value, bias.contiguous())
2121
max_diff = (out - out_contig).abs().mean()
2122
self.assertTrue(max_diff.item() < 1e-7)
2124
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
2125
def test_singelton_head_dim_stride_ne_1(self, device):
2126
query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
2127
query = query.transpose(-1, -2)
2128
key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2129
value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
2131
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
2132
scaled_dot_product_attention(query, key, value)
2134
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
2135
@parametrize("type", ["dense", "nested"])
2136
@parametrize("is_contiguous", [True, False])
2137
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
2138
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
2140
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
2141
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2144
qkv = make_tensor(shape)
2145
query, key, value = qkv.chunk(3, dim=-1)
2147
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2148
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2149
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2152
query = query.contiguous()
2153
key = key.contiguous()
2154
value = value.contiguous()
2156
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2157
actual = torch.nn.functional.scaled_dot_product_attention(
2158
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2159
with sdpa_kernel(backends=[SDPBackend.MATH]):
2160
math_ref = torch.nn.functional.scaled_dot_product_attention(
2161
query.contiguous(), key.contiguous(), value.contiguous(),
2162
attn_mask=None, dropout_p=0.0, is_causal=False)
2164
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
2167
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2168
@parametrize("type", ["dense", "nested"])
2169
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2170
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2171
def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str):
2173
batch, seq_len, num_heads, head_dim = shape
2174
tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2175
for _ in range(batch)]
2176
return (torch.nested.nested_tensor(tensors, device=device, dtype=torch.float32),
2177
torch.nested.nested_tensor(tensors, device=device, dtype=torch.float16))
2179
def rand_tensor(shape):
2180
batch, seq_len, num_heads, head_dim = shape
2181
tensor = 6 * torch.rand((batch, seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
2182
return tensor, tensor.to(dtype=torch.float16)
2184
batch_size, seq_len, num_heads, head_dim = 16, 8, 4, 64
2185
shape = (batch_size, seq_len, num_heads, head_dim)
2188
qkv, qkv_low_precision = rand_tensor(shape) if type == "dense" else rand_nt(shape)
2189
query, key, value = qkv.chunk(3, dim=-1)
2190
query_lp, key_lp, value_lp = qkv_low_precision.chunk(3, dim=-1)
2192
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2193
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2194
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2196
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2197
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2198
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2200
with sdpa_kernel(backends=[fused_kernel]):
2201
actual = torch.nn.functional.scaled_dot_product_attention(
2202
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, is_causal=False)
2204
with sdpa_kernel(backends=[SDPBackend.MATH]):
2205
math_ref_lp = torch.nn.functional.scaled_dot_product_attention(
2206
query_lp.contiguous(), key_lp.contiguous(), value_lp.contiguous(),
2207
attn_mask=None, dropout_p=0.0, is_causal=False)
2209
math_query = query.contiguous()
2210
math_key = key.contiguous()
2211
math_value = value.contiguous()
2213
math_ref = torch.nn.functional.scaled_dot_product_attention(
2214
math_query, math_key, math_value, attn_mask=None, dropout_p=0.0, is_causal=False)
2216
actual_test = actual
2217
math_ref_test = math_ref
2218
math_ref_lp_test = math_ref_lp
2220
if actual_test.is_nested:
2221
actual_test = torch.nested.to_padded_tensor(actual_test.contiguous(), padding=0.0)
2222
math_ref_test = torch.nested.to_padded_tensor(math_ref_test, padding=0.0)
2223
math_ref_lp_test = torch.nested.to_padded_tensor(math_ref_lp_test, padding=0.0)
2225
actual_test = actual_test.to(dtype=torch.float32).contiguous()
2226
math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
2227
math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()
2229
self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
2230
self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
2232
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Flash Attention was not built for this system")
2233
@parametrize("contiguous_inputs", [True, False])
2234
@parametrize("is_causal", [True, False])
2235
def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool):
2236
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2237
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2238
dtype=torch.float64, requires_grad=True, packed=True)
2240
qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2241
qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
2243
query, key, value = qkv.chunk(3, dim=-1)
2244
query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2246
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2247
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2248
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2250
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2251
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2252
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2254
if contiguous_inputs:
2255
query = query.contiguous()
2256
key = key.contiguous()
2257
value = value.contiguous()
2259
query_lp = query_lp.contiguous()
2260
key_lp = key_lp.contiguous()
2261
value_lp = value_lp.contiguous()
2263
with sdpa_kernel(backends=[SDPBackend.MATH]):
2264
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2266
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2267
out_lp = torch.nn.functional.scaled_dot_product_attention(
2268
query_lp, key_lp, value_lp, None, 0.0, is_causal)
2270
rand_upward = torch.rand_like(out)
2271
rand_upward_lp = rand_upward.to(torch.float32)
2273
out.backward(rand_upward)
2274
out_lp.backward(rand_upward_lp)
2277
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
2280
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
2281
@parametrize("contiguous_inputs", [True, False])
2282
@parametrize("is_causal", [True, False])
2283
@parametrize("dtype", [torch.float16, torch.bfloat16])
2284
def test_sdp_flash_attention_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool, dtype: torch.dtype):
2285
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
2286
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device,
2287
dtype=torch.float64, requires_grad=True, packed=True)
2289
qkv = make_tensor(SdpaShape(batch_size, num_heads, seq_len, head_dim))
2290
qkv_lp = qkv.detach().clone().to(dtype).requires_grad_()
2292
query, key, value = qkv.chunk(3, dim=-1)
2293
query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
2295
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2296
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2297
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2299
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2300
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2301
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2303
if contiguous_inputs:
2304
query = query.contiguous()
2305
key = key.contiguous()
2306
value = value.contiguous()
2308
query_lp = query_lp.contiguous()
2309
key_lp = key_lp.contiguous()
2310
value_lp = value_lp.contiguous()
2312
with sdpa_kernel(backends=[SDPBackend.MATH]):
2313
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, None, 0.0, is_causal)
2315
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2316
out_lp = torch.nn.functional.scaled_dot_product_attention(
2317
query_lp, key_lp, value_lp, None, 0.0, is_causal)
2319
rand_upward = torch.rand_like(out)
2320
rand_upward_lp = rand_upward.to(dtype)
2322
out.backward(rand_upward)
2323
out_lp.backward(rand_upward_lp)
2328
atol = 7e-4 if dtype == torch.float16 else 7e-3
2329
rtol = 7e-4 if dtype == torch.float16 else 7e-3
2330
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
2333
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
2334
@parametrize("type", ["dense", "nested"])
2335
def test_fused_sdp_choice(self, device, type: str):
2336
batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
2337
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2338
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
2340
qkv = make_tensor(shape, type=type)
2341
query, key, value = qkv.chunk(3, dim=-1)
2343
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2344
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2345
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2347
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
2348
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION.value
2350
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2353
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
2355
qkv = make_tensor(shape, type=type)
2356
query, key, value = qkv.chunk(3, dim=-1)
2358
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2359
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2360
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
2362
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2364
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2365
@parametrize("warn_only", [True, False])
2366
def test_sdp_choice_with_determinism(self, device, warn_only):
2367
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2368
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2369
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False)
2370
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2372
with use_deterministic_algorithims(True, warn_only=warn_only):
2374
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
2375
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
2377
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
2378
@parametrize("warn_only", [True, False])
2379
def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only):
2380
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
2381
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
2382
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True)
2383
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
2386
self.assertWarnsRegex(
2388
"Memory Efficient attention defaults to a non-deterministic algorithm.",
2391
else contextlib.nullcontext()
2393
with use_deterministic_algorithims(True, warn_only=warn_only):
2394
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2395
with warning_context:
2396
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
2398
@unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD")
2399
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA")
2400
def test_mem_eff_backwards_determinism(self, device):
2402
dtype = torch.float32
2403
batch_size, seq_len, n_heads, head_dim = 1, 1024, 8, 64
2404
query = torch.rand(batch_size, n_heads, seq_len, head_dim,
2405
device=device, dtype=dtype, requires_grad=True)
2406
key = torch.rand(batch_size, n_heads, seq_len, head_dim, device=device,
2407
dtype=dtype, requires_grad=True)
2408
value = torch.rand(batch_size, n_heads, seq_len, head_dim,
2409
device=device, dtype=dtype, requires_grad=True)
2411
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2413
out = F.scaled_dot_product_attention(query, key, value)
2414
upward_grad = torch.rand_like(out)
2415
out.backward(upward_grad)
2416
intial_query_grad = query.grad
2420
diff_anwser_once = False
2421
for _ in range(100):
2423
out = F.scaled_dot_product_attention(query, key, value)
2424
out.backward(upward_grad)
2425
if not torch.equal(intial_query_grad, query.grad):
2426
diff_anwser_once = True
2428
self.assertTrue(diff_anwser_once)
2430
with use_deterministic_algorithims(True, warn_only=False):
2432
out = F.scaled_dot_product_attention(query, key, value)
2433
upward_grad = torch.rand_like(out)
2434
out.backward(upward_grad)
2435
intial_query_grad = query.grad
2439
diff_anwser_once = False
2440
for _ in range(100):
2442
out = F.scaled_dot_product_attention(query, key, value)
2443
out.backward(upward_grad)
2444
if not torch.equal(intial_query_grad, query.grad):
2445
diff_anwser_once = True
2447
self.assertFalse(diff_anwser_once)
2450
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2451
@parametrize("batch_size", [1, 8])
2452
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
2453
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
2454
@parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64])
2455
@parametrize("is_causal", [False, True])
2456
@parametrize("dropout_p", [0.0, 0.22])
2457
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
2458
SM80OrLater else [torch.float16, torch.float32])
2459
@parametrize("scale", [None, "l1"])
2460
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2461
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2463
def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2464
mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2465
rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2466
mask = (rand_uniform > p).to(torch.float32)
2468
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2469
unittest.skip("Reference implementation OOM")
2472
scale = scale if scale is None else (1 / head_dim)
2474
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2475
device=device, dtype=dtype, requires_grad=True)
2476
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2477
dtype=dtype, requires_grad=True)
2478
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2479
device=device, dtype=dtype, requires_grad=True)
2482
query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2484
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2485
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2488
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2490
torch.manual_seed(seed)
2491
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2493
if dropout_p == 0.0:
2494
with sdpa_kernel(backends=[SDPBackend.MATH]):
2496
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2497
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2499
out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2500
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2502
if seq_len_q > 1024:
2503
self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2505
torch.manual_seed(seed)
2506
dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q, seq_len_k, dropout_p, seed, 0, device=device)
2508
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2509
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2511
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2512
query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2513
dropout_mask=dropout_mask)[0]
2515
upstream_grad = torch.rand_like(out, requires_grad=False)
2517
out.backward(upstream_grad)
2518
out_ref.backward(upstream_grad.to(out_ref.dtype))
2519
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2527
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2530
dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 2.0
2532
query_fudge_factor = dropout_fudge_factor
2533
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2536
key_fudge_factor = 8 * dropout_fudge_factor
2537
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2539
value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2540
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2542
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2543
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2544
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2545
self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2546
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2547
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2548
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2551
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
2552
@parametrize("batch_size", [1, 8])
2553
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512])
2554
@parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512])
2555
@parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128] if SM80OrLater else [8, 16, 32, 64])
2556
@parametrize("is_causal", [False])
2557
@parametrize("dropout_p", [0.0, 0.22])
2558
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
2559
SM80OrLater else [torch.float16, torch.float32])
2560
@parametrize("scale", [None, "l1"])
2561
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
2562
seq_len_k: int, head_dim: int, is_causal: bool,
2563
dropout_p: float, dtype: torch.dtype,
2565
def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, device=device):
2566
mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2567
rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, p, seed, offset)
2568
mask = (rand_uniform > p).to(torch.float32)
2570
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
2571
unittest.skip("Reference implementation OOM")
2574
scale = scale if scale is None else (1 / head_dim)
2576
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2577
device=device, dtype=dtype, requires_grad=True)
2578
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2579
dtype=dtype, requires_grad=True)
2580
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2581
device=device, dtype=dtype, requires_grad=True)
2583
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
2586
query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2587
attn_mask_ref_lp = attn_mask.detach().to(dtype).requires_grad_(True)
2589
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2590
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2591
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
2594
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
2596
torch.manual_seed(seed)
2597
out = F.scaled_dot_product_attention(query, key, value, attn_mask, dropout_p=dropout_p,
2598
is_causal=is_causal, scale=scale)
2600
if dropout_p == 0.0:
2601
with sdpa_kernel(backends=[SDPBackend.MATH]):
2603
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref, attn_mask_ref,
2604
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2606
out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2607
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2609
if seq_len_q > 1024:
2610
self.skipTest("Will call _fill_mem_eff_dropout_mask with too many threads!")
2612
torch.manual_seed(seed)
2613
dropout_mask = _get_mem_eff_drop_mask(batch_size, n_heads, seq_len_q,
2614
seq_len_k, dropout_p, seed, 0, device=device)
2616
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2617
query_ref, key_ref, value_ref, attn_mask_ref, dropout_p=dropout_p, is_causal=is_causal,
2618
scale=scale, dropout_mask=dropout_mask)[0]
2620
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2621
query_ref_lp, key_ref_lp, value_ref_lp, attn_mask_ref_lp,
2622
dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2623
dropout_mask=dropout_mask)[0]
2625
upstream_grad = torch.rand_like(out, requires_grad=False)
2627
out.backward(upstream_grad)
2628
out_ref.backward(upstream_grad.to(out_ref.dtype))
2629
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2637
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2640
dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.75
2641
mask_fudge_factor = 1.0 if attn_mask is None else 1.5
2643
query_fudge_factor = dropout_fudge_factor
2644
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2647
key_fudge_factor = 8 * dropout_fudge_factor * mask_fudge_factor
2648
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2650
value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2651
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2653
mask_fudge_factor = 12 if attn_mask.numel() > 512 else 22
2654
grad_attn_mask_atol, grad_attn_mask_rtol = get_tolerances(
2655
attn_mask_ref.grad, attn_mask_ref_lp.grad, mask_fudge_factor)
2657
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2658
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2659
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2660
self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2661
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2662
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2663
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2665
self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype),
2666
atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
2668
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2669
@parametrize("batch_size", [1, 8])
2670
@parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
2671
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
2672
@parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
2673
@parametrize("is_causal", [True, False])
2674
@parametrize("dropout_p", [0.0, 0.22, 0.48])
2675
@parametrize("dtype", [torch.float16, torch.bfloat16])
2676
@parametrize("scale", [None, "l1"])
2677
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2678
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
2681
def is_power_of_2(n):
2682
return n & (n - 1) == 0
2683
if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
2684
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
2685
if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
2686
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
2688
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
2690
if isSM8XDevice and head_dim in range(193, 256 + 1):
2691
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
2692
if is_causal and seq_len_q != seq_len_k:
2693
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2695
scale = scale if scale is None else (1 / head_dim)
2697
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2698
device=device, dtype=dtype, requires_grad=True)
2699
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2700
dtype=dtype, requires_grad=True)
2701
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2702
device=device, dtype=dtype, requires_grad=True)
2705
query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2707
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2708
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2710
is_dropout = dropout_p > 0.0
2715
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
2716
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2717
with sdpa_kernel(backends=[SDPBackend.MATH]):
2719
out_ref = F.scaled_dot_product_attention(
2720
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
2722
out_lp_ref = F.scaled_dot_product_attention(
2723
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
2725
q_padded, q_og_size = pad_last_dim(query, 8)
2726
k_padded, k_og_size = pad_last_dim(key, 8)
2727
v_padded, v_og_size = pad_last_dim(value, 8)
2730
scale = 1 / math.sqrt(q_og_size)
2731
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
2732
q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
2733
out = output_tuple[0]
2734
out = out[..., :v_og_size]
2736
dbug_mask = output_tuple[-1]
2737
query_padding_mask = torch.ones(
2738
batch_size, seq_len_q, device=device, dtype=torch.bool)
2739
key_padding_mask = torch.ones(
2740
batch_size, seq_len_k, device=device, dtype=torch.bool)
2742
softmax_mask = self.convert_flash_attn_S_to_softmax(
2743
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim,
2744
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
2745
dropout_mask = softmax_mask >= 0
2747
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2748
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
2750
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2751
query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2752
dropout_mask=dropout_mask)[0]
2754
upstream_grad = torch.rand_like(out, requires_grad=False)
2757
if isSM8XDevice and head_dim in range(193, 256):
2758
self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
2760
out.backward(upstream_grad)
2761
out_ref.backward(upstream_grad.to(out_ref.dtype))
2762
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2765
output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
2766
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
2769
query_fudge_factor = 4
2770
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2772
key_fudge_factor = 2
2773
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2775
value_fudge_factor = 2
2776
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2778
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2779
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2780
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2781
self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2782
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2783
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2784
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2787
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
2788
@parametrize("batch_size", [1, 8])
2789
@parametrize("seq_len_q", [256, 512, 1024])
2790
@parametrize("seq_len_k", [256, 512, 1024])
2791
@parametrize("head_dim", [32, 64])
2792
@parametrize("is_causal", [True, False])
2793
@parametrize("dropout_p", [0.0, 0.22])
2794
@parametrize("dtype", [torch.float16,])
2795
@parametrize("scale", [None, "l1"])
2796
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
2797
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
2803
fused_kernel: SDPBackend):
2804
def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, dropout_p, seed, offset, device=device):
2805
mask = torch.empty((batch_size, n_heads, q_len, kv_len), device=device, dtype=torch.float32)
2806
rand_uniform = torch._fill_mem_eff_dropout_mask_(mask, dropout_p, seed, offset)
2807
mask = (rand_uniform > dropout_p).to(torch.float32)
2810
def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, dropout_p, device=device):
2811
if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
2812
output_seed, output_offset = output_tuple[2], output_tuple[3]
2813
output_seed = output_seed.item()
2814
output_offset = output_offset.item()
2815
return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len,
2816
dropout_p, output_seed, output_offset, device=device)
2819
dbug_mask = output_tuple[-1]
2820
query_padding_mask = torch.ones(
2821
batch_size, seq_len_q, device=device, dtype=torch.bool)
2822
key_padding_mask = torch.ones(
2823
batch_size, seq_len_k, device=device, dtype=torch.bool)
2825
softmax_mask = self.convert_flash_attn_S_to_softmax(
2826
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
2827
dropout_mask = softmax_mask >= 0
2830
if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
2831
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
2834
scale = scale if scale is None else (1 / head_dim)
2836
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
2837
device=device, dtype=dtype, requires_grad=True)
2838
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
2839
dtype=dtype, requires_grad=True)
2840
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
2841
device=device, dtype=dtype, requires_grad=True)
2843
fused_op = (torch.ops.aten._scaled_dot_product_efficient_attention
2844
if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else torch.ops.aten._scaled_dot_product_flash_attention)
2846
query_ref_lp, key_ref_lp, value_ref_lp = query_key_value_clones(query, key, value, dtype=dtype)
2848
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
2849
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
2852
s = torch.cuda.Stream()
2853
s.wait_stream(torch.cuda.current_stream())
2855
torch.manual_seed(seed)
2856
kwargs = {"dropout_p": dropout_p, "is_causal": is_causal, "scale": scale}
2857
if fused_kernel == SDPBackend.EFFICIENT_ATTENTION:
2858
kwargs["compute_log_sumexp"] = True
2859
kwargs["attn_bias"] = None
2860
if fused_kernel == SDPBackend.FLASH_ATTENTION:
2861
kwargs['return_debug_mask'] = dropout_p > 0.0
2862
with torch.cuda.stream(s):
2864
output_tuple = fused_op(query, key, value, **kwargs)
2866
torch.cuda.current_stream().wait_stream(s)
2867
out = output_tuple[0]
2868
upstream_grad = torch.rand_like(out, requires_grad=False)
2869
s.wait_stream(torch.cuda.current_stream())
2870
with torch.cuda.stream(s):
2871
out.backward(upstream_grad)
2872
for x in (query, key, value):
2874
g = torch.cuda.CUDAGraph()
2876
with torch.cuda.graph(g):
2877
tmp = torch.rand_like(query, device=query.device)
2879
output_tuple = fused_op(query, key, value, **kwargs)
2880
assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
2882
out_first = output_tuple[0].clone()
2884
out = output_tuple[0]
2885
if dropout_p == 0.0:
2886
self.assertEqual(out_first, out, atol=0, rtol=0)
2889
self.assertNotEqual(out_first, out)
2891
with sdpa_kernel(backends=[SDPBackend.MATH]):
2892
if dropout_p == 0.0:
2894
out_ref = F.scaled_dot_product_attention(query_ref, key_ref, value_ref,
2895
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2897
out_lp_ref = F.scaled_dot_product_attention(query_ref_lp, key_ref_lp, value_ref_lp,
2898
dropout_p=dropout_p, is_causal=is_causal, scale=scale)
2901
dropout_mask = get_dropout_mask(output_tuple, fused_kernel, batch_size,
2902
n_heads, seq_len_q, seq_len_k, dropout_p, device)
2904
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
2905
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
2906
scale=scale, dropout_mask=dropout_mask)[0]
2908
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
2909
query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
2910
dropout_mask=dropout_mask)[0]
2913
g1 = torch.cuda.CUDAGraph()
2914
with torch.cuda.graph(g1):
2915
out.backward(upstream_grad)
2917
out_ref.backward(upstream_grad.to(out_ref.dtype))
2918
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
2926
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
2929
dropout_fudge_factor = 1.0 if dropout_p == 0.0 else 1.5
2931
query_fudge_factor = dropout_fudge_factor
2932
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
2935
key_fudge_factor = 8 * dropout_fudge_factor
2936
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad, key_fudge_factor)
2938
value_fudge_factor = 7 if not SM80OrLater and dtype == torch.float16 else 1.0
2939
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
2941
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
2942
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
2943
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
2944
self.assertEqual(key.grad, key_ref.grad.to(key.grad.dtype),
2945
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
2946
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
2947
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
2950
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2951
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2952
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2953
def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel):
2954
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
2955
batch, num_heads, head_dim = 32, 16, 64
2956
seq_lens = torch.randint(low=1, high=32, size=(batch,))
2959
indices = torch.randint(low=0, high=batch, size=(num_ones,))
2960
seq_lens.scatter_(0, indices, 1)
2962
shape = SdpaShape(batch, num_heads, seq_lens.tolist(), head_dim)
2963
query = rand_nested_tensor(shape)
2964
key = rand_nested_tensor(shape)
2965
value = rand_nested_tensor(shape)
2967
query = query.transpose(1, 2)
2968
key = key.transpose(1, 2)
2969
value = value.transpose(1, 2)
2971
with sdpa_kernel(backends=[fused_kernel]):
2972
actual = torch.nn.functional.scaled_dot_product_attention(
2973
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
2974
with sdpa_kernel(backends=[SDPBackend.MATH]):
2975
math_ref = torch.nn.functional.scaled_dot_product_attention(
2976
query.contiguous().to(torch.float32),
2977
key.contiguous().to(torch.float32),
2978
value.contiguous().to(torch.float32),
2979
attn_mask=None, dropout_p=0.0, is_causal=False)
2981
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
2985
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
2986
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
2987
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
2988
@parametrize("expand_q_batch", [True, False])
2989
@parametrize("expand_k_batch", [True, False])
2990
@parametrize("expand_v_batch", [True, False])
2991
@parametrize("expand_q_num_heads", [True, False])
2992
@parametrize("expand_k_num_heads", [True, False])
2993
@parametrize("expand_v_num_heads", [True, False])
2994
def test_fused_kernels_nested_broadcasting(
3005
is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION
3006
dtype = torch.float32 if is_efficient else torch.float16
3007
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
3008
batch, num_heads, head_dim = 32, 8, 64
3009
head_dim_v = 32 if is_efficient else head_dim
3010
seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
3012
else torch.randint(low=1, high=32, size=(batch,)).tolist())
3013
seq_lens_kv = (torch.randint(low=1, high=5, size=(1,)).item()
3014
if (expand_k_batch or expand_v_batch)
3015
else torch.randint(low=1, high=32, size=(batch,)).tolist())
3017
batch_q = 1 if expand_q_batch else batch
3018
batch_k = 1 if expand_k_batch else batch
3019
batch_v = 1 if expand_v_batch else batch
3022
batch = max(batch_q, batch_k, batch_v)
3024
num_heads_q = 1 if expand_q_num_heads else num_heads
3025
num_heads_k = 1 if expand_k_num_heads else num_heads
3026
num_heads_v = 1 if expand_v_num_heads else num_heads
3029
num_heads = max(num_heads_q, num_heads_k, num_heads_v)
3031
q_shape = SdpaShape(batch_q, num_heads_q, seq_lens_q, head_dim)
3032
k_shape = SdpaShape(batch_k, num_heads_k, seq_lens_kv, head_dim)
3033
v_shape = SdpaShape(batch_v, num_heads_v, seq_lens_kv, head_dim_v)
3035
query = rand_nested_tensor(q_shape)
3036
key = rand_nested_tensor(k_shape)
3037
value = rand_nested_tensor(v_shape)
3039
def _broadcast(t, batch_broadcasted, num_heads_broadcasted):
3040
if batch_broadcasted and num_heads_broadcasted:
3042
result = torch.nested.nested_tensor(
3043
[t[0].expand(-1, num_heads, t.size(-1)) for _ in range(batch)], dtype=torch.float32)
3044
elif batch_broadcasted:
3046
result = torch.nested.nested_tensor([t[0] for _ in range(batch)], dtype=torch.float32)
3047
elif num_heads_broadcasted:
3049
result = torch.nested.nested_tensor([x.expand(-1, num_heads, t.size(-1))
3050
for x in t.unbind()], dtype=torch.float32)
3052
result = t.to(torch.float32)
3055
query_expanded = _broadcast(query, expand_q_batch, expand_q_num_heads).transpose(1, 2)
3056
key_expanded = _broadcast(key, expand_k_batch, expand_k_num_heads).transpose(1, 2)
3057
value_expanded = _broadcast(value, expand_v_batch, expand_v_num_heads).transpose(1, 2)
3059
query = query.transpose(1, 2)
3060
key = key.transpose(1, 2)
3061
value = value.transpose(1, 2)
3063
with sdpa_kernel(backends=[kernel]):
3064
actual = torch.nn.functional.scaled_dot_product_attention(
3065
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3066
with sdpa_kernel(backends=[SDPBackend.MATH]):
3067
math_ref = torch.nn.functional.scaled_dot_product_attention(
3068
query_expanded.contiguous(), key_expanded.contiguous(), value_expanded.contiguous(),
3069
attn_mask=None, dropout_p=0.0, is_causal=False)
3071
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
3073
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
3074
def test_fused_kernels_nested_broadcasting_query_dense(self, device):
3075
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
3076
batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96
3077
seq_lens = torch.randint(low=1, high=32, size=(batch,)).tolist()
3078
q_shape = (1, 1, num_heads, head_dim)
3079
k_shape = SdpaShape(batch, num_heads, seq_lens, head_dim)
3080
v_shape = SdpaShape(batch, 1, seq_lens, head_dim_v)
3083
query = torch.randn(q_shape, device=device, dtype=torch.float32)
3084
key = rand_nested_tensor(k_shape)
3085
value = rand_nested_tensor(v_shape)
3088
query_expanded = torch.nested.nested_tensor([query.squeeze(0) for _ in range(batch)]).transpose(1, 2)
3090
value_expanded = torch.nested.nested_tensor(
3091
[t.expand(-1, num_heads, head_dim_v) for t in value.unbind()]).transpose(1, 2)
3093
query = query.transpose(1, 2)
3094
key = key.transpose(1, 2)
3095
value = value.transpose(1, 2)
3097
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
3098
actual = torch.nn.functional.scaled_dot_product_attention(
3099
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
3100
with sdpa_kernel(backends=[SDPBackend.MATH]):
3101
math_ref = torch.nn.functional.scaled_dot_product_attention(
3102
query_expanded.contiguous(), key.contiguous(), value_expanded.contiguous(),
3103
attn_mask=None, dropout_p=0.0, is_causal=False)
3105
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
3109
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3110
@parametrize("batch_size", [8, 32])
3111
@parametrize("max_seq_len_q", [32, 256])
3112
@parametrize("max_seq_len_kv", [32, 256])
3113
@parametrize("head_dim", [8, 64])
3114
@parametrize("dropout_p", [0.0, 0.1])
3115
@parametrize("dtype", [torch.float16])
3116
@parametrize("scale", [None, "l1"])
3117
@parametrize("is_causal", [True, False])
3118
def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
3119
head_dim: int, dropout_p: float, dtype: torch.dtype,
3120
scale: str, is_causal: bool):
3123
self.assertRaisesRegex(RuntimeError, "Nested tensors for query / key are not supported when is_causal=True")
3125
scale = scale if scale is None else (1 / head_dim)
3127
seq_lens_q = torch.randint(low=1, high=max_seq_len_q, size=(batch_size,))
3129
seq_lens_q[torch.randint(0, batch_size, size=(1,))] = max_seq_len_q
3130
seq_lens_kv = torch.randint(low=1, high=max_seq_len_kv, size=(batch_size,))
3131
seq_lens_kv[torch.randint(0, batch_size, size=(1,))] = max_seq_len_kv
3133
def rand_nt(sequence_list, num_heads, head_dim):
3134
tensors = [torch.rand((num_heads, seq_len, head_dim)) for seq_len in sequence_list]
3135
return torch.nested.nested_tensor(tensors, requires_grad=True, device=device, dtype=dtype)
3137
query = rand_nt(seq_lens_q, n_heads, head_dim)
3138
key = rand_nt(seq_lens_kv, n_heads, head_dim)
3139
value = rand_nt(seq_lens_kv, n_heads, head_dim)
3142
query_ref_lp = query.clone().detach().requires_grad_(True)
3143
key_ref_lp = key.clone().detach().requires_grad_(True)
3144
value_ref_lp = value.clone().detach().requires_grad_(True)
3146
query_ref = query.clone().detach().to(torch.float32).requires_grad_(True)
3147
key_ref = key.clone().detach().to(torch.float32).requires_grad_(True)
3148
value_ref = value.clone().detach().to(torch.float32).requires_grad_(True)
3150
is_dropout = dropout_p > 0.0
3153
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
3154
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
3155
with sdpa_kernel(backends=[SDPBackend.MATH]):
3157
out_ref = F.scaled_dot_product_attention(
3158
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
3160
out_lp_ref = F.scaled_dot_product_attention(
3161
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
3164
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
3165
query, key, value, dropout_p=dropout_p, is_causal=is_causal,
3166
scale=scale, return_debug_mask=is_dropout)
3167
out = output_tuple[0]
3168
dbug_mask = output_tuple[-1]
3170
query_padding_mask = torch.arange(max_seq_len_q).unsqueeze(0).expand(
3171
batch_size, max_seq_len_q
3172
) < seq_lens_q.unsqueeze(-1)
3173
query_padding_mask = query_padding_mask.to("cuda")
3175
key_padding_mask = torch.arange(max_seq_len_kv).unsqueeze(0).expand(
3176
batch_size, max_seq_len_kv
3177
) < seq_lens_kv.unsqueeze(-1)
3178
key_padding_mask = key_padding_mask.to("cuda")
3180
softmax_mask = self.convert_flash_attn_S_to_softmax(
3181
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
3182
dropout_mask = softmax_mask >= 0
3184
for tensor_component in range(batch_size):
3186
for head in range(n_heads):
3187
batch_stack.append(dropout_mask[tensor_component, head,
3188
0:seq_lens_q[tensor_component],
3189
0:seq_lens_kv[tensor_component]].unsqueeze(0))
3190
nt_stack.append(torch.cat(batch_stack))
3191
nested_dropout_mask = torch.nested.nested_tensor(nt_stack)
3193
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
3194
query_ref, key_ref, value_ref, dropout_p=dropout_p,
3195
is_causal=is_causal, scale=scale, dropout_mask=nested_dropout_mask)[0]
3197
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
3198
query_ref_lp, key_ref_lp, value_ref_lp, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
3199
dropout_mask=nested_dropout_mask)[0]
3201
upstream_grad = out.detach().clone().contiguous()
3203
out.backward(upstream_grad)
3204
out_ref.backward(upstream_grad.to(out_ref.dtype))
3205
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
3208
output_ref_atol, output_ref_rtol = calculate_nt_tolerances(out_ref, out_lp_ref, out.dtype)
3209
grad_q_ref_atol, grad_q_ref_rtol = calculate_nt_tolerances(query_ref.grad, query_ref_lp.grad,
3210
query.grad.dtype, fudge_factor=4)
3211
grad_k_ref_atol, grad_k_ref_rtol = calculate_nt_tolerances(key_ref.grad, key_ref_lp.grad, key.grad.dtype)
3212
grad_v_ref_atol, grad_v_ref_rtol = calculate_nt_tolerances(value_ref.grad, value_ref_lp.grad, value.grad.dtype)
3214
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
3215
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
3216
atol=grad_q_ref_atol, rtol=grad_q_ref_rtol)
3217
self.assertEqual(key.grad.contiguous(), key_ref.grad.contiguous().to(key.grad.dtype),
3218
atol=grad_k_ref_atol, rtol=grad_k_ref_rtol)
3219
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
3220
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
3222
class TestAttnBias(NNTestCase):
3230
forw_tolerances: Optional[Tolerances] = None,
3231
grad_tolerances: Optional[Tolerances] = None,
3234
if backend is not None:
3235
torch._dynamo.reset()
3237
query, key, value = make_q(), make_kv(), make_kv()
3238
query_prototype, key_prototype, value_prototype = query_key_value_clones(query, key, value)
3240
realized = attn_bias._materialize(device) if attn_bias is not None else None
3241
pytorch_output = scaled_dot_product_attention(
3242
query, key, value, attn_mask=realized, dropout_p=0.0, is_causal=False
3246
torch.compile(scaled_dot_product_attention, backend=backend)
3247
if backend is not None
3248
else scaled_dot_product_attention
3250
sdpa_output = sdpa_op(
3254
attn_mask=attn_bias,
3260
dOut = torch.randn_like(pytorch_output)
3261
pytorch_output.backward(dOut)
3262
sdpa_output.backward(dOut)
3265
if forw_tolerances is None:
3266
forw_tolerances = Tolerances(atol=None, rtol=None)
3267
if grad_tolerances is None:
3268
grad_tolerances = Tolerances(atol=None, rtol=None)
3270
torch.testing.assert_close(pytorch_output, sdpa_output, rtol=forw_tolerances.rtol, atol=forw_tolerances.atol)
3271
torch.testing.assert_close(query.grad, query_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3272
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3273
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
3276
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3279
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3281
def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3282
make_tensor = partial(
3283
torch.rand, device=device, dtype=torch.float16, requires_grad=True
3286
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3287
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3288
make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3289
if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3291
"Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3294
forw_tol = Tolerances(1e-3, 1e-3)
3295
grad_tol = Tolerances(5e-3, 5e-3)
3297
if causal_variant == CausalVariant.UPPER_LEFT:
3298
attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3300
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3302
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
3304
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
3307
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
3309
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
3311
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
3313
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
3314
cnts = CompileCounterWithBackend("aot_eager")
3315
make_tensor = partial(
3316
torch.rand, device=device, dtype=torch.float16, requires_grad=True
3319
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3320
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3321
make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3322
if causal_variant == CausalVariant.LOWER_RIGHT and seq_len_q > seq_len_kv:
3324
"Lower right causal mask will produce NaNs in the output when seq_len_q > seq_len_kv!"
3326
forw_tol = Tolerances(1e-3, 1e-3)
3327
grad_tol = Tolerances(5e-3, 5e-3)
3329
if causal_variant == CausalVariant.UPPER_LEFT:
3330
attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3332
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
3334
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
3335
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
3337
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
3338
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
3339
make_tensor = partial(
3340
torch.rand, device=device, dtype=torch.float16, requires_grad=True
3343
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
3344
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
3345
make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
3347
forw_tol = Tolerances(1e-3, 1e-3)
3348
grad_tol = Tolerances(5e-3, 5e-3)
3350
query = make_q_tensor()
3351
key = make_kv_tensor()
3352
value = make_kv_tensor()
3353
attn_bias = causal_upper_left(seq_len_q, seq_len_kv)
3355
out_attn_bias = scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, dropout_p=0.0)
3356
out_is_causal = scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=0.0)
3357
torch.testing.assert_close(out_attn_bias, out_is_causal, rtol=forw_tol.rtol, atol=forw_tol.atol)
3359
def test_is_causal_and_mask_fails(self, device):
3360
make_tensor = partial(
3361
torch.rand, device=device, dtype=torch.float16, requires_grad=True
3363
make_q_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3364
make_kv_tensor = partial(make_tensor, SdpaShape(16, 16, 128, 16))
3366
query = make_q_tensor()
3367
key = make_kv_tensor()
3368
value = make_kv_tensor()
3369
attn_bias = causal_upper_left(128, 128)
3371
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
3372
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
3375
device_types = ("cuda", )
3377
device_types = ("cpu", "cuda")
3379
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
3380
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
3381
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
3382
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
3383
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
3385
if __name__ == '__main__':