llama-index
135 строк · 4.5 Кб
1from typing import Any, Dict, Optional, Type, cast2
3from llama_index.legacy.bridge.pydantic import BaseModel4from llama_index.legacy.llms.llm import LLM5from llama_index.legacy.llms.openai import OpenAI6from llama_index.legacy.output_parsers.pydantic import PydanticOutputParser7from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate8from llama_index.legacy.types import BaseOutputParser, BasePydanticProgram9
10
11class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):12"""13LLM Text Completion Program.
14
15Uses generic LLM text completion + an output parser to generate a structured output.
16
17"""
18
19def __init__(20self,21output_parser: BaseOutputParser,22output_cls: Type[BaseModel],23prompt: BasePromptTemplate,24llm: LLM,25verbose: bool = False,26) -> None:27self._output_parser = output_parser28self._output_cls = output_cls29self._llm = llm30self._prompt = prompt31self._verbose = verbose32
33self._prompt.output_parser = output_parser34
35@classmethod36def from_defaults(37cls,38output_parser: Optional[BaseOutputParser] = None,39output_cls: Optional[Type[BaseModel]] = None,40prompt_template_str: Optional[str] = None,41prompt: Optional[PromptTemplate] = None,42llm: Optional[LLM] = None,43verbose: bool = False,44**kwargs: Any,45) -> "LLMTextCompletionProgram":46llm = llm or OpenAI(temperature=0, model="gpt-3.5-turbo-0613")47if prompt is None and prompt_template_str is None:48raise ValueError("Must provide either prompt or prompt_template_str.")49if prompt is not None and prompt_template_str is not None:50raise ValueError("Must provide either prompt or prompt_template_str.")51if prompt_template_str is not None:52prompt = PromptTemplate(prompt_template_str)53
54# decide default output class if not set55if output_cls is None:56if not isinstance(output_parser, PydanticOutputParser):57raise ValueError("Output parser must be PydanticOutputParser.")58output_cls = output_parser.output_cls59else:60if output_parser is None:61output_parser = PydanticOutputParser(output_cls=output_cls)62
63return cls(64output_parser,65output_cls,66prompt=cast(PromptTemplate, prompt),67llm=llm,68verbose=verbose,69)70
71@property72def output_cls(self) -> Type[BaseModel]:73return self._output_cls74
75@property76def prompt(self) -> BasePromptTemplate:77return self._prompt78
79@prompt.setter80def prompt(self, prompt: BasePromptTemplate) -> None:81self._prompt = prompt82
83def __call__(84self,85llm_kwargs: Optional[Dict[str, Any]] = None,86*args: Any,87**kwargs: Any,88) -> BaseModel:89llm_kwargs = llm_kwargs or {}90if self._llm.metadata.is_chat_model:91messages = self._prompt.format_messages(llm=self._llm, **kwargs)92
93response = self._llm.chat(messages, **llm_kwargs)94
95raw_output = response.message.content or ""96else:97formatted_prompt = self._prompt.format(llm=self._llm, **kwargs)98
99response = self._llm.complete(formatted_prompt, **llm_kwargs)100
101raw_output = response.text102
103output = self._output_parser.parse(raw_output)104if not isinstance(output, self._output_cls):105raise ValueError(106f"Output parser returned {type(output)} but expected {self._output_cls}"107)108return output109
110async def acall(111self,112llm_kwargs: Optional[Dict[str, Any]] = None,113*args: Any,114**kwargs: Any,115) -> BaseModel:116llm_kwargs = llm_kwargs or {}117if self._llm.metadata.is_chat_model:118messages = self._prompt.format_messages(llm=self._llm, **kwargs)119
120response = await self._llm.achat(messages, **llm_kwargs)121
122raw_output = response.message.content or ""123else:124formatted_prompt = self._prompt.format(llm=self._llm, **kwargs)125
126response = await self._llm.acomplete(formatted_prompt, **llm_kwargs)127
128raw_output = response.text129
130output = self._output_parser.parse(raw_output)131if not isinstance(output, self._output_cls):132raise ValueError(133f"Output parser returned {type(output)} but expected {self._output_cls}"134)135return output136