build-your-own-rag-chatbot
/
app_colab.py
156 строк · 6.0 Кб
1import streamlit as st2import tempfile, os3from langchain_openai import OpenAIEmbeddings4from langchain_openai import ChatOpenAI5from langchain_community.vectorstores import AstraDB6from langchain.schema.runnable import RunnableMap7from langchain.prompts import ChatPromptTemplate8from langchain.callbacks.base import BaseCallbackHandler9from langchain.text_splitter import RecursiveCharacterTextSplitter10from langchain_community.document_loaders import PyPDFLoader11
12# Streaming call back handler for responses
13class StreamHandler(BaseCallbackHandler):14def __init__(self, container, initial_text=""):15self.container = container16self.text = initial_text17
18def on_llm_new_token(self, token: str, **kwargs):19self.text += token20self.container.markdown(self.text + "▌")21
22# Function for Vectorizing uploaded data into Astra DB
23def vectorize_text(uploaded_files, vector_store):24for uploaded_file in uploaded_files:25if uploaded_file is not None:26
27# Write to temporary file28temp_dir = tempfile.TemporaryDirectory()29file = uploaded_file30print(f"""Processing: {file}""")31temp_filepath = os.path.join(temp_dir.name, file.name)32with open(temp_filepath, 'wb') as f:33f.write(file.getvalue())34
35# Process TXT36if uploaded_file.name.endswith('txt'):37file = [uploaded_file.read().decode()]38
39text_splitter = RecursiveCharacterTextSplitter(40chunk_size = 1500,41chunk_overlap = 10042)43
44texts = text_splitter.create_documents(file, [{'source': uploaded_file.name}])45vector_store.add_documents(texts)46st.info(f"Loaded {len(texts)} chunks")47
48# Cache prompt for future runs
49@st.cache_data()50def load_prompt():51template = """You're a helpful AI assistent tasked to answer the user's questions.52You're friendly and you answer extensively with multiple sentences. You prefer to use bulletpoints to summarize.
53
54CONTEXT:
55{context}
56
57QUESTION:
58{question}
59
60YOUR ANSWER:"""
61return ChatPromptTemplate.from_messages([("system", template)])62
63# Cache OpenAI Chat Model for future runs
64@st.cache_resource()65def load_chat_model(openai_api_key):66return ChatOpenAI(67openai_api_key=openai_api_key,68temperature=0.3,69model='gpt-3.5-turbo',70streaming=True,71verbose=True72)73
74# Cache the Astra DB Vector Store for future runs
75@st.cache_resource(show_spinner='Connecting to Astra DB Vector Store')76def load_vector_store(_astra_db_endpoint, astra_db_secret, openai_api_key):77# Connect to the Vector Store78vector_store = AstraDB(79embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),80collection_name="my_store",81api_endpoint=astra_db_endpoint,82token=astra_db_secret83)84return vector_store85
86# Cache the Retriever for future runs
87@st.cache_resource(show_spinner='Getting retriever')88def load_retriever(_vector_store):89# Get the retriever for the Chat Model90retriever = vector_store.as_retriever(91search_kwargs={"k": 5}92)93return retriever94
95# Start with empty messages, stored in session state
96if 'messages' not in st.session_state:97st.session_state.messages = []98
99# Draw a title and some markdown
100st.title("Your personal Efficiency Booster")101st.markdown("""Generative AI is considered to bring the next Industrial Revolution.102Why? Studies show a **37% efficiency boost** in day to day work activities!""")103
104# Get the secrets
105astra_db_endpoint = st.sidebar.text_input('Astra DB Endpoint', type="password")106astra_db_secret = st.sidebar.text_input('Astra DB Secret', type="password")107openai_api_key = st.sidebar.text_input('OpenAI API Key', type="password")108
109# Draw all messages, both user and bot so far (every time the app reruns)
110for message in st.session_state.messages:111st.chat_message(message['role']).markdown(message['content'])112
113# Draw the chat input box
114if not openai_api_key.startswith('sk-') or not astra_db_endpoint.startswith('https') or not astra_db_secret.startswith('AstraCS'):115st.warning('Please enter your Astra DB Endpoint, Astra DB Secret and Open AI API Key!', icon='⚠')116
117else:118prompt = load_prompt()119chat_model = load_chat_model(openai_api_key)120vector_store = load_vector_store(astra_db_endpoint, astra_db_secret, openai_api_key)121retriever = load_retriever(vector_store)122
123# Include the upload form for new data to be Vectorized124with st.sidebar:125st.divider()126uploaded_file = st.file_uploader('Upload a document for additional context', type=['txt'], accept_multiple_files=True)127submitted = st.button('Save to Astra DB')128if submitted:129vectorize_text(uploaded_file, vector_store)130
131if question := st.chat_input("What's up?"):132# Store the user's question in a session object for redrawing next time133st.session_state.messages.append({"role": "human", "content": question})134
135# Draw the user's question136with st.chat_message('human'):137st.markdown(question)138
139# UI placeholder to start filling with agent response140with st.chat_message('assistant'):141response_placeholder = st.empty()142
143# Generate the answer by calling OpenAI's Chat Model144inputs = RunnableMap({145'context': lambda x: retriever.get_relevant_documents(x['question']),146'question': lambda x: x['question']147})148chain = inputs | prompt | chat_model149response = chain.invoke({'question': question}, config={'callbacks': [StreamHandler(response_placeholder)]})150answer = response.content151
152# Store the bot's answer in a session object for redrawing next time153st.session_state.messages.append({"role": "ai", "content": answer})154
155# Write the final answer without the cursor156response_placeholder.markdown(answer)