llama-index

Форк
0
199 строк · 7.5 Кб
1
"""Langchain memory wrapper (for LlamaIndex)."""
2

3
from typing import Any, Dict, List, Optional
4

5
from llama_index.legacy.bridge.langchain import (
6
    AIMessage,
7
    BaseChatMemory,
8
    BaseMessage,
9
    HumanMessage,
10
)
11
from llama_index.legacy.bridge.langchain import BaseMemory as Memory
12
from llama_index.legacy.bridge.pydantic import Field
13
from llama_index.legacy.indices.base import BaseIndex
14
from llama_index.legacy.schema import Document
15
from llama_index.legacy.utils import get_new_id
16

17

18
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
19
    """Get prompt input key.
20

21
    Copied over from langchain.
22

23
    """
24
    # "stop" is a special key that can be passed as input but is not used to
25
    # format the prompt.
26
    prompt_input_keys = list(set(inputs).difference([*memory_variables, "stop"]))
27
    if len(prompt_input_keys) != 1:
28
        raise ValueError(f"One input key expected got {prompt_input_keys}")
29
    return prompt_input_keys[0]
30

31

32
class GPTIndexMemory(Memory):
33
    """Langchain memory wrapper (for LlamaIndex).
34

35
    Args:
36
        human_prefix (str): Prefix for human input. Defaults to "Human".
37
        ai_prefix (str): Prefix for AI output. Defaults to "AI".
38
        memory_key (str): Key for memory. Defaults to "history".
39
        index (BaseIndex): LlamaIndex instance.
40
        query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
41
        input_key (Optional[str]): Input key. Defaults to None.
42
        output_key (Optional[str]): Output key. Defaults to None.
43

44
    """
45

46
    human_prefix: str = "Human"
47
    ai_prefix: str = "AI"
48
    memory_key: str = "history"
49
    index: BaseIndex
50
    query_kwargs: Dict = Field(default_factory=dict)
51
    output_key: Optional[str] = None
52
    input_key: Optional[str] = None
53

54
    @property
55
    def memory_variables(self) -> List[str]:
56
        """Return memory variables."""
57
        return [self.memory_key]
58

59
    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
60
        if self.input_key is None:
61
            prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
62
        else:
63
            prompt_input_key = self.input_key
64
        return prompt_input_key
65

66
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
67
        """Return key-value pairs given the text input to the chain."""
68
        prompt_input_key = self._get_prompt_input_key(inputs)
69
        query_str = inputs[prompt_input_key]
70

71
        # TODO: wrap in prompt
72
        # TODO: add option to return the raw text
73
        # NOTE: currently it's a hack
74
        query_engine = self.index.as_query_engine(**self.query_kwargs)
75
        response = query_engine.query(query_str)
76
        return {self.memory_key: str(response)}
77

78
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
79
        """Save the context of this model run to memory."""
80
        prompt_input_key = self._get_prompt_input_key(inputs)
81
        if self.output_key is None:
82
            if len(outputs) != 1:
83
                raise ValueError(f"One output key expected, got {outputs.keys()}")
84
            output_key = next(iter(outputs.keys()))
85
        else:
86
            output_key = self.output_key
87
        human = f"{self.human_prefix}: " + inputs[prompt_input_key]
88
        ai = f"{self.ai_prefix}: " + outputs[output_key]
89
        doc_text = f"{human}\n{ai}"
90
        doc = Document(text=doc_text)
91
        self.index.insert(doc)
92

93
    def clear(self) -> None:
94
        """Clear memory contents."""
95

96
    def __repr__(self) -> str:
97
        """Return representation."""
98
        return "GPTIndexMemory()"
99

100

101
class GPTIndexChatMemory(BaseChatMemory):
102
    """Langchain chat memory wrapper (for LlamaIndex).
103

104
    Args:
105
        human_prefix (str): Prefix for human input. Defaults to "Human".
106
        ai_prefix (str): Prefix for AI output. Defaults to "AI".
107
        memory_key (str): Key for memory. Defaults to "history".
108
        index (BaseIndex): LlamaIndex instance.
109
        query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
110
        input_key (Optional[str]): Input key. Defaults to None.
111
        output_key (Optional[str]): Output key. Defaults to None.
112

113
    """
114

115
    human_prefix: str = "Human"
116
    ai_prefix: str = "AI"
117
    memory_key: str = "history"
118
    index: BaseIndex
119
    query_kwargs: Dict = Field(default_factory=dict)
120
    output_key: Optional[str] = None
121
    input_key: Optional[str] = None
122

123
    return_source: bool = False
124
    id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict)
125

126
    @property
127
    def memory_variables(self) -> List[str]:
128
        """Return memory variables."""
129
        return [self.memory_key]
130

131
    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
132
        if self.input_key is None:
133
            prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
134
        else:
135
            prompt_input_key = self.input_key
136
        return prompt_input_key
137

138
    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
139
        """Return key-value pairs given the text input to the chain."""
140
        prompt_input_key = self._get_prompt_input_key(inputs)
141
        query_str = inputs[prompt_input_key]
142

143
        query_engine = self.index.as_query_engine(**self.query_kwargs)
144
        response_obj = query_engine.query(query_str)
145
        if self.return_source:
146
            source_nodes = response_obj.source_nodes
147
            if self.return_messages:
148
                # get source messages from ids
149
                source_ids = [sn.node.node_id for sn in source_nodes]
150
                source_messages = [
151
                    m for id, m in self.id_to_message.items() if id in source_ids
152
                ]
153
                # NOTE: type List[BaseMessage]
154
                response: Any = source_messages
155
            else:
156
                source_texts = [sn.node.get_content() for sn in source_nodes]
157
                response = "\n\n".join(source_texts)
158
        else:
159
            response = str(response_obj)
160
        return {self.memory_key: response}
161

162
    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
163
        """Save the context of this model run to memory."""
164
        prompt_input_key = self._get_prompt_input_key(inputs)
165
        if self.output_key is None:
166
            if len(outputs) != 1:
167
                raise ValueError(f"One output key expected, got {outputs.keys()}")
168
            output_key = next(iter(outputs.keys()))
169
        else:
170
            output_key = self.output_key
171

172
        # a bit different than existing langchain implementation
173
        # because we want to track id's for messages
174
        human_message = HumanMessage(content=inputs[prompt_input_key])
175
        human_message_id = get_new_id(set(self.id_to_message.keys()))
176
        ai_message = AIMessage(content=outputs[output_key])
177
        ai_message_id = get_new_id(
178
            set(self.id_to_message.keys()).union({human_message_id})
179
        )
180

181
        self.chat_memory.messages.append(human_message)
182
        self.chat_memory.messages.append(ai_message)
183

184
        self.id_to_message[human_message_id] = human_message
185
        self.id_to_message[ai_message_id] = ai_message
186

187
        human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key]
188
        ai_txt = f"{self.ai_prefix}: " + outputs[output_key]
189
        human_doc = Document(text=human_txt, id_=human_message_id)
190
        ai_doc = Document(text=ai_txt, id_=ai_message_id)
191
        self.index.insert(human_doc)
192
        self.index.insert(ai_doc)
193

194
    def clear(self) -> None:
195
        """Clear memory contents."""
196

197
    def __repr__(self) -> str:
198
        """Return representation."""
199
        return "GPTIndexMemory()"
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.