4
from functools import partial
17
from fastapi import HTTPException, Request, status
18
from httpx import HTTPStatusError as HTTPXHTTPStatusError
19
from opentelemetry import trace
20
from pydantic import ValidationError as PydanticValidationError
22
from rayllm.backend.logger import get_logger
23
from rayllm.backend.server import constants
24
from rayllm.backend.server.models import (
28
from rayllm.backend.server.openai_compat.openai_exception import OpenAIHTTPException
29
from rayllm.common.models import ErrorResponse
33
logger = get_logger(__name__)
34
AVIARY_ROUTER_HTTP_TIMEOUT = float(os.environ.get("AVIARY_ROUTER_HTTP_TIMEOUT", 175))
38
args: Union[str, LLMApp, List[Union[LLMApp, str]]], llm_app_cls=LLMApp
40
"""Parse the input args and return a standardized list of LLMApp objects
42
Supported args format:
43
1. The path to a yaml file defining your LLMApp
44
2. The path to a folder containing yaml files, which define your LLMApps
45
2. A list of yaml files defining multiple LLMApps
46
3. A dict or LLMApp object
47
4. A list of dicts or LLMApp objects
52
if isinstance(args, list):
58
models: List[LLMApp] = []
59
for raw_model in raw_models:
60
if isinstance(raw_model, str):
61
if os.path.exists(raw_model):
62
parsed_models = _parse_path_args(raw_model, llm_app_cls=llm_app_cls)
65
parsed_models = [llm_app_cls.parse_yaml(raw_model)]
66
except pydantic.ValidationError as e:
67
if "__root__" in repr(e):
69
"Could not parse string as yaml. If you are specifying a path, make sure it exists and can be reached."
74
parsed_models = [llm_app_cls.parse_obj(raw_model)]
75
models += parsed_models
76
return [model for model in models if model.enabled]
79
def _parse_path_args(path: str, llm_app_cls=LLMApp) -> List[LLMApp]:
80
assert os.path.exists(
82
), f"Could not load model from {path}, as it does not exist."
83
if os.path.isfile(path):
84
with open(path, "r") as f:
85
return [llm_app_cls.parse_yaml(f)]
86
elif os.path.isdir(path):
88
for root, _dirs, files in os.walk(path):
91
with open(os.path.join(root, p), "r") as f:
92
apps.append(llm_app_cls.parse_yaml(f))
96
f"Could not load model from {path}, as it is not a file or directory."
100
def _is_yaml_file(filename: str) -> bool:
101
yaml_exts = [".yml", ".yaml", ".json"]
103
if filename.endswith(s):
108
def _replace_prefix(model: str) -> str:
109
"""Replace -- with / in model name to handle slashes within the URL path segment"""
110
return model.replace("--", "/")
113
async def _until_disconnected(request: Request):
115
if await request.is_disconnected():
117
await asyncio.sleep(1)
120
EOS_SENTINELS = (None, StopIteration, StopAsyncIteration)
123
async def collapse_stream(async_iterator: AsyncIterable[T]) -> List[T]:
124
return [x async for x in async_iterator]
127
async def get_lines_batched(async_iterator: AsyncIterable[bytes]):
128
"""Batch the lines of a bytes iterator
130
Group the output of an async iterator into lines.
132
Eg. if the iterator output:
140
Furthermore, if multiple lines are present, return all of them.
143
async_iterator (AsyncIterable[bytes]): A bytes iterator
146
AsyncIterable[bytes]: A bytes iterator that chunks on lines
149
async for chunk in async_iterator:
151
if b"\n" in remainder:
152
out, remainder = remainder.rsplit(b"\n", 1)
159
async def get_model_response_batched(async_iterator: AsyncIterable[bytes]):
160
"""Parse AviaryModelResponse from byte stream
162
First group the data from the iterator into lines.
163
Then for each line, parse it as an AviaryModelResponse object.
166
async_iterator (AsyncIterable[bytes]): the input iterator
171
async for chunk in get_lines_batched(async_iterator):
172
responses = [AviaryModelResponse.parse_raw(p) for p in chunk.split(b"\n") if p]
173
combined_response = AviaryModelResponse.merge_stream(*responses)
174
yield combined_response
177
async def stream_model_responses(
178
url: str, json=None, timeout=AVIARY_ROUTER_HTTP_TIMEOUT
180
"""Make a streaming network request, and parse the output into a stream of AviaryModelResponse
182
Take the output stream of the request and parse it into a stream of AviaryModelResponse.
185
url (str): The url to querky
186
json (_type_, optional): the json body
187
timeout (_type_, optional): Defaults to AVIARY_ROUTER_HTTP_TIMEOUT.
192
async with aiohttp.ClientSession(raise_for_status=True) as session:
193
async with session.post(
198
async for combined_response in get_model_response_batched(
199
response.content.iter_any()
201
yield combined_response
207
def make_async(_func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
208
"""Take a blocking function, and run it on in an executor thread.
210
This function prevents the blocking function from blocking the asyncio event loop.
211
The code in this function needs to be thread safe.
214
def _async_wrapper(*args, **kwargs) -> asyncio.Future:
215
loop = asyncio.get_event_loop()
216
func = partial(_func, *args, **kwargs)
217
return loop.run_in_executor(executor=None, func=func)
219
return _async_wrapper
222
def extract_message_from_exception(e: Exception) -> str:
223
# If the exception is a Ray exception, we need to dig through the text to get just
224
# the exception message without the stack trace
225
# This also works for normal exceptions (we will just return everything from
226
# format_exception_only in that case)
227
message_lines = traceback.format_exception_only(type(e), e)[-1].strip().split("\n")
229
# The stack trace lines will be prefixed with spaces, so we need to start from the bottom
230
# and stop at the last line before a line with a space
231
found_last_line_before_stack_trace = False
232
for line in reversed(message_lines):
233
if not line.startswith(" "):
234
found_last_line_before_stack_trace = True
235
if found_last_line_before_stack_trace and line.startswith(" "):
237
message = line + "\n" + message
238
message = message.strip()
242
def get_response_for_error(
245
span: Optional[trace.Span] = None,
247
enable_returning_500_errors: Optional[bool] = None,
248
) -> AviaryModelResponse:
249
"""Convert an exception to an AviaryModelResponse object"""
250
enable_returning_500_errors = (
251
enable_returning_500_errors
252
if enable_returning_500_errors is not None
253
else constants.enable_returning_internal_exceptions
255
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
256
if isinstance(e, HTTPException):
257
status_code = e.status_code
258
elif isinstance(e, OpenAIHTTPException):
259
status_code = e.status_code
260
elif isinstance(e, PydanticValidationError):
262
elif isinstance(e, HTTPXHTTPStatusError):
263
status_code = e.response.status_code
265
# Try to get the status code attribute
266
status_code = getattr(e, "status_code", status_code)
269
f"{prefix}. Request {request_id} failed with status code {status_code}: {e}",
274
status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
275
and not enable_returning_500_errors
277
message = "Internal Server Error"
278
internal_message = message
279
exc_type = "InternalServerError"
281
if isinstance(e, OpenAIHTTPException) and e.internal_message is not None:
282
internal_message = e.internal_message
284
internal_message = extract_message_from_exception(e)
285
if isinstance(e, HTTPException):
287
elif isinstance(e, OpenAIHTTPException):
290
message = internal_message
291
exc_type = e.__class__.__name__
293
# TODO make this more robust
294
if "(Request ID: " not in message:
295
message += f" (Request ID: {request_id})"
297
if "(Request ID: " not in internal_message:
298
internal_message += f" (Request ID: {request_id})"
301
span = trace.get_current_span()
302
span.record_exception(e)
303
span.set_status(trace.StatusCode.ERROR, description=message)
305
return AviaryModelResponse(
309
internal_message=internal_message,