llama-index

Форк
0
181 строка · 6.3 Кб
1
"""Base object types."""
2

3
import pickle
4
import warnings
5
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar
6

7
from llama_index.legacy.bridge.pydantic import Field
8
from llama_index.legacy.callbacks.base import CallbackManager
9
from llama_index.legacy.core.base_retriever import BaseRetriever
10
from llama_index.legacy.core.query_pipeline.query_component import (
11
    ChainableMixin,
12
    InputKeys,
13
    OutputKeys,
14
    QueryComponent,
15
    validate_and_convert_stringable,
16
)
17
from llama_index.legacy.indices.base import BaseIndex
18
from llama_index.legacy.indices.vector_store.base import VectorStoreIndex
19
from llama_index.legacy.objects.base_node_mapping import (
20
    DEFAULT_PERSIST_FNAME,
21
    BaseObjectNodeMapping,
22
    SimpleObjectNodeMapping,
23
)
24
from llama_index.legacy.schema import QueryType
25
from llama_index.legacy.storage.storage_context import (
26
    DEFAULT_PERSIST_DIR,
27
    StorageContext,
28
)
29

30
OT = TypeVar("OT")
31

32

33
class ObjectRetriever(ChainableMixin, Generic[OT]):
34
    """Object retriever."""
35

36
    def __init__(
37
        self, retriever: BaseRetriever, object_node_mapping: BaseObjectNodeMapping[OT]
38
    ):
39
        self._retriever = retriever
40
        self._object_node_mapping = object_node_mapping
41

42
    @property
43
    def retriever(self) -> BaseRetriever:
44
        """Retriever."""
45
        return self._retriever
46

47
    def retrieve(self, str_or_query_bundle: QueryType) -> List[OT]:
48
        nodes = self._retriever.retrieve(str_or_query_bundle)
49
        return [self._object_node_mapping.from_node(node.node) for node in nodes]
50

51
    async def aretrieve(self, str_or_query_bundle: QueryType) -> List[OT]:
52
        nodes = await self._retriever.aretrieve(str_or_query_bundle)
53
        return [self._object_node_mapping.from_node(node.node) for node in nodes]
54

55
    def _as_query_component(self, **kwargs: Any) -> QueryComponent:
56
        """As query component."""
57
        return ObjectRetrieverComponent(retriever=self)
58

59

60
class ObjectRetrieverComponent(QueryComponent):
61
    """Object retriever component."""
62

63
    retriever: ObjectRetriever = Field(..., description="Retriever.")
64

65
    class Config:
66
        arbitrary_types_allowed = True
67

68
    def set_callback_manager(self, callback_manager: CallbackManager) -> None:
69
        """Set callback manager."""
70
        self.retriever.retriever.callback_manager = callback_manager
71

72
    def _validate_component_inputs(self, input: Dict[str, Any]) -> Dict[str, Any]:
73
        """Validate component inputs during run_component."""
74
        # make sure input is a string
75
        input["input"] = validate_and_convert_stringable(input["input"])
76
        return input
77

78
    def _run_component(self, **kwargs: Any) -> Any:
79
        """Run component."""
80
        output = self.retriever.retrieve(kwargs["input"])
81
        return {"output": output}
82

83
    async def _arun_component(self, **kwargs: Any) -> Any:
84
        """Run component (async)."""
85
        output = await self.retriever.aretrieve(kwargs["input"])
86
        return {"output": output}
87

88
    @property
89
    def input_keys(self) -> InputKeys:
90
        """Input keys."""
91
        return InputKeys.from_keys({"input"})
92

93
    @property
94
    def output_keys(self) -> OutputKeys:
95
        """Output keys."""
96
        return OutputKeys.from_keys({"output"})
97

98

99
class ObjectIndex(Generic[OT]):
100
    """Object index."""
101

102
    def __init__(
103
        self, index: BaseIndex, object_node_mapping: BaseObjectNodeMapping
104
    ) -> None:
105
        self._index = index
106
        self._object_node_mapping = object_node_mapping
107

108
    @classmethod
109
    def from_objects(
110
        cls,
111
        objects: Sequence[OT],
112
        object_mapping: Optional[BaseObjectNodeMapping] = None,
113
        index_cls: Type[BaseIndex] = VectorStoreIndex,
114
        **index_kwargs: Any,
115
    ) -> "ObjectIndex":
116
        if object_mapping is None:
117
            object_mapping = SimpleObjectNodeMapping.from_objects(objects)
118
        nodes = object_mapping.to_nodes(objects)
119
        index = index_cls(nodes, **index_kwargs)
120
        return cls(index, object_mapping)
121

122
    def insert_object(self, obj: Any) -> None:
123
        self._object_node_mapping.add_object(obj)
124
        node = self._object_node_mapping.to_node(obj)
125
        self._index.insert_nodes([node])
126

127
    def as_retriever(self, **kwargs: Any) -> ObjectRetriever:
128
        return ObjectRetriever(
129
            retriever=self._index.as_retriever(**kwargs),
130
            object_node_mapping=self._object_node_mapping,
131
        )
132

133
    def as_node_retriever(self, **kwargs: Any) -> BaseRetriever:
134
        return self._index.as_retriever(**kwargs)
135

136
    def persist(
137
        self,
138
        persist_dir: str = DEFAULT_PERSIST_DIR,
139
        obj_node_mapping_fname: str = DEFAULT_PERSIST_FNAME,
140
    ) -> None:
141
        # try to persist object node mapping
142
        try:
143
            self._object_node_mapping.persist(
144
                persist_dir=persist_dir, obj_node_mapping_fname=obj_node_mapping_fname
145
            )
146
        except (NotImplementedError, pickle.PickleError) as err:
147
            warnings.warn(
148
                (
149
                    "Unable to persist ObjectNodeMapping. You will need to "
150
                    "reconstruct the same object node mapping to build this ObjectIndex"
151
                ),
152
                stacklevel=2,
153
            )
154
        self._index._storage_context.persist(persist_dir=persist_dir)
155

156
    @classmethod
157
    def from_persist_dir(
158
        cls,
159
        persist_dir: str = DEFAULT_PERSIST_DIR,
160
        object_node_mapping: Optional[BaseObjectNodeMapping] = None,
161
    ) -> "ObjectIndex":
162
        from llama_index.legacy.indices import load_index_from_storage
163

164
        storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
165
        index = load_index_from_storage(storage_context)
166
        if object_node_mapping:
167
            return cls(index=index, object_node_mapping=object_node_mapping)
168
        else:
169
            # try to load object_node_mapping
170
            # assume SimpleObjectNodeMapping for simplicity as its only subclass
171
            # that supports this method
172
            try:
173
                object_node_mapping = SimpleObjectNodeMapping.from_persist_dir(
174
                    persist_dir=persist_dir
175
                )
176
            except Exception as err:
177
                raise Exception(
178
                    "Unable to load from persist dir. The object_node_mapping cannot be loaded."
179
                ) from err
180
            else:
181
                return cls(index=index, object_node_mapping=object_node_mapping)
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.