llama-index

Форк
0
320 строк · 11.2 Кб
1
import json
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
3

4
from llama_index.legacy.bridge.pydantic import Field
5
from llama_index.legacy.callbacks import CallbackManager
6
from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
7
from llama_index.legacy.core.llms.types import (
8
    ChatMessage,
9
    ChatResponse,
10
    ChatResponseAsyncGen,
11
    ChatResponseGen,
12
    CompletionResponse,
13
    CompletionResponseAsyncGen,
14
    CompletionResponseGen,
15
    LLMMetadata,
16
    MessageRole,
17
)
18
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
19
from llama_index.legacy.llms.llm import LLM
20
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
21

22
DEFAULT_RUNGPT_MODEL = "rungpt"
23
DEFAULT_RUNGPT_TEMP = 0.75
24

25

26
class RunGptLLM(LLM):
27
    """The opengpt of Jina AI models."""
28

29
    model: Optional[str] = Field(
30
        default=DEFAULT_RUNGPT_MODEL, description="The rungpt model to use."
31
    )
32
    endpoint: str = Field(description="The endpoint of serving address.")
33
    temperature: float = Field(
34
        default=DEFAULT_RUNGPT_TEMP,
35
        description="The temperature to use for sampling.",
36
        gte=0.0,
37
        lte=1.0,
38
    )
39
    max_tokens: int = Field(
40
        default=DEFAULT_NUM_OUTPUTS,
41
        description="Max tokens model generates.",
42
        gt=0,
43
    )
44
    context_window: int = Field(
45
        default=DEFAULT_CONTEXT_WINDOW,
46
        description="The maximum number of context tokens for the model.",
47
        gt=0,
48
    )
49
    additional_kwargs: Dict[str, Any] = Field(
50
        default_factory=dict, description="Additional kwargs for the Replicate API."
51
    )
52
    base_url: str = Field(
53
        description="The address of your target model served by rungpt."
54
    )
55

56
    def __init__(
57
        self,
58
        model: Optional[str] = DEFAULT_RUNGPT_MODEL,
59
        endpoint: str = "0.0.0.0:51002",
60
        temperature: float = DEFAULT_RUNGPT_TEMP,
61
        max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS,
62
        context_window: int = DEFAULT_CONTEXT_WINDOW,
63
        additional_kwargs: Optional[Dict[str, Any]] = None,
64
        callback_manager: Optional[CallbackManager] = None,
65
        system_prompt: Optional[str] = None,
66
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
67
        completion_to_prompt: Optional[Callable[[str], str]] = None,
68
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
69
        output_parser: Optional[BaseOutputParser] = None,
70
    ):
71
        if endpoint.startswith("http://"):
72
            base_url = endpoint
73
        else:
74
            base_url = "http://" + endpoint
75
        super().__init__(
76
            model=model,
77
            endpoint=endpoint,
78
            temperature=temperature,
79
            max_tokens=max_tokens,
80
            context_window=context_window,
81
            additional_kwargs=additional_kwargs or {},
82
            callback_manager=callback_manager or CallbackManager([]),
83
            base_url=base_url,
84
            system_prompt=system_prompt,
85
            messages_to_prompt=messages_to_prompt,
86
            completion_to_prompt=completion_to_prompt,
87
            pydantic_program_mode=pydantic_program_mode,
88
            output_parser=output_parser,
89
        )
90

91
    @classmethod
92
    def class_name(cls) -> str:
93
        return "RunGptLLM"
94

95
    @property
96
    def metadata(self) -> LLMMetadata:
97
        """LLM metadata."""
98
        return LLMMetadata(
99
            context_window=self.context_window,
100
            num_output=self.max_tokens,
101
            model_name=self._model,
102
        )
103

104
    @llm_completion_callback()
105
    def complete(
106
        self, prompt: str, formatted: bool = False, **kwargs: Any
107
    ) -> CompletionResponse:
108
        try:
109
            import requests
110
        except ImportError:
111
            raise ImportError(
112
                "Could not import requests library."
113
                "Please install requests with `pip install requests`"
114
            )
115
        response_gpt = requests.post(
116
            self.base_url + "/generate",
117
            json=self._request_pack("complete", prompt, **kwargs),
118
            stream=False,
119
        ).json()
120

121
        return CompletionResponse(
122
            text=response_gpt["choices"][0]["text"],
123
            additional_kwargs=response_gpt["usage"],
124
            raw=response_gpt,
125
        )
126

127
    @llm_completion_callback()
128
    def stream_complete(
129
        self, prompt: str, formatted: bool = False, **kwargs: Any
130
    ) -> CompletionResponseGen:
131
        try:
132
            import requests
133
        except ImportError:
134
            raise ImportError(
135
                "Could not import requests library."
136
                "Please install requests with `pip install requests`"
137
            )
138
        response_gpt = requests.post(
139
            self.base_url + "/generate_stream",
140
            json=self._request_pack("complete", prompt, **kwargs),
141
            stream=True,
142
        )
143
        try:
144
            import sseclient
145
        except ImportError:
146
            raise ImportError(
147
                "Could not import sseclient-py library."
148
                "Please install requests with `pip install sseclient-py`"
149
            )
150
        client = sseclient.SSEClient(response_gpt)
151
        response_iter = client.events()
152

153
        def gen() -> CompletionResponseGen:
154
            text = ""
155
            for item in response_iter:
156
                item_dict = json.loads(json.dumps(eval(item.data)))
157
                delta = item_dict["choices"][0]["text"]
158
                additional_kwargs = item_dict["usage"]
159
                text = text + self._space_handler(delta)
160
                yield CompletionResponse(
161
                    text=text,
162
                    delta=delta,
163
                    raw=item_dict,
164
                    additional_kwargs=additional_kwargs,
165
                )
166

167
        return gen()
168

169
    @llm_chat_callback()
170
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
171
        message_list = self._message_wrapper(messages)
172
        try:
173
            import requests
174
        except ImportError:
175
            raise ImportError(
176
                "Could not import requests library."
177
                "Please install requests with `pip install requests`"
178
            )
179
        response_gpt = requests.post(
180
            self.base_url + "/chat",
181
            json=self._request_pack("chat", message_list, **kwargs),
182
            stream=False,
183
        ).json()
184
        chat_message, _ = self._message_unpacker(response_gpt)
185
        return ChatResponse(message=chat_message, raw=response_gpt)
186

187
    @llm_chat_callback()
188
    def stream_chat(
189
        self, messages: Sequence[ChatMessage], **kwargs: Any
190
    ) -> ChatResponseGen:
191
        message_list = self._message_wrapper(messages)
192
        try:
193
            import requests
194
        except ImportError:
195
            raise ImportError(
196
                "Could not import requests library."
197
                "Please install requests with `pip install requests`"
198
            )
199
        response_gpt = requests.post(
200
            self.base_url + "/chat_stream",
201
            json=self._request_pack("chat", message_list, **kwargs),
202
            stream=True,
203
        )
204
        try:
205
            import sseclient
206
        except ImportError:
207
            raise ImportError(
208
                "Could not import sseclient-py library."
209
                "Please install requests with `pip install sseclient-py`"
210
            )
211
        client = sseclient.SSEClient(response_gpt)
212
        chat_iter = client.events()
213

214
        def gen() -> ChatResponseGen:
215
            content = ""
216
            for item in chat_iter:
217
                item_dict = json.loads(json.dumps(eval(item.data)))
218
                chat_message, delta = self._message_unpacker(item_dict)
219
                content = content + self._space_handler(delta)
220
                chat_message.content = content
221
                yield ChatResponse(message=chat_message, raw=item_dict, delta=delta)
222

223
        return gen()
224

225
    @llm_chat_callback()
226
    async def achat(
227
        self,
228
        messages: Sequence[ChatMessage],
229
        **kwargs: Any,
230
    ) -> ChatResponse:
231
        return self.chat(messages, **kwargs)
232

233
    @llm_chat_callback()
234
    async def astream_chat(
235
        self,
236
        messages: Sequence[ChatMessage],
237
        **kwargs: Any,
238
    ) -> ChatResponseAsyncGen:
239
        async def gen() -> ChatResponseAsyncGen:
240
            for message in self.stream_chat(messages, **kwargs):
241
                yield message
242

243
        # NOTE: convert generator to async generator
244
        return gen()
245

246
    @llm_completion_callback()
247
    async def acomplete(
248
        self, prompt: str, formatted: bool = False, **kwargs: Any
249
    ) -> CompletionResponse:
250
        return self.complete(prompt, **kwargs)
251

252
    @llm_completion_callback()
253
    async def astream_complete(
254
        self, prompt: str, formatted: bool = False, **kwargs: Any
255
    ) -> CompletionResponseAsyncGen:
256
        async def gen() -> CompletionResponseAsyncGen:
257
            for message in self.stream_complete(prompt, **kwargs):
258
                yield message
259

260
        return gen()
261

262
    def _message_wrapper(self, messages: Sequence[ChatMessage]) -> List[Dict[str, Any]]:
263
        message_list = []
264
        for message in messages:
265
            role = message.role.value
266
            content = message.content
267
            message_list.append({"role": role, "content": content})
268
        return message_list
269

270
    def _message_unpacker(
271
        self, response_gpt: Dict[str, Any]
272
    ) -> Tuple[ChatMessage, str]:
273
        message = response_gpt["choices"][0]["message"]
274
        additional_kwargs = response_gpt["usage"]
275
        role = message["role"]
276
        content = message["content"]
277
        key = MessageRole.SYSTEM
278
        for r in MessageRole:
279
            if r.value == role:
280
                key = r
281
        chat_message = ChatMessage(
282
            role=key, content=content, additional_kwargs=additional_kwargs
283
        )
284
        return chat_message, content
285

286
    def _request_pack(
287
        self, mode: str, prompt: Union[str, List[Dict[str, Any]]], **kwargs: Any
288
    ) -> Optional[Dict[str, Any]]:
289
        if mode == "complete":
290
            return {
291
                "prompt": prompt,
292
                "max_tokens": kwargs.pop("max_tokens", self.max_tokens),
293
                "temperature": kwargs.pop("temperature", self.temperature),
294
                "top_k": kwargs.pop("top_k", 50),
295
                "top_p": kwargs.pop("top_p", 0.95),
296
                "repetition_penalty": kwargs.pop("repetition_penalty", 1.2),
297
                "do_sample": kwargs.pop("do_sample", False),
298
                "echo": kwargs.pop("echo", True),
299
                "n": kwargs.pop("n", 1),
300
                "stop": kwargs.pop("stop", "."),
301
            }
302
        elif mode == "chat":
303
            return {
304
                "messages": prompt,
305
                "max_tokens": kwargs.pop("max_tokens", self.max_tokens),
306
                "temperature": kwargs.pop("temperature", self.temperature),
307
                "top_k": kwargs.pop("top_k", 50),
308
                "top_p": kwargs.pop("top_p", 0.95),
309
                "repetition_penalty": kwargs.pop("repetition_penalty", 1.2),
310
                "do_sample": kwargs.pop("do_sample", False),
311
                "echo": kwargs.pop("echo", True),
312
                "n": kwargs.pop("n", 1),
313
                "stop": kwargs.pop("stop", "."),
314
            }
315
        return None
316

317
    def _space_handler(self, word: str) -> str:
318
        if word.isalnum():
319
            return " " + word
320
        return word
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.