vllm

Форк
0
/
test_worker.py 
63 строки · 2.1 Кб
1
import os
2
import random
3
import tempfile
4
from unittest.mock import patch
5

6
from vllm.lora.models import LoRAMapping
7
from vllm.lora.request import LoRARequest
8
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
9
                         DeviceConfig, LoRAConfig)
10
from vllm.worker.worker import Worker
11

12

13
@patch.dict(os.environ, {"RANK": "0"})
14
def test_worker_apply_lora(sql_lora_files):
15
    worker = Worker(
16
        model_config=ModelConfig(
17
            "meta-llama/Llama-2-7b-hf",
18
            "meta-llama/Llama-2-7b-hf",
19
            tokenizer_mode="auto",
20
            trust_remote_code=False,
21
            download_dir=None,
22
            load_format="dummy",
23
            seed=0,
24
            dtype="float16",
25
            revision=None,
26
        ),
27
        parallel_config=ParallelConfig(1, 1, False),
28
        scheduler_config=SchedulerConfig(32, 32, 32, 256),
29
        device_config=DeviceConfig("cuda"),
30
        local_rank=0,
31
        rank=0,
32
        lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
33
                               max_loras=32),
34
        distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
35
    )
36
    worker.init_model()
37
    worker.load_model()
38

39
    worker.model_runner.set_active_loras([], LoRAMapping([], []))
40
    assert worker.list_loras() == set()
41

42
    n_loras = 32
43
    lora_requests = [
44
        LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
45
    ]
46

47
    worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
48
    assert worker.list_loras() == {
49
        lora_request.lora_int_id
50
        for lora_request in lora_requests
51
    }
52

53
    for i in range(32):
54
        random.seed(i)
55
        iter_lora_requests = random.choices(lora_requests,
56
                                            k=random.randint(1, n_loras))
57
        random.shuffle(iter_lora_requests)
58
        iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
59
        worker.model_runner.set_active_loras(iter_lora_requests,
60
                                             LoRAMapping([], []))
61
        assert worker.list_loras().issuperset(
62
            {lora_request.lora_int_id
63
             for lora_request in iter_lora_requests})
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.