vllm

Форк
0
/
test_cache_block_hashing.py 
76 строк · 2.7 Кб
1
"""Test hashing of cache blocks.
2

3
Run `pytest tests/test_cache_block_hashing.py`.
4
"""
5
import pytest
6

7
from vllm.transformers_utils.tokenizer import TokenizerGroup
8
from vllm.sequence import Sequence
9

10
# Make two prefixes with different first blocks.
11
prefix_start = [("You are an expert"), ("You are a")]
12
prefix_common = (
13
    " school principal, skilled in effectively managing "
14
    "faculty and staff. Draft 10-15 questions for a potential first grade "
15
    "Head Teacher for my K-12, all-girls', independent school that emphasizes "
16
    "community, joyful discovery, and life-long learning. The candidate is "
17
    "coming in for a first-round panel interview for a 8th grade Math "
18
    "teaching role. They have 5 years of previous teaching experience "
19
    "as an assistant teacher at a co-ed, public school with experience "
20
    "in middle school math teaching. Based on this, fulfill "
21
    "the following: ")
22
prefixes = [start + prefix_common for start in prefix_start]
23

24
# Sample prompts.
25
sample_prompts = [
26
    "Hello, my name is", "The president of the United States is",
27
    "The capital of France is", "The future of AI is"
28
]
29

30

31
# Helper function.
32
def flatten_2d(li):
33
    return [lss for ls in li for lss in ls]
34

35

36
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
37
@pytest.mark.parametrize("block_size", [16])
38
@pytest.mark.parametrize("max_num_seqs", [256])
39
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
40

41
    tokenizer = TokenizerGroup(
42
        tokenizer_id="facebook/opt-125m",
43
        enable_lora=False,
44
        max_num_seqs=max_num_seqs,
45
        max_input_length=None,
46
    )
47

48
    hashes = []
49

50
    for prefix in prefixes:
51
        hashes.append([])
52
        prompts = [prefix + prompt for prompt in sample_prompts]
53
        seq_id = 0
54
        for prompt in prompts:
55
            hashes[-1].append([])
56
            prompt_token_ids = tokenizer.encode(prompt)
57
            seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
58

59
            num_blocks = len(prompt_token_ids) // block_size
60
            for idx in range(num_blocks):
61
                hashes[-1][-1].append(seq.hash_of_block(idx))
62

63
            seq_id += 1
64

65
    # Check that hashes made with two prefixes with different first blocks are
66
    # different everywhere.
67
    for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
68
        assert (hash0 != hash1)
69

70
    # Check that hashes of different prompts made with the same prefix are the
71
    # same until the hashes that contain the prompt.
72
    for hash_pref in hashes:
73
        same_hashes = [tuple(h[:-1]) for h in hash_pref]
74
        different_hashes = [h[-1] for h in hash_pref]
75
        assert (len(set(same_hashes)) == 1)
76
        assert (len(set(different_hashes)) == len(different_hashes))
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.