llama-index
147 строк · 4.8 Кб
1from typing import Any, Dict, Optional, Sequence
2
3from llama_index.legacy.core.base_selector import (
4BaseSelector,
5MultiSelection,
6SelectorResult,
7SingleSelection,
8)
9from llama_index.legacy.llms.openai import OpenAI
10from llama_index.legacy.program.openai_program import OpenAIPydanticProgram
11from llama_index.legacy.prompts.mixin import PromptDictType
12from llama_index.legacy.schema import QueryBundle
13from llama_index.legacy.selectors.llm_selectors import _build_choices_text
14from llama_index.legacy.selectors.prompts import (
15DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
16DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
17)
18from llama_index.legacy.tools.types import ToolMetadata
19from llama_index.legacy.types import BasePydanticProgram
20
21
22def _pydantic_output_to_selector_result(output: Any) -> SelectorResult:
23"""
24Convert pydantic output to selector result.
25Takes into account zero-indexing on answer indexes.
26"""
27if isinstance(output, SingleSelection):
28output.index -= 1
29return SelectorResult(selections=[output])
30elif isinstance(output, MultiSelection):
31for idx in range(len(output.selections)):
32output.selections[idx].index -= 1
33return SelectorResult(selections=output.selections)
34else:
35raise ValueError(f"Unsupported output type: {type(output)}")
36
37
38class PydanticSingleSelector(BaseSelector):
39def __init__(self, selector_program: BasePydanticProgram) -> None:
40self._selector_program = selector_program
41
42@classmethod
43def from_defaults(
44cls,
45program: Optional[BasePydanticProgram] = None,
46llm: Optional[OpenAI] = None,
47prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
48verbose: bool = False,
49) -> "PydanticSingleSelector":
50if program is None:
51program = OpenAIPydanticProgram.from_defaults(
52output_cls=SingleSelection,
53prompt_template_str=prompt_template_str,
54llm=llm,
55verbose=verbose,
56)
57
58return cls(selector_program=program)
59
60def _get_prompts(self) -> Dict[str, Any]:
61"""Get prompts."""
62# TODO: no accessible prompts for a base pydantic program
63return {}
64
65def _update_prompts(self, prompts: PromptDictType) -> None:
66"""Update prompts."""
67
68def _select(
69self, choices: Sequence[ToolMetadata], query: QueryBundle
70) -> SelectorResult:
71# prepare input
72choices_text = _build_choices_text(choices)
73
74# predict
75prediction = self._selector_program(
76num_choices=len(choices),
77context_list=choices_text,
78query_str=query.query_str,
79)
80
81# parse output
82return _pydantic_output_to_selector_result(prediction)
83
84async def _aselect(
85self, choices: Sequence[ToolMetadata], query: QueryBundle
86) -> SelectorResult:
87raise NotImplementedError(
88"Async selection not supported for Pydantic Selectors."
89)
90
91
92class PydanticMultiSelector(BaseSelector):
93def __init__(
94self, selector_program: BasePydanticProgram, max_outputs: Optional[int] = None
95) -> None:
96self._selector_program = selector_program
97self._max_outputs = max_outputs
98
99@classmethod
100def from_defaults(
101cls,
102program: Optional[BasePydanticProgram] = None,
103llm: Optional[OpenAI] = None,
104prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
105max_outputs: Optional[int] = None,
106verbose: bool = False,
107) -> "PydanticMultiSelector":
108if program is None:
109program = OpenAIPydanticProgram.from_defaults(
110output_cls=MultiSelection,
111prompt_template_str=prompt_template_str,
112llm=llm,
113verbose=verbose,
114)
115
116return cls(selector_program=program, max_outputs=max_outputs)
117
118def _get_prompts(self) -> Dict[str, Any]:
119"""Get prompts."""
120# TODO: no accessible prompts for a base pydantic program
121return {}
122
123def _update_prompts(self, prompts: PromptDictType) -> None:
124"""Update prompts."""
125
126def _select(
127self, choices: Sequence[ToolMetadata], query: QueryBundle
128) -> SelectorResult:
129# prepare input
130context_list = _build_choices_text(choices)
131max_outputs = self._max_outputs or len(choices)
132
133# predict
134prediction = self._selector_program(
135num_choices=len(choices),
136max_outputs=max_outputs,
137context_list=context_list,
138query_str=query.query_str,
139)
140
141# parse output
142return _pydantic_output_to_selector_result(prediction)
143
144async def _aselect(
145self, choices: Sequence[ToolMetadata], query: QueryBundle
146) -> SelectorResult:
147return self._select(choices, query)
148