rag-demystified

Форк
0
/
llama_index_baseline.py 
207 строк · 6.8 Кб
1
from pathlib import Path
2

3
import requests
4

5
from llama_index import (
6
    VectorStoreIndex,
7
    SummaryIndex,
8
    SimpleKeywordTableIndex,
9
    SimpleDirectoryReader,
10
    ServiceContext,
11
)
12
from llama_index.schema import IndexNode
13
from llama_index.tools import QueryEngineTool, ToolMetadata
14
from llama_index.llms import OpenAI, AzureOpenAI
15
from llama_index.query_engine import SubQuestionQueryEngine
16
from llama_index.agent import OpenAIAgent
17
from llama_index.embeddings import HuggingFaceEmbedding, OpenAIEmbedding
18
from llama_index.callbacks import CallbackManager, TokenCountingHandler
19
from llama_index.response_synthesizers import get_response_synthesizer
20
import tiktoken
21

22
api_type = ""
23
api_base = ""
24
api_version = ""
25
api_key = ""
26

27

28
embed_model_name = "hugging_face"
29

30
if embed_model_name == "hugging_face":
31
    embed_model = HuggingFaceEmbedding(
32
        model_name="sentence-transformers/all-mpnet-base-v2", max_length=512
33
    )
34
elif embed_model_name == "text-embedding-ada-002":
35
    embed_model = OpenAIEmbedding(
36
        model="text-embedding-ada-002",
37
        deployment_name="text-embedding-ada-002",
38
        api_key=api_key,
39
        api_base=api_base,
40
        api_type=api_type,
41
        api_version=api_version,
42
    )
43

44
llm = AzureOpenAI(
45
    model="gpt-3.5-turbo",
46
    engine="gpt-35-turbo",
47
    api_key=api_key,
48
    api_base=api_base,
49
    api_type=api_type,
50
    api_version=api_version,
51
)
52

53
token_counter = TokenCountingHandler(
54
    tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
55
)
56

57
callback_manager = CallbackManager([token_counter])
58

59
service_context = ServiceContext.from_defaults(
60
    # system_prompt=system_prompt,
61
    llm=llm,
62
    callback_manager=callback_manager,
63
    embed_model=embed_model,
64
)
65

66

67
def print_token_count(token_counter, embed_model, model="gpt-35-turbo"):
68
    print(
69
        "Embedding Tokens: ",
70
        token_counter.total_embedding_token_count,
71
        "\n",
72
        "LLM Prompt Tokens: ",
73
        token_counter.prompt_llm_token_count,
74
        "\n",
75
        "LLM Completion Tokens: ",
76
        token_counter.completion_llm_token_count,
77
        "\n",
78
        "Total LLM Token Count: ",
79
        token_counter.total_llm_token_count,
80
        "\n",
81
    )
82
    pricing = {
83
        'gpt-35-turbo': {'prompt': 0.0015, 'completion': 0.002},
84
        'gpt-35-turbo-16k': {'prompt': 0.003, 'completion': 0.004},
85
        'gpt-4-0613': {'prompt': 0.03, 'completion': 0.06},
86
        'gpt-4-32k': {'prompt': 0.06, 'completion': 0.12},
87
        'embedding': {'hugging_face': 0, 'text-embedding-ada-002': 0.0001}
88
    }
89
    print(
90
        "Embedding Cost: ",
91
        pricing['embedding'][embed_model] * token_counter.total_embedding_token_count/1000,
92
        "\n",
93
        "LLM Prompt Cost: ",
94
        pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000,
95
        "\n",
96
        "LLM Completion Cost: ",
97
        pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,
98
        "\n",
99
        "Total LLM Cost: ",
100
        pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000 + pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,
101
        "\n",
102
        "Total cost: ",
103
        pricing['embedding'][embed_model] * token_counter.total_embedding_token_count/1000 + pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000 + pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,
104
    )
105

106

107
if __name__ == "__main__":
108
    wiki_titles = ["Toronto", "Chicago", "Houston", "Boston", "Atlanta"]
109

110
    for title in wiki_titles:
111
        response = requests.get(
112
            "https://en.wikipedia.org/w/api.php",
113
            params={
114
                "action": "query",
115
                "format": "json",
116
                "titles": title,
117
                "prop": "extracts",
118
                # 'exintro': True,
119
                "explaintext": True,
120
            },
121
        ).json()
122
        page = next(iter(response["query"]["pages"].values()))
123
        wiki_text = page["extract"]
124

125
        data_path = Path("data")
126
        if not data_path.exists():
127
            Path.mkdir(data_path)
128

129
        with open(data_path / f"{title}.txt", "w") as fp:
130
            fp.write(wiki_text)
131

132
    # Load all wiki documents
133
    city_docs = {}
134
    for wiki_title in wiki_titles:
135
        city_docs[wiki_title] = SimpleDirectoryReader(
136
            input_files=[f"data/{wiki_title}.txt"]
137
        ).load_data()
138

139
    # # Build agents dictionary
140
    # agents = {}
141

142
    query_engine_tools = []
143
    for wiki_title in wiki_titles:
144
        # build vector index
145
        vector_index = VectorStoreIndex.from_documents(
146
            city_docs[wiki_title], service_context=service_context
147
        )
148
        # build summary index
149
        summary_index = SummaryIndex.from_documents(
150
            city_docs[wiki_title], service_context=service_context
151
        )
152
        # define query engines
153
        vector_query_engine = vector_index.as_query_engine()
154
        list_query_engine = summary_index.as_query_engine()
155

156
        # define tools
157
        query_engine_tools_per_doc = [
158
            QueryEngineTool(
159
                query_engine=vector_query_engine,
160
                metadata=ToolMetadata(
161
                    name=f"vector_tool_{wiki_title}",
162
                    description="Useful for questions related to specific aspects of"
163
                                f" {wiki_title} (e.g. the history, arts and culture,"
164
                                " sports, demographics, or more).",
165
                ),
166
            ),
167
            QueryEngineTool(
168
                query_engine=list_query_engine,
169
                metadata=ToolMetadata(
170
                    name=f"summary_tool_{wiki_title}",
171
                    description="Useful for any requests that require a holistic summary"
172
                                f" of EVERYTHING about {wiki_title}. For questions about"
173
                                " more specific sections, please use the"
174
                                f" vector_tool_{wiki_title}.",
175
                ),
176
            ),
177
        ]
178

179
        query_engine_tools.extend(query_engine_tools_per_doc)
180

181
        # build agent
182
        # function_llm = OpenAI(model="gpt-3.5-turbo-0613")
183
        # agent = OpenAIAgent.from_tools(
184
        #     query_engine_tools,
185
        #     llm=llm,
186
        #     verbose=True,
187
        # )
188

189
        # agents[wiki_title] = agent
190

191
    response_synthesizer = get_response_synthesizer(
192
        service_context=service_context,
193
        response_mode="compact",
194
    )
195

196
    sub_query_engine = SubQuestionQueryEngine.from_defaults(
197
        query_engine_tools=query_engine_tools,
198
        response_synthesizer=response_synthesizer,
199
        service_context=service_context,
200
        use_async=False,
201
        verbose=True,
202
    )
203

204
    question = "Which are the sports teams in Toronto?"
205
    print("Question: ", question)
206
    response = sub_query_engine.query(question)
207
    print_token_count(token_counter, embed_model_name)
208

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.