ollama
61 строка · 2.1 Кб
1from langchain.document_loaders import OnlinePDFLoader
2from langchain.vectorstores import Chroma
3from langchain.embeddings import GPT4AllEmbeddings
4from langchain import PromptTemplate
5from langchain.llms import Ollama
6from langchain.callbacks.manager import CallbackManager
7from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
8from langchain.chains import RetrievalQA
9import sys
10import os
11
12class SuppressStdout:
13def __enter__(self):
14self._original_stdout = sys.stdout
15self._original_stderr = sys.stderr
16sys.stdout = open(os.devnull, 'w')
17sys.stderr = open(os.devnull, 'w')
18
19def __exit__(self, exc_type, exc_val, exc_tb):
20sys.stdout.close()
21sys.stdout = self._original_stdout
22sys.stderr = self._original_stderr
23
24# load the pdf and split it into chunks
25loader = OnlinePDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001813756/975b3e9b-268e-4798-a9e4-2a9a7c92dc10.pdf")
26data = loader.load()
27
28from langchain.text_splitter import RecursiveCharacterTextSplitter
29text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
30all_splits = text_splitter.split_documents(data)
31
32with SuppressStdout():
33vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings())
34
35while True:
36query = input("\nQuery: ")
37if query == "exit":
38break
39if query.strip() == "":
40continue
41
42# Prompt
43template = """Use the following pieces of context to answer the question at the end.
44If you don't know the answer, just say that you don't know, don't try to make up an answer.
45Use three sentences maximum and keep the answer as concise as possible.
46{context}
47Question: {question}
48Helpful Answer:"""
49QA_CHAIN_PROMPT = PromptTemplate(
50input_variables=["context", "question"],
51template=template,
52)
53
54llm = Ollama(model="llama2:13b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
55qa_chain = RetrievalQA.from_chain_type(
56llm,
57retriever=vectorstore.as_retriever(),
58chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
59)
60
61result = qa_chain({"query": query})