llama-index
314 строк · 11.5 Кб
1from __future__ import annotations2
3from enum import Enum4from os import getenv5from time import sleep6from types import ModuleType7from typing import Any, List, Type, TypeVar8
9from llama_index.legacy.schema import BaseNode10from llama_index.legacy.vector_stores.types import (11VectorStore,12VectorStoreQuery,13VectorStoreQueryResult,14)
15from llama_index.legacy.vector_stores.utils import (16DEFAULT_EMBEDDING_KEY,17DEFAULT_TEXT_KEY,18metadata_dict_to_node,19node_to_metadata_dict,20)
21
22T = TypeVar("T", bound="RocksetVectorStore")23
24
25def _get_rockset() -> ModuleType:26"""Gets the rockset module and raises an ImportError if27the rockset package hasn't been installed.
28
29Returns:
30rockset module (ModuleType)
31"""
32try:33import rockset34except ImportError:35raise ImportError("Please install rockset with `pip install rockset`")36return rockset37
38
39def _get_client(api_key: str | None, api_server: str | None, client: Any | None) -> Any:40"""Returns the passed in client object if valid, else41constructs and returns one.
42
43Returns:
44The rockset client object (rockset.RocksetClient)
45"""
46rockset = _get_rockset()47if client:48if type(client) is not rockset.RocksetClient:49raise ValueError("Parameter `client` must be of type rockset.RocksetClient")50elif not api_key and not getenv("ROCKSET_API_KEY"):51raise ValueError(52"Parameter `client`, `api_key` or env var `ROCKSET_API_KEY` must be set"53)54else:55client = rockset.RocksetClient(56api_key=api_key or getenv("ROCKSET_API_KEY"),57host=api_server or getenv("ROCKSET_API_SERVER"),58)59return client60
61
62class RocksetVectorStore(VectorStore):63stores_text: bool = True64is_embedding_query: bool = True65flat_metadata: bool = False66
67class DistanceFunc(Enum):68COSINE_SIM = "COSINE_SIM"69EUCLIDEAN_DIST = "EUCLIDEAN_DIST"70DOT_PRODUCT = "DOT_PRODUCT"71
72def __init__(73self,74collection: str,75client: Any | None = None,76text_key: str = DEFAULT_TEXT_KEY,77embedding_col: str = DEFAULT_EMBEDDING_KEY,78metadata_col: str = "metadata",79workspace: str = "commons",80api_server: str | None = None,81api_key: str | None = None,82distance_func: DistanceFunc = DistanceFunc.COSINE_SIM,83) -> None:84"""Rockset Vector Store Data container.85
86Args:
87collection (str): The name of the collection of vectors
88client (Optional[Any]): Rockset client object
89text_key (str): The key to the text of nodes
90(default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
91embedding_col (str): The DB column containing embeddings
92(default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
93metadata_col (str): The DB column containing node metadata
94(default: "metadata")
95workspace (str): The workspace containing the collection of vectors
96(default: "commons")
97api_server (Optional[str]): The Rockset API server to use
98api_key (Optional[str]): The Rockset API key to use
99distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
100vector relationship
101(default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
102"""
103self.rockset = _get_rockset()104self.rs = _get_client(api_key, api_server, client)105self.workspace = workspace106self.collection = collection107self.text_key = text_key108self.embedding_col = embedding_col109self.metadata_col = metadata_col110self.distance_func = distance_func111self.distance_order = (112"ASC" if distance_func is distance_func.EUCLIDEAN_DIST else "DESC"113)114
115try:116self.rs.set_application("llama_index")117except AttributeError:118# set_application method does not exist.119# rockset version < 2.1.0120pass121
122@property123def client(self) -> Any:124return self.rs125
126def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:127"""Stores vectors in the collection.128
129Args:
130nodes (List[BaseNode]): List of nodes with embeddings
131
132Returns:
133Stored node IDs (List[str])
134"""
135return [136row["_id"]137for row in self.rs.Documents.add_documents(138collection=self.collection,139workspace=self.workspace,140data=[141{142self.embedding_col: node.get_embedding(),143"_id": node.node_id,144self.metadata_col: node_to_metadata_dict(145node, text_field=self.text_key146),147}148for node in nodes149],150).data151]152
153def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:154"""Deletes nodes stored in the collection by their ref_doc_id.155
156Args:
157ref_doc_id (str): The ref_doc_id of the document
158whose nodes are to be deleted
159"""
160self.rs.Documents.delete_documents(161collection=self.collection,162workspace=self.workspace,163data=[164self.rockset.models.DeleteDocumentsRequestData(id=row["_id"])165for row in self.rs.sql(166f"""167SELECT
168_id
169FROM
170"{self.workspace}"."{self.collection}" x171WHERE
172x.{self.metadata_col}.ref_doc_id=:ref_doc_id173""",174params={"ref_doc_id": ref_doc_id},175).results176],177)178
179def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:180"""Gets nodes relevant to a query.181
182Args:
183query (llama_index.vector_stores.types.VectorStoreQuery): The query
184similarity_col (Optional[str]): The column to select the cosine
185similarity as (default: "_similarity")
186
187Returns:
188query results (llama_index.vector_stores.types.VectorStoreQueryResult)
189"""
190similarity_col = kwargs.get("similarity_col", "_similarity")191res = self.rs.sql(192f"""193SELECT
194_id,
195{self.metadata_col}196{197f''', {self.distance_func.value}(198{query.query_embedding},199{self.embedding_col}200)
201AS {similarity_col}'''202if query.query_embedding203else ''204}205FROM
206"{self.workspace}"."{self.collection}" x207{"WHERE" if query.node_ids or (query.filters and len(query.filters.legacy_filters()) > 0) else ""} {208f'''({209' OR '.join([210f"_id='{node_id}'" for node_id in query.node_ids211])212})''' if query.node_ids else ""213} {214f''' {'AND' if query.node_ids else ''} ({215' AND '.join([216f"x.{self.metadata_col}.{filter.key}=:{filter.key}"217for filter218in query.filters.legacy_filters()219])220})''' if query.filters else ""221}222ORDER BY
223{similarity_col} {self.distance_order}224LIMIT
225{query.similarity_top_k}226""",227params=(228{filter.key: filter.value for filter in query.filters.legacy_filters()}229if query.filters230else {}231),232)233
234similarities: List[float] | None = [] if query.query_embedding else None235nodes, ids = [], []236for row in res.results:237if similarities is not None:238similarities.append(row[similarity_col])239nodes.append(metadata_dict_to_node(row[self.metadata_col]))240ids.append(row["_id"])241
242return VectorStoreQueryResult(similarities=similarities, nodes=nodes, ids=ids)243
244@classmethod245def with_new_collection(246cls: Type[T], dimensions: int | None = None, **rockset_vector_store_args: Any247) -> RocksetVectorStore:248"""Creates a new collection and returns its RocksetVectorStore.249
250Args:
251dimensions (Optional[int]): The length of the vectors to enforce
252in the collection's ingest transformation. By default, the
253collection will do no vector enforcement.
254collection (str): The name of the collection to be created
255client (Optional[Any]): Rockset client object
256workspace (str): The workspace containing the collection to be
257created (default: "commons")
258text_key (str): The key to the text of nodes
259(default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
260embedding_col (str): The DB column containing embeddings
261(default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
262metadata_col (str): The DB column containing node metadata
263(default: "metadata")
264api_server (Optional[str]): The Rockset API server to use
265api_key (Optional[str]): The Rockset API key to use
266distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
267vector relationship
268(default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
269"""
270client = rockset_vector_store_args["client"] = _get_client(271api_key=rockset_vector_store_args.get("api_key"),272api_server=rockset_vector_store_args.get("api_server"),273client=rockset_vector_store_args.get("client"),274)275collection_args = {276"workspace": rockset_vector_store_args.get("workspace", "commons"),277"name": rockset_vector_store_args.get("collection"),278}279embeddings_col = rockset_vector_store_args.get(280"embeddings_col", DEFAULT_EMBEDDING_KEY281)282if dimensions:283collection_args[284"field_mapping_query"285] = _get_rockset().model.field_mapping_query.FieldMappingQuery(286sql=f"""287SELECT
288*, VECTOR_ENFORCE(
289{embeddings_col},290{dimensions},291'float'
292) AS {embeddings_col}293FROM
294_input
295"""
296)297
298client.Collections.create_s3_collection(**collection_args) # create collection299while (300client.Collections.get(301collection=rockset_vector_store_args.get("collection")302).data.status303!= "READY"304): # wait until collection is ready305sleep(0.1)306# TODO: add async, non-blocking method collection creation307
308return cls(309**dict(310filter( # filter out None args311lambda arg: arg[1] is not None, rockset_vector_store_args.items()312)313)314)315