llama-index

Форк
0
265 строк · 9.3 Кб
1
"""Epsilla vector store."""
2

3
import logging
4
from typing import Any, List, Optional
5

6
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
7
from llama_index.legacy.vector_stores.types import (
8
    DEFAULT_PERSIST_DIR,
9
    VectorStore,
10
    VectorStoreQuery,
11
    VectorStoreQueryMode,
12
    VectorStoreQueryResult,
13
)
14
from llama_index.legacy.vector_stores.utils import (
15
    DEFAULT_DOC_ID_KEY,
16
    DEFAULT_EMBEDDING_KEY,
17
    DEFAULT_TEXT_KEY,
18
    legacy_metadata_dict_to_node,
19
    metadata_dict_to_node,
20
    node_to_metadata_dict,
21
)
22

23
logger = logging.getLogger(__name__)
24

25

26
class EpsillaVectorStore(VectorStore):
27
    """The Epsilla Vector Store.
28

29
    In this vector store we store the text, its embedding and
30
    a few pieces of its metadata in a Epsilla collection. This implemnetation
31
    allows the use of an already existing collection.
32
    It also supports creating a new one if the collection does not
33
    exist or if `overwrite` is set to True.
34

35
    As a prerequisite, you need to install ``pyepsilla`` package
36
    and have a running Epsilla vector database (for example, through our docker image)
37
    See the following documentation for how to run an Epsilla vector database:
38
    https://epsilla-inc.gitbook.io/epsilladb/quick-start
39

40
    Args:
41
        client (Any): Epsilla client to connect to.
42
        collection_name (Optional[str]): Which collection to use.
43
                    Defaults to "llama_collection".
44
        db_path (Optional[str]): The path where the database will be persisted.
45
                    Defaults to "/tmp/langchain-epsilla".
46
        db_name (Optional[str]): Give a name to the loaded database.
47
                    Defaults to "langchain_store".
48
        dimension (Optional[int]): The dimension of the embeddings. If not provided,
49
                    collection creation will be done on first insert. Defaults to None.
50
        overwrite (Optional[bool]): Whether to overwrite existing collection with same
51
                    name. Defaults to False.
52

53
    Returns:
54
        EpsillaVectorStore: Vectorstore that supports add, delete, and query.
55
    """
56

57
    stores_text = True
58
    flat_metadata: bool = False
59

60
    def __init__(
61
        self,
62
        client: Any,
63
        collection_name: str = "llama_collection",
64
        db_path: Optional[str] = DEFAULT_PERSIST_DIR,  # sub folder
65
        db_name: Optional[str] = "llama_db",
66
        dimension: Optional[int] = None,
67
        overwrite: bool = False,
68
        **kwargs: Any,
69
    ) -> None:
70
        """Init params."""
71
        try:
72
            from pyepsilla import vectordb
73
        except ImportError as e:
74
            raise ImportError(
75
                "Could not import pyepsilla python package. "
76
                "Please install pyepsilla package with `pip/pip3 install pyepsilla`."
77
            ) from e
78

79
        if not isinstance(client, vectordb.Client):
80
            raise TypeError(
81
                f"client should be an instance of pyepsilla.vectordb.Client, "
82
                f"got {type(client)}"
83
            )
84

85
        self._client: vectordb.Client = client
86
        self._collection_name = collection_name
87
        self._client.load_db(db_name, db_path)
88
        self._client.use_db(db_name)
89
        self._collection_created = False
90

91
        status_code, response = self._client.list_tables()
92
        if status_code != 200:
93
            self._handle_error(msg=response["message"])
94
        table_list = response["result"]
95

96
        if self._collection_name in table_list and overwrite is False:
97
            self._collection_created = True
98

99
        if self._collection_name in table_list and overwrite is True:
100
            status_code, response = self._client.drop_table(
101
                table_name=self._collection_name
102
            )
103
            if status_code != 200:
104
                self._handle_error(msg=response["message"])
105
            logger.debug(
106
                f"Successfully removed old collection: {self._collection_name}"
107
            )
108
            if dimension is not None:
109
                self._create_collection(dimension)
110

111
        if self._collection_name not in table_list and dimension is not None:
112
            self._create_collection(dimension)
113

114
    def client(self) -> Any:
115
        """Return the Epsilla client."""
116
        return self._client
117

118
    def _handle_error(self, msg: str) -> None:
119
        """Handle error."""
120
        logger.error(f"Failed to get records: {msg}")
121
        raise Exception(f"Error: {msg}.")
122

123
    def _create_collection(self, dimension: int) -> None:
124
        """
125
        Create collection.
126

127
        Args:
128
            dimension (int): The dimension of the embeddings.
129
        """
130
        fields: List[dict] = [
131
            {"name": "id", "dataType": "STRING", "primaryKey": True},
132
            {"name": DEFAULT_DOC_ID_KEY, "dataType": "STRING"},
133
            {"name": DEFAULT_TEXT_KEY, "dataType": "STRING"},
134
            {
135
                "name": DEFAULT_EMBEDDING_KEY,
136
                "dataType": "VECTOR_FLOAT",
137
                "dimensions": dimension,
138
            },
139
            {"name": "metadata", "dataType": "JSON"},
140
        ]
141
        status_code, response = self._client.create_table(
142
            table_name=self._collection_name, table_fields=fields
143
        )
144
        if status_code != 200:
145
            self._handle_error(msg=response["message"])
146
        self._collection_created = True
147
        logger.debug(f"Successfully created collection: {self._collection_name}")
148

149
    def add(
150
        self,
151
        nodes: List[BaseNode],
152
        **add_kwargs: Any,
153
    ) -> List[str]:
154
        """
155
        Add nodes to Epsilla vector store.
156

157
        Args:
158
            nodes: List[BaseNode]: list of nodes with embeddings
159

160
        Returns:
161
            List[str]: List of ids inserted.
162
        """
163
        # If the collection doesn't exist yet, create the collection
164
        if not self._collection_created and len(nodes) > 0:
165
            dimension = len(nodes[0].get_embedding())
166
            self._create_collection(dimension)
167

168
        elif len(nodes) == 0:
169
            return []
170

171
        ids = []
172
        records = []
173
        for node in nodes:
174
            ids.append(node.node_id)
175
            text = node.get_content(metadata_mode=MetadataMode.NONE)
176
            metadata_dict = node_to_metadata_dict(node, remove_text=True)
177
            metadata = metadata_dict["_node_content"]
178
            record = {
179
                "id": node.node_id,
180
                DEFAULT_DOC_ID_KEY: node.ref_doc_id,
181
                DEFAULT_TEXT_KEY: text,
182
                DEFAULT_EMBEDDING_KEY: node.get_embedding(),
183
                "metadata": metadata,
184
            }
185
            records.append(record)
186

187
        status_code, response = self._client.insert(
188
            table_name=self._collection_name, records=records
189
        )
190
        if status_code != 200:
191
            self._handle_error(msg=response["message"])
192

193
        return ids
194

195
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
196
        """
197
        Delete nodes using with ref_doc_id.
198

199
        Args:
200
            ref_doc_id (str): The doc_id of the document to delete.
201
        """
202
        raise NotImplementedError("Delete with filtering will be coming soon.")
203

204
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
205
        """Query index for top k most similar nodes.
206

207
        Args:
208
            query (VectorStoreQuery): query.
209

210
        Returns:
211
            Vector store query result.
212
        """
213
        if not self._collection_created:
214
            raise ValueError("Please initialize a collection first.")
215

216
        if query.mode != VectorStoreQueryMode.DEFAULT:
217
            raise NotImplementedError(f"Epsilla does not support {query.mode} yet.")
218

219
        if query.filters is not None:
220
            raise NotImplementedError("Epsilla does not support Metadata filters yet.")
221

222
        if query.doc_ids is not None and len(query.doc_ids) > 0:
223
            raise NotImplementedError("Epsilla does not support filters yet.")
224

225
        status_code, response = self._client.query(
226
            table_name=self._collection_name,
227
            query_field=DEFAULT_EMBEDDING_KEY,
228
            query_vector=query.query_embedding,
229
            limit=query.similarity_top_k,
230
            with_distance=True,
231
        )
232
        if status_code != 200:
233
            self._handle_error(msg=response["message"])
234

235
        results = response["result"]
236
        logger.debug(
237
            f"Successfully searched embedding in collection: {self._collection_name}"
238
            f" Num Results: {len(results)}"
239
        )
240

241
        nodes = []
242
        similarities = []
243
        ids = []
244
        for res in results:
245
            try:
246
                node = metadata_dict_to_node({"_node_content": res["metadata"]})
247
                node.text = res[DEFAULT_TEXT_KEY]
248
            except Exception:
249
                # NOTE: deprecated legacy logic for backward compatibility
250
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
251
                    res["metadata"]
252
                )
253
                node = TextNode(
254
                    id=res["id"],
255
                    text=res[DEFAULT_TEXT_KEY],
256
                    metadata=metadata,
257
                    start_char_idx=node_info.get("start", None),
258
                    end_char_idx=node_info.get("end", None),
259
                    relationships=relationships,
260
                )
261
            nodes.append(node)
262
            similarities.append(res["@distance"])
263
            ids.append(res["id"])
264

265
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
266

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.