llama-index

Форк
0
262 строки · 8.9 Кб
1
import warnings
2
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
3

4
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.core.llms.types import (
7
    ChatMessage,
8
    ChatResponse,
9
    ChatResponseGen,
10
    CompletionResponse,
11
    CompletionResponseGen,
12
    LLMMetadata,
13
    MessageRole,
14
)
15
from llama_index.legacy.llms.base import (
16
    llm_chat_callback,
17
    llm_completion_callback,
18
)
19
from llama_index.legacy.llms.custom import CustomLLM
20
from llama_index.legacy.llms.xinference_utils import (
21
    xinference_message_to_history,
22
    xinference_modelname_to_contextsize,
23
)
24
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
25

26
# an approximation of the ratio between llama and GPT2 tokens
27
TOKEN_RATIO = 2.5
28
DEFAULT_XINFERENCE_TEMP = 1.0
29

30

31
class Xinference(CustomLLM):
32
    model_uid: str = Field(description="The Xinference model to use.")
33
    endpoint: str = Field(description="The Xinference endpoint URL to use.")
34
    temperature: float = Field(
35
        description="The temperature to use for sampling.", gte=0.0, lte=1.0
36
    )
37
    max_tokens: int = Field(
38
        description="The maximum new tokens to generate as answer.", gt=0
39
    )
40
    context_window: int = Field(
41
        description="The maximum number of context tokens for the model.", gt=0
42
    )
43
    model_description: Dict[str, Any] = Field(
44
        description="The model description from Xinference."
45
    )
46

47
    _generator: Any = PrivateAttr()
48

49
    def __init__(
50
        self,
51
        model_uid: str,
52
        endpoint: str,
53
        temperature: float = DEFAULT_XINFERENCE_TEMP,
54
        max_tokens: Optional[int] = None,
55
        callback_manager: Optional[CallbackManager] = None,
56
        system_prompt: Optional[str] = None,
57
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
58
        completion_to_prompt: Optional[Callable[[str], str]] = None,
59
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
60
        output_parser: Optional[BaseOutputParser] = None,
61
    ) -> None:
62
        generator, context_window, model_description = self.load_model(
63
            model_uid, endpoint
64
        )
65
        self._generator = generator
66
        if max_tokens is None:
67
            max_tokens = context_window // 4
68
        elif max_tokens > context_window:
69
            raise ValueError(
70
                f"received max_tokens {max_tokens} with context window {context_window}"
71
                "max_tokens can not exceed the context window of the model"
72
            )
73

74
        super().__init__(
75
            model_uid=model_uid,
76
            endpoint=endpoint,
77
            temperature=temperature,
78
            context_window=context_window,
79
            max_tokens=max_tokens,
80
            model_description=model_description,
81
            callback_manager=callback_manager,
82
            system_prompt=system_prompt,
83
            messages_to_prompt=messages_to_prompt,
84
            completion_to_prompt=completion_to_prompt,
85
            pydantic_program_mode=pydantic_program_mode,
86
            output_parser=output_parser,
87
        )
88

89
    def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
90
        try:
91
            from xinference.client import RESTfulClient
92
        except ImportError:
93
            raise ImportError(
94
                "Could not import Xinference library."
95
                'Please install Xinference with `pip install "xinference[all]"`'
96
            )
97

98
        client = RESTfulClient(endpoint)
99

100
        try:
101
            assert isinstance(client, RESTfulClient)
102
        except AssertionError:
103
            raise RuntimeError(
104
                "Could not create RESTfulClient instance."
105
                "Please make sure Xinference endpoint is running at the correct port."
106
            )
107

108
        generator = client.get_model(model_uid)
109
        model_description = client.list_models()[model_uid]
110

111
        try:
112
            assert generator is not None
113
            assert model_description is not None
114
        except AssertionError:
115
            raise RuntimeError(
116
                "Could not get model from endpoint."
117
                "Please make sure Xinference endpoint is running at the correct port."
118
            )
119

120
        model = model_description["model_name"]
121
        if "context_length" in model_description:
122
            context_window = model_description["context_length"]
123
        else:
124
            warnings.warn(
125
                """
126
            Parameter `context_length` not found in model description,
127
            using `xinference_modelname_to_contextsize` that is no longer maintained.
128
            Please update Xinference to the newest version.
129
            """
130
            )
131
            context_window = xinference_modelname_to_contextsize(model)
132

133
        return generator, context_window, model_description
134

135
    @classmethod
136
    def class_name(cls) -> str:
137
        return "Xinference_llm"
138

139
    @property
140
    def metadata(self) -> LLMMetadata:
141
        """LLM metadata."""
142
        assert isinstance(self.context_window, int)
143
        return LLMMetadata(
144
            context_window=int(self.context_window // TOKEN_RATIO),
145
            num_output=self.max_tokens,
146
            model_name=self.model_uid,
147
        )
148

149
    @property
150
    def _model_kwargs(self) -> Dict[str, Any]:
151
        assert self.context_window is not None
152
        base_kwargs = {
153
            "temperature": self.temperature,
154
            "max_length": self.context_window,
155
        }
156
        return {
157
            **base_kwargs,
158
            **self.model_description,
159
        }
160

161
    def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
162
        return {"prompt": prompt, **self._model_kwargs, **kwargs}
163

164
    @llm_chat_callback()
165
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
166
        assert self._generator is not None
167
        prompt = messages[-1].content if len(messages) > 0 else ""
168
        history = [xinference_message_to_history(message) for message in messages[:-1]]
169
        response_text = self._generator.chat(
170
            prompt=prompt,
171
            chat_history=history,
172
            generate_config={
173
                "stream": False,
174
                "temperature": self.temperature,
175
                "max_tokens": self.max_tokens,
176
            },
177
        )["choices"][0]["message"]["content"]
178
        return ChatResponse(
179
            message=ChatMessage(
180
                role=MessageRole.ASSISTANT,
181
                content=response_text,
182
            ),
183
            delta=None,
184
        )
185

186
    @llm_chat_callback()
187
    def stream_chat(
188
        self, messages: Sequence[ChatMessage], **kwargs: Any
189
    ) -> ChatResponseGen:
190
        assert self._generator is not None
191
        prompt = messages[-1].content if len(messages) > 0 else ""
192
        history = [xinference_message_to_history(message) for message in messages[:-1]]
193
        response_iter = self._generator.chat(
194
            prompt=prompt,
195
            chat_history=history,
196
            generate_config={
197
                "stream": True,
198
                "temperature": self.temperature,
199
                "max_tokens": self.max_tokens,
200
            },
201
        )
202

203
        def gen() -> ChatResponseGen:
204
            text = ""
205
            for c in response_iter:
206
                delta = c["choices"][0]["delta"].get("content", "")
207
                text += delta
208
                yield ChatResponse(
209
                    message=ChatMessage(
210
                        role=MessageRole.ASSISTANT,
211
                        content=text,
212
                    ),
213
                    delta=delta,
214
                )
215

216
        return gen()
217

218
    @llm_completion_callback()
219
    def complete(
220
        self, prompt: str, formatted: bool = False, **kwargs: Any
221
    ) -> CompletionResponse:
222
        assert self._generator is not None
223
        response_text = self._generator.chat(
224
            prompt=prompt,
225
            chat_history=None,
226
            generate_config={
227
                "stream": False,
228
                "temperature": self.temperature,
229
                "max_tokens": self.max_tokens,
230
            },
231
        )["choices"][0]["message"]["content"]
232
        return CompletionResponse(
233
            delta=None,
234
            text=response_text,
235
        )
236

237
    @llm_completion_callback()
238
    def stream_complete(
239
        self, prompt: str, formatted: bool = False, **kwargs: Any
240
    ) -> CompletionResponseGen:
241
        assert self._generator is not None
242
        response_iter = self._generator.chat(
243
            prompt=prompt,
244
            chat_history=None,
245
            generate_config={
246
                "stream": True,
247
                "temperature": self.temperature,
248
                "max_tokens": self.max_tokens,
249
            },
250
        )
251

252
        def gen() -> CompletionResponseGen:
253
            text = ""
254
            for c in response_iter:
255
                delta = c["choices"][0]["delta"].get("content", "")
256
                text += delta
257
                yield CompletionResponse(
258
                    delta=delta,
259
                    text=text,
260
                )
261

262
        return gen()
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.