llm-applications
286 строк · 8.3 Кб
1import json2import pickle3import re4import time5from pathlib import Path6
7from IPython.display import JSON, clear_output, display8from rank_bm25 import BM25Okapi9from tqdm import tqdm10
11from rag.config import EFS_DIR, ROOT_DIR12from rag.embed import get_embedding_model13from rag.index import load_index14from rag.rerank import custom_predict, get_reranked_indices15from rag.search import lexical_search, semantic_search16from rag.utils import get_client, get_num_tokens, trim17
18
19def response_stream(chat_completion):20for chunk in chat_completion:21content = chunk.choices[0].delta.content22if content is not None:23yield content24
25
26def prepare_response(chat_completion, stream):27if stream:28return response_stream(chat_completion)29else:30return chat_completion.choices[0].message.content31
32
33def send_request(34llm,35messages,36max_tokens=None,37temperature=0.0,38stream=False,39max_retries=1,40retry_interval=60,41):42retry_count = 043client = get_client(llm=llm)44while retry_count <= max_retries:45try:46chat_completion = client.chat.completions.create(47model=llm,48max_tokens=max_tokens,49temperature=temperature,50stream=stream,51messages=messages,52)53return prepare_response(chat_completion, stream=stream)54
55except Exception as e:56print(f"Exception: {e}")57time.sleep(retry_interval) # default is per-minute rate limits58retry_count += 159return ""60
61
62def generate_response(63llm,64max_tokens=None,65temperature=0.0,66stream=False,67system_content="",68assistant_content="",69user_content="",70max_retries=1,71retry_interval=60,72):73"""Generate response from an LLM."""74messages = [75{"role": role, "content": content}76for role, content in [77("system", system_content),78("assistant", assistant_content),79("user", user_content),80]81if content82]83return send_request(llm, messages, max_tokens, temperature, stream, max_retries, retry_interval)84
85
86class QueryAgent:87def __init__(88self,89embedding_model_name="thenlper/gte-base",90chunks=None,91lexical_index=None,92reranker=None,93llm="meta-llama/Llama-2-70b-chat-hf",94temperature=0.0,95max_context_length=4096,96system_content="",97assistant_content="",98):99# Embedding model100self.embedding_model = get_embedding_model(101embedding_model_name=embedding_model_name,102model_kwargs={"device": "cuda"},103encode_kwargs={"device": "cuda", "batch_size": 100},104)105
106# Lexical search107self.chunks = chunks108self.lexical_index = lexical_index109
110# Reranker111self.reranker = reranker112
113# LLM114self.llm = llm115self.temperature = temperature116self.context_length = int(1170.5 * max_context_length118) - get_num_tokens( # 50% of total context reserved for input119system_content + assistant_content120)121self.max_tokens = int(1220.5 * max_context_length123) # max sampled output (the other 50% of total context)124self.system_content = system_content125self.assistant_content = assistant_content126
127def __call__(128self,129query,130num_chunks=5,131lexical_search_k=1,132rerank_threshold=0.2,133rerank_k=7,134stream=True,135):136# Get top_k context137context_results = semantic_search(138query=query, embedding_model=self.embedding_model, k=num_chunks139)140
141# Add lexical search results142if self.lexical_index:143lexical_context = lexical_search(144index=self.lexical_index, query=query, chunks=self.chunks, k=lexical_search_k145)146# Insert after <lexical_search_k> worth of semantic results147context_results[lexical_search_k:lexical_search_k] = lexical_context148
149# Rerank150if self.reranker:151predicted_tag = custom_predict(152inputs=[query], classifier=self.reranker, threshold=rerank_threshold153)[0]154if predicted_tag != "other":155sources = [item["source"] for item in context_results]156reranked_indices = get_reranked_indices(sources, predicted_tag)157context_results = [context_results[i] for i in reranked_indices]158context_results = context_results[:rerank_k]159
160# Generate response161document_ids = [item["id"] for item in context_results]162context = [item["text"] for item in context_results]163sources = set([item["source"] for item in context_results])164user_content = f"query: {query}, context: {context}"165answer = generate_response(166llm=self.llm,167max_tokens=self.max_tokens,168temperature=self.temperature,169stream=stream,170system_content=self.system_content,171assistant_content=self.assistant_content,172user_content=trim(user_content, self.context_length),173)174
175# Result176result = {177"question": query,178"sources": sources,179"document_ids": document_ids,180"answer": answer,181"llm": self.llm,182}183return result184
185
186# Generate responses
187def generate_responses(188experiment_name,189chunk_size,190chunk_overlap,191num_chunks,192embedding_model_name,193embedding_dim,194use_lexical_search,195lexical_search_k,196use_reranking,197rerank_threshold,198rerank_k,199llm,200temperature,201max_context_length,202system_content,203assistant_content,204docs_dir,205experiments_dir,206references_fp,207num_samples=None,208sql_dump_fp=None,209):210# Build index211chunks = load_index(212embedding_model_name=embedding_model_name,213embedding_dim=embedding_dim,214chunk_size=chunk_size,215chunk_overlap=chunk_overlap,216docs_dir=docs_dir,217sql_dump_fp=sql_dump_fp,218)219
220# Lexical index221lexical_index = None222if use_lexical_search:223texts = [re.sub(r"[^a-zA-Z0-9]", " ", chunk[1]).lower().split() for chunk in chunks]224lexical_index = BM25Okapi(texts)225
226# Reranker227reranker = None228if use_reranking:229reranker_fp = Path(EFS_DIR, "reranker.pkl")230with open(reranker_fp, "rb") as file:231reranker = pickle.load(file)232
233# Query agent234agent = QueryAgent(235embedding_model_name=embedding_model_name,236chunks=chunks,237lexical_index=lexical_index,238reranker=reranker,239llm=llm,240temperature=temperature,241system_content=system_content,242assistant_content=assistant_content,243)244
245# Generate responses246results = []247with open(Path(references_fp), "r") as f:248questions = [item["question"] for item in json.load(f)][:num_samples]249for query in tqdm(questions):250result = agent(251query=query,252num_chunks=num_chunks,253lexical_search_k=lexical_search_k,254rerank_threshold=rerank_threshold,255rerank_k=rerank_k,256stream=False,257)258results.append(result)259clear_output(wait=True)260display(JSON(json.dumps(result, indent=2)))261
262# Save to file263responses_fp = Path(ROOT_DIR, experiments_dir, "responses", f"{experiment_name}.json")264responses_fp.parent.mkdir(parents=True, exist_ok=True)265config = {266"experiment_name": experiment_name,267"chunk_size": chunk_size,268"chunk_overlap": chunk_overlap,269"num_chunks": num_chunks,270"embedding_model_name": embedding_model_name,271"llm": llm,272"temperature": temperature,273"max_context_length": max_context_length,274"system_content": system_content,275"assistant_content": assistant_content,276"docs_dir": str(docs_dir),277"experiments_dir": str(experiments_dir),278"references_fp": str(references_fp),279"num_samples": len(questions),280}281responses = {282"config": config,283"results": results,284}285with open(responses_fp, "w") as fp:286json.dump(responses, fp, indent=4)287