ray-llm

Форк
0
/
cli.py 
254 строки · 7.9 Кб
1
import ast
2
import json
3

4
import requests
5

6
try:
7
    from typing import Annotated
8
except ImportError:
9
    from typing_extensions import Annotated
10

11
from typing import List, Optional
12

13
import typer
14
from rich import print as rp
15
from rich.console import Console
16
from rich.progress import Progress, SpinnerColumn, TextColumn
17
from rich.table import Table
18

19
from rayllm import sdk
20
from rayllm.common.evaluation import GPT
21

22
__all__ = ["app", "models", "metadata", "query", "run"]
23

24
app = typer.Typer()
25

26
model_type = typer.Option(
27
    default=..., help="The model to use. You can specify multiple models."
28
)
29
prompt_type = typer.Option(help="Prompt to query")
30
stats_type = typer.Option(help="Whether to print generated statistics")
31
prompt_file_type = typer.Option(
32
    default=..., help="File containing prompts. A simple text file"
33
)
34
separator_type = typer.Option(help="Separator used in prompt files")
35
results_type = typer.Option(help="Where to save the results")
36
true_or_false_type = typer.Option(default=False, is_flag=True)
37

38

39
@app.command()
40
def models(metadata: Annotated[bool, "Whether to print metadata"] = False):
41
    """Get a list of the available models"""
42
    result = sdk.models()
43
    if metadata:
44
        for model in result:
45
            rp(f"[bold]{model}:[/]")
46
            rp(sdk.metadata(model))
47
    else:
48
        print("\n".join(result))
49

50

51
@app.command()
52
def metadata(model: Annotated[List[str], model_type]):
53
    """Get metadata for models."""
54
    results = [sdk.metadata(m) for m in model]
55
    rp(results)
56

57

58
def _get_text(result: dict) -> str:
59
    if "text" in result["choices"][0]:
60
        return result["choices"][0]["text"]
61
    elif "message" in result["choices"][0]:
62
        return result["choices"][0]["message"]["content"]
63
    elif "delta" in result["choices"][0]:
64
        return result["choices"][0]["delta"].get("content", "")
65

66

67
def _print_result(result, model, print_stats):
68
    rp(f"[bold]{model}:[/]")
69
    if print_stats:
70
        rp("[bold]Stats:[/]")
71
        rp(result)
72
    else:
73
        rp(_get_text(result))
74

75

76
def progress_spinner():
77
    return Progress(
78
        SpinnerColumn(),
79
        TextColumn("[progress.description]{task.description}"),
80
        transient=True,
81
    )
82

83

84
@app.command()
85
def query(
86
    model: Annotated[List[str], model_type],
87
    prompt: Annotated[Optional[List[str]], prompt_type] = None,
88
    prompt_file: Annotated[Optional[str], prompt_file_type] = None,
89
    separator: Annotated[str, separator_type] = "----",
90
    output_file: Annotated[str, results_type] = "aviary-output.json",
91
    print_stats: Annotated[bool, stats_type] = False,
92
):
93
    """Query one or several models with one or multiple prompts,
94
    optionally read from file, and save the results to a file."""
95
    # TODO (max): deprecate and rename to "completions" to match the API
96
    with progress_spinner() as progress:
97
        if prompt_file:
98
            with open(prompt_file, "r") as f:
99
                prompt = f.read().split(separator)
100

101
        results = {p: [] for p in prompt}
102

103
        for m in model:
104
            progress.add_task(
105
                description=f"Processing all prompts against model: {m}.",
106
                total=None,
107
            )
108
            query_results = [sdk.query(m, p) for p in prompt]
109
            for result in query_results:
110
                _print_result(result, m, print_stats)
111

112
            for i, p in enumerate(prompt):
113
                result = query_results[i]
114
                text = _get_text(result)
115
                results[p].append({"model": m, "result": text, "stats": result})
116

117
        progress.add_task(description="Writing output file.", total=None)
118
        with open(output_file, "w") as f:
119
            f.write(json.dumps(results, indent=2))
120

121

122
def _get_yes_or_no_input(prompt) -> bool:
123
    while True:
124
        user_input = input(prompt).strip().lower()
125
        if user_input == "yes" or user_input == "y":
126
            return True
127
        elif user_input == "no" or user_input == "n" or user_input == "":
128
            return False
129
        else:
130
            print("Invalid input. Please enter 'yes / y' or 'no / n'.")
131

132

133
@app.command()
134
def run(
135
    model: Annotated[List[str], model_type],
136
    blocking: bool = True,
137
    restart: bool = true_or_false_type,
138
):
139
    """Start a model in Aviary.
140

141
    Args:
142
        *model: Models to run.
143
        blocking: Whether to block the CLI until the application is ready.
144
        restart: Whether to restart Aviary if it is already running.
145
    """
146
    msg = (
147
        "Running `aviary run` while Aviary is running will stop any exisiting Aviary (or other Ray Serve) deployments "
148
        f"and run the specified ones ({model}).\n"
149
        "Do you want to continue? [y/N]\n"
150
    )
151
    try:
152
        backend = sdk.get_aviary_backend(verbose=False)
153
        aviary_url = backend.backend_url
154
        aviary_started = False
155
        if aviary_url:
156
            health_check_url = f"{aviary_url}/health_check"
157
            aviary_started = requests.get(health_check_url).status_code == 200
158
        if aviary_started:
159
            if restart:
160
                restart_aviary = True
161
            else:
162
                restart_aviary = _get_yes_or_no_input(msg) or False
163

164
            if not restart_aviary:
165
                return
166
    except (requests.exceptions.ConnectionError, sdk.URLNotSetException):
167
        pass  # Aviary is not running
168

169
    sdk.shutdown()
170
    sdk.run(model, blocking=blocking)
171

172

173
@app.command()
174
def shutdown():
175
    """Shutdown Aviary."""
176
    sdk.shutdown()
177

178

179
evaluator_type = typer.Option(help="Which LLM to use for evaluation")
180

181

182
@app.command()
183
def evaluate(
184
    input_file: Annotated[str, results_type] = "aviary-output.json",
185
    evaluation_file: Annotated[str, results_type] = "evaluation-output.json",
186
    evaluator: Annotated[str, evaluator_type] = "gpt-4",
187
):
188
    """Evaluate and summarize the results of a multi_query run with a strong
189
    'evaluator' LLM like GPT-4.
190
    The results of the ranking are stored to file and displayed in a table.
191
    """
192
    with progress_spinner() as progress:
193
        progress.add_task(description="Loading the evaluator LLM.", total=None)
194
        if evaluator == "gpt-4":
195
            eval_model = GPT()
196
        else:
197
            raise NotImplementedError(f"No evaluator for {evaluator}")
198

199
        with open(input_file, "r") as f:
200
            results = json.load(f)
201

202
        for prompt, result_list in results.items():
203
            progress.add_task(
204
                description=f"Evaluating results for prompt: {prompt}.", total=None
205
            )
206
            evaluation = eval_model.evaluate_results(prompt, result_list)
207
            try:
208
                # GPT-4 returns a string with a Python dictionary, hopefully!
209
                evaluation = ast.literal_eval(evaluation)
210
            except Exception:
211
                print(f"Could not parse evaluation: {evaluation}")
212

213
            for i, _res in enumerate(results[prompt]):
214
                results[prompt][i]["rank"] = evaluation[i]["rank"]
215

216
        progress.add_task(description="Storing evaluations.", total=None)
217
        with open(evaluation_file, "w") as f:
218
            f.write(json.dumps(results, indent=2))
219

220
    for prompt in results.keys():
221
        table = Table(title="Evaluation results (higher ranks are better)")
222

223
        table.add_column("Model", justify="left", style="cyan", no_wrap=True)
224
        table.add_column("Rank", style="magenta")
225
        table.add_column("Response", justify="right", style="green")
226

227
        for i, _res in enumerate(results[prompt]):
228
            model = results[prompt][i]["model"]
229
            response = results[prompt][i]["result"]
230
            rank = results[prompt][i]["rank"]
231
            table.add_row(model, str(rank), response)
232

233
        console = Console()
234
        console.print(table)
235

236

237
@app.command()
238
def stream(
239
    model: Annotated[str, model_type],
240
    prompt: Annotated[str, prompt_type],
241
    print_stats: Annotated[bool, stats_type] = False,
242
):
243
    """"""
244
    for chunk in sdk.stream(model, prompt):
245
        text = _get_text(chunk)
246
        rp(text, end="")
247
    rp("")
248
    if print_stats:
249
        rp("[bold]Stats:[/]")
250
        rp(chunk)
251

252

253
if __name__ == "__main__":
254
    app()
255

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.