llama-index

Форк
0
347 строк · 10.8 Кб
1
"""Chroma vector store."""
2

3
import logging
4
import math
5
from typing import Any, Dict, Generator, List, Optional, cast
6

7
from llama_index.legacy.bridge.pydantic import Field, PrivateAttr
8
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
9
from llama_index.legacy.utils import truncate_text
10
from llama_index.legacy.vector_stores.types import (
11
    BasePydanticVectorStore,
12
    MetadataFilters,
13
    VectorStoreQuery,
14
    VectorStoreQueryResult,
15
)
16
from llama_index.legacy.vector_stores.utils import (
17
    legacy_metadata_dict_to_node,
18
    metadata_dict_to_node,
19
    node_to_metadata_dict,
20
)
21

22
logger = logging.getLogger(__name__)
23

24

25
def _transform_chroma_filter_condition(condition: str) -> str:
26
    """Translate standard metadata filter op to Chroma specific spec."""
27
    if condition == "and":
28
        return "$and"
29
    elif condition == "or":
30
        return "$or"
31
    else:
32
        raise ValueError(f"Filter condition {condition} not supported")
33

34

35
def _transform_chroma_filter_operator(operator: str) -> str:
36
    """Translate standard metadata filter operator to Chroma specific spec."""
37
    if operator == "!=":
38
        return "$ne"
39
    elif operator == "==":
40
        return "$eq"
41
    elif operator == ">":
42
        return "$gt"
43
    elif operator == "<":
44
        return "$lt"
45
    elif operator == ">=":
46
        return "$gte"
47
    elif operator == "<=":
48
        return "$lte"
49
    else:
50
        raise ValueError(f"Filter operator {operator} not supported")
51

52

53
def _to_chroma_filter(
54
    standard_filters: MetadataFilters,
55
) -> dict:
56
    """Translate standard metadata filters to Chroma specific spec."""
57
    filters = {}
58
    filters_list = []
59
    condition = standard_filters.condition or "and"
60
    condition = _transform_chroma_filter_condition(condition)
61
    if standard_filters.filters:
62
        for filter in standard_filters.filters:
63
            if filter.operator:
64
                filters_list.append(
65
                    {
66
                        filter.key: {
67
                            _transform_chroma_filter_operator(
68
                                filter.operator
69
                            ): filter.value
70
                        }
71
                    }
72
                )
73
            else:
74
                filters_list.append({filter.key: filter.value})
75

76
    if len(filters_list) == 1:
77
        # If there is only one filter, return it directly
78
        return filters_list[0]
79
    elif len(filters_list) > 1:
80
        filters[condition] = filters_list
81
    return filters
82

83

84
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
85

86
MAX_CHUNK_SIZE = 41665  # One less than the max chunk size for ChromaDB
87

88

89
def chunk_list(
90
    lst: List[BaseNode], max_chunk_size: int
91
) -> Generator[List[BaseNode], None, None]:
92
    """Yield successive max_chunk_size-sized chunks from lst.
93

94
    Args:
95
        lst (List[BaseNode]): list of nodes with embeddings
96
        max_chunk_size (int): max chunk size
97

98
    Yields:
99
        Generator[List[BaseNode], None, None]: list of nodes with embeddings
100
    """
101
    for i in range(0, len(lst), max_chunk_size):
102
        yield lst[i : i + max_chunk_size]
103

104

105
class ChromaVectorStore(BasePydanticVectorStore):
106
    """Chroma vector store.
107

108
    In this vector store, embeddings are stored within a ChromaDB collection.
109

110
    During query time, the index uses ChromaDB to query for the top
111
    k most similar nodes.
112

113
    Args:
114
        chroma_collection (chromadb.api.models.Collection.Collection):
115
            ChromaDB collection instance
116

117
    """
118

119
    stores_text: bool = True
120
    flat_metadata: bool = True
121

122
    collection_name: Optional[str]
123
    host: Optional[str]
124
    port: Optional[str]
125
    ssl: bool
126
    headers: Optional[Dict[str, str]]
127
    persist_dir: Optional[str]
128
    collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
129

130
    _collection: Any = PrivateAttr()
131

132
    def __init__(
133
        self,
134
        chroma_collection: Optional[Any] = None,
135
        collection_name: Optional[str] = None,
136
        host: Optional[str] = None,
137
        port: Optional[str] = None,
138
        ssl: bool = False,
139
        headers: Optional[Dict[str, str]] = None,
140
        persist_dir: Optional[str] = None,
141
        collection_kwargs: Optional[dict] = None,
142
        **kwargs: Any,
143
    ) -> None:
144
        """Init params."""
145
        try:
146
            import chromadb
147
        except ImportError:
148
            raise ImportError(import_err_msg)
149
        from chromadb.api.models.Collection import Collection
150

151
        if chroma_collection is None:
152
            client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
153
            self._collection = client.get_or_create_collection(
154
                name=collection_name, **collection_kwargs
155
            )
156
        else:
157
            self._collection = cast(Collection, chroma_collection)
158

159
        super().__init__(
160
            host=host,
161
            port=port,
162
            ssl=ssl,
163
            headers=headers,
164
            collection_name=collection_name,
165
            persist_dir=persist_dir,
166
            collection_kwargs=collection_kwargs or {},
167
        )
168

169
    @classmethod
170
    def from_collection(cls, collection: Any) -> "ChromaVectorStore":
171
        try:
172
            from chromadb import Collection
173
        except ImportError:
174
            raise ImportError(import_err_msg)
175

176
        if not isinstance(collection, Collection):
177
            raise Exception("argument is not chromadb collection instance")
178

179
        return cls(chroma_collection=collection)
180

181
    @classmethod
182
    def from_params(
183
        cls,
184
        collection_name: str,
185
        host: Optional[str] = None,
186
        port: Optional[str] = None,
187
        ssl: bool = False,
188
        headers: Optional[Dict[str, str]] = None,
189
        persist_dir: Optional[str] = None,
190
        collection_kwargs: dict = {},
191
        **kwargs: Any,
192
    ) -> "ChromaVectorStore":
193
        try:
194
            import chromadb
195
        except ImportError:
196
            raise ImportError(import_err_msg)
197
        if persist_dir:
198
            client = chromadb.PersistentClient(path=persist_dir)
199
            collection = client.get_or_create_collection(
200
                name=collection_name, **collection_kwargs
201
            )
202
        elif host and port:
203
            client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
204
            collection = client.get_or_create_collection(
205
                name=collection_name, **collection_kwargs
206
            )
207
        else:
208
            raise ValueError(
209
                "Either `persist_dir` or (`host`,`port`) must be specified"
210
            )
211
        return cls(
212
            chroma_collection=collection,
213
            host=host,
214
            port=port,
215
            ssl=ssl,
216
            headers=headers,
217
            persist_dir=persist_dir,
218
            collection_kwargs=collection_kwargs,
219
            **kwargs,
220
        )
221

222
    @classmethod
223
    def class_name(cls) -> str:
224
        return "ChromaVectorStore"
225

226
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
227
        """Add nodes to index.
228

229
        Args:
230
            nodes: List[BaseNode]: list of nodes with embeddings
231

232
        """
233
        if not self._collection:
234
            raise ValueError("Collection not initialized")
235

236
        max_chunk_size = MAX_CHUNK_SIZE
237
        node_chunks = chunk_list(nodes, max_chunk_size)
238

239
        all_ids = []
240
        for node_chunk in node_chunks:
241
            embeddings = []
242
            metadatas = []
243
            ids = []
244
            documents = []
245
            for node in node_chunk:
246
                embeddings.append(node.get_embedding())
247
                metadata_dict = node_to_metadata_dict(
248
                    node, remove_text=True, flat_metadata=self.flat_metadata
249
                )
250
                for key in metadata_dict:
251
                    if metadata_dict[key] is None:
252
                        metadata_dict[key] = ""
253
                metadatas.append(metadata_dict)
254
                ids.append(node.node_id)
255
                documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
256

257
            self._collection.add(
258
                embeddings=embeddings,
259
                ids=ids,
260
                metadatas=metadatas,
261
                documents=documents,
262
            )
263
            all_ids.extend(ids)
264

265
        return all_ids
266

267
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
268
        """
269
        Delete nodes using with ref_doc_id.
270

271
        Args:
272
            ref_doc_id (str): The doc_id of the document to delete.
273

274
        """
275
        self._collection.delete(where={"document_id": ref_doc_id})
276

277
    @property
278
    def client(self) -> Any:
279
        """Return client."""
280
        return self._collection
281

282
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
283
        """Query index for top k most similar nodes.
284

285
        Args:
286
            query_embedding (List[float]): query embedding
287
            similarity_top_k (int): top k most similar nodes
288

289
        """
290
        if query.filters is not None:
291
            if "where" in kwargs:
292
                raise ValueError(
293
                    "Cannot specify metadata filters via both query and kwargs. "
294
                    "Use kwargs only for chroma specific items that are "
295
                    "not supported via the generic query interface."
296
                )
297
            where = _to_chroma_filter(query.filters)
298
        else:
299
            where = kwargs.pop("where", {})
300

301
        results = self._collection.query(
302
            query_embeddings=query.query_embedding,
303
            n_results=query.similarity_top_k,
304
            where=where,
305
            **kwargs,
306
        )
307

308
        logger.debug(f"> Top {len(results['documents'])} nodes:")
309
        nodes = []
310
        similarities = []
311
        ids = []
312
        for node_id, text, metadata, distance in zip(
313
            results["ids"][0],
314
            results["documents"][0],
315
            results["metadatas"][0],
316
            results["distances"][0],
317
        ):
318
            try:
319
                node = metadata_dict_to_node(metadata)
320
                node.set_content(text)
321
            except Exception:
322
                # NOTE: deprecated legacy logic for backward compatibility
323
                metadata, node_info, relationships = legacy_metadata_dict_to_node(
324
                    metadata
325
                )
326

327
                node = TextNode(
328
                    text=text,
329
                    id_=node_id,
330
                    metadata=metadata,
331
                    start_char_idx=node_info.get("start", None),
332
                    end_char_idx=node_info.get("end", None),
333
                    relationships=relationships,
334
                )
335

336
            nodes.append(node)
337

338
            similarity_score = math.exp(-distance)
339
            similarities.append(similarity_score)
340

341
            logger.debug(
342
                f"> [Node {node_id}] [Similarity score: {similarity_score}] "
343
                f"{truncate_text(str(text), 100)}"
344
            )
345
            ids.append(node_id)
346

347
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
348

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.