llama-index
262 строки · 8.9 Кб
1import warnings
2from typing import Any, Callable, Dict, Optional, Sequence, Tuple
3
4from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5from llama_index.legacy.callbacks import CallbackManager
6from llama_index.legacy.core.llms.types import (
7ChatMessage,
8ChatResponse,
9ChatResponseGen,
10CompletionResponse,
11CompletionResponseGen,
12LLMMetadata,
13MessageRole,
14)
15from llama_index.legacy.llms.base import (
16llm_chat_callback,
17llm_completion_callback,
18)
19from llama_index.legacy.llms.custom import CustomLLM
20from llama_index.legacy.llms.xinference_utils import (
21xinference_message_to_history,
22xinference_modelname_to_contextsize,
23)
24from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
25
26# an approximation of the ratio between llama and GPT2 tokens
27TOKEN_RATIO = 2.5
28DEFAULT_XINFERENCE_TEMP = 1.0
29
30
31class Xinference(CustomLLM):
32model_uid: str = Field(description="The Xinference model to use.")
33endpoint: str = Field(description="The Xinference endpoint URL to use.")
34temperature: float = Field(
35description="The temperature to use for sampling.", gte=0.0, lte=1.0
36)
37max_tokens: int = Field(
38description="The maximum new tokens to generate as answer.", gt=0
39)
40context_window: int = Field(
41description="The maximum number of context tokens for the model.", gt=0
42)
43model_description: Dict[str, Any] = Field(
44description="The model description from Xinference."
45)
46
47_generator: Any = PrivateAttr()
48
49def __init__(
50self,
51model_uid: str,
52endpoint: str,
53temperature: float = DEFAULT_XINFERENCE_TEMP,
54max_tokens: Optional[int] = None,
55callback_manager: Optional[CallbackManager] = None,
56system_prompt: Optional[str] = None,
57messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
58completion_to_prompt: Optional[Callable[[str], str]] = None,
59pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
60output_parser: Optional[BaseOutputParser] = None,
61) -> None:
62generator, context_window, model_description = self.load_model(
63model_uid, endpoint
64)
65self._generator = generator
66if max_tokens is None:
67max_tokens = context_window // 4
68elif max_tokens > context_window:
69raise ValueError(
70f"received max_tokens {max_tokens} with context window {context_window}"
71"max_tokens can not exceed the context window of the model"
72)
73
74super().__init__(
75model_uid=model_uid,
76endpoint=endpoint,
77temperature=temperature,
78context_window=context_window,
79max_tokens=max_tokens,
80model_description=model_description,
81callback_manager=callback_manager,
82system_prompt=system_prompt,
83messages_to_prompt=messages_to_prompt,
84completion_to_prompt=completion_to_prompt,
85pydantic_program_mode=pydantic_program_mode,
86output_parser=output_parser,
87)
88
89def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
90try:
91from xinference.client import RESTfulClient
92except ImportError:
93raise ImportError(
94"Could not import Xinference library."
95'Please install Xinference with `pip install "xinference[all]"`'
96)
97
98client = RESTfulClient(endpoint)
99
100try:
101assert isinstance(client, RESTfulClient)
102except AssertionError:
103raise RuntimeError(
104"Could not create RESTfulClient instance."
105"Please make sure Xinference endpoint is running at the correct port."
106)
107
108generator = client.get_model(model_uid)
109model_description = client.list_models()[model_uid]
110
111try:
112assert generator is not None
113assert model_description is not None
114except AssertionError:
115raise RuntimeError(
116"Could not get model from endpoint."
117"Please make sure Xinference endpoint is running at the correct port."
118)
119
120model = model_description["model_name"]
121if "context_length" in model_description:
122context_window = model_description["context_length"]
123else:
124warnings.warn(
125"""
126Parameter `context_length` not found in model description,
127using `xinference_modelname_to_contextsize` that is no longer maintained.
128Please update Xinference to the newest version.
129"""
130)
131context_window = xinference_modelname_to_contextsize(model)
132
133return generator, context_window, model_description
134
135@classmethod
136def class_name(cls) -> str:
137return "Xinference_llm"
138
139@property
140def metadata(self) -> LLMMetadata:
141"""LLM metadata."""
142assert isinstance(self.context_window, int)
143return LLMMetadata(
144context_window=int(self.context_window // TOKEN_RATIO),
145num_output=self.max_tokens,
146model_name=self.model_uid,
147)
148
149@property
150def _model_kwargs(self) -> Dict[str, Any]:
151assert self.context_window is not None
152base_kwargs = {
153"temperature": self.temperature,
154"max_length": self.context_window,
155}
156return {
157**base_kwargs,
158**self.model_description,
159}
160
161def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
162return {"prompt": prompt, **self._model_kwargs, **kwargs}
163
164@llm_chat_callback()
165def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
166assert self._generator is not None
167prompt = messages[-1].content if len(messages) > 0 else ""
168history = [xinference_message_to_history(message) for message in messages[:-1]]
169response_text = self._generator.chat(
170prompt=prompt,
171chat_history=history,
172generate_config={
173"stream": False,
174"temperature": self.temperature,
175"max_tokens": self.max_tokens,
176},
177)["choices"][0]["message"]["content"]
178return ChatResponse(
179message=ChatMessage(
180role=MessageRole.ASSISTANT,
181content=response_text,
182),
183delta=None,
184)
185
186@llm_chat_callback()
187def stream_chat(
188self, messages: Sequence[ChatMessage], **kwargs: Any
189) -> ChatResponseGen:
190assert self._generator is not None
191prompt = messages[-1].content if len(messages) > 0 else ""
192history = [xinference_message_to_history(message) for message in messages[:-1]]
193response_iter = self._generator.chat(
194prompt=prompt,
195chat_history=history,
196generate_config={
197"stream": True,
198"temperature": self.temperature,
199"max_tokens": self.max_tokens,
200},
201)
202
203def gen() -> ChatResponseGen:
204text = ""
205for c in response_iter:
206delta = c["choices"][0]["delta"].get("content", "")
207text += delta
208yield ChatResponse(
209message=ChatMessage(
210role=MessageRole.ASSISTANT,
211content=text,
212),
213delta=delta,
214)
215
216return gen()
217
218@llm_completion_callback()
219def complete(
220self, prompt: str, formatted: bool = False, **kwargs: Any
221) -> CompletionResponse:
222assert self._generator is not None
223response_text = self._generator.chat(
224prompt=prompt,
225chat_history=None,
226generate_config={
227"stream": False,
228"temperature": self.temperature,
229"max_tokens": self.max_tokens,
230},
231)["choices"][0]["message"]["content"]
232return CompletionResponse(
233delta=None,
234text=response_text,
235)
236
237@llm_completion_callback()
238def stream_complete(
239self, prompt: str, formatted: bool = False, **kwargs: Any
240) -> CompletionResponseGen:
241assert self._generator is not None
242response_iter = self._generator.chat(
243prompt=prompt,
244chat_history=None,
245generate_config={
246"stream": True,
247"temperature": self.temperature,
248"max_tokens": self.max_tokens,
249},
250)
251
252def gen() -> CompletionResponseGen:
253text = ""
254for c in response_iter:
255delta = c["choices"][0]["delta"].get("content", "")
256text += delta
257yield CompletionResponse(
258delta=delta,
259text=text,
260)
261
262return gen()
263