4
from collections import OrderedDict
5
from unittest.mock import patch, MagicMock
11
from huggingface_hub import snapshot_download
14
from vllm.config import LoRAConfig
15
from vllm.model_executor.layers.sampler import Sampler
16
from vllm.model_executor.model_loader import get_model
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
18
MergedColumnParallelLinear,
20
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
21
from vllm.model_executor.parallel_utils.parallel_state import (
22
destroy_model_parallel, initialize_model_parallel)
26
destroy_model_parallel()
27
with contextlib.suppress(AssertionError):
28
torch.distributed.destroy_process_group()
30
torch.cuda.empty_cache()
34
@pytest.fixture(autouse=True)
42
if not torch.distributed.is_initialized():
43
temp_file = tempfile.mkstemp()[1]
44
torch.distributed.init_process_group(
48
init_method=f"file://{temp_file}",
50
torch.distributed.all_reduce(torch.zeros(1).cuda())
51
initialize_model_parallel(1, 1)
57
def dist_init_torch_only():
58
if torch.distributed.is_initialized():
60
temp_file = tempfile.mkstemp()[1]
61
torch.distributed.init_process_group(
65
init_method=f"file://{temp_file}",
70
def dummy_model() -> nn.Module:
71
model = nn.Sequential(
73
("dense1", ColumnParallelLinear(764, 100)),
74
("dense2", RowParallelLinear(100, 50)),
79
("dense1", ColumnParallelLinear(100, 10)),
80
("dense2", RowParallelLinear(10, 50)),
84
("output", ColumnParallelLinear(50, 10)),
85
("outact", nn.Sigmoid()),
87
("lm_head", ParallelLMHead(512, 10)),
88
("sampler", Sampler(512))
90
model.config = MagicMock()
95
def dummy_model_gate_up() -> nn.Module:
96
model = nn.Sequential(
98
("dense1", ColumnParallelLinear(764, 100)),
99
("dense2", RowParallelLinear(100, 50)),
104
("dense1", ColumnParallelLinear(100, 10)),
105
("dense2", RowParallelLinear(10, 50)),
109
("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
110
("outact", nn.Sigmoid()),
112
("lm_head", ParallelLMHead(512, 10)),
113
("sampler", Sampler(512))
115
model.config = MagicMock()
119
@pytest.fixture(scope="session")
121
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
124
@pytest.fixture(scope="session")
125
def mixtral_lora_files():
126
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
129
@pytest.fixture(scope="session")
130
def gemma_lora_files():
131
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
135
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
137
get_model_old = get_model
139
def get_model_patched(model_config, device_config, **kwargs):
140
return get_model_old(model_config,
142
lora_config=LoRAConfig(max_loras=4,
145
with patch("vllm.worker.model_runner.get_model", get_model_patched):
146
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
147
yield engine.llm_engine
153
def llama_2_7b_model_extra_embeddings(
154
llama_2_7b_engine_extra_embeddings) -> nn.Module:
155
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model