llama-index
336 строк · 10.9 Кб
1"""Wrapper functions around an LLM chain."""
2
3import logging4from abc import ABC, abstractmethod5from collections import ChainMap6from typing import Any, Dict, List, Optional, Union7
8from typing_extensions import Self9
10from llama_index.legacy.bridge.pydantic import BaseModel, PrivateAttr11from llama_index.legacy.callbacks.base import CallbackManager12from llama_index.legacy.callbacks.schema import CBEventType, EventPayload13from llama_index.legacy.core.llms.types import (14ChatMessage,15LLMMetadata,16MessageRole,17)
18from llama_index.legacy.llms.llm import (19LLM,20astream_chat_response_to_tokens,21astream_completion_response_to_tokens,22stream_chat_response_to_tokens,23stream_completion_response_to_tokens,24)
25from llama_index.legacy.llms.utils import LLMType, resolve_llm26from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate27from llama_index.legacy.schema import BaseComponent28from llama_index.legacy.types import PydanticProgramMode, TokenAsyncGen, TokenGen29
30logger = logging.getLogger(__name__)31
32
33class BaseLLMPredictor(BaseComponent, ABC):34"""Base LLM Predictor."""35
36def dict(self, **kwargs: Any) -> Dict[str, Any]:37data = super().dict(**kwargs)38data["llm"] = self.llm.to_dict()39return data40
41def to_dict(self, **kwargs: Any) -> Dict[str, Any]:42data = super().to_dict(**kwargs)43data["llm"] = self.llm.to_dict()44return data45
46@property47@abstractmethod48def llm(self) -> LLM:49"""Get LLM."""50
51@property52@abstractmethod53def callback_manager(self) -> CallbackManager:54"""Get callback manager."""55
56@property57@abstractmethod58def metadata(self) -> LLMMetadata:59"""Get LLM metadata."""60
61@abstractmethod62def predict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:63"""Predict the answer to a query."""64
65@abstractmethod66def stream(self, prompt: BasePromptTemplate, **prompt_args: Any) -> TokenGen:67"""Stream the answer to a query."""68
69@abstractmethod70async def apredict(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str:71"""Async predict the answer to a query."""72
73@abstractmethod74async def astream(75self, prompt: BasePromptTemplate, **prompt_args: Any76) -> TokenAsyncGen:77"""Async predict the answer to a query."""78
79
80class LLMPredictor(BaseLLMPredictor):81"""LLM predictor class.82
83A lightweight wrapper on top of LLMs that handles:
84- conversion of prompts to the string input format expected by LLMs
85- logging of prompts and responses to a callback manager
86
87NOTE: Mostly keeping around for legacy reasons. A potential future path is to
88deprecate this class and move all functionality into the LLM class.
89"""
90
91class Config:92arbitrary_types_allowed = True93
94system_prompt: Optional[str]95query_wrapper_prompt: Optional[BasePromptTemplate]96pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT97
98_llm: LLM = PrivateAttr()99
100def __init__(101self,102llm: Optional[LLMType] = "default",103callback_manager: Optional[CallbackManager] = None,104system_prompt: Optional[str] = None,105query_wrapper_prompt: Optional[BasePromptTemplate] = None,106pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,107) -> None:108"""Initialize params."""109self._llm = resolve_llm(llm)110
111if callback_manager:112self._llm.callback_manager = callback_manager113
114super().__init__(115system_prompt=system_prompt,116query_wrapper_prompt=query_wrapper_prompt,117pydantic_program_mode=pydantic_program_mode,118)119
120@classmethod121def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore122if isinstance(kwargs, dict):123data.update(kwargs)124
125data.pop("class_name", None)126
127llm = data.get("llm", "default")128if llm != "default":129from llama_index.legacy.llms.loading import load_llm130
131llm = load_llm(llm)132
133data["llm"] = llm134return cls(**data)135
136@classmethod137def class_name(cls) -> str:138return "LLMPredictor"139
140@property141def llm(self) -> LLM:142"""Get LLM."""143return self._llm144
145@property146def callback_manager(self) -> CallbackManager:147"""Get callback manager."""148return self._llm.callback_manager149
150@property151def metadata(self) -> LLMMetadata:152"""Get LLM metadata."""153return self._llm.metadata154
155def _log_template_data(156self, prompt: BasePromptTemplate, **prompt_args: Any157) -> None:158template_vars = {159k: v160for k, v in ChainMap(prompt.kwargs, prompt_args).items()161if k in prompt.template_vars162}163with self.callback_manager.event(164CBEventType.TEMPLATING,165payload={166EventPayload.TEMPLATE: prompt.get_template(llm=self._llm),167EventPayload.TEMPLATE_VARS: template_vars,168EventPayload.SYSTEM_PROMPT: self.system_prompt,169EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt,170},171):172pass173
174def _run_program(175self,176output_cls: BaseModel,177prompt: PromptTemplate,178**prompt_args: Any,179) -> str:180from llama_index.legacy.program.utils import get_program_for_llm181
182program = get_program_for_llm(183output_cls,184prompt,185self._llm,186pydantic_program_mode=self.pydantic_program_mode,187)188
189chat_response = program(**prompt_args)190return chat_response.json()191
192async def _arun_program(193self,194output_cls: BaseModel,195prompt: PromptTemplate,196**prompt_args: Any,197) -> str:198from llama_index.legacy.program.utils import get_program_for_llm199
200program = get_program_for_llm(201output_cls,202prompt,203self._llm,204pydantic_program_mode=self.pydantic_program_mode,205)206
207chat_response = await program.acall(**prompt_args)208return chat_response.json()209
210def predict(211self,212prompt: BasePromptTemplate,213output_cls: Optional[BaseModel] = None,214**prompt_args: Any,215) -> str:216"""Predict."""217self._log_template_data(prompt, **prompt_args)218
219if output_cls is not None:220output = self._run_program(output_cls, prompt, **prompt_args)221elif self._llm.metadata.is_chat_model:222messages = prompt.format_messages(llm=self._llm, **prompt_args)223messages = self._extend_messages(messages)224chat_response = self._llm.chat(messages)225output = chat_response.message.content or ""226else:227formatted_prompt = prompt.format(llm=self._llm, **prompt_args)228formatted_prompt = self._extend_prompt(formatted_prompt)229response = self._llm.complete(formatted_prompt)230output = response.text231
232logger.debug(output)233
234return output235
236def stream(237self,238prompt: BasePromptTemplate,239output_cls: Optional[BaseModel] = None,240**prompt_args: Any,241) -> TokenGen:242"""Stream."""243if output_cls is not None:244raise NotImplementedError("Streaming with output_cls not supported.")245
246self._log_template_data(prompt, **prompt_args)247
248if self._llm.metadata.is_chat_model:249messages = prompt.format_messages(llm=self._llm, **prompt_args)250messages = self._extend_messages(messages)251chat_response = self._llm.stream_chat(messages)252stream_tokens = stream_chat_response_to_tokens(chat_response)253else:254formatted_prompt = prompt.format(llm=self._llm, **prompt_args)255formatted_prompt = self._extend_prompt(formatted_prompt)256stream_response = self._llm.stream_complete(formatted_prompt)257stream_tokens = stream_completion_response_to_tokens(stream_response)258return stream_tokens259
260async def apredict(261self,262prompt: BasePromptTemplate,263output_cls: Optional[BaseModel] = None,264**prompt_args: Any,265) -> str:266"""Async predict."""267self._log_template_data(prompt, **prompt_args)268
269if output_cls is not None:270output = await self._arun_program(output_cls, prompt, **prompt_args)271elif self._llm.metadata.is_chat_model:272messages = prompt.format_messages(llm=self._llm, **prompt_args)273messages = self._extend_messages(messages)274chat_response = await self._llm.achat(messages)275output = chat_response.message.content or ""276else:277formatted_prompt = prompt.format(llm=self._llm, **prompt_args)278formatted_prompt = self._extend_prompt(formatted_prompt)279response = await self._llm.acomplete(formatted_prompt)280output = response.text281
282logger.debug(output)283
284return output285
286async def astream(287self,288prompt: BasePromptTemplate,289output_cls: Optional[BaseModel] = None,290**prompt_args: Any,291) -> TokenAsyncGen:292"""Async stream."""293if output_cls is not None:294raise NotImplementedError("Streaming with output_cls not supported.")295
296self._log_template_data(prompt, **prompt_args)297
298if self._llm.metadata.is_chat_model:299messages = prompt.format_messages(llm=self._llm, **prompt_args)300messages = self._extend_messages(messages)301chat_response = await self._llm.astream_chat(messages)302stream_tokens = await astream_chat_response_to_tokens(chat_response)303else:304formatted_prompt = prompt.format(llm=self._llm, **prompt_args)305formatted_prompt = self._extend_prompt(formatted_prompt)306stream_response = await self._llm.astream_complete(formatted_prompt)307stream_tokens = await astream_completion_response_to_tokens(stream_response)308return stream_tokens309
310def _extend_prompt(311self,312formatted_prompt: str,313) -> str:314"""Add system and query wrapper prompts to base prompt."""315extended_prompt = formatted_prompt316if self.system_prompt:317extended_prompt = self.system_prompt + "\n\n" + extended_prompt318
319if self.query_wrapper_prompt:320extended_prompt = self.query_wrapper_prompt.format(321query_str=extended_prompt322)323
324return extended_prompt325
326def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:327"""Add system prompt to chat message list."""328if self.system_prompt:329messages = [330ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt),331*messages,332]333return messages334
335
336LLMPredictorType = Union[LLMPredictor, LLM]337