llama-index

Форк
0
189 строк · 6.7 Кб
1
"""Qdrant reader."""
2

3
from typing import Dict, List, Optional, cast
4

5
from llama_index.legacy.readers.base import BaseReader
6
from llama_index.legacy.schema import Document
7

8

9
class QdrantReader(BaseReader):
10
    """Qdrant reader.
11

12
    Retrieve documents from existing Qdrant collections.
13

14
    Args:
15
        location:
16
            If `:memory:` - use in-memory Qdrant instance.
17
            If `str` - use it as a `url` parameter.
18
            If `None` - use default values for `host` and `port`.
19
        url:
20
            either host or str of
21
            "Optional[scheme], host, Optional[port], Optional[prefix]".
22
            Default: `None`
23
        port: Port of the REST API interface. Default: 6333
24
        grpc_port: Port of the gRPC interface. Default: 6334
25
        prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
26
        https: If `true` - use HTTPS(SSL) protocol. Default: `false`
27
        api_key: API key for authentication in Qdrant Cloud. Default: `None`
28
        prefix:
29
            If not `None` - add `prefix` to the REST URL path.
30
            Example: `service/v1` will result in
31
            `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API.
32
            Default: `None`
33
        timeout:
34
            Timeout for REST and gRPC API requests.
35
            Default: 5.0 seconds for REST and unlimited for gRPC
36
        host: Host name of Qdrant service. If url and host are None, set to 'localhost'.
37
            Default: `None`
38
    """
39

40
    def __init__(
41
        self,
42
        location: Optional[str] = None,
43
        url: Optional[str] = None,
44
        port: Optional[int] = 6333,
45
        grpc_port: int = 6334,
46
        prefer_grpc: bool = False,
47
        https: Optional[bool] = None,
48
        api_key: Optional[str] = None,
49
        prefix: Optional[str] = None,
50
        timeout: Optional[float] = None,
51
        host: Optional[str] = None,
52
        path: Optional[str] = None,
53
    ):
54
        """Initialize with parameters."""
55
        import_err_msg = (
56
            "`qdrant-client` package not found, please run `pip install qdrant-client`"
57
        )
58
        try:
59
            import qdrant_client
60
        except ImportError:
61
            raise ImportError(import_err_msg)
62

63
        self._client = qdrant_client.QdrantClient(
64
            location=location,
65
            url=url,
66
            port=port,
67
            grpc_port=grpc_port,
68
            prefer_grpc=prefer_grpc,
69
            https=https,
70
            api_key=api_key,
71
            prefix=prefix,
72
            timeout=timeout,
73
            host=host,
74
            path=path,
75
        )
76

77
    def load_data(
78
        self,
79
        collection_name: str,
80
        query_vector: List[float],
81
        should_search_mapping: Optional[Dict[str, str]] = None,
82
        must_search_mapping: Optional[Dict[str, str]] = None,
83
        must_not_search_mapping: Optional[Dict[str, str]] = None,
84
        rang_search_mapping: Optional[Dict[str, Dict[str, float]]] = None,
85
        limit: int = 10,
86
    ) -> List[Document]:
87
        """Load data from Qdrant.
88

89
        Args:
90
            collection_name (str): Name of the Qdrant collection.
91
            query_vector (List[float]): Query vector.
92
            should_search_mapping (Optional[Dict[str, str]]): Mapping from field name
93
                to query string.
94
            must_search_mapping (Optional[Dict[str, str]]): Mapping from field name
95
                to query string.
96
            must_not_search_mapping (Optional[Dict[str, str]]): Mapping from field
97
                name to query string.
98
            rang_search_mapping (Optional[Dict[str, Dict[str, float]]]): Mapping from
99
                field name to range query.
100
            limit (int): Number of results to return.
101

102
        Example:
103
            reader = QdrantReader()
104
            reader.load_data(
105
                 collection_name="test_collection",
106
                 query_vector=[0.1, 0.2, 0.3],
107
                 should_search_mapping={"text_field": "text"},
108
                 must_search_mapping={"text_field": "text"},
109
                 must_not_search_mapping={"text_field": "text"},
110
                 # gte, lte, gt, lt supported
111
                 rang_search_mapping={"text_field": {"gte": 0.1, "lte": 0.2}},
112
                 limit=10
113
            )
114

115
        Returns:
116
            List[Document]: A list of documents.
117
        """
118
        from qdrant_client.http.models import (
119
            FieldCondition,
120
            Filter,
121
            MatchText,
122
            MatchValue,
123
            Range,
124
        )
125
        from qdrant_client.http.models.models import Payload
126

127
        should_search_mapping = should_search_mapping or {}
128
        must_search_mapping = must_search_mapping or {}
129
        must_not_search_mapping = must_not_search_mapping or {}
130
        rang_search_mapping = rang_search_mapping or {}
131

132
        should_search_conditions = [
133
            FieldCondition(key=key, match=MatchText(text=value))
134
            for key, value in should_search_mapping.items()
135
            if should_search_mapping
136
        ]
137
        must_search_conditions = [
138
            FieldCondition(key=key, match=MatchValue(value=value))
139
            for key, value in must_search_mapping.items()
140
            if must_search_mapping
141
        ]
142
        must_not_search_conditions = [
143
            FieldCondition(key=key, match=MatchValue(value=value))
144
            for key, value in must_not_search_mapping.items()
145
            if must_not_search_mapping
146
        ]
147
        rang_search_conditions = [
148
            FieldCondition(
149
                key=key,
150
                range=Range(
151
                    gte=value.get("gte"),
152
                    lte=value.get("lte"),
153
                    gt=value.get("gt"),
154
                    lt=value.get("lt"),
155
                ),
156
            )
157
            for key, value in rang_search_mapping.items()
158
            if rang_search_mapping
159
        ]
160
        should_search_conditions.extend(rang_search_conditions)
161
        response = self._client.search(
162
            collection_name=collection_name,
163
            query_vector=query_vector,
164
            query_filter=Filter(
165
                must=must_search_conditions,
166
                must_not=must_not_search_conditions,
167
                should=should_search_conditions,
168
            ),
169
            with_vectors=True,
170
            with_payload=True,
171
            limit=limit,
172
        )
173

174
        documents = []
175
        for point in response:
176
            payload = cast(Payload, point.payload)
177
            try:
178
                vector = cast(List[float], point.vector)
179
            except ValueError as e:
180
                raise ValueError("Could not cast vector to List[float].") from e
181
            document = Document(
182
                id_=payload.get("doc_id"),
183
                text=payload.get("text"),
184
                metadata=payload.get("metadata"),
185
                embedding=vector,
186
            )
187
            documents.append(document)
188

189
        return documents
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.