local-llm-with-rag
/
app.py
122 строки · 3.4 Кб
1from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
2from langchain_community.llms import Ollama
3from langchain_community.embeddings import OllamaEmbeddings
4from langchain.text_splitter import RecursiveCharacterTextSplitter
5from langchain_community.vectorstores import Chroma
6from langchain.chains import RetrievalQA
7from langchain.prompts import PromptTemplate
8from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9from models import check_if_model_is_available
10from document_loader import load_documents
11import argparse
12import sys
13
14
15TEXT_SPLITTER = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
16
17
18PROMPT_TEMPLATE = """
19### Instruction:
20You're helpful assistant, who answers questions based upon provided research in a distinct and clear way.
21
22## Research:
23{context}
24
25## Question:
26{question}
27"""
28
29
30PROMPT = PromptTemplate(
31template=PROMPT_TEMPLATE, input_variables=["context", "question"]
32)
33
34
35def load_documents_into_database(model_name: str, documents_path: str) -> Chroma:
36"""
37Loads documents from the specified directory into the Chroma database
38after splitting the text into chunks.
39
40Returns:
41Chroma: The Chroma database with loaded documents.
42"""
43
44print("Loading documents")
45raw_documents = load_documents(documents_path)
46documents = TEXT_SPLITTER.split_documents(raw_documents)
47
48print("Creating embeddings and loading documents into Chroma")
49db = Chroma.from_documents(
50documents,
51OllamaEmbeddings(model=model_name),
52)
53return db
54
55
56def main(llm_model_name: str, embedding_model_name: str, documents_path: str) -> None:
57# Check to see if the models available, if not attempt to pull them
58try:
59check_if_model_is_available(llm_model_name)
60check_if_model_is_available(embedding_model_name)
61except Exception as e:
62print(e)
63sys.exit()
64
65# Creating database form documents
66try:
67db = load_documents_into_database(embedding_model_name, documents_path)
68except FileNotFoundError as e:
69print(e)
70sys.exit()
71
72llm = Ollama(
73model=llm_model_name,
74callbacks=[StreamingStdOutCallbackHandler()],
75)
76
77qa_chain = RetrievalQA.from_chain_type(
78llm,
79retriever=db.as_retriever(search_kwargs={"k": 8}),
80chain_type_kwargs={"prompt": PROMPT},
81)
82
83while True:
84try:
85user_input = input(
86"\n\nPlease enter your question (or type 'exit' to end): "
87)
88if user_input.lower() == "exit":
89break
90
91docs = db.similarity_search(user_input)
92qa_chain.invoke({"query": user_input})
93except KeyboardInterrupt:
94break
95
96
97def parse_arguments() -> argparse.Namespace:
98parser = argparse.ArgumentParser(description="Run local LLM with RAG with Ollama.")
99parser.add_argument(
100"-m",
101"--model",
102default="mistral",
103help="The name of the LLM model to use.",
104)
105parser.add_argument(
106"-e",
107"--embedding_model",
108default="nomic-embed-text",
109help="The name of the embedding model to use.",
110)
111parser.add_argument(
112"-p",
113"--path",
114default="Research",
115help="The path to the directory containing documents to load.",
116)
117return parser.parse_args()
118
119
120if __name__ == "__main__":
121args = parse_arguments()
122main(args.model, args.embedding_model, args.path)
123