llama-index

Форк
0
336 строк · 10.9 Кб
1
"""Wrapper functions around an LLM chain."""
2

3
import logging
4
from abc import ABC, abstractmethod
5
from collections import ChainMap
6
from typing import Any, Dict, List, Optional, Union
7

8
from typing_extensions import Self
9

10
from llama_index.legacy.bridge.pydantic import BaseModel, PrivateAttr
11
from llama_index.legacy.callbacks.base import CallbackManager
12
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
13
from llama_index.legacy.core.llms.types import (
14
    ChatMessage,
15
    LLMMetadata,
16
    MessageRole,
17
)
18
from llama_index.legacy.llms.llm import (
19
    LLM,
20
    astream_chat_response_to_tokens,
21
    astream_completion_response_to_tokens,
22
    stream_chat_response_to_tokens,
23
    stream_completion_response_to_tokens,
24
)
25
from llama_index.legacy.llms.utils import LLMType, resolve_llm
26
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
27
from llama_index.legacy.schema import BaseComponent
28
from llama_index.legacy.types import PydanticProgramMode, TokenAsyncGen, TokenGen
29

30
logger = logging.getLogger(__name__)
31

32

33
class BaseLLMPredictor(BaseComponent, ABC):
34
    """Base LLM Predictor."""
35

36
    def dict(self, **kwargs: Any) -> Dict[str, Any]:
37
        data = super().dict(**kwargs)
38
        data["llm"] = self.llm.to_dict()
39
        return data
40

41
    def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
42
        data = super().to_dict(**kwargs)
43
        data["llm"] = self.llm.to_dict()
44
        return data
45

46
    @property
47
    @abstractmethod
48
    def llm(self) -> LLM:
49
        """Get LLM."""
50

51
    @property
52
    @abstractmethod
53
    def callback_manager(self) -> CallbackManager:
54
        """Get callback manager."""
55

56
    @property
57
    @abstractmethod
58
    def metadata(self) -> LLMMetadata:
59
        """Get LLM metadata."""
60

61
    @abstractmethod
62
    def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
63
        """Predict the answer to a query."""
64

65
    @abstractmethod
66
    def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:
67
        """Stream the answer to a query."""
68

69
    @abstractmethod
70
    async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:
71
        """Async predict the answer to a query."""
72

73
    @abstractmethod
74
    async def astream(
75
        self, prompt: BasePromptTemplate, **prompt_args: Any
76
    ) -> TokenAsyncGen:
77
        """Async predict the answer to a query."""
78

79

80
class LLMPredictor(BaseLLMPredictor):
81
    """LLM predictor class.
82

83
    A lightweight wrapper on top of LLMs that handles:
84
    - conversion of prompts to the string input format expected by LLMs
85
    - logging of prompts and responses to a callback manager
86

87
    NOTE: Mostly keeping around for legacy reasons. A potential future path is to
88
    deprecate this class and move all functionality into the LLM class.
89
    """
90

91
    class Config:
92
        arbitrary_types_allowed = True
93

94
    system_prompt: Optional[str]
95
    query_wrapper_prompt: Optional[BasePromptTemplate]
96
    pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT
97

98
    _llm: LLM = PrivateAttr()
99

100
    def __init__(
101
        self,
102
        llm: Optional[LLMType] = "default",
103
        callback_manager: Optional[CallbackManager] = None,
104
        system_prompt: Optional[str] = None,
105
        query_wrapper_prompt: Optional[BasePromptTemplate] = None,
106
        pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
107
    ) -> None:
108
        """Initialize params."""
109
        self._llm = resolve_llm(llm)
110

111
        if callback_manager:
112
            self._llm.callback_manager = callback_manager
113

114
        super().__init__(
115
            system_prompt=system_prompt,
116
            query_wrapper_prompt=query_wrapper_prompt,
117
            pydantic_program_mode=pydantic_program_mode,
118
        )
119

120
    @classmethod
121
    def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self:  # type: ignore
122
        if isinstance(kwargs, dict):
123
            data.update(kwargs)
124

125
        data.pop("class_name", None)
126

127
        llm = data.get("llm", "default")
128
        if llm != "default":
129
            from llama_index.legacy.llms.loading import load_llm
130

131
            llm = load_llm(llm)
132

133
        data["llm"] = llm
134
        return cls(**data)
135

136
    @classmethod
137
    def class_name(cls) -> str:
138
        return "LLMPredictor"
139

140
    @property
141
    def llm(self) -> LLM:
142
        """Get LLM."""
143
        return self._llm
144

145
    @property
146
    def callback_manager(self) -> CallbackManager:
147
        """Get callback manager."""
148
        return self._llm.callback_manager
149

150
    @property
151
    def metadata(self) -> LLMMetadata:
152
        """Get LLM metadata."""
153
        return self._llm.metadata
154

155
    def _log_template_data(
156
        self, prompt: BasePromptTemplate, **prompt_args: Any
157
    ) -> None:
158
        template_vars = {
159
            k: v
160
            for k, v in ChainMap(prompt.kwargs, prompt_args).items()
161
            if k in prompt.template_vars
162
        }
163
        with self.callback_manager.event(
164
            CBEventType.TEMPLATING,
165
            payload={
166
                EventPayload.TEMPLATE: prompt.get_template(llm=self._llm),
167
                EventPayload.TEMPLATE_VARS: template_vars,
168
                EventPayload.SYSTEM_PROMPT: self.system_prompt,
169
                EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,
170
            },
171
        ):
172
            pass
173

174
    def _run_program(
175
        self,
176
        output_cls: BaseModel,
177
        prompt: PromptTemplate,
178
        **prompt_args: Any,
179
    ) -> str:
180
        from llama_index.legacy.program.utils import get_program_for_llm
181

182
        program = get_program_for_llm(
183
            output_cls,
184
            prompt,
185
            self._llm,
186
            pydantic_program_mode=self.pydantic_program_mode,
187
        )
188

189
        chat_response = program(**prompt_args)
190
        return chat_response.json()
191

192
    async def _arun_program(
193
        self,
194
        output_cls: BaseModel,
195
        prompt: PromptTemplate,
196
        **prompt_args: Any,
197
    ) -> str:
198
        from llama_index.legacy.program.utils import get_program_for_llm
199

200
        program = get_program_for_llm(
201
            output_cls,
202
            prompt,
203
            self._llm,
204
            pydantic_program_mode=self.pydantic_program_mode,
205
        )
206

207
        chat_response = await program.acall(**prompt_args)
208
        return chat_response.json()
209

210
    def predict(
211
        self,
212
        prompt: BasePromptTemplate,
213
        output_cls: Optional[BaseModel] = None,
214
        **prompt_args: Any,
215
    ) -> str:
216
        """Predict."""
217
        self._log_template_data(prompt, **prompt_args)
218

219
        if output_cls is not None:
220
            output = self._run_program(output_cls, prompt, **prompt_args)
221
        elif self._llm.metadata.is_chat_model:
222
            messages = prompt.format_messages(llm=self._llm, **prompt_args)
223
            messages = self._extend_messages(messages)
224
            chat_response = self._llm.chat(messages)
225
            output = chat_response.message.content or ""
226
        else:
227
            formatted_prompt = prompt.format(llm=self._llm, **prompt_args)
228
            formatted_prompt = self._extend_prompt(formatted_prompt)
229
            response = self._llm.complete(formatted_prompt)
230
            output = response.text
231

232
        logger.debug(output)
233

234
        return output
235

236
    def stream(
237
        self,
238
        prompt: BasePromptTemplate,
239
        output_cls: Optional[BaseModel] = None,
240
        **prompt_args: Any,
241
    ) -> TokenGen:
242
        """Stream."""
243
        if output_cls is not None:
244
            raise NotImplementedError("Streaming with output_cls not supported.")
245

246
        self._log_template_data(prompt, **prompt_args)
247

248
        if self._llm.metadata.is_chat_model:
249
            messages = prompt.format_messages(llm=self._llm, **prompt_args)
250
            messages = self._extend_messages(messages)
251
            chat_response = self._llm.stream_chat(messages)
252
            stream_tokens = stream_chat_response_to_tokens(chat_response)
253
        else:
254
            formatted_prompt = prompt.format(llm=self._llm, **prompt_args)
255
            formatted_prompt = self._extend_prompt(formatted_prompt)
256
            stream_response = self._llm.stream_complete(formatted_prompt)
257
            stream_tokens = stream_completion_response_to_tokens(stream_response)
258
        return stream_tokens
259

260
    async def apredict(
261
        self,
262
        prompt: BasePromptTemplate,
263
        output_cls: Optional[BaseModel] = None,
264
        **prompt_args: Any,
265
    ) -> str:
266
        """Async predict."""
267
        self._log_template_data(prompt, **prompt_args)
268

269
        if output_cls is not None:
270
            output = await self._arun_program(output_cls, prompt, **prompt_args)
271
        elif self._llm.metadata.is_chat_model:
272
            messages = prompt.format_messages(llm=self._llm, **prompt_args)
273
            messages = self._extend_messages(messages)
274
            chat_response = await self._llm.achat(messages)
275
            output = chat_response.message.content or ""
276
        else:
277
            formatted_prompt = prompt.format(llm=self._llm, **prompt_args)
278
            formatted_prompt = self._extend_prompt(formatted_prompt)
279
            response = await self._llm.acomplete(formatted_prompt)
280
            output = response.text
281

282
        logger.debug(output)
283

284
        return output
285

286
    async def astream(
287
        self,
288
        prompt: BasePromptTemplate,
289
        output_cls: Optional[BaseModel] = None,
290
        **prompt_args: Any,
291
    ) -> TokenAsyncGen:
292
        """Async stream."""
293
        if output_cls is not None:
294
            raise NotImplementedError("Streaming with output_cls not supported.")
295

296
        self._log_template_data(prompt, **prompt_args)
297

298
        if self._llm.metadata.is_chat_model:
299
            messages = prompt.format_messages(llm=self._llm, **prompt_args)
300
            messages = self._extend_messages(messages)
301
            chat_response = await self._llm.astream_chat(messages)
302
            stream_tokens = await astream_chat_response_to_tokens(chat_response)
303
        else:
304
            formatted_prompt = prompt.format(llm=self._llm, **prompt_args)
305
            formatted_prompt = self._extend_prompt(formatted_prompt)
306
            stream_response = await self._llm.astream_complete(formatted_prompt)
307
            stream_tokens = await astream_completion_response_to_tokens(stream_response)
308
        return stream_tokens
309

310
    def _extend_prompt(
311
        self,
312
        formatted_prompt: str,
313
    ) -> str:
314
        """Add system and query wrapper prompts to base prompt."""
315
        extended_prompt = formatted_prompt
316
        if self.system_prompt:
317
            extended_prompt = self.system_prompt + "\n\n" + extended_prompt
318

319
        if self.query_wrapper_prompt:
320
            extended_prompt = self.query_wrapper_prompt.format(
321
                query_str=extended_prompt
322
            )
323

324
        return extended_prompt
325

326
    def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
327
        """Add system prompt to chat message list."""
328
        if self.system_prompt:
329
            messages = [
330
                ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),
331
                *messages,
332
            ]
333
        return messages
334

335

336
LLMPredictorType = Union[LLMPredictor, LLM]
337

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.