rags

Форк
0
/
param_cache.py 
156 строк · 5.1 Кб
1
"""Param cache."""
2

3
from pydantic import BaseModel, Field
4
from llama_index import (
5
    VectorStoreIndex,
6
    StorageContext,
7
    load_index_from_storage,
8
)
9
from typing import List, cast, Optional
10
from llama_index.chat_engine.types import BaseChatEngine
11
from pathlib import Path
12
import json
13
import uuid
14
from core.utils import (
15
    load_data,
16
    get_tool_objects,
17
    construct_agent,
18
    RAGParams,
19
    construct_mm_agent,
20
)
21

22

23
class ParamCache(BaseModel):
24
    """Cache for RAG agent builder.
25

26
    Created a wrapper class around a dict in case we wanted to more explicitly
27
    type different items in the cache.
28

29
    """
30

31
    # arbitrary types
32
    class Config:
33
        arbitrary_types_allowed = True
34

35
    # system prompt
36
    system_prompt: Optional[str] = Field(
37
        default=None, description="System prompt for RAG agent."
38
    )
39
    # data
40
    file_names: List[str] = Field(
41
        default_factory=list, description="File names as data source (if specified)"
42
    )
43
    urls: List[str] = Field(
44
        default_factory=list, description="URLs as data source (if specified)"
45
    )
46
    directory: Optional[str] = Field(
47
        default=None, description="Directory as data source (if specified)"
48
    )
49

50
    docs: List = Field(default_factory=list, description="Documents for RAG agent.")
51
    # tools
52
    tools: List = Field(
53
        default_factory=list, description="Additional tools for RAG agent (e.g. web)"
54
    )
55
    # RAG params
56
    rag_params: RAGParams = Field(
57
        default_factory=RAGParams, description="RAG parameters for RAG agent."
58
    )
59

60
    # agent params
61
    builder_type: str = Field(
62
        default="default", description="Builder type (default, multimodal)."
63
    )
64
    vector_index: Optional[VectorStoreIndex] = Field(
65
        default=None, description="Vector index for RAG agent."
66
    )
67
    agent_id: str = Field(
68
        default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
69
        description="Agent ID for RAG agent.",
70
    )
71
    agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
72

73
    def save_to_disk(self, save_dir: str) -> None:
74
        """Save cache to disk."""
75
        # NOTE: more complex than just calling dict() because we want to
76
        # only store serializable fields and be space-efficient
77

78
        dict_to_serialize = {
79
            "system_prompt": self.system_prompt,
80
            "file_names": self.file_names,
81
            "urls": self.urls,
82
            "directory": self.directory,
83
            # TODO: figure out tools
84
            "tools": self.tools,
85
            "rag_params": self.rag_params.dict(),
86
            "builder_type": self.builder_type,
87
            "agent_id": self.agent_id,
88
        }
89
        # store the vector store within the agent
90
        if self.vector_index is None:
91
            raise ValueError("Must specify vector index in order to save.")
92
        self.vector_index.storage_context.persist(Path(save_dir) / "storage")
93

94
        # if save_path directories don't exist, create it
95
        if not Path(save_dir).exists():
96
            Path(save_dir).mkdir(parents=True)
97
        with open(Path(save_dir) / "cache.json", "w") as f:
98
            json.dump(dict_to_serialize, f)
99

100
    @classmethod
101
    def load_from_disk(
102
        cls,
103
        save_dir: str,
104
    ) -> "ParamCache":
105
        """Load cache from disk."""
106
        with open(Path(save_dir) / "cache.json", "r") as f:
107
            cache_dict = json.load(f)
108

109
        storage_context = StorageContext.from_defaults(
110
            persist_dir=str(Path(save_dir) / "storage")
111
        )
112
        if cache_dict["builder_type"] == "multimodal":
113
            from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
114

115
            vector_index: VectorStoreIndex = cast(
116
                MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
117
            )
118
        else:
119
            vector_index = cast(
120
                VectorStoreIndex, load_index_from_storage(storage_context)
121
            )
122

123
        # replace rag params with RAGParams object
124
        cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
125

126
        # add in the missing fields
127
        # load docs
128
        cache_dict["docs"] = load_data(
129
            file_names=cache_dict["file_names"],
130
            urls=cache_dict["urls"],
131
            directory=cache_dict["directory"],
132
        )
133
        # load agent from index
134
        additional_tools = get_tool_objects(cache_dict["tools"])
135

136
        if cache_dict["builder_type"] == "multimodal":
137
            vector_index = cast(MultiModalVectorStoreIndex, vector_index)
138
            agent, _ = construct_mm_agent(
139
                cache_dict["system_prompt"],
140
                cache_dict["rag_params"],
141
                cache_dict["docs"],
142
                mm_vector_index=vector_index,
143
            )
144
        else:
145
            agent, _ = construct_agent(
146
                cache_dict["system_prompt"],
147
                cache_dict["rag_params"],
148
                cache_dict["docs"],
149
                vector_index=vector_index,
150
                additional_tools=additional_tools,
151
                # TODO: figure out tools
152
            )
153
        cache_dict["vector_index"] = vector_index
154
        cache_dict["agent"] = agent
155

156
        return cls(**cache_dict)
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.