llama-index

Форк
0
598 строк · 20.3 Кб
1
"""Elasticsearch vector store."""
2

3
import asyncio
4
import uuid
5
from logging import getLogger
6
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
7

8
import nest_asyncio
9
import numpy as np
10

11
from llama_index.legacy.bridge.pydantic import PrivateAttr
12
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
13
from llama_index.legacy.vector_stores.types import (
14
    BasePydanticVectorStore,
15
    MetadataFilters,
16
    VectorStoreQuery,
17
    VectorStoreQueryMode,
18
    VectorStoreQueryResult,
19
)
20
from llama_index.legacy.vector_stores.utils import (
21
    metadata_dict_to_node,
22
    node_to_metadata_dict,
23
)
24

25
logger = getLogger(__name__)
26

27
DISTANCE_STRATEGIES = Literal[
28
    "COSINE",
29
    "DOT_PRODUCT",
30
    "EUCLIDEAN_DISTANCE",
31
]
32

33

34
def _get_elasticsearch_client(
35
    *,
36
    es_url: Optional[str] = None,
37
    cloud_id: Optional[str] = None,
38
    api_key: Optional[str] = None,
39
    username: Optional[str] = None,
40
    password: Optional[str] = None,
41
) -> Any:
42
    """Get AsyncElasticsearch client.
43

44
    Args:
45
        es_url: Elasticsearch URL.
46
        cloud_id: Elasticsearch cloud ID.
47
        api_key: Elasticsearch API key.
48
        username: Elasticsearch username.
49
        password: Elasticsearch password.
50

51
    Returns:
52
        AsyncElasticsearch client.
53

54
    Raises:
55
        ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
56
    """
57
    try:
58
        import elasticsearch
59
    except ImportError:
60
        raise ImportError(
61
            "Could not import elasticsearch python package. "
62
            "Please install it with `pip install elasticsearch`."
63
        )
64

65
    if es_url and cloud_id:
66
        raise ValueError(
67
            "Both es_url and cloud_id are defined. Please provide only one."
68
        )
69

70
    connection_params: Dict[str, Any] = {}
71

72
    if es_url:
73
        connection_params["hosts"] = [es_url]
74
    elif cloud_id:
75
        connection_params["cloud_id"] = cloud_id
76
    else:
77
        raise ValueError("Please provide either elasticsearch_url or cloud_id.")
78

79
    if api_key:
80
        connection_params["api_key"] = api_key
81
    elif username and password:
82
        connection_params["basic_auth"] = (username, password)
83

84
    sync_es_client = elasticsearch.Elasticsearch(
85
        **connection_params, headers={"user-agent": ElasticsearchStore.get_user_agent()}
86
    )
87
    async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)
88
    try:
89
        sync_es_client.info()  # so don't have to 'await' to just get info
90
    except Exception as e:
91
        logger.error(f"Error connecting to Elasticsearch: {e}")
92
        raise
93

94
    return async_es_client
95

96

97
def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
98
    """Convert standard filters to Elasticsearch filter.
99

100
    Args:
101
        standard_filters: Standard Llama-index filters.
102

103
    Returns:
104
        Elasticsearch filter.
105
    """
106
    if len(standard_filters.legacy_filters()) == 1:
107
        filter = standard_filters.legacy_filters()[0]
108
        return {
109
            "term": {
110
                f"metadata.{filter.key}.keyword": {
111
                    "value": filter.value,
112
                }
113
            }
114
        }
115
    else:
116
        operands = []
117
        for filter in standard_filters.legacy_filters():
118
            operands.append(
119
                {
120
                    "term": {
121
                        f"metadata.{filter.key}.keyword": {
122
                            "value": filter.value,
123
                        }
124
                    }
125
                }
126
            )
127
        return {"bool": {"must": operands}}
128

129

130
def _to_llama_similarities(scores: List[float]) -> List[float]:
131
    if scores is None or len(scores) == 0:
132
        return []
133

134
    scores_to_norm: np.ndarray = np.array(scores)
135
    return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist()
136

137

138
class ElasticsearchStore(BasePydanticVectorStore):
139
    """Elasticsearch vector store.
140

141
    Args:
142
        index_name: Name of the Elasticsearch index.
143
        es_client: Optional. Pre-existing AsyncElasticsearch client.
144
        es_url: Optional. Elasticsearch URL.
145
        es_cloud_id: Optional. Elasticsearch cloud ID.
146
        es_api_key: Optional. Elasticsearch API key.
147
        es_user: Optional. Elasticsearch username.
148
        es_password: Optional. Elasticsearch password.
149
        text_field: Optional. Name of the Elasticsearch field that stores the text.
150
        vector_field: Optional. Name of the Elasticsearch field that stores the
151
                    embedding.
152
        batch_size: Optional. Batch size for bulk indexing. Defaults to 200.
153
        distance_strategy: Optional. Distance strategy to use for similarity search.
154
                        Defaults to "COSINE".
155

156
    Raises:
157
        ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
158
        ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
159

160
    """
161

162
    stores_text: bool = True
163
    index_name: str
164
    es_client: Optional[Any]
165
    es_url: Optional[str]
166
    es_cloud_id: Optional[str]
167
    es_api_key: Optional[str]
168
    es_user: Optional[str]
169
    es_password: Optional[str]
170
    text_field: str = "content"
171
    vector_field: str = "embedding"
172
    batch_size: int = 200
173
    distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE"
174

175
    _client = PrivateAttr()
176

177
    def __init__(
178
        self,
179
        index_name: str,
180
        es_client: Optional[Any] = None,
181
        es_url: Optional[str] = None,
182
        es_cloud_id: Optional[str] = None,
183
        es_api_key: Optional[str] = None,
184
        es_user: Optional[str] = None,
185
        es_password: Optional[str] = None,
186
        text_field: str = "content",
187
        vector_field: str = "embedding",
188
        batch_size: int = 200,
189
        distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE",
190
    ) -> None:
191
        nest_asyncio.apply()
192

193
        if es_client is not None:
194
            self._client = es_client.options(
195
                headers={"user-agent": self.get_user_agent()}
196
            )
197
        elif es_url is not None or es_cloud_id is not None:
198
            self._client = _get_elasticsearch_client(
199
                es_url=es_url,
200
                username=es_user,
201
                password=es_password,
202
                cloud_id=es_cloud_id,
203
                api_key=es_api_key,
204
            )
205
        else:
206
            raise ValueError(
207
                """Either provide a pre-existing AsyncElasticsearch or valid \
208
                credentials for creating a new connection."""
209
            )
210
        super().__init__(
211
            index_name=index_name,
212
            es_client=es_client,
213
            es_url=es_url,
214
            es_cloud_id=es_cloud_id,
215
            es_api_key=es_api_key,
216
            es_user=es_user,
217
            es_password=es_password,
218
            text_field=text_field,
219
            vector_field=vector_field,
220
            batch_size=batch_size,
221
            distance_strategy=distance_strategy,
222
        )
223

224
    @property
225
    def client(self) -> Any:
226
        """Get async elasticsearch client."""
227
        return self._client
228

229
    @staticmethod
230
    def get_user_agent() -> str:
231
        """Get user agent for elasticsearch client."""
232
        import llama_index.legacy
233

234
        return f"llama_index-py-vs/{llama_index.legacy.__version__}"
235

236
    async def _create_index_if_not_exists(
237
        self, index_name: str, dims_length: Optional[int] = None
238
    ) -> None:
239
        """Create the AsyncElasticsearch index if it doesn't already exist.
240

241
        Args:
242
            index_name: Name of the AsyncElasticsearch index to create.
243
            dims_length: Length of the embedding vectors.
244
        """
245
        if self.client.indices.exists(index=index_name):
246
            logger.debug(f"Index {index_name} already exists. Skipping creation.")
247

248
        else:
249
            if dims_length is None:
250
                raise ValueError(
251
                    "Cannot create index without specifying dims_length "
252
                    "when the index doesn't already exist. We infer "
253
                    "dims_length from the first embedding. Check that "
254
                    "you have provided an embedding function."
255
                )
256

257
            if self.distance_strategy == "COSINE":
258
                similarityAlgo = "cosine"
259
            elif self.distance_strategy == "EUCLIDEAN_DISTANCE":
260
                similarityAlgo = "l2_norm"
261
            elif self.distance_strategy == "DOT_PRODUCT":
262
                similarityAlgo = "dot_product"
263
            else:
264
                raise ValueError(f"Similarity {self.distance_strategy} not supported.")
265

266
            index_settings = {
267
                "mappings": {
268
                    "properties": {
269
                        self.vector_field: {
270
                            "type": "dense_vector",
271
                            "dims": dims_length,
272
                            "index": True,
273
                            "similarity": similarityAlgo,
274
                        },
275
                        self.text_field: {"type": "text"},
276
                        "metadata": {
277
                            "properties": {
278
                                "document_id": {"type": "keyword"},
279
                                "doc_id": {"type": "keyword"},
280
                                "ref_doc_id": {"type": "keyword"},
281
                            }
282
                        },
283
                    }
284
                }
285
            }
286

287
            logger.debug(
288
                f"Creating index {index_name} with mappings {index_settings['mappings']}"
289
            )
290
            await self.client.indices.create(index=index_name, **index_settings)
291

292
    def add(
293
        self,
294
        nodes: List[BaseNode],
295
        *,
296
        create_index_if_not_exists: bool = True,
297
        **add_kwargs: Any,
298
    ) -> List[str]:
299
        """Add nodes to Elasticsearch index.
300

301
        Args:
302
            nodes: List of nodes with embeddings.
303
            create_index_if_not_exists: Optional. Whether to create
304
                                        the Elasticsearch index if it
305
                                        doesn't already exist.
306
                                        Defaults to True.
307

308
        Returns:
309
            List of node IDs that were added to the index.
310

311
        Raises:
312
            ImportError: If elasticsearch['async'] python package is not installed.
313
            BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
314
        """
315
        return asyncio.get_event_loop().run_until_complete(
316
            self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists)
317
        )
318

319
    async def async_add(
320
        self,
321
        nodes: List[BaseNode],
322
        *,
323
        create_index_if_not_exists: bool = True,
324
        **add_kwargs: Any,
325
    ) -> List[str]:
326
        """Asynchronous method to add nodes to Elasticsearch index.
327

328
        Args:
329
            nodes: List of nodes with embeddings.
330
            create_index_if_not_exists: Optional. Whether to create
331
                                        the AsyncElasticsearch index if it
332
                                        doesn't already exist.
333
                                        Defaults to True.
334

335
        Returns:
336
            List of node IDs that were added to the index.
337

338
        Raises:
339
            ImportError: If elasticsearch python package is not installed.
340
            BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
341
        """
342
        try:
343
            from elasticsearch.helpers import BulkIndexError, async_bulk
344
        except ImportError:
345
            raise ImportError(
346
                "Could not import elasticsearch[async] python package. "
347
                "Please install it with `pip install 'elasticsearch[async]'`."
348
            )
349

350
        if len(nodes) == 0:
351
            return []
352

353
        if create_index_if_not_exists:
354
            dims_length = len(nodes[0].get_embedding())
355
            await self._create_index_if_not_exists(
356
                index_name=self.index_name, dims_length=dims_length
357
            )
358

359
        embeddings: List[List[float]] = []
360
        texts: List[str] = []
361
        metadatas: List[dict] = []
362
        ids: List[str] = []
363
        for node in nodes:
364
            ids.append(node.node_id)
365
            embeddings.append(node.get_embedding())
366
            texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
367
            metadatas.append(node_to_metadata_dict(node, remove_text=True))
368

369
        requests = []
370
        return_ids = []
371

372
        for i, text in enumerate(texts):
373
            metadata = metadatas[i] if metadatas else {}
374
            _id = ids[i] if ids else str(uuid.uuid4())
375
            request = {
376
                "_op_type": "index",
377
                "_index": self.index_name,
378
                self.vector_field: embeddings[i],
379
                self.text_field: text,
380
                "metadata": metadata,
381
                "_id": _id,
382
            }
383
            requests.append(request)
384
            return_ids.append(_id)
385

386
        await async_bulk(
387
            self.client, requests, chunk_size=self.batch_size, refresh=True
388
        )
389
        try:
390
            success, failed = await async_bulk(
391
                self.client, requests, stats_only=True, refresh=True
392
            )
393
            logger.debug(f"Added {success} and failed to add {failed} texts to index")
394

395
            logger.debug(f"added texts {ids} to index")
396
            return return_ids
397
        except BulkIndexError as e:
398
            logger.error(f"Error adding texts: {e}")
399
            firstError = e.errors[0].get("index", {}).get("error", {})
400
            logger.error(f"First error reason: {firstError.get('reason')}")
401
            raise
402

403
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
404
        """Delete node from Elasticsearch index.
405

406
        Args:
407
            ref_doc_id: ID of the node to delete.
408
            delete_kwargs: Optional. Additional arguments to
409
                        pass to Elasticsearch delete_by_query.
410

411
        Raises:
412
            Exception: If Elasticsearch delete_by_query fails.
413
        """
414
        return asyncio.get_event_loop().run_until_complete(
415
            self.adelete(ref_doc_id, **delete_kwargs)
416
        )
417

418
    async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
419
        """Async delete node from Elasticsearch index.
420

421
        Args:
422
            ref_doc_id: ID of the node to delete.
423
            delete_kwargs: Optional. Additional arguments to
424
                        pass to AsyncElasticsearch delete_by_query.
425

426
        Raises:
427
            Exception: If AsyncElasticsearch delete_by_query fails.
428
        """
429
        try:
430
            async with self.client as client:
431
                res = await client.delete_by_query(
432
                    index=self.index_name,
433
                    query={"term": {"metadata.ref_doc_id": ref_doc_id}},
434
                    refresh=True,
435
                    **delete_kwargs,
436
                )
437
            if res["deleted"] == 0:
438
                logger.warning(f"Could not find text {ref_doc_id} to delete")
439
            else:
440
                logger.debug(f"Deleted text {ref_doc_id} from index")
441
        except Exception:
442
            logger.error(f"Error deleting text: {ref_doc_id}")
443
            raise
444

445
    def query(
446
        self,
447
        query: VectorStoreQuery,
448
        custom_query: Optional[
449
            Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
450
        ] = None,
451
        es_filter: Optional[List[Dict]] = None,
452
        **kwargs: Any,
453
    ) -> VectorStoreQueryResult:
454
        """Query index for top k most similar nodes.
455

456
        Args:
457
            query_embedding (List[float]): query embedding
458
            custom_query: Optional. custom query function that takes in the es query
459
                        body and returns a modified query body.
460
                        This can be used to add additional query
461
                        parameters to the Elasticsearch query.
462
            es_filter: Optional. Elasticsearch filter to apply to the
463
                        query. If filter is provided in the query,
464
                        this filter will be ignored.
465

466
        Returns:
467
            VectorStoreQueryResult: Result of the query.
468

469
        Raises:
470
            Exception: If Elasticsearch query fails.
471

472
        """
473
        return asyncio.get_event_loop().run_until_complete(
474
            self.aquery(query, custom_query, es_filter, **kwargs)
475
        )
476

477
    async def aquery(
478
        self,
479
        query: VectorStoreQuery,
480
        custom_query: Optional[
481
            Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
482
        ] = None,
483
        es_filter: Optional[List[Dict]] = None,
484
        **kwargs: Any,
485
    ) -> VectorStoreQueryResult:
486
        """Asynchronous query index for top k most similar nodes.
487

488
        Args:
489
            query_embedding (VectorStoreQuery): query embedding
490
            custom_query: Optional. custom query function that takes in the es query
491
                        body and returns a modified query body.
492
                        This can be used to add additional query
493
                        parameters to the AsyncElasticsearch query.
494
            es_filter: Optional. AsyncElasticsearch filter to apply to the
495
                        query. If filter is provided in the query,
496
                        this filter will be ignored.
497

498
        Returns:
499
            VectorStoreQueryResult: Result of the query.
500

501
        Raises:
502
            Exception: If AsyncElasticsearch query fails.
503

504
        """
505
        query_embedding = cast(List[float], query.query_embedding)
506

507
        es_query = {}
508

509
        if query.filters is not None and len(query.filters.legacy_filters()) > 0:
510
            filter = [_to_elasticsearch_filter(query.filters)]
511
        else:
512
            filter = es_filter or []
513

514
        if query.mode in (
515
            VectorStoreQueryMode.DEFAULT,
516
            VectorStoreQueryMode.HYBRID,
517
        ):
518
            es_query["knn"] = {
519
                "filter": filter,
520
                "field": self.vector_field,
521
                "query_vector": query_embedding,
522
                "k": query.similarity_top_k,
523
                "num_candidates": query.similarity_top_k * 10,
524
            }
525

526
        if query.mode in (
527
            VectorStoreQueryMode.TEXT_SEARCH,
528
            VectorStoreQueryMode.HYBRID,
529
        ):
530
            es_query["query"] = {
531
                "bool": {
532
                    "must": {"match": {self.text_field: {"query": query.query_str}}},
533
                    "filter": filter,
534
                }
535
            }
536

537
        if query.mode == VectorStoreQueryMode.HYBRID:
538
            es_query["rank"] = {"rrf": {}}
539

540
        if custom_query is not None:
541
            es_query = custom_query(es_query, query)
542
            logger.debug(f"Calling custom_query, Query body now: {es_query}")
543

544
        async with self.client as client:
545
            response = await client.search(
546
                index=self.index_name,
547
                **es_query,
548
                size=query.similarity_top_k,
549
                _source={"excludes": [self.vector_field]},
550
            )
551

552
        top_k_nodes = []
553
        top_k_ids = []
554
        top_k_scores = []
555
        hits = response["hits"]["hits"]
556
        for hit in hits:
557
            source = hit["_source"]
558
            metadata = source.get("metadata", None)
559
            text = source.get(self.text_field, None)
560
            node_id = hit["_id"]
561

562
            try:
563
                node = metadata_dict_to_node(metadata)
564
                node.text = text
565
            except Exception:
566
                # Legacy support for old metadata format
567
                logger.warning(
568
                    f"Could not parse metadata from hit {hit['_source']['metadata']}"
569
                )
570
                node_info = source.get("node_info")
571
                relationships = source.get("relationships") or {}
572
                start_char_idx = None
573
                end_char_idx = None
574
                if isinstance(node_info, dict):
575
                    start_char_idx = node_info.get("start", None)
576
                    end_char_idx = node_info.get("end", None)
577

578
                node = TextNode(
579
                    text=text,
580
                    metadata=metadata,
581
                    id_=node_id,
582
                    start_char_idx=start_char_idx,
583
                    end_char_idx=end_char_idx,
584
                    relationships=relationships,
585
                )
586
            top_k_nodes.append(node)
587
            top_k_ids.append(node_id)
588
            top_k_scores.append(hit.get("_rank", hit["_score"]))
589

590
        if query.mode == VectorStoreQueryMode.HYBRID:
591
            total_rank = sum(top_k_scores)
592
            top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores]
593

594
        return VectorStoreQueryResult(
595
            nodes=top_k_nodes,
596
            ids=top_k_ids,
597
            similarities=_to_llama_similarities(top_k_scores),
598
        )
599

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.