chromadb

Форк
0
143 строки · 4.3 Кб
1
import argparse
2
import os
3
from typing import List
4

5
import google.generativeai as genai
6
import chromadb
7
from chromadb.utils import embedding_functions
8

9
model = genai.GenerativeModel("gemini-pro")
10

11

12
def build_prompt(query: str, context: List[str]) -> str:
13
    """
14
    Builds a prompt for the LLM. #
15

16
    This function builds a prompt for the LLM. It takes the original query,
17
    and the returned context, and asks the model to answer the question based only
18
    on what's in the context, not what's in its weights.
19

20
    Args:
21
    query (str): The original query.
22
    context (List[str]): The context of the query, returned by embedding search.
23

24
    Returns:
25
    A prompt for the LLM (str).
26
    """
27

28
    base_prompt = {
29
        "content": "I am going to ask you a question, which I would like you to answer"
30
        " based only on the provided context, and not any other information."
31
        " If there is not enough information in the context to answer the question,"
32
        ' say "I am not sure", then try to make a guess.'
33
        " Break your answer up into nicely readable paragraphs.",
34
    }
35
    user_prompt = {
36
        "content": f" The question is '{query}'. Here is all the context you have:"
37
        f'{(" ").join(context)}',
38
    }
39

40
    # combine the prompts to output a single prompt string
41
    system = f"{base_prompt['content']} {user_prompt['content']}"
42

43
    return system
44

45

46
def get_gemini_response(query: str, context: List[str]) -> str:
47
    """
48
    Queries the Gemini API to get a response to the question.
49

50
    Args:
51
    query (str): The original query.
52
    context (List[str]): The context of the query, returned by embedding search.
53

54
    Returns:
55
    A response to the question.
56
    """
57

58
    response = model.generate_content(build_prompt(query, context))
59

60
    return response.text
61

62

63
def main(
64
    collection_name: str = "documents_collection", persist_directory: str = "."
65
) -> None:
66
    # Check if the GOOGLE_API_KEY environment variable is set. Prompt the user to set it if not.
67
    google_api_key = None
68
    if "GOOGLE_API_KEY" not in os.environ:
69
        gapikey = input("Please enter your Google API Key: ")
70
        genai.configure(api_key=gapikey)
71
        google_api_key = gapikey
72
    else:
73
        google_api_key = os.environ["GOOGLE_API_KEY"]
74

75
    # Instantiate a persistent chroma client in the persist_directory.
76
    # This will automatically load any previously saved collections.
77
    # Learn more at docs.trychroma.com
78
    client = chromadb.PersistentClient(path=persist_directory)
79

80
    # create embedding function
81
    embedding_function = embedding_functions.GoogleGenerativeAIEmbeddingFunction(api_key=google_api_key, task_type="RETRIEVAL_QUERY")
82

83
    # Get the collection.
84
    collection = client.get_collection(
85
        name=collection_name, embedding_function=embedding_function
86
    )
87

88
    # We use a simple input loop.
89
    while True:
90
        # Get the user's query
91
        query = input("Query: ")
92
        if len(query) == 0:
93
            print("Please enter a question. Ctrl+C to Quit.\n")
94
            continue
95
        print("\nThinking...\n")
96

97
        # Query the collection to get the 5 most relevant results
98
        results = collection.query(
99
            query_texts=[query], n_results=5, include=["documents", "metadatas"]
100
        )
101

102
        sources = "\n".join(
103
            [
104
                f"{result['filename']}: line {result['line_number']}"
105
                for result in results["metadatas"][0]  # type: ignore
106
            ]
107
        )
108

109
        # Get the response from Gemini
110
        response = get_gemini_response(query, results["documents"][0])  # type: ignore
111

112
        # Output, with sources
113
        print(response)
114
        print("\n")
115
        print(f"Source documents:\n{sources}")
116
        print("\n")
117

118

119
if __name__ == "__main__":
120
    parser = argparse.ArgumentParser(
121
        description="Load documents from a directory into a Chroma collection"
122
    )
123

124
    parser.add_argument(
125
        "--persist_directory",
126
        type=str,
127
        default="chroma_storage",
128
        help="The directory where you want to store the Chroma collection",
129
    )
130
    parser.add_argument(
131
        "--collection_name",
132
        type=str,
133
        default="documents_collection",
134
        help="The name of the Chroma collection",
135
    )
136

137
    # Parse arguments
138
    args = parser.parse_args()
139

140
    main(
141
        collection_name=args.collection_name,
142
        persist_directory=args.persist_directory,
143
    )
144

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.