llama-index

Форк
0
232 строки · 7.0 Кб
1
"""txtai Vector store index.
2

3
An index that is built on top of an existing vector store.
4

5
"""
6

7
import json
8
import logging
9
import os
10
import pickle
11
from pathlib import Path
12
from typing import Any, List, Optional, cast
13

14
import fsspec
15
import numpy as np
16
from fsspec.implementations.local import LocalFileSystem
17

18
from llama_index.legacy.bridge.pydantic import PrivateAttr
19
from llama_index.legacy.schema import BaseNode
20
from llama_index.legacy.vector_stores.simple import DEFAULT_VECTOR_STORE, NAMESPACE_SEP
21
from llama_index.legacy.vector_stores.types import (
22
    DEFAULT_PERSIST_DIR,
23
    DEFAULT_PERSIST_FNAME,
24
    BasePydanticVectorStore,
25
    VectorStoreQuery,
26
    VectorStoreQueryResult,
27
)
28

29
logger = logging.getLogger()
30

31
DEFAULT_PERSIST_PATH = os.path.join(
32
    DEFAULT_PERSIST_DIR, f"{DEFAULT_VECTOR_STORE}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}"
33
)
34
IMPORT_ERROR_MSG = """
35
    `txtai` package not found. For instructions on
36
    how to install `txtai` please visit
37
    https://neuml.github.io/txtai/install/
38
"""
39

40

41
class TxtaiVectorStore(BasePydanticVectorStore):
42
    """txtai Vector Store.
43

44
    Embeddings are stored within a txtai index.
45

46
    During query time, the index uses txtai to query for the top
47
    k embeddings, and returns the corresponding indices.
48

49
    Args:
50
        txtai_index (txtai.ann.ANN): txtai index instance
51

52
    """
53

54
    stores_text: bool = False
55

56
    _txtai_index = PrivateAttr()
57

58
    def __init__(
59
        self,
60
        txtai_index: Any,
61
    ) -> None:
62
        """Initialize params."""
63
        try:
64
            import txtai
65
        except ImportError:
66
            raise ImportError(IMPORT_ERROR_MSG)
67

68
        self._txtai_index = cast(txtai.ann.ANN, txtai_index)
69

70
        super().__init__()
71

72
    @classmethod
73
    def from_persist_dir(
74
        cls,
75
        persist_dir: str = DEFAULT_PERSIST_DIR,
76
        fs: Optional[fsspec.AbstractFileSystem] = None,
77
    ) -> "TxtaiVectorStore":
78
        persist_path = os.path.join(
79
            persist_dir,
80
            f"{DEFAULT_VECTOR_STORE}{NAMESPACE_SEP}{DEFAULT_PERSIST_FNAME}",
81
        )
82
        # only support local storage for now
83
        if fs and not isinstance(fs, LocalFileSystem):
84
            raise NotImplementedError("txtai only supports local storage for now.")
85
        return cls.from_persist_path(persist_path=persist_path, fs=None)
86

87
    @classmethod
88
    def from_persist_path(
89
        cls,
90
        persist_path: str,
91
        fs: Optional[fsspec.AbstractFileSystem] = None,
92
    ) -> "TxtaiVectorStore":
93
        try:
94
            import txtai
95
        except ImportError:
96
            raise ImportError(IMPORT_ERROR_MSG)
97

98
        if fs and not isinstance(fs, LocalFileSystem):
99
            raise NotImplementedError("txtai only supports local storage for now.")
100

101
        if not os.path.exists(persist_path):
102
            raise ValueError(f"No existing {__name__} found at {persist_path}.")
103

104
        logger.info(f"Loading {__name__} config from {persist_path}.")
105
        parent_directory = Path(persist_path).parent
106
        config_path = parent_directory / "config.json"
107
        jsonconfig = config_path.exists()
108
        # Determine if config is json or pickle
109
        config_path = config_path if jsonconfig else parent_directory / "config"
110
        # Load configuration
111
        with open(config_path, "r" if jsonconfig else "rb") as f:
112
            config = json.load(f) if jsonconfig else pickle.load(f)
113

114
        logger.info(f"Loading {__name__} from {persist_path}.")
115
        txtai_index = txtai.ann.ANNFactory.create(config)
116
        txtai_index.load(persist_path)
117
        return cls(txtai_index=txtai_index)
118

119
    def add(
120
        self,
121
        nodes: List[BaseNode],
122
        **add_kwargs: Any,
123
    ) -> List[str]:
124
        """Add nodes to index.
125

126
        Args:
127
            nodes: List[BaseNode]: list of nodes with embeddings
128

129
        """
130
        text_embedding_np = np.array(
131
            [node.get_embedding() for node in nodes], dtype="float32"
132
        )
133

134
        # Check if the ann index is already created
135
        # If not create the index with node embeddings
136
        if self._txtai_index.backend is None:
137
            self._txtai_index.index(text_embedding_np)
138
        else:
139
            self._txtai_index.append(text_embedding_np)
140

141
        indx_size = self._txtai_index.count()
142
        return [str(idx) for idx in range(indx_size - len(nodes) + 1, indx_size + 1)]
143

144
    @property
145
    def client(self) -> Any:
146
        """Return the txtai index."""
147
        return self._txtai_index
148

149
    def persist(
150
        self,
151
        persist_path: str = DEFAULT_PERSIST_PATH,
152
        fs: Optional[fsspec.AbstractFileSystem] = None,
153
    ) -> None:
154
        """Save to file.
155

156
        This method saves the vector store to disk.
157

158
        Args:
159
            persist_path (str): The save_path of the file.
160

161
        """
162
        if fs and not isinstance(fs, LocalFileSystem):
163
            raise NotImplementedError("txtai only supports local storage for now.")
164

165
        dirpath = Path(persist_path).parent
166
        dirpath.mkdir(exist_ok=True)
167

168
        jsonconfig = self._txtai_index.config.get("format", "pickle") == "json"
169
        # Determine if config is json or pickle
170
        config_path = dirpath / "config.json" if jsonconfig else dirpath / "config"
171

172
        # Write configuration
173
        with open(
174
            config_path,
175
            "w" if jsonconfig else "wb",
176
            encoding="utf-8" if jsonconfig else None,
177
        ) as f:
178
            if jsonconfig:
179
                # Write config as JSON
180
                json.dump(self._txtai_index.config, f, default=str)
181
            else:
182
                from txtai.version import __pickle__
183

184
                # Write config as pickle format
185
                pickle.dump(self._txtai_index.config, f, protocol=__pickle__)
186

187
        self._txtai_index.save(persist_path)
188

189
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
190
        """
191
        Delete nodes using with ref_doc_id.
192

193
        Args:
194
            ref_doc_id (str): The doc_id of the document to delete.
195

196
        """
197
        self._txtai_index.delete([int(ref_doc_id)])
198

199
    def query(
200
        self,
201
        query: VectorStoreQuery,
202
        **kwargs: Any,
203
    ) -> VectorStoreQueryResult:
204
        """Query index for top k most similar nodes.
205

206
        Args:
207
            query (VectorStoreQuery): query to search for in the index
208

209
        """
210
        if query.filters is not None:
211
            raise ValueError("Metadata filters not implemented for txtai yet.")
212

213
        query_embedding = cast(List[float], query.query_embedding)
214
        query_embedding_np = np.array(query_embedding, dtype="float32")[np.newaxis, :]
215
        search_result = self._txtai_index.search(
216
            query_embedding_np, query.similarity_top_k
217
        )[0]
218
        # if empty, then return an empty response
219
        if len(search_result) == 0:
220
            return VectorStoreQueryResult(similarities=[], ids=[])
221

222
        filtered_dists = []
223
        filtered_node_idxs = []
224
        for dist, idx in search_result:
225
            if idx < 0:
226
                continue
227
            filtered_dists.append(dist)
228
            filtered_node_idxs.append(str(idx))
229

230
        return VectorStoreQueryResult(
231
            similarities=filtered_dists, ids=filtered_node_idxs
232
        )
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.