rag-demystified
/
llama_index_baseline.py
207 строк · 6.8 Кб
1from pathlib import Path2
3import requests4
5from llama_index import (6VectorStoreIndex,7SummaryIndex,8SimpleKeywordTableIndex,9SimpleDirectoryReader,10ServiceContext,11)
12from llama_index.schema import IndexNode13from llama_index.tools import QueryEngineTool, ToolMetadata14from llama_index.llms import OpenAI, AzureOpenAI15from llama_index.query_engine import SubQuestionQueryEngine16from llama_index.agent import OpenAIAgent17from llama_index.embeddings import HuggingFaceEmbedding, OpenAIEmbedding18from llama_index.callbacks import CallbackManager, TokenCountingHandler19from llama_index.response_synthesizers import get_response_synthesizer20import tiktoken21
22api_type = ""23api_base = ""24api_version = ""25api_key = ""26
27
28embed_model_name = "hugging_face"29
30if embed_model_name == "hugging_face":31embed_model = HuggingFaceEmbedding(32model_name="sentence-transformers/all-mpnet-base-v2", max_length=51233)34elif embed_model_name == "text-embedding-ada-002":35embed_model = OpenAIEmbedding(36model="text-embedding-ada-002",37deployment_name="text-embedding-ada-002",38api_key=api_key,39api_base=api_base,40api_type=api_type,41api_version=api_version,42)43
44llm = AzureOpenAI(45model="gpt-3.5-turbo",46engine="gpt-35-turbo",47api_key=api_key,48api_base=api_base,49api_type=api_type,50api_version=api_version,51)
52
53token_counter = TokenCountingHandler(54tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode55)
56
57callback_manager = CallbackManager([token_counter])58
59service_context = ServiceContext.from_defaults(60# system_prompt=system_prompt,61llm=llm,62callback_manager=callback_manager,63embed_model=embed_model,64)
65
66
67def print_token_count(token_counter, embed_model, model="gpt-35-turbo"):68print(69"Embedding Tokens: ",70token_counter.total_embedding_token_count,71"\n",72"LLM Prompt Tokens: ",73token_counter.prompt_llm_token_count,74"\n",75"LLM Completion Tokens: ",76token_counter.completion_llm_token_count,77"\n",78"Total LLM Token Count: ",79token_counter.total_llm_token_count,80"\n",81)82pricing = {83'gpt-35-turbo': {'prompt': 0.0015, 'completion': 0.002},84'gpt-35-turbo-16k': {'prompt': 0.003, 'completion': 0.004},85'gpt-4-0613': {'prompt': 0.03, 'completion': 0.06},86'gpt-4-32k': {'prompt': 0.06, 'completion': 0.12},87'embedding': {'hugging_face': 0, 'text-embedding-ada-002': 0.0001}88}89print(90"Embedding Cost: ",91pricing['embedding'][embed_model] * token_counter.total_embedding_token_count/1000,92"\n",93"LLM Prompt Cost: ",94pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000,95"\n",96"LLM Completion Cost: ",97pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,98"\n",99"Total LLM Cost: ",100pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000 + pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,101"\n",102"Total cost: ",103pricing['embedding'][embed_model] * token_counter.total_embedding_token_count/1000 + pricing[model]["prompt"] * token_counter.prompt_llm_token_count/1000 + pricing[model]["completion"] * token_counter.completion_llm_token_count/1000,104)105
106
107if __name__ == "__main__":108wiki_titles = ["Toronto", "Chicago", "Houston", "Boston", "Atlanta"]109
110for title in wiki_titles:111response = requests.get(112"https://en.wikipedia.org/w/api.php",113params={114"action": "query",115"format": "json",116"titles": title,117"prop": "extracts",118# 'exintro': True,119"explaintext": True,120},121).json()122page = next(iter(response["query"]["pages"].values()))123wiki_text = page["extract"]124
125data_path = Path("data")126if not data_path.exists():127Path.mkdir(data_path)128
129with open(data_path / f"{title}.txt", "w") as fp:130fp.write(wiki_text)131
132# Load all wiki documents133city_docs = {}134for wiki_title in wiki_titles:135city_docs[wiki_title] = SimpleDirectoryReader(136input_files=[f"data/{wiki_title}.txt"]137).load_data()138
139# # Build agents dictionary140# agents = {}141
142query_engine_tools = []143for wiki_title in wiki_titles:144# build vector index145vector_index = VectorStoreIndex.from_documents(146city_docs[wiki_title], service_context=service_context147)148# build summary index149summary_index = SummaryIndex.from_documents(150city_docs[wiki_title], service_context=service_context151)152# define query engines153vector_query_engine = vector_index.as_query_engine()154list_query_engine = summary_index.as_query_engine()155
156# define tools157query_engine_tools_per_doc = [158QueryEngineTool(159query_engine=vector_query_engine,160metadata=ToolMetadata(161name=f"vector_tool_{wiki_title}",162description="Useful for questions related to specific aspects of"163f" {wiki_title} (e.g. the history, arts and culture,"164" sports, demographics, or more).",165),166),167QueryEngineTool(168query_engine=list_query_engine,169metadata=ToolMetadata(170name=f"summary_tool_{wiki_title}",171description="Useful for any requests that require a holistic summary"172f" of EVERYTHING about {wiki_title}. For questions about"173" more specific sections, please use the"174f" vector_tool_{wiki_title}.",175),176),177]178
179query_engine_tools.extend(query_engine_tools_per_doc)180
181# build agent182# function_llm = OpenAI(model="gpt-3.5-turbo-0613")183# agent = OpenAIAgent.from_tools(184# query_engine_tools,185# llm=llm,186# verbose=True,187# )188
189# agents[wiki_title] = agent190
191response_synthesizer = get_response_synthesizer(192service_context=service_context,193response_mode="compact",194)195
196sub_query_engine = SubQuestionQueryEngine.from_defaults(197query_engine_tools=query_engine_tools,198response_synthesizer=response_synthesizer,199service_context=service_context,200use_async=False,201verbose=True,202)203
204question = "Which are the sports teams in Toronto?"205print("Question: ", question)206response = sub_query_engine.query(question)207print_token_count(token_counter, embed_model_name)208