rags

Форк
0
/
st_utils.py 
191 строка · 6.2 Кб
1
"""Streamlit utils."""
2
from core.agent_builder.loader import (
3
    load_meta_agent_and_tools,
4
    AgentCacheRegistry,
5
)
6
from core.agent_builder.base import BaseRAGAgentBuilder
7
from core.param_cache import ParamCache
8
from core.constants import (
9
    AGENT_CACHE_DIR,
10
)
11
from typing import Optional, cast
12
from pydantic import BaseModel
13

14
from llama_index.agent.types import BaseAgent
15
import streamlit as st
16

17

18
def update_selected_agent_with_id(selected_id: Optional[str] = None) -> None:
19
    """Update selected agent with id."""
20
    # set session state
21
    st.session_state.selected_id = (
22
        selected_id if selected_id != "Create a new agent" else None
23
    )
24

25
    # clear agent builder and builder agent
26
    st.session_state.builder_agent = None
27
    st.session_state.agent_builder = None
28

29
    # clear selected cache
30
    st.session_state.selected_cache = None
31

32

33
## handler for sidebar specifically
34
def update_selected_agent() -> None:
35
    """Update selected agent."""
36
    selected_id = st.session_state.agent_selector
37

38
    update_selected_agent_with_id(selected_id)
39

40

41
def get_cached_is_multimodal() -> bool:
42
    """Get default multimodal st."""
43
    if (
44
        "selected_cache" not in st.session_state.keys()
45
        or st.session_state.selected_cache is None
46
    ):
47
        default_val = False
48
    else:
49
        selected_cache = cast(ParamCache, st.session_state.selected_cache)
50
        default_val = True if selected_cache.builder_type == "multimodal" else False
51
    return default_val
52

53

54
def get_is_multimodal() -> bool:
55
    """Get is multimodal."""
56
    if "is_multimodal_st" not in st.session_state.keys():
57
        st.session_state.is_multimodal_st = False
58
    return st.session_state.is_multimodal_st
59

60

61
def add_builder_config() -> None:
62
    """Add builder config."""
63
    with st.expander("Builder Config (Advanced)"):
64
        # add a few options - openai api key, and
65
        if (
66
            "selected_cache" not in st.session_state.keys()
67
            or st.session_state.selected_cache is None
68
        ):
69
            is_locked = False
70
        else:
71
            is_locked = True
72

73
        st.checkbox(
74
            "Enable multimodal search (beta)",
75
            key="is_multimodal_st",
76
            on_change=update_selected_agent,
77
            value=get_cached_is_multimodal(),
78
            disabled=is_locked,
79
        )
80

81

82
def add_sidebar() -> None:
83
    """Add sidebar."""
84
    with st.sidebar:
85
        agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
86
        st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
87
        choices = ["Create a new agent"] + st.session_state.cur_agent_ids
88

89
        # by default, set index to 0. if value is in selected_id, set index to that
90
        index = 0
91
        if "selected_id" in st.session_state.keys():
92
            if st.session_state.selected_id is not None:
93
                index = choices.index(st.session_state.selected_id)
94
        # display buttons
95
        st.radio(
96
            "Agents",
97
            choices,
98
            index=index,
99
            on_change=update_selected_agent,
100
            key="agent_selector",
101
        )
102

103

104
class CurrentSessionState(BaseModel):
105
    """Current session state."""
106

107
    # arbitrary types
108
    class Config:
109
        arbitrary_types_allowed = True
110

111
    agent_registry: AgentCacheRegistry
112
    selected_id: Optional[str]
113
    selected_cache: Optional[ParamCache]
114
    agent_builder: BaseRAGAgentBuilder
115
    cache: ParamCache
116
    builder_agent: BaseAgent
117

118

119
def get_current_state() -> CurrentSessionState:
120
    """Get current state.
121

122
    This includes current state stored in session state and derived from it, e.g.
123
    - agent registry
124
    - selected agent
125
    - selected cache
126
    - agent builder
127
    - builder agent
128

129
    """
130
    # get agent registry
131
    agent_registry = AgentCacheRegistry(str(AGENT_CACHE_DIR))
132
    if "agent_registry" not in st.session_state.keys():
133
        st.session_state.agent_registry = agent_registry
134

135
    if "cur_agent_ids" not in st.session_state.keys():
136
        st.session_state.cur_agent_ids = agent_registry.get_agent_ids()
137

138
    if "selected_id" not in st.session_state.keys():
139
        st.session_state.selected_id = None
140

141
    # set selected cache if doesn't exist
142
    if (
143
        "selected_cache" not in st.session_state.keys()
144
        or st.session_state.selected_cache is None
145
    ):
146
        # update selected cache
147
        if st.session_state.selected_id is None:
148
            st.session_state.selected_cache = None
149
        else:
150
            # load agent from directory
151
            agent_registry = cast(AgentCacheRegistry, st.session_state.agent_registry)
152
            agent_cache = agent_registry.get_agent_cache(st.session_state.selected_id)
153
            st.session_state.selected_cache = agent_cache
154

155
    # set builder agent / agent builder
156
    if (
157
        "builder_agent" not in st.session_state.keys()
158
        or st.session_state.builder_agent is None
159
        or "agent_builder" not in st.session_state.keys()
160
        or st.session_state.agent_builder is None
161
    ):
162
        if (
163
            "selected_cache" in st.session_state.keys()
164
            and st.session_state.selected_cache is not None
165
        ):
166
            # create builder agent / tools from selected cache
167
            builder_agent, agent_builder = load_meta_agent_and_tools(
168
                cache=st.session_state.selected_cache,
169
                agent_registry=st.session_state.agent_registry,
170
                # NOTE: we will probably generalize this later into different
171
                # builder configs
172
                is_multimodal=get_cached_is_multimodal(),
173
            )
174
        else:
175
            # create builder agent / tools from new cache
176
            builder_agent, agent_builder = load_meta_agent_and_tools(
177
                agent_registry=st.session_state.agent_registry,
178
                is_multimodal=get_is_multimodal(),
179
            )
180

181
        st.session_state.builder_agent = builder_agent
182
        st.session_state.agent_builder = agent_builder
183

184
    return CurrentSessionState(
185
        agent_registry=st.session_state.agent_registry,
186
        selected_id=st.session_state.selected_id,
187
        selected_cache=st.session_state.selected_cache,
188
        agent_builder=st.session_state.agent_builder,
189
        cache=st.session_state.agent_builder.cache,
190
        builder_agent=st.session_state.builder_agent,
191
    )
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.