intel-extension-for-pytorch

Форк
0
790 строк · 34.5 Кб
1
import torch
2
import torch.nn as nn
3
from common_utils import TestCase
4
import unittest
5
from typing import Tuple
6
import intel_extension_for_pytorch as ipex
7

8

9
class MaskedMHA(torch.nn.Module):
10
    def __init__(self, hidden_size=4096, n_head=16, n_head_kv=16, head_dim=256):
11
        super().__init__()
12
        self.num_heads = n_head
13
        self.num_kv = n_head_kv
14
        self.head_dim = head_dim
15
        self.query_key_value = nn.Linear(
16
            hidden_size, (n_head_kv * 2 + n_head) * head_dim
17
        )
18

19
    def _split_heads(
20
        self, fused_qkv: torch.Tensor
21
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
22
        """
23
        Split the last dimension into (num_heads, head_dim), results share same memory
24
        storage as `fused_qkv`
25

26
        Args:
27
            fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, (num_heads + kv_num * 2) * head_dim]
28

29
        Returns:
30
            query: [batch_size, seq_length, num_heads, head_dim]
31
            key: [batch_size, seq_length, kv_num, head_dim]
32
            value: [batch_size, seq_length, kv_num, head_dim]
33
        """
34
        bs = fused_qkv.shape[0]
35
        query_layer = fused_qkv[:, :, : self.num_heads * self.head_dim]
36
        query_layer = query_layer.view(bs, -1, self.num_heads, self.head_dim)
37
        key_layer = fused_qkv[
38
            :,
39
            :,
40
            self.num_heads
41
            * self.head_dim : (self.num_heads + self.num_kv)
42
            * self.head_dim,
43
        ]
44
        key_layer = key_layer.view(bs, -1, self.num_kv, self.head_dim)
45
        value_layer = fused_qkv[:, :, (self.num_heads + self.num_kv) * self.head_dim :]
46
        value_layer = value_layer.view(bs, -1, self.num_kv, self.head_dim)
47
        return query_layer, key_layer, value_layer
48

49
    def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
50
        "torch.repeat_interleave(x, dim=2, repeats=n_rep)"
51
        bs, slen, n_kv_heads, head_dim = x.shape
52
        if n_rep == 1:
53
            return x
54
        return (
55
            x[:, :, :, None, :]
56
            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
57
            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
58
        )
59

60
    def forward(
61
        self,
62
        input_t,
63
        key_cache,
64
        value_cache,
65
        max_position,
66
        attention_mask,
67
        beam_idx,
68
        indirect_access_kv_cache=False,
69
        offset=0,
70
        enable_linear=True,
71
    ):
72
        head_size = self.head_dim
73
        origin_type = input_t.dtype
74
        if enable_linear:
75
            query, key, value = self._split_heads(self.query_key_value(input_t))
76
        else:
77
            query, key, value = self._split_heads(input_t)
78
        if indirect_access_kv_cache:
79
            query = query.contiguous()
80
            key = key.contiguous()
81
            value = value.contiguous()
82
            return torch.ops.torch_ipex.masked_multihead_self_attention(
83
                query,
84
                key,
85
                value,
86
                key_cache,
87
                value_cache,
88
                beam_idx,
89
                offset,
90
                head_size**0.5,
91
                max_position,
92
                None,
93
                attention_mask,
94
            )
95
        else:
96
            # Get the concatenated key and value
97
            if key_cache is not None:
98
                key = torch.cat([key_cache, key], dim=1)
99
                value = torch.cat([value_cache, value], dim=1)
100
            key_cache = key
101
            value_cache = value
102
            n_rep = self.num_heads // self.num_kv
103
            key = self._repeat_kv(key, n_rep)
104
            value = self._repeat_kv(value, n_rep)
105

106
            key = key.transpose(1, 2)
107
            query = query.transpose(1, 2)
108
            value = value.transpose(1, 2)
109
            if origin_type == torch.half:
110
                key = key.to(torch.float32)
111
                query = query.to(torch.float32)
112
                value = value.to(torch.float32)
113
            # matmul new_key and new_value to get the attention score
114
            attention_scores = torch.matmul(query, key.transpose(-1, -2))
115
            # scale the attention score
116
            attention_scores = attention_scores / (head_size**0.5)
117
            # import pdb; pdb.set_trace()
118
            if attention_mask is not None:
119
                attention_scores = attention_scores + attention_mask
120
            # softmax the attention score
121
            attention_probs = attention_scores.softmax(dim=-1)
122
            # matmul the attention score and value to get the context
123
            attention_output = torch.matmul(attention_probs, value)
124
            if origin_type == torch.half:
125
                attention_output = attention_output.to(origin_type)
126
            return attention_output, None, key_cache, value_cache, None
127

128

129
class MaskedMHATest(TestCase):
130
    def _test_mha(self, torchcompile=False):
131
        beam_size_list = [1, 4]
132
        batch_size_list = [1, 2, 4]
133
        head_size = 256
134
        head_num = 16
135
        head_num_kv_list = [1, 4, 16]
136
        max_seq_len = 64
137
        first_seq_len = 32
138
        for batch_size in batch_size_list:
139
            for beam_size in beam_size_list:
140
                for head_num_kv in head_num_kv_list:
141
                    key_cache = None
142
                    value_cache = None
143
                    offset = 0
144
                    mha = MaskedMHA(
145
                        n_head=head_num, n_head_kv=head_num_kv, head_dim=head_size
146
                    )
147

148
                    if torchcompile:
149
                        torch._dynamo.reset()
150
                        ipex._set_compiler_backend("inductor")
151
                        mha = torch.compile(mha, backend="ipex")
152

153
                    # first token decode
154
                    input_t = torch.randn(
155
                        batch_size,
156
                        first_seq_len,
157
                        head_num * head_size,
158
                        dtype=torch.float32,
159
                    )
160
                    key_cache_iakv = torch.randn(
161
                        max_seq_len,
162
                        beam_size * batch_size,
163
                        head_num,
164
                        head_size,
165
                        dtype=torch.float32,
166
                    )
167
                    value_cache_iakv = torch.randn(
168
                        max_seq_len,
169
                        beam_size * batch_size,
170
                        head_num,
171
                        head_size,
172
                        dtype=torch.float32,
173
                    )
174
                    beam_idx = torch.zeros(
175
                        max_seq_len, beam_size * batch_size, dtype=torch.int64
176
                    )
177
                    # create attention mask and causal mask
178
                    attention_mask = torch.zeros(
179
                        batch_size, 1, first_seq_len, first_seq_len, dtype=torch.float32
180
                    )
181
                    casual_mask = torch.full(
182
                        (first_seq_len, first_seq_len), -1e6, dtype=input_t.dtype
183
                    )
184
                    casual_mask = casual_mask.triu(1)
185
                    casual_mask = casual_mask.unsqueeze(0).unsqueeze(0)
186
                    attention_mask = (
187
                        attention_mask + casual_mask
188
                    )  # combine the attention mask and causal mask
189
                    # UT for first token with fp32
190
                    with torch.inference_mode(), torch.no_grad():
191
                        naive_output, _, key_cache, value_cache, _ = mha(
192
                            input_t, None, None, max_seq_len, attention_mask, None, None
193
                        )
194
                        (
195
                            indirect_access_kv_cache_output,
196
                            _,
197
                            key_cache_iakv,
198
                            value_cache_iakv,
199
                            beam_idx,
200
                        ) = mha(
201
                            input_t,
202
                            key_cache_iakv,
203
                            value_cache_iakv,
204
                            max_seq_len,
205
                            attention_mask,
206
                            beam_idx,
207
                            True,
208
                            torch.tensor(offset),
209
                        )
210
                        # self.assertEqual(naive_output, indirect_access_kv_cache_output)
211
                        key_cache = key_cache.repeat_interleave(beam_size, dim=0)
212
                        value_cache = value_cache.repeat_interleave(beam_size, dim=0)
213
                        for i in range(batch_size):
214
                            self.assertEqual(
215
                                key_cache.transpose(0, 1)[:, i * beam_size, :, :],
216
                                key_cache_iakv[0:first_seq_len, i * beam_size, :, :],
217
                            )
218
                            self.assertEqual(
219
                                value_cache.transpose(0, 1)[:, i * beam_size, :, :],
220
                                value_cache_iakv[0:first_seq_len, i * beam_size, :, :],
221
                            )
222
                        if beam_size == 4:
223
                            beam_idx_t = torch.zeros(
224
                                beam_size * batch_size, dtype=torch.int64
225
                            )
226
                            for i in range(1, batch_size):
227
                                beam_idx_t[
228
                                    i * beam_size : i * beam_size + beam_size
229
                                ] = (
230
                                    beam_idx_t[
231
                                        i * beam_size : i * beam_size + beam_size
232
                                    ]
233
                                    + i * beam_size
234
                                )
235
                        elif beam_size == 1:
236
                            beam_idx_t = torch.arange(batch_size)
237
                        beam_idx[offset] = beam_idx_t
238
                        # reorder cache for naive impelementation
239
                        key_cache = torch.index_select(key_cache, 0, beam_idx_t)
240
                        value_cache = torch.index_select(value_cache, 0, beam_idx_t)
241

242
                    # # #UT for first token with bf16
243
                    input_t_bf16 = input_t.bfloat16()
244
                    key_cache_iakv_bf16 = key_cache_iakv.bfloat16()
245
                    value_cache_iakv_bf16 = value_cache_iakv.bfloat16()
246
                    attention_mask_bf16 = attention_mask.bfloat16()
247
                    with torch.inference_mode(), torch.no_grad(), torch.autocast(
248
                        device_type="cpu",
249
                        enabled=True,
250
                        dtype=torch.bfloat16,
251
                    ):
252
                        naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(
253
                            input_t_bf16,
254
                            None,
255
                            None,
256
                            max_seq_len,
257
                            attention_mask_bf16,
258
                            None,
259
                            None,
260
                        )
261
                        (
262
                            indirect_access_kv_cache_output_bf16,
263
                            _,
264
                            key_cache_iakv_bf16,
265
                            value_cache_iakv_bf16,
266
                            beam_idx,
267
                        ) = mha(
268
                            input_t_bf16,
269
                            key_cache_iakv_bf16,
270
                            value_cache_iakv_bf16,
271
                            max_seq_len,
272
                            attention_mask_bf16,
273
                            beam_idx,
274
                            True,
275
                            torch.tensor(offset),
276
                        )
277
                        self.assertEqual(
278
                            naive_output_bf16,
279
                            indirect_access_kv_cache_output_bf16,
280
                            prec=2e-2,
281
                        )
282
                        key_cache_bf16 = key_cache_bf16.repeat_interleave(
283
                            beam_size, dim=0
284
                        )
285
                        value_cache_bf16 = value_cache_bf16.repeat_interleave(
286
                            beam_size, dim=0
287
                        )
288
                        for i in range(batch_size):
289
                            self.assertEqual(
290
                                key_cache_bf16.transpose(0, 1)[:, i * beam_size, :, :],
291
                                key_cache_iakv_bf16[
292
                                    0:first_seq_len, i * beam_size, :, :
293
                                ],
294
                            )
295
                            self.assertEqual(
296
                                value_cache_bf16.transpose(0, 1)[
297
                                    :, i * beam_size, :, :
298
                                ],
299
                                value_cache_iakv_bf16[
300
                                    0:first_seq_len, i * beam_size, :, :
301
                                ],
302
                            )
303
                        key_cache_bf16 = torch.index_select(
304
                            key_cache_bf16, 0, beam_idx_t
305
                        )
306
                        value_cache_bf16 = torch.index_select(
307
                            value_cache_bf16, 0, beam_idx_t
308
                        )
309

310
                    offset = offset + first_seq_len
311
                    # UT for next token with fp32
312
                    input_t = torch.randn(
313
                        beam_size * batch_size,
314
                        1,
315
                        head_num * head_size,
316
                        dtype=torch.float32,
317
                    )
318
                    attention_mask = torch.zeros(
319
                        beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32
320
                    )
321
                    with torch.inference_mode(), torch.no_grad():
322
                        naive_output, _, key_cache, value_cache, _ = mha(
323
                            input_t,
324
                            key_cache,
325
                            value_cache,
326
                            max_seq_len,
327
                            attention_mask,
328
                            None,
329
                            None,
330
                        )
331
                        (
332
                            indirect_access_kv_cache_output,
333
                            _,
334
                            key_cache_iakv,
335
                            value_cache_iakv,
336
                            beam_idx,
337
                        ) = mha(
338
                            input_t,
339
                            key_cache_iakv,
340
                            value_cache_iakv,
341
                            max_seq_len,
342
                            attention_mask,
343
                            beam_idx,
344
                            True,
345
                            torch.tensor(offset),
346
                        )
347
                        self.assertEqual(naive_output, indirect_access_kv_cache_output)
348
                        self.assertEqual(
349
                            key_cache.transpose(0, 1)[offset],
350
                            key_cache_iakv[offset, :, :, :],
351
                        )
352
                        self.assertEqual(
353
                            value_cache.transpose(0, 1)[offset],
354
                            value_cache_iakv[offset, :, :, :],
355
                        )
356
                    # #UT for next token with bf16
357
                    input_t_bf16 = input_t.bfloat16()
358
                    attention_mask_bf16 = attention_mask.bfloat16()
359
                    with torch.inference_mode(), torch.no_grad(), torch.autocast(
360
                        device_type="cpu",
361
                        enabled=True,
362
                        dtype=torch.bfloat16,
363
                    ):
364
                        naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(
365
                            input_t_bf16,
366
                            key_cache_bf16,
367
                            value_cache_bf16,
368
                            max_seq_len,
369
                            attention_mask_bf16,
370
                            None,
371
                            None,
372
                        )
373
                        (
374
                            indirect_access_kv_cache_output_bf16,
375
                            _,
376
                            key_cache_iakv_bf16,
377
                            value_cache_iakv_bf16,
378
                            beam_idx,
379
                        ) = mha(
380
                            input_t_bf16,
381
                            key_cache_iakv_bf16,
382
                            value_cache_iakv_bf16,
383
                            max_seq_len,
384
                            attention_mask_bf16,
385
                            beam_idx,
386
                            True,
387
                            torch.tensor(offset),
388
                        )
389
                        self.assertEqual(
390
                            naive_output_bf16,
391
                            indirect_access_kv_cache_output_bf16,
392
                            prec=0.05,
393
                        )
394
                        self.assertEqual(
395
                            key_cache_bf16.transpose(0, 1)[offset],
396
                            key_cache_iakv_bf16[offset, :, :, :],
397
                        )
398
                        self.assertEqual(
399
                            value_cache_bf16.transpose(0, 1)[offset],
400
                            value_cache_iakv_bf16[offset, :, :, :],
401
                        )
402
                        if beam_size == 4:
403
                            beam_idx_t = torch.tensor([1, 3, 0, 0]).repeat(batch_size)
404
                            for i in range(1, batch_size):
405
                                beam_idx_t[
406
                                    i * beam_size : i * beam_size + beam_size
407
                                ] = (
408
                                    beam_idx_t[
409
                                        i * beam_size : i * beam_size + beam_size
410
                                    ]
411
                                    + i * beam_size
412
                                )
413
                        elif beam_size == 1:
414
                            beam_idx_t = torch.arange(batch_size)
415
                        beam_idx[offset] = beam_idx_t
416
                        offset = offset + 1
417
                        # reorder cache for naive impelementation
418
                        key_cache = torch.index_select(key_cache, 0, beam_idx_t)
419
                        value_cache = torch.index_select(value_cache, 0, beam_idx_t)
420
                        key_cache_bf16 = torch.index_select(
421
                            key_cache_bf16, 0, beam_idx_t
422
                        )
423
                        value_cache_bf16 = torch.index_select(
424
                            value_cache_bf16, 0, beam_idx_t
425
                        )
426
                    # UT for next token with fp32
427
                    input_t = torch.randn(
428
                        beam_size * batch_size,
429
                        1,
430
                        head_num * head_size,
431
                        dtype=torch.float32,
432
                    )
433
                    attention_mask = torch.zeros(
434
                        beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32
435
                    )
436
                    with torch.inference_mode(), torch.no_grad():
437
                        naive_output, _, key_cache, value_cache, _ = mha(
438
                            input_t,
439
                            key_cache,
440
                            value_cache,
441
                            max_seq_len,
442
                            attention_mask,
443
                            None,
444
                            None,
445
                        )
446
                        (
447
                            indirect_access_kv_cache_output,
448
                            _,
449
                            key_cache_iakv,
450
                            value_cache_iakv,
451
                            beam_idx,
452
                        ) = mha(
453
                            input_t,
454
                            key_cache_iakv,
455
                            value_cache_iakv,
456
                            max_seq_len,
457
                            attention_mask,
458
                            beam_idx,
459
                            True,
460
                            torch.tensor(offset),
461
                        )
462
                        self.assertEqual(naive_output, indirect_access_kv_cache_output)
463
                        self.assertEqual(
464
                            key_cache.transpose(0, 1)[offset],
465
                            key_cache_iakv[offset, :, :, :],
466
                        )
467
                        self.assertEqual(
468
                            value_cache.transpose(0, 1)[offset],
469
                            value_cache_iakv[offset, :, :, :],
470
                        )
471
                    # #UT for next token with bf16
472
                    input_t_bf16 = input_t.bfloat16()
473
                    attention_mask_bf16 = attention_mask.bfloat16()
474
                    with torch.inference_mode(), torch.no_grad(), torch.autocast(
475
                        device_type="cpu",
476
                        enabled=True,
477
                        dtype=torch.bfloat16,
478
                    ):
479
                        naive_output_bf16, _, key_cache_bf16, value_cache_bf16, _ = mha(
480
                            input_t_bf16,
481
                            key_cache_bf16,
482
                            value_cache_bf16,
483
                            max_seq_len,
484
                            attention_mask_bf16,
485
                            None,
486
                            None,
487
                        )
488
                        (
489
                            indirect_access_kv_cache_output_bf16,
490
                            _,
491
                            key_cache_iakv_bf16,
492
                            value_cache_iakv_bf16,
493
                            beam_idx,
494
                        ) = mha(
495
                            input_t_bf16,
496
                            key_cache_iakv_bf16,
497
                            value_cache_iakv_bf16,
498
                            max_seq_len,
499
                            attention_mask_bf16,
500
                            beam_idx,
501
                            True,
502
                            torch.tensor(offset),
503
                        )
504
                        self.assertEqual(
505
                            naive_output_bf16,
506
                            indirect_access_kv_cache_output_bf16,
507
                            prec=0.05,
508
                        )
509
                        self.assertEqual(
510
                            key_cache_bf16.transpose(0, 1)[offset],
511
                            key_cache_iakv_bf16[offset, :, :, :],
512
                        )
513
                        self.assertEqual(
514
                            value_cache_bf16.transpose(0, 1)[offset],
515
                            value_cache_iakv_bf16[offset, :, :, :],
516
                        )
517

518
    def _test_mha_fp16(self, torchcompile=False):
519
        beam_size_list = [1, 4]
520
        batch_size_list = [1, 2, 4]
521
        head_size = 256
522
        head_num = 16
523
        head_num_kv_list = [1, 4, 16]
524
        max_seq_len = 64
525
        first_seq_len = 32
526
        for batch_size in batch_size_list:
527
            for beam_size in beam_size_list:
528
                for head_num_kv in head_num_kv_list:
529
                    offset = 0
530
                    mha = MaskedMHA(
531
                        n_head=head_num, n_head_kv=head_num_kv, head_dim=head_size
532
                    )
533

534
                    if torchcompile:
535
                        torch._dynamo.reset()
536
                        ipex._set_compiler_backend("inductor")
537
                        mha = torch.compile(mha, backend="ipex")
538

539
                    # first token decode
540
                    input_t = torch.randn(
541
                        batch_size,
542
                        first_seq_len,
543
                        (head_num + 2 * head_num_kv) * head_size,
544
                        dtype=torch.float32,
545
                    )
546
                    key_cache_iakv = torch.randn(
547
                        max_seq_len,
548
                        beam_size * batch_size,
549
                        head_num,
550
                        head_size,
551
                        dtype=torch.float32,
552
                    )
553
                    value_cache_iakv = torch.randn(
554
                        max_seq_len,
555
                        beam_size * batch_size,
556
                        head_num,
557
                        head_size,
558
                        dtype=torch.float32,
559
                    )
560
                    beam_idx = torch.zeros(
561
                        max_seq_len, beam_size * batch_size, dtype=torch.int64
562
                    )
563
                    # create attention mask and causal mask
564
                    attention_mask = torch.zeros(
565
                        batch_size, 1, first_seq_len, first_seq_len, dtype=torch.float32
566
                    )
567
                    casual_mask = torch.full(
568
                        (first_seq_len, first_seq_len), -1e6, dtype=input_t.dtype
569
                    )
570
                    casual_mask = casual_mask.triu(1)
571
                    casual_mask = casual_mask.unsqueeze(0).unsqueeze(0)
572
                    attention_mask = (
573
                        attention_mask + casual_mask
574
                    )  # combine the attention mask and causal mask
575
                    if beam_size == 4:
576
                        beam_idx_t = torch.zeros(
577
                            beam_size * batch_size, dtype=torch.int64
578
                        )
579
                        for i in range(1, batch_size):
580
                            beam_idx_t[i * beam_size : i * beam_size + beam_size] = (
581
                                beam_idx_t[i * beam_size : i * beam_size + beam_size]
582
                                + i * beam_size
583
                            )
584
                    elif beam_size == 1:
585
                        beam_idx_t = torch.arange(batch_size)
586
                    beam_idx[offset] = beam_idx_t
587
                    # # #UT for first token with fp16
588
                    input_t_half = input_t.half()
589
                    key_cache_iakv_half = key_cache_iakv.half()
590
                    value_cache_iakv_half = value_cache_iakv.half()
591
                    attention_mask_half = attention_mask.half()
592
                    with torch.inference_mode(), torch.no_grad():
593
                        naive_output_half, _, key_cache_half, value_cache_half, _ = mha(
594
                            input_t_half,
595
                            None,
596
                            None,
597
                            max_seq_len,
598
                            attention_mask_half,
599
                            None,
600
                            None,
601
                            enable_linear=False,
602
                        )
603
                        (
604
                            indirect_access_kv_cache_output_half,
605
                            _,
606
                            key_cache_iakv_half,
607
                            value_cache_iakv_half,
608
                            beam_idx,
609
                        ) = mha(
610
                            input_t_half,
611
                            key_cache_iakv_half,
612
                            value_cache_iakv_half,
613
                            max_seq_len,
614
                            attention_mask_half,
615
                            beam_idx,
616
                            True,
617
                            torch.tensor(offset),
618
                            enable_linear=False,
619
                        )
620
                        self.assertEqual(
621
                            naive_output_half,
622
                            indirect_access_kv_cache_output_half,
623
                            prec=2e-2,
624
                        )
625
                        key_cache_half = key_cache_half.repeat_interleave(
626
                            beam_size, dim=0
627
                        )
628
                        value_cache_half = value_cache_half.repeat_interleave(
629
                            beam_size, dim=0
630
                        )
631
                        for i in range(batch_size):
632
                            self.assertEqual(
633
                                key_cache_half.transpose(0, 1)[:, i * beam_size, :, :],
634
                                key_cache_iakv_half[
635
                                    0:first_seq_len, i * beam_size, :, :
636
                                ],
637
                            )
638
                            self.assertEqual(
639
                                value_cache_half.transpose(0, 1)[
640
                                    :, i * beam_size, :, :
641
                                ],
642
                                value_cache_iakv_half[
643
                                    0:first_seq_len, i * beam_size, :, :
644
                                ],
645
                            )
646
                        key_cache_half = torch.index_select(
647
                            key_cache_half, 0, beam_idx_t
648
                        )
649
                        value_cache_half = torch.index_select(
650
                            value_cache_half, 0, beam_idx_t
651
                        )
652

653
                    offset = offset + first_seq_len
654
                    # #UT for next token with fp32
655
                    input_t = torch.randn(
656
                        beam_size * batch_size,
657
                        1,
658
                        (head_num + 2 * head_num_kv) * head_size,
659
                        dtype=torch.float32,
660
                    )
661
                    attention_mask = torch.zeros(
662
                        beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32
663
                    )
664
                    # UT for next token with fp16
665
                    input_t_half = input_t.half()
666
                    attention_mask_half = attention_mask.half()
667
                    with torch.inference_mode(), torch.no_grad():
668
                        naive_output_half, _, key_cache_half, value_cache_half, _ = mha(
669
                            input_t_half,
670
                            key_cache_half,
671
                            value_cache_half,
672
                            max_seq_len,
673
                            attention_mask_half,
674
                            None,
675
                            None,
676
                            enable_linear=False,
677
                        )
678
                        (
679
                            indirect_access_kv_cache_output_half,
680
                            _,
681
                            key_cache_iakv_half,
682
                            value_cache_iakv_half,
683
                            beam_idx,
684
                        ) = mha(
685
                            input_t_half,
686
                            key_cache_iakv_half,
687
                            value_cache_iakv_half,
688
                            max_seq_len,
689
                            attention_mask_half,
690
                            beam_idx,
691
                            True,
692
                            torch.tensor(offset),
693
                            enable_linear=False,
694
                        )
695
                        self.assertEqual(
696
                            naive_output_half,
697
                            indirect_access_kv_cache_output_half,
698
                            prec=0.05,
699
                        )
700
                        self.assertEqual(
701
                            key_cache_half.transpose(0, 1)[offset],
702
                            key_cache_iakv_half[offset, :, :, :],
703
                        )
704
                        self.assertEqual(
705
                            value_cache_half.transpose(0, 1)[offset],
706
                            value_cache_iakv_half[offset, :, :, :],
707
                        )
708
                        if beam_size == 4:
709
                            beam_idx_t = torch.tensor([1, 3, 0, 0]).repeat(batch_size)
710
                            for i in range(1, batch_size):
711
                                beam_idx_t[
712
                                    i * beam_size : i * beam_size + beam_size
713
                                ] = (
714
                                    beam_idx_t[
715
                                        i * beam_size : i * beam_size + beam_size
716
                                    ]
717
                                    + i * beam_size
718
                                )
719
                        elif beam_size == 1:
720
                            beam_idx_t = torch.arange(batch_size)
721
                        beam_idx[offset] = beam_idx_t
722
                        offset = offset + 1
723
                        key_cache_half = torch.index_select(
724
                            key_cache_half, 0, beam_idx_t
725
                        )
726
                        value_cache_half = torch.index_select(
727
                            value_cache_half, 0, beam_idx_t
728
                        )
729
                    # #UT for next token with fp32
730
                    input_t = torch.randn(
731
                        beam_size * batch_size,
732
                        1,
733
                        (head_num + 2 * head_num_kv) * head_size,
734
                        dtype=torch.float32,
735
                    )
736
                    attention_mask = torch.zeros(
737
                        beam_size * batch_size, 1, 1, offset + 1, dtype=torch.float32
738
                    )
739
                    # #UT for next token with fp16
740
                    input_t_half = input_t.half()
741
                    attention_mask_half = attention_mask.half()
742
                    with torch.inference_mode(), torch.no_grad():
743
                        naive_output_half, _, key_cache_half, value_cache_half, _ = mha(
744
                            input_t_half,
745
                            key_cache_half,
746
                            value_cache_half,
747
                            max_seq_len,
748
                            attention_mask_half,
749
                            None,
750
                            None,
751
                            enable_linear=False,
752
                        )
753
                        (
754
                            indirect_access_kv_cache_output_half,
755
                            _,
756
                            key_cache_iakv_half,
757
                            value_cache_iakv_half,
758
                            beam_idx,
759
                        ) = mha(
760
                            input_t_half,
761
                            key_cache_iakv_half,
762
                            value_cache_iakv_half,
763
                            max_seq_len,
764
                            attention_mask_half,
765
                            beam_idx,
766
                            True,
767
                            torch.tensor(offset),
768
                            enable_linear=False,
769
                        )
770
                        self.assertEqual(
771
                            naive_output_half,
772
                            indirect_access_kv_cache_output_half,
773
                            prec=0.05,
774
                        )
775
                        self.assertEqual(
776
                            key_cache_half.transpose(0, 1)[offset],
777
                            key_cache_iakv_half[offset, :, :, :],
778
                        )
779
                        self.assertEqual(
780
                            value_cache_half.transpose(0, 1)[offset],
781
                            value_cache_iakv_half[offset, :, :, :],
782
                        )
783

784
    def test_mha(self):
785
        self._test_mha(torchcompile=False)
786
        self._test_mha_fp16(torchcompile=False)
787

788

789
if __name__ == "__main__":
790
    test = unittest.main()
791

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.