2
from typing import List, Optional, Tuple
6
from transformers import AutoModelForCausalLM
8
from vllm import LLM, SamplingParams
9
from vllm.transformers_utils.tokenizer import get_tokenizer
11
_TEST_DIR = os.path.dirname(__file__)
12
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
13
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
16
def _read_prompts(filename: str) -> List[str]:
17
with open(filename, "r") as f:
18
prompts = f.readlines()
23
def example_prompts() -> List[str]:
25
for filename in _TEST_PROMPTS:
26
prompts += _read_prompts(filename)
31
def example_long_prompts() -> List[str]:
33
for filename in _LONG_PROMPTS:
34
prompts += _read_prompts(filename)
38
_STR_DTYPE_TO_TORCH_DTYPE = {
40
"bfloat16": torch.bfloat16,
50
tokenizer_name: Optional[str] = None,
53
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
54
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
55
self.model = AutoModelForCausalLM.from_pretrained(
57
torch_dtype=torch_dtype,
58
trust_remote_code=True,
60
if tokenizer_name is None:
61
tokenizer_name = model_name
62
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
68
) -> List[Tuple[List[int], str]]:
69
outputs: List[Tuple[List[int], str]] = []
70
for prompt in prompts:
71
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
72
output_ids = self.model.generate(
77
output_str = self.tokenizer.batch_decode(
79
skip_special_tokens=True,
80
clean_up_tokenization_spaces=False,
82
output_ids = output_ids.cpu().tolist()
83
outputs.append((output_ids, output_str))
90
) -> List[Tuple[List[int], str]]:
91
outputs = self.generate(prompts,
93
max_new_tokens=max_tokens)
94
for i in range(len(outputs)):
95
output_ids, output_str = outputs[i]
96
outputs[i] = (output_ids[0], output_str[0])
99
def generate_beam_search(
104
) -> List[Tuple[List[int], str]]:
105
outputs = self.generate(prompts,
107
max_new_tokens=max_tokens,
108
num_beams=beam_width,
109
num_return_sequences=beam_width)
110
for i in range(len(outputs)):
111
output_ids, output_str = outputs[i]
112
for j in range(len(output_ids)):
114
x for x in output_ids[j]
115
if x != self.tokenizer.pad_token_id
117
outputs[i] = (output_ids, output_str)
120
def generate_greedy_logprobs(
124
) -> List[List[torch.Tensor]]:
126
for prompt in prompts:
127
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
128
output = self.model.generate(
132
max_new_tokens=max_tokens,
133
output_hidden_states=True,
134
return_dict_in_generate=True,
137
for hidden_states in output.hidden_states:
138
last_hidden_states = hidden_states[-1][0]
139
logits = torch.matmul(
141
self.model.get_output_embeddings().weight.t(),
143
if self.model.get_output_embeddings().bias is not None:
144
logits += self.model.get_output_embeddings(
146
logprobs = torch.nn.functional.log_softmax(logits,
149
seq_logprobs.append(logprobs)
150
all_logprobs.append(seq_logprobs)
164
tokenizer_name: Optional[str] = None,
166
disable_log_stats: bool = True,
167
tensor_parallel_size: int = 1,
172
tokenizer=tokenizer_name,
173
trust_remote_code=True,
176
disable_log_stats=disable_log_stats,
177
tensor_parallel_size=tensor_parallel_size,
184
sampling_params: SamplingParams,
185
) -> List[Tuple[List[int], str]]:
186
req_outputs = self.model.generate(prompts,
187
sampling_params=sampling_params)
189
for req_output in req_outputs:
190
prompt_str = req_output.prompt
191
prompt_ids = req_output.prompt_token_ids
192
req_sample_output_ids = []
193
req_sample_output_strs = []
194
for sample in req_output.outputs:
195
output_str = sample.text
196
output_ids = sample.token_ids
197
req_sample_output_ids.append(prompt_ids + output_ids)
198
req_sample_output_strs.append(prompt_str + output_str)
199
outputs.append((req_sample_output_ids, req_sample_output_strs))
202
def generate_w_logprobs(
205
sampling_params: SamplingParams,
206
) -> List[Tuple[List[int], str]]:
207
assert sampling_params.logprobs is not None
209
req_outputs = self.model.generate(prompts,
210
sampling_params=sampling_params)
212
for req_output in req_outputs:
213
for sample in req_output.outputs:
214
output_str = sample.text
215
output_ids = sample.token_ids
216
output_logprobs = sample.logprobs
217
outputs.append((output_ids, output_str, output_logprobs))
224
) -> List[Tuple[List[int], str]]:
225
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
226
outputs = self.generate(prompts, greedy_params)
227
return [(output_ids[0], output_str[0])
228
for output_ids, output_str in outputs]
230
def generate_greedy_logprobs(
235
) -> List[Tuple[List[int], str]]:
236
greedy_logprobs_params = SamplingParams(temperature=0.0,
237
max_tokens=max_tokens,
238
logprobs=num_logprobs)
239
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
241
return [(output_ids, output_str, output_logprobs)
242
for output_ids, output_str, output_logprobs in outputs]
244
def generate_beam_search(
249
) -> List[Tuple[List[int], str]]:
250
beam_search_params = SamplingParams(n=beam_width,
251
use_beam_search=True,
253
max_tokens=max_tokens)
254
outputs = self.generate(prompts, beam_search_params)