Langchain-Chatchat

Форк
0
/
custom_template.py 
67 строк · 2.6 Кб
1
from __future__ import annotations
2
from langchain.agents import Tool, AgentOutputParser
3
from langchain.prompts import StringPromptTemplate
4
from typing import List
5
from langchain.schema import AgentAction, AgentFinish
6

7
from configs import SUPPORT_AGENT_MODEL
8
from server.agent import model_container
9
class CustomPromptTemplate(StringPromptTemplate):
10
    template: str
11
    tools: List[Tool]
12

13
    def format(self, **kwargs) -> str:
14
        intermediate_steps = kwargs.pop("intermediate_steps")
15
        thoughts = ""
16
        for action, observation in intermediate_steps:
17
            thoughts += action.log
18
            thoughts += f"\nObservation: {observation}\nThought: "
19
        kwargs["agent_scratchpad"] = thoughts
20
        kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
21
        kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
22
        return self.template.format(**kwargs)
23

24
class CustomOutputParser(AgentOutputParser):
25
    begin: bool = False
26
    def __init__(self):
27
        super().__init__()
28
        self.begin = True
29

30
    def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
31
        if not any(agent in model_container.MODEL for agent in SUPPORT_AGENT_MODEL) and self.begin:
32
            self.begin = False
33
            stop_words = ["Observation:"]
34
            min_index = len(llm_output)
35
            for stop_word in stop_words:
36
                index = llm_output.find(stop_word)
37
                if index != -1 and index < min_index:
38
                    min_index = index
39
                llm_output = llm_output[:min_index]
40

41
        if "Final Answer:" in llm_output:
42
            self.begin = True
43
            return AgentFinish(
44
                return_values={"output": llm_output.split("Final Answer:", 1)[-1].strip()},
45
                log=llm_output,
46
            )
47
        parts = llm_output.split("Action:")
48
        if len(parts) < 2:
49
            return AgentFinish(
50
                return_values={"output": f"调用agent工具失败,该回答为大模型自身能力的回答:\n\n `{llm_output}`"},
51
                log=llm_output,
52
            )
53

54
        action = parts[1].split("Action Input:")[0].strip()
55
        action_input = parts[1].split("Action Input:")[1].strip()
56
        try:
57
            ans = AgentAction(
58
                tool=action,
59
                tool_input=action_input.strip(" ").strip('"'),
60
                log=llm_output
61
            )
62
            return ans
63
        except:
64
            return AgentFinish(
65
                return_values={"output": f"调用agent失败: `{llm_output}`"},
66
                log=llm_output,
67
            )
68

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.