1
"""Test hashing of cache blocks.
3
Run `pytest tests/test_cache_block_hashing.py`.
7
from vllm.transformers_utils.tokenizer import TokenizerGroup
8
from vllm.sequence import Sequence
10
# Make two prefixes with different first blocks.
11
prefix_start = [("You are an expert"), ("You are a")]
13
" school principal, skilled in effectively managing "
14
"faculty and staff. Draft 10-15 questions for a potential first grade "
15
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
16
"community, joyful discovery, and life-long learning. The candidate is "
17
"coming in for a first-round panel interview for a 8th grade Math "
18
"teaching role. They have 5 years of previous teaching experience "
19
"as an assistant teacher at a co-ed, public school with experience "
20
"in middle school math teaching. Based on this, fulfill "
22
prefixes = [start + prefix_common for start in prefix_start]
26
"Hello, my name is", "The president of the United States is",
27
"The capital of France is", "The future of AI is"
33
return [lss for ls in li for lss in ls]
36
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
37
@pytest.mark.parametrize("block_size", [16])
38
@pytest.mark.parametrize("max_num_seqs", [256])
39
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
41
tokenizer = TokenizerGroup(
42
tokenizer_id="facebook/opt-125m",
44
max_num_seqs=max_num_seqs,
45
max_input_length=None,
50
for prefix in prefixes:
52
prompts = [prefix + prompt for prompt in sample_prompts]
54
for prompt in prompts:
56
prompt_token_ids = tokenizer.encode(prompt)
57
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
59
num_blocks = len(prompt_token_ids) // block_size
60
for idx in range(num_blocks):
61
hashes[-1][-1].append(seq.hash_of_block(idx))
65
# Check that hashes made with two prefixes with different first blocks are
66
# different everywhere.
67
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
68
assert (hash0 != hash1)
70
# Check that hashes of different prompts made with the same prefix are the
71
# same until the hashes that contain the prompt.
72
for hash_pref in hashes:
73
same_hashes = [tuple(h[:-1]) for h in hash_pref]
74
different_hashes = [h[-1] for h in hash_pref]
75
assert (len(set(same_hashes)) == 1)
76
assert (len(set(different_hashes)) == len(different_hashes))