vllm

Форк
0
/
conftest.py 
260 строк · 8.4 Кб
1
import os
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
from transformers import AutoModelForCausalLM
7

8
from vllm import LLM, SamplingParams
9
from vllm.transformers_utils.tokenizer import get_tokenizer
10

11
_TEST_DIR = os.path.dirname(__file__)
12
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
13
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
14

15

16
def _read_prompts(filename: str) -> List[str]:
17
    with open(filename, "r") as f:
18
        prompts = f.readlines()
19
        return prompts
20

21

22
@pytest.fixture
23
def example_prompts() -> List[str]:
24
    prompts = []
25
    for filename in _TEST_PROMPTS:
26
        prompts += _read_prompts(filename)
27
    return prompts
28

29

30
@pytest.fixture
31
def example_long_prompts() -> List[str]:
32
    prompts = []
33
    for filename in _LONG_PROMPTS:
34
        prompts += _read_prompts(filename)
35
    return prompts
36

37

38
_STR_DTYPE_TO_TORCH_DTYPE = {
39
    "half": torch.half,
40
    "bfloat16": torch.bfloat16,
41
    "float": torch.float,
42
}
43

44

45
class HfRunner:
46

47
    def __init__(
48
        self,
49
        model_name: str,
50
        tokenizer_name: Optional[str] = None,
51
        dtype: str = "half",
52
    ) -> None:
53
        assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
54
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
55
        self.model = AutoModelForCausalLM.from_pretrained(
56
            model_name,
57
            torch_dtype=torch_dtype,
58
            trust_remote_code=True,
59
        ).cuda()
60
        if tokenizer_name is None:
61
            tokenizer_name = model_name
62
        self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
63

64
    def generate(
65
        self,
66
        prompts: List[str],
67
        **kwargs,
68
    ) -> List[Tuple[List[int], str]]:
69
        outputs: List[Tuple[List[int], str]] = []
70
        for prompt in prompts:
71
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
72
            output_ids = self.model.generate(
73
                input_ids.cuda(),
74
                use_cache=True,
75
                **kwargs,
76
            )
77
            output_str = self.tokenizer.batch_decode(
78
                output_ids,
79
                skip_special_tokens=True,
80
                clean_up_tokenization_spaces=False,
81
            )
82
            output_ids = output_ids.cpu().tolist()
83
            outputs.append((output_ids, output_str))
84
        return outputs
85

86
    def generate_greedy(
87
        self,
88
        prompts: List[str],
89
        max_tokens: int,
90
    ) -> List[Tuple[List[int], str]]:
91
        outputs = self.generate(prompts,
92
                                do_sample=False,
93
                                max_new_tokens=max_tokens)
94
        for i in range(len(outputs)):
95
            output_ids, output_str = outputs[i]
96
            outputs[i] = (output_ids[0], output_str[0])
97
        return outputs
98

99
    def generate_beam_search(
100
        self,
101
        prompts: List[str],
102
        beam_width: int,
103
        max_tokens: int,
104
    ) -> List[Tuple[List[int], str]]:
105
        outputs = self.generate(prompts,
106
                                do_sample=False,
107
                                max_new_tokens=max_tokens,
108
                                num_beams=beam_width,
109
                                num_return_sequences=beam_width)
110
        for i in range(len(outputs)):
111
            output_ids, output_str = outputs[i]
112
            for j in range(len(output_ids)):
113
                output_ids[j] = [
114
                    x for x in output_ids[j]
115
                    if x != self.tokenizer.pad_token_id
116
                ]
117
            outputs[i] = (output_ids, output_str)
118
        return outputs
119

120
    def generate_greedy_logprobs(
121
        self,
122
        prompts: List[str],
123
        max_tokens: int,
124
    ) -> List[List[torch.Tensor]]:
125
        all_logprobs = []
126
        for prompt in prompts:
127
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
128
            output = self.model.generate(
129
                input_ids.cuda(),
130
                use_cache=True,
131
                do_sample=False,
132
                max_new_tokens=max_tokens,
133
                output_hidden_states=True,
134
                return_dict_in_generate=True,
135
            )
136
            seq_logprobs = []
137
            for hidden_states in output.hidden_states:
138
                last_hidden_states = hidden_states[-1][0]
139
                logits = torch.matmul(
140
                    last_hidden_states,
141
                    self.model.get_output_embeddings().weight.t(),
142
                )
143
                if self.model.get_output_embeddings().bias is not None:
144
                    logits += self.model.get_output_embeddings(
145
                    ).bias.unsqueeze(0)
146
                logprobs = torch.nn.functional.log_softmax(logits,
147
                                                           dim=-1,
148
                                                           dtype=torch.float32)
149
                seq_logprobs.append(logprobs)
150
            all_logprobs.append(seq_logprobs)
151
        return all_logprobs
152

153

154
@pytest.fixture
155
def hf_runner():
156
    return HfRunner
157

158

159
class VllmRunner:
160

161
    def __init__(
162
        self,
163
        model_name: str,
164
        tokenizer_name: Optional[str] = None,
165
        dtype: str = "half",
166
        disable_log_stats: bool = True,
167
        tensor_parallel_size: int = 1,
168
        **kwargs,
169
    ) -> None:
170
        self.model = LLM(
171
            model=model_name,
172
            tokenizer=tokenizer_name,
173
            trust_remote_code=True,
174
            dtype=dtype,
175
            swap_space=0,
176
            disable_log_stats=disable_log_stats,
177
            tensor_parallel_size=tensor_parallel_size,
178
            **kwargs,
179
        )
180

181
    def generate(
182
        self,
183
        prompts: List[str],
184
        sampling_params: SamplingParams,
185
    ) -> List[Tuple[List[int], str]]:
186
        req_outputs = self.model.generate(prompts,
187
                                          sampling_params=sampling_params)
188
        outputs = []
189
        for req_output in req_outputs:
190
            prompt_str = req_output.prompt
191
            prompt_ids = req_output.prompt_token_ids
192
            req_sample_output_ids = []
193
            req_sample_output_strs = []
194
            for sample in req_output.outputs:
195
                output_str = sample.text
196
                output_ids = sample.token_ids
197
                req_sample_output_ids.append(prompt_ids + output_ids)
198
                req_sample_output_strs.append(prompt_str + output_str)
199
            outputs.append((req_sample_output_ids, req_sample_output_strs))
200
        return outputs
201

202
    def generate_w_logprobs(
203
        self,
204
        prompts: List[str],
205
        sampling_params: SamplingParams,
206
    ) -> List[Tuple[List[int], str]]:
207
        assert sampling_params.logprobs is not None
208

209
        req_outputs = self.model.generate(prompts,
210
                                          sampling_params=sampling_params)
211
        outputs = []
212
        for req_output in req_outputs:
213
            for sample in req_output.outputs:
214
                output_str = sample.text
215
                output_ids = sample.token_ids
216
                output_logprobs = sample.logprobs
217
            outputs.append((output_ids, output_str, output_logprobs))
218
        return outputs
219

220
    def generate_greedy(
221
        self,
222
        prompts: List[str],
223
        max_tokens: int,
224
    ) -> List[Tuple[List[int], str]]:
225
        greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
226
        outputs = self.generate(prompts, greedy_params)
227
        return [(output_ids[0], output_str[0])
228
                for output_ids, output_str in outputs]
229

230
    def generate_greedy_logprobs(
231
        self,
232
        prompts: List[str],
233
        max_tokens: int,
234
        num_logprobs: int,
235
    ) -> List[Tuple[List[int], str]]:
236
        greedy_logprobs_params = SamplingParams(temperature=0.0,
237
                                                max_tokens=max_tokens,
238
                                                logprobs=num_logprobs)
239
        outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
240

241
        return [(output_ids, output_str, output_logprobs)
242
                for output_ids, output_str, output_logprobs in outputs]
243

244
    def generate_beam_search(
245
        self,
246
        prompts: List[str],
247
        beam_width: int,
248
        max_tokens: int,
249
    ) -> List[Tuple[List[int], str]]:
250
        beam_search_params = SamplingParams(n=beam_width,
251
                                            use_beam_search=True,
252
                                            temperature=0.0,
253
                                            max_tokens=max_tokens)
254
        outputs = self.generate(prompts, beam_search_params)
255
        return outputs
256

257

258
@pytest.fixture
259
def vllm_runner():
260
    return VllmRunner
261

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.