llama-index

Форк
0
492 строки · 15.0 Кб
1
"""Elasticsearch/Opensearch vector store."""
2

3
import json
4
import uuid
5
from typing import Any, Dict, Iterable, List, Optional, Union, cast
6

7
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
8
from llama_index.legacy.vector_stores.types import (
9
    MetadataFilters,
10
    VectorStore,
11
    VectorStoreQuery,
12
    VectorStoreQueryMode,
13
    VectorStoreQueryResult,
14
)
15
from llama_index.legacy.vector_stores.utils import (
16
    metadata_dict_to_node,
17
    node_to_metadata_dict,
18
)
19

20
IMPORT_OPENSEARCH_PY_ERROR = (
21
    "Could not import OpenSearch. Please install it with `pip install opensearch-py`."
22
)
23
INVALID_HYBRID_QUERY_ERROR = (
24
    "Please specify the lexical_query and search_pipeline for hybrid search."
25
)
26
MATCH_ALL_QUERY = {"match_all": {}}  # type: Dict
27

28

29
def _import_opensearch() -> Any:
30
    """Import OpenSearch if available, otherwise raise error."""
31
    try:
32
        from opensearchpy import OpenSearch
33
    except ImportError:
34
        raise ValueError(IMPORT_OPENSEARCH_PY_ERROR)
35
    return OpenSearch
36

37

38
def _import_bulk() -> Any:
39
    """Import bulk if available, otherwise raise error."""
40
    try:
41
        from opensearchpy.helpers import bulk
42
    except ImportError:
43
        raise ValueError(IMPORT_OPENSEARCH_PY_ERROR)
44
    return bulk
45

46

47
def _import_not_found_error() -> Any:
48
    """Import not found error if available, otherwise raise error."""
49
    try:
50
        from opensearchpy.exceptions import NotFoundError
51
    except ImportError:
52
        raise ValueError(IMPORT_OPENSEARCH_PY_ERROR)
53
    return NotFoundError
54

55

56
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
57
    """Get OpenSearch client from the opensearch_url, otherwise raise error."""
58
    try:
59
        opensearch = _import_opensearch()
60
        client = opensearch(opensearch_url, **kwargs)
61

62
    except ValueError as e:
63
        raise ValueError(
64
            f"OpenSearch client string provided is not in proper format. "
65
            f"Got error: {e} "
66
        )
67
    return client
68

69

70
def _bulk_ingest_embeddings(
71
    client: Any,
72
    index_name: str,
73
    embeddings: List[List[float]],
74
    texts: Iterable[str],
75
    metadatas: Optional[List[dict]] = None,
76
    ids: Optional[List[str]] = None,
77
    vector_field: str = "embedding",
78
    text_field: str = "content",
79
    mapping: Optional[Dict] = None,
80
    max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
81
    is_aoss: bool = False,
82
) -> List[str]:
83
    """Bulk Ingest Embeddings into given index."""
84
    if not mapping:
85
        mapping = {}
86

87
    bulk = _import_bulk()
88
    not_found_error = _import_not_found_error()
89
    requests = []
90
    return_ids = []
91
    mapping = mapping
92

93
    try:
94
        client.indices.get(index=index_name)
95
    except not_found_error:
96
        client.indices.create(index=index_name, body=mapping)
97

98
    for i, text in enumerate(texts):
99
        metadata = metadatas[i] if metadatas else {}
100
        _id = ids[i] if ids else str(uuid.uuid4())
101
        request = {
102
            "_op_type": "index",
103
            "_index": index_name,
104
            vector_field: embeddings[i],
105
            text_field: text,
106
            "metadata": metadata,
107
        }
108
        if is_aoss:
109
            request["id"] = _id
110
        else:
111
            request["_id"] = _id
112
        requests.append(request)
113
        return_ids.append(_id)
114
    bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
115
    if not is_aoss:
116
        client.indices.refresh(index=index_name)
117
    return return_ids
118

119

120
def _default_approximate_search_query(
121
    query_vector: List[float],
122
    k: int = 4,
123
    vector_field: str = "embedding",
124
) -> Dict:
125
    """For Approximate k-NN Search, this is the default query."""
126
    return {
127
        "size": k,
128
        "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
129
    }
130

131

132
def _parse_filters(filters: Optional[MetadataFilters]) -> Any:
133
    pre_filter = []
134
    if filters is not None:
135
        for f in filters.legacy_filters():
136
            pre_filter.append({f.key: json.loads(str(f.value))})
137

138
    return pre_filter
139

140

141
def _knn_search_query(
142
    embedding_field: str,
143
    query_embedding: List[float],
144
    k: int,
145
    filters: Optional[MetadataFilters] = None,
146
) -> Dict:
147
    """Do knn search.
148

149
    If there are no filters do approx-knn search.
150
    If there are (pre)-filters, do an exhaustive exact knn search using 'painless
151
        scripting'.
152

153
    Note that approximate knn search does not support pre-filtering.
154

155
    Args:
156
        query_embedding: Vector embedding to query.
157
        k: Maximum number of results.
158
        filters: Optional filters to apply before the search.
159
            Supports filter-context queries documented at
160
            https://opensearch.org/docs/latest/query-dsl/query-filter-context/
161

162
    Returns:
163
        Up to k docs closest to query_embedding
164
    """
165
    if filters is None:
166
        search_query = _default_approximate_search_query(
167
            query_embedding, k, vector_field=embedding_field
168
        )
169
    else:
170
        pre_filter = _parse_filters(filters)
171
        # https://opensearch.org/docs/latest/search-plugins/knn/painless-functions/
172
        search_query = _default_painless_scripting_query(
173
            query_embedding,
174
            k,
175
            space_type="l2Squared",
176
            pre_filter={"bool": {"filter": pre_filter}},
177
            vector_field=embedding_field,
178
        )
179

180
    return search_query
181

182

183
def _hybrid_search_query(
184
    text_field: str,
185
    query_str: str,
186
    embedding_field: str,
187
    query_embedding: List[float],
188
    k: int,
189
    filters: Optional[MetadataFilters] = None,
190
) -> Dict:
191
    knn_query = _knn_search_query(embedding_field, query_embedding, k, filters)["query"]
192
    lexical_query = {"must": {"match": {text_field: {"query": query_str}}}}
193

194
    parsed_filters = _parse_filters(filters)
195
    if len(parsed_filters) > 0:
196
        lexical_query["filter"] = parsed_filters
197
    return {
198
        "size": k,
199
        "query": {"hybrid": {"queries": [{"bool": lexical_query}, knn_query]}},
200
    }
201

202

203
def __get_painless_scripting_source(
204
    space_type: str, vector_field: str = "embedding"
205
) -> str:
206
    """For Painless Scripting, it returns the script source based on space type."""
207
    source_value = f"(1.0 + {space_type}(params.query_value, doc['{vector_field}']))"
208
    if space_type == "cosineSimilarity":
209
        return source_value
210
    else:
211
        return f"1/{source_value}"
212

213

214
def _default_painless_scripting_query(
215
    query_vector: List[float],
216
    k: int = 4,
217
    space_type: str = "l2Squared",
218
    pre_filter: Optional[Union[Dict, List]] = None,
219
    vector_field: str = "embedding",
220
) -> Dict:
221
    """For Painless Scripting Search, this is the default query."""
222
    if not pre_filter:
223
        pre_filter = MATCH_ALL_QUERY
224

225
    source = __get_painless_scripting_source(space_type, vector_field)
226
    return {
227
        "size": k,
228
        "query": {
229
            "script_score": {
230
                "query": pre_filter,
231
                "script": {
232
                    "source": source,
233
                    "params": {
234
                        "field": vector_field,
235
                        "query_value": query_vector,
236
                    },
237
                },
238
            }
239
        },
240
    }
241

242

243
def _is_aoss_enabled(http_auth: Any) -> bool:
244
    """Check if the service is http_auth is set as `aoss`."""
245
    if (
246
        http_auth is not None
247
        and hasattr(http_auth, "service")
248
        and http_auth.service == "aoss"
249
    ):
250
        return True
251
    return False
252

253

254
class OpensearchVectorClient:
255
    """Object encapsulating an Opensearch index that has vector search enabled.
256

257
    If the index does not yet exist, it is created during init.
258
    Therefore, the underlying index is assumed to either:
259
    1) not exist yet or 2) be created due to previous usage of this class.
260

261
    Args:
262
        endpoint (str): URL (http/https) of elasticsearch endpoint
263
        index (str): Name of the elasticsearch index
264
        dim (int): Dimension of the vector
265
        embedding_field (str): Name of the field in the index to store
266
            embedding array in.
267
        text_field (str): Name of the field to grab text from
268
        method (Optional[dict]): Opensearch "method" JSON obj for configuring
269
            the KNN index.
270
            This includes engine, metric, and other config params. Defaults to:
271
            {"name": "hnsw", "space_type": "l2", "engine": "faiss",
272
            "parameters": {"ef_construction": 256, "m": 48}}
273
        **kwargs: Optional arguments passed to the OpenSearch client from opensearch-py.
274

275
    """
276

277
    def __init__(
278
        self,
279
        endpoint: str,
280
        index: str,
281
        dim: int,
282
        embedding_field: str = "embedding",
283
        text_field: str = "content",
284
        method: Optional[dict] = None,
285
        max_chunk_bytes: int = 1 * 1024 * 1024,
286
        search_pipeline: Optional[str] = None,
287
        **kwargs: Any,
288
    ):
289
        """Init params."""
290
        if method is None:
291
            method = {
292
                "name": "hnsw",
293
                "space_type": "l2",
294
                "engine": "nmslib",
295
                "parameters": {"ef_construction": 256, "m": 48},
296
            }
297
        if embedding_field is None:
298
            embedding_field = "embedding"
299
        self._embedding_field = embedding_field
300

301
        self._endpoint = endpoint
302
        self._dim = dim
303
        self._index = index
304
        self._text_field = text_field
305
        self._max_chunk_bytes = max_chunk_bytes
306

307
        self._search_pipeline = search_pipeline
308
        http_auth = kwargs.get("http_auth")
309
        self.is_aoss = _is_aoss_enabled(http_auth=http_auth)
310
        # initialize mapping
311
        idx_conf = {
312
            "settings": {"index": {"knn": True, "knn.algo_param.ef_search": 100}},
313
            "mappings": {
314
                "properties": {
315
                    embedding_field: {
316
                        "type": "knn_vector",
317
                        "dimension": dim,
318
                        "method": method,
319
                    },
320
                }
321
            },
322
        }
323
        self._os_client = _get_opensearch_client(self._endpoint, **kwargs)
324
        not_found_error = _import_not_found_error()
325
        try:
326
            self._os_client.indices.get(index=self._index)
327
        except not_found_error:
328
            self._os_client.indices.create(index=self._index, body=idx_conf)
329
            self._os_client.indices.refresh(index=self._index)
330

331
    def index_results(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
332
        """Store results in the index."""
333
        embeddings: List[List[float]] = []
334
        texts: List[str] = []
335
        metadatas: List[dict] = []
336
        ids: List[str] = []
337
        for node in nodes:
338
            ids.append(node.node_id)
339
            embeddings.append(node.get_embedding())
340
            texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
341
            metadatas.append(node_to_metadata_dict(node, remove_text=True))
342

343
        return _bulk_ingest_embeddings(
344
            self._os_client,
345
            self._index,
346
            embeddings,
347
            texts,
348
            metadatas=metadatas,
349
            ids=ids,
350
            vector_field=self._embedding_field,
351
            text_field=self._text_field,
352
            mapping=None,
353
            max_chunk_bytes=self._max_chunk_bytes,
354
            is_aoss=self.is_aoss,
355
        )
356

357
    def delete_doc_id(self, doc_id: str) -> None:
358
        """Delete a document.
359

360
        Args:
361
            doc_id (str): document id
362
        """
363
        self._os_client.delete(index=self._index, id=doc_id)
364

365
    def query(
366
        self,
367
        query_mode: VectorStoreQueryMode,
368
        query_str: Optional[str],
369
        query_embedding: List[float],
370
        k: int,
371
        filters: Optional[MetadataFilters] = None,
372
    ) -> VectorStoreQueryResult:
373
        if query_mode == VectorStoreQueryMode.HYBRID:
374
            if query_str is None or self._search_pipeline is None:
375
                raise ValueError(INVALID_HYBRID_QUERY_ERROR)
376
            search_query = _hybrid_search_query(
377
                self._text_field,
378
                query_str,
379
                self._embedding_field,
380
                query_embedding,
381
                k,
382
                filters=filters,
383
            )
384
            params = {"search_pipeline": self._search_pipeline}
385
        else:
386
            search_query = _knn_search_query(
387
                self._embedding_field, query_embedding, k, filters=filters
388
            )
389
            params = None
390

391
        res = self._os_client.search(
392
            index=self._index, body=search_query, params=params
393
        )
394
        nodes = []
395
        ids = []
396
        scores = []
397
        for hit in res["hits"]["hits"]:
398
            source = hit["_source"]
399
            node_id = hit["_id"]
400
            text = source[self._text_field]
401
            metadata = source.get("metadata", None)
402

403
            try:
404
                node = metadata_dict_to_node(metadata)
405
                node.text = text
406
            except Exception:
407
                # TODO: Legacy support for old nodes
408
                node_info = source.get("node_info")
409
                relationships = source.get("relationships") or {}
410
                start_char_idx = None
411
                end_char_idx = None
412
                if isinstance(node_info, dict):
413
                    start_char_idx = node_info.get("start", None)
414
                    end_char_idx = node_info.get("end", None)
415

416
                node = TextNode(
417
                    text=text,
418
                    metadata=metadata,
419
                    id_=node_id,
420
                    start_char_idx=start_char_idx,
421
                    end_char_idx=end_char_idx,
422
                    relationships=relationships,
423
                    extra_info=source,
424
                )
425
            ids.append(node_id)
426
            nodes.append(node)
427
            scores.append(hit["_score"])
428
        return VectorStoreQueryResult(nodes=nodes, ids=ids, similarities=scores)
429

430

431
class OpensearchVectorStore(VectorStore):
432
    """Elasticsearch/Opensearch vector store.
433

434
    Args:
435
        client (OpensearchVectorClient): Vector index client to use
436
            for data insertion/querying.
437
    """
438

439
    stores_text: bool = True
440

441
    def __init__(
442
        self,
443
        client: OpensearchVectorClient,
444
    ) -> None:
445
        """Initialize params."""
446
        self._client = client
447

448
    @property
449
    def client(self) -> Any:
450
        """Get client."""
451
        return self._client
452

453
    def add(
454
        self,
455
        nodes: List[BaseNode],
456
        **add_kwargs: Any,
457
    ) -> List[str]:
458
        """Add nodes to index.
459

460
        Args:
461
            nodes: List[BaseNode]: list of nodes with embeddings.
462

463
        """
464
        self._client.index_results(nodes)
465
        return [result.node_id for result in nodes]
466

467
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
468
        """
469
        Delete nodes using with ref_doc_id.
470

471
        Args:
472
            ref_doc_id (str): The doc_id of the document to delete.
473

474
        """
475
        self._client.delete_doc_id(ref_doc_id)
476

477
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
478
        """Query index for top k most similar nodes.
479

480
        Args:
481
            query (VectorStoreQuery): Store query object.
482

483
        """
484
        query_embedding = cast(List[float], query.query_embedding)
485

486
        return self._client.query(
487
            query.mode,
488
            query.query_str,
489
            query_embedding,
490
            query.similarity_top_k,
491
            filters=query.filters,
492
        )
493

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.