7
from typing import Annotated
9
from typing_extensions import Annotated
11
from typing import List, Optional
14
from rich import print as rp
15
from rich.console import Console
16
from rich.progress import Progress, SpinnerColumn, TextColumn
17
from rich.table import Table
20
from rayllm.common.evaluation import GPT
22
__all__ = ["app", "models", "metadata", "query", "run"]
26
model_type = typer.Option(
27
default=..., help="The model to use. You can specify multiple models."
29
prompt_type = typer.Option(help="Prompt to query")
30
stats_type = typer.Option(help="Whether to print generated statistics")
31
prompt_file_type = typer.Option(
32
default=..., help="File containing prompts. A simple text file"
34
separator_type = typer.Option(help="Separator used in prompt files")
35
results_type = typer.Option(help="Where to save the results")
36
true_or_false_type = typer.Option(default=False, is_flag=True)
40
def models(metadata: Annotated[bool, "Whether to print metadata"] = False):
41
"""Get a list of the available models"""
45
rp(f"[bold]{model}:[/]")
46
rp(sdk.metadata(model))
48
print("\n".join(result))
52
def metadata(model: Annotated[List[str], model_type]):
53
"""Get metadata for models."""
54
results = [sdk.metadata(m) for m in model]
58
def _get_text(result: dict) -> str:
59
if "text" in result["choices"][0]:
60
return result["choices"][0]["text"]
61
elif "message" in result["choices"][0]:
62
return result["choices"][0]["message"]["content"]
63
elif "delta" in result["choices"][0]:
64
return result["choices"][0]["delta"].get("content", "")
67
def _print_result(result, model, print_stats):
68
rp(f"[bold]{model}:[/]")
76
def progress_spinner():
79
TextColumn("[progress.description]{task.description}"),
86
model: Annotated[List[str], model_type],
87
prompt: Annotated[Optional[List[str]], prompt_type] = None,
88
prompt_file: Annotated[Optional[str], prompt_file_type] = None,
89
separator: Annotated[str, separator_type] = "----",
90
output_file: Annotated[str, results_type] = "aviary-output.json",
91
print_stats: Annotated[bool, stats_type] = False,
93
"""Query one or several models with one or multiple prompts,
94
optionally read from file, and save the results to a file."""
95
# TODO (max): deprecate and rename to "completions" to match the API
96
with progress_spinner() as progress:
98
with open(prompt_file, "r") as f:
99
prompt = f.read().split(separator)
101
results = {p: [] for p in prompt}
105
description=f"Processing all prompts against model: {m}.",
108
query_results = [sdk.query(m, p) for p in prompt]
109
for result in query_results:
110
_print_result(result, m, print_stats)
112
for i, p in enumerate(prompt):
113
result = query_results[i]
114
text = _get_text(result)
115
results[p].append({"model": m, "result": text, "stats": result})
117
progress.add_task(description="Writing output file.", total=None)
118
with open(output_file, "w") as f:
119
f.write(json.dumps(results, indent=2))
122
def _get_yes_or_no_input(prompt) -> bool:
124
user_input = input(prompt).strip().lower()
125
if user_input == "yes" or user_input == "y":
127
elif user_input == "no" or user_input == "n" or user_input == "":
130
print("Invalid input. Please enter 'yes / y' or 'no / n'.")
135
model: Annotated[List[str], model_type],
136
blocking: bool = True,
137
restart: bool = true_or_false_type,
139
"""Start a model in Aviary.
142
*model: Models to run.
143
blocking: Whether to block the CLI until the application is ready.
144
restart: Whether to restart Aviary if it is already running.
147
"Running `aviary run` while Aviary is running will stop any exisiting Aviary (or other Ray Serve) deployments "
148
f"and run the specified ones ({model}).\n"
149
"Do you want to continue? [y/N]\n"
152
backend = sdk.get_aviary_backend(verbose=False)
153
aviary_url = backend.backend_url
154
aviary_started = False
156
health_check_url = f"{aviary_url}/health_check"
157
aviary_started = requests.get(health_check_url).status_code == 200
160
restart_aviary = True
162
restart_aviary = _get_yes_or_no_input(msg) or False
164
if not restart_aviary:
166
except (requests.exceptions.ConnectionError, sdk.URLNotSetException):
167
pass # Aviary is not running
170
sdk.run(model, blocking=blocking)
175
"""Shutdown Aviary."""
179
evaluator_type = typer.Option(help="Which LLM to use for evaluation")
184
input_file: Annotated[str, results_type] = "aviary-output.json",
185
evaluation_file: Annotated[str, results_type] = "evaluation-output.json",
186
evaluator: Annotated[str, evaluator_type] = "gpt-4",
188
"""Evaluate and summarize the results of a multi_query run with a strong
189
'evaluator' LLM like GPT-4.
190
The results of the ranking are stored to file and displayed in a table.
192
with progress_spinner() as progress:
193
progress.add_task(description="Loading the evaluator LLM.", total=None)
194
if evaluator == "gpt-4":
197
raise NotImplementedError(f"No evaluator for {evaluator}")
199
with open(input_file, "r") as f:
200
results = json.load(f)
202
for prompt, result_list in results.items():
204
description=f"Evaluating results for prompt: {prompt}.", total=None
206
evaluation = eval_model.evaluate_results(prompt, result_list)
208
# GPT-4 returns a string with a Python dictionary, hopefully!
209
evaluation = ast.literal_eval(evaluation)
211
print(f"Could not parse evaluation: {evaluation}")
213
for i, _res in enumerate(results[prompt]):
214
results[prompt][i]["rank"] = evaluation[i]["rank"]
216
progress.add_task(description="Storing evaluations.", total=None)
217
with open(evaluation_file, "w") as f:
218
f.write(json.dumps(results, indent=2))
220
for prompt in results.keys():
221
table = Table(title="Evaluation results (higher ranks are better)")
223
table.add_column("Model", justify="left", style="cyan", no_wrap=True)
224
table.add_column("Rank", style="magenta")
225
table.add_column("Response", justify="right", style="green")
227
for i, _res in enumerate(results[prompt]):
228
model = results[prompt][i]["model"]
229
response = results[prompt][i]["result"]
230
rank = results[prompt][i]["rank"]
231
table.add_row(model, str(rank), response)
239
model: Annotated[str, model_type],
240
prompt: Annotated[str, prompt_type],
241
print_stats: Annotated[bool, stats_type] = False,
244
for chunk in sdk.stream(model, prompt):
245
text = _get_text(chunk)
249
rp("[bold]Stats:[/]")
253
if __name__ == "__main__":