llama-index
320 строк · 11.2 Кб
1import json
2from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
3
4from llama_index.legacy.bridge.pydantic import Field
5from llama_index.legacy.callbacks import CallbackManager
6from llama_index.legacy.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
7from llama_index.legacy.core.llms.types import (
8ChatMessage,
9ChatResponse,
10ChatResponseAsyncGen,
11ChatResponseGen,
12CompletionResponse,
13CompletionResponseAsyncGen,
14CompletionResponseGen,
15LLMMetadata,
16MessageRole,
17)
18from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
19from llama_index.legacy.llms.llm import LLM
20from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
21
22DEFAULT_RUNGPT_MODEL = "rungpt"
23DEFAULT_RUNGPT_TEMP = 0.75
24
25
26class RunGptLLM(LLM):
27"""The opengpt of Jina AI models."""
28
29model: Optional[str] = Field(
30default=DEFAULT_RUNGPT_MODEL, description="The rungpt model to use."
31)
32endpoint: str = Field(description="The endpoint of serving address.")
33temperature: float = Field(
34default=DEFAULT_RUNGPT_TEMP,
35description="The temperature to use for sampling.",
36gte=0.0,
37lte=1.0,
38)
39max_tokens: int = Field(
40default=DEFAULT_NUM_OUTPUTS,
41description="Max tokens model generates.",
42gt=0,
43)
44context_window: int = Field(
45default=DEFAULT_CONTEXT_WINDOW,
46description="The maximum number of context tokens for the model.",
47gt=0,
48)
49additional_kwargs: Dict[str, Any] = Field(
50default_factory=dict, description="Additional kwargs for the Replicate API."
51)
52base_url: str = Field(
53description="The address of your target model served by rungpt."
54)
55
56def __init__(
57self,
58model: Optional[str] = DEFAULT_RUNGPT_MODEL,
59endpoint: str = "0.0.0.0:51002",
60temperature: float = DEFAULT_RUNGPT_TEMP,
61max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS,
62context_window: int = DEFAULT_CONTEXT_WINDOW,
63additional_kwargs: Optional[Dict[str, Any]] = None,
64callback_manager: Optional[CallbackManager] = None,
65system_prompt: Optional[str] = None,
66messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
67completion_to_prompt: Optional[Callable[[str], str]] = None,
68pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
69output_parser: Optional[BaseOutputParser] = None,
70):
71if endpoint.startswith("http://"):
72base_url = endpoint
73else:
74base_url = "http://" + endpoint
75super().__init__(
76model=model,
77endpoint=endpoint,
78temperature=temperature,
79max_tokens=max_tokens,
80context_window=context_window,
81additional_kwargs=additional_kwargs or {},
82callback_manager=callback_manager or CallbackManager([]),
83base_url=base_url,
84system_prompt=system_prompt,
85messages_to_prompt=messages_to_prompt,
86completion_to_prompt=completion_to_prompt,
87pydantic_program_mode=pydantic_program_mode,
88output_parser=output_parser,
89)
90
91@classmethod
92def class_name(cls) -> str:
93return "RunGptLLM"
94
95@property
96def metadata(self) -> LLMMetadata:
97"""LLM metadata."""
98return LLMMetadata(
99context_window=self.context_window,
100num_output=self.max_tokens,
101model_name=self._model,
102)
103
104@llm_completion_callback()
105def complete(
106self, prompt: str, formatted: bool = False, **kwargs: Any
107) -> CompletionResponse:
108try:
109import requests
110except ImportError:
111raise ImportError(
112"Could not import requests library."
113"Please install requests with `pip install requests`"
114)
115response_gpt = requests.post(
116self.base_url + "/generate",
117json=self._request_pack("complete", prompt, **kwargs),
118stream=False,
119).json()
120
121return CompletionResponse(
122text=response_gpt["choices"][0]["text"],
123additional_kwargs=response_gpt["usage"],
124raw=response_gpt,
125)
126
127@llm_completion_callback()
128def stream_complete(
129self, prompt: str, formatted: bool = False, **kwargs: Any
130) -> CompletionResponseGen:
131try:
132import requests
133except ImportError:
134raise ImportError(
135"Could not import requests library."
136"Please install requests with `pip install requests`"
137)
138response_gpt = requests.post(
139self.base_url + "/generate_stream",
140json=self._request_pack("complete", prompt, **kwargs),
141stream=True,
142)
143try:
144import sseclient
145except ImportError:
146raise ImportError(
147"Could not import sseclient-py library."
148"Please install requests with `pip install sseclient-py`"
149)
150client = sseclient.SSEClient(response_gpt)
151response_iter = client.events()
152
153def gen() -> CompletionResponseGen:
154text = ""
155for item in response_iter:
156item_dict = json.loads(json.dumps(eval(item.data)))
157delta = item_dict["choices"][0]["text"]
158additional_kwargs = item_dict["usage"]
159text = text + self._space_handler(delta)
160yield CompletionResponse(
161text=text,
162delta=delta,
163raw=item_dict,
164additional_kwargs=additional_kwargs,
165)
166
167return gen()
168
169@llm_chat_callback()
170def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
171message_list = self._message_wrapper(messages)
172try:
173import requests
174except ImportError:
175raise ImportError(
176"Could not import requests library."
177"Please install requests with `pip install requests`"
178)
179response_gpt = requests.post(
180self.base_url + "/chat",
181json=self._request_pack("chat", message_list, **kwargs),
182stream=False,
183).json()
184chat_message, _ = self._message_unpacker(response_gpt)
185return ChatResponse(message=chat_message, raw=response_gpt)
186
187@llm_chat_callback()
188def stream_chat(
189self, messages: Sequence[ChatMessage], **kwargs: Any
190) -> ChatResponseGen:
191message_list = self._message_wrapper(messages)
192try:
193import requests
194except ImportError:
195raise ImportError(
196"Could not import requests library."
197"Please install requests with `pip install requests`"
198)
199response_gpt = requests.post(
200self.base_url + "/chat_stream",
201json=self._request_pack("chat", message_list, **kwargs),
202stream=True,
203)
204try:
205import sseclient
206except ImportError:
207raise ImportError(
208"Could not import sseclient-py library."
209"Please install requests with `pip install sseclient-py`"
210)
211client = sseclient.SSEClient(response_gpt)
212chat_iter = client.events()
213
214def gen() -> ChatResponseGen:
215content = ""
216for item in chat_iter:
217item_dict = json.loads(json.dumps(eval(item.data)))
218chat_message, delta = self._message_unpacker(item_dict)
219content = content + self._space_handler(delta)
220chat_message.content = content
221yield ChatResponse(message=chat_message, raw=item_dict, delta=delta)
222
223return gen()
224
225@llm_chat_callback()
226async def achat(
227self,
228messages: Sequence[ChatMessage],
229**kwargs: Any,
230) -> ChatResponse:
231return self.chat(messages, **kwargs)
232
233@llm_chat_callback()
234async def astream_chat(
235self,
236messages: Sequence[ChatMessage],
237**kwargs: Any,
238) -> ChatResponseAsyncGen:
239async def gen() -> ChatResponseAsyncGen:
240for message in self.stream_chat(messages, **kwargs):
241yield message
242
243# NOTE: convert generator to async generator
244return gen()
245
246@llm_completion_callback()
247async def acomplete(
248self, prompt: str, formatted: bool = False, **kwargs: Any
249) -> CompletionResponse:
250return self.complete(prompt, **kwargs)
251
252@llm_completion_callback()
253async def astream_complete(
254self, prompt: str, formatted: bool = False, **kwargs: Any
255) -> CompletionResponseAsyncGen:
256async def gen() -> CompletionResponseAsyncGen:
257for message in self.stream_complete(prompt, **kwargs):
258yield message
259
260return gen()
261
262def _message_wrapper(self, messages: Sequence[ChatMessage]) -> List[Dict[str, Any]]:
263message_list = []
264for message in messages:
265role = message.role.value
266content = message.content
267message_list.append({"role": role, "content": content})
268return message_list
269
270def _message_unpacker(
271self, response_gpt: Dict[str, Any]
272) -> Tuple[ChatMessage, str]:
273message = response_gpt["choices"][0]["message"]
274additional_kwargs = response_gpt["usage"]
275role = message["role"]
276content = message["content"]
277key = MessageRole.SYSTEM
278for r in MessageRole:
279if r.value == role:
280key = r
281chat_message = ChatMessage(
282role=key, content=content, additional_kwargs=additional_kwargs
283)
284return chat_message, content
285
286def _request_pack(
287self, mode: str, prompt: Union[str, List[Dict[str, Any]]], **kwargs: Any
288) -> Optional[Dict[str, Any]]:
289if mode == "complete":
290return {
291"prompt": prompt,
292"max_tokens": kwargs.pop("max_tokens", self.max_tokens),
293"temperature": kwargs.pop("temperature", self.temperature),
294"top_k": kwargs.pop("top_k", 50),
295"top_p": kwargs.pop("top_p", 0.95),
296"repetition_penalty": kwargs.pop("repetition_penalty", 1.2),
297"do_sample": kwargs.pop("do_sample", False),
298"echo": kwargs.pop("echo", True),
299"n": kwargs.pop("n", 1),
300"stop": kwargs.pop("stop", "."),
301}
302elif mode == "chat":
303return {
304"messages": prompt,
305"max_tokens": kwargs.pop("max_tokens", self.max_tokens),
306"temperature": kwargs.pop("temperature", self.temperature),
307"top_k": kwargs.pop("top_k", 50),
308"top_p": kwargs.pop("top_p", 0.95),
309"repetition_penalty": kwargs.pop("repetition_penalty", 1.2),
310"do_sample": kwargs.pop("do_sample", False),
311"echo": kwargs.pop("echo", True),
312"n": kwargs.pop("n", 1),
313"stop": kwargs.pop("stop", "."),
314}
315return None
316
317def _space_handler(self, word: str) -> str:
318if word.isalnum():
319return " " + word
320return word
321