intel-extension-for-pytorch
305 строк · 10.3 Кб
1import torch
2from common_utils import TestCase
3import unittest
4import random
5from typing import List, Optional, Tuple
6from itertools import product
7
8
9class PagedAttentionTest(TestCase):
10def create_kv_caches(
11self,
12num_blocks: int,
13block_size: int,
14num_layer: int,
15num_head: int,
16head_size: int,
17dtype: torch.dtype,
18seed: int,
19) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
20torch.random.manual_seed(seed)
21torch.manual_seed(seed)
22
23scale = head_size**-0.5
24key_cache_shape = (num_blocks, block_size, num_head, head_size)
25key_caches = []
26for _ in range(num_layer):
27key_cache = torch.empty(size=key_cache_shape, dtype=dtype)
28key_cache.uniform_(-scale, scale)
29key_caches.append(key_cache)
30
31value_cache_shape = (num_blocks, block_size, num_head, head_size)
32value_caches = []
33for _ in range(num_layer):
34value_cache = torch.empty(size=value_cache_shape, dtype=dtype)
35value_cache.uniform_(-scale, scale)
36value_caches.append(value_cache)
37return key_caches, value_caches
38
39def ref_masked_attention(
40self,
41query: torch.Tensor,
42key: torch.Tensor,
43value: torch.Tensor,
44scale: float,
45attn_mask: Optional[torch.Tensor] = None,
46) -> torch.Tensor:
47attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
48if attn_mask is not None:
49attn_weights = attn_weights + attn_mask.float()
50attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
51out = torch.einsum("hqk,khd->qhd", attn_weights, value)
52return out
53
54def ref_single_query_cached_kv_attention(
55self,
56output: torch.Tensor,
57query: torch.Tensor,
58num_queries_per_kv: int,
59key_cache: torch.Tensor,
60value_cache: torch.Tensor,
61block_tables: torch.Tensor,
62context_lens: torch.Tensor,
63scale: float,
64alibi_slopes: Optional[torch.Tensor],
65) -> None:
66num_query_heads = query.shape[1]
67num_kv_head = value_cache.shape[2]
68head_size = value_cache.shape[3]
69block_size = value_cache.shape[1]
70num_seqs = query.shape[0]
71
72block_tables = block_tables.cpu().tolist()
73context_lens = context_lens.cpu().tolist()
74for i in range(num_seqs):
75q = query[i].unsqueeze(0)
76block_table = block_tables[i]
77context_len = int(context_lens[i])
78
79keys = []
80values = []
81for j in range(context_len):
82block_number = int(block_table[j // block_size])
83block_offset = j % block_size
84
85k = key_cache[block_number, block_offset, :, :]
86k = k.reshape(num_kv_head, head_size)
87keys.append(k)
88
89v = value_cache[block_number, block_offset, :, :]
90values.append(v)
91keys = torch.stack(keys, dim=0)
92values = torch.stack(values, dim=0)
93if num_queries_per_kv > 1:
94# Handle MQA and GQA
95keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
96values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
97alibi_bias = None
98if alibi_slopes is not None:
99# Create the ALiBi bias used in the paged attention kernel.
100position_ids = torch.arange(context_len, device="cpu").int()
101alibi_bias = (position_ids - context_len + 1).float()
102alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
103
104out = self.ref_masked_attention(q, keys, values, scale, alibi_bias)
105out = out.view(num_query_heads, head_size)
106output[i].copy_(out, non_blocking=True)
107
108def _test_paged_attention_func(
109self,
110num_seqs: int,
111num_head: Tuple[int, int],
112head_size: int,
113use_alibi: bool,
114num_blocks: int,
115block_size: int,
116dtype: torch.dtype,
117seed: int,
118) -> None:
119random.seed(seed)
120torch.random.manual_seed(seed)
121torch.manual_seed(seed)
122max_seq_len = 1024
123scale = float(1.0 / (head_size**0.5))
124num_query_heads, num_kv_head = num_head
125query = torch.empty(
126num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu"
127)
128query.uniform_(-scale, scale)
129assert num_query_heads % num_kv_head == 0
130num_queries_per_kv = num_query_heads // num_kv_head
131head_mapping = torch.repeat_interleave(
132torch.arange(num_kv_head, dtype=torch.int32, device="cpu"),
133num_queries_per_kv,
134)
135alibi_slopes = None
136if use_alibi:
137alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device="cpu")
138
139context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)]
140context_lens[-1] = max_seq_len
141max_context_len = max(context_lens)
142context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu")
143
144# Create the block tables.NUM_PREFILL_SEQS
145max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
146block_tables = []
147for _ in range(num_seqs):
148block_table = [
149random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq)
150]
151block_tables.append(block_table)
152block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu")
153
154# Create the KV caches.
155key_caches, value_caches = self.create_kv_caches(
156num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed
157)
158key_cache, value_cache = key_caches[0], value_caches[0]
159# Call the paged attention kernel.
160output = torch.empty_like(query)
161torch.ops.torch_ipex.single_query_cached_kv_attention(
162output,
163query,
164key_cache,
165value_cache,
166head_mapping,
167scale,
168block_tables,
169context_lens,
170block_size,
171max_context_len,
172alibi_slopes,
173)
174
175# Run the reference implementation.
176ref_output = torch.empty_like(query)
177self.ref_single_query_cached_kv_attention(
178ref_output,
179query,
180num_queries_per_kv,
181key_cache,
182value_cache,
183block_tables,
184context_lens,
185scale,
186alibi_slopes,
187)
188assert torch.allclose(output, ref_output, atol=5e-3, rtol=1e-3)
189
190def test_paged_attention(self):
191num_blocks = 128
192dtypes = [torch.bfloat16, torch.float]
193num_gen_seqs = [7] # Arbitrary values for testing
194num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing
195head_sizes = [64, 80, 128, 96, 112, 128, 256]
196block_sizes = [16, 32]
197use_alibis = [True, False]
198seeds = [0]
199for (
200num_seqs,
201num_head,
202head_size,
203use_alibi,
204block_size,
205dtype,
206seed,
207) in product(
208num_gen_seqs,
209num_heads,
210head_sizes,
211use_alibis,
212block_sizes,
213dtypes,
214seeds,
215):
216self._test_paged_attention_func(
217num_seqs,
218num_head,
219head_size,
220use_alibi,
221num_blocks,
222block_size,
223dtype,
224seed,
225)
226
227def _test_reshape_and_cache_func(
228self,
229num_token: int,
230num_head: int,
231head_size: int,
232block_size: int,
233num_blocks: int,
234dtype: torch.dtype,
235seed: int,
236) -> None:
237random.seed(seed)
238torch.random.manual_seed(seed)
239torch.manual_seed(seed)
240
241# Create a random slot mapping.
242num_slots = block_size * num_blocks
243slot_mapping = random.sample(range(num_slots), num_token)
244slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cpu")
245
246qkv = torch.randn(num_token, 3, num_head, head_size, dtype=dtype, device="cpu")
247_, key, value = qkv.unbind(dim=1)
248# Create the KV caches.
249key_caches, value_caches = self.create_kv_caches(
250num_blocks, block_size, 1, num_head, head_size, dtype, seed
251)
252key_cache, value_cache = key_caches[0], value_caches[0]
253# Clone the KV caches.
254cloned_key_cache = key_cache.clone()
255cloned_value_cache = value_cache.clone()
256
257# Call the reshape_and_cache kernel.
258torch.ops.torch_ipex.reshape_and_cache(
259key, value, key_cache, value_cache, slot_mapping
260)
261
262# Run the reference implementation.
263block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
264block_indicies = block_indicies.cpu().tolist()
265block_offsets = slot_mapping % block_size
266block_offsets = block_offsets.cpu().tolist()
267for i in range(num_token):
268block_idx = block_indicies[i]
269block_offset = block_offsets[i]
270cloned_key_cache[block_idx, block_offset, :, :] = key[i]
271cloned_value_cache[block_idx, block_offset, :, :] = value[i]
272
273assert torch.allclose(key_cache, cloned_key_cache)
274assert torch.allclose(value_cache, cloned_value_cache)
275
276def test_reshape_and_cache(self):
277num_blocks = 128 # Arbitrary values for testing
278num_tokens = [1, 83, 1024] # Arbitrary values for testing
279num_kv_heads = [8] # Arbitrary values for testing
280head_sizes = [64, 80, 128, 96, 112, 128, 256]
281block_sizes = [16, 32]
282dtypes = [torch.bfloat16, torch.float]
283seeds = [0]
284for (
285num_token,
286num_kv_head,
287head_size,
288block_size,
289dtype,
290seed,
291) in product(
292num_tokens,
293num_kv_heads,
294head_sizes,
295block_sizes,
296dtypes,
297seeds,
298):
299self._test_reshape_and_cache_func(
300num_token, num_kv_head, head_size, block_size, num_blocks, dtype, seed
301)
302
303
304if __name__ == "__main__":
305test = unittest.main()
306