llama-index
598 строк · 20.3 Кб
1"""Elasticsearch vector store."""
2
3import asyncio4import uuid5from logging import getLogger6from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast7
8import nest_asyncio9import numpy as np10
11from llama_index.legacy.bridge.pydantic import PrivateAttr12from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode13from llama_index.legacy.vector_stores.types import (14BasePydanticVectorStore,15MetadataFilters,16VectorStoreQuery,17VectorStoreQueryMode,18VectorStoreQueryResult,19)
20from llama_index.legacy.vector_stores.utils import (21metadata_dict_to_node,22node_to_metadata_dict,23)
24
25logger = getLogger(__name__)26
27DISTANCE_STRATEGIES = Literal[28"COSINE",29"DOT_PRODUCT",30"EUCLIDEAN_DISTANCE",31]
32
33
34def _get_elasticsearch_client(35*,36es_url: Optional[str] = None,37cloud_id: Optional[str] = None,38api_key: Optional[str] = None,39username: Optional[str] = None,40password: Optional[str] = None,41) -> Any:42"""Get AsyncElasticsearch client.43
44Args:
45es_url: Elasticsearch URL.
46cloud_id: Elasticsearch cloud ID.
47api_key: Elasticsearch API key.
48username: Elasticsearch username.
49password: Elasticsearch password.
50
51Returns:
52AsyncElasticsearch client.
53
54Raises:
55ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
56"""
57try:58import elasticsearch59except ImportError:60raise ImportError(61"Could not import elasticsearch python package. "62"Please install it with `pip install elasticsearch`."63)64
65if es_url and cloud_id:66raise ValueError(67"Both es_url and cloud_id are defined. Please provide only one."68)69
70connection_params: Dict[str, Any] = {}71
72if es_url:73connection_params["hosts"] = [es_url]74elif cloud_id:75connection_params["cloud_id"] = cloud_id76else:77raise ValueError("Please provide either elasticsearch_url or cloud_id.")78
79if api_key:80connection_params["api_key"] = api_key81elif username and password:82connection_params["basic_auth"] = (username, password)83
84sync_es_client = elasticsearch.Elasticsearch(85**connection_params, headers={"user-agent": ElasticsearchStore.get_user_agent()}86)87async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)88try:89sync_es_client.info() # so don't have to 'await' to just get info90except Exception as e:91logger.error(f"Error connecting to Elasticsearch: {e}")92raise93
94return async_es_client95
96
97def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:98"""Convert standard filters to Elasticsearch filter.99
100Args:
101standard_filters: Standard Llama-index filters.
102
103Returns:
104Elasticsearch filter.
105"""
106if len(standard_filters.legacy_filters()) == 1:107filter = standard_filters.legacy_filters()[0]108return {109"term": {110f"metadata.{filter.key}.keyword": {111"value": filter.value,112}113}114}115else:116operands = []117for filter in standard_filters.legacy_filters():118operands.append(119{120"term": {121f"metadata.{filter.key}.keyword": {122"value": filter.value,123}124}125}126)127return {"bool": {"must": operands}}128
129
130def _to_llama_similarities(scores: List[float]) -> List[float]:131if scores is None or len(scores) == 0:132return []133
134scores_to_norm: np.ndarray = np.array(scores)135return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist()136
137
138class ElasticsearchStore(BasePydanticVectorStore):139"""Elasticsearch vector store.140
141Args:
142index_name: Name of the Elasticsearch index.
143es_client: Optional. Pre-existing AsyncElasticsearch client.
144es_url: Optional. Elasticsearch URL.
145es_cloud_id: Optional. Elasticsearch cloud ID.
146es_api_key: Optional. Elasticsearch API key.
147es_user: Optional. Elasticsearch username.
148es_password: Optional. Elasticsearch password.
149text_field: Optional. Name of the Elasticsearch field that stores the text.
150vector_field: Optional. Name of the Elasticsearch field that stores the
151embedding.
152batch_size: Optional. Batch size for bulk indexing. Defaults to 200.
153distance_strategy: Optional. Distance strategy to use for similarity search.
154Defaults to "COSINE".
155
156Raises:
157ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
158ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
159
160"""
161
162stores_text: bool = True163index_name: str164es_client: Optional[Any]165es_url: Optional[str]166es_cloud_id: Optional[str]167es_api_key: Optional[str]168es_user: Optional[str]169es_password: Optional[str]170text_field: str = "content"171vector_field: str = "embedding"172batch_size: int = 200173distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE"174
175_client = PrivateAttr()176
177def __init__(178self,179index_name: str,180es_client: Optional[Any] = None,181es_url: Optional[str] = None,182es_cloud_id: Optional[str] = None,183es_api_key: Optional[str] = None,184es_user: Optional[str] = None,185es_password: Optional[str] = None,186text_field: str = "content",187vector_field: str = "embedding",188batch_size: int = 200,189distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE",190) -> None:191nest_asyncio.apply()192
193if es_client is not None:194self._client = es_client.options(195headers={"user-agent": self.get_user_agent()}196)197elif es_url is not None or es_cloud_id is not None:198self._client = _get_elasticsearch_client(199es_url=es_url,200username=es_user,201password=es_password,202cloud_id=es_cloud_id,203api_key=es_api_key,204)205else:206raise ValueError(207"""Either provide a pre-existing AsyncElasticsearch or valid \208credentials for creating a new connection."""
209)210super().__init__(211index_name=index_name,212es_client=es_client,213es_url=es_url,214es_cloud_id=es_cloud_id,215es_api_key=es_api_key,216es_user=es_user,217es_password=es_password,218text_field=text_field,219vector_field=vector_field,220batch_size=batch_size,221distance_strategy=distance_strategy,222)223
224@property225def client(self) -> Any:226"""Get async elasticsearch client."""227return self._client228
229@staticmethod230def get_user_agent() -> str:231"""Get user agent for elasticsearch client."""232import llama_index.legacy233
234return f"llama_index-py-vs/{llama_index.legacy.__version__}"235
236async def _create_index_if_not_exists(237self, index_name: str, dims_length: Optional[int] = None238) -> None:239"""Create the AsyncElasticsearch index if it doesn't already exist.240
241Args:
242index_name: Name of the AsyncElasticsearch index to create.
243dims_length: Length of the embedding vectors.
244"""
245if self.client.indices.exists(index=index_name):246logger.debug(f"Index {index_name} already exists. Skipping creation.")247
248else:249if dims_length is None:250raise ValueError(251"Cannot create index without specifying dims_length "252"when the index doesn't already exist. We infer "253"dims_length from the first embedding. Check that "254"you have provided an embedding function."255)256
257if self.distance_strategy == "COSINE":258similarityAlgo = "cosine"259elif self.distance_strategy == "EUCLIDEAN_DISTANCE":260similarityAlgo = "l2_norm"261elif self.distance_strategy == "DOT_PRODUCT":262similarityAlgo = "dot_product"263else:264raise ValueError(f"Similarity {self.distance_strategy} not supported.")265
266index_settings = {267"mappings": {268"properties": {269self.vector_field: {270"type": "dense_vector",271"dims": dims_length,272"index": True,273"similarity": similarityAlgo,274},275self.text_field: {"type": "text"},276"metadata": {277"properties": {278"document_id": {"type": "keyword"},279"doc_id": {"type": "keyword"},280"ref_doc_id": {"type": "keyword"},281}282},283}284}285}286
287logger.debug(288f"Creating index {index_name} with mappings {index_settings['mappings']}"289)290await self.client.indices.create(index=index_name, **index_settings)291
292def add(293self,294nodes: List[BaseNode],295*,296create_index_if_not_exists: bool = True,297**add_kwargs: Any,298) -> List[str]:299"""Add nodes to Elasticsearch index.300
301Args:
302nodes: List of nodes with embeddings.
303create_index_if_not_exists: Optional. Whether to create
304the Elasticsearch index if it
305doesn't already exist.
306Defaults to True.
307
308Returns:
309List of node IDs that were added to the index.
310
311Raises:
312ImportError: If elasticsearch['async'] python package is not installed.
313BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
314"""
315return asyncio.get_event_loop().run_until_complete(316self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists)317)318
319async def async_add(320self,321nodes: List[BaseNode],322*,323create_index_if_not_exists: bool = True,324**add_kwargs: Any,325) -> List[str]:326"""Asynchronous method to add nodes to Elasticsearch index.327
328Args:
329nodes: List of nodes with embeddings.
330create_index_if_not_exists: Optional. Whether to create
331the AsyncElasticsearch index if it
332doesn't already exist.
333Defaults to True.
334
335Returns:
336List of node IDs that were added to the index.
337
338Raises:
339ImportError: If elasticsearch python package is not installed.
340BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
341"""
342try:343from elasticsearch.helpers import BulkIndexError, async_bulk344except ImportError:345raise ImportError(346"Could not import elasticsearch[async] python package. "347"Please install it with `pip install 'elasticsearch[async]'`."348)349
350if len(nodes) == 0:351return []352
353if create_index_if_not_exists:354dims_length = len(nodes[0].get_embedding())355await self._create_index_if_not_exists(356index_name=self.index_name, dims_length=dims_length357)358
359embeddings: List[List[float]] = []360texts: List[str] = []361metadatas: List[dict] = []362ids: List[str] = []363for node in nodes:364ids.append(node.node_id)365embeddings.append(node.get_embedding())366texts.append(node.get_content(metadata_mode=MetadataMode.NONE))367metadatas.append(node_to_metadata_dict(node, remove_text=True))368
369requests = []370return_ids = []371
372for i, text in enumerate(texts):373metadata = metadatas[i] if metadatas else {}374_id = ids[i] if ids else str(uuid.uuid4())375request = {376"_op_type": "index",377"_index": self.index_name,378self.vector_field: embeddings[i],379self.text_field: text,380"metadata": metadata,381"_id": _id,382}383requests.append(request)384return_ids.append(_id)385
386await async_bulk(387self.client, requests, chunk_size=self.batch_size, refresh=True388)389try:390success, failed = await async_bulk(391self.client, requests, stats_only=True, refresh=True392)393logger.debug(f"Added {success} and failed to add {failed} texts to index")394
395logger.debug(f"added texts {ids} to index")396return return_ids397except BulkIndexError as e:398logger.error(f"Error adding texts: {e}")399firstError = e.errors[0].get("index", {}).get("error", {})400logger.error(f"First error reason: {firstError.get('reason')}")401raise402
403def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:404"""Delete node from Elasticsearch index.405
406Args:
407ref_doc_id: ID of the node to delete.
408delete_kwargs: Optional. Additional arguments to
409pass to Elasticsearch delete_by_query.
410
411Raises:
412Exception: If Elasticsearch delete_by_query fails.
413"""
414return asyncio.get_event_loop().run_until_complete(415self.adelete(ref_doc_id, **delete_kwargs)416)417
418async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:419"""Async delete node from Elasticsearch index.420
421Args:
422ref_doc_id: ID of the node to delete.
423delete_kwargs: Optional. Additional arguments to
424pass to AsyncElasticsearch delete_by_query.
425
426Raises:
427Exception: If AsyncElasticsearch delete_by_query fails.
428"""
429try:430async with self.client as client:431res = await client.delete_by_query(432index=self.index_name,433query={"term": {"metadata.ref_doc_id": ref_doc_id}},434refresh=True,435**delete_kwargs,436)437if res["deleted"] == 0:438logger.warning(f"Could not find text {ref_doc_id} to delete")439else:440logger.debug(f"Deleted text {ref_doc_id} from index")441except Exception:442logger.error(f"Error deleting text: {ref_doc_id}")443raise444
445def query(446self,447query: VectorStoreQuery,448custom_query: Optional[449Callable[[Dict, Union[VectorStoreQuery, None]], Dict]450] = None,451es_filter: Optional[List[Dict]] = None,452**kwargs: Any,453) -> VectorStoreQueryResult:454"""Query index for top k most similar nodes.455
456Args:
457query_embedding (List[float]): query embedding
458custom_query: Optional. custom query function that takes in the es query
459body and returns a modified query body.
460This can be used to add additional query
461parameters to the Elasticsearch query.
462es_filter: Optional. Elasticsearch filter to apply to the
463query. If filter is provided in the query,
464this filter will be ignored.
465
466Returns:
467VectorStoreQueryResult: Result of the query.
468
469Raises:
470Exception: If Elasticsearch query fails.
471
472"""
473return asyncio.get_event_loop().run_until_complete(474self.aquery(query, custom_query, es_filter, **kwargs)475)476
477async def aquery(478self,479query: VectorStoreQuery,480custom_query: Optional[481Callable[[Dict, Union[VectorStoreQuery, None]], Dict]482] = None,483es_filter: Optional[List[Dict]] = None,484**kwargs: Any,485) -> VectorStoreQueryResult:486"""Asynchronous query index for top k most similar nodes.487
488Args:
489query_embedding (VectorStoreQuery): query embedding
490custom_query: Optional. custom query function that takes in the es query
491body and returns a modified query body.
492This can be used to add additional query
493parameters to the AsyncElasticsearch query.
494es_filter: Optional. AsyncElasticsearch filter to apply to the
495query. If filter is provided in the query,
496this filter will be ignored.
497
498Returns:
499VectorStoreQueryResult: Result of the query.
500
501Raises:
502Exception: If AsyncElasticsearch query fails.
503
504"""
505query_embedding = cast(List[float], query.query_embedding)506
507es_query = {}508
509if query.filters is not None and len(query.filters.legacy_filters()) > 0:510filter = [_to_elasticsearch_filter(query.filters)]511else:512filter = es_filter or []513
514if query.mode in (515VectorStoreQueryMode.DEFAULT,516VectorStoreQueryMode.HYBRID,517):518es_query["knn"] = {519"filter": filter,520"field": self.vector_field,521"query_vector": query_embedding,522"k": query.similarity_top_k,523"num_candidates": query.similarity_top_k * 10,524}525
526if query.mode in (527VectorStoreQueryMode.TEXT_SEARCH,528VectorStoreQueryMode.HYBRID,529):530es_query["query"] = {531"bool": {532"must": {"match": {self.text_field: {"query": query.query_str}}},533"filter": filter,534}535}536
537if query.mode == VectorStoreQueryMode.HYBRID:538es_query["rank"] = {"rrf": {}}539
540if custom_query is not None:541es_query = custom_query(es_query, query)542logger.debug(f"Calling custom_query, Query body now: {es_query}")543
544async with self.client as client:545response = await client.search(546index=self.index_name,547**es_query,548size=query.similarity_top_k,549_source={"excludes": [self.vector_field]},550)551
552top_k_nodes = []553top_k_ids = []554top_k_scores = []555hits = response["hits"]["hits"]556for hit in hits:557source = hit["_source"]558metadata = source.get("metadata", None)559text = source.get(self.text_field, None)560node_id = hit["_id"]561
562try:563node = metadata_dict_to_node(metadata)564node.text = text565except Exception:566# Legacy support for old metadata format567logger.warning(568f"Could not parse metadata from hit {hit['_source']['metadata']}"569)570node_info = source.get("node_info")571relationships = source.get("relationships") or {}572start_char_idx = None573end_char_idx = None574if isinstance(node_info, dict):575start_char_idx = node_info.get("start", None)576end_char_idx = node_info.get("end", None)577
578node = TextNode(579text=text,580metadata=metadata,581id_=node_id,582start_char_idx=start_char_idx,583end_char_idx=end_char_idx,584relationships=relationships,585)586top_k_nodes.append(node)587top_k_ids.append(node_id)588top_k_scores.append(hit.get("_rank", hit["_score"]))589
590if query.mode == VectorStoreQueryMode.HYBRID:591total_rank = sum(top_k_scores)592top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores]593
594return VectorStoreQueryResult(595nodes=top_k_nodes,596ids=top_k_ids,597similarities=_to_llama_similarities(top_k_scores),598)599