llama-index

Форк
0
643 строки · 21.1 Кб
1
import logging
2
from typing import Any, List, NamedTuple, Optional, Type
3

4
from llama_index.legacy.bridge.pydantic import PrivateAttr
5
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
6
from llama_index.legacy.vector_stores.types import (
7
    BasePydanticVectorStore,
8
    MetadataFilters,
9
    VectorStoreQuery,
10
    VectorStoreQueryMode,
11
    VectorStoreQueryResult,
12
)
13
from llama_index.legacy.vector_stores.utils import (
14
    metadata_dict_to_node,
15
    node_to_metadata_dict,
16
)
17

18

19
class DBEmbeddingRow(NamedTuple):
20
    node_id: str  # FIXME: verify this type hint
21
    text: str
22
    metadata: dict
23
    similarity: float
24

25

26
_logger = logging.getLogger(__name__)
27

28

29
def get_data_model(
30
    base: Type,
31
    index_name: str,
32
    schema_name: str,
33
    hybrid_search: bool,
34
    text_search_config: str,
35
    cache_okay: bool,
36
    embed_dim: int = 1536,
37
    m: int = 16,
38
    ef_construction: int = 128,
39
    ef: int = 64,
40
) -> Any:
41
    """
42
    This part create a dynamic sqlalchemy model with a new table.
43
    """
44
    from sqlalchemy import Column, Computed
45
    from sqlalchemy.dialects.postgresql import (
46
        ARRAY,
47
        BIGINT,
48
        JSON,
49
        REAL,
50
        TSVECTOR,
51
        VARCHAR,
52
    )
53
    from sqlalchemy.schema import Index
54
    from sqlalchemy.types import TypeDecorator
55

56
    class TSVector(TypeDecorator):
57
        impl = TSVECTOR
58
        cache_ok = cache_okay
59

60
    tablename = "data_%s" % index_name  # dynamic table name
61
    class_name = "Data%s" % index_name  # dynamic class name
62
    indexname = "%s_idx" % index_name  # dynamic index name
63
    hnsw_indexname = "%s_hnsw_idx" % index_name  # dynamic hnsw index name
64

65
    if hybrid_search:
66

67
        class HybridAbstractData(base):  # type: ignore
68
            __abstract__ = True  # this line is necessary
69
            id = Column(BIGINT, primary_key=True, autoincrement=True)
70
            text = Column(VARCHAR, nullable=False)
71
            metadata_ = Column(JSON)
72
            node_id = Column(VARCHAR)
73
            embedding = Column(ARRAY(REAL, embed_dim))  # type: ignore
74
            text_search_tsv = Column(  # type: ignore
75
                TSVector(),
76
                Computed(
77
                    "to_tsvector('%s', text)" % text_search_config, persisted=True
78
                ),
79
            )
80

81
        model = type(
82
            class_name,
83
            (HybridAbstractData,),
84
            {"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
85
        )
86

87
        Index(
88
            indexname,
89
            model.text_search_tsv,  # type: ignore
90
            postgresql_using="gin",
91
        )
92
    else:
93

94
        class AbstractData(base):  # type: ignore
95
            __abstract__ = True  # this line is necessary
96
            id = Column(BIGINT, primary_key=True, autoincrement=True)
97
            text = Column(VARCHAR, nullable=False)
98
            metadata_ = Column(JSON)
99
            node_id = Column(VARCHAR)
100
            embedding = Column(ARRAY(REAL, embed_dim))  # type: ignore
101

102
        model = type(
103
            class_name,
104
            (AbstractData,),
105
            {"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
106
        )
107

108
    Index(
109
        hnsw_indexname,
110
        model.embedding,  # type: ignore
111
        postgresql_using="hnsw",
112
        postgresql_with={
113
            "m": m,
114
            "ef_construction": ef_construction,
115
            "ef": ef,
116
            "dim": embed_dim,
117
        },
118
        postgresql_ops={"embedding": "dist_cos_ops"},
119
    )
120
    return model
121

122

123
class LanternVectorStore(BasePydanticVectorStore):
124
    from sqlalchemy.sql.selectable import Select
125

126
    stores_text = True
127
    flat_metadata = False
128

129
    connection_string: str
130
    async_connection_string: str
131
    table_name: str
132
    schema_name: str
133
    embed_dim: int
134
    hybrid_search: bool
135
    text_search_config: str
136
    cache_ok: bool
137
    perform_setup: bool
138
    debug: bool
139

140
    _base: Any = PrivateAttr()
141
    _table_class: Any = PrivateAttr()
142
    _engine: Any = PrivateAttr()
143
    _session: Any = PrivateAttr()
144
    _async_engine: Any = PrivateAttr()
145
    _async_session: Any = PrivateAttr()
146
    _is_initialized: bool = PrivateAttr(default=False)
147

148
    def __init__(
149
        self,
150
        connection_string: str,
151
        async_connection_string: str,
152
        table_name: str,
153
        schema_name: str,
154
        hybrid_search: bool = False,
155
        text_search_config: str = "english",
156
        embed_dim: int = 1536,
157
        m: int = 16,
158
        ef_construction: int = 128,
159
        ef: int = 64,
160
        cache_ok: bool = False,
161
        perform_setup: bool = True,
162
        debug: bool = False,
163
    ) -> None:
164
        try:
165
            import asyncpg  # noqa
166
            import psycopg2  # noqa
167
            import sqlalchemy
168
            import sqlalchemy.ext.asyncio  # noqa
169
        except ImportError:
170
            raise ImportError(
171
                "`sqlalchemy[asyncio]`, `psycopg2-binary` and `asyncpg` "
172
                "packages should be pre installed"
173
            )
174

175
        table_name = table_name.lower()
176
        schema_name = schema_name.lower()
177

178
        if hybrid_search and text_search_config is None:
179
            raise ValueError(
180
                "Sparse vector index creation requires "
181
                "a text search configuration specification."
182
            )
183

184
        from sqlalchemy.orm import declarative_base
185

186
        # sqlalchemy model
187
        self._base = declarative_base()
188
        self._table_class = get_data_model(
189
            self._base,
190
            table_name,
191
            schema_name,
192
            hybrid_search,
193
            text_search_config,
194
            cache_ok,
195
            embed_dim=embed_dim,
196
            m=m,
197
            ef_construction=ef_construction,
198
            ef=ef,
199
        )
200

201
        super().__init__(
202
            connection_string=connection_string,
203
            async_connection_string=async_connection_string,
204
            table_name=table_name,
205
            schema_name=schema_name,
206
            hybrid_search=hybrid_search,
207
            text_search_config=text_search_config,
208
            embed_dim=embed_dim,
209
            cache_ok=cache_ok,
210
            perform_setup=perform_setup,
211
            debug=debug,
212
        )
213

214
    async def close(self) -> None:
215
        if not self._is_initialized:
216
            return
217

218
        self._session.close_all()
219
        self._engine.dispose()
220

221
        await self._async_engine.dispose()
222

223
    @classmethod
224
    def class_name(cls) -> str:
225
        return "LanternStore"
226

227
    @classmethod
228
    def from_params(
229
        cls,
230
        host: Optional[str] = None,
231
        port: Optional[str] = None,
232
        database: Optional[str] = None,
233
        user: Optional[str] = None,
234
        password: Optional[str] = None,
235
        table_name: str = "llamaindex",
236
        schema_name: str = "public",
237
        connection_string: Optional[str] = None,
238
        async_connection_string: Optional[str] = None,
239
        hybrid_search: bool = False,
240
        text_search_config: str = "english",
241
        embed_dim: int = 1536,
242
        m: int = 16,
243
        ef_construction: int = 128,
244
        ef: int = 64,
245
        cache_ok: bool = False,
246
        perform_setup: bool = True,
247
        debug: bool = False,
248
    ) -> "LanternVectorStore":
249
        """Return connection string from database parameters."""
250
        conn_str = (
251
            connection_string
252
            or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
253
        )
254
        async_conn_str = async_connection_string or (
255
            f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
256
        )
257
        return cls(
258
            connection_string=conn_str,
259
            async_connection_string=async_conn_str,
260
            table_name=table_name,
261
            schema_name=schema_name,
262
            hybrid_search=hybrid_search,
263
            text_search_config=text_search_config,
264
            embed_dim=embed_dim,
265
            m=m,
266
            ef_construction=ef_construction,
267
            ef=ef,
268
            cache_ok=cache_ok,
269
            perform_setup=perform_setup,
270
            debug=debug,
271
        )
272

273
    @property
274
    def client(self) -> Any:
275
        if not self._is_initialized:
276
            return None
277
        return self._engine
278

279
    def _connect(self) -> Any:
280
        from sqlalchemy import create_engine
281
        from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
282
        from sqlalchemy.orm import sessionmaker
283

284
        self._engine = create_engine(self.connection_string, echo=self.debug)
285
        self._session = sessionmaker(self._engine)
286

287
        self._async_engine = create_async_engine(self.async_connection_string)
288
        self._async_session = sessionmaker(self._async_engine, class_=AsyncSession)  # type: ignore
289

290
    def _create_schema_if_not_exists(self) -> None:
291
        with self._session() as session, session.begin():
292
            from sqlalchemy import text
293

294
            statement = text(f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}")
295
            session.execute(statement)
296
            session.commit()
297

298
    def _create_tables_if_not_exists(self) -> None:
299
        with self._session() as session, session.begin():
300
            self._base.metadata.create_all(session.connection())
301

302
    def _create_extension(self) -> None:
303
        import sqlalchemy
304

305
        with self._session() as session, session.begin():
306
            statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS lantern")
307
            session.execute(statement)
308
            session.commit()
309

310
    def _initialize(self) -> None:
311
        if not self._is_initialized:
312
            self._connect()
313
            if self.perform_setup:
314
                self._create_extension()
315
                self._create_schema_if_not_exists()
316
                self._create_tables_if_not_exists()
317
            self._is_initialized = True
318

319
    def _node_to_table_row(self, node: BaseNode) -> Any:
320
        return self._table_class(
321
            node_id=node.node_id,
322
            embedding=node.get_embedding(),
323
            text=node.get_content(metadata_mode=MetadataMode.NONE),
324
            metadata_=node_to_metadata_dict(
325
                node,
326
                remove_text=True,
327
                flat_metadata=self.flat_metadata,
328
            ),
329
        )
330

331
    def add(self, nodes: List[BaseNode]) -> List[str]:
332
        self._initialize()
333
        ids = []
334
        with self._session() as session, session.begin():
335
            for node in nodes:
336
                ids.append(node.node_id)
337
                item = self._node_to_table_row(node)
338
                session.add(item)
339
            session.commit()
340
        return ids
341

342
    async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
343
        self._initialize()
344
        ids = []
345
        async with self._async_session() as session, session.begin():
346
            for node in nodes:
347
                ids.append(node.node_id)
348
                item = self._node_to_table_row(node)
349
                session.add(item)
350
            await session.commit()
351
        return ids
352

353
    def _apply_filters_and_limit(
354
        self,
355
        stmt: Select,
356
        limit: int,
357
        metadata_filters: Optional[MetadataFilters] = None,
358
    ) -> Any:
359
        import sqlalchemy
360

361
        if metadata_filters:
362
            for filter_ in metadata_filters.legacy_filters():
363
                bind_parameter = f"value_{filter_.key}"
364
                stmt = stmt.where(  # type: ignore
365
                    sqlalchemy.text(f"metadata_->>'{filter_.key}' = :{bind_parameter}")
366
                )
367
                stmt = stmt.params(  # type: ignore
368
                    **{bind_parameter: str(filter_.value)}
369
                )
370
        return stmt.limit(limit)  # type: ignore
371

372
    def _build_query(
373
        self,
374
        embedding: Optional[List[float]],
375
        limit: int = 10,
376
        metadata_filters: Optional[MetadataFilters] = None,
377
    ) -> Any:
378
        from sqlalchemy import func, select
379

380
        stmt = select(  # type: ignore
381
            self._table_class,
382
            func.cos_dist(self._table_class.embedding, embedding),
383
        ).order_by(self._table_class.embedding.op("<=>")(embedding))
384

385
        return self._apply_filters_and_limit(stmt, limit, metadata_filters)
386

387
    def _prepare_query(self, session: Any, limit: int) -> None:
388
        from sqlalchemy import text
389

390
        session.execute(text("SET enable_seqscan=OFF"))  # always use index
391
        session.execute(text(f"SET hnsw.init_k={limit}"))  # always use index
392

393
    async def _aprepare_query(self, session: Any, limit: int) -> None:
394
        from sqlalchemy import text
395

396
        await session.execute(text("SET enable_seqscan=OFF"))  # always use index
397
        await session.execute(text(f"SET hnsw.init_k={limit}"))  # always use index
398

399
    def _query_with_score(
400
        self,
401
        embedding: Optional[List[float]],
402
        limit: int = 10,
403
        metadata_filters: Optional[MetadataFilters] = None,
404
    ) -> List[DBEmbeddingRow]:
405
        stmt = self._build_query(embedding, limit, metadata_filters)
406
        with self._session() as session, session.begin():
407
            self._prepare_query(session, limit)
408
            res = session.execute(
409
                stmt,
410
            )
411
            return [
412
                DBEmbeddingRow(
413
                    node_id=item.node_id,
414
                    text=item.text,
415
                    metadata=item.metadata_,
416
                    similarity=(1 - distance) if distance is not None else 0,
417
                )
418
                for item, distance in res.all()
419
            ]
420

421
    async def _aquery_with_score(
422
        self,
423
        embedding: Optional[List[float]],
424
        limit: int = 10,
425
        metadata_filters: Optional[MetadataFilters] = None,
426
    ) -> List[DBEmbeddingRow]:
427
        stmt = self._build_query(embedding, limit, metadata_filters)
428
        async with self._async_session() as async_session, async_session.begin():
429
            await self._aprepare_query(async_session, limit)
430
            res = await async_session.execute(stmt)
431
            return [
432
                DBEmbeddingRow(
433
                    node_id=item.node_id,
434
                    text=item.text,
435
                    metadata=item.metadata_,
436
                    similarity=(1 - distance) if distance is not None else 0,
437
                )
438
                for item, distance in res.all()
439
            ]
440

441
    def _build_sparse_query(
442
        self,
443
        query_str: Optional[str],
444
        limit: int,
445
        metadata_filters: Optional[MetadataFilters] = None,
446
    ) -> Any:
447
        from sqlalchemy import select, type_coerce
448
        from sqlalchemy.sql import func, text
449
        from sqlalchemy.types import UserDefinedType
450

451
        class REGCONFIG(UserDefinedType):
452
            def get_col_spec(self, **kw: Any) -> str:
453
                return "regconfig"
454

455
        if query_str is None:
456
            raise ValueError("query_str must be specified for a sparse vector query.")
457

458
        ts_query = func.plainto_tsquery(
459
            type_coerce(self.text_search_config, REGCONFIG), query_str
460
        )
461
        stmt = (
462
            select(  # type: ignore
463
                self._table_class,
464
                func.ts_rank(self._table_class.text_search_tsv, ts_query).label("rank"),
465
            )
466
            .where(self._table_class.text_search_tsv.op("@@")(ts_query))
467
            .order_by(text("rank desc"))
468
        )
469

470
        # type: ignore
471
        return self._apply_filters_and_limit(stmt, limit, metadata_filters)
472

473
    async def _async_sparse_query_with_rank(
474
        self,
475
        query_str: Optional[str] = None,
476
        limit: int = 10,
477
        metadata_filters: Optional[MetadataFilters] = None,
478
    ) -> List[DBEmbeddingRow]:
479
        stmt = self._build_sparse_query(query_str, limit, metadata_filters)
480
        async with self._async_session() as async_session, async_session.begin():
481
            res = await async_session.execute(stmt)
482
            return [
483
                DBEmbeddingRow(
484
                    node_id=item.node_id,
485
                    text=item.text,
486
                    metadata=item.metadata_,
487
                    similarity=rank,
488
                )
489
                for item, rank in res.all()
490
            ]
491

492
    def _sparse_query_with_rank(
493
        self,
494
        query_str: Optional[str] = None,
495
        limit: int = 10,
496
        metadata_filters: Optional[MetadataFilters] = None,
497
    ) -> List[DBEmbeddingRow]:
498
        stmt = self._build_sparse_query(query_str, limit, metadata_filters)
499
        with self._session() as session, session.begin():
500
            res = session.execute(stmt)
501
            return [
502
                DBEmbeddingRow(
503
                    node_id=item.node_id,
504
                    text=item.text,
505
                    metadata=item.metadata_,
506
                    similarity=rank,
507
                )
508
                for item, rank in res.all()
509
            ]
510

511
    async def _async_hybrid_query(
512
        self, query: VectorStoreQuery
513
    ) -> List[DBEmbeddingRow]:
514
        import asyncio
515

516
        if query.alpha is not None:
517
            _logger.warning("postgres hybrid search does not support alpha parameter.")
518

519
        sparse_top_k = query.sparse_top_k or query.similarity_top_k
520

521
        results = await asyncio.gather(
522
            self._aquery_with_score(
523
                query.query_embedding, query.similarity_top_k, query.filters
524
            ),
525
            self._async_sparse_query_with_rank(
526
                query.query_str, sparse_top_k, query.filters
527
            ),
528
        )
529

530
        dense_results, sparse_results = results
531
        all_results = dense_results + sparse_results
532
        return _dedup_results(all_results)
533

534
    def _hybrid_query(self, query: VectorStoreQuery) -> List[DBEmbeddingRow]:
535
        if query.alpha is not None:
536
            _logger.warning("postgres hybrid search does not support alpha parameter.")
537

538
        sparse_top_k = query.sparse_top_k or query.similarity_top_k
539

540
        dense_results = self._query_with_score(
541
            query.query_embedding, query.similarity_top_k, query.filters
542
        )
543

544
        sparse_results = self._sparse_query_with_rank(
545
            query.query_str, sparse_top_k, query.filters
546
        )
547

548
        all_results = dense_results + sparse_results
549
        return _dedup_results(all_results)
550

551
    def _db_rows_to_query_result(
552
        self, rows: List[DBEmbeddingRow]
553
    ) -> VectorStoreQueryResult:
554
        nodes = []
555
        similarities = []
556
        ids = []
557
        for db_embedding_row in rows:
558
            try:
559
                node = metadata_dict_to_node(db_embedding_row.metadata)
560
                node.set_content(str(db_embedding_row.text))
561
            except Exception:
562
                # NOTE: deprecated legacy logic for backward compatibility
563
                node = TextNode(
564
                    id_=db_embedding_row.node_id,
565
                    text=db_embedding_row.text,
566
                    metadata=db_embedding_row.metadata,
567
                )
568
            similarities.append(db_embedding_row.similarity)
569
            ids.append(db_embedding_row.node_id)
570
            nodes.append(node)
571

572
        return VectorStoreQueryResult(
573
            nodes=nodes,
574
            similarities=similarities,
575
            ids=ids,
576
        )
577

578
    async def aquery(
579
        self, query: VectorStoreQuery, **kwargs: Any
580
    ) -> VectorStoreQueryResult:
581
        self._initialize()
582
        if query.mode == VectorStoreQueryMode.HYBRID:
583
            results = await self._async_hybrid_query(query)
584
        elif query.mode in [
585
            VectorStoreQueryMode.SPARSE,
586
            VectorStoreQueryMode.TEXT_SEARCH,
587
        ]:
588
            sparse_top_k = query.sparse_top_k or query.similarity_top_k
589
            results = await self._async_sparse_query_with_rank(
590
                query.query_str, sparse_top_k, query.filters
591
            )
592
        elif query.mode == VectorStoreQueryMode.DEFAULT:
593
            results = await self._aquery_with_score(
594
                query.query_embedding, query.similarity_top_k, query.filters
595
            )
596
        else:
597
            raise ValueError(f"Invalid query mode: {query.mode}")
598

599
        return self._db_rows_to_query_result(results)
600

601
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
602
        self._initialize()
603
        if query.mode == VectorStoreQueryMode.HYBRID:
604
            results = self._hybrid_query(query)
605
        elif query.mode in [
606
            VectorStoreQueryMode.SPARSE,
607
            VectorStoreQueryMode.TEXT_SEARCH,
608
        ]:
609
            sparse_top_k = query.sparse_top_k or query.similarity_top_k
610
            results = self._sparse_query_with_rank(
611
                query.query_str, sparse_top_k, query.filters
612
            )
613
        elif query.mode == VectorStoreQueryMode.DEFAULT:
614
            results = self._query_with_score(
615
                query.query_embedding, query.similarity_top_k, query.filters
616
            )
617
        else:
618
            raise ValueError(f"Invalid query mode: {query.mode}")
619

620
        return self._db_rows_to_query_result(results)
621

622
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
623
        import sqlalchemy
624

625
        self._initialize()
626
        with self._session() as session, session.begin():
627
            stmt = sqlalchemy.text(
628
                f"DELETE FROM {self.schema_name}.data_{self.table_name} where "
629
                f"(metadata_->>'doc_id')::text = '{ref_doc_id}' "
630
            )
631

632
            session.execute(stmt)
633
            session.commit()
634

635

636
def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]:
637
    seen_ids = set()
638
    deduped_results = []
639
    for result in results:
640
        if result.node_id not in seen_ids:
641
            deduped_results.append(result)
642
            seen_ids.add(result.node_id)
643
    return deduped_results
644

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.