pytorch

Форк
0
/
test_native_mha.py 
333 строки · 13.3 Кб
1
# Owner(s): ["module: nn"]
2
import math
3
import copy
4

5
import torch
6
from torch.testing._internal.common_device_type import (
7
    dtypes,
8
    dtypesIfCUDA,
9
    instantiate_device_type_tests,
10
    onlyCUDA,
11
    skipMeta,
12
)
13
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_WITH_ROCM
14

15
class TestMHADeviceType(TestCase):
16
    @torch.no_grad()
17
    def _test_transform_bias_rescale_qkv_impl(
18
        self, device, dtype, use_nt, use_padding=False
19
    ):
20
        tests = [
21
            (64, 4, 16, 8),
22
            # dim_per_head = 12 does not divide evenly by CPU vectorization length of 8
23
            (24, 2, 4, 2),
24
            # Make sure CUDA can handle small input sizes
25
            (2, 2, 2, 2),
26
            # dim_per_head = 6 does not divide evenly by CUDA vectorization length of 4,
27
            # causes alignment issues
28
            (24, 4, 4, 2),
29
            (48, 4, 16, 8),
30
        ]
31
        for (embed_dim, num_heads, bs, sl) in tests:
32
            with self.subTest(embed_dim=embed_dim, num_heads=num_heads, bs=bs, sl=sl):
33
                torch.manual_seed(9343)
34
                dense_x = x = (
35
                    torch.randn(bs, sl, 3 * embed_dim, device=device, dtype=dtype) * 10
36
                )
37
                if use_padding:
38
                    x[0][-1] = torch.full(x[0][-1].shape, float("-Inf"))
39
                if use_nt:
40
                    xs = list(torch.unbind(x))
41
                    if use_padding:
42
                        xs[0] = xs[0][:-1]
43
                    x = torch.nested.nested_tensor(xs, device=device, dtype=dtype)
44
                qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype)
45

46
                # We have to use inference_mode here because q/k/v are
47
                # all views of the same Tensor, which autograd doesn't
48
                # like. This is fine because this function is only
49
                # exposed to Python for purposes of writing this test.
50
                with torch.inference_mode():
51
                    (q, k, v) = torch._transform_bias_rescale_qkv(
52
                        x, qkv.bias, num_heads=num_heads
53
                    )
54

55
                    def simple_transform_bias_rescale_qkv(qkv, bias):
56
                        (q, k, v) = torch.split(qkv, embed_dim, dim=-1)
57
                        (q_bias, k_bias, v_bias) = torch.split(bias, embed_dim, dim=-1)
58

59
                        def embiggen(x):
60
                            if not use_nt:
61
                                return x
62
                            b, t, d = x.size()
63
                            t = t + (8 - t % 8) % 8
64
                            newsize = (b, t, d)
65
                            new_x = torch.zeros(newsize, device=device, dtype=dtype)
66
                            new_x[:x.size()[0], :x.size()[1], :x.size()[2]] = x
67
                            return new_x
68
                        return tuple(
69
                            embiggen(x).reshape(
70
                                (bs, -1, num_heads, embed_dim // num_heads)
71
                            ).transpose(2, 1)
72
                            for x in (
73
                                (q + q_bias) / math.sqrt(embed_dim // num_heads),
74
                                (k + k_bias),
75
                                (v + v_bias),
76
                            )
77
                        )
78

79
                    correct_q, correct_k, correct_v = simple_transform_bias_rescale_qkv(
80
                        dense_x, qkv.bias
81
                    )
82
                    if use_nt and use_padding:
83
                        for t in (correct_q, correct_k, correct_v):
84
                            t[t == float("-Inf")] = 0
85

86
                self.assertEqual(q.size(), correct_q.size())
87
                torch.testing.assert_close(q, correct_q)
88
                torch.testing.assert_close(k, correct_k)
89
                torch.testing.assert_close(v, correct_v)
90

91
    @dtypesIfCUDA(torch.float)
92
    @dtypes(torch.float)
93
    @skipMeta
94
    def test_transform_bias_rescale_qkv(self, device, dtype):
95
        for use_padding in (False, True):
96
            with self.subTest(use_padding=use_padding):
97
                self._test_transform_bias_rescale_qkv_impl(
98
                    device, dtype, use_nt=False, use_padding=use_padding
99
                )
100

101
    @dtypesIfCUDA(torch.float)
102
    @dtypes(torch.float)
103
    @skipMeta
104
    @onlyCUDA
105
    def test_transform_bias_rescale_qkv_nested(self, device, dtype):
106
        for use_padding in (False, True):
107
            with self.subTest(use_padding=use_padding):
108
                self._test_transform_bias_rescale_qkv_impl(
109
                    device, dtype, use_nt=True, use_padding=use_padding
110
                )
111

112
    def _test_multihead_attention_impl(
113
        self, device, dtype, mode, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False
114
    ):
115
        embed_dim = 64
116
        num_heads = 4
117
        bs = 16
118
        sl = 8
119

120
        q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
121
        if use_padding:
122
            if pad_all:
123
                for q_i in q:
124
                    q_i[-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32)
125
                mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
126
                for mask_i in mask:
127
                    mask_i[-1] = True
128
            else:
129
                q[0][-1] = torch.zeros_like(q[0][-1], device=device, dtype=torch.float32)
130
                mask = torch.zeros(q.shape[:-1], device=device, dtype=torch.bool)
131
                mask[0][-1] = True
132
        if mode == "self":
133
            k = q
134
            v = q
135
        elif mode == "encdec":
136
            k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
137
            v = k
138
        elif mode == "generic":
139
            k = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
140
            v = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
141
        else:
142
            self.fail(f"invalid mode `{mode}`!")
143

144
        qkv = torch.nn.Linear(embed_dim, 3 * embed_dim, device=device, dtype=torch.float32)
145
        native_qkv = copy.deepcopy(qkv).to(dtype=dtype)
146

147
        proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
148
        native_proj = copy.deepcopy(proj).to(dtype=dtype)
149

150
        pt = torch.nn.MultiheadAttention(
151
            embed_dim, num_heads, batch_first=True, device=device, dtype=torch.float32
152
        )
153

154
        pt.in_proj_weight = qkv.weight
155
        pt.in_proj_bias = qkv.bias
156
        pt.out_proj.weight = proj.weight
157
        pt.out_proj.bias = proj.bias
158

159
        class NativeMHA(torch.nn.Module):
160
            def __init__(self, embed_dim, num_heads, qkv, proj):
161
                super().__init__()
162
                self.qkv = qkv
163
                self.proj = proj
164
                self.embed_dim = embed_dim
165
                self.num_heads = num_heads
166

167
            def forward(self, q, k, v, key_padding_mask):
168
                return torch._native_multi_head_attention(
169
                    q,
170
                    k,
171
                    v,
172
                    self.embed_dim,
173
                    self.num_heads,
174
                    self.qkv.weight,
175
                    self.qkv.bias,
176
                    self.proj.weight,
177
                    self.proj.bias,
178
                    key_padding_mask,
179
                    need_weights=need_weights,
180
                    average_attn_weights=average_attn_weights,
181
                    mask_type=1,   # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
182
                )
183

184
        npt = NativeMHA(
185
            embed_dim=embed_dim, num_heads=num_heads, qkv=native_qkv, proj=native_proj
186
        ).to(dtype)
187

188
        if device == "cuda":
189
            pt = pt.cuda()
190
            npt = npt.cuda()
191

192
        ypt, weight_pt = pt(
193
            q,
194
            k,
195
            v,
196
            need_weights=need_weights,
197
            average_attn_weights=average_attn_weights,
198
            key_padding_mask=mask if use_padding else None,
199
        )
200
        if use_nt:
201
            qs = list(torch.unbind(q))
202
            if use_padding:
203
                if pad_all:
204
                    qs = [x[:-1] for x in qs]
205
                else:
206
                    qs[0] = qs[0][:-1]
207
            q = torch.nested.nested_tensor(qs, device=device, dtype=dtype)
208
            if mode == "self":
209
                k = v = q
210
            elif mode == "encdec":
211
                k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
212
                v = k
213
            else:
214
                k = torch.nested.nested_tensor(torch.unbind(k), device=device, dtype=dtype)
215
                v = torch.nested.nested_tensor(torch.unbind(v), device=device, dtype=dtype)
216

217
        native_q = q.to(dtype=dtype)
218
        native_k = k.to(dtype=dtype)
219
        native_v = v.to(dtype=dtype)
220

221
        ynpt, weight_npt = npt(
222
            native_q, native_k, native_v, key_padding_mask=mask if use_padding and not use_nt else None
223
        )
224
        if use_nt:
225
            ynpt = ynpt.to_padded_tensor(0)
226
            if pad_all:
227
                ynpt_final = torch.zeros_like(ypt)
228
                ynpt_final[:, :ynpt.shape[1], :] = ynpt
229
                ynpt = ynpt_final
230

231
        def do_pad_all(tensors):
232
            for t in tensors:
233
                for t_i in t:
234
                    t_i[-1] = torch.zeros_like(t_i[-1], device=device, dtype=dtype)
235

236
        # PyTorch implementation returns non-zero junk in the padding
237
        # locations; overwrite it so that the comparison works out.
238
        if use_padding:
239
            ypt[0][-1] = torch.zeros_like(ypt[0][-1], device=device, dtype=dtype)
240
            ynpt[0][-1] = torch.zeros_like(ynpt[0][-1], device=device, dtype=dtype)
241
            if pad_all:
242
                do_pad_all((ypt, ynpt))
243
            # Zero the last row of each TxT weight matrix
244
            if need_weights:
245
                if average_attn_weights:
246
                    weight_pt[0][-1] = torch.zeros_like(weight_pt[0][-1], device=device, dtype=dtype)
247
                    weight_npt[0][-1] = torch.zeros_like(weight_npt[0][-1], device=device, dtype=dtype)
248
                    if pad_all:
249
                        do_pad_all((weight_pt, weight_npt))
250
                else:
251
                    for nh in range(num_heads):
252
                        weight_pt[0][nh][-1] = torch.zeros_like(weight_pt[0][nh][-1], device=device, dtype=dtype)
253
                        weight_npt[0][nh][-1] = torch.zeros_like(weight_npt[0][nh][-1], device=device, dtype=dtype)
254

255
        if dtype == torch.half:
256
            torch.testing.assert_close(ypt, ynpt.to(torch.float32), atol=1e-3, rtol=1e-3)
257
        else:
258
            # High rtol seems necessary for
259
            # test_native_multihead_attention_cpu_float32 on Windows,
260
            # otherwise 2e-4 would likely be fine.
261
            torch.testing.assert_close(ypt, ynpt, atol=2e-5, rtol=2e-3)
262

263
        if need_weights:
264
            torch.testing.assert_close(weight_pt, weight_npt.to(torch.float32), atol=5e-4, rtol=5e-4)
265
        else:
266
            self.assertEqual(weight_pt, weight_npt)
267

268
    @dtypesIfCUDA(torch.float, torch.half)
269
    @dtypes(torch.float)
270
    @skipMeta
271
    @parametrize("use_nt", [False, True])
272
    @parametrize("use_padding, pad_all", [(False, False), (True, False), (True, True)])
273
    @parametrize("need_weights", [False])
274
    @parametrize("average_attn_weights", [False, True])
275
    @parametrize("fused", [False, True])
276
    @torch.no_grad()
277
    def test_native_multihead_self_attention(self, device, dtype, use_nt,
278
                                             need_weights, average_attn_weights, use_padding, pad_all, fused):
279
        if TEST_WITH_ROCM and use_nt:
280
            self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
281
        for need_weights in (False, not pad_all):
282
            with self.subTest(use_padding=use_padding, pad_all=pad_all,
283
                              use_nt=use_nt, need_weights=need_weights,
284
                              average_attn_weights=average_attn_weights):
285
                with torch.backends.cuda.sdp_kernel(
286
                        enable_flash=False, enable_mem_efficient=False
287
                ) if not fused else torch.backends.cuda.sdp_kernel(
288
                        enable_flash=True, enable_mem_efficient=True
289
                ):
290
                    self._test_multihead_attention_impl(
291
                        device,
292
                        dtype,
293
                        "self",
294
                        use_nt=use_nt,
295
                        use_padding=use_padding,
296
                        pad_all=pad_all,
297
                        need_weights=need_weights,
298
                        average_attn_weights=average_attn_weights,
299
                    )
300

301
    @dtypesIfCUDA(torch.float, torch.half)
302
    @dtypes(torch.float)
303
    @skipMeta
304
    @torch.no_grad()
305
    def test_native_multihead_encoder_decoder_attention(self, device, dtype):
306
        self._test_multihead_attention_impl(
307
            device,
308
            dtype,
309
            "encdec",
310
            use_nt=False,
311
            need_weights=False,
312
            average_attn_weights=False,
313
        )
314

315
    @dtypesIfCUDA(torch.float, torch.half)
316
    @dtypes(torch.float)
317
    @skipMeta
318
    @torch.no_grad()
319
    def test_native_multihead_attention(self, device, dtype):
320
        self._test_multihead_attention_impl(
321
            device,
322
            dtype,
323
            "generic",
324
            use_nt=False,
325
            need_weights=False,
326
            average_attn_weights=False,
327
        )
328

329

330
instantiate_device_type_tests(TestMHADeviceType, globals())
331

332
if __name__ == "__main__":
333
    run_tests()
334

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.