ray-llm

Форк
0
571 строка · 21.1 Кб
1
import asyncio
2
import os
3
import time
4
from typing import AsyncGenerator, List, Optional, Tuple
5

6
import async_timeout
7
from fastapi import FastAPI, HTTPException, status
8
from fastapi import Response as FastAPIResponse
9
from fastapi.middleware.cors import CORSMiddleware
10
from httpx import HTTPStatusError as HTTPXHTTPStatusError
11
from ray import serve
12
from starlette.exceptions import ExceptionMiddleware
13
from starlette.requests import Request
14
from starlette.responses import Response, StreamingResponse
15

16
from rayllm.backend.llm.embedding.embedding_models import Embeddings
17
from rayllm.backend.logger import get_logger
18
from rayllm.backend.observability.telemetry import configure_telemetry
19
from rayllm.backend.server.models import (
20
    AviaryModelResponse,
21
    ChatCompletionsParams,
22
    CompletionsParams,
23
    Prompt,
24
    QueuePriority,
25
)
26
from rayllm.backend.server.openai_compat.openai_exception import OpenAIHTTPException
27
from rayllm.backend.server.openai_compat.openai_middleware import (
28
    openai_exception_handler,
29
)
30
from rayllm.backend.server.plugins.router_query_engine import RouterQueryClient
31
from rayllm.backend.server.routers.middleware import add_request_id
32
from rayllm.backend.server.utils import _replace_prefix, get_response_for_error
33
from rayllm.common.models import (
34
    ChatCompletion,
35
    ChoiceLogProbs,
36
    Completion,
37
    DeletedModel,
38
    DeltaChoices,
39
    DeltaContent,
40
    DeltaEOS,
41
    DeltaRole,
42
    EmbeddingsData,
43
    EmbeddingsOutput,
44
    EmbeddingsUsage,
45
    LogProbs,
46
    Message,
47
    MessageChoices,
48
    Model,
49
    ModelData,
50
    TextChoice,
51
    Usage,
52
)
53

54
logger = get_logger(__name__)
55

56

57
# timeout in 10 minutes. Streaming can take longer than 3 min
58
TIMEOUT = float(os.environ.get("AVIARY_ROUTER_HTTP_TIMEOUT", 600))
59

60

61
def init() -> FastAPI:
62
    router_app = FastAPI()
63

64
    router_app.add_exception_handler(OpenAIHTTPException, openai_exception_handler)
65
    router_app.add_exception_handler(HTTPException, openai_exception_handler)
66
    router_app.add_exception_handler(HTTPXHTTPStatusError, openai_exception_handler)
67
    router_app.add_middleware(
68
        CORSMiddleware,
69
        allow_origins=["*"],
70
        allow_credentials=True,
71
        allow_methods=["*"],
72
        allow_headers=["*"],
73
    )
74

75
    # Add a unique per-request ID
76
    router_app.middleware("http")(add_request_id)
77
    # Configure common FastAPI app telemetry
78
    configure_telemetry(router_app, "model_router_app")
79
    # this is necessary for passing through exceptions to users,
80
    # seems to be some flaws of starlette, see discussion at
81
    # https://github.com/encode/starlette/issues/1175
82
    router_app.add_middleware(
83
        ExceptionMiddleware, handlers=router_app.exception_handlers
84
    )
85

86
    return router_app
87

88

89
router_app = init()
90

91

92
async def _openai_json_generator(
93
    generator: AsyncGenerator[AviaryModelResponse, None],
94
    first_response: Optional[AviaryModelResponse] = None,
95
):
96
    if first_response is not None:
97
        yield "data: " + first_response.json() + "\n\n"
98
    async for response in generator:
99
        yield "data: " + response.json() + "\n\n"
100
    yield "data: [DONE]\n\n"
101

102

103
async def _peek_at_openai_json_generator(
104
    generator: AsyncGenerator[AviaryModelResponse, None]
105
) -> Tuple[AviaryModelResponse, AsyncGenerator[str, None]]:
106
    """Runs one iteration of the underlying generator
107
    and returns the result alongside the generator itself (with the
108
    first iteration still there).
109
    """
110
    first_response = await generator.__anext__()
111
    return first_response, _openai_json_generator(generator, first_response)
112

113

114
async def _completions_wrapper(
115
    model: str,
116
    request_id: str,
117
    response: Response,
118
    generator: AsyncGenerator[AviaryModelResponse, None],
119
) -> AsyncGenerator[AviaryModelResponse, None]:
120
    had_error = False
121
    completion_id = _get_model_request_id(model, request_id)
122
    async with async_timeout.timeout(TIMEOUT):
123
        all_results = []
124
        try:
125
            async for results in generator:
126
                for subresult in results.unpack():
127
                    all_results.append(subresult)
128
                    subresult_dict = subresult.dict()
129
                    if subresult_dict.get("error"):
130
                        response.status_code = subresult_dict["error"]["code"]
131
                        # Drop finish reason as OpenAI doesn't expect it
132
                        # for errors in streaming
133
                        subresult_dict["finish_reason"] = None
134
                        logger.error(
135
                            f"Reporting back an error: {subresult_dict['error']}"
136
                        )
137
                        all_results.pop()
138
                        had_error = True
139
                        yield AviaryModelResponse.parse_obj(subresult_dict)
140
                        # Return early in case of an error
141
                        break
142
                    choices = [
143
                        TextChoice(
144
                            text=subresult_dict["generated_text"] or "",
145
                            index=0,
146
                            logprobs={},
147
                            finish_reason=subresult_dict["finish_reason"],
148
                        )
149
                    ]
150
                    usage = None
151
                    if subresult_dict["finish_reason"]:
152
                        usage = (
153
                            Usage.from_response(
154
                                AviaryModelResponse.merge_stream(*all_results)
155
                            )
156
                            if all_results
157
                            else None
158
                        )
159
                    yield Completion(
160
                        id=completion_id,
161
                        object="text_completion",
162
                        created=int(time.time()),
163
                        model=model,
164
                        choices=choices,
165
                        usage=usage,
166
                    )
167
                if had_error:
168
                    # Return early in case of an error
169
                    break
170
        except Exception as e:
171
            logger.error(
172
                f"Failed while handling completions for request ({request_id}): {repr(e)}",
173
                exc_info=e,
174
            )
175

176
            exc_response = get_response_for_error(e, request_id)
177
            response.status_code = exc_response.error.code
178
            had_error = True
179
            yield exc_response
180

181

182
async def _chat_completions_wrapper(
183
    model: str,
184
    request_id: str,
185
    response: Response,
186
    generator: AsyncGenerator[AviaryModelResponse, None],
187
) -> AsyncGenerator[AviaryModelResponse, None]:
188
    had_error = False
189
    completion_id = _get_model_request_id(model, request_id)
190
    async with async_timeout.timeout(TIMEOUT):
191
        finish_reason = None
192
        choices: List[DeltaChoices] = [
193
            DeltaChoices(
194
                delta=DeltaRole(role="assistant"),
195
                index=0,
196
                finish_reason=None,
197
            )
198
        ]
199

200
        yielded_role = False
201
        all_results = []
202
        try:
203
            async for results in generator:
204
                for subresult in results.unpack():
205
                    logger.info(f"subresult: {subresult}")
206
                    all_results.append(subresult)
207
                    subresult_dict = subresult.dict()
208
                    if subresult_dict.get("error"):
209
                        response.status_code = subresult_dict["error"]["code"]
210
                        logger.error(f"{subresult_dict['error']}")
211
                        # Drop finish reason as OpenAI doesn't expect it
212
                        # for errors in streaming
213
                        subresult_dict["finish_reason"] = None
214
                        all_results.pop()
215
                        had_error = True
216
                        yield AviaryModelResponse.parse_obj(subresult_dict)
217
                        # Return early in case of an error
218
                        break
219
                    else:
220
                        finish_reason = subresult_dict["finish_reason"]
221

222
                        if not yielded_role:
223
                            choices: List[DeltaChoices] = [
224
                                DeltaChoices(
225
                                    delta=DeltaRole(role="assistant"),
226
                                    index=0,
227
                                    finish_reason=None,
228
                                    logprobs=ChoiceLogProbs(content=[]),
229
                                )
230
                            ]
231
                            yield ChatCompletion(
232
                                id=completion_id,
233
                                object="text_completion",
234
                                created=int(time.time()),
235
                                model=model,
236
                                choices=choices,
237
                                usage=None,
238
                            )
239
                            yielded_role = True
240
                        if subresult_dict["logprobs"]:
241
                            logprobs = ChoiceLogProbs(
242
                                content=[
243
                                    LogProbs.parse_obj(logprob)
244
                                    for logprob in subresult_dict["logprobs"]
245
                                ]
246
                            )
247
                        else:
248
                            logprobs = None
249
                        choices: List[DeltaChoices] = [
250
                            DeltaChoices(
251
                                delta=DeltaContent(
252
                                    content=subresult_dict["generated_text"] or "",
253
                                    tool_calls=subresult_dict["tool_calls"] or None,
254
                                ),
255
                                index=0,
256
                                finish_reason=None,
257
                                logprobs=logprobs,
258
                            )
259
                        ]
260
                        yield ChatCompletion(
261
                            id=completion_id,
262
                            object="text_completion",
263
                            created=int(time.time()),
264
                            model=model,
265
                            choices=choices,
266
                            usage=None,
267
                        )
268
                if had_error:
269
                    # Return early in case of an error
270
                    break
271
        except Exception as e:
272
            logger.error(
273
                f"Failed while handling chat-completions for request ({request_id}): {repr(e)}",
274
                exc_info=e,
275
            )
276

277
            exc_response = get_response_for_error(e, request_id)
278
            response.status_code = exc_response.error.code
279
            had_error = True
280
            yield exc_response
281

282
        if not had_error:
283
            choices: List[DeltaChoices] = [
284
                DeltaChoices(
285
                    delta=DeltaEOS(),
286
                    index=0,
287
                    finish_reason=finish_reason,
288
                )
289
            ]
290
            usage = (
291
                Usage.from_response(AviaryModelResponse.merge_stream(*all_results))
292
                if all_results
293
                else None
294
            )
295
            yield ChatCompletion(
296
                id=completion_id,
297
                object="text_completion",
298
                created=int(time.time()),
299
                model=model,
300
                choices=choices,
301
                usage=usage,
302
            )
303

304

305
class Router:
306
    def __init__(
307
        self,
308
        query_engine: RouterQueryClient,
309
    ) -> None:
310
        # Increase the amount of time allocated for fetching the queue length
311
        # TODO(tchordia): use the associated env var instead once it's available
312
        serve._private.router.PowerOfTwoChoicesReplicaScheduler.queue_len_response_deadline_s = (
313
            0.5
314
        )
315
        self.query_engine = query_engine
316

317
    @router_app.get("/v1/models", response_model=Model)
318
    async def models(self) -> Model:
319
        """OpenAI API-compliant endpoint to get all Aviary models."""
320
        models = await self.query_engine.models()
321
        return Model(data=list(models.values()))
322

323
    # :path allows us to have slashes in the model name
324
    @router_app.get("/v1/models/{model:path}", response_model=ModelData)
325
    async def model_data(self, model: str) -> ModelData:
326
        """OpenAI API-compliant endpoint to get one Aviary model.
327

328
        :param model: The Aviary model ID (e.g. "amazon/LightGPT")
329
        """
330
        model = _replace_prefix(model)
331
        model_data = await self.query_engine.model(model)
332
        if model_data is None:
333
            raise OpenAIHTTPException(
334
                message=f"Invalid model '{model}'",
335
                status_code=status.HTTP_400_BAD_REQUEST,
336
                type="InvalidModel",
337
            )
338
        return model_data
339

340
    @router_app.delete("/v1/models/{model:path}", response_model=DeletedModel)
341
    async def delete_fine_tuned_model(self, model: str) -> DeletedModel:
342
        """OpenAI API-compliant endpoint to delete one fine-tuned model.
343

344
        :param model: The fine-tuned model ID (e.g. "meta-llama/Llama-2-7b-chat-hf:john:aBc1234")
345
        """
346
        model = _replace_prefix(model)
347
        await self.query_engine.delete_fine_tuned_model(model)
348
        return DeletedModel(id=model)
349

350
    @router_app.post("/v1/completions")
351
    async def completions(
352
        self,
353
        body: CompletionsParams,
354
        request: Request,
355
        response: FastAPIResponse,
356
    ):
357
        """Given a prompt, the model will return one or more predicted completions,
358
        and can also return the probabilities of alternative tokens at each position.
359

360
        Returns:
361
            A response object with completions.
362
        """
363
        req_id = request.state.request_id
364
        prompt = Prompt(
365
            prompt=body.prompt,
366
            parameters=body,
367
            use_prompt_format=False,
368
        )
369

370
        if body.stream:
371
            first_response, wrapper = await _peek_at_openai_json_generator(
372
                _completions_wrapper(
373
                    body.model,
374
                    req_id,
375
                    response,
376
                    self.query_engine.stream(
377
                        body.model,
378
                        prompt,
379
                        request,
380
                        priority=QueuePriority.GENERATE_TEXT,
381
                    ),
382
                ),
383
            )
384
            if isinstance(first_response, AviaryModelResponse) and first_response.error:
385
                raise OpenAIHTTPException.from_model_response(first_response)
386
            return StreamingResponse(wrapper, media_type="text/event-stream")
387
        else:
388
            async with async_timeout.timeout(TIMEOUT):
389
                results = await self.query_engine.query(body.model, prompt, request)
390
                if results.error:
391
                    raise OpenAIHTTPException(
392
                        message=results.error.message,
393
                        status_code=results.error.code,
394
                        type=results.error.type,
395
                    )
396
                results = results.dict()
397

398
                choices = [
399
                    TextChoice(
400
                        text=results["generated_text"] or "",
401
                        index=0,
402
                        logprobs={},
403
                        finish_reason=results["finish_reason"],
404
                    )
405
                ]
406
                usage = Usage.from_response(results)
407
                # TODO: pick up parameters that make sense, remove the rest
408

409
                return Completion(
410
                    id=_get_model_request_id(body.model, req_id),
411
                    object="text_completion",
412
                    created=int(time.time()),
413
                    model=body.model,
414
                    choices=choices,
415
                    usage=usage,
416
                )
417

418
    @router_app.post("/v1/chat/completions")
419
    async def chat(
420
        self,
421
        body: ChatCompletionsParams,
422
        request: Request,
423
        response: FastAPIResponse,
424
    ):
425
        """Given a prompt, the model will return one or more predicted completions,
426
        and can also return the probabilities of alternative tokens at each position.
427

428
        Returns:
429
            A response object with completions.
430
        """
431
        tools = body.tools
432
        tool_choice = body.tool_choice
433
        # Doing this to remove them from sampling params
434
        body.tools = None
435
        body.tool_choice = None
436

437
        req_id = request.state.request_id
438
        prompt = Prompt(
439
            prompt=body.messages, parameters=body, tools=tools, tool_choice=tool_choice
440
        )
441

442
        if body.stream:
443
            first_response, wrapper = await _peek_at_openai_json_generator(
444
                _chat_completions_wrapper(
445
                    body.model,
446
                    req_id,
447
                    response,
448
                    self.query_engine.stream(
449
                        body.model,
450
                        prompt,
451
                        request,
452
                        priority=QueuePriority.GENERATE_TEXT,
453
                    ),
454
                ),
455
            )
456
            if isinstance(first_response, AviaryModelResponse) and first_response.error:
457
                raise OpenAIHTTPException.from_model_response(first_response)
458
            return StreamingResponse(wrapper, media_type="text/event-stream")
459
        else:
460
            async with async_timeout.timeout(TIMEOUT):
461
                results = await self.query_engine.query(body.model, prompt, request)
462
                if results.error:
463
                    raise OpenAIHTTPException(
464
                        message=results.error.message,
465
                        status_code=results.error.code,
466
                        type=results.error.type,
467
                    )
468
                # TODO: pick up parameters that make sense, remove the rest
469
                logprobs = results.logprobs
470
                if logprobs:
471
                    logprobs = ChoiceLogProbs(
472
                        content=[LogProbs.parse_obj(logprob) for logprob in logprobs]
473
                    )
474
                else:
475
                    logprobs = None
476
                if results.tool_calls:
477
                    msg = Message(role="assistant", tool_calls=results.tool_calls)
478
                    # deleting this fields so that they don't appear in the response
479
                    del msg.tool_call_id
480
                    choices: List[MessageChoices] = [
481
                        MessageChoices(
482
                            message=msg,
483
                            index=0,
484
                            finish_reason=results.finish_reason,
485
                            logprobs=logprobs,
486
                        )
487
                    ]
488
                else:
489
                    choices: List[MessageChoices] = [
490
                        MessageChoices(
491
                            message=Message(
492
                                role="assistant",
493
                                content=results.generated_text or "",
494
                            ),
495
                            index=0,
496
                            finish_reason=results.finish_reason,
497
                            logprobs=logprobs,
498
                        )
499
                    ]
500

501
                usage = Usage.from_response(results)
502

503
                return ChatCompletion(
504
                    id=_get_model_request_id(body.model, req_id),
505
                    object="text_completion",
506
                    created=int(time.time()),
507
                    model=body.model,
508
                    choices=choices,
509
                    usage=usage,
510
                )
511

512
    @router_app.post("/v1/embeddings")
513
    async def embed(
514
        self,
515
        body: Embeddings,
516
        request: Request,
517
    ):
518
        """Given a prompt, the model will return one embedding.
519

520
        Returns:
521
            A response object with an embedding.
522
        """
523
        embedding_id = _get_model_request_id(body.model, request.state.request_id)
524

525
        async with async_timeout.timeout(TIMEOUT):
526
            if isinstance(body.input, str):
527
                input = [body.input]
528
            else:
529
                input = body.input
530
            prompts = [Prompt(prompt=x, parameters=body) for x in input]
531
            results_list: List[AviaryModelResponse] = await asyncio.gather(
532
                *[
533
                    self.query_engine.query(body.model, prompt, request)
534
                    for prompt in prompts
535
                ]
536
            )
537
            final_results = []
538
            tokens = 0
539
            for results in results_list:
540
                if results.error:
541
                    raise OpenAIHTTPException.from_model_response(results)
542
                final_results.append(results.dict())
543
                tokens += results.num_input_tokens
544

545
            return EmbeddingsOutput(
546
                data=[
547
                    EmbeddingsData(
548
                        embedding=results["embedding_outputs"],
549
                        index=i,
550
                        object="embedding",
551
                    )
552
                    for i, results in enumerate(final_results)
553
                ],
554
                id=embedding_id,
555
                object="list",
556
                created=int(time.time()),
557
                model=body.model,
558
                usage=EmbeddingsUsage(
559
                    prompt_tokens=tokens,
560
                    total_tokens=tokens,
561
                ),
562
            )
563

564
    @router_app.get("/v1/health_check")
565
    async def health_check(self) -> bool:
566
        """Check if the routher is still running."""
567
        return True
568

569

570
def _get_model_request_id(model: str, request_id: str):
571
    return model + "-" + request_id
572

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.