llama-index

Форк
0
229 строк · 7.5 Кб
1
from typing import Any, Dict, List, Optional, Sequence, cast
2

3
from llama_index.legacy.core.base_selector import (
4
    BaseSelector,
5
    SelectorResult,
6
    SingleSelection,
7
)
8
from llama_index.legacy.llm_predictor.base import LLMPredictorType
9
from llama_index.legacy.output_parsers.base import StructuredOutput
10
from llama_index.legacy.output_parsers.selection import Answer, SelectionOutputParser
11
from llama_index.legacy.prompts.mixin import PromptDictType
12
from llama_index.legacy.prompts.prompt_type import PromptType
13
from llama_index.legacy.schema import QueryBundle
14
from llama_index.legacy.selectors.prompts import (
15
    DEFAULT_MULTI_SELECT_PROMPT_TMPL,
16
    DEFAULT_SINGLE_SELECT_PROMPT_TMPL,
17
    MultiSelectPrompt,
18
    SingleSelectPrompt,
19
)
20
from llama_index.legacy.service_context import ServiceContext
21
from llama_index.legacy.tools.types import ToolMetadata
22
from llama_index.legacy.types import BaseOutputParser
23

24

25
def _build_choices_text(choices: Sequence[ToolMetadata]) -> str:
26
    """Convert sequence of metadata to enumeration text."""
27
    texts: List[str] = []
28
    for ind, choice in enumerate(choices):
29
        text = " ".join(choice.description.splitlines())
30
        text = f"({ind + 1}) {text}"  # to one indexing
31
        texts.append(text)
32
    return "\n\n".join(texts)
33

34

35
def _structured_output_to_selector_result(output: Any) -> SelectorResult:
36
    """Convert structured output to selector result."""
37
    structured_output = cast(StructuredOutput, output)
38
    answers = cast(List[Answer], structured_output.parsed_output)
39

40
    # adjust for zero indexing
41
    selections = [
42
        SingleSelection(index=answer.choice - 1, reason=answer.reason)
43
        for answer in answers
44
    ]
45
    return SelectorResult(selections=selections)
46

47

48
class LLMSingleSelector(BaseSelector):
49
    """LLM single selector.
50

51
    LLM-based selector that chooses one out of many options.
52

53
    Args:
54
        LLM (LLM): An LLM.
55
        prompt (SingleSelectPrompt): A LLM prompt for selecting one out of many options.
56
    """
57

58
    def __init__(
59
        self,
60
        llm: LLMPredictorType,
61
        prompt: SingleSelectPrompt,
62
    ) -> None:
63
        self._llm = llm
64
        self._prompt = prompt
65

66
        if self._prompt.output_parser is None:
67
            raise ValueError("Prompt should have output parser.")
68

69
    @classmethod
70
    def from_defaults(
71
        cls,
72
        service_context: Optional[ServiceContext] = None,
73
        prompt_template_str: Optional[str] = None,
74
        output_parser: Optional[BaseOutputParser] = None,
75
    ) -> "LLMSingleSelector":
76
        # optionally initialize defaults
77
        service_context = service_context or ServiceContext.from_defaults()
78
        prompt_template_str = prompt_template_str or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
79
        output_parser = output_parser or SelectionOutputParser()
80

81
        # construct prompt
82
        prompt = SingleSelectPrompt(
83
            template=prompt_template_str,
84
            output_parser=output_parser,
85
            prompt_type=PromptType.SINGLE_SELECT,
86
        )
87
        return cls(service_context.llm, prompt)
88

89
    def _get_prompts(self) -> Dict[str, Any]:
90
        """Get prompts."""
91
        return {"prompt": self._prompt}
92

93
    def _update_prompts(self, prompts: PromptDictType) -> None:
94
        """Update prompts."""
95
        if "prompt" in prompts:
96
            self._prompt = prompts["prompt"]
97

98
    def _select(
99
        self, choices: Sequence[ToolMetadata], query: QueryBundle
100
    ) -> SelectorResult:
101
        # prepare input
102
        choices_text = _build_choices_text(choices)
103

104
        # predict
105
        prediction = self._llm.predict(
106
            prompt=self._prompt,
107
            num_choices=len(choices),
108
            context_list=choices_text,
109
            query_str=query.query_str,
110
        )
111

112
        # parse output
113
        assert self._prompt.output_parser is not None
114
        parse = self._prompt.output_parser.parse(prediction)
115
        return _structured_output_to_selector_result(parse)
116

117
    async def _aselect(
118
        self, choices: Sequence[ToolMetadata], query: QueryBundle
119
    ) -> SelectorResult:
120
        # prepare input
121
        choices_text = _build_choices_text(choices)
122

123
        # predict
124
        prediction = await self._llm.apredict(
125
            prompt=self._prompt,
126
            num_choices=len(choices),
127
            context_list=choices_text,
128
            query_str=query.query_str,
129
        )
130

131
        # parse output
132
        assert self._prompt.output_parser is not None
133
        parse = self._prompt.output_parser.parse(prediction)
134
        return _structured_output_to_selector_result(parse)
135

136

137
class LLMMultiSelector(BaseSelector):
138
    """LLM multi selector.
139

140
    LLM-based selector that chooses multiple out of many options.
141

142
    Args:
143
        llm (LLM): An LLM.
144
        prompt (SingleSelectPrompt): A LLM prompt for selecting multiple out of many
145
            options.
146
    """
147

148
    def __init__(
149
        self,
150
        llm: LLMPredictorType,
151
        prompt: MultiSelectPrompt,
152
        max_outputs: Optional[int] = None,
153
    ) -> None:
154
        self._llm = llm
155
        self._prompt = prompt
156
        self._max_outputs = max_outputs
157

158
        if self._prompt.output_parser is None:
159
            raise ValueError("Prompt should have output parser.")
160

161
    @classmethod
162
    def from_defaults(
163
        cls,
164
        service_context: Optional[ServiceContext] = None,
165
        prompt_template_str: Optional[str] = None,
166
        output_parser: Optional[BaseOutputParser] = None,
167
        max_outputs: Optional[int] = None,
168
    ) -> "LLMMultiSelector":
169
        service_context = service_context or ServiceContext.from_defaults()
170
        prompt_template_str = prompt_template_str or DEFAULT_MULTI_SELECT_PROMPT_TMPL
171
        output_parser = output_parser or SelectionOutputParser()
172

173
        # add output formatting
174
        prompt_template_str = output_parser.format(prompt_template_str)
175

176
        # construct prompt
177
        prompt = MultiSelectPrompt(
178
            template=prompt_template_str,
179
            output_parser=output_parser,
180
            prompt_type=PromptType.MULTI_SELECT,
181
        )
182
        return cls(service_context.llm, prompt, max_outputs)
183

184
    def _get_prompts(self) -> Dict[str, Any]:
185
        """Get prompts."""
186
        return {"prompt": self._prompt}
187

188
    def _update_prompts(self, prompts: PromptDictType) -> None:
189
        """Update prompts."""
190
        if "prompt" in prompts:
191
            self._prompt = prompts["prompt"]
192

193
    def _select(
194
        self, choices: Sequence[ToolMetadata], query: QueryBundle
195
    ) -> SelectorResult:
196
        # prepare input
197
        context_list = _build_choices_text(choices)
198
        max_outputs = self._max_outputs or len(choices)
199

200
        prediction = self._llm.predict(
201
            prompt=self._prompt,
202
            num_choices=len(choices),
203
            max_outputs=max_outputs,
204
            context_list=context_list,
205
            query_str=query.query_str,
206
        )
207

208
        assert self._prompt.output_parser is not None
209
        parsed = self._prompt.output_parser.parse(prediction)
210
        return _structured_output_to_selector_result(parsed)
211

212
    async def _aselect(
213
        self, choices: Sequence[ToolMetadata], query: QueryBundle
214
    ) -> SelectorResult:
215
        # prepare input
216
        context_list = _build_choices_text(choices)
217
        max_outputs = self._max_outputs or len(choices)
218

219
        prediction = await self._llm.apredict(
220
            prompt=self._prompt,
221
            num_choices=len(choices),
222
            max_outputs=max_outputs,
223
            context_list=context_list,
224
            query_str=query.query_str,
225
        )
226

227
        assert self._prompt.output_parser is not None
228
        parsed = self._prompt.output_parser.parse(prediction)
229
        return _structured_output_to_selector_result(parsed)
230

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.