llama-index

Форк
0
147 строк · 4.8 Кб
1
from typing import Any, Dict, Optional, Sequence
2

3
from llama_index.legacy.core.base_selector import (
4
    BaseSelector,
5
    MultiSelection,
6
    SelectorResult,
7
    SingleSelection,
8
)
9
from llama_index.legacy.llms.openai import OpenAI
10
from llama_index.legacy.program.openai_program import OpenAIPydanticProgram
11
from llama_index.legacy.prompts.mixin import PromptDictType
12
from llama_index.legacy.schema import QueryBundle
13
from llama_index.legacy.selectors.llm_selectors import _build_choices_text
14
from llama_index.legacy.selectors.prompts import (
15
    DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
16
    DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
17
)
18
from llama_index.legacy.tools.types import ToolMetadata
19
from llama_index.legacy.types import BasePydanticProgram
20

21

22
def _pydantic_output_to_selector_result(output: Any) -> SelectorResult:
23
    """
24
    Convert pydantic output to selector result.
25
    Takes into account zero-indexing on answer indexes.
26
    """
27
    if isinstance(output, SingleSelection):
28
        output.index -= 1
29
        return SelectorResult(selections=[output])
30
    elif isinstance(output, MultiSelection):
31
        for idx in range(len(output.selections)):
32
            output.selections[idx].index -= 1
33
        return SelectorResult(selections=output.selections)
34
    else:
35
        raise ValueError(f"Unsupported output type: {type(output)}")
36

37

38
class PydanticSingleSelector(BaseSelector):
39
    def __init__(self, selector_program: BasePydanticProgram) -> None:
40
        self._selector_program = selector_program
41

42
    @classmethod
43
    def from_defaults(
44
        cls,
45
        program: Optional[BasePydanticProgram] = None,
46
        llm: Optional[OpenAI] = None,
47
        prompt_template_str: str = DEFAULT_SINGLE_PYD_SELECT_PROMPT_TMPL,
48
        verbose: bool = False,
49
    ) -> "PydanticSingleSelector":
50
        if program is None:
51
            program = OpenAIPydanticProgram.from_defaults(
52
                output_cls=SingleSelection,
53
                prompt_template_str=prompt_template_str,
54
                llm=llm,
55
                verbose=verbose,
56
            )
57

58
        return cls(selector_program=program)
59

60
    def _get_prompts(self) -> Dict[str, Any]:
61
        """Get prompts."""
62
        # TODO: no accessible prompts for a base pydantic program
63
        return {}
64

65
    def _update_prompts(self, prompts: PromptDictType) -> None:
66
        """Update prompts."""
67

68
    def _select(
69
        self, choices: Sequence[ToolMetadata], query: QueryBundle
70
    ) -> SelectorResult:
71
        # prepare input
72
        choices_text = _build_choices_text(choices)
73

74
        # predict
75
        prediction = self._selector_program(
76
            num_choices=len(choices),
77
            context_list=choices_text,
78
            query_str=query.query_str,
79
        )
80

81
        # parse output
82
        return _pydantic_output_to_selector_result(prediction)
83

84
    async def _aselect(
85
        self, choices: Sequence[ToolMetadata], query: QueryBundle
86
    ) -> SelectorResult:
87
        raise NotImplementedError(
88
            "Async selection not supported for Pydantic Selectors."
89
        )
90

91

92
class PydanticMultiSelector(BaseSelector):
93
    def __init__(
94
        self, selector_program: BasePydanticProgram, max_outputs: Optional[int] = None
95
    ) -> None:
96
        self._selector_program = selector_program
97
        self._max_outputs = max_outputs
98

99
    @classmethod
100
    def from_defaults(
101
        cls,
102
        program: Optional[BasePydanticProgram] = None,
103
        llm: Optional[OpenAI] = None,
104
        prompt_template_str: str = DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL,
105
        max_outputs: Optional[int] = None,
106
        verbose: bool = False,
107
    ) -> "PydanticMultiSelector":
108
        if program is None:
109
            program = OpenAIPydanticProgram.from_defaults(
110
                output_cls=MultiSelection,
111
                prompt_template_str=prompt_template_str,
112
                llm=llm,
113
                verbose=verbose,
114
            )
115

116
        return cls(selector_program=program, max_outputs=max_outputs)
117

118
    def _get_prompts(self) -> Dict[str, Any]:
119
        """Get prompts."""
120
        # TODO: no accessible prompts for a base pydantic program
121
        return {}
122

123
    def _update_prompts(self, prompts: PromptDictType) -> None:
124
        """Update prompts."""
125

126
    def _select(
127
        self, choices: Sequence[ToolMetadata], query: QueryBundle
128
    ) -> SelectorResult:
129
        # prepare input
130
        context_list = _build_choices_text(choices)
131
        max_outputs = self._max_outputs or len(choices)
132

133
        # predict
134
        prediction = self._selector_program(
135
            num_choices=len(choices),
136
            max_outputs=max_outputs,
137
            context_list=context_list,
138
            query_str=query.query_str,
139
        )
140

141
        # parse output
142
        return _pydantic_output_to_selector_result(prediction)
143

144
    async def _aselect(
145
        self, choices: Sequence[ToolMetadata], query: QueryBundle
146
    ) -> SelectorResult:
147
        return self._select(choices, query)
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.