llama-index
229 строк · 7.5 Кб
1from typing import Any, Dict, List, Optional, Sequence, cast2
3from llama_index.legacy.core.base_selector import (4BaseSelector,5SelectorResult,6SingleSelection,7)
8from llama_index.legacy.llm_predictor.base import LLMPredictorType9from llama_index.legacy.output_parsers.base import StructuredOutput10from llama_index.legacy.output_parsers.selection import Answer, SelectionOutputParser11from llama_index.legacy.prompts.mixin import PromptDictType12from llama_index.legacy.prompts.prompt_type import PromptType13from llama_index.legacy.schema import QueryBundle14from llama_index.legacy.selectors.prompts import (15DEFAULT_MULTI_SELECT_PROMPT_TMPL,16DEFAULT_SINGLE_SELECT_PROMPT_TMPL,17MultiSelectPrompt,18SingleSelectPrompt,19)
20from llama_index.legacy.service_context import ServiceContext21from llama_index.legacy.tools.types import ToolMetadata22from llama_index.legacy.types import BaseOutputParser23
24
25def _build_choices_text(choices: Sequence[ToolMetadata]) -> str:26"""Convert sequence of metadata to enumeration text."""27texts: List[str] = []28for ind, choice in enumerate(choices):29text = " ".join(choice.description.splitlines())30text = f"({ind + 1}) {text}" # to one indexing31texts.append(text)32return "\n\n".join(texts)33
34
35def _structured_output_to_selector_result(output: Any) -> SelectorResult:36"""Convert structured output to selector result."""37structured_output = cast(StructuredOutput, output)38answers = cast(List[Answer], structured_output.parsed_output)39
40# adjust for zero indexing41selections = [42SingleSelection(index=answer.choice - 1, reason=answer.reason)43for answer in answers44]45return SelectorResult(selections=selections)46
47
48class LLMSingleSelector(BaseSelector):49"""LLM single selector.50
51LLM-based selector that chooses one out of many options.
52
53Args:
54LLM (LLM): An LLM.
55prompt (SingleSelectPrompt): A LLM prompt for selecting one out of many options.
56"""
57
58def __init__(59self,60llm: LLMPredictorType,61prompt: SingleSelectPrompt,62) -> None:63self._llm = llm64self._prompt = prompt65
66if self._prompt.output_parser is None:67raise ValueError("Prompt should have output parser.")68
69@classmethod70def from_defaults(71cls,72service_context: Optional[ServiceContext] = None,73prompt_template_str: Optional[str] = None,74output_parser: Optional[BaseOutputParser] = None,75) -> "LLMSingleSelector":76# optionally initialize defaults77service_context = service_context or ServiceContext.from_defaults()78prompt_template_str = prompt_template_str or DEFAULT_SINGLE_SELECT_PROMPT_TMPL79output_parser = output_parser or SelectionOutputParser()80
81# construct prompt82prompt = SingleSelectPrompt(83template=prompt_template_str,84output_parser=output_parser,85prompt_type=PromptType.SINGLE_SELECT,86)87return cls(service_context.llm, prompt)88
89def _get_prompts(self) -> Dict[str, Any]:90"""Get prompts."""91return {"prompt": self._prompt}92
93def _update_prompts(self, prompts: PromptDictType) -> None:94"""Update prompts."""95if "prompt" in prompts:96self._prompt = prompts["prompt"]97
98def _select(99self, choices: Sequence[ToolMetadata], query: QueryBundle100) -> SelectorResult:101# prepare input102choices_text = _build_choices_text(choices)103
104# predict105prediction = self._llm.predict(106prompt=self._prompt,107num_choices=len(choices),108context_list=choices_text,109query_str=query.query_str,110)111
112# parse output113assert self._prompt.output_parser is not None114parse = self._prompt.output_parser.parse(prediction)115return _structured_output_to_selector_result(parse)116
117async def _aselect(118self, choices: Sequence[ToolMetadata], query: QueryBundle119) -> SelectorResult:120# prepare input121choices_text = _build_choices_text(choices)122
123# predict124prediction = await self._llm.apredict(125prompt=self._prompt,126num_choices=len(choices),127context_list=choices_text,128query_str=query.query_str,129)130
131# parse output132assert self._prompt.output_parser is not None133parse = self._prompt.output_parser.parse(prediction)134return _structured_output_to_selector_result(parse)135
136
137class LLMMultiSelector(BaseSelector):138"""LLM multi selector.139
140LLM-based selector that chooses multiple out of many options.
141
142Args:
143llm (LLM): An LLM.
144prompt (SingleSelectPrompt): A LLM prompt for selecting multiple out of many
145options.
146"""
147
148def __init__(149self,150llm: LLMPredictorType,151prompt: MultiSelectPrompt,152max_outputs: Optional[int] = None,153) -> None:154self._llm = llm155self._prompt = prompt156self._max_outputs = max_outputs157
158if self._prompt.output_parser is None:159raise ValueError("Prompt should have output parser.")160
161@classmethod162def from_defaults(163cls,164service_context: Optional[ServiceContext] = None,165prompt_template_str: Optional[str] = None,166output_parser: Optional[BaseOutputParser] = None,167max_outputs: Optional[int] = None,168) -> "LLMMultiSelector":169service_context = service_context or ServiceContext.from_defaults()170prompt_template_str = prompt_template_str or DEFAULT_MULTI_SELECT_PROMPT_TMPL171output_parser = output_parser or SelectionOutputParser()172
173# add output formatting174prompt_template_str = output_parser.format(prompt_template_str)175
176# construct prompt177prompt = MultiSelectPrompt(178template=prompt_template_str,179output_parser=output_parser,180prompt_type=PromptType.MULTI_SELECT,181)182return cls(service_context.llm, prompt, max_outputs)183
184def _get_prompts(self) -> Dict[str, Any]:185"""Get prompts."""186return {"prompt": self._prompt}187
188def _update_prompts(self, prompts: PromptDictType) -> None:189"""Update prompts."""190if "prompt" in prompts:191self._prompt = prompts["prompt"]192
193def _select(194self, choices: Sequence[ToolMetadata], query: QueryBundle195) -> SelectorResult:196# prepare input197context_list = _build_choices_text(choices)198max_outputs = self._max_outputs or len(choices)199
200prediction = self._llm.predict(201prompt=self._prompt,202num_choices=len(choices),203max_outputs=max_outputs,204context_list=context_list,205query_str=query.query_str,206)207
208assert self._prompt.output_parser is not None209parsed = self._prompt.output_parser.parse(prediction)210return _structured_output_to_selector_result(parsed)211
212async def _aselect(213self, choices: Sequence[ToolMetadata], query: QueryBundle214) -> SelectorResult:215# prepare input216context_list = _build_choices_text(choices)217max_outputs = self._max_outputs or len(choices)218
219prediction = await self._llm.apredict(220prompt=self._prompt,221num_choices=len(choices),222max_outputs=max_outputs,223context_list=context_list,224query_str=query.query_str,225)226
227assert self._prompt.output_parser is not None228parsed = self._prompt.output_parser.parse(prediction)229return _structured_output_to_selector_result(parsed)230