llama-index
199 строк · 7.5 Кб
1"""Langchain memory wrapper (for LlamaIndex)."""
2
3from typing import Any, Dict, List, Optional4
5from llama_index.legacy.bridge.langchain import (6AIMessage,7BaseChatMemory,8BaseMessage,9HumanMessage,10)
11from llama_index.legacy.bridge.langchain import BaseMemory as Memory12from llama_index.legacy.bridge.pydantic import Field13from llama_index.legacy.indices.base import BaseIndex14from llama_index.legacy.schema import Document15from llama_index.legacy.utils import get_new_id16
17
18def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:19"""Get prompt input key.20
21Copied over from langchain.
22
23"""
24# "stop" is a special key that can be passed as input but is not used to25# format the prompt.26prompt_input_keys = list(set(inputs).difference([*memory_variables, "stop"]))27if len(prompt_input_keys) != 1:28raise ValueError(f"One input key expected got {prompt_input_keys}")29return prompt_input_keys[0]30
31
32class GPTIndexMemory(Memory):33"""Langchain memory wrapper (for LlamaIndex).34
35Args:
36human_prefix (str): Prefix for human input. Defaults to "Human".
37ai_prefix (str): Prefix for AI output. Defaults to "AI".
38memory_key (str): Key for memory. Defaults to "history".
39index (BaseIndex): LlamaIndex instance.
40query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
41input_key (Optional[str]): Input key. Defaults to None.
42output_key (Optional[str]): Output key. Defaults to None.
43
44"""
45
46human_prefix: str = "Human"47ai_prefix: str = "AI"48memory_key: str = "history"49index: BaseIndex50query_kwargs: Dict = Field(default_factory=dict)51output_key: Optional[str] = None52input_key: Optional[str] = None53
54@property55def memory_variables(self) -> List[str]:56"""Return memory variables."""57return [self.memory_key]58
59def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:60if self.input_key is None:61prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)62else:63prompt_input_key = self.input_key64return prompt_input_key65
66def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:67"""Return key-value pairs given the text input to the chain."""68prompt_input_key = self._get_prompt_input_key(inputs)69query_str = inputs[prompt_input_key]70
71# TODO: wrap in prompt72# TODO: add option to return the raw text73# NOTE: currently it's a hack74query_engine = self.index.as_query_engine(**self.query_kwargs)75response = query_engine.query(query_str)76return {self.memory_key: str(response)}77
78def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:79"""Save the context of this model run to memory."""80prompt_input_key = self._get_prompt_input_key(inputs)81if self.output_key is None:82if len(outputs) != 1:83raise ValueError(f"One output key expected, got {outputs.keys()}")84output_key = next(iter(outputs.keys()))85else:86output_key = self.output_key87human = f"{self.human_prefix}: " + inputs[prompt_input_key]88ai = f"{self.ai_prefix}: " + outputs[output_key]89doc_text = f"{human}\n{ai}"90doc = Document(text=doc_text)91self.index.insert(doc)92
93def clear(self) -> None:94"""Clear memory contents."""95
96def __repr__(self) -> str:97"""Return representation."""98return "GPTIndexMemory()"99
100
101class GPTIndexChatMemory(BaseChatMemory):102"""Langchain chat memory wrapper (for LlamaIndex).103
104Args:
105human_prefix (str): Prefix for human input. Defaults to "Human".
106ai_prefix (str): Prefix for AI output. Defaults to "AI".
107memory_key (str): Key for memory. Defaults to "history".
108index (BaseIndex): LlamaIndex instance.
109query_kwargs (Dict[str, Any]): Keyword arguments for LlamaIndex query.
110input_key (Optional[str]): Input key. Defaults to None.
111output_key (Optional[str]): Output key. Defaults to None.
112
113"""
114
115human_prefix: str = "Human"116ai_prefix: str = "AI"117memory_key: str = "history"118index: BaseIndex119query_kwargs: Dict = Field(default_factory=dict)120output_key: Optional[str] = None121input_key: Optional[str] = None122
123return_source: bool = False124id_to_message: Dict[str, BaseMessage] = Field(default_factory=dict)125
126@property127def memory_variables(self) -> List[str]:128"""Return memory variables."""129return [self.memory_key]130
131def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:132if self.input_key is None:133prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)134else:135prompt_input_key = self.input_key136return prompt_input_key137
138def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:139"""Return key-value pairs given the text input to the chain."""140prompt_input_key = self._get_prompt_input_key(inputs)141query_str = inputs[prompt_input_key]142
143query_engine = self.index.as_query_engine(**self.query_kwargs)144response_obj = query_engine.query(query_str)145if self.return_source:146source_nodes = response_obj.source_nodes147if self.return_messages:148# get source messages from ids149source_ids = [sn.node.node_id for sn in source_nodes]150source_messages = [151m for id, m in self.id_to_message.items() if id in source_ids152]153# NOTE: type List[BaseMessage]154response: Any = source_messages155else:156source_texts = [sn.node.get_content() for sn in source_nodes]157response = "\n\n".join(source_texts)158else:159response = str(response_obj)160return {self.memory_key: response}161
162def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:163"""Save the context of this model run to memory."""164prompt_input_key = self._get_prompt_input_key(inputs)165if self.output_key is None:166if len(outputs) != 1:167raise ValueError(f"One output key expected, got {outputs.keys()}")168output_key = next(iter(outputs.keys()))169else:170output_key = self.output_key171
172# a bit different than existing langchain implementation173# because we want to track id's for messages174human_message = HumanMessage(content=inputs[prompt_input_key])175human_message_id = get_new_id(set(self.id_to_message.keys()))176ai_message = AIMessage(content=outputs[output_key])177ai_message_id = get_new_id(178set(self.id_to_message.keys()).union({human_message_id})179)180
181self.chat_memory.messages.append(human_message)182self.chat_memory.messages.append(ai_message)183
184self.id_to_message[human_message_id] = human_message185self.id_to_message[ai_message_id] = ai_message186
187human_txt = f"{self.human_prefix}: " + inputs[prompt_input_key]188ai_txt = f"{self.ai_prefix}: " + outputs[output_key]189human_doc = Document(text=human_txt, id_=human_message_id)190ai_doc = Document(text=ai_txt, id_=ai_message_id)191self.index.insert(human_doc)192self.index.insert(ai_doc)193
194def clear(self) -> None:195"""Clear memory contents."""196
197def __repr__(self) -> str:198"""Return representation."""199return "GPTIndexMemory()"200