llama-index

Форк
0
171 строка · 6.1 Кб
1
"""Pathway Retriever."""
2

3
import logging
4
from typing import Any, Callable, List, Optional, Tuple, Union
5

6
from llama_index.legacy.callbacks.base import CallbackManager
7
from llama_index.legacy.constants import DEFAULT_SIMILARITY_TOP_K
8
from llama_index.legacy.core.base_retriever import BaseRetriever
9
from llama_index.legacy.embeddings import BaseEmbedding
10
from llama_index.legacy.indices.query.schema import QueryBundle
11
from llama_index.legacy.ingestion.pipeline import run_transformations
12
from llama_index.legacy.schema import (
13
    BaseNode,
14
    NodeWithScore,
15
    QueryBundle,
16
    TextNode,
17
    TransformComponent,
18
)
19

20
logger = logging.getLogger(__name__)
21

22

23
def node_transformer(x: str) -> List[BaseNode]:
24
    return [TextNode(text=x)]
25

26

27
def node_to_pathway(x: BaseNode) -> List[Tuple[str, dict]]:
28
    return [(node.text, node.extra_info) for node in x]
29

30

31
class PathwayVectorServer:
32
    """
33
    Build an autoupdating document indexing pipeline
34
    for approximate nearest neighbor search.
35

36
    Args:
37
        docs (list): Pathway tables, may be pw.io connectors or custom tables.
38

39
        transformations (List[TransformComponent]): list of transformation steps, has to
40
            include embedding as last step, optionally splitter and other
41
            TransformComponent in the middle
42

43
        parser (Callable[[bytes], list[tuple[str, dict]]]): optional, callable that
44
            parses file contents into a list of documents. If None, defaults to `uft-8` decoding of the file contents. Defaults to None.
45
    """
46

47
    def __init__(
48
        self,
49
        *docs: Any,
50
        transformations: List[Union[TransformComponent, Callable[[Any], Any]]],
51
        parser: Optional[Callable[[bytes], List[Tuple[str, dict]]]] = None,
52
        **kwargs: Any,
53
    ) -> None:
54
        try:
55
            from pathway.xpacks.llm import vector_store
56
        except ImportError:
57
            raise ImportError(
58
                "Could not import pathway python package. "
59
                "Please install it with `pip install pathway`."
60
            )
61

62
        if transformations is None or not transformations:
63
            raise ValueError("Transformations list cannot be None or empty.")
64

65
        if not isinstance(transformations[-1], BaseEmbedding):
66
            raise ValueError(
67
                f"Last step of transformations should be an instance of {BaseEmbedding.__name__}, "
68
                f"found {type(transformations[-1])}."
69
            )
70

71
        embedder: BaseEmbedding = transformations.pop()  # type: ignore
72

73
        def embedding_callable(x: str) -> List[float]:
74
            return embedder.get_text_embedding(x)
75

76
        transformations.insert(0, node_transformer)
77
        transformations.append(node_to_pathway)  # TextNode -> (str, dict)
78

79
        def generic_transformer(x: List[str]) -> List[Tuple[str, dict]]:
80
            return run_transformations(x, transformations)  # type: ignore
81

82
        self.vector_store_server = vector_store.VectorStoreServer(
83
            *docs,
84
            embedder=embedding_callable,
85
            parser=parser,
86
            splitter=generic_transformer,
87
            **kwargs,
88
        )
89

90
    def run_server(
91
        self,
92
        host: str,
93
        port: str,
94
        threaded: bool = False,
95
        with_cache: bool = True,
96
        cache_backend: Any = None,
97
    ) -> Any:
98
        """
99
        Run the server and start answering queries.
100

101
        Args:
102
            host (str): host to bind the HTTP listener
103
            port (str | int): port to bind the HTTP listener
104
            threaded (bool): if True, run in a thread. Else block computation
105
            with_cache (bool): if True, embedding requests for the same contents are cached
106
            cache_backend: the backend to use for caching if it is enabled. The
107
              default is the disk cache, hosted locally in the folder ``./Cache``. You
108
              can use ``Backend`` class of the [`persistence API`]
109
              (/developers/api-docs/persistence-api/#pathway.persistence.Backend)
110
              to override it.
111

112
        Returns:
113
            If threaded, return the Thread object. Else, does not return.
114
        """
115
        try:
116
            import pathway as pw
117
        except ImportError:
118
            raise ImportError(
119
                "Could not import pathway python package. "
120
                "Please install it with `pip install pathway`."
121
            )
122
        if with_cache and cache_backend is None:
123
            cache_backend = pw.persistence.Backend.filesystem("./Cache")
124
        return self.vector_store_server.run_server(
125
            host,
126
            port,
127
            threaded=threaded,
128
            with_cache=with_cache,
129
            cache_backend=cache_backend,
130
        )
131

132

133
class PathwayRetriever(BaseRetriever):
134
    """Pathway retriever.
135
    Pathway is an open data processing framework.
136
    It allows you to easily develop data transformation pipelines
137
    that work with live data sources and changing data.
138

139
    This is the client that implements Retriever API for PathwayVectorServer.
140
    """
141

142
    def __init__(
143
        self,
144
        host: str,
145
        port: Union[str, int],
146
        similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
147
        callback_manager: Optional[CallbackManager] = None,
148
    ) -> None:
149
        """Initializing the Pathway retriever client."""
150
        import_err_msg = "`pathway` package not found, please run `pip install pathway`"
151
        try:
152
            from pathway.xpacks.llm.vector_store import VectorStoreClient
153
        except ImportError:
154
            raise ImportError(import_err_msg)
155
        self.client = VectorStoreClient(host, port)
156
        self.similarity_top_k = similarity_top_k
157
        super().__init__(callback_manager)
158

159
    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
160
        """Retrieve."""
161
        rets = self.client(query=query_bundle.query_str, k=self.similarity_top_k)
162
        items = [
163
            NodeWithScore(
164
                node=TextNode(text=ret["text"], extra_info=ret["metadata"]),
165
                # Transform cosine distance into a similairty score
166
                # (higher is more similar)
167
                score=1 - ret["dist"],
168
            )
169
            for ret in rets
170
        ]
171
        return sorted(items, key=lambda x: x.score or 0.0, reverse=True)
172

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.