intel-extension-for-pytorch

Форк
0
/
test_paged_attention.py 
305 строк · 10.3 Кб
1
import torch
2
from common_utils import TestCase
3
import unittest
4
import random
5
from typing import List, Optional, Tuple
6
from itertools import product
7

8

9
class PagedAttentionTest(TestCase):
10
    def create_kv_caches(
11
        self,
12
        num_blocks: int,
13
        block_size: int,
14
        num_layer: int,
15
        num_head: int,
16
        head_size: int,
17
        dtype: torch.dtype,
18
        seed: int,
19
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
20
        torch.random.manual_seed(seed)
21
        torch.manual_seed(seed)
22

23
        scale = head_size**-0.5
24
        key_cache_shape = (num_blocks, block_size, num_head, head_size)
25
        key_caches = []
26
        for _ in range(num_layer):
27
            key_cache = torch.empty(size=key_cache_shape, dtype=dtype)
28
            key_cache.uniform_(-scale, scale)
29
            key_caches.append(key_cache)
30

31
        value_cache_shape = (num_blocks, block_size, num_head, head_size)
32
        value_caches = []
33
        for _ in range(num_layer):
34
            value_cache = torch.empty(size=value_cache_shape, dtype=dtype)
35
            value_cache.uniform_(-scale, scale)
36
            value_caches.append(value_cache)
37
        return key_caches, value_caches
38

39
    def ref_masked_attention(
40
        self,
41
        query: torch.Tensor,
42
        key: torch.Tensor,
43
        value: torch.Tensor,
44
        scale: float,
45
        attn_mask: Optional[torch.Tensor] = None,
46
    ) -> torch.Tensor:
47
        attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
48
        if attn_mask is not None:
49
            attn_weights = attn_weights + attn_mask.float()
50
        attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
51
        out = torch.einsum("hqk,khd->qhd", attn_weights, value)
52
        return out
53

54
    def ref_single_query_cached_kv_attention(
55
        self,
56
        output: torch.Tensor,
57
        query: torch.Tensor,
58
        num_queries_per_kv: int,
59
        key_cache: torch.Tensor,
60
        value_cache: torch.Tensor,
61
        block_tables: torch.Tensor,
62
        context_lens: torch.Tensor,
63
        scale: float,
64
        alibi_slopes: Optional[torch.Tensor],
65
    ) -> None:
66
        num_query_heads = query.shape[1]
67
        num_kv_head = value_cache.shape[2]
68
        head_size = value_cache.shape[3]
69
        block_size = value_cache.shape[1]
70
        num_seqs = query.shape[0]
71

72
        block_tables = block_tables.cpu().tolist()
73
        context_lens = context_lens.cpu().tolist()
74
        for i in range(num_seqs):
75
            q = query[i].unsqueeze(0)
76
            block_table = block_tables[i]
77
            context_len = int(context_lens[i])
78

79
            keys = []
80
            values = []
81
            for j in range(context_len):
82
                block_number = int(block_table[j // block_size])
83
                block_offset = j % block_size
84

85
                k = key_cache[block_number, block_offset, :, :]
86
                k = k.reshape(num_kv_head, head_size)
87
                keys.append(k)
88

89
                v = value_cache[block_number, block_offset, :, :]
90
                values.append(v)
91
            keys = torch.stack(keys, dim=0)
92
            values = torch.stack(values, dim=0)
93
            if num_queries_per_kv > 1:
94
                # Handle MQA and GQA
95
                keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
96
                values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
97
            alibi_bias = None
98
            if alibi_slopes is not None:
99
                # Create the ALiBi bias used in the paged attention kernel.
100
                position_ids = torch.arange(context_len, device="cpu").int()
101
                alibi_bias = (position_ids - context_len + 1).float()
102
                alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
103

104
            out = self.ref_masked_attention(q, keys, values, scale, alibi_bias)
105
            out = out.view(num_query_heads, head_size)
106
            output[i].copy_(out, non_blocking=True)
107

108
    def _test_paged_attention_func(
109
        self,
110
        num_seqs: int,
111
        num_head: Tuple[int, int],
112
        head_size: int,
113
        use_alibi: bool,
114
        num_blocks: int,
115
        block_size: int,
116
        dtype: torch.dtype,
117
        seed: int,
118
    ) -> None:
119
        random.seed(seed)
120
        torch.random.manual_seed(seed)
121
        torch.manual_seed(seed)
122
        max_seq_len = 1024
123
        scale = float(1.0 / (head_size**0.5))
124
        num_query_heads, num_kv_head = num_head
125
        query = torch.empty(
126
            num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu"
127
        )
128
        query.uniform_(-scale, scale)
129
        assert num_query_heads % num_kv_head == 0
130
        num_queries_per_kv = num_query_heads // num_kv_head
131
        head_mapping = torch.repeat_interleave(
132
            torch.arange(num_kv_head, dtype=torch.int32, device="cpu"),
133
            num_queries_per_kv,
134
        )
135
        alibi_slopes = None
136
        if use_alibi:
137
            alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device="cpu")
138

139
        context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)]
140
        context_lens[-1] = max_seq_len
141
        max_context_len = max(context_lens)
142
        context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu")
143

144
        # Create the block tables.NUM_PREFILL_SEQS
145
        max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
146
        block_tables = []
147
        for _ in range(num_seqs):
148
            block_table = [
149
                random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)
150
            ]
151
            block_tables.append(block_table)
152
        block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu")
153

154
        # Create the KV caches.
155
        key_caches, value_caches = self.create_kv_caches(
156
            num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed
157
        )
158
        key_cache, value_cache = key_caches[0], value_caches[0]
159
        # Call the paged attention kernel.
160
        output = torch.empty_like(query)
161
        torch.ops.torch_ipex.single_query_cached_kv_attention(
162
            output,
163
            query,
164
            key_cache,
165
            value_cache,
166
            head_mapping,
167
            scale,
168
            block_tables,
169
            context_lens,
170
            block_size,
171
            max_context_len,
172
            alibi_slopes,
173
        )
174

175
        # Run the reference implementation.
176
        ref_output = torch.empty_like(query)
177
        self.ref_single_query_cached_kv_attention(
178
            ref_output,
179
            query,
180
            num_queries_per_kv,
181
            key_cache,
182
            value_cache,
183
            block_tables,
184
            context_lens,
185
            scale,
186
            alibi_slopes,
187
        )
188
        assert torch.allclose(output, ref_output, atol=5e-3, rtol=1e-3)
189

190
    def test_paged_attention(self):
191
        num_blocks = 128
192
        dtypes = [torch.bfloat16, torch.float]
193
        num_gen_seqs = [7]  # Arbitrary values for testing
194
        num_heads = [(40, 40), (64, 16)]  # Arbitrary values for testing
195
        head_sizes = [64, 80, 128, 96, 112, 128, 256]
196
        block_sizes = [16, 32]
197
        use_alibis = [True, False]
198
        seeds = [0]
199
        for (
200
            num_seqs,
201
            num_head,
202
            head_size,
203
            use_alibi,
204
            block_size,
205
            dtype,
206
            seed,
207
        ) in product(
208
            num_gen_seqs,
209
            num_heads,
210
            head_sizes,
211
            use_alibis,
212
            block_sizes,
213
            dtypes,
214
            seeds,
215
        ):
216
            self._test_paged_attention_func(
217
                num_seqs,
218
                num_head,
219
                head_size,
220
                use_alibi,
221
                num_blocks,
222
                block_size,
223
                dtype,
224
                seed,
225
            )
226

227
    def _test_reshape_and_cache_func(
228
        self,
229
        num_token: int,
230
        num_head: int,
231
        head_size: int,
232
        block_size: int,
233
        num_blocks: int,
234
        dtype: torch.dtype,
235
        seed: int,
236
    ) -> None:
237
        random.seed(seed)
238
        torch.random.manual_seed(seed)
239
        torch.manual_seed(seed)
240

241
        # Create a random slot mapping.
242
        num_slots = block_size * num_blocks
243
        slot_mapping = random.sample(range(num_slots), num_token)
244
        slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cpu")
245

246
        qkv = torch.randn(num_token, 3, num_head, head_size, dtype=dtype, device="cpu")
247
        _, key, value = qkv.unbind(dim=1)
248
        # Create the KV caches.
249
        key_caches, value_caches = self.create_kv_caches(
250
            num_blocks, block_size, 1, num_head, head_size, dtype, seed
251
        )
252
        key_cache, value_cache = key_caches[0], value_caches[0]
253
        # Clone the KV caches.
254
        cloned_key_cache = key_cache.clone()
255
        cloned_value_cache = value_cache.clone()
256

257
        # Call the reshape_and_cache kernel.
258
        torch.ops.torch_ipex.reshape_and_cache(
259
            key, value, key_cache, value_cache, slot_mapping
260
        )
261

262
        # Run the reference implementation.
263
        block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
264
        block_indicies = block_indicies.cpu().tolist()
265
        block_offsets = slot_mapping % block_size
266
        block_offsets = block_offsets.cpu().tolist()
267
        for i in range(num_token):
268
            block_idx = block_indicies[i]
269
            block_offset = block_offsets[i]
270
            cloned_key_cache[block_idx, block_offset, :, :] = key[i]
271
            cloned_value_cache[block_idx, block_offset, :, :] = value[i]
272

273
        assert torch.allclose(key_cache, cloned_key_cache)
274
        assert torch.allclose(value_cache, cloned_value_cache)
275

276
    def test_reshape_and_cache(self):
277
        num_blocks = 128  # Arbitrary values for testing
278
        num_tokens = [1, 83, 1024]  # Arbitrary values for testing
279
        num_kv_heads = [8]  # Arbitrary values for testing
280
        head_sizes = [64, 80, 128, 96, 112, 128, 256]
281
        block_sizes = [16, 32]
282
        dtypes = [torch.bfloat16, torch.float]
283
        seeds = [0]
284
        for (
285
            num_token,
286
            num_kv_head,
287
            head_size,
288
            block_size,
289
            dtype,
290
            seed,
291
        ) in product(
292
            num_tokens,
293
            num_kv_heads,
294
            head_sizes,
295
            block_sizes,
296
            dtypes,
297
            seeds,
298
        ):
299
            self._test_reshape_and_cache_func(
300
                num_token, num_kv_head, head_size, block_size, num_blocks, dtype, seed
301
            )
302

303

304
if __name__ == "__main__":
305
    test = unittest.main()
306

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.