llama-index

Форк
0
135 строк · 4.5 Кб
1
from typing import Any, Dict, Optional, Type, cast
2

3
from llama_index.legacy.bridge.pydantic import BaseModel
4
from llama_index.legacy.llms.llm import LLM
5
from llama_index.legacy.llms.openai import OpenAI
6
from llama_index.legacy.output_parsers.pydantic import PydanticOutputParser
7
from llama_index.legacy.prompts.base import BasePromptTemplate, PromptTemplate
8
from llama_index.legacy.types import BaseOutputParser, BasePydanticProgram
9

10

11
class LLMTextCompletionProgram(BasePydanticProgram[BaseModel]):
12
    """
13
    LLM Text Completion Program.
14

15
    Uses generic LLM text completion + an output parser to generate a structured output.
16

17
    """
18

19
    def __init__(
20
        self,
21
        output_parser: BaseOutputParser,
22
        output_cls: Type[BaseModel],
23
        prompt: BasePromptTemplate,
24
        llm: LLM,
25
        verbose: bool = False,
26
    ) -> None:
27
        self._output_parser = output_parser
28
        self._output_cls = output_cls
29
        self._llm = llm
30
        self._prompt = prompt
31
        self._verbose = verbose
32

33
        self._prompt.output_parser = output_parser
34

35
    @classmethod
36
    def from_defaults(
37
        cls,
38
        output_parser: Optional[BaseOutputParser] = None,
39
        output_cls: Optional[Type[BaseModel]] = None,
40
        prompt_template_str: Optional[str] = None,
41
        prompt: Optional[PromptTemplate] = None,
42
        llm: Optional[LLM] = None,
43
        verbose: bool = False,
44
        **kwargs: Any,
45
    ) -> "LLMTextCompletionProgram":
46
        llm = llm or OpenAI(temperature=0, model="gpt-3.5-turbo-0613")
47
        if prompt is None and prompt_template_str is None:
48
            raise ValueError("Must provide either prompt or prompt_template_str.")
49
        if prompt is not None and prompt_template_str is not None:
50
            raise ValueError("Must provide either prompt or prompt_template_str.")
51
        if prompt_template_str is not None:
52
            prompt = PromptTemplate(prompt_template_str)
53

54
        # decide default output class if not set
55
        if output_cls is None:
56
            if not isinstance(output_parser, PydanticOutputParser):
57
                raise ValueError("Output parser must be PydanticOutputParser.")
58
            output_cls = output_parser.output_cls
59
        else:
60
            if output_parser is None:
61
                output_parser = PydanticOutputParser(output_cls=output_cls)
62

63
        return cls(
64
            output_parser,
65
            output_cls,
66
            prompt=cast(PromptTemplate, prompt),
67
            llm=llm,
68
            verbose=verbose,
69
        )
70

71
    @property
72
    def output_cls(self) -> Type[BaseModel]:
73
        return self._output_cls
74

75
    @property
76
    def prompt(self) -> BasePromptTemplate:
77
        return self._prompt
78

79
    @prompt.setter
80
    def prompt(self, prompt: BasePromptTemplate) -> None:
81
        self._prompt = prompt
82

83
    def __call__(
84
        self,
85
        llm_kwargs: Optional[Dict[str, Any]] = None,
86
        *args: Any,
87
        **kwargs: Any,
88
    ) -> BaseModel:
89
        llm_kwargs = llm_kwargs or {}
90
        if self._llm.metadata.is_chat_model:
91
            messages = self._prompt.format_messages(llm=self._llm, **kwargs)
92

93
            response = self._llm.chat(messages, **llm_kwargs)
94

95
            raw_output = response.message.content or ""
96
        else:
97
            formatted_prompt = self._prompt.format(llm=self._llm, **kwargs)
98

99
            response = self._llm.complete(formatted_prompt, **llm_kwargs)
100

101
            raw_output = response.text
102

103
        output = self._output_parser.parse(raw_output)
104
        if not isinstance(output, self._output_cls):
105
            raise ValueError(
106
                f"Output parser returned {type(output)} but expected {self._output_cls}"
107
            )
108
        return output
109

110
    async def acall(
111
        self,
112
        llm_kwargs: Optional[Dict[str, Any]] = None,
113
        *args: Any,
114
        **kwargs: Any,
115
    ) -> BaseModel:
116
        llm_kwargs = llm_kwargs or {}
117
        if self._llm.metadata.is_chat_model:
118
            messages = self._prompt.format_messages(llm=self._llm, **kwargs)
119

120
            response = await self._llm.achat(messages, **llm_kwargs)
121

122
            raw_output = response.message.content or ""
123
        else:
124
            formatted_prompt = self._prompt.format(llm=self._llm, **kwargs)
125

126
            response = await self._llm.acomplete(formatted_prompt, **llm_kwargs)
127

128
            raw_output = response.text
129

130
        output = self._output_parser.parse(raw_output)
131
        if not isinstance(output, self._output_cls):
132
            raise ValueError(
133
                f"Output parser returned {type(output)} but expected {self._output_cls}"
134
            )
135
        return output
136

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.