llama-index

Форк
0
176 строк · 5.5 Кб
1
"""ChatGPT Plugin vector store."""
2

3
import os
4
from typing import Any, Dict, List, Optional
5

6
import requests
7
from requests.adapters import HTTPAdapter, Retry
8

9
from llama_index.legacy.schema import (
10
    BaseNode,
11
    MetadataMode,
12
    NodeRelationship,
13
    RelatedNodeInfo,
14
    TextNode,
15
)
16
from llama_index.legacy.utils import get_tqdm_iterable
17
from llama_index.legacy.vector_stores.types import (
18
    VectorStore,
19
    VectorStoreQuery,
20
    VectorStoreQueryResult,
21
)
22

23

24
def convert_docs_to_json(nodes: List[BaseNode]) -> List[Dict]:
25
    """Convert docs to JSON."""
26
    docs = []
27
    for node in nodes:
28
        # TODO: add information for other fields as well
29
        # fields taken from
30
        # https://rb.gy/nmac9u
31
        doc_dict = {
32
            "id": node.node_id,
33
            "text": node.get_content(metadata_mode=MetadataMode.NONE),
34
            # NOTE: this is the doc_id to reference document
35
            "source_id": node.ref_doc_id,
36
            # "url": "...",
37
            # "created_at": ...,
38
            # "author": "..."",
39
        }
40
        metadata = node.metadata
41
        if metadata is not None:
42
            if "source" in metadata:
43
                doc_dict["source"] = metadata["source"]
44
            if "source_id" in metadata:
45
                doc_dict["source_id"] = metadata["source_id"]
46
            if "url" in metadata:
47
                doc_dict["url"] = metadata["url"]
48
            if "created_at" in metadata:
49
                doc_dict["created_at"] = metadata["created_at"]
50
            if "author" in metadata:
51
                doc_dict["author"] = metadata["author"]
52

53
        docs.append(doc_dict)
54
    return docs
55

56

57
class ChatGPTRetrievalPluginClient(VectorStore):
58
    """ChatGPT Retrieval Plugin Client.
59

60
    In this client, we make use of the endpoints defined by ChatGPT.
61

62
    Args:
63
        endpoint_url (str): URL of the ChatGPT Retrieval Plugin.
64
        bearer_token (Optional[str]): Bearer token for the ChatGPT Retrieval Plugin.
65
        retries (Optional[Retry]): Retry object for the ChatGPT Retrieval Plugin.
66
        batch_size (int): Batch size for the ChatGPT Retrieval Plugin.
67
    """
68

69
    stores_text: bool = True
70
    is_embedding_query: bool = False
71

72
    def __init__(
73
        self,
74
        endpoint_url: str,
75
        bearer_token: Optional[str] = None,
76
        retries: Optional[Retry] = None,
77
        batch_size: int = 100,
78
        **kwargs: Any,
79
    ) -> None:
80
        """Initialize params."""
81
        self._endpoint_url = endpoint_url
82
        self._bearer_token = bearer_token or os.getenv("BEARER_TOKEN")
83
        self._retries = retries
84
        self._batch_size = batch_size
85

86
        self._s = requests.Session()
87
        self._s.mount("http://", HTTPAdapter(max_retries=self._retries))
88

89
    @property
90
    def client(self) -> None:
91
        """Get client."""
92
        return
93

94
    def add(
95
        self,
96
        nodes: List[BaseNode],
97
        **add_kwargs: Any,
98
    ) -> List[str]:
99
        """Add nodes to index."""
100
        headers = {"Authorization": f"Bearer {self._bearer_token}"}
101

102
        docs_to_upload = convert_docs_to_json(nodes)
103
        iterable_docs = get_tqdm_iterable(
104
            range(0, len(docs_to_upload), self._batch_size),
105
            show_progress=True,
106
            desc="Uploading documents",
107
        )
108
        for i in iterable_docs:
109
            i_end = min(i + self._batch_size, len(docs_to_upload))
110
            self._s.post(
111
                f"{self._endpoint_url}/upsert",
112
                headers=headers,
113
                json={"documents": docs_to_upload[i:i_end]},
114
            )
115

116
        return [result.node_id for result in nodes]
117

118
    def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
119
        """
120
        Delete nodes using with ref_doc_id.
121

122
        Args:
123
            ref_doc_id (str): The doc_id of the document to delete.
124

125
        """
126
        headers = {"Authorization": f"Bearer {self._bearer_token}"}
127
        self._s.post(
128
            f"{self._endpoint_url}/delete",
129
            headers=headers,
130
            json={"ids": [ref_doc_id]},
131
        )
132

133
    def query(
134
        self,
135
        query: VectorStoreQuery,
136
        **kwargs: Any,
137
    ) -> VectorStoreQueryResult:
138
        """Get nodes for response."""
139
        if query.filters is not None:
140
            raise ValueError("Metadata filters not implemented for ChatGPT Plugin yet.")
141

142
        if query.query_str is None:
143
            raise ValueError("query_str must be provided")
144
        headers = {"Authorization": f"Bearer {self._bearer_token}"}
145
        # TODO: add metadata filter
146
        queries = [{"query": query.query_str, "top_k": query.similarity_top_k}]
147
        res = requests.post(
148
            f"{self._endpoint_url}/query", headers=headers, json={"queries": queries}
149
        )
150

151
        nodes = []
152
        similarities = []
153
        ids = []
154
        for query_result in res.json()["results"]:
155
            for result in query_result["results"]:
156
                result_id = result["id"]
157
                result_txt = result["text"]
158
                result_score = result["score"]
159
                result_ref_doc_id = result["source_id"]
160
                node = TextNode(
161
                    id_=result_id,
162
                    text=result_txt,
163
                    relationships={
164
                        NodeRelationship.SOURCE: RelatedNodeInfo(
165
                            node_id=result_ref_doc_id
166
                        )
167
                    },
168
                )
169
                nodes.append(node)
170
                similarities.append(result_score)
171
                ids.append(result_id)
172

173
            # NOTE: there should only be one query
174
            break
175

176
        return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.