llama-index

Форк
0
330 строк · 11.7 Кб
1
"""Base retriever."""
2

3
from abc import abstractmethod
4
from typing import Any, Dict, List, Optional
5

6
from llama_index.legacy.bridge.pydantic import Field
7
from llama_index.legacy.callbacks.base import CallbackManager
8
from llama_index.legacy.callbacks.schema import CBEventType, EventPayload
9
from llama_index.legacy.core.base_query_engine import BaseQueryEngine
10
from llama_index.legacy.core.query_pipeline.query_component import (
11
    ChainableMixin,
12
    InputKeys,
13
    OutputKeys,
14
    QueryComponent,
15
    validate_and_convert_stringable,
16
)
17
from llama_index.legacy.prompts.mixin import (
18
    PromptDictType,
19
    PromptMixin,
20
    PromptMixinType,
21
)
22
from llama_index.legacy.schema import (
23
    BaseNode,
24
    IndexNode,
25
    NodeWithScore,
26
    QueryBundle,
27
    QueryType,
28
    TextNode,
29
)
30
from llama_index.legacy.service_context import ServiceContext
31
from llama_index.legacy.utils import print_text
32

33

34
class BaseRetriever(ChainableMixin, PromptMixin):
35
    """Base retriever."""
36

37
    def __init__(
38
        self,
39
        callback_manager: Optional[CallbackManager] = None,
40
        object_map: Optional[Dict] = None,
41
        objects: Optional[List[IndexNode]] = None,
42
        verbose: bool = False,
43
    ) -> None:
44
        self.callback_manager = callback_manager or CallbackManager()
45

46
        if objects is not None:
47
            object_map = {obj.index_id: obj.obj for obj in objects}
48

49
        self.object_map = object_map or {}
50
        self._verbose = verbose
51

52
    def _check_callback_manager(self) -> None:
53
        """Check callback manager."""
54
        if not hasattr(self, "callback_manager"):
55
            self.callback_manager = CallbackManager()
56

57
    def _get_prompts(self) -> PromptDictType:
58
        """Get prompts."""
59
        return {}
60

61
    def _get_prompt_modules(self) -> PromptMixinType:
62
        """Get prompt modules."""
63
        return {}
64

65
    def _update_prompts(self, prompts: PromptDictType) -> None:
66
        """Update prompts."""
67

68
    def _retrieve_from_object(
69
        self,
70
        obj: Any,
71
        query_bundle: QueryBundle,
72
        score: float,
73
    ) -> List[NodeWithScore]:
74
        """Retrieve nodes from object."""
75
        if self._verbose:
76
            print_text(
77
                f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
78
                color="llama_pink",
79
            )
80

81
        if isinstance(obj, NodeWithScore):
82
            return [obj]
83
        elif isinstance(obj, BaseNode):
84
            return [NodeWithScore(node=obj, score=score)]
85
        elif isinstance(obj, BaseQueryEngine):
86
            response = obj.query(query_bundle)
87
            return [
88
                NodeWithScore(
89
                    node=TextNode(text=str(response), metadata=response.metadata or {}),
90
                    score=score,
91
                )
92
            ]
93
        elif isinstance(obj, BaseRetriever):
94
            return obj.retrieve(query_bundle)
95
        elif isinstance(obj, QueryComponent):
96
            component_keys = obj.input_keys.required_keys
97
            if len(component_keys) > 1:
98
                raise ValueError(
99
                    f"QueryComponent {obj} has more than one input key: {component_keys}"
100
                )
101
            elif len(component_keys) == 0:
102
                component_response = obj.run_component()
103
            else:
104
                kwargs = {next(iter(component_keys)): query_bundle.query_str}
105
                component_response = obj.run_component(**kwargs)
106

107
            result_output = str(next(iter(component_response.values())))
108
            return [NodeWithScore(node=TextNode(text=result_output), score=score)]
109
        else:
110
            raise ValueError(f"Object {obj} is not retrievable.")
111

112
    async def _aretrieve_from_object(
113
        self,
114
        obj: Any,
115
        query_bundle: QueryBundle,
116
        score: float,
117
    ) -> List[NodeWithScore]:
118
        """Retrieve nodes from object."""
119
        if isinstance(obj, NodeWithScore):
120
            return [obj]
121
        elif isinstance(obj, BaseNode):
122
            return [NodeWithScore(node=obj, score=score)]
123
        elif isinstance(obj, BaseQueryEngine):
124
            response = await obj.aquery(query_bundle)
125
            return [NodeWithScore(node=TextNode(text=str(response)), score=score)]
126
        elif isinstance(obj, BaseRetriever):
127
            return await obj.aretrieve(query_bundle)
128
        elif isinstance(obj, QueryComponent):
129
            component_keys = obj.input_keys.required_keys
130
            if len(component_keys) > 1:
131
                raise ValueError(
132
                    f"QueryComponent {obj} has more than one input key: {component_keys}"
133
                )
134
            elif len(component_keys) == 0:
135
                component_response = await obj.arun_component()
136
            else:
137
                kwargs = {next(iter(component_keys)): query_bundle.query_str}
138
                component_response = await obj.arun_component(**kwargs)
139

140
            result_output = str(next(iter(component_response.values())))
141
            return [NodeWithScore(node=TextNode(text=result_output), score=score)]
142
        else:
143
            raise ValueError(f"Object {obj} is not retrievable.")
144

145
    def _handle_recursive_retrieval(
146
        self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
147
    ) -> List[NodeWithScore]:
148
        retrieved_nodes: List[NodeWithScore] = []
149
        for n in nodes:
150
            node = n.node
151
            score = n.score or 1.0
152
            if isinstance(node, IndexNode):
153
                obj = node.obj or self.object_map.get(node.index_id, None)
154
                if obj is not None:
155
                    if self._verbose:
156
                        print_text(
157
                            f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
158
                            color="llama_turquoise",
159
                        )
160
                    retrieved_nodes.extend(
161
                        self._retrieve_from_object(
162
                            obj, query_bundle=query_bundle, score=score
163
                        )
164
                    )
165
                else:
166
                    retrieved_nodes.append(n)
167
            else:
168
                retrieved_nodes.append(n)
169

170
        seen = set()
171
        return [
172
            n
173
            for n in retrieved_nodes
174
            if not (n.node.hash in seen or seen.add(n.node.hash))  # type: ignore[func-returns-value]
175
        ]
176

177
    async def _ahandle_recursive_retrieval(
178
        self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
179
    ) -> List[NodeWithScore]:
180
        retrieved_nodes: List[NodeWithScore] = []
181
        for n in nodes:
182
            node = n.node
183
            score = n.score or 1.0
184
            if isinstance(node, IndexNode):
185
                obj = self.object_map.get(node.index_id, None)
186
                if obj is not None:
187
                    if self._verbose:
188
                        print_text(
189
                            f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",
190
                            color="llama_turquoise",
191
                        )
192
                    # TODO: Add concurrent execution via `run_jobs()` ?
193
                    retrieved_nodes.extend(
194
                        await self._aretrieve_from_object(
195
                            obj, query_bundle=query_bundle, score=score
196
                        )
197
                    )
198
                else:
199
                    retrieved_nodes.append(n)
200
            else:
201
                retrieved_nodes.append(n)
202

203
        # remove any duplicates based on hash
204
        seen = set()
205
        return [
206
            n
207
            for n in retrieved_nodes
208
            if not (n.node.hash in seen or seen.add(n.node.hash))  # type: ignore[func-returns-value]
209
        ]
210

211
    def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
212
        """Retrieve nodes given query.
213

214
        Args:
215
            str_or_query_bundle (QueryType): Either a query string or
216
                a QueryBundle object.
217

218
        """
219
        self._check_callback_manager()
220

221
        if isinstance(str_or_query_bundle, str):
222
            query_bundle = QueryBundle(str_or_query_bundle)
223
        else:
224
            query_bundle = str_or_query_bundle
225
        with self.callback_manager.as_trace("query"):
226
            with self.callback_manager.event(
227
                CBEventType.RETRIEVE,
228
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
229
            ) as retrieve_event:
230
                nodes = self._retrieve(query_bundle)
231
                nodes = self._handle_recursive_retrieval(query_bundle, nodes)
232
                retrieve_event.on_end(
233
                    payload={EventPayload.NODES: nodes},
234
                )
235

236
        return nodes
237

238
    async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
239
        self._check_callback_manager()
240

241
        if isinstance(str_or_query_bundle, str):
242
            query_bundle = QueryBundle(str_or_query_bundle)
243
        else:
244
            query_bundle = str_or_query_bundle
245
        with self.callback_manager.as_trace("query"):
246
            with self.callback_manager.event(
247
                CBEventType.RETRIEVE,
248
                payload={EventPayload.QUERY_STR: query_bundle.query_str},
249
            ) as retrieve_event:
250
                nodes = await self._aretrieve(query_bundle)
251
                nodes = await self._ahandle_recursive_retrieval(query_bundle, nodes)
252
                retrieve_event.on_end(
253
                    payload={EventPayload.NODES: nodes},
254
                )
255

256
        return nodes
257

258
    @abstractmethod
259
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
260
        """Retrieve nodes given query.
261

262
        Implemented by the user.
263

264
        """
265

266
    # TODO: make this abstract
267
    # @abstractmethod
268
    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
269
        """Asynchronously retrieve nodes given query.
270

271
        Implemented by the user.
272

273
        """
274
        return self._retrieve(query_bundle)
275

276
    def get_service_context(self) -> Optional[ServiceContext]:
277
        """Attempts to resolve a service context.
278
        Short-circuits at self.service_context, self._service_context,
279
        or self._index.service_context.
280
        """
281
        if hasattr(self, "service_context"):
282
            return self.service_context
283
        if hasattr(self, "_service_context"):
284
            return self._service_context
285
        elif hasattr(self, "_index") and hasattr(self._index, "service_context"):
286
            return self._index.service_context
287
        return None
288

289
    def _as_query_component(self, **kwargs: Any) -> QueryComponent:
290
        """Return a query component."""
291
        return RetrieverComponent(retriever=self)
292

293

294
class RetrieverComponent(QueryComponent):
295
    """Retriever component."""
296

297
    retriever: BaseRetriever = Field(..., description="Retriever")
298

299
    class Config:
300
        arbitrary_types_allowed = True
301

302
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
303
        """Set callback manager."""
304
        self.retriever.callback_manager = callback_manager
305

306
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
307
        """Validate component inputs during run_component."""
308
        # make sure input is a string
309
        input["input"] = validate_and_convert_stringable(input["input"])
310
        return input
311

312
    def _run_component(self, **kwargs: Any) -> Any:
313
        """Run component."""
314
        output = self.retriever.retrieve(kwargs["input"])
315
        return {"output": output}
316

317
    async def _arun_component(self, **kwargs: Any) -> Any:
318
        """Run component."""
319
        output = await self.retriever.aretrieve(kwargs["input"])
320
        return {"output": output}
321

322
    @property
323
    def input_keys(self) -> InputKeys:
324
        """Input keys."""
325
        return InputKeys.from_keys({"input"})
326

327
    @property
328
    def output_keys(self) -> OutputKeys:
329
        """Output keys."""
330
        return OutputKeys.from_keys({"output"})
331

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.