4
from unittest.mock import patch
6
from vllm.lora.models import LoRAMapping
7
from vllm.lora.request import LoRARequest
8
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
9
DeviceConfig, LoRAConfig)
10
from vllm.worker.worker import Worker
13
@patch.dict(os.environ, {"RANK": "0"})
14
def test_worker_apply_lora(sql_lora_files):
16
model_config=ModelConfig(
17
"meta-llama/Llama-2-7b-hf",
18
"meta-llama/Llama-2-7b-hf",
19
tokenizer_mode="auto",
20
trust_remote_code=False,
27
parallel_config=ParallelConfig(1, 1, False),
28
scheduler_config=SchedulerConfig(32, 32, 32, 256),
29
device_config=DeviceConfig("cuda"),
32
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
34
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
39
worker.model_runner.set_active_loras([], LoRAMapping([], []))
40
assert worker.list_loras() == set()
44
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
47
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
48
assert worker.list_loras() == {
49
lora_request.lora_int_id
50
for lora_request in lora_requests
55
iter_lora_requests = random.choices(lora_requests,
56
k=random.randint(1, n_loras))
57
random.shuffle(iter_lora_requests)
58
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
59
worker.model_runner.set_active_loras(iter_lora_requests,
61
assert worker.list_loras().issuperset(
62
{lora_request.lora_int_id
63
for lora_request in iter_lora_requests})