llama-index

Форк
0
275 строк · 9.0 Кб
1
import enum
2
import uuid
3
from datetime import timedelta
4
from typing import Any, Dict, List, Optional
5

6
from llama_index.legacy.constants import DEFAULT_EMBEDDING_DIM
7
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
8
from llama_index.legacy.vector_stores.types import (
9
    MetadataFilters,
10
    VectorStore,
11
    VectorStoreQuery,
12
    VectorStoreQueryResult,
13
)
14
from llama_index.legacy.vector_stores.utils import (
15
    metadata_dict_to_node,
16
    node_to_metadata_dict,
17
)
18

19

20
class IndexType(enum.Enum):
21
    """Enumerator for the supported Index types."""
22

23
    TIMESCALE_VECTOR = 1
24
    PGVECTOR_IVFFLAT = 2
25
    PGVECTOR_HNSW = 3
26

27

28
class TimescaleVectorStore(VectorStore):
29
    stores_text = True
30
    flat_metadata = False
31

32
    def __init__(
33
        self,
34
        service_url: str,
35
        table_name: str,
36
        num_dimensions: int = DEFAULT_EMBEDDING_DIM,
37
        time_partition_interval: Optional[timedelta] = None,
38
    ) -> None:
39
        try:
40
            from timescale_vector import client  # noqa
41
        except ImportError:
42
            raise ImportError("`timescale-vector` package should be pre installed")
43

44
        self.service_url = service_url
45
        self.table_name: str = table_name.lower()
46
        self.num_dimensions = num_dimensions
47
        self.time_partition_interval = time_partition_interval
48

49
        self._create_clients()
50
        self._create_tables()
51

52
    async def close(self) -> None:
53
        self._sync_client.close()
54
        await self._async_client.close()
55

56
    @classmethod
57
    def from_params(
58
        cls,
59
        service_url: str,
60
        table_name: str,
61
        num_dimensions: int = DEFAULT_EMBEDDING_DIM,
62
        time_partition_interval: Optional[timedelta] = None,
63
    ) -> "TimescaleVectorStore":
64
        return cls(
65
            service_url=service_url,
66
            table_name=table_name,
67
            num_dimensions=num_dimensions,
68
            time_partition_interval=time_partition_interval,
69
        )
70

71
    def _create_clients(self) -> None:
72
        from timescale_vector import client
73

74
        # in the normal case doesn't restrict the id type to even uuid.
75
        # Allow arbitrary text
76
        id_type = "TEXT"
77
        if self.time_partition_interval is not None:
78
            # for time partitioned tables, the id type must be UUID v1
79
            id_type = "UUID"
80

81
        self._sync_client = client.Sync(
82
            self.service_url,
83
            self.table_name,
84
            self.num_dimensions,
85
            id_type=id_type,
86
            time_partition_interval=self.time_partition_interval,
87
        )
88
        self._async_client = client.Async(
89
            self.service_url,
90
            self.table_name,
91
            self.num_dimensions,
92
            id_type=id_type,
93
            time_partition_interval=self.time_partition_interval,
94
        )
95

96
    def _create_tables(self) -> None:
97
        self._sync_client.create_tables()
98

99
    def _node_to_row(self, node: BaseNode) -> Any:
100
        metadata = node_to_metadata_dict(
101
            node,
102
            remove_text=True,
103
            flat_metadata=self.flat_metadata,
104
        )
105
        # reuse the node id in the common  case
106
        id = node.node_id
107
        if self.time_partition_interval is not None:
108
            # for time partitioned tables, the id must be a UUID v1,
109
            # so generate one if it's not already set
110
            try:
111
                # Attempt to parse the UUID from the string
112
                parsed_uuid = uuid.UUID(id)
113
                if parsed_uuid.version != 1:
114
                    id = str(uuid.uuid1())
115
            except ValueError:
116
                id = str(uuid.uuid1())
117
        return [
118
            id,
119
            metadata,
120
            node.get_content(metadata_mode=MetadataMode.NONE),
121
            node.embedding,
122
        ]
123

124
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
125
        rows_to_insert = [self._node_to_row(node) for node in nodes]
126
        ids = [result[0] for result in rows_to_insert]
127
        self._sync_client.upsert(rows_to_insert)
128
        return ids
129

130
    async def async_add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
131
        rows_to_insert = [self._node_to_row(node) for node in nodes]
132
        ids = [result.node_id for result in nodes]
133
        await self._async_client.upsert(rows_to_insert)
134
        return ids
135

136
    def _filter_to_dict(
137
        self, metadata_filters: Optional[MetadataFilters]
138
    ) -> Optional[Dict[str, str]]:
139
        if metadata_filters is None or len(metadata_filters.legacy_filters()) <= 0:
140
            return None
141

142
        res = {}
143
        for filter in metadata_filters.legacy_filters():
144
            res[filter.key] = filter.value
145

146
        return res
147

148
    def _db_rows_to_query_result(self, rows: List) -> VectorStoreQueryResult:
149
        from timescale_vector import client
150

151
        nodes = []
152
        similarities = []
153
        ids = []
154
        for row in rows:
155
            try:
156
                node = metadata_dict_to_node(row[client.SEARCH_RESULT_METADATA_IDX])
157
                node.set_content(str(row[client.SEARCH_RESULT_CONTENTS_IDX]))
158
            except Exception:
159
                # NOTE: deprecated legacy logic for backward compatibility
160
                node = TextNode(
161
                    id_=row[client.SEARCH_RESULT_ID_IDX],
162
                    text=row[client.SEARCH_RESULT_CONTENTS_IDX],
163
                    metadata=row[client.SEARCH_RESULT_METADATA_IDX],
164
                )
165
            similarities.append(row[client.SEARCH_RESULT_DISTANCE_IDX])
166
            ids.append(row[client.SEARCH_RESULT_ID_IDX])
167
            nodes.append(node)
168

169
        return VectorStoreQueryResult(
170
            nodes=nodes,
171
            similarities=similarities,
172
            ids=ids,
173
        )
174

175
    def date_to_range_filter(self, **kwargs: Any) -> Any:
176
        constructor_args = {
177
            key: kwargs[key]
178
            for key in [
179
                "start_date",
180
                "end_date",
181
                "time_delta",
182
                "start_inclusive",
183
                "end_inclusive",
184
            ]
185
            if key in kwargs
186
        }
187
        if not constructor_args or len(constructor_args) == 0:
188
            return None
189

190
        try:
191
            from timescale_vector import client
192
        except ImportError:
193
            raise ValueError(
194
                "Could not import timescale_vector python package. "
195
                "Please install it with `pip install timescale-vector`."
196
            )
197
        return client.UUIDTimeRange(**constructor_args)
198

199
    def _query_with_score(
200
        self,
201
        embedding: Optional[List[float]],
202
        limit: int = 10,
203
        metadata_filters: Optional[MetadataFilters] = None,
204
        **kwargs: Any,
205
    ) -> VectorStoreQueryResult:
206
        filter = self._filter_to_dict(metadata_filters)
207
        res = self._sync_client.search(
208
            embedding,
209
            limit,
210
            filter,
211
            uuid_time_filter=self.date_to_range_filter(**kwargs),
212
        )
213
        return self._db_rows_to_query_result(res)
214

215
    async def _aquery_with_score(
216
        self,
217
        embedding: Optional[List[float]],
218
        limit: int = 10,
219
        metadata_filters: Optional[MetadataFilters] = None,
220
        **kwargs: Any,
221
    ) -> VectorStoreQueryResult:
222
        filter = self._filter_to_dict(metadata_filters)
223
        res = await self._async_client.search(
224
            embedding,
225
            limit,
226
            filter,
227
            uuid_time_filter=self.date_to_range_filter(**kwargs),
228
        )
229
        return self._db_rows_to_query_result(res)
230

231
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
232
        return self._query_with_score(
233
            query.query_embedding, query.similarity_top_k, query.filters, **kwargs
234
        )
235

236
    async def aquery(
237
        self, query: VectorStoreQuery, **kwargs: Any
238
    ) -> VectorStoreQueryResult:
239
        return await self._aquery_with_score(
240
            query.query_embedding,
241
            query.similarity_top_k,
242
            query.filters,
243
            **kwargs,
244
        )
245

246
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
247
        filter: Dict[str, str] = {"doc_id": ref_doc_id}
248
        self._sync_client.delete_by_metadata(filter)
249

250
    DEFAULT_INDEX_TYPE = IndexType.TIMESCALE_VECTOR
251

252
    def create_index(
253
        self, index_type: IndexType = DEFAULT_INDEX_TYPE, **kwargs: Any
254
    ) -> None:
255
        try:
256
            from timescale_vector import client
257
        except ImportError:
258
            raise ValueError(
259
                "Could not import timescale_vector python package. "
260
                "Please install it with `pip install timescale-vector`."
261
            )
262

263
        if index_type == IndexType.PGVECTOR_IVFFLAT:
264
            self._sync_client.create_embedding_index(client.IvfflatIndex(**kwargs))
265

266
        if index_type == IndexType.PGVECTOR_HNSW:
267
            self._sync_client.create_embedding_index(client.HNSWIndex(**kwargs))
268

269
        if index_type == IndexType.TIMESCALE_VECTOR:
270
            self._sync_client.create_embedding_index(
271
                client.TimescaleVectorIndex(**kwargs)
272
            )
273

274
    def drop_index(self) -> None:
275
        self._sync_client.drop_embedding_index()
276

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.