Langchain-Chatchat

Форк
0
/
search_knowledgebase_complex.py 
287 строк · 11.2 Кб
1
from __future__ import annotations
2
import json
3
import re
4
import warnings
5
from typing import Dict
6
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
7
from langchain.chains.llm import LLMChain
8
from langchain.pydantic_v1 import Extra, root_validator
9
from langchain.schema import BasePromptTemplate
10
from langchain.schema.language_model import BaseLanguageModel
11
from typing import List, Any, Optional
12
from langchain.prompts import PromptTemplate
13
from server.chat.knowledge_base_chat import knowledge_base_chat
14
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
15
import asyncio
16
from server.agent import model_container
17
from pydantic import BaseModel, Field
18

19
async def search_knowledge_base_iter(database: str, query: str) -> str:
20
    response = await knowledge_base_chat(query=query,
21
                                         knowledge_base_name=database,
22
                                         model_name=model_container.MODEL.model_name,
23
                                         temperature=0.01,
24
                                         history=[],
25
                                         top_k=VECTOR_SEARCH_TOP_K,
26
                                         max_tokens=MAX_TOKENS,
27
                                         prompt_name="default",
28
                                         score_threshold=SCORE_THRESHOLD,
29
                                         stream=False)
30

31
    contents = ""
32
    async for data in response.body_iterator:  # 这里的data是一个json字符串
33
        data = json.loads(data)
34
        contents += data["answer"]
35
        docs = data["docs"]
36
    return contents
37

38

39
async def search_knowledge_multiple(queries) -> List[str]:
40
    # queries 应该是一个包含多个 (database, query) 元组的列表
41
    tasks = [search_knowledge_base_iter(database, query) for database, query in queries]
42
    results = await asyncio.gather(*tasks)
43
    # 结合每个查询结果,并在每个查询结果前添加一个自定义的消息
44
    combined_results = []
45
    for (database, _), result in zip(queries, results):
46
        message = f"\n查询到 {database} 知识库的相关信息:\n{result}"
47
        combined_results.append(message)
48

49
    return combined_results
50

51

52
def search_knowledge(queries) -> str:
53
    responses = asyncio.run(search_knowledge_multiple(queries))
54
    # 输出每个整合的查询结果
55
    contents = ""
56
    for response in responses:
57
        contents += response + "\n\n"
58
    return contents
59

60

61
_PROMPT_TEMPLATE = """
62
用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
63

64
对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
65

66
例子:
67

68
robotic,机器人男女比例是多少
69
bigdata,大数据的就业情况如何 
70

71

72
这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
73

74

75
{database_names}
76

77
你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
78
不要输出中文的逗号,不要输出引号。
79

80
Question: ${{用户的问题}}
81

82
```text
83
${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
84

85
```output
86
数据库查询的结果
87

88
现在,我们开始作答
89
问题: {question}
90
"""
91

92
PROMPT = PromptTemplate(
93
    input_variables=["question", "database_names"],
94
    template=_PROMPT_TEMPLATE,
95
)
96

97

98
class LLMKnowledgeChain(LLMChain):
99
    llm_chain: LLMChain
100
    llm: Optional[BaseLanguageModel] = None
101
    """[Deprecated] LLM wrapper to use."""
102
    prompt: BasePromptTemplate = PROMPT
103
    """[Deprecated] Prompt to use to translate to python if necessary."""
104
    database_names: Dict[str, str] = None
105
    input_key: str = "question"  #: :meta private:
106
    output_key: str = "answer"  #: :meta private:
107

108
    class Config:
109
        """Configuration for this pydantic object."""
110

111
        extra = Extra.forbid
112
        arbitrary_types_allowed = True
113

114
    @root_validator(pre=True)
115
    def raise_deprecation(cls, values: Dict) -> Dict:
116
        if "llm" in values:
117
            warnings.warn(
118
                "Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "
119
                "Please instantiate with llm_chain argument or using the from_llm "
120
                "class method."
121
            )
122
            if "llm_chain" not in values and values["llm"] is not None:
123
                prompt = values.get("prompt", PROMPT)
124
                values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
125
        return values
126

127
    @property
128
    def input_keys(self) -> List[str]:
129
        """Expect input key.
130

131
        :meta private:
132
        """
133
        return [self.input_key]
134

135
    @property
136
    def output_keys(self) -> List[str]:
137
        """Expect output key.
138

139
        :meta private:
140
        """
141
        return [self.output_key]
142

143
    def _evaluate_expression(self, queries) -> str:
144
        try:
145
            output = search_knowledge(queries)
146
        except Exception as e:
147
            output = "输入的信息有误或不存在知识库,错误信息如下:\n"
148
            return output + str(e)
149
        return output
150

151
    def _process_llm_result(
152
            self,
153
            llm_output: str,
154
            run_manager: CallbackManagerForChainRun
155
    ) -> Dict[str, str]:
156

157
        run_manager.on_text(llm_output, color="green", verbose=self.verbose)
158

159
        llm_output = llm_output.strip()
160
        # text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
161
        text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
162
        if text_match:
163
            expression = text_match.group(1).strip()
164
            cleaned_input_str = (expression.replace("\"", "").replace("“", "").
165
                                 replace("”", "").replace("```", "").strip())
166
            lines = cleaned_input_str.split("\n")
167
            # 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表
168

169
            try:
170
                queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
171
            except:
172
                queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
173
            run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)
174
            output = self._evaluate_expression(queries)
175
            run_manager.on_text("\nAnswer: ", verbose=self.verbose)
176
            run_manager.on_text(output, color="yellow", verbose=self.verbose)
177
            answer = "Answer: " + output
178
        elif llm_output.startswith("Answer:"):
179
            answer = llm_output
180
        elif "Answer:" in llm_output:
181
            answer = llm_output.split("Answer:")[-1]
182
        else:
183
            return {self.output_key: f"输入的格式不对:\n {llm_output}"}
184
        return {self.output_key: answer}
185

186
    async def _aprocess_llm_result(
187
            self,
188
            llm_output: str,
189
            run_manager: AsyncCallbackManagerForChainRun,
190
    ) -> Dict[str, str]:
191
        await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
192
        llm_output = llm_output.strip()
193
        text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)
194
        if text_match:
195

196
            expression = text_match.group(1).strip()
197
            cleaned_input_str = (
198
                expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())
199
            lines = cleaned_input_str.split("\n")
200
            try:
201
                queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
202
            except:
203
                queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]
204
            await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",
205
                                      verbose=self.verbose)
206

207
            output = self._evaluate_expression(queries)
208
            await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
209
            await run_manager.on_text(output, color="yellow", verbose=self.verbose)
210
            answer = "Answer: " + output
211
        elif llm_output.startswith("Answer:"):
212
            answer = llm_output
213
        elif "Answer:" in llm_output:
214
            answer = "Answer: " + llm_output.split("Answer:")[-1]
215
        else:
216
            raise ValueError(f"unknown format from LLM: {llm_output}")
217
        return {self.output_key: answer}
218

219
    def _call(
220
            self,
221
            inputs: Dict[str, str],
222
            run_manager: Optional[CallbackManagerForChainRun] = None,
223
    ) -> Dict[str, str]:
224
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
225
        _run_manager.on_text(inputs[self.input_key])
226
        self.database_names = model_container.DATABASE
227
        data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
228
        llm_output = self.llm_chain.predict(
229
            database_names=data_formatted_str,
230
            question=inputs[self.input_key],
231
            stop=["```output"],
232
            callbacks=_run_manager.get_child(),
233
        )
234
        return self._process_llm_result(llm_output, _run_manager)
235

236
    async def _acall(
237
            self,
238
            inputs: Dict[str, str],
239
            run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
240
    ) -> Dict[str, str]:
241
        _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
242
        await _run_manager.on_text(inputs[self.input_key])
243
        self.database_names = model_container.DATABASE
244
        data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])
245
        llm_output = await self.llm_chain.apredict(
246
            database_names=data_formatted_str,
247
            question=inputs[self.input_key],
248
            stop=["```output"],
249
            callbacks=_run_manager.get_child(),
250
        )
251
        return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)
252

253
    @property
254
    def _chain_type(self) -> str:
255
        return "llm_knowledge_chain"
256

257
    @classmethod
258
    def from_llm(
259
            cls,
260
            llm: BaseLanguageModel,
261
            prompt: BasePromptTemplate = PROMPT,
262
            **kwargs: Any,
263
    ) -> LLMKnowledgeChain:
264
        llm_chain = LLMChain(llm=llm, prompt=prompt)
265
        return cls(llm_chain=llm_chain, **kwargs)
266

267

268
def search_knowledgebase_complex(query: str):
269
    model = model_container.MODEL
270
    llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)
271
    ans = llm_knowledge.run(query)
272
    return ans
273

274
class KnowledgeSearchInput(BaseModel):
275
    location: str = Field(description="The query to be searched")
276

277
if __name__ == "__main__":
278
    result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")
279
    print(result)
280

281
# 这是一个正常的切割
282
#     queries = [
283
#         ("bigdata", "大数据专业的男女比例"),
284
#         ("robotic", "机器人专业的优势")
285
#     ]
286
#     result = search_knowledge(queries)
287
#     print(result)
288

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.