llama-index

Форк
0
702 строки · 23.4 Кб
1
import logging
2
from typing import Any, List, NamedTuple, Optional, Type
3

4
from llama_index.legacy.bridge.pydantic import PrivateAttr
5
from llama_index.legacy.schema import BaseNode, MetadataMode, TextNode
6
from llama_index.legacy.vector_stores.types import (
7
    BasePydanticVectorStore,
8
    FilterOperator,
9
    MetadataFilters,
10
    VectorStoreQuery,
11
    VectorStoreQueryMode,
12
    VectorStoreQueryResult,
13
)
14
from llama_index.legacy.vector_stores.utils import (
15
    metadata_dict_to_node,
16
    node_to_metadata_dict,
17
)
18

19

20
class DBEmbeddingRow(NamedTuple):
21
    node_id: str  # FIXME: verify this type hint
22
    text: str
23
    metadata: dict
24
    similarity: float
25

26

27
_logger = logging.getLogger(__name__)
28

29

30
def get_data_model(
31
    base: Type,
32
    index_name: str,
33
    schema_name: str,
34
    hybrid_search: bool,
35
    text_search_config: str,
36
    cache_okay: bool,
37
    embed_dim: int = 1536,
38
    use_jsonb: bool = False,
39
) -> Any:
40
    """
41
    This part create a dynamic sqlalchemy model with a new table.
42
    """
43
    from pgvector.sqlalchemy import Vector
44
    from sqlalchemy import Column, Computed
45
    from sqlalchemy.dialects.postgresql import BIGINT, JSON, JSONB, TSVECTOR, VARCHAR
46
    from sqlalchemy.schema import Index
47
    from sqlalchemy.types import TypeDecorator
48

49
    class TSVector(TypeDecorator):
50
        impl = TSVECTOR
51
        cache_ok = cache_okay
52

53
    tablename = "data_%s" % index_name  # dynamic table name
54
    class_name = "Data%s" % index_name  # dynamic class name
55
    indexname = "%s_idx" % index_name  # dynamic class name
56

57
    metadata_dtype = JSONB if use_jsonb else JSON
58

59
    if hybrid_search:
60

61
        class HybridAbstractData(base):  # type: ignore
62
            __abstract__ = True  # this line is necessary
63
            id = Column(BIGINT, primary_key=True, autoincrement=True)
64
            text = Column(VARCHAR, nullable=False)
65
            metadata_ = Column(metadata_dtype)
66
            node_id = Column(VARCHAR)
67
            embedding = Column(Vector(embed_dim))  # type: ignore
68
            text_search_tsv = Column(  # type: ignore
69
                TSVector(),
70
                Computed(
71
                    "to_tsvector('%s', text)" % text_search_config, persisted=True
72
                ),
73
            )
74

75
        model = type(
76
            class_name,
77
            (HybridAbstractData,),
78
            {"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
79
        )
80

81
        Index(
82
            indexname,
83
            model.text_search_tsv,  # type: ignore
84
            postgresql_using="gin",
85
        )
86
    else:
87

88
        class AbstractData(base):  # type: ignore
89
            __abstract__ = True  # this line is necessary
90
            id = Column(BIGINT, primary_key=True, autoincrement=True)
91
            text = Column(VARCHAR, nullable=False)
92
            metadata_ = Column(metadata_dtype)
93
            node_id = Column(VARCHAR)
94
            embedding = Column(Vector(embed_dim))  # type: ignore
95

96
        model = type(
97
            class_name,
98
            (AbstractData,),
99
            {"__tablename__": tablename, "__table_args__": {"schema": schema_name}},
100
        )
101

102
    return model
103

104

105
class PGVectorStore(BasePydanticVectorStore):
106
    from sqlalchemy.sql.selectable import Select
107

108
    stores_text = True
109
    flat_metadata = False
110

111
    connection_string: str
112
    async_connection_string: str
113
    table_name: str
114
    schema_name: str
115
    embed_dim: int
116
    hybrid_search: bool
117
    text_search_config: str
118
    cache_ok: bool
119
    perform_setup: bool
120
    debug: bool
121
    use_jsonb: bool
122

123
    _base: Any = PrivateAttr()
124
    _table_class: Any = PrivateAttr()
125
    _engine: Any = PrivateAttr()
126
    _session: Any = PrivateAttr()
127
    _async_engine: Any = PrivateAttr()
128
    _async_session: Any = PrivateAttr()
129
    _is_initialized: bool = PrivateAttr(default=False)
130

131
    def __init__(
132
        self,
133
        connection_string: str,
134
        async_connection_string: str,
135
        table_name: str,
136
        schema_name: str,
137
        hybrid_search: bool = False,
138
        text_search_config: str = "english",
139
        embed_dim: int = 1536,
140
        cache_ok: bool = False,
141
        perform_setup: bool = True,
142
        debug: bool = False,
143
        use_jsonb: bool = False,
144
    ) -> None:
145
        try:
146
            import asyncpg  # noqa
147
            import pgvector  # noqa
148
            import psycopg2  # noqa
149
            import sqlalchemy
150
            import sqlalchemy.ext.asyncio  # noqa
151
        except ImportError:
152
            raise ImportError(
153
                "`sqlalchemy[asyncio]`, `pgvector`, `psycopg2-binary` and `asyncpg` "
154
                "packages should be pre installed"
155
            )
156

157
        table_name = table_name.lower()
158
        schema_name = schema_name.lower()
159

160
        if hybrid_search and text_search_config is None:
161
            raise ValueError(
162
                "Sparse vector index creation requires "
163
                "a text search configuration specification."
164
            )
165

166
        from sqlalchemy.orm import declarative_base
167

168
        # sqlalchemy model
169
        self._base = declarative_base()
170
        self._table_class = get_data_model(
171
            self._base,
172
            table_name,
173
            schema_name,
174
            hybrid_search,
175
            text_search_config,
176
            cache_ok,
177
            embed_dim=embed_dim,
178
            use_jsonb=use_jsonb,
179
        )
180

181
        super().__init__(
182
            connection_string=connection_string,
183
            async_connection_string=async_connection_string,
184
            table_name=table_name,
185
            schema_name=schema_name,
186
            hybrid_search=hybrid_search,
187
            text_search_config=text_search_config,
188
            embed_dim=embed_dim,
189
            cache_ok=cache_ok,
190
            perform_setup=perform_setup,
191
            debug=debug,
192
            use_jsonb=use_jsonb,
193
        )
194

195
    async def close(self) -> None:
196
        if not self._is_initialized:
197
            return
198

199
        self._session.close_all()
200
        self._engine.dispose()
201

202
        await self._async_engine.dispose()
203

204
    @classmethod
205
    def class_name(cls) -> str:
206
        return "PGVectorStore"
207

208
    @classmethod
209
    def from_params(
210
        cls,
211
        host: Optional[str] = None,
212
        port: Optional[str] = None,
213
        database: Optional[str] = None,
214
        user: Optional[str] = None,
215
        password: Optional[str] = None,
216
        table_name: str = "llamaindex",
217
        schema_name: str = "public",
218
        connection_string: Optional[str] = None,
219
        async_connection_string: Optional[str] = None,
220
        hybrid_search: bool = False,
221
        text_search_config: str = "english",
222
        embed_dim: int = 1536,
223
        cache_ok: bool = False,
224
        perform_setup: bool = True,
225
        debug: bool = False,
226
        use_jsonb: bool = False,
227
    ) -> "PGVectorStore":
228
        """Return connection string from database parameters."""
229
        conn_str = (
230
            connection_string
231
            or f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}"
232
        )
233
        async_conn_str = async_connection_string or (
234
            f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
235
        )
236
        return cls(
237
            connection_string=conn_str,
238
            async_connection_string=async_conn_str,
239
            table_name=table_name,
240
            schema_name=schema_name,
241
            hybrid_search=hybrid_search,
242
            text_search_config=text_search_config,
243
            embed_dim=embed_dim,
244
            cache_ok=cache_ok,
245
            perform_setup=perform_setup,
246
            debug=debug,
247
            use_jsonb=use_jsonb,
248
        )
249

250
    @property
251
    def client(self) -> Any:
252
        if not self._is_initialized:
253
            return None
254
        return self._engine
255

256
    def _connect(self) -> Any:
257
        from sqlalchemy import create_engine
258
        from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
259
        from sqlalchemy.orm import sessionmaker
260

261
        self._engine = create_engine(self.connection_string, echo=self.debug)
262
        self._session = sessionmaker(self._engine)
263

264
        self._async_engine = create_async_engine(self.async_connection_string)
265
        self._async_session = sessionmaker(self._async_engine, class_=AsyncSession)  # type: ignore
266

267
    def _create_schema_if_not_exists(self) -> None:
268
        with self._session() as session, session.begin():
269
            from sqlalchemy import text
270

271
            # Check if the specified schema exists with "CREATE" statement
272
            check_schema_statement = text(
273
                f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = '{self.schema_name}'"
274
            )
275
            result = session.execute(check_schema_statement).fetchone()
276

277
            # If the schema does not exist, then create it
278
            if not result:
279
                create_schema_statement = text(
280
                    f"CREATE SCHEMA IF NOT EXISTS {self.schema_name}"
281
                )
282
                session.execute(create_schema_statement)
283

284
            session.commit()
285

286
    def _create_tables_if_not_exists(self) -> None:
287
        with self._session() as session, session.begin():
288
            self._base.metadata.create_all(session.connection())
289

290
    def _create_extension(self) -> None:
291
        import sqlalchemy
292

293
        with self._session() as session, session.begin():
294
            statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")
295
            session.execute(statement)
296
            session.commit()
297

298
    def _initialize(self) -> None:
299
        if not self._is_initialized:
300
            self._connect()
301
            if self.perform_setup:
302
                self._create_extension()
303
                self._create_schema_if_not_exists()
304
                self._create_tables_if_not_exists()
305
            self._is_initialized = True
306

307
    def _node_to_table_row(self, node: BaseNode) -> Any:
308
        return self._table_class(
309
            node_id=node.node_id,
310
            embedding=node.get_embedding(),
311
            text=node.get_content(metadata_mode=MetadataMode.NONE),
312
            metadata_=node_to_metadata_dict(
313
                node,
314
                remove_text=True,
315
                flat_metadata=self.flat_metadata,
316
            ),
317
        )
318

319
    def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
320
        self._initialize()
321
        ids = []
322
        with self._session() as session, session.begin():
323
            for node in nodes:
324
                ids.append(node.node_id)
325
                item = self._node_to_table_row(node)
326
                session.add(item)
327
            session.commit()
328
        return ids
329

330
    async def async_add(self, nodes: List[BaseNode], **kwargs: Any) -> List[str]:
331
        self._initialize()
332
        ids = []
333
        async with self._async_session() as session, session.begin():
334
            for node in nodes:
335
                ids.append(node.node_id)
336
                item = self._node_to_table_row(node)
337
                session.add(item)
338
            await session.commit()
339
        return ids
340

341
    def _to_postgres_operator(self, operator: FilterOperator) -> str:
342
        if operator == FilterOperator.EQ:
343
            return "="
344
        elif operator == FilterOperator.GT:
345
            return ">"
346
        elif operator == FilterOperator.LT:
347
            return "<"
348
        elif operator == FilterOperator.NE:
349
            return "!="
350
        elif operator == FilterOperator.GTE:
351
            return ">="
352
        elif operator == FilterOperator.LTE:
353
            return "<="
354
        elif operator == FilterOperator.IN:
355
            return "@>"
356
        else:
357
            _logger.warning(f"Unknown operator: {operator}, fallback to '='")
358
            return "="
359

360
    def _apply_filters_and_limit(
361
        self,
362
        stmt: Select,
363
        limit: int,
364
        metadata_filters: Optional[MetadataFilters] = None,
365
    ) -> Any:
366
        import sqlalchemy
367

368
        sqlalchemy_conditions = {
369
            "or": sqlalchemy.sql.or_,
370
            "and": sqlalchemy.sql.and_,
371
        }
372

373
        if metadata_filters:
374
            if metadata_filters.condition not in sqlalchemy_conditions:
375
                raise ValueError(
376
                    f"Invalid condition: {metadata_filters.condition}. "
377
                    f"Must be one of {list(sqlalchemy_conditions.keys())}"
378
                )
379
            stmt = stmt.where(  # type: ignore
380
                sqlalchemy_conditions[metadata_filters.condition](
381
                    *(
382
                        (
383
                            sqlalchemy.text(
384
                                f"metadata_::jsonb->'{filter_.key}' "
385
                                f"{self._to_postgres_operator(filter_.operator)} "
386
                                f"'[\"{filter_.value}\"]'"
387
                            )
388
                            if filter_.operator == FilterOperator.IN
389
                            else sqlalchemy.text(
390
                                f"metadata_->>'{filter_.key}' "
391
                                f"{self._to_postgres_operator(filter_.operator)} "
392
                                f"'{filter_.value}'"
393
                            )
394
                        )
395
                        for filter_ in metadata_filters.filters
396
                    )
397
                )
398
            )
399
        return stmt.limit(limit)  # type: ignore
400

401
    def _build_query(
402
        self,
403
        embedding: Optional[List[float]],
404
        limit: int = 10,
405
        metadata_filters: Optional[MetadataFilters] = None,
406
    ) -> Any:
407
        from sqlalchemy import select, text
408

409
        stmt = select(  # type: ignore
410
            self._table_class.id,
411
            self._table_class.node_id,
412
            self._table_class.text,
413
            self._table_class.metadata_,
414
            self._table_class.embedding.cosine_distance(embedding).label("distance"),
415
        ).order_by(text("distance asc"))
416

417
        return self._apply_filters_and_limit(stmt, limit, metadata_filters)
418

419
    def _query_with_score(
420
        self,
421
        embedding: Optional[List[float]],
422
        limit: int = 10,
423
        metadata_filters: Optional[MetadataFilters] = None,
424
        **kwargs: Any,
425
    ) -> List[DBEmbeddingRow]:
426
        stmt = self._build_query(embedding, limit, metadata_filters)
427
        with self._session() as session, session.begin():
428
            from sqlalchemy import text
429

430
            if kwargs.get("ivfflat_probes"):
431
                session.execute(
432
                    text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
433
                )
434
            if kwargs.get("hnsw_ef_search"):
435
                session.execute(
436
                    text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
437
                )
438

439
            res = session.execute(
440
                stmt,
441
            )
442
            return [
443
                DBEmbeddingRow(
444
                    node_id=item.node_id,
445
                    text=item.text,
446
                    metadata=item.metadata_,
447
                    similarity=(1 - item.distance) if item.distance is not None else 0,
448
                )
449
                for item in res.all()
450
            ]
451

452
    async def _aquery_with_score(
453
        self,
454
        embedding: Optional[List[float]],
455
        limit: int = 10,
456
        metadata_filters: Optional[MetadataFilters] = None,
457
        **kwargs: Any,
458
    ) -> List[DBEmbeddingRow]:
459
        stmt = self._build_query(embedding, limit, metadata_filters)
460
        async with self._async_session() as async_session, async_session.begin():
461
            from sqlalchemy import text
462

463
            if kwargs.get("hnsw_ef_search"):
464
                await async_session.execute(
465
                    text(f"SET hnsw.ef_search = {kwargs.get('hnsw_ef_search')}")
466
                )
467
            if kwargs.get("ivfflat_probes"):
468
                await async_session.execute(
469
                    text(f"SET ivfflat.probes = {kwargs.get('ivfflat_probes')}")
470
                )
471

472
            res = await async_session.execute(stmt)
473
            return [
474
                DBEmbeddingRow(
475
                    node_id=item.node_id,
476
                    text=item.text,
477
                    metadata=item.metadata_,
478
                    similarity=(1 - item.distance) if item.distance is not None else 0,
479
                )
480
                for item in res.all()
481
            ]
482

483
    def _build_sparse_query(
484
        self,
485
        query_str: Optional[str],
486
        limit: int,
487
        metadata_filters: Optional[MetadataFilters] = None,
488
    ) -> Any:
489
        from sqlalchemy import select, type_coerce
490
        from sqlalchemy.sql import func, text
491
        from sqlalchemy.types import UserDefinedType
492

493
        class REGCONFIG(UserDefinedType):
494
            def get_col_spec(self, **kw: Any) -> str:
495
                return "regconfig"
496

497
        if query_str is None:
498
            raise ValueError("query_str must be specified for a sparse vector query.")
499

500
        ts_query = func.plainto_tsquery(
501
            type_coerce(self.text_search_config, REGCONFIG), query_str
502
        )
503
        stmt = (
504
            select(  # type: ignore
505
                self._table_class.id,
506
                self._table_class.node_id,
507
                self._table_class.text,
508
                self._table_class.metadata_,
509
                func.ts_rank(self._table_class.text_search_tsv, ts_query).label("rank"),
510
            )
511
            .where(self._table_class.text_search_tsv.op("@@")(ts_query))
512
            .order_by(text("rank desc"))
513
        )
514

515
        # type: ignore
516
        return self._apply_filters_and_limit(stmt, limit, metadata_filters)
517

518
    async def _async_sparse_query_with_rank(
519
        self,
520
        query_str: Optional[str] = None,
521
        limit: int = 10,
522
        metadata_filters: Optional[MetadataFilters] = None,
523
    ) -> List[DBEmbeddingRow]:
524
        stmt = self._build_sparse_query(query_str, limit, metadata_filters)
525
        async with self._async_session() as async_session, async_session.begin():
526
            res = await async_session.execute(stmt)
527
            return [
528
                DBEmbeddingRow(
529
                    node_id=item.node_id,
530
                    text=item.text,
531
                    metadata=item.metadata_,
532
                    similarity=item.rank,
533
                )
534
                for item in res.all()
535
            ]
536

537
    def _sparse_query_with_rank(
538
        self,
539
        query_str: Optional[str] = None,
540
        limit: int = 10,
541
        metadata_filters: Optional[MetadataFilters] = None,
542
    ) -> List[DBEmbeddingRow]:
543
        stmt = self._build_sparse_query(query_str, limit, metadata_filters)
544
        with self._session() as session, session.begin():
545
            res = session.execute(stmt)
546
            return [
547
                DBEmbeddingRow(
548
                    node_id=item.node_id,
549
                    text=item.text,
550
                    metadata=item.metadata_,
551
                    similarity=item.rank,
552
                )
553
                for item in res.all()
554
            ]
555

556
    async def _async_hybrid_query(
557
        self, query: VectorStoreQuery, **kwargs: Any
558
    ) -> List[DBEmbeddingRow]:
559
        import asyncio
560

561
        if query.alpha is not None:
562
            _logger.warning("postgres hybrid search does not support alpha parameter.")
563

564
        sparse_top_k = query.sparse_top_k or query.similarity_top_k
565

566
        results = await asyncio.gather(
567
            self._aquery_with_score(
568
                query.query_embedding,
569
                query.similarity_top_k,
570
                query.filters,
571
                **kwargs,
572
            ),
573
            self._async_sparse_query_with_rank(
574
                query.query_str, sparse_top_k, query.filters
575
            ),
576
        )
577

578
        dense_results, sparse_results = results
579
        all_results = dense_results + sparse_results
580
        return _dedup_results(all_results)
581

582
    def _hybrid_query(
583
        self, query: VectorStoreQuery, **kwargs: Any
584
    ) -> List[DBEmbeddingRow]:
585
        if query.alpha is not None:
586
            _logger.warning("postgres hybrid search does not support alpha parameter.")
587

588
        sparse_top_k = query.sparse_top_k or query.similarity_top_k
589

590
        dense_results = self._query_with_score(
591
            query.query_embedding,
592
            query.similarity_top_k,
593
            query.filters,
594
            **kwargs,
595
        )
596

597
        sparse_results = self._sparse_query_with_rank(
598
            query.query_str, sparse_top_k, query.filters
599
        )
600

601
        all_results = dense_results + sparse_results
602
        return _dedup_results(all_results)
603

604
    def _db_rows_to_query_result(
605
        self, rows: List[DBEmbeddingRow]
606
    ) -> VectorStoreQueryResult:
607
        nodes = []
608
        similarities = []
609
        ids = []
610
        for db_embedding_row in rows:
611
            try:
612
                node = metadata_dict_to_node(db_embedding_row.metadata)
613
                node.set_content(str(db_embedding_row.text))
614
            except Exception:
615
                # NOTE: deprecated legacy logic for backward compatibility
616
                node = TextNode(
617
                    id_=db_embedding_row.node_id,
618
                    text=db_embedding_row.text,
619
                    metadata=db_embedding_row.metadata,
620
                )
621
            similarities.append(db_embedding_row.similarity)
622
            ids.append(db_embedding_row.node_id)
623
            nodes.append(node)
624

625
        return VectorStoreQueryResult(
626
            nodes=nodes,
627
            similarities=similarities,
628
            ids=ids,
629
        )
630

631
    async def aquery(
632
        self, query: VectorStoreQuery, **kwargs: Any
633
    ) -> VectorStoreQueryResult:
634
        self._initialize()
635
        if query.mode == VectorStoreQueryMode.HYBRID:
636
            results = await self._async_hybrid_query(query, **kwargs)
637
        elif query.mode in [
638
            VectorStoreQueryMode.SPARSE,
639
            VectorStoreQueryMode.TEXT_SEARCH,
640
        ]:
641
            sparse_top_k = query.sparse_top_k or query.similarity_top_k
642
            results = await self._async_sparse_query_with_rank(
643
                query.query_str, sparse_top_k, query.filters
644
            )
645
        elif query.mode == VectorStoreQueryMode.DEFAULT:
646
            results = await self._aquery_with_score(
647
                query.query_embedding,
648
                query.similarity_top_k,
649
                query.filters,
650
                **kwargs,
651
            )
652
        else:
653
            raise ValueError(f"Invalid query mode: {query.mode}")
654

655
        return self._db_rows_to_query_result(results)
656

657
    def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
658
        self._initialize()
659
        if query.mode == VectorStoreQueryMode.HYBRID:
660
            results = self._hybrid_query(query, **kwargs)
661
        elif query.mode in [
662
            VectorStoreQueryMode.SPARSE,
663
            VectorStoreQueryMode.TEXT_SEARCH,
664
        ]:
665
            sparse_top_k = query.sparse_top_k or query.similarity_top_k
666
            results = self._sparse_query_with_rank(
667
                query.query_str, sparse_top_k, query.filters
668
            )
669
        elif query.mode == VectorStoreQueryMode.DEFAULT:
670
            results = self._query_with_score(
671
                query.query_embedding,
672
                query.similarity_top_k,
673
                query.filters,
674
                **kwargs,
675
            )
676
        else:
677
            raise ValueError(f"Invalid query mode: {query.mode}")
678

679
        return self._db_rows_to_query_result(results)
680

681
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
682
        import sqlalchemy
683

684
        self._initialize()
685
        with self._session() as session, session.begin():
686
            stmt = sqlalchemy.text(
687
                f"DELETE FROM {self.schema_name}.data_{self.table_name} where "
688
                f"(metadata_->>'doc_id')::text = '{ref_doc_id}' "
689
            )
690

691
            session.execute(stmt)
692
            session.commit()
693

694

695
def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]:
696
    seen_ids = set()
697
    deduped_results = []
698
    for result in results:
699
        if result.node_id not in seen_ids:
700
            deduped_results.append(result)
701
            seen_ids.add(result.node_id)
702
    return deduped_results
703

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.