ray-llm

Форк
0
/
models.py 
454 строки · 13.1 Кб
1
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypeVar, Union
2

3
from fastapi import HTTPException, status
4
from pydantic import BaseModel, root_validator, validator
5

6
if TYPE_CHECKING:
7
    from rayllm.backend.server.models import AviaryModelResponse
8

9
TModel = TypeVar("TModel", bound="Model")
10
TCompletion = TypeVar("TCompletion", bound="Completion")
11
TChatCompletion = TypeVar("TChatCompletion", bound="ChatCompletion")
12

13
PROMPT_TRACE_KEY = "+TRACE_"
14

15

16
class PromptFormatDisabledError(ValueError):
17
    status_code = 404
18

19

20
class ModelData(BaseModel):
21
    id: str
22
    object: str
23
    owned_by: str
24
    permission: List[str]
25
    rayllm_metadata: Dict[str, Any]
26

27

28
class Model(BaseModel):
29
    data: List[ModelData]
30
    object: str = "list"
31

32
    @classmethod
33
    def list(cls) -> TModel:
34
        pass
35

36

37
class DeletedModel(BaseModel):
38
    id: str
39
    object: str = "model"
40
    deleted: bool = True
41

42

43
class TextChoice(BaseModel):
44
    text: str
45
    index: int
46
    logprobs: dict
47
    finish_reason: Optional[str]
48

49

50
class Usage(BaseModel):
51
    prompt_tokens: int
52
    completion_tokens: int
53
    total_tokens: int
54

55
    @classmethod
56
    def from_response(
57
        cls, response: Union["AviaryModelResponse", Dict[str, Any]]
58
    ) -> "Usage":
59
        if isinstance(response, BaseModel):
60
            response_dict = response.dict()
61
        else:
62
            response_dict = response
63
        return cls(
64
            prompt_tokens=response_dict["num_input_tokens"] or 0,
65
            completion_tokens=response_dict["num_generated_tokens"] or 0,
66
            total_tokens=(response_dict["num_input_tokens"] or 0)
67
            + (response_dict["num_generated_tokens"] or 0),
68
        )
69

70

71
class EmbeddingsUsage(BaseModel):
72
    prompt_tokens: int
73
    total_tokens: int
74

75

76
class Completion(BaseModel):
77
    id: str
78
    object: str
79
    created: int
80
    model: str
81
    choices: List[TextChoice]
82
    usage: Optional[Usage]
83

84
    @classmethod
85
    def create(
86
        cls,
87
        model: str,
88
        prompt: str,
89
        use_prompt_format: bool = True,
90
        max_tokens: Optional[int] = 16,
91
        temperature: Optional[float] = 1.0,
92
        top_p: Optional[float] = 1.0,
93
        stream: bool = False,
94
        stop: Optional[List[str]] = None,
95
        frequency_penalty: float = 0.0,
96
        top_k: Optional[int] = None,
97
        typical_p: Optional[float] = None,
98
        watermark: Optional[bool] = False,
99
        seed: Optional[int] = None,
100
    ) -> TCompletion:
101
        pass
102

103

104
class EmbeddingsData(BaseModel):
105
    embedding: List[float]
106
    index: int
107
    object: str
108

109

110
class EmbeddingsOutput(BaseModel):
111
    data: List[EmbeddingsData]
112
    id: str
113
    object: str
114
    created: int
115
    model: str
116
    usage: Optional[EmbeddingsUsage]
117

118

119
class FunctionCall(BaseModel):
120
    name: str
121
    arguments: Optional[str] = None
122

123

124
class ToolCall(BaseModel):
125
    function: FunctionCall
126
    type: Literal["function"]
127
    id: str
128

129
    def __str__(self):
130
        return str(self.dict())
131

132

133
class Function(BaseModel):
134
    name: str
135
    description: Optional[str] = None
136
    parameters: Optional[Dict[str, Any]] = None
137

138

139
class LogProb(BaseModel):
140
    logprob: float
141
    token: str
142
    bytes: List[int]
143

144

145
class LogProbs(BaseModel):
146
    token: str
147
    logprob: float
148
    bytes: List[int]
149
    top_logprobs: List[LogProb]
150

151
    @classmethod
152
    def create(cls, logprobs: List[LogProb], top_logprobs: Optional[int] = None):
153
        assert len(logprobs) > 0, "logprobs must be a non-empty list"
154
        token = logprobs[0].token
155
        logprob = logprobs[0].logprob
156
        bytes = logprobs[0].bytes
157
        all_logprobs = logprobs if top_logprobs else []
158
        ret = cls(token=token, logprob=logprob, bytes=bytes, top_logprobs=all_logprobs)
159
        return ret
160

161

162
class ToolChoice(BaseModel):
163
    type: Literal["function"]
164
    function: Function
165

166

167
class Tool(BaseModel):
168
    type: Literal["function"]
169
    function: Function
170

171

172
class Message(BaseModel):
173
    role: Literal["system", "assistant", "user", "tool"]
174
    content: Optional[str] = None
175
    tool_calls: Optional[List[ToolCall]] = None
176
    tool_call_id: Optional[str] = None
177

178
    def __str__(self):
179
        # if tool_calls is not None, then we are passing a tool message
180
        # using get attr instead of  just in case the attribute is deleted off of
181
        # the object
182
        if getattr(self, "tool_calls", None):
183
            return str(self.content)
184
        return str(self.dict())
185

186
    @root_validator
187
    def check_fields(cls, values):
188
        if values["role"] in ["system", "user"]:
189
            if not isinstance(values.get("content"), str):
190
                raise ValueError("content must be a string")
191
        if values["role"] == "tool":
192
            if not isinstance(values.get("tool_call_id"), str):
193
                raise ValueError("tool_call_id must be a str")
194
            # content should either be a dict with errors or with results
195
            if not isinstance(values.get("content"), str):
196
                raise ValueError(
197
                    "content must be a str with results or errors for " "the tool call"
198
                )
199
        if values["role"] == "assistant":
200
            if values.get("tool_calls"):
201
                # passing a regular assistant message
202
                if not isinstance(values.get("tool_calls"), list):
203
                    raise ValueError("tool_calls must be a list")
204
                for tool_call in values["tool_calls"]:
205
                    if not isinstance(tool_call, ToolCall):
206
                        raise TypeError("Tool calls must be of type ToolCall")
207
            else:
208
                # passing a regular assistant message
209
                if (
210
                    not isinstance(values.get("content"), str)
211
                    or values.get("content") == ""
212
                ):
213
                    raise ValueError("content must be a string or None")
214
        return values
215

216

217
class DeltaRole(BaseModel):
218
    role: Literal["system", "assistant", "user"]
219

220
    def __str__(self):
221
        return self.role
222

223

224
class DeltaContent(BaseModel):
225
    content: str
226
    tool_calls: Optional[List[Dict[str, Any]]] = None
227

228
    def __str__(self):
229
        if self.tool_calls:
230
            return str(self.tool_calls)
231
        else:
232
            return str(self.dict())
233

234

235
class DeltaEOS(BaseModel):
236
    class Config:
237
        extra = "forbid"
238

239

240
class ChoiceLogProbs(BaseModel):
241
    content: List[LogProbs]
242

243

244
class MessageChoices(BaseModel):
245
    message: Message
246
    index: int
247
    finish_reason: str
248
    logprobs: Optional[ChoiceLogProbs] = None
249

250

251
class DeltaChoices(BaseModel):
252
    delta: Union[DeltaRole, DeltaContent, DeltaEOS]
253
    index: int
254
    finish_reason: Optional[str]
255
    logprobs: Optional[ChoiceLogProbs] = None
256

257

258
class ChatCompletion(BaseModel):
259
    id: str
260
    object: str
261
    created: int
262
    model: str
263
    choices: List[Union[MessageChoices, DeltaChoices]]
264
    usage: Optional[Usage]
265

266
    @classmethod
267
    def create(
268
        cls,
269
        model: str,
270
        messages: List[Dict[str, str]],
271
        max_tokens: Optional[int] = None,
272
        temperature: Optional[float] = 1.0,
273
        top_p: Optional[float] = 1.0,
274
        stream: bool = False,
275
        stop: Optional[List[str]] = None,
276
        frequency_penalty: float = 0.0,
277
        top_k: Optional[int] = None,
278
        typical_p: Optional[float] = None,
279
        watermark: Optional[bool] = False,
280
        seed: Optional[int] = None,
281
    ) -> TChatCompletion:
282
        pass
283

284

285
class Prompt(BaseModel):
286
    prompt: Union[str, List[Message]]
287
    use_prompt_format: bool = True
288
    parameters: Optional[Union[Dict[str, Any], BaseModel]] = None
289
    tools: Optional[List[Tool]] = None
290
    tool_choice: Union[Literal["auto", "none"], ToolChoice] = "auto"
291

292
    @validator("prompt")
293
    def check_prompt(cls, value):
294
        if isinstance(value, list) and not value:
295
            raise ValueError("Messages cannot be an empty list.")
296
        return value
297

298
    def to_unformatted_string(self) -> str:
299
        if isinstance(self.prompt, list):
300
            return ", ".join(str(message.content) for message in self.prompt)
301
        return self.prompt
302

303
    def get_log_str(self):
304
        prompt_str = self.to_unformatted_string()
305
        if PROMPT_TRACE_KEY in prompt_str:
306
            start_idx = prompt_str.find(PROMPT_TRACE_KEY)
307

308
            # Grab the prompt key and the next following 30 chars.
309
            return prompt_str[start_idx : start_idx + len(PROMPT_TRACE_KEY) + 31]
310
        else:
311
            return None
312

313

314
class ErrorResponse(BaseModel):
315
    message: str
316
    internal_message: str
317
    code: int
318
    type: str
319
    param: Dict[str, Any] = {}
320

321

322
class AbstractPromptFormat(BaseModel):
323
    class Config:
324
        extra = "forbid"
325

326
    def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
327
        raise NotImplementedError()
328

329

330
class PromptFormat(AbstractPromptFormat):
331
    system: str
332
    assistant: str
333
    trailing_assistant: str
334
    user: str
335

336
    default_system_message: str = ""
337
    system_in_user: bool = False
338
    add_system_tags_even_if_message_is_empty: bool = False
339
    strip_whitespace: bool = True
340

341
    @validator("system")
342
    def check_system(cls, value):
343
        assert value and (
344
            "{instruction}" in value
345
        ), "system must be a string containing '{instruction}'"
346
        return value
347

348
    @validator("assistant")
349
    def check_assistant(cls, value):
350
        assert (
351
            value and "{instruction}" in value
352
        ), "assistant must be a string containing '{instruction}'"
353
        return value
354

355
    @validator("user")
356
    def check_user(cls, value):
357
        assert value and (
358
            "{instruction}" in value
359
        ), "user must be a string containing '{instruction}'"
360
        return value
361

362
    @root_validator
363
    def check_user_system_in_user(cls, values):
364
        if values["system_in_user"]:
365
            assert (
366
                "{system}" in values["user"]
367
            ), "If system_in_user=True, user must contain '{system}'"
368
        return values
369

370
    def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
371
        if isinstance(messages, Prompt):
372
            if isinstance(messages.prompt, str):
373
                if not messages.use_prompt_format:
374
                    return messages.prompt
375
                new_messages = []
376
                if self.default_system_message:
377
                    new_messages.append(
378
                        Message(role="system", content=self.default_system_message),
379
                    )
380
                new_messages.append(
381
                    Message(role="user", content=messages.prompt),
382
                )
383
                messages = new_messages
384
            else:
385
                messages = messages.prompt
386

387
        # Get system message
388
        system_message_index = -1
389
        for i, message in enumerate(messages):
390
            if message.role == "system":
391
                if system_message_index == -1:
392
                    system_message_index = i
393
                else:
394
                    raise HTTPException(
395
                        status.HTTP_400_BAD_REQUEST,
396
                        "Only one system message can be specified.",
397
                    )
398

399
        system_message = None
400
        if system_message_index != -1:
401
            system_message = messages.pop(system_message_index)
402
        elif (
403
            self.default_system_message or self.add_system_tags_even_if_message_is_empty
404
        ):
405
            system_message = Message(role="system", content=self.default_system_message)
406
        if (
407
            system_message is not None
408
            and (
409
                system_message.content or self.add_system_tags_even_if_message_is_empty
410
            )
411
            and not self.system_in_user
412
        ):
413
            messages.insert(0, system_message)
414

415
        prompt = []
416
        for message in messages:
417
            message_content = message.content
418
            if self.strip_whitespace:
419
                message_content = message_content.strip()
420
            if message.role == "system":
421
                prompt.append(self.system.format(instruction=message_content))
422
            elif message.role == "user":
423
                if self.system_in_user:
424
                    prompt.append(
425
                        self.user.format(
426
                            instruction=message_content,
427
                            system=self.system.format(
428
                                instruction=system_message.content
429
                            )
430
                            if system_message
431
                            else "",
432
                        )
433
                    )
434
                    system_message = None
435
                else:
436
                    prompt.append(self.user.format(instruction=message_content))
437
            elif message.role == "assistant":
438
                prompt.append(self.assistant.format(instruction=message_content))
439
        prompt.append(self.trailing_assistant)
440
        return "".join(prompt)
441

442

443
class DisabledPromptFormat(AbstractPromptFormat):
444
    def generate_prompt(self, messages: Union[Prompt, List[Message]]) -> str:
445
        if (
446
            isinstance(messages, Prompt)
447
            and isinstance(messages.prompt, str)
448
            and not messages.use_prompt_format
449
        ):
450
            return messages.prompt
451
        raise PromptFormatDisabledError(
452
            "This model doesn't support chat completions. Please use the completions "
453
            "endpoint instead."
454
        )
455

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.