Langchain-Chatchat
234 строки · 8.3 Кб
1from __future__ import annotations2import re3import warnings4from typing import Dict5
6from langchain.callbacks.manager import (7AsyncCallbackManagerForChainRun,8CallbackManagerForChainRun,9)
10from langchain.chains.llm import LLMChain11from langchain.pydantic_v1 import Extra, root_validator12from langchain.schema import BasePromptTemplate13from langchain.schema.language_model import BaseLanguageModel14from typing import List, Any, Optional15from langchain.prompts import PromptTemplate16import sys17import os18import json19
20sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))21from server.chat.knowledge_base_chat import knowledge_base_chat22from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS23
24import asyncio25from server.agent import model_container26from pydantic import BaseModel, Field27
28async def search_knowledge_base_iter(database: str, query: str):29response = await knowledge_base_chat(query=query,30knowledge_base_name=database,31model_name=model_container.MODEL.model_name,32temperature=0.01,33history=[],34top_k=VECTOR_SEARCH_TOP_K,35max_tokens=MAX_TOKENS,36prompt_name="knowledge_base_chat",37score_threshold=SCORE_THRESHOLD,38stream=False)39
40contents = ""41async for data in response.body_iterator: # 这里的data是一个json字符串42data = json.loads(data)43contents += data["answer"]44docs = data["docs"]45return contents46
47
48_PROMPT_TEMPLATE = """49用户会提出一个需要你查询知识库的问题,你应该按照我提供的思想进行思考
50Question: ${{用户的问题}}
51这些数据库是你能访问的,冒号之前是他们的名字,冒号之后是他们的功能:
52
53{database_names}
54
55你的回答格式应该按照下面的内容,请注意,格式内的```text 等标记都必须输出,这是我用来提取答案的标记。
56```text
57${{知识库的名称}}
58```
59```output
60数据库查询的结果
61```
62答案: ${{答案}}
63
64现在,这是我的问题:
65问题: {question}
66
67"""
68PROMPT = PromptTemplate(69input_variables=["question", "database_names"],70template=_PROMPT_TEMPLATE,71)
72
73
74class LLMKnowledgeChain(LLMChain):75llm_chain: LLMChain76llm: Optional[BaseLanguageModel] = None77"""[Deprecated] LLM wrapper to use."""78prompt: BasePromptTemplate = PROMPT79"""[Deprecated] Prompt to use to translate to python if necessary."""80database_names: Dict[str, str] = model_container.DATABASE81input_key: str = "question" #: :meta private:82output_key: str = "answer" #: :meta private:83
84class Config:85"""Configuration for this pydantic object."""86
87extra = Extra.forbid88arbitrary_types_allowed = True89
90@root_validator(pre=True)91def raise_deprecation(cls, values: Dict) -> Dict:92if "llm" in values:93warnings.warn(94"Directly instantiating an LLMKnowledgeChain with an llm is deprecated. "95"Please instantiate with llm_chain argument or using the from_llm "96"class method."97)98if "llm_chain" not in values and values["llm"] is not None:99prompt = values.get("prompt", PROMPT)100values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)101return values102
103@property104def input_keys(self) -> List[str]:105"""Expect input key.106
107:meta private:
108"""
109return [self.input_key]110
111@property112def output_keys(self) -> List[str]:113"""Expect output key.114
115:meta private:
116"""
117return [self.output_key]118
119def _evaluate_expression(self, dataset, query) -> str:120try:121output = asyncio.run(search_knowledge_base_iter(dataset, query))122except Exception as e:123output = "输入的信息有误或不存在知识库"124return output125return output126
127def _process_llm_result(128self,129llm_output: str,130llm_input: str,131run_manager: CallbackManagerForChainRun132) -> Dict[str, str]:133
134run_manager.on_text(llm_output, color="green", verbose=self.verbose)135
136llm_output = llm_output.strip()137text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)138if text_match:139database = text_match.group(1).strip()140output = self._evaluate_expression(database, llm_input)141run_manager.on_text("\nAnswer: ", verbose=self.verbose)142run_manager.on_text(output, color="yellow", verbose=self.verbose)143answer = "Answer: " + output144elif llm_output.startswith("Answer:"):145answer = llm_output146elif "Answer:" in llm_output:147answer = "Answer: " + llm_output.split("Answer:")[-1]148else:149return {self.output_key: f"输入的格式不对: {llm_output}"}150return {self.output_key: answer}151
152async def _aprocess_llm_result(153self,154llm_output: str,155run_manager: AsyncCallbackManagerForChainRun,156) -> Dict[str, str]:157await run_manager.on_text(llm_output, color="green", verbose=self.verbose)158llm_output = llm_output.strip()159text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)160if text_match:161expression = text_match.group(1)162output = self._evaluate_expression(expression)163await run_manager.on_text("\nAnswer: ", verbose=self.verbose)164await run_manager.on_text(output, color="yellow", verbose=self.verbose)165answer = "Answer: " + output166elif llm_output.startswith("Answer:"):167answer = llm_output168elif "Answer:" in llm_output:169answer = "Answer: " + llm_output.split("Answer:")[-1]170else:171raise ValueError(f"unknown format from LLM: {llm_output}")172return {self.output_key: answer}173
174def _call(175self,176inputs: Dict[str, str],177run_manager: Optional[CallbackManagerForChainRun] = None,178) -> Dict[str, str]:179_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()180_run_manager.on_text(inputs[self.input_key])181data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])182llm_output = self.llm_chain.predict(183database_names=data_formatted_str,184question=inputs[self.input_key],185stop=["```output"],186callbacks=_run_manager.get_child(),187)188return self._process_llm_result(llm_output, inputs[self.input_key], _run_manager)189
190async def _acall(191self,192inputs: Dict[str, str],193run_manager: Optional[AsyncCallbackManagerForChainRun] = None,194) -> Dict[str, str]:195_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()196await _run_manager.on_text(inputs[self.input_key])197data_formatted_str = ',\n'.join([f' "{k}":"{v}"' for k, v in self.database_names.items()])198llm_output = await self.llm_chain.apredict(199database_names=data_formatted_str,200question=inputs[self.input_key],201stop=["```output"],202callbacks=_run_manager.get_child(),203)204return await self._aprocess_llm_result(llm_output, inputs[self.input_key], _run_manager)205
206@property207def _chain_type(self) -> str:208return "llm_knowledge_chain"209
210@classmethod211def from_llm(212cls,213llm: BaseLanguageModel,214prompt: BasePromptTemplate = PROMPT,215**kwargs: Any,216) -> LLMKnowledgeChain:217llm_chain = LLMChain(llm=llm, prompt=prompt)218return cls(llm_chain=llm_chain, **kwargs)219
220
221def search_knowledgebase_once(query: str):222model = model_container.MODEL223llm_knowledge = LLMKnowledgeChain.from_llm(model, verbose=True, prompt=PROMPT)224ans = llm_knowledge.run(query)225return ans226
227
228class KnowledgeSearchInput(BaseModel):229location: str = Field(description="The query to be searched")230
231
232if __name__ == "__main__":233result = search_knowledgebase_once("大数据的男女比例")234print(result)235