llama-index

Форк
0
1
"""Router components."""
2

3
from typing import Any, Dict, List
4

5
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.core.base_selector import BaseSelector
8
from llama_index.legacy.core.query_pipeline.query_component import (
9
    QUERY_COMPONENT_TYPE,
10
    ChainableMixin,
11
    InputKeys,
12
    OutputKeys,
13
    QueryComponent,
14
    validate_and_convert_stringable,
15
)
16
from llama_index.legacy.utils import print_text
17

18

19
class SelectorComponent(QueryComponent):
20
    """Selector component."""
21

22
    selector: BaseSelector = Field(..., description="Selector")
23

24
    class Config:
25
        arbitrary_types_allowed = True
26

27
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
28
        """Set callback manager."""
29

30
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
31
        """Validate component inputs during run_component."""
32
        if "choices" not in input:
33
            raise ValueError("Input must have key 'choices'")
34
        if not isinstance(input["choices"], list):
35
            raise ValueError("Input choices must be a list")
36

37
        for idx, choice in enumerate(input["choices"]):
38
            # make stringable
39
            input["choices"][idx] = validate_and_convert_stringable(choice)
40

41
        # make sure `query` is stringable
42
        if "query" not in input:
43
            raise ValueError("Input must have key 'query'")
44
        input["query"] = validate_and_convert_stringable(input["query"])
45

46
        return input
47

48
    def _run_component(self, **kwargs: Any) -> Any:
49
        """Run component."""
50
        output = self.selector.select(kwargs["choices"], kwargs["query"])
51
        return {"output": output.selections}
52

53
    async def _arun_component(self, **kwargs: Any) -> Any:
54
        """Run component (async)."""
55
        # NOTE: no native async for postprocessor
56
        return self._run_component(**kwargs)
57

58
    @property
59
    def input_keys(self) -> InputKeys:
60
        """Input keys."""
61
        return InputKeys.from_keys({"choices", "query"})
62

63
    @property
64
    def output_keys(self) -> OutputKeys:
65
        """Output keys."""
66
        return OutputKeys.from_keys({"output"})
67

68

69
class RouterComponent(QueryComponent):
70
    """Router Component.
71

72
    Routes queries to different query components based on a selector.
73

74
    Assumes a single query component is selected.
75

76
    """
77

78
    selector: BaseSelector = Field(..., description="Selector")
79
    choices: List[str] = Field(
80
        ..., description="Choices (must correspond to components)"
81
    )
82
    components: List[QueryComponent] = Field(
83
        ..., description="Components (must correspond to choices)"
84
    )
85
    verbose: bool = Field(default=False, description="Verbose")
86

87
    _query_keys: List[str] = PrivateAttr()
88

89
    class Config:
90
        arbitrary_types_allowed = True
91

92
    def __init__(
93
        self,
94
        selector: BaseSelector,
95
        choices: List[str],
96
        components: List[QUERY_COMPONENT_TYPE],
97
        verbose: bool = False,
98
    ) -> None:
99
        """Init."""
100
        new_components = []
101
        query_keys = []
102
        for component in components:
103
            if isinstance(component, ChainableMixin):
104
                new_component = component.as_query_component()
105
            else:
106
                new_component = component
107

108
            # validate component has one input key
109
            if len(new_component.free_req_input_keys) != 1:
110
                raise ValueError("Expected one required input key")
111
            query_keys.append(next(iter(new_component.free_req_input_keys)))
112
            new_components.append(new_component)
113

114
        self._query_keys = query_keys
115

116
        super().__init__(
117
            selector=selector,
118
            choices=choices,
119
            components=new_components,
120
            verbose=verbose,
121
        )
122

123
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
124
        """Set callback manager."""
125
        for component in self.components:
126
            component.set_callback_manager(callback_manager)
127

128
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
129
        """Validate component inputs during run_component."""
130
        # make sure `query` is stringable
131
        if "query" not in input:
132
            raise ValueError("Input must have key 'query'")
133
        input["query"] = validate_and_convert_stringable(input["query"])
134

135
        return input
136

137
    def validate_component_outputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
138
        """Validate component inputs during run_component."""
139
        return input
140

141
    def _validate_component_outputs(self, output: Dict[str, Any]) -> Dict[str, Any]:
142
        raise NotImplementedError
143

144
    def _run_component(self, **kwargs: Any) -> Any:
145
        """Run component."""
146
        # for the output selection, run the corresponding component, aggregate into list
147
        sel_output = self.selector.select(self.choices, kwargs["query"])
148
        # assume one selection
149
        if len(sel_output.selections) != 1:
150
            raise ValueError("Expected one selection")
151
        component = self.components[sel_output.ind]
152
        log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
153
        if self.verbose:
154
            print_text(log_str + "\n", color="pink")
155
        # run component
156
        # run with input_keys of component
157
        return component.run_component(
158
            **{self._query_keys[sel_output.ind]: kwargs["query"]}
159
        )
160

161
    async def _arun_component(self, **kwargs: Any) -> Any:
162
        """Run component (async)."""
163
        # for the output selection, run the corresponding component, aggregate into list
164
        sel_output = await self.selector.aselect(self.choices, kwargs["query"])
165
        # assume one selection
166
        if len(sel_output.selections) != 1:
167
            raise ValueError("Expected one selection")
168
        component = self.components[sel_output.ind]
169
        log_str = f"Selecting component {sel_output.ind}: " f"{sel_output.reason}."
170
        if self.verbose:
171
            print_text(log_str + "\n", color="pink")
172
        # run component
173
        return await component.arun_component(
174
            **{self._query_keys[sel_output.ind]: kwargs["query"]}
175
        )
176

177
    @property
178
    def input_keys(self) -> InputKeys:
179
        """Input keys."""
180
        return InputKeys.from_keys({"query"})
181

182
    @property
183
    def output_keys(self) -> OutputKeys:
184
        """Output keys."""
185
        # not used
186
        return OutputKeys.from_keys(set())
187

188
    @property
189
    def sub_query_components(self) -> List["QueryComponent"]:
190
        """Get sub query components.
191

192
        Certain query components may have sub query components, e.g. a
193
        query pipeline will have sub query components, and so will
194
        an IfElseComponent.
195

196
        """
197
        return self.components
198

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.