llama-index
103 строки · 3.7 Кб
1import json2from typing import Any, Dict, Optional, Type, Union, cast3
4from llama_index.legacy.bridge.pydantic import BaseModel5from llama_index.legacy.llms.huggingface import HuggingFaceLLM6from llama_index.legacy.llms.llama_cpp import LlamaCPP7from llama_index.legacy.program.llm_prompt_program import BaseLLMFunctionProgram8from llama_index.legacy.prompts.base import PromptTemplate9from llama_index.legacy.prompts.lmformatenforcer_utils import (10activate_lm_format_enforcer,11build_lm_format_enforcer_function,12)
13
14
15class LMFormatEnforcerPydanticProgram(BaseLLMFunctionProgram):16"""17A lm-format-enforcer-based function that returns a pydantic model.
18
19In LMFormatEnforcerPydanticProgram, prompt_template_str can also have a {json_schema} parameter
20that will be automatically filled by the json_schema of output_cls.
21Note: this interface is not yet stable.
22"""
23
24def __init__(25self,26output_cls: Type[BaseModel],27prompt_template_str: str,28llm: Optional[Union[LlamaCPP, HuggingFaceLLM]] = None,29verbose: bool = False,30):31try:32import lmformatenforcer33except ImportError as e:34raise ImportError(35"lm-format-enforcer package not found."36"please run `pip install lm-format-enforcer`"37) from e38
39if llm is None:40try:41from llama_index.legacy.llms import LlamaCPP42
43llm = LlamaCPP()44except ImportError as e:45raise ImportError(46"llama.cpp package not found."47"please run `pip install llama-cpp-python`"48) from e49
50self.llm = llm51
52self._prompt_template_str = prompt_template_str53self._output_cls = output_cls54self._verbose = verbose55json_schema_parser = lmformatenforcer.JsonSchemaParser(self.output_cls.schema())56self._token_enforcer_fn = build_lm_format_enforcer_function(57self.llm, json_schema_parser58)59
60@classmethod61def from_defaults(62cls,63output_cls: Type[BaseModel],64prompt_template_str: Optional[str] = None,65prompt: Optional[PromptTemplate] = None,66llm: Optional[Union["LlamaCPP", "HuggingFaceLLM"]] = None,67**kwargs: Any,68) -> "BaseLLMFunctionProgram":69"""From defaults."""70if prompt is None and prompt_template_str is None:71raise ValueError("Must provide either prompt or prompt_template_str.")72if prompt is not None and prompt_template_str is not None:73raise ValueError("Must provide either prompt or prompt_template_str.")74if prompt is not None:75prompt_template_str = prompt.template76prompt_template_str = cast(str, prompt_template_str)77return cls(78output_cls,79prompt_template_str,80llm=llm,81**kwargs,82)83
84@property85def output_cls(self) -> Type[BaseModel]:86return self._output_cls87
88def __call__(89self,90llm_kwargs: Optional[Dict[str, Any]] = None,91*args: Any,92**kwargs: Any,93) -> BaseModel:94llm_kwargs = llm_kwargs or {}95# While the format enforcer is active, any calls to the llm will have the format enforced.96with activate_lm_format_enforcer(self.llm, self._token_enforcer_fn):97json_schema_str = json.dumps(self.output_cls.schema())98full_str = self._prompt_template_str.format(99*args, **kwargs, json_schema=json_schema_str100)101output = self.llm.complete(full_str, **llm_kwargs)102text = output.text103return self.output_cls.parse_raw(text)104