llama-index

Форк
0
355 строк · 11.2 Кб
1
"""Weaviate Vector store index.
2

3
An index that is built on top of an existing vector store.
4

5
"""
6

7
import logging
8
from typing import Any, Dict, List, Optional, cast
9
from uuid import uuid4
10

11
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
12
from llama_index.legacy.schema import BaseNode
13
from llama_index.legacy.vector_stores.types import (
14
    BasePydanticVectorStore,
15
    MetadataFilters,
16
    VectorStoreQuery,
17
    VectorStoreQueryMode,
18
    VectorStoreQueryResult,
19
)
20
from llama_index.legacy.vector_stores.utils import DEFAULT_TEXT_KEY
21
from llama_index.legacy.vector_stores.weaviate_utils import (
22
    add_node,
23
    class_schema_exists,
24
    create_default_schema,
25
    get_all_properties,
26
    get_node_similarity,
27
    parse_get_response,
28
    to_node,
29
)
30

31
logger = logging.getLogger(__name__)
32

33
import_err_msg = (
34
    "`weaviate` package not found, please run `pip install weaviate-client`"
35
)
36

37

38
def _transform_weaviate_filter_condition(condition: str) -> str:
39
    """Translate standard metadata filter op to Chroma specific spec."""
40
    if condition == "and":
41
        return "And"
42
    elif condition == "or":
43
        return "Or"
44
    else:
45
        raise ValueError(f"Filter condition {condition} not supported")
46

47

48
def _transform_weaviate_filter_operator(operator: str) -> str:
49
    """Translate standard metadata filter operator to Chroma specific spec."""
50
    if operator == "!=":
51
        return "NotEqual"
52
    elif operator == "==":
53
        return "Equal"
54
    elif operator == ">":
55
        return "GreaterThan"
56
    elif operator == "<":
57
        return "LessThan"
58
    elif operator == ">=":
59
        return "GreaterThanEqual"
60
    elif operator == "<=":
61
        return "LessThanEqual"
62
    else:
63
        raise ValueError(f"Filter operator {operator} not supported")
64

65

66
def _to_weaviate_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
67
    filters_list = []
68
    condition = standard_filters.condition or "and"
69
    condition = _transform_weaviate_filter_condition(condition)
70

71
    if standard_filters.filters:
72
        for filter in standard_filters.filters:
73
            value_type = "valueText"
74
            if isinstance(filter.value, float):
75
                value_type = "valueNumber"
76
            elif isinstance(filter.value, int):
77
                value_type = "valueNumber"
78
            elif isinstance(filter.value, str) and filter.value.isnumeric():
79
                filter.value = float(filter.value)
80
                value_type = "valueNumber"
81
            filters_list.append(
82
                {
83
                    "path": filter.key,
84
                    "operator": _transform_weaviate_filter_operator(filter.operator),
85
                    value_type: filter.value,
86
                }
87
            )
88
    else:
89
        return {}
90

91
    if len(filters_list) == 1:
92
        # If there is only one filter, return it directly
93
        return filters_list[0]
94

95
    return {"operands": filters_list, "operator": condition}
96

97

98
class WeaviateVectorStore(BasePydanticVectorStore):
99
    """Weaviate vector store.
100

101
    In this vector store, embeddings and docs are stored within a
102
    Weaviate collection.
103

104
    During query time, the index uses Weaviate to query for the top
105
    k most similar nodes.
106

107
    Args:
108
        weaviate_client (weaviate.Client): WeaviateClient
109
            instance from `weaviate-client` package
110
        index_name (Optional[str]): name for Weaviate classes
111

112
    """
113

114
    stores_text: bool = True
115

116
    index_name: str
117
    url: Optional[str]
118
    text_key: str
119
    auth_config: Dict[str, Any] = Field(default_factory=dict)
120
    client_kwargs: Dict[str, Any] = Field(default_factory=dict)
121

122
    _client = PrivateAttr()
123

124
    def __init__(
125
        self,
126
        weaviate_client: Optional[Any] = None,
127
        class_prefix: Optional[str] = None,
128
        index_name: Optional[str] = None,
129
        text_key: str = DEFAULT_TEXT_KEY,
130
        auth_config: Optional[Any] = None,
131
        client_kwargs: Optional[Dict[str, Any]] = None,
132
        url: Optional[str] = None,
133
        **kwargs: Any,
134
    ) -> None:
135
        """Initialize params."""
136
        try:
137
            import weaviate  # noqa
138
            from weaviate import AuthApiKey, Client
139
        except ImportError:
140
            raise ImportError(import_err_msg)
141

142
        if weaviate_client is None:
143
            if isinstance(auth_config, dict):
144
                auth_config = AuthApiKey(**auth_config)
145

146
            client_kwargs = client_kwargs or {}
147
            self._client = Client(
148
                url=url, auth_client_secret=auth_config, **client_kwargs
149
            )
150
        else:
151
            self._client = cast(Client, weaviate_client)
152

153
        # validate class prefix starts with a capital letter
154
        if class_prefix is not None:
155
            logger.warning("class_prefix is deprecated, please use index_name")
156
            # legacy, kept for backward compatibility
157
            index_name = f"{class_prefix}_Node"
158

159
        index_name = index_name or f"LlamaIndex_{uuid4().hex}"
160
        if not index_name[0].isupper():
161
            raise ValueError(
162
                "Index name must start with a capital letter, e.g. 'LlamaIndex'"
163
            )
164

165
        # create default schema if does not exist
166
        if not class_schema_exists(self._client, index_name):
167
            create_default_schema(self._client, index_name)
168

169
        super().__init__(
170
            url=url,
171
            index_name=index_name,
172
            text_key=text_key,
173
            auth_config=auth_config.__dict__ if auth_config else {},
174
            client_kwargs=client_kwargs or {},
175
        )
176

177
    @classmethod
178
    def from_params(
179
        cls,
180
        url: str,
181
        auth_config: Any,
182
        index_name: Optional[str] = None,
183
        text_key: str = DEFAULT_TEXT_KEY,
184
        client_kwargs: Optional[Dict[str, Any]] = None,
185
        **kwargs: Any,
186
    ) -> "WeaviateVectorStore":
187
        """Create WeaviateVectorStore from config."""
188
        try:
189
            import weaviate  # noqa
190
            from weaviate import AuthApiKey, Client  # noqa
191
        except ImportError:
192
            raise ImportError(import_err_msg)
193

194
        client_kwargs = client_kwargs or {}
195
        weaviate_client = Client(
196
            url=url, auth_client_secret=auth_config, **client_kwargs
197
        )
198
        return cls(
199
            weaviate_client=weaviate_client,
200
            url=url,
201
            auth_config=auth_config.__dict__,
202
            client_kwargs=client_kwargs,
203
            index_name=index_name,
204
            text_key=text_key,
205
            **kwargs,
206
        )
207

208
    @classmethod
209
    def class_name(cls) -> str:
210
        return "WeaviateVectorStore"
211

212
    @property
213
    def client(self) -> Any:
214
        """Get client."""
215
        return self._client
216

217
    def add(
218
        self,
219
        nodes: List[BaseNode],
220
        **add_kwargs: Any,
221
    ) -> List[str]:
222
        """Add nodes to index.
223

224
        Args:
225
            nodes: List[BaseNode]: list of nodes with embeddings
226

227
        """
228
        ids = [r.node_id for r in nodes]
229

230
        with self._client.batch as batch:
231
            for node in nodes:
232
                add_node(
233
                    self._client,
234
                    node,
235
                    self.index_name,
236
                    batch=batch,
237
                    text_key=self.text_key,
238
                )
239
        return ids
240

241
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
242
        """
243
        Delete nodes using with ref_doc_id.
244

245
        Args:
246
            ref_doc_id (str): The doc_id of the document to delete.
247

248
        """
249
        where_filter = {
250
            "path": ["ref_doc_id"],
251
            "operator": "Equal",
252
            "valueText": ref_doc_id,
253
        }
254
        if "filter" in delete_kwargs and delete_kwargs["filter"] is not None:
255
            where_filter = {
256
                "operator": "And",
257
                "operands": [where_filter, delete_kwargs["filter"]],  # type: ignore
258
            }
259

260
        query = (
261
            self._client.query.get(self.index_name)
262
            .with_additional(["id"])
263
            .with_where(where_filter)
264
            .with_limit(10000)  # 10,000 is the max weaviate can fetch
265
        )
266

267
        query_result = query.do()
268
        parsed_result = parse_get_response(query_result)
269
        entries = parsed_result[self.index_name]
270
        for entry in entries:
271
            self._client.data_object.delete(entry["_additional"]["id"], self.index_name)
272

273
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
274
        """Query index for top k most similar nodes."""
275
        all_properties = get_all_properties(self._client, self.index_name)
276

277
        # build query
278
        query_builder = self._client.query.get(self.index_name, all_properties)
279

280
        # list of documents to constrain search
281
        if query.doc_ids:
282
            filter_with_doc_ids = {
283
                "operator": "Or",
284
                "operands": [
285
                    {"path": ["doc_id"], "operator": "Equal", "valueText": doc_id}
286
                    for doc_id in query.doc_ids
287
                ],
288
            }
289
            query_builder = query_builder.with_where(filter_with_doc_ids)
290

291
        if query.node_ids:
292
            filter_with_node_ids = {
293
                "operator": "Or",
294
                "operands": [
295
                    {"path": ["id"], "operator": "Equal", "valueText": node_id}
296
                    for node_id in query.node_ids
297
                ],
298
            }
299
            query_builder = query_builder.with_where(filter_with_node_ids)
300

301
        query_builder = query_builder.with_additional(
302
            ["id", "vector", "distance", "score"]
303
        )
304

305
        vector = query.query_embedding
306
        similarity_key = "distance"
307
        if query.mode == VectorStoreQueryMode.DEFAULT:
308
            logger.debug("Using vector search")
309
            if vector is not None:
310
                query_builder = query_builder.with_near_vector(
311
                    {
312
                        "vector": vector,
313
                    }
314
                )
315
        elif query.mode == VectorStoreQueryMode.HYBRID:
316
            logger.debug(f"Using hybrid search with alpha {query.alpha}")
317
            similarity_key = "score"
318
            if vector is not None and query.query_str:
319
                query_builder = query_builder.with_hybrid(
320
                    query=query.query_str,
321
                    alpha=query.alpha,
322
                    vector=vector,
323
                )
324

325
        if query.filters is not None:
326
            filter = _to_weaviate_filter(query.filters)
327
            query_builder = query_builder.with_where(filter)
328
        elif "filter" in kwargs and kwargs["filter"] is not None:
329
            query_builder = query_builder.with_where(kwargs["filter"])
330

331
        query_builder = query_builder.with_limit(query.similarity_top_k)
332
        logger.debug(f"Using limit of {query.similarity_top_k}")
333

334
        # execute query
335
        query_result = query_builder.do()
336

337
        # parse results
338
        parsed_result = parse_get_response(query_result)
339
        entries = parsed_result[self.index_name]
340

341
        similarities = []
342
        nodes: List[BaseNode] = []
343
        node_ids = []
344

345
        for i, entry in enumerate(entries):
346
            if i < query.similarity_top_k:
347
                similarities.append(get_node_similarity(entry, similarity_key))
348
                nodes.append(to_node(entry, text_key=self.text_key))
349
                node_ids.append(nodes[-1].node_id)
350
            else:
351
                break
352

353
        return VectorStoreQueryResult(
354
            nodes=nodes, ids=node_ids, similarities=similarities
355
        )
356

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.