llama-index

Форк
0
107 строк · 3.3 Кб
1
from functools import partial
2
from typing import TYPE_CHECKING, Any, Optional, Type, cast
3

4
from llama_index.legacy.bridge.pydantic import BaseModel
5
from llama_index.legacy.program.llm_prompt_program import BaseLLMFunctionProgram
6
from llama_index.legacy.prompts.base import PromptTemplate
7
from llama_index.legacy.prompts.guidance_utils import (
8
    parse_pydantic_from_guidance_program,
9
)
10

11
if TYPE_CHECKING:
12
    from guidance.models import Model as GuidanceLLM
13

14

15
class GuidancePydanticProgram(BaseLLMFunctionProgram["GuidanceLLM"]):
16
    """
17
    A guidance-based function that returns a pydantic model.
18

19
    Note: this interface is not yet stable.
20
    """
21

22
    def __init__(
23
        self,
24
        output_cls: Type[BaseModel],
25
        prompt_template_str: str,
26
        guidance_llm: Optional["GuidanceLLM"] = None,
27
        verbose: bool = False,
28
    ):
29
        try:
30
            from guidance.models import OpenAIChat
31
        except ImportError as e:
32
            raise ImportError(
33
                "guidance package not found." "please run `pip install guidance`"
34
            ) from e
35

36
        if not guidance_llm:
37
            llm = guidance_llm
38
        else:
39
            llm = OpenAIChat("gpt-3.5-turbo")
40

41
        full_str = prompt_template_str + "\n"
42
        self._full_str = full_str
43
        self._guidance_program = partial(self.program, llm=llm, silent=not verbose)
44
        self._output_cls = output_cls
45
        self._verbose = verbose
46

47
    def program(
48
        self,
49
        llm: "GuidanceLLM",
50
        silent: bool,
51
        tools_str: str,
52
        query_str: str,
53
        **kwargs: dict,
54
    ) -> "GuidanceLLM":
55
        """A wrapper to execute the program with new guidance version."""
56
        from guidance import assistant, gen, user
57

58
        given_query = self._full_str.replace("{{tools_str}}", tools_str).replace(
59
            "{{query_str}}", query_str
60
        )
61
        with user():
62
            llm = llm + given_query
63

64
        with assistant():
65
            llm = llm + gen(stop=".")
66

67
        return llm  # noqa: RET504
68

69
    @classmethod
70
    def from_defaults(
71
        cls,
72
        output_cls: Type[BaseModel],
73
        prompt_template_str: Optional[str] = None,
74
        prompt: Optional[PromptTemplate] = None,
75
        llm: Optional["GuidanceLLM"] = None,
76
        **kwargs: Any,
77
    ) -> "BaseLLMFunctionProgram":
78
        """From defaults."""
79
        if prompt is None and prompt_template_str is None:
80
            raise ValueError("Must provide either prompt or prompt_template_str.")
81
        if prompt is not None and prompt_template_str is not None:
82
            raise ValueError("Must provide either prompt or prompt_template_str.")
83
        if prompt is not None:
84
            prompt_template_str = prompt.template
85
        prompt_template_str = cast(str, prompt_template_str)
86
        return cls(
87
            output_cls,
88
            prompt_template_str,
89
            guidance_llm=llm,
90
            **kwargs,
91
        )
92

93
    @property
94
    def output_cls(self) -> Type[BaseModel]:
95
        return self._output_cls
96

97
    def __call__(
98
        self,
99
        *args: Any,
100
        **kwargs: Any,
101
    ) -> BaseModel:
102
        executed_program = self._guidance_program(**kwargs)
103
        response = str(executed_program)
104

105
        return parse_pydantic_from_guidance_program(
106
            response=response, cls=self._output_cls
107
        )
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.