llama-index

Форк
0
142 строки · 4.4 Кб
1
"""Milvus reader."""
2

3
from typing import Any, Dict, List, Optional
4
from uuid import uuid4
5

6
from llama_index.legacy.readers.base import BaseReader
7
from llama_index.legacy.schema import Document
8

9

10
class MilvusReader(BaseReader):
11
    """Milvus reader."""
12

13
    def __init__(
14
        self,
15
        host: str = "localhost",
16
        port: int = 19530,
17
        user: str = "",
18
        password: str = "",
19
        use_secure: bool = False,
20
    ):
21
        """Initialize with parameters."""
22
        import_err_msg = (
23
            "`pymilvus` package not found, please run `pip install pymilvus`"
24
        )
25
        try:
26
            import pymilvus  # noqa
27
        except ImportError:
28
            raise ImportError(import_err_msg)
29

30
        from pymilvus import MilvusException
31

32
        self.host = host
33
        self.port = port
34
        self.user = user
35
        self.password = password
36
        self.use_secure = use_secure
37
        self.collection = None
38

39
        self.default_search_params = {
40
            "IVF_FLAT": {"metric_type": "IP", "params": {"nprobe": 10}},
41
            "IVF_SQ8": {"metric_type": "IP", "params": {"nprobe": 10}},
42
            "IVF_PQ": {"metric_type": "IP", "params": {"nprobe": 10}},
43
            "HNSW": {"metric_type": "IP", "params": {"ef": 10}},
44
            "RHNSW_FLAT": {"metric_type": "IP", "params": {"ef": 10}},
45
            "RHNSW_SQ": {"metric_type": "IP", "params": {"ef": 10}},
46
            "RHNSW_PQ": {"metric_type": "IP", "params": {"ef": 10}},
47
            "IVF_HNSW": {"metric_type": "IP", "params": {"nprobe": 10, "ef": 10}},
48
            "ANNOY": {"metric_type": "IP", "params": {"search_k": 10}},
49
            "AUTOINDEX": {"metric_type": "IP", "params": {}},
50
        }
51
        try:
52
            self._create_connection_alias()
53
        except MilvusException:
54
            raise
55

56
    def load_data(
57
        self,
58
        query_vector: List[float],
59
        collection_name: str,
60
        expr: Any = None,
61
        search_params: Optional[dict] = None,
62
        limit: int = 10,
63
    ) -> List[Document]:
64
        """Load data from Milvus.
65

66
        Args:
67
            collection_name (str): Name of the Milvus collection.
68
            query_vector (List[float]): Query vector.
69
            limit (int): Number of results to return.
70

71
        Returns:
72
            List[Document]: A list of documents.
73
        """
74
        from pymilvus import Collection, MilvusException
75

76
        try:
77
            self.collection = Collection(collection_name, using=self.alias)
78
        except MilvusException:
79
            raise
80

81
        assert self.collection is not None
82
        try:
83
            self.collection.load()
84
        except MilvusException:
85
            raise
86
        if search_params is None:
87
            search_params = self._create_search_params()
88

89
        res = self.collection.search(
90
            [query_vector],
91
            "embedding",
92
            param=search_params,
93
            expr=expr,
94
            output_fields=["doc_id", "text"],
95
            limit=limit,
96
        )
97

98
        documents = []
99
        # TODO: In future append embedding when more efficient
100
        for hit in res[0]:
101
            document = Document(
102
                id_=hit.entity.get("doc_id"),
103
                text=hit.entity.get("text"),
104
            )
105

106
            documents.append(document)
107

108
        return documents
109

110
    def _create_connection_alias(self) -> None:
111
        from pymilvus import connections
112

113
        self.alias = None
114
        # Attempt to reuse an open connection
115
        for x in connections.list_connections():
116
            addr = connections.get_connection_addr(x[0])
117
            if (
118
                x[1]
119
                and ("address" in addr)
120
                and (addr["address"] == f"{self.host}:{self.port}")
121
            ):
122
                self.alias = x[0]
123
                break
124

125
        # Connect to the Milvus instance using the passed in Environment variables
126
        if self.alias is None:
127
            self.alias = uuid4().hex
128
            connections.connect(
129
                alias=self.alias,
130
                host=self.host,
131
                port=self.port,
132
                user=self.user,  # type: ignore
133
                password=self.password,  # type: ignore
134
                secure=self.use_secure,
135
            )
136

137
    def _create_search_params(self) -> Dict[str, Any]:
138
        assert self.collection is not None
139
        index = self.collection.indexes[0]._index_params
140
        search_params = self.default_search_params[index["index_type"]]
141
        search_params["metric_type"] = index["metric_type"]
142
        return search_params
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.