llama-index

Форк
0
314 строк · 11.5 Кб
1
from __future__ import annotations
2

3
from enum import Enum
4
from os import getenv
5
from time import sleep
6
from types import ModuleType
7
from typing import Any, List, Type, TypeVar
8

9
from llama_index.legacy.schema import BaseNode
10
from llama_index.legacy.vector_stores.types import (
11
    VectorStore,
12
    VectorStoreQuery,
13
    VectorStoreQueryResult,
14
)
15
from llama_index.legacy.vector_stores.utils import (
16
    DEFAULT_EMBEDDING_KEY,
17
    DEFAULT_TEXT_KEY,
18
    metadata_dict_to_node,
19
    node_to_metadata_dict,
20
)
21

22
T = TypeVar("T", bound="RocksetVectorStore")
23

24

25
def _get_rockset() -> ModuleType:
26
    """Gets the rockset module and raises an ImportError if
27
    the rockset package hasn't been installed.
28

29
    Returns:
30
        rockset module (ModuleType)
31
    """
32
    try:
33
        import rockset
34
    except ImportError:
35
        raise ImportError("Please install rockset with `pip install rockset`")
36
    return rockset
37

38

39
def _get_client(api_key: str | None, api_server: str | None, client: Any | None) -> Any:
40
    """Returns the passed in client object if valid, else
41
    constructs and returns one.
42

43
    Returns:
44
        The rockset client object (rockset.RocksetClient)
45
    """
46
    rockset = _get_rockset()
47
    if client:
48
        if type(client) is not rockset.RocksetClient:
49
            raise ValueError("Parameter `client` must be of type rockset.RocksetClient")
50
    elif not api_key and not getenv("ROCKSET_API_KEY"):
51
        raise ValueError(
52
            "Parameter `client`, `api_key` or env var `ROCKSET_API_KEY` must be set"
53
        )
54
    else:
55
        client = rockset.RocksetClient(
56
            api_key=api_key or getenv("ROCKSET_API_KEY"),
57
            host=api_server or getenv("ROCKSET_API_SERVER"),
58
        )
59
    return client
60

61

62
class RocksetVectorStore(VectorStore):
63
    stores_text: bool = True
64
    is_embedding_query: bool = True
65
    flat_metadata: bool = False
66

67
    class DistanceFunc(Enum):
68
        COSINE_SIM = "COSINE_SIM"
69
        EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
70
        DOT_PRODUCT = "DOT_PRODUCT"
71

72
    def __init__(
73
        self,
74
        collection: str,
75
        client: Any | None = None,
76
        text_key: str = DEFAULT_TEXT_KEY,
77
        embedding_col: str = DEFAULT_EMBEDDING_KEY,
78
        metadata_col: str = "metadata",
79
        workspace: str = "commons",
80
        api_server: str | None = None,
81
        api_key: str | None = None,
82
        distance_func: DistanceFunc = DistanceFunc.COSINE_SIM,
83
    ) -> None:
84
        """Rockset Vector Store Data container.
85

86
        Args:
87
            collection (str): The name of the collection of vectors
88
            client (Optional[Any]): Rockset client object
89
            text_key (str): The key to the text of nodes
90
                (default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
91
            embedding_col (str): The DB column containing embeddings
92
                (default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
93
            metadata_col (str): The DB column containing node metadata
94
                (default: "metadata")
95
            workspace (str): The workspace containing the collection of vectors
96
                (default: "commons")
97
            api_server (Optional[str]): The Rockset API server to use
98
            api_key (Optional[str]): The Rockset API key to use
99
            distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
100
                vector relationship
101
                (default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
102
        """
103
        self.rockset = _get_rockset()
104
        self.rs = _get_client(api_key, api_server, client)
105
        self.workspace = workspace
106
        self.collection = collection
107
        self.text_key = text_key
108
        self.embedding_col = embedding_col
109
        self.metadata_col = metadata_col
110
        self.distance_func = distance_func
111
        self.distance_order = (
112
            "ASC" if distance_func is distance_func.EUCLIDEAN_DIST else "DESC"
113
        )
114

115
        try:
116
            self.rs.set_application("llama_index")
117
        except AttributeError:
118
            # set_application method does not exist.
119
            # rockset version < 2.1.0
120
            pass
121

122
    @property
123
    def client(self) -> Any:
124
        return self.rs
125

126
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
127
        """Stores vectors in the collection.
128

129
        Args:
130
            nodes (List[BaseNode]): List of nodes with embeddings
131

132
        Returns:
133
            Stored node IDs (List[str])
134
        """
135
        return [
136
            row["_id"]
137
            for row in self.rs.Documents.add_documents(
138
                collection=self.collection,
139
                workspace=self.workspace,
140
                data=[
141
                    {
142
                        self.embedding_col: node.get_embedding(),
143
                        "_id": node.node_id,
144
                        self.metadata_col: node_to_metadata_dict(
145
                            node, text_field=self.text_key
146
                        ),
147
                    }
148
                    for node in nodes
149
                ],
150
            ).data
151
        ]
152

153
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
154
        """Deletes nodes stored in the collection by their ref_doc_id.
155

156
        Args:
157
            ref_doc_id (str): The ref_doc_id of the document
158
                whose nodes are to be deleted
159
        """
160
        self.rs.Documents.delete_documents(
161
            collection=self.collection,
162
            workspace=self.workspace,
163
            data=[
164
                self.rockset.models.DeleteDocumentsRequestData(id=row["_id"])
165
                for row in self.rs.sql(
166
                    f"""
167
                        SELECT
168
                            _id
169
                        FROM
170
                            "{self.workspace}"."{self.collection}" x
171
                        WHERE
172
                            x.{self.metadata_col}.ref_doc_id=:ref_doc_id
173
                    """,
174
                    params={"ref_doc_id": ref_doc_id},
175
                ).results
176
            ],
177
        )
178

179
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
180
        """Gets nodes relevant to a query.
181

182
        Args:
183
            query (llama_index.vector_stores.types.VectorStoreQuery): The query
184
            similarity_col (Optional[str]): The column to select the cosine
185
                similarity as (default: "_similarity")
186

187
        Returns:
188
            query results (llama_index.vector_stores.types.VectorStoreQueryResult)
189
        """
190
        similarity_col = kwargs.get("similarity_col", "_similarity")
191
        res = self.rs.sql(
192
            f"""
193
                SELECT
194
                    _id,
195
                    {self.metadata_col}
196
                    {
197
                        f''', {self.distance_func.value}(
198
                            {query.query_embedding},
199
                            {self.embedding_col}
200
                        )
201
                            AS {similarity_col}'''
202
                        if query.query_embedding
203
                        else ''
204
                    }
205
                FROM
206
                    "{self.workspace}"."{self.collection}" x
207
                {"WHERE" if query.node_ids or (query.filters and len(query.filters.legacy_filters()) > 0) else ""} {
208
                    f'''({
209
                        ' OR '.join([
210
                            f"_id='{node_id}'" for node_id in query.node_ids
211
                        ])
212
                    })''' if query.node_ids else ""
213
                } {
214
                    f''' {'AND' if query.node_ids else ''} ({
215
                        ' AND '.join([
216
                            f"x.{self.metadata_col}.{filter.key}=:{filter.key}"
217
                            for filter
218
                            in query.filters.legacy_filters()
219
                        ])
220
                    })''' if query.filters else ""
221
                }
222
                ORDER BY
223
                    {similarity_col} {self.distance_order}
224
                LIMIT
225
                    {query.similarity_top_k}
226
            """,
227
            params=(
228
                {filter.key: filter.value for filter in query.filters.legacy_filters()}
229
                if query.filters
230
                else {}
231
            ),
232
        )
233

234
        similarities: List[float] | None = [] if query.query_embedding else None
235
        nodes, ids = [], []
236
        for row in res.results:
237
            if similarities is not None:
238
                similarities.append(row[similarity_col])
239
            nodes.append(metadata_dict_to_node(row[self.metadata_col]))
240
            ids.append(row["_id"])
241

242
        return VectorStoreQueryResult(similarities=similarities, nodes=nodes, ids=ids)
243

244
    @classmethod
245
    def with_new_collection(
246
        cls: Type[T], dimensions: int | None = None, **rockset_vector_store_args: Any
247
    ) -> RocksetVectorStore:
248
        """Creates a new collection and returns its RocksetVectorStore.
249

250
        Args:
251
            dimensions (Optional[int]): The length of the vectors to enforce
252
                in the collection's ingest transformation. By default, the
253
                collection will do no vector enforcement.
254
            collection (str): The name of the collection to be created
255
            client (Optional[Any]): Rockset client object
256
            workspace (str): The workspace containing the collection to be
257
                created (default: "commons")
258
            text_key (str): The key to the text of nodes
259
                (default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
260
            embedding_col (str): The DB column containing embeddings
261
                (default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
262
            metadata_col (str): The DB column containing node metadata
263
                (default: "metadata")
264
            api_server (Optional[str]): The Rockset API server to use
265
            api_key (Optional[str]): The Rockset API key to use
266
            distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
267
                vector relationship
268
                (default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
269
        """
270
        client = rockset_vector_store_args["client"] = _get_client(
271
            api_key=rockset_vector_store_args.get("api_key"),
272
            api_server=rockset_vector_store_args.get("api_server"),
273
            client=rockset_vector_store_args.get("client"),
274
        )
275
        collection_args = {
276
            "workspace": rockset_vector_store_args.get("workspace", "commons"),
277
            "name": rockset_vector_store_args.get("collection"),
278
        }
279
        embeddings_col = rockset_vector_store_args.get(
280
            "embeddings_col", DEFAULT_EMBEDDING_KEY
281
        )
282
        if dimensions:
283
            collection_args[
284
                "field_mapping_query"
285
            ] = _get_rockset().model.field_mapping_query.FieldMappingQuery(
286
                sql=f"""
287
                    SELECT
288
                        *, VECTOR_ENFORCE(
289
                            {embeddings_col},
290
                            {dimensions},
291
                            'float'
292
                        ) AS {embeddings_col}
293
                    FROM
294
                        _input
295
                """
296
            )
297

298
        client.Collections.create_s3_collection(**collection_args)  # create collection
299
        while (
300
            client.Collections.get(
301
                collection=rockset_vector_store_args.get("collection")
302
            ).data.status
303
            != "READY"
304
        ):  # wait until collection is ready
305
            sleep(0.1)
306
            # TODO: add async, non-blocking method collection creation
307

308
        return cls(
309
            **dict(
310
                filter(  # filter out None args
311
                    lambda arg: arg[1] is not None, rockset_vector_store_args.items()
312
                )
313
            )
314
        )
315

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.