vllm

Форк
0
/
conftest.py 
155 строк · 4.5 Кб
1
import contextlib
2
import gc
3
import tempfile
4
from collections import OrderedDict
5
from unittest.mock import patch, MagicMock
6

7
import pytest
8
import ray
9
import torch
10
import torch.nn as nn
11
from huggingface_hub import snapshot_download
12

13
import vllm
14
from vllm.config import LoRAConfig
15
from vllm.model_executor.layers.sampler import Sampler
16
from vllm.model_executor.model_loader import get_model
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
18
                                               MergedColumnParallelLinear,
19
                                               RowParallelLinear)
20
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
21
from vllm.model_executor.parallel_utils.parallel_state import (
22
    destroy_model_parallel, initialize_model_parallel)
23

24

25
def cleanup():
26
    destroy_model_parallel()
27
    with contextlib.suppress(AssertionError):
28
        torch.distributed.destroy_process_group()
29
    gc.collect()
30
    torch.cuda.empty_cache()
31
    ray.shutdown()
32

33

34
@pytest.fixture(autouse=True)
35
def cleanup_fixture():
36
    yield
37
    cleanup()
38

39

40
@pytest.fixture
41
def dist_init():
42
    if not torch.distributed.is_initialized():
43
        temp_file = tempfile.mkstemp()[1]
44
        torch.distributed.init_process_group(
45
            backend="nccl",
46
            world_size=1,
47
            rank=0,
48
            init_method=f"file://{temp_file}",
49
        )
50
        torch.distributed.all_reduce(torch.zeros(1).cuda())
51
    initialize_model_parallel(1, 1)
52
    yield
53
    cleanup()
54

55

56
@pytest.fixture
57
def dist_init_torch_only():
58
    if torch.distributed.is_initialized():
59
        return
60
    temp_file = tempfile.mkstemp()[1]
61
    torch.distributed.init_process_group(
62
        backend="nccl",
63
        world_size=1,
64
        rank=0,
65
        init_method=f"file://{temp_file}",
66
    )
67

68

69
@pytest.fixture
70
def dummy_model() -> nn.Module:
71
    model = nn.Sequential(
72
        OrderedDict([
73
            ("dense1", ColumnParallelLinear(764, 100)),
74
            ("dense2", RowParallelLinear(100, 50)),
75
            (
76
                "layer1",
77
                nn.Sequential(
78
                    OrderedDict([
79
                        ("dense1", ColumnParallelLinear(100, 10)),
80
                        ("dense2", RowParallelLinear(10, 50)),
81
                    ])),
82
            ),
83
            ("act2", nn.ReLU()),
84
            ("output", ColumnParallelLinear(50, 10)),
85
            ("outact", nn.Sigmoid()),
86
            # Special handling for lm_head & sampler
87
            ("lm_head", ParallelLMHead(512, 10)),
88
            ("sampler", Sampler(512))
89
        ]))
90
    model.config = MagicMock()
91
    return model
92

93

94
@pytest.fixture
95
def dummy_model_gate_up() -> nn.Module:
96
    model = nn.Sequential(
97
        OrderedDict([
98
            ("dense1", ColumnParallelLinear(764, 100)),
99
            ("dense2", RowParallelLinear(100, 50)),
100
            (
101
                "layer1",
102
                nn.Sequential(
103
                    OrderedDict([
104
                        ("dense1", ColumnParallelLinear(100, 10)),
105
                        ("dense2", RowParallelLinear(10, 50)),
106
                    ])),
107
            ),
108
            ("act2", nn.ReLU()),
109
            ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
110
            ("outact", nn.Sigmoid()),
111
            # Special handling for lm_head & sampler
112
            ("lm_head", ParallelLMHead(512, 10)),
113
            ("sampler", Sampler(512))
114
        ]))
115
    model.config = MagicMock()
116
    return model
117

118

119
@pytest.fixture(scope="session")
120
def sql_lora_files():
121
    return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
122

123

124
@pytest.fixture(scope="session")
125
def mixtral_lora_files():
126
    return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
127

128

129
@pytest.fixture(scope="session")
130
def gemma_lora_files():
131
    return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
132

133

134
@pytest.fixture
135
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
136
    cleanup()
137
    get_model_old = get_model
138

139
    def get_model_patched(model_config, device_config, **kwargs):
140
        return get_model_old(model_config,
141
                             device_config,
142
                             lora_config=LoRAConfig(max_loras=4,
143
                                                    max_lora_rank=8))
144

145
    with patch("vllm.worker.model_runner.get_model", get_model_patched):
146
        engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
147
    yield engine.llm_engine
148
    del engine
149
    cleanup()
150

151

152
@pytest.fixture
153
def llama_2_7b_model_extra_embeddings(
154
        llama_2_7b_engine_extra_embeddings) -> nn.Module:
155
    yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.