llama-index
355 строк · 11.2 Кб
1"""Weaviate Vector store index.
2
3An index that is built on top of an existing vector store.
4
5"""
6
7import logging8from typing import Any, Dict, List, Optional, cast9from uuid import uuid410
11from llama_index.legacy.bridge.pydantic import Field, PrivateAttr12from llama_index.legacy.schema import BaseNode13from llama_index.legacy.vector_stores.types import (14BasePydanticVectorStore,15MetadataFilters,16VectorStoreQuery,17VectorStoreQueryMode,18VectorStoreQueryResult,19)
20from llama_index.legacy.vector_stores.utils import DEFAULT_TEXT_KEY21from llama_index.legacy.vector_stores.weaviate_utils import (22add_node,23class_schema_exists,24create_default_schema,25get_all_properties,26get_node_similarity,27parse_get_response,28to_node,29)
30
31logger = logging.getLogger(__name__)32
33import_err_msg = (34"`weaviate` package not found, please run `pip install weaviate-client`"35)
36
37
38def _transform_weaviate_filter_condition(condition: str) -> str:39"""Translate standard metadata filter op to Chroma specific spec."""40if condition == "and":41return "And"42elif condition == "or":43return "Or"44else:45raise ValueError(f"Filter condition {condition} not supported")46
47
48def _transform_weaviate_filter_operator(operator: str) -> str:49"""Translate standard metadata filter operator to Chroma specific spec."""50if operator == "!=":51return "NotEqual"52elif operator == "==":53return "Equal"54elif operator == ">":55return "GreaterThan"56elif operator == "<":57return "LessThan"58elif operator == ">=":59return "GreaterThanEqual"60elif operator == "<=":61return "LessThanEqual"62else:63raise ValueError(f"Filter operator {operator} not supported")64
65
66def _to_weaviate_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:67filters_list = []68condition = standard_filters.condition or "and"69condition = _transform_weaviate_filter_condition(condition)70
71if standard_filters.filters:72for filter in standard_filters.filters:73value_type = "valueText"74if isinstance(filter.value, float):75value_type = "valueNumber"76elif isinstance(filter.value, int):77value_type = "valueNumber"78elif isinstance(filter.value, str) and filter.value.isnumeric():79filter.value = float(filter.value)80value_type = "valueNumber"81filters_list.append(82{83"path": filter.key,84"operator": _transform_weaviate_filter_operator(filter.operator),85value_type: filter.value,86}87)88else:89return {}90
91if len(filters_list) == 1:92# If there is only one filter, return it directly93return filters_list[0]94
95return {"operands": filters_list, "operator": condition}96
97
98class WeaviateVectorStore(BasePydanticVectorStore):99"""Weaviate vector store.100
101In this vector store, embeddings and docs are stored within a
102Weaviate collection.
103
104During query time, the index uses Weaviate to query for the top
105k most similar nodes.
106
107Args:
108weaviate_client (weaviate.Client): WeaviateClient
109instance from `weaviate-client` package
110index_name (Optional[str]): name for Weaviate classes
111
112"""
113
114stores_text: bool = True115
116index_name: str117url: Optional[str]118text_key: str119auth_config: Dict[str, Any] = Field(default_factory=dict)120client_kwargs: Dict[str, Any] = Field(default_factory=dict)121
122_client = PrivateAttr()123
124def __init__(125self,126weaviate_client: Optional[Any] = None,127class_prefix: Optional[str] = None,128index_name: Optional[str] = None,129text_key: str = DEFAULT_TEXT_KEY,130auth_config: Optional[Any] = None,131client_kwargs: Optional[Dict[str, Any]] = None,132url: Optional[str] = None,133**kwargs: Any,134) -> None:135"""Initialize params."""136try:137import weaviate # noqa138from weaviate import AuthApiKey, Client139except ImportError:140raise ImportError(import_err_msg)141
142if weaviate_client is None:143if isinstance(auth_config, dict):144auth_config = AuthApiKey(**auth_config)145
146client_kwargs = client_kwargs or {}147self._client = Client(148url=url, auth_client_secret=auth_config, **client_kwargs149)150else:151self._client = cast(Client, weaviate_client)152
153# validate class prefix starts with a capital letter154if class_prefix is not None:155logger.warning("class_prefix is deprecated, please use index_name")156# legacy, kept for backward compatibility157index_name = f"{class_prefix}_Node"158
159index_name = index_name or f"LlamaIndex_{uuid4().hex}"160if not index_name[0].isupper():161raise ValueError(162"Index name must start with a capital letter, e.g. 'LlamaIndex'"163)164
165# create default schema if does not exist166if not class_schema_exists(self._client, index_name):167create_default_schema(self._client, index_name)168
169super().__init__(170url=url,171index_name=index_name,172text_key=text_key,173auth_config=auth_config.__dict__ if auth_config else {},174client_kwargs=client_kwargs or {},175)176
177@classmethod178def from_params(179cls,180url: str,181auth_config: Any,182index_name: Optional[str] = None,183text_key: str = DEFAULT_TEXT_KEY,184client_kwargs: Optional[Dict[str, Any]] = None,185**kwargs: Any,186) -> "WeaviateVectorStore":187"""Create WeaviateVectorStore from config."""188try:189import weaviate # noqa190from weaviate import AuthApiKey, Client # noqa191except ImportError:192raise ImportError(import_err_msg)193
194client_kwargs = client_kwargs or {}195weaviate_client = Client(196url=url, auth_client_secret=auth_config, **client_kwargs197)198return cls(199weaviate_client=weaviate_client,200url=url,201auth_config=auth_config.__dict__,202client_kwargs=client_kwargs,203index_name=index_name,204text_key=text_key,205**kwargs,206)207
208@classmethod209def class_name(cls) -> str:210return "WeaviateVectorStore"211
212@property213def client(self) -> Any:214"""Get client."""215return self._client216
217def add(218self,219nodes: List[BaseNode],220**add_kwargs: Any,221) -> List[str]:222"""Add nodes to index.223
224Args:
225nodes: List[BaseNode]: list of nodes with embeddings
226
227"""
228ids = [r.node_id for r in nodes]229
230with self._client.batch as batch:231for node in nodes:232add_node(233self._client,234node,235self.index_name,236batch=batch,237text_key=self.text_key,238)239return ids240
241def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:242"""243Delete nodes using with ref_doc_id.
244
245Args:
246ref_doc_id (str): The doc_id of the document to delete.
247
248"""
249where_filter = {250"path": ["ref_doc_id"],251"operator": "Equal",252"valueText": ref_doc_id,253}254if "filter" in delete_kwargs and delete_kwargs["filter"] is not None:255where_filter = {256"operator": "And",257"operands": [where_filter, delete_kwargs["filter"]], # type: ignore258}259
260query = (261self._client.query.get(self.index_name)262.with_additional(["id"])263.with_where(where_filter)264.with_limit(10000) # 10,000 is the max weaviate can fetch265)266
267query_result = query.do()268parsed_result = parse_get_response(query_result)269entries = parsed_result[self.index_name]270for entry in entries:271self._client.data_object.delete(entry["_additional"]["id"], self.index_name)272
273def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:274"""Query index for top k most similar nodes."""275all_properties = get_all_properties(self._client, self.index_name)276
277# build query278query_builder = self._client.query.get(self.index_name, all_properties)279
280# list of documents to constrain search281if query.doc_ids:282filter_with_doc_ids = {283"operator": "Or",284"operands": [285{"path": ["doc_id"], "operator": "Equal", "valueText": doc_id}286for doc_id in query.doc_ids287],288}289query_builder = query_builder.with_where(filter_with_doc_ids)290
291if query.node_ids:292filter_with_node_ids = {293"operator": "Or",294"operands": [295{"path": ["id"], "operator": "Equal", "valueText": node_id}296for node_id in query.node_ids297],298}299query_builder = query_builder.with_where(filter_with_node_ids)300
301query_builder = query_builder.with_additional(302["id", "vector", "distance", "score"]303)304
305vector = query.query_embedding306similarity_key = "distance"307if query.mode == VectorStoreQueryMode.DEFAULT:308logger.debug("Using vector search")309if vector is not None:310query_builder = query_builder.with_near_vector(311{312"vector": vector,313}314)315elif query.mode == VectorStoreQueryMode.HYBRID:316logger.debug(f"Using hybrid search with alpha {query.alpha}")317similarity_key = "score"318if vector is not None and query.query_str:319query_builder = query_builder.with_hybrid(320query=query.query_str,321alpha=query.alpha,322vector=vector,323)324
325if query.filters is not None:326filter = _to_weaviate_filter(query.filters)327query_builder = query_builder.with_where(filter)328elif "filter" in kwargs and kwargs["filter"] is not None:329query_builder = query_builder.with_where(kwargs["filter"])330
331query_builder = query_builder.with_limit(query.similarity_top_k)332logger.debug(f"Using limit of {query.similarity_top_k}")333
334# execute query335query_result = query_builder.do()336
337# parse results338parsed_result = parse_get_response(query_result)339entries = parsed_result[self.index_name]340
341similarities = []342nodes: List[BaseNode] = []343node_ids = []344
345for i, entry in enumerate(entries):346if i < query.similarity_top_k:347similarities.append(get_node_similarity(entry, similarity_key))348nodes.append(to_node(entry, text_key=self.text_key))349node_ids.append(nodes[-1].node_id)350else:351break352
353return VectorStoreQueryResult(354nodes=nodes, ids=node_ids, similarities=similarities355)356