llama-index

Форк
0
225 строк · 7.3 Кб
1
"""LanceDB vector store."""
2

3
import logging
4
from typing import Any, List, Optional
5

6
import numpy as np
7
from pandas import DataFrame
8

9
from llama_index.legacy.schema import (
10
    BaseNode,
11
    MetadataMode,
12
    NodeRelationship,
13
    RelatedNodeInfo,
14
    TextNode,
15
)
16
from llama_index.legacy.vector_stores.types import (
17
    MetadataFilters,
18
    VectorStore,
19
    VectorStoreQuery,
20
    VectorStoreQueryResult,
21
)
22
from llama_index.legacy.vector_stores.utils import (
23
    DEFAULT_DOC_ID_KEY,
24
    DEFAULT_TEXT_KEY,
25
    legacy_metadata_dict_to_node,
26
    metadata_dict_to_node,
27
    node_to_metadata_dict,
28
)
29

30
_logger = logging.getLogger(__name__)
31

32

33
def _to_lance_filter(standard_filters: MetadataFilters) -> Any:
34
    """Translate standard metadata filters to Lance specific spec."""
35
    filters = []
36
    for filter in standard_filters.legacy_filters():
37
        if isinstance(filter.value, str):
38
            filters.append(filter.key + ' = "' + filter.value + '"')
39
        else:
40
            filters.append(filter.key + " = " + str(filter.value))
41
    return " AND ".join(filters)
42

43

44
def _to_llama_similarities(results: DataFrame) -> List[float]:
45
    keys = results.keys()
46
    normalized_similarities: np.ndarray
47
    if "score" in keys:
48
        normalized_similarities = np.exp(results["score"] - np.max(results["score"]))
49
    elif "_distance" in keys:
50
        normalized_similarities = np.exp(-results["_distance"])
51
    else:
52
        normalized_similarities = np.linspace(1, 0, len(results))
53
    return normalized_similarities.tolist()
54

55

56
class LanceDBVectorStore(VectorStore):
57
    """
58
    The LanceDB Vector Store.
59

60
    Stores text and embeddings in LanceDB. The vector store will open an existing
61
        LanceDB dataset or create the dataset if it does not exist.
62

63
    Args:
64
        uri (str, required): Location where LanceDB will store its files.
65
        table_name (str, optional): The table name where the embeddings will be stored.
66
            Defaults to "vectors".
67
        vector_column_name (str, optional): The vector column name in the table if different from default.
68
            Defaults to "vector", in keeping with lancedb convention.
69
        nprobes (int, optional): The number of probes used.
70
            A higher number makes search more accurate but also slower.
71
            Defaults to 20.
72
        refine_factor: (int, optional): Refine the results by reading extra elements
73
            and re-ranking them in memory.
74
            Defaults to None
75

76
    Raises:
77
        ImportError: Unable to import `lancedb`.
78

79
    Returns:
80
        LanceDBVectorStore: VectorStore that supports creating LanceDB datasets and
81
            querying it.
82
    """
83

84
    stores_text = True
85
    flat_metadata: bool = True
86

87
    def __init__(
88
        self,
89
        uri: str,
90
        table_name: str = "vectors",
91
        vector_column_name: str = "vector",
92
        nprobes: int = 20,
93
        refine_factor: Optional[int] = None,
94
        text_key: str = DEFAULT_TEXT_KEY,
95
        doc_id_key: str = DEFAULT_DOC_ID_KEY,
96
        **kwargs: Any,
97
    ) -> None:
98
        """Init params."""
99
        import_err_msg = "`lancedb` package not found, please run `pip install lancedb`"
100
        try:
101
            import lancedb
102
        except ImportError:
103
            raise ImportError(import_err_msg)
104

105
        self.connection = lancedb.connect(uri)
106
        self.uri = uri
107
        self.table_name = table_name
108
        self.vector_column_name = vector_column_name
109
        self.nprobes = nprobes
110
        self.text_key = text_key
111
        self.doc_id_key = doc_id_key
112
        self.refine_factor = refine_factor
113

114
    @property
115
    def client(self) -> None:
116
        """Get client."""
117
        return
118

119
    def add(
120
        self,
121
        nodes: List[BaseNode],
122
        **add_kwargs: Any,
123
    ) -> List[str]:
124
        data = []
125
        ids = []
126
        for node in nodes:
127
            metadata = node_to_metadata_dict(
128
                node, remove_text=False, flat_metadata=self.flat_metadata
129
            )
130
            append_data = {
131
                "id": node.node_id,
132
                "doc_id": node.ref_doc_id,
133
                "vector": node.get_embedding(),
134
                "text": node.get_content(metadata_mode=MetadataMode.NONE),
135
                "metadata": metadata,
136
            }
137
            data.append(append_data)
138
            ids.append(node.node_id)
139

140
        if self.table_name in self.connection.table_names():
141
            tbl = self.connection.open_table(self.table_name)
142
            tbl.add(data)
143
        else:
144
            self.connection.create_table(self.table_name, data)
145
        return ids
146

147
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
148
        """
149
        Delete nodes using with ref_doc_id.
150

151
        Args:
152
            ref_doc_id (str): The doc_id of the document to delete.
153

154
        """
155
        table = self.connection.open_table(self.table_name)
156
        table.delete('document_id = "' + ref_doc_id + '"')
157

158
    def query(
159
        self,
160
        query: VectorStoreQuery,
161
        **kwargs: Any,
162
    ) -> VectorStoreQueryResult:
163
        """Query index for top k most similar nodes."""
164
        if query.filters is not None:
165
            if "where" in kwargs:
166
                raise ValueError(
167
                    "Cannot specify filter via both query and kwargs. "
168
                    "Use kwargs only for lancedb specific items that are "
169
                    "not supported via the generic query interface."
170
                )
171
            where = _to_lance_filter(query.filters)
172
        else:
173
            where = kwargs.pop("where", None)
174

175
        table = self.connection.open_table(self.table_name)
176
        lance_query = (
177
            table.search(
178
                query=query.query_embedding,
179
                vector_column_name=self.vector_column_name,
180
            )
181
            .limit(query.similarity_top_k)
182
            .where(where)
183
            .nprobes(self.nprobes)
184
        )
185

186
        if self.refine_factor is not None:
187
            lance_query.refine_factor(self.refine_factor)
188

189
        results = lance_query.to_pandas()
190
        nodes = []
191
        for _, item in results.iterrows():
192
            try:
193
                node = metadata_dict_to_node(item.metadata)
194
                node.embedding = list(item[self.vector_column_name])
195
            except Exception:
196
                # deprecated legacy logic for backward compatibility
197
                _logger.debug(
198
                    "Failed to parse Node metadata, fallback to legacy logic."
199
                )
200
                if "metadata" in item:
201
                    metadata, node_info, _relation = legacy_metadata_dict_to_node(
202
                        item.metadata, text_key=self.text_key
203
                    )
204
                else:
205
                    metadata, node_info = {}, {}
206
                node = TextNode(
207
                    text=item[self.text_key] or "",
208
                    id_=item.id,
209
                    metadata=metadata,
210
                    start_char_idx=node_info.get("start", None),
211
                    end_char_idx=node_info.get("end", None),
212
                    relationships={
213
                        NodeRelationship.SOURCE: RelatedNodeInfo(
214
                            node_id=item[self.doc_id_key]
215
                        ),
216
                    },
217
                )
218

219
            nodes.append(node)
220

221
        return VectorStoreQueryResult(
222
            nodes=nodes,
223
            similarities=_to_llama_similarities(results),
224
            ids=results["id"].tolist(),
225
        )
226

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.