llm-applications

Форк
0
/
generate.py 
286 строк · 8.3 Кб
1
import json
2
import pickle
3
import re
4
import time
5
from pathlib import Path
6

7
from IPython.display import JSON, clear_output, display
8
from rank_bm25 import BM25Okapi
9
from tqdm import tqdm
10

11
from rag.config import EFS_DIR, ROOT_DIR
12
from rag.embed import get_embedding_model
13
from rag.index import load_index
14
from rag.rerank import custom_predict, get_reranked_indices
15
from rag.search import lexical_search, semantic_search
16
from rag.utils import get_client, get_num_tokens, trim
17

18

19
def response_stream(chat_completion):
20
    for chunk in chat_completion:
21
        content = chunk.choices[0].delta.content
22
        if content is not None:
23
            yield content
24

25

26
def prepare_response(chat_completion, stream):
27
    if stream:
28
        return response_stream(chat_completion)
29
    else:
30
        return chat_completion.choices[0].message.content
31

32

33
def send_request(
34
    llm,
35
    messages,
36
    max_tokens=None,
37
    temperature=0.0,
38
    stream=False,
39
    max_retries=1,
40
    retry_interval=60,
41
):
42
    retry_count = 0
43
    client = get_client(llm=llm)
44
    while retry_count <= max_retries:
45
        try:
46
            chat_completion = client.chat.completions.create(
47
                model=llm,
48
                max_tokens=max_tokens,
49
                temperature=temperature,
50
                stream=stream,
51
                messages=messages,
52
            )
53
            return prepare_response(chat_completion, stream=stream)
54

55
        except Exception as e:
56
            print(f"Exception: {e}")
57
            time.sleep(retry_interval)  # default is per-minute rate limits
58
            retry_count += 1
59
    return ""
60

61

62
def generate_response(
63
    llm,
64
    max_tokens=None,
65
    temperature=0.0,
66
    stream=False,
67
    system_content="",
68
    assistant_content="",
69
    user_content="",
70
    max_retries=1,
71
    retry_interval=60,
72
):
73
    """Generate response from an LLM."""
74
    messages = [
75
        {"role": role, "content": content}
76
        for role, content in [
77
            ("system", system_content),
78
            ("assistant", assistant_content),
79
            ("user", user_content),
80
        ]
81
        if content
82
    ]
83
    return send_request(llm, messages, max_tokens, temperature, stream, max_retries, retry_interval)
84

85

86
class QueryAgent:
87
    def __init__(
88
        self,
89
        embedding_model_name="thenlper/gte-base",
90
        chunks=None,
91
        lexical_index=None,
92
        reranker=None,
93
        llm="meta-llama/Llama-2-70b-chat-hf",
94
        temperature=0.0,
95
        max_context_length=4096,
96
        system_content="",
97
        assistant_content="",
98
    ):
99
        # Embedding model
100
        self.embedding_model = get_embedding_model(
101
            embedding_model_name=embedding_model_name,
102
            model_kwargs={"device": "cuda"},
103
            encode_kwargs={"device": "cuda", "batch_size": 100},
104
        )
105

106
        # Lexical search
107
        self.chunks = chunks
108
        self.lexical_index = lexical_index
109

110
        # Reranker
111
        self.reranker = reranker
112

113
        # LLM
114
        self.llm = llm
115
        self.temperature = temperature
116
        self.context_length = int(
117
            0.5 * max_context_length
118
        ) - get_num_tokens(  # 50% of total context reserved for input
119
            system_content + assistant_content
120
        )
121
        self.max_tokens = int(
122
            0.5 * max_context_length
123
        )  # max sampled output (the other 50% of total context)
124
        self.system_content = system_content
125
        self.assistant_content = assistant_content
126

127
    def __call__(
128
        self,
129
        query,
130
        num_chunks=5,
131
        lexical_search_k=1,
132
        rerank_threshold=0.2,
133
        rerank_k=7,
134
        stream=True,
135
    ):
136
        # Get top_k context
137
        context_results = semantic_search(
138
            query=query, embedding_model=self.embedding_model, k=num_chunks
139
        )
140

141
        # Add lexical search results
142
        if self.lexical_index:
143
            lexical_context = lexical_search(
144
                index=self.lexical_index, query=query, chunks=self.chunks, k=lexical_search_k
145
            )
146
            # Insert after <lexical_search_k> worth of semantic results
147
            context_results[lexical_search_k:lexical_search_k] = lexical_context
148

149
        # Rerank
150
        if self.reranker:
151
            predicted_tag = custom_predict(
152
                inputs=[query], classifier=self.reranker, threshold=rerank_threshold
153
            )[0]
154
            if predicted_tag != "other":
155
                sources = [item["source"] for item in context_results]
156
                reranked_indices = get_reranked_indices(sources, predicted_tag)
157
                context_results = [context_results[i] for i in reranked_indices]
158
            context_results = context_results[:rerank_k]
159

160
        # Generate response
161
        document_ids = [item["id"] for item in context_results]
162
        context = [item["text"] for item in context_results]
163
        sources = set([item["source"] for item in context_results])
164
        user_content = f"query: {query}, context: {context}"
165
        answer = generate_response(
166
            llm=self.llm,
167
            max_tokens=self.max_tokens,
168
            temperature=self.temperature,
169
            stream=stream,
170
            system_content=self.system_content,
171
            assistant_content=self.assistant_content,
172
            user_content=trim(user_content, self.context_length),
173
        )
174

175
        # Result
176
        result = {
177
            "question": query,
178
            "sources": sources,
179
            "document_ids": document_ids,
180
            "answer": answer,
181
            "llm": self.llm,
182
        }
183
        return result
184

185

186
# Generate responses
187
def generate_responses(
188
    experiment_name,
189
    chunk_size,
190
    chunk_overlap,
191
    num_chunks,
192
    embedding_model_name,
193
    embedding_dim,
194
    use_lexical_search,
195
    lexical_search_k,
196
    use_reranking,
197
    rerank_threshold,
198
    rerank_k,
199
    llm,
200
    temperature,
201
    max_context_length,
202
    system_content,
203
    assistant_content,
204
    docs_dir,
205
    experiments_dir,
206
    references_fp,
207
    num_samples=None,
208
    sql_dump_fp=None,
209
):
210
    # Build index
211
    chunks = load_index(
212
        embedding_model_name=embedding_model_name,
213
        embedding_dim=embedding_dim,
214
        chunk_size=chunk_size,
215
        chunk_overlap=chunk_overlap,
216
        docs_dir=docs_dir,
217
        sql_dump_fp=sql_dump_fp,
218
    )
219

220
    # Lexical index
221
    lexical_index = None
222
    if use_lexical_search:
223
        texts = [re.sub(r"[^a-zA-Z0-9]", " ", chunk[1]).lower().split() for chunk in chunks]
224
        lexical_index = BM25Okapi(texts)
225

226
    # Reranker
227
    reranker = None
228
    if use_reranking:
229
        reranker_fp = Path(EFS_DIR, "reranker.pkl")
230
        with open(reranker_fp, "rb") as file:
231
            reranker = pickle.load(file)
232

233
    # Query agent
234
    agent = QueryAgent(
235
        embedding_model_name=embedding_model_name,
236
        chunks=chunks,
237
        lexical_index=lexical_index,
238
        reranker=reranker,
239
        llm=llm,
240
        temperature=temperature,
241
        system_content=system_content,
242
        assistant_content=assistant_content,
243
    )
244

245
    # Generate responses
246
    results = []
247
    with open(Path(references_fp), "r") as f:
248
        questions = [item["question"] for item in json.load(f)][:num_samples]
249
    for query in tqdm(questions):
250
        result = agent(
251
            query=query,
252
            num_chunks=num_chunks,
253
            lexical_search_k=lexical_search_k,
254
            rerank_threshold=rerank_threshold,
255
            rerank_k=rerank_k,
256
            stream=False,
257
        )
258
        results.append(result)
259
        clear_output(wait=True)
260
        display(JSON(json.dumps(result, indent=2)))
261

262
    # Save to file
263
    responses_fp = Path(ROOT_DIR, experiments_dir, "responses", f"{experiment_name}.json")
264
    responses_fp.parent.mkdir(parents=True, exist_ok=True)
265
    config = {
266
        "experiment_name": experiment_name,
267
        "chunk_size": chunk_size,
268
        "chunk_overlap": chunk_overlap,
269
        "num_chunks": num_chunks,
270
        "embedding_model_name": embedding_model_name,
271
        "llm": llm,
272
        "temperature": temperature,
273
        "max_context_length": max_context_length,
274
        "system_content": system_content,
275
        "assistant_content": assistant_content,
276
        "docs_dir": str(docs_dir),
277
        "experiments_dir": str(experiments_dir),
278
        "references_fp": str(references_fp),
279
        "num_samples": len(questions),
280
    }
281
    responses = {
282
        "config": config,
283
        "results": results,
284
    }
285
    with open(responses_fp, "w") as fp:
286
        json.dump(responses, fp, indent=4)
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.