Langchain-Chatchat
287 строк · 11.2 Кб
1from __future__ import annotations2import json3import re4import warnings5from typing import Dict6from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun7from langchain.chains.llm import LLMChain8from langchain.pydantic_v1 import Extra, root_validator9from langchain.schema import BasePromptTemplate10from langchain.schema.language_model import BaseLanguageModel11from typing import List, Any, Optional12from langchain.prompts import PromptTemplate13from server.chat.knowledge_base_chat import knowledge_base_chat14from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS15import asyncio16from server.agent import model_container17from pydantic import BaseModel, Field18
19async def search_knowledge_base_iter(database: str, query: str) -> str:20response = await knowledge_base_chat(query=query,21knowledge_base_name=database,22model_name=model_container.MODEL.model_name,23temperature=0.01,24history=[],25top_k=VECTOR_SEARCH_TOP_K,26max_tokens=MAX_TOKENS,27prompt_name="default",28score_threshold=SCORE_THRESHOLD,29stream=False)30
31contents = ""32async for data in response.body_iterator: # 这里的data是一个json字符串33data = json.loads(data)34contents += data["answer"]35docs = data["docs"]36return contents37
38
39async def search_knowledge_multiple(queries) -> List[str]:40# queries 应该是一个包含多个 (database, query) 元组的列表41tasks = [search_knowledge_base_iter(database, query) for database, query in queries]42results = await asyncio.gather(*tasks)43# 结合每个查询结果,并在每个查询结果前添加一个自定义的消息44combined_results = []45for (database, _), result in zip(queries, results):46message = f"\n查询到 {database} 知识库的相关信息:\n{result}"47combined_results.append(message)48
49return combined_results50
51
52def search_knowledge(queries) -> str:53responses = asyncio.run(search_knowledge_multiple(queries))54# 输出每个整合的查询结果55contents = ""56for response in responses:57contents += response + "\n\n"58return contents59
60
61_PROMPT_TEMPLATE = """62用户会提出一个需要你查询知识库的问题,你应该对问题进行理解和拆解,并在知识库中查询相关的内容。
63
64对于每个知识库,你输出的内容应该是一个一行的字符串,这行字符串包含知识库名称和查询内容,中间用逗号隔开,不要有多余的文字和符号。你可以同时查询多个知识库,下面这个例子就是同时查询两个知识库的内容。
65
66例子:
67
68robotic,机器人男女比例是多少
69bigdata,大数据的就业情况如何
70
71
72这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能,你应该参考他们的功能来帮助你思考
73
74
75{database_names}
76
77你的回答格式应该按照下面的内容,请注意```text 等标记都必须输出,这是我用来提取答案的标记。
78不要输出中文的逗号,不要输出引号。
79
80Question: ${{用户的问题}}
81
82```text
83${{知识库名称,查询问题,不要带有任何除了,之外的符号,比如不要输出中文的逗号,不要输出引号}}
84
85```output
86数据库查询的结果
87
88现在,我们开始作答
89问题: {question}
90"""
91
92PROMPT = PromptTemplate(93input_variables=["question", "database_names"],94template=_PROMPT_TEMPLATE,95)
96
97
98class LLMKnowledgeChain(LLMChain):99llm_chain: LLMChain100llm: Optional[BaseLanguageModel] = None101"""[Deprecated] LLM wrapper to use."""102prompt: BasePromptTemplate = PROMPT103"""[Deprecated] Prompt to use to translate to python if necessary."""104database_names: Dict[str, str] = None105input_key: str = "question" #: :meta private:106output_key: str = "answer" #: :meta private:107
108class Config:109"""Configuration for this pydantic object."""110
111extra = Extra.forbid112arbitrary_types_allowed = True113
114@root_validator(pre=True)115def raise_deprecation(cls, values: Dict) -> Dict:116if "llm" in values:117warnings.warn(118"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "119"Please instantiate with llm_chain argument or using the from_llm "120"class method."121)122if "llm_chain" not in values and values["llm"] is not None:123prompt = values.get("prompt", PROMPT)124values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)125return values126
127@property128def input_keys(self) -> List[str]:129"""Expect input key.130
131:meta private:
132"""
133return [self.input_key]134
135@property136def output_keys(self) -> List[str]:137"""Expect output key.138
139:meta private:
140"""
141return [self.output_key]142
143def _evaluate_expression(self, queries) -> str:144try:145output = search_knowledge(queries)146except Exception as e:147output = "输入的信息有误或不存在知识库,错误信息如下:\n"148return output + str(e)149return output150
151def _process_llm_result(152self,153llm_output: str,154run_manager: CallbackManagerForChainRun155) -> Dict[str, str]:156
157run_manager.on_text(llm_output, color="green", verbose=self.verbose)158
159llm_output = llm_output.strip()160# text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)161text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)162if text_match:163expression = text_match.group(1).strip()164cleaned_input_str = (expression.replace("\"", "").replace("“", "").165replace("”", "").replace("```", "").strip())166lines = cleaned_input_str.split("\n")167# 使用逗号分割每一行,然后形成一个(数据库,查询)元组的列表168
169try:170queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]171except:172queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]173run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue", verbose=self.verbose)174output = self._evaluate_expression(queries)175run_manager.on_text("\nAnswer: ", verbose=self.verbose)176run_manager.on_text(output, color="yellow", verbose=self.verbose)177answer = "Answer: " + output178elif llm_output.startswith("Answer:"):179answer = llm_output180elif "Answer:" in llm_output:181answer = llm_output.split("Answer:")[-1]182else:183return {self.output_key: f"输入的格式不对:\n {llm_output}"}184return {self.output_key: answer}185
186async def _aprocess_llm_result(187self,188llm_output: str,189run_manager: AsyncCallbackManagerForChainRun,190) -> Dict[str, str]:191await run_manager.on_text(llm_output, color="green", verbose=self.verbose)192llm_output = llm_output.strip()193text_match = re.search(r"```text(.*)", llm_output, re.DOTALL)194if text_match:195
196expression = text_match.group(1).strip()197cleaned_input_str = (198expression.replace("\"", "").replace("“", "").replace("”", "").replace("```", "").strip())199lines = cleaned_input_str.split("\n")200try:201queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]202except:203queries = [(line.split(",")[0].strip(), line.split(",")[1].strip()) for line in lines]204await run_manager.on_text("知识库查询询内容:\n\n" + str(queries) + " \n\n", color="blue",205verbose=self.verbose)206
207output = self._evaluate_expression(queries)208await run_manager.on_text("\nAnswer: ", verbose=self.verbose)209await run_manager.on_text(output, color="yellow", verbose=self.verbose)210answer = "Answer: " + output211elif llm_output.startswith("Answer:"):212answer = llm_output213elif "Answer:" in llm_output:214answer = "Answer: " + llm_output.split("Answer:")[-1]215else:216raise ValueError(f"unknown format from LLM: {llm_output}")217return {self.output_key: answer}218
219def _call(220self,221inputs: Dict[str, str],222run_manager: Optional[CallbackManagerForChainRun] = None,223) -> Dict[str, str]:224_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()225_run_manager.on_text(inputs[self.input_key])226self.database_names = model_container.DATABASE227data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])228llm_output = self.llm_chain.predict(229database_names=data_formatted_str,230question=inputs[self.input_key],231stop=["```output"],232callbacks=_run_manager.get_child(),233)234return self._process_llm_result(llm_output, _run_manager)235
236async def _acall(237self,238inputs: Dict[str, str],239run_manager: Optional[AsyncCallbackManagerForChainRun] = None,240) -> Dict[str, str]:241_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()242await _run_manager.on_text(inputs[self.input_key])243self.database_names = model_container.DATABASE244data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])245llm_output = await self.llm_chain.apredict(246database_names=data_formatted_str,247question=inputs[self.input_key],248stop=["```output"],249callbacks=_run_manager.get_child(),250)251return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)252
253@property254def _chain_type(self) -> str:255return "llm_knowledge_chain"256
257@classmethod258def from_llm(259cls,260llm: BaseLanguageModel,261prompt: BasePromptTemplate = PROMPT,262**kwargs: Any,263) -> LLMKnowledgeChain:264llm_chain = LLMChain(llm=llm, prompt=prompt)265return cls(llm_chain=llm_chain, **kwargs)266
267
268def search_knowledgebase_complex(query: str):269model = model_container.MODEL270llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)271ans = llm_knowledge.run(query)272return ans273
274class KnowledgeSearchInput(BaseModel):275location: str = Field(description="The query to be searched")276
277if __name__ == "__main__":278result = search_knowledgebase_complex("机器人和大数据在代码教学上有什么区别")279print(result)280
281# 这是一个正常的切割
282# queries = [
283# ("bigdata", "大数据专业的男女比例"),
284# ("robotic", "机器人专业的优势")
285# ]
286# result = search_knowledge(queries)
287# print(result)
288