5
from vllm.lora.request import LoRARequest
6
from .conftest import cleanup
8
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
11
def do_sample(llm, lora_path: str, lora_id: int):
13
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
14
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
15
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
16
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
17
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
18
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]"
20
sampling_params = vllm.SamplingParams(temperature=0,
22
stop=["[/assistant]"])
23
outputs = llm.generate(
26
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
30
for output in outputs:
31
prompt = output.prompt
32
generated_text = output.outputs[0].text
33
generated_texts.append(generated_text)
34
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
35
return generated_texts
38
@pytest.mark.parametrize("tp_size", [1])
39
def test_llama_lora(sql_lora_files, tp_size):
40
# Cannot use as it will initialize torch.cuda too early...
41
# if torch.cuda.device_count() < tp_size:
42
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
44
llm = vllm.LLM(MODEL_PATH,
48
tensor_parallel_size=tp_size)
50
expected_no_lora_output = [
51
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]",
52
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ",
53
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
54
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
55
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
56
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
58
expected_lora_output = [
59
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
60
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ",
61
" SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ",
62
" SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ",
63
" SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ",
64
" SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' "
67
print("lora adapter created")
68
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
71
assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output
74
assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output
77
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output
79
print("removing lora")
82
@pytest.mark.skip("Requires multiple GPUs")
83
def test_llama_tensor_parallel_equality(sql_lora_files):
84
# Cannot use as it will initialize torch.cuda too early...
85
# if torch.cuda.device_count() < 4:
86
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
88
llm_tp1 = vllm.LLM(MODEL_PATH,
92
tensor_parallel_size=1)
93
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
98
llm_tp2 = vllm.LLM(MODEL_PATH,
102
tensor_parallel_size=2)
103
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
108
assert output_tp1 == output_tp2
110
llm_tp4 = vllm.LLM(MODEL_PATH,
114
tensor_parallel_size=4)
115
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
120
assert output_tp1 == output_tp4
123
def test_llama_lora_warmup(sql_lora_files):
124
"""Test that the LLM initialization works with a warmup LORA path and is more conservative"""
126
@ray.remote(num_gpus=1)
127
def get_num_gpu_blocks_lora():
128
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
129
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
130
return num_gpu_blocks_lora_warmup
132
@ray.remote(num_gpus=1)
133
def get_num_gpu_blocks_no_lora():
134
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
135
num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
136
return num_gpu_blocks_no_lora_warmup
138
num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote())
139
num_gpu_blocks_no_lora_warmup = ray.get(
140
get_num_gpu_blocks_no_lora.remote())
141
assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, (
142
"The warmup with lora should be more"
143
" conservative than without lora, therefore the number of memory blocks for the KV cache should be "
144
"less when using lora than when not using lora")