3
from pydantic import BaseModel, Field
4
from llama_index import (
7
load_index_from_storage,
9
from typing import List, cast, Optional
10
from llama_index.chat_engine.types import BaseChatEngine
11
from pathlib import Path
14
from core.utils import (
23
class ParamCache(BaseModel):
24
"""Cache for RAG agent builder.
26
Created a wrapper class around a dict in case we wanted to more explicitly
27
type different items in the cache.
33
arbitrary_types_allowed = True
36
system_prompt: Optional[str] = Field(
37
default=None, description="System prompt for RAG agent."
40
file_names: List[str] = Field(
41
default_factory=list, description="File names as data source (if specified)"
43
urls: List[str] = Field(
44
default_factory=list, description="URLs as data source (if specified)"
46
directory: Optional[str] = Field(
47
default=None, description="Directory as data source (if specified)"
50
docs: List = Field(default_factory=list, description="Documents for RAG agent.")
53
default_factory=list, description="Additional tools for RAG agent (e.g. web)"
56
rag_params: RAGParams = Field(
57
default_factory=RAGParams, description="RAG parameters for RAG agent."
61
builder_type: str = Field(
62
default="default", description="Builder type (default, multimodal)."
64
vector_index: Optional[VectorStoreIndex] = Field(
65
default=None, description="Vector index for RAG agent."
67
agent_id: str = Field(
68
default_factory=lambda: f"Agent_{str(uuid.uuid4())}",
69
description="Agent ID for RAG agent.",
71
agent: Optional[BaseChatEngine] = Field(default=None, description="RAG agent.")
73
def save_to_disk(self, save_dir: str) -> None:
74
"""Save cache to disk."""
79
"system_prompt": self.system_prompt,
80
"file_names": self.file_names,
82
"directory": self.directory,
85
"rag_params": self.rag_params.dict(),
86
"builder_type": self.builder_type,
87
"agent_id": self.agent_id,
90
if self.vector_index is None:
91
raise ValueError("Must specify vector index in order to save.")
92
self.vector_index.storage_context.persist(Path(save_dir) / "storage")
95
if not Path(save_dir).exists():
96
Path(save_dir).mkdir(parents=True)
97
with open(Path(save_dir) / "cache.json", "w") as f:
98
json.dump(dict_to_serialize, f)
105
"""Load cache from disk."""
106
with open(Path(save_dir) / "cache.json", "r") as f:
107
cache_dict = json.load(f)
109
storage_context = StorageContext.from_defaults(
110
persist_dir=str(Path(save_dir) / "storage")
112
if cache_dict["builder_type"] == "multimodal":
113
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
115
vector_index: VectorStoreIndex = cast(
116
MultiModalVectorStoreIndex, load_index_from_storage(storage_context)
120
VectorStoreIndex, load_index_from_storage(storage_context)
124
cache_dict["rag_params"] = RAGParams(**cache_dict["rag_params"])
128
cache_dict["docs"] = load_data(
129
file_names=cache_dict["file_names"],
130
urls=cache_dict["urls"],
131
directory=cache_dict["directory"],
134
additional_tools = get_tool_objects(cache_dict["tools"])
136
if cache_dict["builder_type"] == "multimodal":
137
vector_index = cast(MultiModalVectorStoreIndex, vector_index)
138
agent, _ = construct_mm_agent(
139
cache_dict["system_prompt"],
140
cache_dict["rag_params"],
142
mm_vector_index=vector_index,
145
agent, _ = construct_agent(
146
cache_dict["system_prompt"],
147
cache_dict["rag_params"],
149
vector_index=vector_index,
150
additional_tools=additional_tools,
153
cache_dict["vector_index"] = vector_index
154
cache_dict["agent"] = agent
156
return cls(**cache_dict)