1
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypeVar, Union
3
from fastapi import HTTPException, status
4
from pydantic import BaseModel, root_validator, validator
7
from rayllm.backend.server.models import AviaryModelResponse
9
TModel = TypeVar("TModel", bound="Model")
10
TCompletion = TypeVar("TCompletion", bound="Completion")
11
TChatCompletion = TypeVar("TChatCompletion", bound="ChatCompletion")
13
PROMPT_TRACE_KEY = "+TRACE_"
16
class PromptFormatDisabledError(ValueError):
20
class ModelData(BaseModel):
25
rayllm_metadata: Dict[str, Any]
28
class Model(BaseModel):
33
def list(cls) -> TModel:
37
class DeletedModel(BaseModel):
43
class TextChoice(BaseModel):
47
finish_reason: Optional[str]
50
class Usage(BaseModel):
52
completion_tokens: int
57
cls, response: Union["AviaryModelResponse", Dict[str, Any]]
59
if isinstance(response, BaseModel):
60
response_dict = response.dict()
62
response_dict = response
64
prompt_tokens=response_dict["num_input_tokens"] or 0,
65
completion_tokens=response_dict["num_generated_tokens"] or 0,
66
total_tokens=(response_dict["num_input_tokens"] or 0)
67
+ (response_dict["num_generated_tokens"] or 0),
71
class EmbeddingsUsage(BaseModel):
76
class Completion(BaseModel):
81
choices: List[TextChoice]
82
usage: Optional[Usage]
89
use_prompt_format: bool = True,
90
max_tokens: Optional[int] = 16,
91
temperature: Optional[float] = 1.0,
92
top_p: Optional[float] = 1.0,
94
stop: Optional[List[str]] = None,
95
frequency_penalty: float = 0.0,
96
top_k: Optional[int] = None,
97
typical_p: Optional[float] = None,
98
watermark: Optional[bool] = False,
99
seed: Optional[int] = None,
104
class EmbeddingsData(BaseModel):
105
embedding: List[float]
110
class EmbeddingsOutput(BaseModel):
111
data: List[EmbeddingsData]
116
usage: Optional[EmbeddingsUsage]
119
class FunctionCall(BaseModel):
121
arguments: Optional[str] = None
124
class ToolCall(BaseModel):
125
function: FunctionCall
126
type: Literal["function"]
130
return str(self.dict())
133
class Function(BaseModel):
135
description: Optional[str] = None
136
parameters: Optional[Dict[str, Any]] = None
139
class LogProb(BaseModel):
145
class LogProbs(BaseModel):
149
top_logprobs: List[LogProb]
152
def create(cls, logprobs: List[LogProb], top_logprobs: Optional[int] = None):
153
assert len(logprobs) > 0, "logprobs must be a non-empty list"
154
token = logprobs[0].token
155
logprob = logprobs[0].logprob
156
bytes = logprobs[0].bytes
157
all_logprobs = logprobs if top_logprobs else []
158
ret = cls(token=token, logprob=logprob, bytes=bytes, top_logprobs=all_logprobs)
162
class ToolChoice(BaseModel):
163
type: Literal["function"]
167
class Tool(BaseModel):
168
type: Literal["function"]
172
class Message(BaseModel):
173
role: Literal["system", "assistant", "user", "tool"]
174
content: Optional[str] = None
175
tool_calls: Optional[List[ToolCall]] = None
176
tool_call_id: Optional[str] = None
182
if getattr(self, "tool_calls", None):
183
return str(self.content)
184
return str(self.dict())
187
def check_fields(cls, values):
188
if values["role"] in ["system", "user"]:
189
if not isinstance(values.get("content"), str):
190
raise ValueError("content must be a string")
191
if values["role"] == "tool":
192
if not isinstance(values.get("tool_call_id"), str):
193
raise ValueError("tool_call_id must be a str")
195
if not isinstance(values.get("content"), str):
197
"content must be a str with results or errors for " "the tool call"
199
if values["role"] == "assistant":
200
if values.get("tool_calls"):
202
if not isinstance(values.get("tool_calls"), list):
203
raise ValueError("tool_calls must be a list")
204
for tool_call in values["tool_calls"]:
205
if not isinstance(tool_call, ToolCall):
206
raise TypeError("Tool calls must be of type ToolCall")
210
not isinstance(values.get("content"), str)
211
or values.get("content") == ""
213
raise ValueError("content must be a string or None")
217
class DeltaRole(BaseModel):
218
role: Literal["system", "assistant", "user"]
224
class DeltaContent(BaseModel):
226
tool_calls: Optional[List[Dict[str, Any]]] = None
230
return str(self.tool_calls)
232
return str(self.dict())
235
class DeltaEOS(BaseModel):
240
class ChoiceLogProbs(BaseModel):
241
content: List[LogProbs]
244
class MessageChoices(BaseModel):
248
logprobs: Optional[ChoiceLogProbs] = None
251
class DeltaChoices(BaseModel):
252
delta: Union[DeltaRole, DeltaContent, DeltaEOS]
254
finish_reason: Optional[str]
255
logprobs: Optional[ChoiceLogProbs] = None
258
class ChatCompletion(BaseModel):
263
choices: List[Union[MessageChoices, DeltaChoices]]
264
usage: Optional[Usage]
270
messages: List[Dict[str, str]],
271
max_tokens: Optional[int] = None,
272
temperature: Optional[float] = 1.0,
273
top_p: Optional[float] = 1.0,
274
stream: bool = False,
275
stop: Optional[List[str]] = None,
276
frequency_penalty: float = 0.0,
277
top_k: Optional[int] = None,
278
typical_p: Optional[float] = None,
279
watermark: Optional[bool] = False,
280
seed: Optional[int] = None,
281
) -> TChatCompletion:
285
class Prompt(BaseModel):
286
prompt: Union[str, List[Message]]
287
use_prompt_format: bool = True
288
parameters: Optional[Union[Dict[str, Any], BaseModel]] = None
289
tools: Optional[List[Tool]] = None
290
tool_choice: Union[Literal["auto", "none"], ToolChoice] = "auto"
293
def check_prompt(cls, value):
294
if isinstance(value, list) and not value:
295
raise ValueError("Messages cannot be an empty list.")
298
def to_unformatted_string(self) -> str:
299
if isinstance(self.prompt, list):
300
return ", ".join(str(message.content) for message in self.prompt)
303
def get_log_str(self):
304
prompt_str = self.to_unformatted_string()
305
if PROMPT_TRACE_KEY in prompt_str:
306
start_idx = prompt_str.find(PROMPT_TRACE_KEY)
309
return prompt_str[start_idx : start_idx + len(PROMPT_TRACE_KEY) + 31]
314
class ErrorResponse(BaseModel):
316
internal_message: str
319
param: Dict[str, Any] = {}
322
class AbstractPromptFormat(BaseModel):
326
def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
327
raise NotImplementedError()
330
class PromptFormat(AbstractPromptFormat):
333
trailing_assistant: str
336
default_system_message: str = ""
337
system_in_user: bool = False
338
add_system_tags_even_if_message_is_empty: bool = False
339
strip_whitespace: bool = True
342
def check_system(cls, value):
344
"{instruction}" in value
345
), "system must be a string containing '{instruction}'"
348
@validator("assistant")
349
def check_assistant(cls, value):
351
value and "{instruction}" in value
352
), "assistant must be a string containing '{instruction}'"
356
def check_user(cls, value):
358
"{instruction}" in value
359
), "user must be a string containing '{instruction}'"
363
def check_user_system_in_user(cls, values):
364
if values["system_in_user"]:
366
"{system}" in values["user"]
367
), "If system_in_user=True, user must contain '{system}'"
370
def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
371
if isinstance(messages, Prompt):
372
if isinstance(messages.prompt, str):
373
if not messages.use_prompt_format:
374
return messages.prompt
376
if self.default_system_message:
378
Message(role="system", content=self.default_system_message),
381
Message(role="user", content=messages.prompt),
383
messages = new_messages
385
messages = messages.prompt
388
system_message_index = -1
389
for i, message in enumerate(messages):
390
if message.role == "system":
391
if system_message_index == -1:
392
system_message_index = i
395
status.HTTP_400_BAD_REQUEST,
396
"Only one system message can be specified.",
399
system_message = None
400
if system_message_index != -1:
401
system_message = messages.pop(system_message_index)
403
self.default_system_message or self.add_system_tags_even_if_message_is_empty
405
system_message = Message(role="system", content=self.default_system_message)
407
system_message is not None
409
system_message.content or self.add_system_tags_even_if_message_is_empty
411
and not self.system_in_user
413
messages.insert(0, system_message)
416
for message in messages:
417
message_content = message.content
418
if self.strip_whitespace:
419
message_content = message_content.strip()
420
if message.role == "system":
421
prompt.append(self.system.format(instruction=message_content))
422
elif message.role == "user":
423
if self.system_in_user:
426
instruction=message_content,
427
system=self.system.format(
428
instruction=system_message.content
434
system_message = None
436
prompt.append(self.user.format(instruction=message_content))
437
elif message.role == "assistant":
438
prompt.append(self.assistant.format(instruction=message_content))
439
prompt.append(self.trailing_assistant)
440
return "".join(prompt)
443
class DisabledPromptFormat(AbstractPromptFormat):
444
def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
446
isinstance(messages, Prompt)
447
and isinstance(messages.prompt, str)
448
and not messages.use_prompt_format
450
return messages.prompt
451
raise PromptFormatDisabledError(
452
"This model doesn't support chat completions. Please use the completions "