llama-index

Форк
0
398 строк · 14.7 Кб
1
import json
2
from typing import Any, Callable, Dict, Optional, Sequence
3

4
import httpx
5
import requests
6

7
from llama_index.legacy.bridge.pydantic import Field
8
from llama_index.legacy.callbacks import CallbackManager
9
from llama_index.legacy.core.llms.types import (
10
    ChatMessage,
11
    ChatResponse,
12
    ChatResponseAsyncGen,
13
    ChatResponseGen,
14
    CompletionResponse,
15
    CompletionResponseAsyncGen,
16
    CompletionResponseGen,
17
    LLMMetadata,
18
)
19
from llama_index.legacy.llms.base import llm_chat_callback, llm_completion_callback
20
from llama_index.legacy.llms.llm import LLM
21
from llama_index.legacy.types import BaseOutputParser, PydanticProgramMode
22

23

24
class Perplexity(LLM):
25
    model: str = Field(description="The Perplexity model to use.")
26
    temperature: float = Field(description="The temperature to use during generation.")
27
    max_tokens: Optional[int] = Field(
28
        default=None,
29
        description="The maximum number of tokens to generate.",
30
    )
31
    context_window: Optional[int] = Field(
32
        default=None,
33
        description="The context window to use during generation.",
34
    )
35
    api_key: str = Field(
36
        default=None, description="The Perplexity API key.", exclude=True
37
    )
38
    api_base: str = Field(
39
        default="https://api.perplexity.ai",
40
        description="The base URL for Perplexity API.",
41
    )
42
    additional_kwargs: Dict[str, Any] = Field(
43
        default_factory=dict, description="Additional kwargs for the Perplexity API."
44
    )
45
    max_retries: int = Field(
46
        default=10, description="The maximum number of API retries."
47
    )
48
    headers: Dict[str, str] = Field(
49
        default_factory=dict, description="Headers for API requests."
50
    )
51

52
    def __init__(
53
        self,
54
        model: str = "mistral-7b-instruct",
55
        temperature: float = 0.1,
56
        max_tokens: Optional[int] = None,
57
        api_key: Optional[str] = None,
58
        api_base: Optional[str] = "https://api.perplexity.ai",
59
        additional_kwargs: Optional[Dict[str, Any]] = None,
60
        max_retries: int = 10,
61
        context_window: Optional[int] = None,
62
        callback_manager: Optional[CallbackManager] = None,
63
        system_prompt: Optional[str] = None,
64
        messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
65
        completion_to_prompt: Optional[Callable[[str], str]] = None,
66
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
67
        output_parser: Optional[BaseOutputParser] = None,
68
        **kwargs: Any,
69
    ) -> None:
70
        additional_kwargs = additional_kwargs or {}
71
        headers = {
72
            "accept": "application/json",
73
            "content-type": "application/json",
74
            "authorization": f"Bearer {api_key}",
75
        }
76
        super().__init__(
77
            model=model,
78
            temperature=temperature,
79
            max_tokens=max_tokens,
80
            additional_kwargs=additional_kwargs,
81
            max_retries=max_retries,
82
            callback_manager=callback_manager,
83
            api_key=api_key,
84
            api_base=api_base,
85
            headers=headers,
86
            context_window=context_window,
87
            system_prompt=system_prompt,
88
            messages_to_prompt=messages_to_prompt,
89
            completion_to_prompt=completion_to_prompt,
90
            pydantic_program_mode=pydantic_program_mode,
91
            output_parser=output_parser,
92
            **kwargs,
93
        )
94

95
    @classmethod
96
    def class_name(cls) -> str:
97
        return "perplexity_llm"
98

99
    @property
100
    def metadata(self) -> LLMMetadata:
101
        return LLMMetadata(
102
            context_window=(
103
                self.context_window
104
                if self.context_window is not None
105
                else self._get_context_window()
106
            ),
107
            num_output=self.max_tokens
108
            or -1,  # You can replace this with the appropriate value
109
            is_chat_model=self._is_chat_model(),
110
            model_name=self.model,
111
        )
112

113
    def _get_context_window(self) -> int:
114
        model_context_windows = {
115
            "codellama-34b-instruct": 16384,
116
            "llama-2-70b-chat": 4096,
117
            "mistral-7b-instruct": 4096,
118
            "mixtral-8x7b-instruct": 4096,
119
            "pplx-7b-chat": 8192,
120
            "pplx-70b-chat": 4096,
121
            "pplx-7b-online": 4096,
122
            "pplx-70b-online": 4096,
123
        }
124
        return model_context_windows.get(
125
            self.model, 4096
126
        )  # Default to 4096 if model not found
127

128
    def _is_chat_model(self) -> bool:
129
        chat_models = {
130
            "codellama-34b-instruct",
131
            "llama-2-70b-chat",
132
            "mistral-7b-instruct",
133
            "mixtral-8x7b-instruct",
134
            "pplx-7b-chat",
135
            "pplx-70b-chat",
136
            "pplx-7b-online",
137
            "pplx-70b-online",
138
        }
139
        return self.model in chat_models
140

141
    def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
142
        """Get all data for the request as a dictionary."""
143
        base_kwargs = {
144
            "model": self.model,
145
            "temperature": self.temperature,
146
        }
147
        if self.max_tokens is not None:
148
            base_kwargs["max_tokens"] = self.max_tokens
149
        return {**base_kwargs, **self.additional_kwargs, **kwargs}
150

151
    def _complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
152
        url = f"{self.api_base}/chat/completions"
153
        payload = {
154
            "model": self.model,
155
            "messages": [
156
                {"role": "system", "content": self.system_prompt},
157
                {
158
                    "role": "user",
159
                    "content": prompt,
160
                },
161
            ],
162
            **self._get_all_kwargs(**kwargs),
163
        }
164
        response = requests.post(url, json=payload, headers=self.headers)
165
        response.raise_for_status()
166
        data = response.json()
167
        return CompletionResponse(text=data["choices"][0]["message"], raw=data)
168

169
    @llm_completion_callback()
170
    def complete(
171
        self, prompt: str, formatted: bool = False, **kwargs: Any
172
    ) -> CompletionResponse:
173
        if self._is_chat_model():
174
            raise ValueError("The complete method is not supported for chat models.")
175
        return self._complete(prompt, **kwargs)
176

177
    def _chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
178
        url = f"{self.api_base}/chat/completions"
179
        payload = {
180
            "model": self.model,
181
            "messages": [
182
                message.dict(exclude={"additional_kwargs"}) for message in messages
183
            ],
184
            **self._get_all_kwargs(**kwargs),
185
        }
186
        response = requests.post(url, json=payload, headers=self.headers)
187
        response.raise_for_status()
188
        data = response.json()
189
        message = ChatMessage(
190
            role="assistant", content=data["choices"][0]["message"]["content"]
191
        )
192
        return ChatResponse(message=message, raw=data)
193

194
    @llm_chat_callback()
195
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
196
        return self._chat(messages, **kwargs)
197

198
    async def _acomplete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
199
        url = f"{self.api_base}/chat/completions"
200
        payload = {
201
            "model": self.model,
202
            "prompt": prompt,
203
            **self._get_all_kwargs(**kwargs),
204
        }
205
        async with httpx.AsyncClient() as client:
206
            response = await client.post(url, json=payload, headers=self.headers)
207
        response.raise_for_status()
208
        data = response.json()
209
        return CompletionResponse(text=data["choices"][0]["text"], raw=data)
210

211
    @llm_completion_callback()
212
    async def acomplete(
213
        self, prompt: str, formatted: bool = False, **kwargs: Any
214
    ) -> CompletionResponse:
215
        if self._is_chat_model():
216
            raise ValueError("The complete method is not supported for chat models.")
217
        return await self._acomplete(prompt, **kwargs)
218

219
    async def _achat(
220
        self, messages: Sequence[ChatMessage], **kwargs: Any
221
    ) -> ChatResponse:
222
        url = f"{self.api_base}/chat/completions"
223
        payload = {
224
            "model": self.model,
225
            "messages": [
226
                message.dict(exclude={"additional_kwargs"}) for message in messages
227
            ],
228
            **self._get_all_kwargs(**kwargs),
229
        }
230
        async with httpx.AsyncClient() as client:
231
            response = await client.post(url, json=payload, headers=self.headers)
232
        response.raise_for_status()
233
        data = response.json()
234
        message = ChatMessage(
235
            role="assistant", content=data["choices"][0]["message"]["content"]
236
        )
237
        return ChatResponse(message=message, raw=data)
238

239
    @llm_chat_callback()
240
    async def achat(
241
        self, messages: Sequence[ChatMessage], **kwargs: Any
242
    ) -> ChatResponse:
243
        return await self._achat(messages, **kwargs)
244

245
    def _stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
246
        url = f"{self.api_base}/chat/completions"
247
        payload = {
248
            "model": self.model,
249
            "prompt": prompt,
250
            "stream": True,
251
            **self._get_all_kwargs(**kwargs),
252
        }
253

254
        def gen() -> CompletionResponseGen:
255
            with requests.Session() as session:
256
                with session.post(
257
                    url, json=payload, headers=self.headers, stream=True
258
                ) as response:
259
                    response.raise_for_status()
260
                    text = ""
261
                    for line in response.iter_lines(
262
                        decode_unicode=True
263
                    ):  # decode lines to Unicode
264
                        if line.startswith("data:"):
265
                            data = json.loads(line[5:])
266
                            delta = data["choices"][0]["text"]
267
                            text += delta
268
                            yield CompletionResponse(delta=delta, text=text, raw=data)
269

270
        return gen()
271

272
    @llm_completion_callback()
273
    def stream_complete(
274
        self, prompt: str, formatted: bool = False, **kwargs: Any
275
    ) -> CompletionResponseGen:
276
        if self._is_chat_model():
277
            raise ValueError("The complete method is not supported for chat models.")
278
        stream_complete_fn = self._stream_complete
279
        return stream_complete_fn(prompt, **kwargs)
280

281
    async def _astream_complete(
282
        self, prompt: str, **kwargs: Any
283
    ) -> CompletionResponseAsyncGen:
284
        import aiohttp
285

286
        url = f"{self.api_base}/chat/completions"
287
        payload = {
288
            "model": self.model,
289
            "prompt": prompt,
290
            "stream": True,
291
            **self._get_all_kwargs(**kwargs),
292
        }
293

294
        async def gen() -> CompletionResponseAsyncGen:
295
            async with aiohttp.ClientSession() as session:
296
                async with session.post(
297
                    url, json=payload, headers=self.headers
298
                ) as response:
299
                    response.raise_for_status()
300
                    text = ""
301
                    async for line in response.content:
302
                        line_text = line.decode("utf-8").strip()
303
                        if line_text.startswith("data:"):
304
                            data = json.loads(line_text[5:])
305
                            delta = data["choices"][0]["text"]
306
                            text += delta
307
                            yield CompletionResponse(delta=delta, text=text, raw=data)
308

309
        return gen()
310

311
    @llm_completion_callback()
312
    async def astream_complete(
313
        self, prompt: str, formatted: bool = False, **kwargs: Any
314
    ) -> CompletionResponseAsyncGen:
315
        if self._is_chat_model():
316
            raise ValueError("The complete method is not supported for chat models.")
317
        return await self._astream_complete(prompt, **kwargs)
318

319
    def _stream_chat(
320
        self, messages: Sequence[ChatMessage], **kwargs: Any
321
    ) -> ChatResponseGen:
322
        url = f"{self.api_base}/chat/completions"
323
        payload = {
324
            "model": self.model,
325
            "messages": [
326
                message.dict(exclude={"additional_kwargs"}) for message in messages
327
            ],
328
            "stream": True,
329
            **self._get_all_kwargs(**kwargs),
330
        }
331

332
        def gen() -> ChatResponseGen:
333
            content = ""
334
            with requests.Session() as session:
335
                with session.post(
336
                    url, json=payload, headers=self.headers, stream=True
337
                ) as response:
338
                    response.raise_for_status()
339
                    for line in response.iter_lines(
340
                        decode_unicode=True
341
                    ):  # decode lines to Unicode
342
                        if line.startswith("data:"):
343
                            data = json.loads(line[5:])
344
                            delta = data["choices"][0]["delta"]["content"]
345
                            content += delta
346
                            message = ChatMessage(
347
                                role="assistant", content=content, raw=data
348
                            )
349
                            yield ChatResponse(message=message, delta=delta, raw=data)
350

351
        return gen()
352

353
    @llm_chat_callback()
354
    def stream_chat(
355
        self, messages: Sequence[ChatMessage], **kwargs: Any
356
    ) -> ChatResponseGen:
357
        return self._stream_chat(messages, **kwargs)
358

359
    async def _astream_chat(
360
        self, messages: Sequence[ChatMessage], **kwargs: Any
361
    ) -> ChatResponseAsyncGen:
362
        import aiohttp
363

364
        url = f"{self.api_base}/chat/completions"
365
        payload = {
366
            "model": self.model,
367
            "messages": [
368
                message.dict(exclude={"additional_kwargs"}) for message in messages
369
            ],
370
            "stream": True,
371
            **self._get_all_kwargs(**kwargs),
372
        }
373

374
        async def gen() -> ChatResponseAsyncGen:
375
            async with aiohttp.ClientSession() as session:
376
                async with session.post(
377
                    url, json=payload, headers=self.headers
378
                ) as response:
379
                    response.raise_for_status()
380
                    content = ""
381
                    async for line in response.content:
382
                        line_text = line.decode("utf-8").strip()
383
                        if line_text.startswith("data:"):
384
                            data = json.loads(line_text[5:])
385
                            delta = data["choices"][0]["delta"]["content"]
386
                            content += delta
387
                            message = ChatMessage(
388
                                role="assistant", content=content, raw=data
389
                            )
390
                            yield ChatResponse(message=message, delta=delta, raw=data)
391

392
        return gen()
393

394
    @llm_chat_callback()
395
    async def astream_chat(
396
        self, messages: Sequence[ChatMessage], **kwargs: Any
397
    ) -> ChatResponseAsyncGen:
398
        return await self._astream_chat(messages, **kwargs)
399

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.