rags
/
st_utils.py
191 строка · 6.2 Кб
1"""Streamlit utils."""
2from core.agent_builder.loader import (3load_meta_agent_and_tools,4AgentCacheRegistry,5)
6from core.agent_builder.base import BaseRAGAgentBuilder7from core.param_cache import ParamCache8from core.constants import (9AGENT_CACHE_DIR,10)
11from typing import Optional, cast12from pydantic import BaseModel13
14from llama_index.agent.types import BaseAgent15import streamlit as st16
17
18def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:19"""Update selected agent with id."""20# set session state21st.session_state.selected_id = (22selected_id if selected_id != "Create a new agent" else None23)24
25# clear agent builder and builder agent26st.session_state.builder_agent = None27st.session_state.agent_builder = None28
29# clear selected cache30st.session_state.selected_cache = None31
32
33## handler for sidebar specifically
34def update_selected_agent() -> None:35"""Update selected agent."""36selected_id = st.session_state.agent_selector37
38update_selected_agent_with_id(selected_id)39
40
41def get_cached_is_multimodal() -> bool:42"""Get default multimodal st."""43if (44"selected_cache" not in st.session_state.keys()45or st.session_state.selected_cache is None46):47default_val = False48else:49selected_cache = cast(ParamCache, st.session_state.selected_cache)50default_val = True if selected_cache.builder_type == "multimodal" else False51return default_val52
53
54def get_is_multimodal() -> bool:55"""Get is multimodal."""56if "is_multimodal_st" not in st.session_state.keys():57st.session_state.is_multimodal_st = False58return st.session_state.is_multimodal_st59
60
61def add_builder_config() -> None:62"""Add builder config."""63with st.expander("Builder Config (Advanced)"):64# add a few options - openai api key, and65if (66"selected_cache" not in st.session_state.keys()67or st.session_state.selected_cache is None68):69is_locked = False70else:71is_locked = True72
73st.checkbox(74"Enable multimodal search (beta)",75key="is_multimodal_st",76on_change=update_selected_agent,77value=get_cached_is_multimodal(),78disabled=is_locked,79)80
81
82def add_sidebar() -> None:83"""Add sidebar."""84with st.sidebar:85agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)86st.session_state.cur_agent_ids = agent_registry.get_agent_ids()87choices = ["Create a new agent"] + st.session_state.cur_agent_ids88
89# by default, set index to 0. if value is in selected_id, set index to that90index = 091if "selected_id" in st.session_state.keys():92if st.session_state.selected_id is not None:93index = choices.index(st.session_state.selected_id)94# display buttons95st.radio(96"Agents",97choices,98index=index,99on_change=update_selected_agent,100key="agent_selector",101)102
103
104class CurrentSessionState(BaseModel):105"""Current session state."""106
107# arbitrary types108class Config:109arbitrary_types_allowed = True110
111agent_registry: AgentCacheRegistry112selected_id: Optional[str]113selected_cache: Optional[ParamCache]114agent_builder: BaseRAGAgentBuilder115cache: ParamCache116builder_agent: BaseAgent117
118
119def get_current_state() -> CurrentSessionState:120"""Get current state.121
122This includes current state stored in session state and derived from it, e.g.
123- agent registry
124- selected agent
125- selected cache
126- agent builder
127- builder agent
128
129"""
130# get agent registry131agent_registry = AgentCacheRegistry(str(AGENT_CACHE_DIR))132if "agent_registry" not in st.session_state.keys():133st.session_state.agent_registry = agent_registry134
135if "cur_agent_ids" not in st.session_state.keys():136st.session_state.cur_agent_ids = agent_registry.get_agent_ids()137
138if "selected_id" not in st.session_state.keys():139st.session_state.selected_id = None140
141# set selected cache if doesn't exist142if (143"selected_cache" not in st.session_state.keys()144or st.session_state.selected_cache is None145):146# update selected cache147if st.session_state.selected_id is None:148st.session_state.selected_cache = None149else:150# load agent from directory151agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)152agent_cache = agent_registry.get_agent_cache(st.session_state.selected_id)153st.session_state.selected_cache = agent_cache154
155# set builder agent / agent builder156if (157"builder_agent" not in st.session_state.keys()158or st.session_state.builder_agent is None159or "agent_builder" not in st.session_state.keys()160or st.session_state.agent_builder is None161):162if (163"selected_cache" in st.session_state.keys()164and st.session_state.selected_cache is not None165):166# create builder agent / tools from selected cache167builder_agent, agent_builder = load_meta_agent_and_tools(168cache=st.session_state.selected_cache,169agent_registry=st.session_state.agent_registry,170# NOTE: we will probably generalize this later into different171# builder configs172is_multimodal=get_cached_is_multimodal(),173)174else:175# create builder agent / tools from new cache176builder_agent, agent_builder = load_meta_agent_and_tools(177agent_registry=st.session_state.agent_registry,178is_multimodal=get_is_multimodal(),179)180
181st.session_state.builder_agent = builder_agent182st.session_state.agent_builder = agent_builder183
184return CurrentSessionState(185agent_registry=st.session_state.agent_registry,186selected_id=st.session_state.selected_id,187selected_cache=st.session_state.selected_cache,188agent_builder=st.session_state.agent_builder,189cache=st.session_state.agent_builder.cache,190builder_agent=st.session_state.builder_agent,191)192