build-your-own-rag-chatbot

Форк
0
156 строк · 6.0 Кб
1
import streamlit as st
2
import tempfile, os
3
from langchain_openai import OpenAIEmbeddings
4
from langchain_openai import ChatOpenAI
5
from langchain_community.vectorstores import AstraDB
6
from langchain.schema.runnable import RunnableMap
7
from langchain.prompts import ChatPromptTemplate
8
from langchain.callbacks.base import BaseCallbackHandler
9
from langchain.text_splitter import RecursiveCharacterTextSplitter
10
from langchain_community.document_loaders import PyPDFLoader
11

12
# Streaming call back handler for responses
13
class StreamHandler(BaseCallbackHandler):
14
    def __init__(self, container, initial_text=""):
15
        self.container = container
16
        self.text = initial_text
17

18
    def on_llm_new_token(self, token: str, **kwargs):
19
        self.text += token
20
        self.container.markdown(self.text + "▌")
21

22
# Function for Vectorizing uploaded data into Astra DB
23
def vectorize_text(uploaded_files, vector_store):
24
    for uploaded_file in uploaded_files:
25
        if uploaded_file is not None:
26
            
27
            # Write to temporary file
28
            temp_dir = tempfile.TemporaryDirectory()
29
            file = uploaded_file
30
            print(f"""Processing: {file}""")
31
            temp_filepath = os.path.join(temp_dir.name, file.name)
32
            with open(temp_filepath, 'wb') as f:
33
                f.write(file.getvalue())
34

35
            # Process TXT
36
            if uploaded_file.name.endswith('txt'):
37
                file = [uploaded_file.read().decode()]
38

39
                text_splitter = RecursiveCharacterTextSplitter(
40
                    chunk_size = 1500,
41
                    chunk_overlap  = 100
42
                )
43

44
                texts = text_splitter.create_documents(file, [{'source': uploaded_file.name}])
45
                vector_store.add_documents(texts)
46
                st.info(f"Loaded {len(texts)} chunks")
47

48
# Cache prompt for future runs
49
@st.cache_data()
50
def load_prompt():
51
    template = """You're a helpful AI assistent tasked to answer the user's questions.
52
You're friendly and you answer extensively with multiple sentences. You prefer to use bulletpoints to summarize.
53

54
CONTEXT:
55
{context}
56

57
QUESTION:
58
{question}
59

60
YOUR ANSWER:"""
61
    return ChatPromptTemplate.from_messages([("system", template)])
62

63
# Cache OpenAI Chat Model for future runs
64
@st.cache_resource()
65
def load_chat_model(openai_api_key):
66
    return ChatOpenAI(
67
        openai_api_key=openai_api_key,
68
        temperature=0.3,
69
        model='gpt-3.5-turbo',
70
        streaming=True,
71
        verbose=True
72
    )
73

74
# Cache the Astra DB Vector Store for future runs
75
@st.cache_resource(show_spinner='Connecting to Astra DB Vector Store')
76
def load_vector_store(_astra_db_endpoint, astra_db_secret, openai_api_key):
77
    # Connect to the Vector Store
78
    vector_store = AstraDB(
79
        embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
80
        collection_name="my_store",
81
        api_endpoint=astra_db_endpoint,
82
        token=astra_db_secret
83
    )
84
    return vector_store
85

86
# Cache the Retriever for future runs
87
@st.cache_resource(show_spinner='Getting retriever')
88
def load_retriever(_vector_store):
89
    # Get the retriever for the Chat Model
90
    retriever = vector_store.as_retriever(
91
        search_kwargs={"k": 5}
92
    )
93
    return retriever
94

95
# Start with empty messages, stored in session state
96
if 'messages' not in st.session_state:
97
    st.session_state.messages = []
98

99
# Draw a title and some markdown
100
st.title("Your personal Efficiency Booster")
101
st.markdown("""Generative AI is considered to bring the next Industrial Revolution.  
102
Why? Studies show a **37% efficiency boost** in day to day work activities!""")
103

104
# Get the secrets
105
astra_db_endpoint = st.sidebar.text_input('Astra DB Endpoint', type="password")
106
astra_db_secret = st.sidebar.text_input('Astra DB Secret', type="password")
107
openai_api_key = st.sidebar.text_input('OpenAI API Key', type="password")
108

109
# Draw all messages, both user and bot so far (every time the app reruns)
110
for message in st.session_state.messages:
111
    st.chat_message(message['role']).markdown(message['content'])
112

113
# Draw the chat input box
114
if not openai_api_key.startswith('sk-') or not astra_db_endpoint.startswith('https') or not astra_db_secret.startswith('AstraCS'):
115
    st.warning('Please enter your Astra DB Endpoint, Astra DB Secret and Open AI API Key!', icon='⚠')
116

117
else:
118
    prompt = load_prompt()
119
    chat_model = load_chat_model(openai_api_key)
120
    vector_store = load_vector_store(astra_db_endpoint, astra_db_secret, openai_api_key)
121
    retriever = load_retriever(vector_store)
122

123
    # Include the upload form for new data to be Vectorized
124
    with st.sidebar:
125
        st.divider()
126
        uploaded_file = st.file_uploader('Upload a document for additional context', type=['txt'], accept_multiple_files=True)
127
        submitted = st.button('Save to Astra DB')
128
        if submitted:
129
            vectorize_text(uploaded_file, vector_store)
130

131
    if question := st.chat_input("What's up?"):
132
            # Store the user's question in a session object for redrawing next time
133
            st.session_state.messages.append({"role": "human", "content": question})
134

135
            # Draw the user's question
136
            with st.chat_message('human'):
137
                st.markdown(question)
138

139
            # UI placeholder to start filling with agent response
140
            with st.chat_message('assistant'):
141
                response_placeholder = st.empty()
142

143
            # Generate the answer by calling OpenAI's Chat Model
144
            inputs = RunnableMap({
145
                'context': lambda x: retriever.get_relevant_documents(x['question']),
146
                'question': lambda x: x['question']
147
            })
148
            chain = inputs | prompt | chat_model
149
            response = chain.invoke({'question': question}, config={'callbacks': [StreamHandler(response_placeholder)]})
150
            answer = response.content
151

152
            # Store the bot's answer in a session object for redrawing next time
153
            st.session_state.messages.append({"role": "ai", "content": answer})
154

155
            # Write the final answer without the cursor
156
            response_placeholder.markdown(answer)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.