llama-index
330 строк · 11.7 Кб
1"""Base retriever."""
2
3from abc import abstractmethod4from typing import Any, Dict, List, Optional5
6from llama_index.legacy.bridge.pydantic import Field7from llama_index.legacy.callbacks.base import CallbackManager8from llama_index.legacy.callbacks.schema import CBEventType, EventPayload9from llama_index.legacy.core.base_query_engine import BaseQueryEngine10from llama_index.legacy.core.query_pipeline.query_component import (11ChainableMixin,12InputKeys,13OutputKeys,14QueryComponent,15validate_and_convert_stringable,16)
17from llama_index.legacy.prompts.mixin import (18PromptDictType,19PromptMixin,20PromptMixinType,21)
22from llama_index.legacy.schema import (23BaseNode,24IndexNode,25NodeWithScore,26QueryBundle,27QueryType,28TextNode,29)
30from llama_index.legacy.service_context import ServiceContext31from llama_index.legacy.utils import print_text32
33
34class BaseRetriever(ChainableMixin, PromptMixin):35"""Base retriever."""36
37def __init__(38self,39callback_manager: Optional[CallbackManager] = None,40object_map: Optional[Dict] = None,41objects: Optional[List[IndexNode]] = None,42verbose: bool = False,43) -> None:44self.callback_manager = callback_manager or CallbackManager()45
46if objects is not None:47object_map = {obj.index_id: obj.obj for obj in objects}48
49self.object_map = object_map or {}50self._verbose = verbose51
52def _check_callback_manager(self) -> None:53"""Check callback manager."""54if not hasattr(self, "callback_manager"):55self.callback_manager = CallbackManager()56
57def _get_prompts(self) -> PromptDictType:58"""Get prompts."""59return {}60
61def _get_prompt_modules(self) -> PromptMixinType:62"""Get prompt modules."""63return {}64
65def _update_prompts(self, prompts: PromptDictType) -> None:66"""Update prompts."""67
68def _retrieve_from_object(69self,70obj: Any,71query_bundle: QueryBundle,72score: float,73) -> List[NodeWithScore]:74"""Retrieve nodes from object."""75if self._verbose:76print_text(77f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",78color="llama_pink",79)80
81if isinstance(obj, NodeWithScore):82return [obj]83elif isinstance(obj, BaseNode):84return [NodeWithScore(node=obj, score=score)]85elif isinstance(obj, BaseQueryEngine):86response = obj.query(query_bundle)87return [88NodeWithScore(89node=TextNode(text=str(response), metadata=response.metadata or {}),90score=score,91)92]93elif isinstance(obj, BaseRetriever):94return obj.retrieve(query_bundle)95elif isinstance(obj, QueryComponent):96component_keys = obj.input_keys.required_keys97if len(component_keys) > 1:98raise ValueError(99f"QueryComponent {obj} has more than one input key: {component_keys}"100)101elif len(component_keys) == 0:102component_response = obj.run_component()103else:104kwargs = {next(iter(component_keys)): query_bundle.query_str}105component_response = obj.run_component(**kwargs)106
107result_output = str(next(iter(component_response.values())))108return [NodeWithScore(node=TextNode(text=result_output), score=score)]109else:110raise ValueError(f"Object {obj} is not retrievable.")111
112async def _aretrieve_from_object(113self,114obj: Any,115query_bundle: QueryBundle,116score: float,117) -> List[NodeWithScore]:118"""Retrieve nodes from object."""119if isinstance(obj, NodeWithScore):120return [obj]121elif isinstance(obj, BaseNode):122return [NodeWithScore(node=obj, score=score)]123elif isinstance(obj, BaseQueryEngine):124response = await obj.aquery(query_bundle)125return [NodeWithScore(node=TextNode(text=str(response)), score=score)]126elif isinstance(obj, BaseRetriever):127return await obj.aretrieve(query_bundle)128elif isinstance(obj, QueryComponent):129component_keys = obj.input_keys.required_keys130if len(component_keys) > 1:131raise ValueError(132f"QueryComponent {obj} has more than one input key: {component_keys}"133)134elif len(component_keys) == 0:135component_response = await obj.arun_component()136else:137kwargs = {next(iter(component_keys)): query_bundle.query_str}138component_response = await obj.arun_component(**kwargs)139
140result_output = str(next(iter(component_response.values())))141return [NodeWithScore(node=TextNode(text=result_output), score=score)]142else:143raise ValueError(f"Object {obj} is not retrievable.")144
145def _handle_recursive_retrieval(146self, query_bundle: QueryBundle, nodes: List[NodeWithScore]147) -> List[NodeWithScore]:148retrieved_nodes: List[NodeWithScore] = []149for n in nodes:150node = n.node151score = n.score or 1.0152if isinstance(node, IndexNode):153obj = node.obj or self.object_map.get(node.index_id, None)154if obj is not None:155if self._verbose:156print_text(157f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",158color="llama_turquoise",159)160retrieved_nodes.extend(161self._retrieve_from_object(162obj, query_bundle=query_bundle, score=score163)164)165else:166retrieved_nodes.append(n)167else:168retrieved_nodes.append(n)169
170seen = set()171return [172n
173for n in retrieved_nodes174if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]175]176
177async def _ahandle_recursive_retrieval(178self, query_bundle: QueryBundle, nodes: List[NodeWithScore]179) -> List[NodeWithScore]:180retrieved_nodes: List[NodeWithScore] = []181for n in nodes:182node = n.node183score = n.score or 1.0184if isinstance(node, IndexNode):185obj = self.object_map.get(node.index_id, None)186if obj is not None:187if self._verbose:188print_text(189f"Retrieval entering {node.index_id}: {obj.__class__.__name__}\n",190color="llama_turquoise",191)192# TODO: Add concurrent execution via `run_jobs()` ?193retrieved_nodes.extend(194await self._aretrieve_from_object(195obj, query_bundle=query_bundle, score=score196)197)198else:199retrieved_nodes.append(n)200else:201retrieved_nodes.append(n)202
203# remove any duplicates based on hash204seen = set()205return [206n
207for n in retrieved_nodes208if not (n.node.hash in seen or seen.add(n.node.hash)) # type: ignore[func-returns-value]209]210
211def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:212"""Retrieve nodes given query.213
214Args:
215str_or_query_bundle (QueryType): Either a query string or
216a QueryBundle object.
217
218"""
219self._check_callback_manager()220
221if isinstance(str_or_query_bundle, str):222query_bundle = QueryBundle(str_or_query_bundle)223else:224query_bundle = str_or_query_bundle225with self.callback_manager.as_trace("query"):226with self.callback_manager.event(227CBEventType.RETRIEVE,228payload={EventPayload.QUERY_STR: query_bundle.query_str},229) as retrieve_event:230nodes = self._retrieve(query_bundle)231nodes = self._handle_recursive_retrieval(query_bundle, nodes)232retrieve_event.on_end(233payload={EventPayload.NODES: nodes},234)235
236return nodes237
238async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:239self._check_callback_manager()240
241if isinstance(str_or_query_bundle, str):242query_bundle = QueryBundle(str_or_query_bundle)243else:244query_bundle = str_or_query_bundle245with self.callback_manager.as_trace("query"):246with self.callback_manager.event(247CBEventType.RETRIEVE,248payload={EventPayload.QUERY_STR: query_bundle.query_str},249) as retrieve_event:250nodes = await self._aretrieve(query_bundle)251nodes = await self._ahandle_recursive_retrieval(query_bundle, nodes)252retrieve_event.on_end(253payload={EventPayload.NODES: nodes},254)255
256return nodes257
258@abstractmethod259def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:260"""Retrieve nodes given query.261
262Implemented by the user.
263
264"""
265
266# TODO: make this abstract267# @abstractmethod268async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:269"""Asynchronously retrieve nodes given query.270
271Implemented by the user.
272
273"""
274return self._retrieve(query_bundle)275
276def get_service_context(self) -> Optional[ServiceContext]:277"""Attempts to resolve a service context.278Short-circuits at self.service_context, self._service_context,
279or self._index.service_context.
280"""
281if hasattr(self, "service_context"):282return self.service_context283if hasattr(self, "_service_context"):284return self._service_context285elif hasattr(self, "_index") and hasattr(self._index, "service_context"):286return self._index.service_context287return None288
289def _as_query_component(self, **kwargs: Any) -> QueryComponent:290"""Return a query component."""291return RetrieverComponent(retriever=self)292
293
294class RetrieverComponent(QueryComponent):295"""Retriever component."""296
297retriever: BaseRetriever = Field(..., description="Retriever")298
299class Config:300arbitrary_types_allowed = True301
302def set_callback_manager(self, callback_manager: CallbackManager) -> None:303"""Set callback manager."""304self.retriever.callback_manager = callback_manager305
306def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:307"""Validate component inputs during run_component."""308# make sure input is a string309input["input"] = validate_and_convert_stringable(input["input"])310return input311
312def _run_component(self, **kwargs: Any) -> Any:313"""Run component."""314output = self.retriever.retrieve(kwargs["input"])315return {"output": output}316
317async def _arun_component(self, **kwargs: Any) -> Any:318"""Run component."""319output = await self.retriever.aretrieve(kwargs["input"])320return {"output": output}321
322@property323def input_keys(self) -> InputKeys:324"""Input keys."""325return InputKeys.from_keys({"input"})326
327@property328def output_keys(self) -> OutputKeys:329"""Output keys."""330return OutputKeys.from_keys({"output"})331