5
import google.generativeai as genai
7
from chromadb.utils import embedding_functions
9
model = genai.GenerativeModel("gemini-pro")
12
def build_prompt(query: str, context: List[str]) -> str:
14
Builds a prompt for the LLM. #
16
This function builds a prompt for the LLM. It takes the original query,
17
and the returned context, and asks the model to answer the question based only
18
on what's in the context, not what's in its weights.
21
query (str): The original query.
22
context (List[str]): The context of the query, returned by embedding search.
25
A prompt for the LLM (str).
29
"content": "I am going to ask you a question, which I would like you to answer"
30
" based only on the provided context, and not any other information."
31
" If there is not enough information in the context to answer the question,"
32
' say "I am not sure", then try to make a guess.'
33
" Break your answer up into nicely readable paragraphs.",
36
"content": f" The question is '{query}'. Here is all the context you have:"
37
f'{(" ").join(context)}',
41
system = f"{base_prompt['content']} {user_prompt['content']}"
46
def get_gemini_response(query: str, context: List[str]) -> str:
48
Queries the Gemini API to get a response to the question.
51
query (str): The original query.
52
context (List[str]): The context of the query, returned by embedding search.
55
A response to the question.
58
response = model.generate_content(build_prompt(query, context))
64
collection_name: str = "documents_collection", persist_directory: str = "."
68
if "GOOGLE_API_KEY" not in os.environ:
69
gapikey = input("Please enter your Google API Key: ")
70
genai.configure(api_key=gapikey)
71
google_api_key = gapikey
73
google_api_key = os.environ["GOOGLE_API_KEY"]
78
client = chromadb.PersistentClient(path=persist_directory)
81
embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction(api_key=google_api_key, task_type="RETRIEVAL_QUERY")
84
collection = client.get_collection(
85
name=collection_name, embedding_function=embedding_function
91
query = input("Query: ")
93
print("Please enter a question. Ctrl+C to Quit.\n")
95
print("\nThinking...\n")
98
results = collection.query(
99
query_texts=[query], n_results=5, include=["documents", "metadatas"]
104
f"{result['filename']}: line {result['line_number']}"
105
for result in results["metadatas"][0]
110
response = get_gemini_response(query, results["documents"][0])
115
print(f"Source documents:\n{sources}")
119
if __name__ == "__main__":
120
parser = argparse.ArgumentParser(
121
description="Load documents from a directory into a Chroma collection"
125
"--persist_directory",
127
default="chroma_storage",
128
help="The directory where you want to store the Chroma collection",
133
default="documents_collection",
134
help="The name of the Chroma collection",
138
args = parser.parse_args()
141
collection_name=args.collection_name,
142
persist_directory=args.persist_directory,