llama-index
181 строка · 6.3 Кб
1"""Base object types."""
2
3import pickle4import warnings5from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar6
7from llama_index.legacy.bridge.pydantic import Field8from llama_index.legacy.callbacks.base import CallbackManager9from llama_index.legacy.core.base_retriever import BaseRetriever10from llama_index.legacy.core.query_pipeline.query_component import (11ChainableMixin,12InputKeys,13OutputKeys,14QueryComponent,15validate_and_convert_stringable,16)
17from llama_index.legacy.indices.base import BaseIndex18from llama_index.legacy.indices.vector_store.base import VectorStoreIndex19from llama_index.legacy.objects.base_node_mapping import (20DEFAULT_PERSIST_FNAME,21BaseObjectNodeMapping,22SimpleObjectNodeMapping,23)
24from llama_index.legacy.schema import QueryType25from llama_index.legacy.storage.storage_context import (26DEFAULT_PERSIST_DIR,27StorageContext,28)
29
30OT = TypeVar("OT")31
32
33class ObjectRetriever(ChainableMixin, Generic[OT]):34"""Object retriever."""35
36def __init__(37self, retriever: BaseRetriever, object_node_mapping: BaseObjectNodeMapping[OT]38):39self._retriever = retriever40self._object_node_mapping = object_node_mapping41
42@property43def retriever(self) -> BaseRetriever:44"""Retriever."""45return self._retriever46
47def retrieve(self, str_or_query_bundle: QueryType) -> List[OT]:48nodes = self._retriever.retrieve(str_or_query_bundle)49return [self._object_node_mapping.from_node(node.node) for node in nodes]50
51async def aretrieve(self, str_or_query_bundle: QueryType) -> List[OT]:52nodes = await self._retriever.aretrieve(str_or_query_bundle)53return [self._object_node_mapping.from_node(node.node) for node in nodes]54
55def _as_query_component(self, **kwargs: Any) -> QueryComponent:56"""As query component."""57return ObjectRetrieverComponent(retriever=self)58
59
60class ObjectRetrieverComponent(QueryComponent):61"""Object retriever component."""62
63retriever: ObjectRetriever = Field(..., description="Retriever.")64
65class Config:66arbitrary_types_allowed = True67
68def set_callback_manager(self, callback_manager: CallbackManager) -> None:69"""Set callback manager."""70self.retriever.retriever.callback_manager = callback_manager71
72def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:73"""Validate component inputs during run_component."""74# make sure input is a string75input["input"] = validate_and_convert_stringable(input["input"])76return input77
78def _run_component(self, **kwargs: Any) -> Any:79"""Run component."""80output = self.retriever.retrieve(kwargs["input"])81return {"output": output}82
83async def _arun_component(self, **kwargs: Any) -> Any:84"""Run component (async)."""85output = await self.retriever.aretrieve(kwargs["input"])86return {"output": output}87
88@property89def input_keys(self) -> InputKeys:90"""Input keys."""91return InputKeys.from_keys({"input"})92
93@property94def output_keys(self) -> OutputKeys:95"""Output keys."""96return OutputKeys.from_keys({"output"})97
98
99class ObjectIndex(Generic[OT]):100"""Object index."""101
102def __init__(103self, index: BaseIndex, object_node_mapping: BaseObjectNodeMapping104) -> None:105self._index = index106self._object_node_mapping = object_node_mapping107
108@classmethod109def from_objects(110cls,111objects: Sequence[OT],112object_mapping: Optional[BaseObjectNodeMapping] = None,113index_cls: Type[BaseIndex] = VectorStoreIndex,114**index_kwargs: Any,115) -> "ObjectIndex":116if object_mapping is None:117object_mapping = SimpleObjectNodeMapping.from_objects(objects)118nodes = object_mapping.to_nodes(objects)119index = index_cls(nodes, **index_kwargs)120return cls(index, object_mapping)121
122def insert_object(self, obj: Any) -> None:123self._object_node_mapping.add_object(obj)124node = self._object_node_mapping.to_node(obj)125self._index.insert_nodes([node])126
127def as_retriever(self, **kwargs: Any) -> ObjectRetriever:128return ObjectRetriever(129retriever=self._index.as_retriever(**kwargs),130object_node_mapping=self._object_node_mapping,131)132
133def as_node_retriever(self, **kwargs: Any) -> BaseRetriever:134return self._index.as_retriever(**kwargs)135
136def persist(137self,138persist_dir: str = DEFAULT_PERSIST_DIR,139obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,140) -> None:141# try to persist object node mapping142try:143self._object_node_mapping.persist(144persist_dir=persist_dir, obj_node_mapping_fname=obj_node_mapping_fname145)146except (NotImplementedError, pickle.PickleError) as err:147warnings.warn(148(149"Unable to persist ObjectNodeMapping. You will need to "150"reconstruct the same object node mapping to build this ObjectIndex"151),152stacklevel=2,153)154self._index._storage_context.persist(persist_dir=persist_dir)155
156@classmethod157def from_persist_dir(158cls,159persist_dir: str = DEFAULT_PERSIST_DIR,160object_node_mapping: Optional[BaseObjectNodeMapping] = None,161) -> "ObjectIndex":162from llama_index.legacy.indices import load_index_from_storage163
164storage_context = StorageContext.from_defaults(persist_dir=persist_dir)165index = load_index_from_storage(storage_context)166if object_node_mapping:167return cls(index=index, object_node_mapping=object_node_mapping)168else:169# try to load object_node_mapping170# assume SimpleObjectNodeMapping for simplicity as its only subclass171# that supports this method172try:173object_node_mapping = SimpleObjectNodeMapping.from_persist_dir(174persist_dir=persist_dir175)176except Exception as err:177raise Exception(178"Unable to load from persist dir. The object_node_mapping cannot be loaded."179) from err180else:181return cls(index=index, object_node_mapping=object_node_mapping)182