Langchain-Chatchat

Форк
0
/
search_knowledgebase_once.py 
234 строки · 8.3 Кб
1
from __future__ import annotations
2
import re
3
import warnings
4
from typing import Dict
5

6
from langchain.callbacks.manager import (
7
    AsyncCallbackManagerForChainRun,
8
    CallbackManagerForChainRun,
9
)
10
from langchain.chains.llm import LLMChain
11
from langchain.pydantic_v1 import Extra, root_validator
12
from langchain.schema import BasePromptTemplate
13
from langchain.schema.language_model import BaseLanguageModel
14
from typing import List, Any, Optional
15
from langchain.prompts import PromptTemplate
16
import sys
17
import os
18
import json
19

20
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
21
from server.chat.knowledge_base_chat import knowledge_base_chat
22
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
23

24
import asyncio
25
from server.agent import model_container
26
from pydantic import BaseModel, Field
27

28
async def search_knowledge_base_iter(database: str, query: str):
29
    response = await knowledge_base_chat(query=query,
30
                                         knowledge_base_name=database,
31
                                         model_name=model_container.MODEL.model_name,
32
                                         temperature=0.01,
33
                                         history=[],
34
                                         top_k=VECTOR_SEARCH_TOP_K,
35
                                         max_tokens=MAX_TOKENS,
36
                                         prompt_name="knowledge_base_chat",
37
                                         score_threshold=SCORE_THRESHOLD,
38
                                         stream=False)
39

40
    contents = ""
41
    async for data in response.body_iterator:  # 这里的data是一个json字符串
42
        data = json.loads(data)
43
        contents += data["answer"]
44
        docs = data["docs"]
45
    return contents
46

47

48
_PROMPT_TEMPLATE = """
49
用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
50
Question: ${{用户的问题}}
51
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
52

53
{database_names}
54

55
你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
56
```text
57
${{知识库的名称}}
58
```
59
```output
60
数据库查询的结果
61
```
62
答案: ${{答案}}
63

64
现在,这是我的问题:
65
问题: {question}
66

67
"""
68
PROMPT = PromptTemplate(
69
    input_variables=["question", "database_names"],
70
    template=_PROMPT_TEMPLATE,
71
)
72

73

74
class LLMKnowledgeChain(LLMChain):
75
    llm_chain: LLMChain
76
    llm: Optional[BaseLanguageModel] = None
77
    """[Deprecated] LLM wrapper to use."""
78
    prompt: BasePromptTemplate = PROMPT
79
    """[Deprecated] Prompt to use to translate to python if necessary."""
80
    database_names: Dict[str, str] = model_container.DATABASE
81
    input_key: str = "question"  #: :meta private:
82
    output_key: str = "answer"  #: :meta private:
83

84
    class Config:
85
        """Configuration for this pydantic object."""
86

87
        extra = Extra.forbid
88
        arbitrary_types_allowed = True
89

90
    @root_validator(pre=True)
91
    def raise_deprecation(cls, values: Dict) -> Dict:
92
        if "llm" in values:
93
            warnings.warn(
94
                "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
95
                "Please instantiate with llm_chain argument or using the from_llm "
96
                "class method."
97
            )
98
            if "llm_chain" not in values and values["llm"] is not None:
99
                prompt = values.get("prompt", PROMPT)
100
                values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
101
        return values
102

103
    @property
104
    def input_keys(self) -> List[str]:
105
        """Expect input key.
106

107
        :meta private:
108
        """
109
        return [self.input_key]
110

111
    @property
112
    def output_keys(self) -> List[str]:
113
        """Expect output key.
114

115
        :meta private:
116
        """
117
        return [self.output_key]
118

119
    def _evaluate_expression(self, dataset, query) -> str:
120
        try:
121
            output = asyncio.run(search_knowledge_base_iter(dataset, query))
122
        except Exception as e:
123
            output = "输入的信息有误或不存在知识库"
124
            return output
125
        return output
126

127
    def _process_llm_result(
128
            self,
129
            llm_output: str,
130
            llm_input: str,
131
            run_manager: CallbackManagerForChainRun
132
    ) -> Dict[str, str]:
133

134
        run_manager.on_text(llm_output, color="green", verbose=self.verbose)
135

136
        llm_output = llm_output.strip()
137
        text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
138
        if text_match:
139
            database = text_match.group(1).strip()
140
            output = self._evaluate_expression(database, llm_input)
141
            run_manager.on_text("\nAnswer: ", verbose=self.verbose)
142
            run_manager.on_text(output, color="yellow", verbose=self.verbose)
143
            answer = "Answer: " + output
144
        elif llm_output.startswith("Answer:"):
145
            answer = llm_output
146
        elif "Answer:" in llm_output:
147
            answer = "Answer: " + llm_output.split("Answer:")[-1]
148
        else:
149
            return {self.output_key: f"输入的格式不对: {llm_output}"}
150
        return {self.output_key: answer}
151

152
    async def _aprocess_llm_result(
153
            self,
154
            llm_output: str,
155
            run_manager: AsyncCallbackManagerForChainRun,
156
    ) -> Dict[str, str]:
157
        await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
158
        llm_output = llm_output.strip()
159
        text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
160
        if text_match:
161
            expression = text_match.group(1)
162
            output = self._evaluate_expression(expression)
163
            await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
164
            await run_manager.on_text(output, color="yellow", verbose=self.verbose)
165
            answer = "Answer: " + output
166
        elif llm_output.startswith("Answer:"):
167
            answer = llm_output
168
        elif "Answer:" in llm_output:
169
            answer = "Answer: " + llm_output.split("Answer:")[-1]
170
        else:
171
            raise ValueError(f"unknown format from LLM: {llm_output}")
172
        return {self.output_key: answer}
173

174
    def _call(
175
            self,
176
            inputs: Dict[str, str],
177
            run_manager: Optional[CallbackManagerForChainRun] = None,
178
    ) -> Dict[str, str]:
179
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
180
        _run_manager.on_text(inputs[self.input_key])
181
        data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
182
        llm_output = self.llm_chain.predict(
183
            database_names=data_formatted_str,
184
            question=inputs[self.input_key],
185
            stop=["```output"],
186
            callbacks=_run_manager.get_child(),
187
        )
188
        return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)
189

190
    async def _acall(
191
            self,
192
            inputs: Dict[str, str],
193
            run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
194
    ) -> Dict[str, str]:
195
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
196
        await _run_manager.on_text(inputs[self.input_key])
197
        data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
198
        llm_output = await self.llm_chain.apredict(
199
            database_names=data_formatted_str,
200
            question=inputs[self.input_key],
201
            stop=["```output"],
202
            callbacks=_run_manager.get_child(),
203
        )
204
        return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
205

206
    @property
207
    def _chain_type(self) -> str:
208
        return "llm_knowledge_chain"
209

210
    @classmethod
211
    def from_llm(
212
            cls,
213
            llm: BaseLanguageModel,
214
            prompt: BasePromptTemplate = PROMPT,
215
            **kwargs: Any,
216
    ) -> LLMKnowledgeChain:
217
        llm_chain = LLMChain(llm=llm, prompt=prompt)
218
        return cls(llm_chain=llm_chain, **kwargs)
219

220

221
def search_knowledgebase_once(query: str):
222
    model = model_container.MODEL
223
    llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
224
    ans = llm_knowledge.run(query)
225
    return ans
226

227

228
class KnowledgeSearchInput(BaseModel):
229
    location: str = Field(description="The query to be searched")
230

231

232
if __name__ == "__main__":
233
    result = search_knowledgebase_once("大数据的男女比例")
234
    print(result)
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.