ray-llm

Форк
0
312 строк · 9.7 Кб
1
import asyncio
2
import os
3
import traceback
4
from functools import partial
5
from typing import (
6
    AsyncIterable,
7
    Awaitable,
8
    Callable,
9
    List,
10
    Optional,
11
    TypeVar,
12
    Union,
13
)
14

15
import aiohttp
16
import pydantic
17
from fastapi import HTTPException, Request, status
18
from httpx import HTTPStatusError as HTTPXHTTPStatusError
19
from opentelemetry import trace
20
from pydantic import ValidationError as PydanticValidationError
21

22
from rayllm.backend.logger import get_logger
23
from rayllm.backend.server import constants
24
from rayllm.backend.server.models import (
25
    AviaryModelResponse,
26
    LLMApp,
27
)
28
from rayllm.backend.server.openai_compat.openai_exception import OpenAIHTTPException
29
from rayllm.common.models import ErrorResponse
30

31
T = TypeVar("T")
32

33
logger = get_logger(__name__)
34
AVIARY_ROUTER_HTTP_TIMEOUT = float(os.environ.get("AVIARY_ROUTER_HTTP_TIMEOUT", 175))
35

36

37
def parse_args(
38
    args: Union[str, LLMApp, List[Union[LLMApp, str]]], llm_app_cls=LLMApp
39
) -> List[LLMApp]:
40
    """Parse the input args and return a standardized list of LLMApp objects
41

42
    Supported args format:
43
    1. The path to a yaml file defining your LLMApp
44
    2. The path to a folder containing yaml files, which define your LLMApps
45
    2. A list of yaml files defining multiple LLMApps
46
    3. A dict or LLMApp object
47
    4. A list of dicts or LLMApp objects
48

49
    """
50

51
    raw_models = []
52
    if isinstance(args, list):
53
        raw_models = args
54
    else:
55
        raw_models = [args]
56

57
    # For each
58
    models: List[LLMApp] = []
59
    for raw_model in raw_models:
60
        if isinstance(raw_model, str):
61
            if os.path.exists(raw_model):
62
                parsed_models = _parse_path_args(raw_model, llm_app_cls=llm_app_cls)
63
            else:
64
                try:
65
                    parsed_models = [llm_app_cls.parse_yaml(raw_model)]
66
                except pydantic.ValidationError as e:
67
                    if "__root__" in repr(e):
68
                        raise ValueError(
69
                            "Could not parse string as yaml. If you are specifying a path, make sure it exists and can be reached."
70
                        ) from e
71
                    else:
72
                        raise
73
        else:
74
            parsed_models = [llm_app_cls.parse_obj(raw_model)]
75
        models += parsed_models
76
    return [model for model in models if model.enabled]
77

78

79
def _parse_path_args(path: str, llm_app_cls=LLMApp) -> List[LLMApp]:
80
    assert os.path.exists(
81
        path
82
    ), f"Could not load model from {path}, as it does not exist."
83
    if os.path.isfile(path):
84
        with open(path, "r") as f:
85
            return [llm_app_cls.parse_yaml(f)]
86
    elif os.path.isdir(path):
87
        apps = []
88
        for root, _dirs, files in os.walk(path):
89
            for p in files:
90
                if _is_yaml_file(p):
91
                    with open(os.path.join(root, p), "r") as f:
92
                        apps.append(llm_app_cls.parse_yaml(f))
93
        return apps
94
    else:
95
        raise ValueError(
96
            f"Could not load model from {path}, as it is not a file or directory."
97
        )
98

99

100
def _is_yaml_file(filename: str) -> bool:
101
    yaml_exts = [".yml", ".yaml", ".json"]
102
    for s in yaml_exts:
103
        if filename.endswith(s):
104
            return True
105
    return False
106

107

108
def _replace_prefix(model: str) -> str:
109
    """Replace -- with / in model name to handle slashes within the URL path segment"""
110
    return model.replace("--", "/")
111

112

113
async def _until_disconnected(request: Request):
114
    while True:
115
        if await request.is_disconnected():
116
            return True
117
        await asyncio.sleep(1)
118

119

120
EOS_SENTINELS = (None, StopIteration, StopAsyncIteration)
121

122

123
async def collapse_stream(async_iterator: AsyncIterable[T]) -> List[T]:
124
    return [x async for x in async_iterator]
125

126

127
async def get_lines_batched(async_iterator: AsyncIterable[bytes]):
128
    """Batch the lines of a bytes iterator
129

130
    Group the output of an async iterator into lines.
131

132
    Eg. if the iterator output:
133
    b'a'
134
    b'b'
135
    b'c\n'
136

137
    The output would be
138
    b'abc\n'
139

140
    Furthermore, if multiple lines are present, return all of them.
141

142
    Args:
143
        async_iterator (AsyncIterable[bytes]): A bytes iterator
144

145
    Yields:
146
        AsyncIterable[bytes]: A bytes iterator that chunks on lines
147
    """
148
    remainder = b""
149
    async for chunk in async_iterator:
150
        remainder += chunk
151
        if b"\n" in remainder:
152
            out, remainder = remainder.rsplit(b"\n", 1)
153
            yield out + b"\n"
154

155
    if remainder != b"":
156
        yield remainder
157

158

159
async def get_model_response_batched(async_iterator: AsyncIterable[bytes]):
160
    """Parse AviaryModelResponse from byte stream
161

162
    First group the data from the iterator into lines.
163
    Then for each line, parse it as an AviaryModelResponse object.
164

165
    Args:
166
        async_iterator (AsyncIterable[bytes]): the input iterator
167

168
    Yields:
169
        AviaryModelResponse
170
    """
171
    async for chunk in get_lines_batched(async_iterator):
172
        responses = [AviaryModelResponse.parse_raw(p) for p in chunk.split(b"\n") if p]
173
        combined_response = AviaryModelResponse.merge_stream(*responses)
174
        yield combined_response
175

176

177
async def stream_model_responses(
178
    url: str, json=None, timeout=AVIARY_ROUTER_HTTP_TIMEOUT
179
):
180
    """Make a streaming network request, and parse the output into a stream of AviaryModelResponse
181

182
    Take the output stream of the request and parse it into a stream of AviaryModelResponse.
183

184
    Args:
185
        url (str): The url to querky
186
        json (_type_, optional): the json body
187
        timeout (_type_, optional): Defaults to AVIARY_ROUTER_HTTP_TIMEOUT.
188

189
    Yields:
190
        AviaryModelResponse
191
    """
192
    async with aiohttp.ClientSession(raise_for_status=True) as session:
193
        async with session.post(
194
            url,
195
            json=json,
196
            timeout=timeout,
197
        ) as response:
198
            async for combined_response in get_model_response_batched(
199
                response.content.iter_any()
200
            ):
201
                yield combined_response
202

203

204
T = TypeVar("T")
205

206

207
def make_async(_func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
208
    """Take a blocking function, and run it on in an executor thread.
209

210
    This function prevents the blocking function from blocking the asyncio event loop.
211
    The code in this function needs to be thread safe.
212
    """
213

214
    def _async_wrapper(*args, **kwargs) -> asyncio.Future:
215
        loop = asyncio.get_event_loop()
216
        func = partial(_func, *args, **kwargs)
217
        return loop.run_in_executor(executor=None, func=func)
218

219
    return _async_wrapper
220

221

222
def extract_message_from_exception(e: Exception) -> str:
223
    # If the exception is a Ray exception, we need to dig through the text to get just
224
    # the exception message without the stack trace
225
    # This also works for normal exceptions (we will just return everything from
226
    # format_exception_only in that case)
227
    message_lines = traceback.format_exception_only(type(e), e)[-1].strip().split("\n")
228
    message = ""
229
    # The stack trace lines will be prefixed with spaces, so we need to start from the bottom
230
    # and stop at the last line before a line with a space
231
    found_last_line_before_stack_trace = False
232
    for line in reversed(message_lines):
233
        if not line.startswith(" "):
234
            found_last_line_before_stack_trace = True
235
        if found_last_line_before_stack_trace and line.startswith(" "):
236
            break
237
        message = line + "\n" + message
238
    message = message.strip()
239
    return message
240

241

242
def get_response_for_error(
243
    e: Exception,
244
    request_id: str,
245
    span: Optional[trace.Span] = None,
246
    prefix="",
247
    enable_returning_500_errors: Optional[bool] = None,
248
) -> AviaryModelResponse:
249
    """Convert an exception to an AviaryModelResponse object"""
250
    enable_returning_500_errors = (
251
        enable_returning_500_errors
252
        if enable_returning_500_errors is not None
253
        else constants.enable_returning_internal_exceptions
254
    )
255
    status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
256
    if isinstance(e, HTTPException):
257
        status_code = e.status_code
258
    elif isinstance(e, OpenAIHTTPException):
259
        status_code = e.status_code
260
    elif isinstance(e, PydanticValidationError):
261
        status_code = 400
262
    elif isinstance(e, HTTPXHTTPStatusError):
263
        status_code = e.response.status_code
264
    else:
265
        # Try to get the status code attribute
266
        status_code = getattr(e, "status_code", status_code)
267

268
    logger.error(
269
        f"{prefix}. Request {request_id} failed with status code {status_code}: {e}",
270
        exc_info=e,
271
    )
272

273
    if (
274
        status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
275
        and not enable_returning_500_errors
276
    ):
277
        message = "Internal Server Error"
278
        internal_message = message
279
        exc_type = "InternalServerError"
280
    else:
281
        if isinstance(e, OpenAIHTTPException) and e.internal_message is not None:
282
            internal_message = e.internal_message
283
        else:
284
            internal_message = extract_message_from_exception(e)
285
        if isinstance(e, HTTPException):
286
            message = e.detail
287
        elif isinstance(e, OpenAIHTTPException):
288
            message = e.message
289
        else:
290
            message = internal_message
291
        exc_type = e.__class__.__name__
292

293
    # TODO make this more robust
294
    if "(Request ID: " not in message:
295
        message += f" (Request ID: {request_id})"
296

297
    if "(Request ID: " not in internal_message:
298
        internal_message += f" (Request ID: {request_id})"
299

300
    if span is None:
301
        span = trace.get_current_span()
302
    span.record_exception(e)
303
    span.set_status(trace.StatusCode.ERROR, description=message)
304

305
    return AviaryModelResponse(
306
        error=ErrorResponse(
307
            message=message,
308
            code=status_code,
309
            internal_message=internal_message,
310
            type=exc_type,
311
        ),
312
    )
313

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.