llama-index

Форк
0
103 строки · 3.7 Кб
1
import json
2
from typing import Any, Dict, Optional, Type, Union, cast
3

4
from llama_index.legacy.bridge.pydantic import BaseModel
5
from llama_index.legacy.llms.huggingface import HuggingFaceLLM
6
from llama_index.legacy.llms.llama_cpp import LlamaCPP
7
from llama_index.legacy.program.llm_prompt_program import BaseLLMFunctionProgram
8
from llama_index.legacy.prompts.base import PromptTemplate
9
from llama_index.legacy.prompts.lmformatenforcer_utils import (
10
    activate_lm_format_enforcer,
11
    build_lm_format_enforcer_function,
12
)
13

14

15
class LMFormatEnforcerPydanticProgram(BaseLLMFunctionProgram):
16
    """
17
    A lm-format-enforcer-based function that returns a pydantic model.
18

19
    In LMFormatEnforcerPydanticProgram, prompt_template_str can also have a {json_schema} parameter
20
    that will be automatically filled by the json_schema of output_cls.
21
    Note: this interface is not yet stable.
22
    """
23

24
    def __init__(
25
        self,
26
        output_cls: Type[BaseModel],
27
        prompt_template_str: str,
28
        llm: Optional[Union[LlamaCPP, HuggingFaceLLM]] = None,
29
        verbose: bool = False,
30
    ):
31
        try:
32
            import lmformatenforcer
33
        except ImportError as e:
34
            raise ImportError(
35
                "lm-format-enforcer package not found."
36
                "please run `pip install lm-format-enforcer`"
37
            ) from e
38

39
        if llm is None:
40
            try:
41
                from llama_index.legacy.llms import LlamaCPP
42

43
                llm = LlamaCPP()
44
            except ImportError as e:
45
                raise ImportError(
46
                    "llama.cpp package not found."
47
                    "please run `pip install llama-cpp-python`"
48
                ) from e
49

50
        self.llm = llm
51

52
        self._prompt_template_str = prompt_template_str
53
        self._output_cls = output_cls
54
        self._verbose = verbose
55
        json_schema_parser = lmformatenforcer.JsonSchemaParser(self.output_cls.schema())
56
        self._token_enforcer_fn = build_lm_format_enforcer_function(
57
            self.llm, json_schema_parser
58
        )
59

60
    @classmethod
61
    def from_defaults(
62
        cls,
63
        output_cls: Type[BaseModel],
64
        prompt_template_str: Optional[str] = None,
65
        prompt: Optional[PromptTemplate] = None,
66
        llm: Optional[Union["LlamaCPP", "HuggingFaceLLM"]] = None,
67
        **kwargs: Any,
68
    ) -> "BaseLLMFunctionProgram":
69
        """From defaults."""
70
        if prompt is None and prompt_template_str is None:
71
            raise ValueError("Must provide either prompt or prompt_template_str.")
72
        if prompt is not None and prompt_template_str is not None:
73
            raise ValueError("Must provide either prompt or prompt_template_str.")
74
        if prompt is not None:
75
            prompt_template_str = prompt.template
76
        prompt_template_str = cast(str, prompt_template_str)
77
        return cls(
78
            output_cls,
79
            prompt_template_str,
80
            llm=llm,
81
            **kwargs,
82
        )
83

84
    @property
85
    def output_cls(self) -> Type[BaseModel]:
86
        return self._output_cls
87

88
    def __call__(
89
        self,
90
        llm_kwargs: Optional[Dict[str, Any]] = None,
91
        *args: Any,
92
        **kwargs: Any,
93
    ) -> BaseModel:
94
        llm_kwargs = llm_kwargs or {}
95
        # While the format enforcer is active, any calls to the llm will have the format enforced.
96
        with activate_lm_format_enforcer(self.llm, self._token_enforcer_fn):
97
            json_schema_str = json.dumps(self.output_cls.schema())
98
            full_str = self._prompt_template_str.format(
99
                *args, **kwargs, json_schema=json_schema_str
100
            )
101
            output = self.llm.complete(full_str, **llm_kwargs)
102
            text = output.text
103
            return self.output_cls.parse_raw(text)
104

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.