llama-index

Форк
0
157 строк · 5.8 Кб
1
import json
2
from typing import Any, Callable, Dict, List, Optional
3

4
from llama_index.legacy.bridge.pydantic import Field, root_validator
5
from llama_index.legacy.core.llms.types import ChatMessage, MessageRole
6
from llama_index.legacy.llms.llm import LLM
7
from llama_index.legacy.llms.types import ChatMessage, MessageRole
8
from llama_index.legacy.memory.types import DEFAULT_CHAT_STORE_KEY, BaseMemory
9
from llama_index.legacy.storage.chat_store import BaseChatStore, SimpleChatStore
10
from llama_index.legacy.utils import get_tokenizer
11

12
DEFAULT_TOKEN_LIMIT_RATIO = 0.75
13
DEFAULT_TOKEN_LIMIT = 3000
14

15

16
class ChatMemoryBuffer(BaseMemory):
17
    """Simple buffer for storing chat history."""
18

19
    token_limit: int
20
    tokenizer_fn: Callable[[str], List] = Field(
21
        # NOTE: mypy does not handle the typing here well, hence the cast
22
        default_factory=get_tokenizer,
23
        exclude=True,
24
    )
25
    chat_store: BaseChatStore = Field(default_factory=SimpleChatStore)
26
    chat_store_key: str = Field(default=DEFAULT_CHAT_STORE_KEY)
27

28
    @classmethod
29
    def class_name(cls) -> str:
30
        """Get class name."""
31
        return "ChatMemoryBuffer"
32

33
    @root_validator(pre=True)
34
    def validate_memory(cls, values: dict) -> dict:
35
        # Validate token limit
36
        token_limit = values.get("token_limit", -1)
37
        if token_limit < 1:
38
            raise ValueError("Token limit must be set and greater than 0.")
39

40
        # Validate tokenizer -- this avoids errors when loading from json/dict
41
        tokenizer_fn = values.get("tokenizer_fn", None)
42
        if tokenizer_fn is None:
43
            values["tokenizer_fn"] = get_tokenizer()
44

45
        return values
46

47
    @classmethod
48
    def from_defaults(
49
        cls,
50
        chat_history: Optional[List[ChatMessage]] = None,
51
        llm: Optional[LLM] = None,
52
        chat_store: Optional[BaseChatStore] = None,
53
        chat_store_key: str = DEFAULT_CHAT_STORE_KEY,
54
        token_limit: Optional[int] = None,
55
        tokenizer_fn: Optional[Callable[[str], List]] = None,
56
    ) -> "ChatMemoryBuffer":
57
        """Create a chat memory buffer from an LLM."""
58
        if llm is not None:
59
            context_window = llm.metadata.context_window
60
            token_limit = token_limit or int(context_window * DEFAULT_TOKEN_LIMIT_RATIO)
61
        elif token_limit is None:
62
            token_limit = DEFAULT_TOKEN_LIMIT
63

64
        if chat_history is not None:
65
            chat_store = chat_store or SimpleChatStore()
66
            chat_store.set_messages(chat_store_key, chat_history)
67

68
        return cls(
69
            token_limit=token_limit,
70
            tokenizer_fn=tokenizer_fn or get_tokenizer(),
71
            chat_store=chat_store or SimpleChatStore(),
72
            chat_store_key=chat_store_key,
73
        )
74

75
    def to_string(self) -> str:
76
        """Convert memory to string."""
77
        return self.json()
78

79
    @classmethod
80
    def from_string(cls, json_str: str) -> "ChatMemoryBuffer":
81
        """Create a chat memory buffer from a string."""
82
        dict_obj = json.loads(json_str)
83
        return cls.from_dict(dict_obj)
84

85
    def to_dict(self, **kwargs: Any) -> dict:
86
        """Convert memory to dict."""
87
        return self.dict()
88

89
    @classmethod
90
    def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> "ChatMemoryBuffer":
91
        from llama_index.legacy.storage.chat_store.loading import load_chat_store
92

93
        # NOTE: this handles backwards compatibility with the old chat history
94
        if "chat_history" in data:
95
            chat_history = data.pop("chat_history")
96
            chat_store = SimpleChatStore(store={DEFAULT_CHAT_STORE_KEY: chat_history})
97
            data["chat_store"] = chat_store
98
        elif "chat_store" in data:
99
            chat_store = data.pop("chat_store")
100
            chat_store = load_chat_store(chat_store)
101
            data["chat_store"] = chat_store
102

103
        return cls(**data)
104

105
    def get(self, initial_token_count: int = 0, **kwargs: Any) -> List[ChatMessage]:
106
        """Get chat history."""
107
        chat_history = self.get_all()
108

109
        if initial_token_count > self.token_limit:
110
            raise ValueError("Initial token count exceeds token limit")
111

112
        message_count = len(chat_history)
113
        token_count = (
114
            self._token_count_for_message_count(message_count) + initial_token_count
115
        )
116

117
        while token_count > self.token_limit and message_count > 1:
118
            message_count -= 1
119
            if chat_history[-message_count].role == MessageRole.ASSISTANT:
120
                # we cannot have an assistant message at the start of the chat history
121
                # if after removal of the first, we have an assistant message,
122
                # we need to remove the assistant message too
123
                message_count -= 1
124

125
            token_count = (
126
                self._token_count_for_message_count(message_count) + initial_token_count
127
            )
128

129
        # catch one message longer than token limit
130
        if token_count > self.token_limit or message_count <= 0:
131
            return []
132

133
        return chat_history[-message_count:]
134

135
    def get_all(self) -> List[ChatMessage]:
136
        """Get all chat history."""
137
        return self.chat_store.get_messages(self.chat_store_key)
138

139
    def put(self, message: ChatMessage) -> None:
140
        """Put chat history."""
141
        self.chat_store.add_message(self.chat_store_key, message)
142

143
    def set(self, messages: List[ChatMessage]) -> None:
144
        """Set chat history."""
145
        self.chat_store.set_messages(self.chat_store_key, messages)
146

147
    def reset(self) -> None:
148
        """Reset chat history."""
149
        self.chat_store.delete_messages(self.chat_store_key)
150

151
    def _token_count_for_message_count(self, message_count: int) -> int:
152
        if message_count <= 0:
153
            return 0
154

155
        chat_history = self.get_all()
156
        msg_str = " ".join(str(m.content) for m in chat_history[-message_count:])
157
        return len(self.tokenizer_fn(msg_str))
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.