raptor

Форк
0
/
demo.ipynb 
337 строк · 9.9 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "912cd8c6-d405-4dfe-8897-46108e6a6af7",
6
   "metadata": {},
7
   "source": [
8
    "# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval"
9
   ]
10
  },
11
  {
12
   "cell_type": "code",
13
   "execution_count": null,
14
   "id": "631b09a3",
15
   "metadata": {},
16
   "outputs": [],
17
   "source": [
18
    "# NOTE: An OpenAI API key must be set here for application initialization, even if not in use.\n",
19
    "# If you're not utilizing OpenAI models, assign a placeholder string (e.g., \"not_used\").\n",
20
    "import os\n",
21
    "os.environ[\"OPENAI_API_KEY\"] = \"your-openai-key\""
22
   ]
23
  },
24
  {
25
   "cell_type": "code",
26
   "execution_count": null,
27
   "id": "e2d7d995-7beb-40b5-9a44-afd350b7d221",
28
   "metadata": {},
29
   "outputs": [],
30
   "source": [
31
    "# Cinderella story defined in sample.txt\n",
32
    "with open('demo/sample.txt', 'r') as file:\n",
33
    "    text = file.read()\n",
34
    "\n",
35
    "print(text[:100])"
36
   ]
37
  },
38
  {
39
   "cell_type": "markdown",
40
   "id": "c7d51ebd-5597-4fdd-8c37-32636395081b",
41
   "metadata": {},
42
   "source": [
43
    "1) **Building**: RAPTOR recursively embeds, clusters, and summarizes chunks of text to construct a tree with varying levels of summarization from the bottom up. You can create a tree from the text in 'sample.txt' using `RA.add_documents(text)`.\n",
44
    "\n",
45
    "2) **Querying**: At inference time, the RAPTOR model retrieves information from this tree, integrating data across lengthy documents at different abstraction levels. You can perform queries on the tree with `RA.answer_question`."
46
   ]
47
  },
48
  {
49
   "cell_type": "markdown",
50
   "id": "f4f58830-9004-48a4-b50e-61a855511d24",
51
   "metadata": {},
52
   "source": [
53
    "### Building the tree"
54
   ]
55
  },
56
  {
57
   "cell_type": "code",
58
   "execution_count": null,
59
   "id": "3753fcf9-0a8e-4ab3-bf3a-6be38ef6cd1e",
60
   "metadata": {},
61
   "outputs": [],
62
   "source": [
63
    "from raptor import RetrievalAugmentation "
64
   ]
65
  },
66
  {
67
   "cell_type": "code",
68
   "execution_count": null,
69
   "id": "7e843edf",
70
   "metadata": {},
71
   "outputs": [],
72
   "source": [
73
    "RA = RetrievalAugmentation()\n",
74
    "\n",
75
    "# construct the tree\n",
76
    "RA.add_documents(text)"
77
   ]
78
  },
79
  {
80
   "cell_type": "markdown",
81
   "id": "f219d60a-1f0b-4cee-89eb-2ae026f13e63",
82
   "metadata": {},
83
   "source": [
84
    "### Querying from the tree\n",
85
    "\n",
86
    "```python\n",
87
    "question = # any question\n",
88
    "RA.answer_question(question)\n",
89
    "```"
90
   ]
91
  },
92
  {
93
   "cell_type": "code",
94
   "execution_count": null,
95
   "id": "1b4037c5-ad5a-424b-80e4-a67b8e00773b",
96
   "metadata": {},
97
   "outputs": [],
98
   "source": [
99
    "question = \"How did Cinderella reach her happy ending ?\"\n",
100
    "\n",
101
    "answer = RA.answer_question(question=question)\n",
102
    "\n",
103
    "print(\"Answer: \", answer)"
104
   ]
105
  },
106
  {
107
   "cell_type": "code",
108
   "execution_count": null,
109
   "id": "f5be7e57",
110
   "metadata": {},
111
   "outputs": [],
112
   "source": [
113
    "# Save the tree by calling RA.save(\"path/to/save\")\n",
114
    "SAVE_PATH = \"demo/cinderella\"\n",
115
    "RA.save(SAVE_PATH)"
116
   ]
117
  },
118
  {
119
   "cell_type": "code",
120
   "execution_count": null,
121
   "id": "2e845de9",
122
   "metadata": {},
123
   "outputs": [],
124
   "source": [
125
    "# load back the tree by passing it into RetrievalAugmentation\n",
126
    "\n",
127
    "RA = RetrievalAugmentation(tree=SAVE_PATH)\n",
128
    "\n",
129
    "answer = RA.answer_question(question=question)\n",
130
    "print(\"Answer: \", answer)"
131
   ]
132
  },
133
  {
134
   "cell_type": "markdown",
135
   "id": "277ab6ea-1c79-4ed1-97de-1c2e39d6db2e",
136
   "metadata": {},
137
   "source": [
138
    "## Using other Open Source Models for Summarization/QA/Embeddings\n",
139
    "\n",
140
    "If you want to use other models such as Llama or Mistral, you can very easily define your own models and use them with RAPTOR. "
141
   ]
142
  },
143
  {
144
   "cell_type": "code",
145
   "execution_count": null,
146
   "id": "f86cbe7e",
147
   "metadata": {},
148
   "outputs": [],
149
   "source": [
150
    "import torch\n",
151
    "from raptor import BaseSummarizationModel, BaseQAModel, BaseEmbeddingModel, RetrievalAugmentationConfig\n",
152
    "from transformers import AutoTokenizer, pipeline"
153
   ]
154
  },
155
  {
156
   "cell_type": "code",
157
   "execution_count": null,
158
   "id": "fe5cef43",
159
   "metadata": {},
160
   "outputs": [],
161
   "source": [
162
    "# if you want to use the Gemma, you will need to authenticate with HuggingFace, Skip this step, if you have the model already downloaded\n",
163
    "from huggingface_hub import login\n",
164
    "login()"
165
   ]
166
  },
167
  {
168
   "cell_type": "code",
169
   "execution_count": null,
170
   "id": "245b91a5",
171
   "metadata": {},
172
   "outputs": [],
173
   "source": [
174
    "from transformers import AutoTokenizer, pipeline\n",
175
    "import torch\n",
176
    "\n",
177
    "# You can define your own Summarization model by extending the base Summarization Class. \n",
178
    "class GEMMASummarizationModel(BaseSummarizationModel):\n",
179
    "    def __init__(self, model_name=\"google/gemma-2b-it\"):\n",
180
    "        # Initialize the tokenizer and the pipeline for the GEMMA model\n",
181
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
182
    "        self.summarization_pipeline = pipeline(\n",
183
    "            \"text-generation\",\n",
184
    "            model=model_name,\n",
185
    "            model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
186
    "            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),  # Use \"cpu\" if CUDA is not available\n",
187
    "        )\n",
188
    "\n",
189
    "    def summarize(self, context, max_tokens=150):\n",
190
    "        # Format the prompt for summarization\n",
191
    "        messages=[\n",
192
    "            {\"role\": \"user\", \"content\": f\"Write a summary of the following, including as many key details as possible: {context}:\"}\n",
193
    "        ]\n",
194
    "        \n",
195
    "        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
196
    "        \n",
197
    "        # Generate the summary using the pipeline\n",
198
    "        outputs = self.summarization_pipeline(\n",
199
    "            prompt,\n",
200
    "            max_new_tokens=max_tokens,\n",
201
    "            do_sample=True,\n",
202
    "            temperature=0.7,\n",
203
    "            top_k=50,\n",
204
    "            top_p=0.95\n",
205
    "        )\n",
206
    "        \n",
207
    "        # Extracting and returning the generated summary\n",
208
    "        summary = outputs[0][\"generated_text\"].strip()\n",
209
    "        return summary\n"
210
   ]
211
  },
212
  {
213
   "cell_type": "code",
214
   "execution_count": null,
215
   "id": "a171496d",
216
   "metadata": {},
217
   "outputs": [],
218
   "source": [
219
    "class GEMMAQAModel(BaseQAModel):\n",
220
    "    def __init__(self, model_name= \"google/gemma-2b-it\"):\n",
221
    "        # Initialize the tokenizer and the pipeline for the model\n",
222
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
223
    "        self.qa_pipeline = pipeline(\n",
224
    "            \"text-generation\",\n",
225
    "            model=model_name,\n",
226
    "            model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
227
    "            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
228
    "        )\n",
229
    "\n",
230
    "    def answer_question(self, context, question):\n",
231
    "        # Apply the chat template for the context and question\n",
232
    "        messages=[\n",
233
    "              {\"role\": \"user\", \"content\": f\"Given Context: {context} Give the best full answer amongst the option to question {question}\"}\n",
234
    "        ]\n",
235
    "        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
236
    "        \n",
237
    "        # Generate the answer using the pipeline\n",
238
    "        outputs = self.qa_pipeline(\n",
239
    "            prompt,\n",
240
    "            max_new_tokens=256,\n",
241
    "            do_sample=True,\n",
242
    "            temperature=0.7,\n",
243
    "            top_k=50,\n",
244
    "            top_p=0.95\n",
245
    "        )\n",
246
    "        \n",
247
    "        # Extracting and returning the generated answer\n",
248
    "        answer = outputs[0][\"generated_text\"][len(prompt):]\n",
249
    "        return answer"
250
   ]
251
  },
252
  {
253
   "cell_type": "code",
254
   "execution_count": null,
255
   "id": "878f7c7b",
256
   "metadata": {},
257
   "outputs": [],
258
   "source": [
259
    "from sentence_transformers import SentenceTransformer\n",
260
    "class SBertEmbeddingModel(BaseEmbeddingModel):\n",
261
    "    def __init__(self, model_name=\"sentence-transformers/multi-qa-mpnet-base-cos-v1\"):\n",
262
    "        self.model = SentenceTransformer(model_name)\n",
263
    "\n",
264
    "    def create_embedding(self, text):\n",
265
    "        return self.model.encode(text)\n"
266
   ]
267
  },
268
  {
269
   "cell_type": "code",
270
   "execution_count": null,
271
   "id": "255791ce",
272
   "metadata": {},
273
   "outputs": [],
274
   "source": [
275
    "RAC = RetrievalAugmentationConfig(summarization_model=GEMMASummarizationModel(), qa_model=GEMMAQAModel(), embedding_model=SBertEmbeddingModel())"
276
   ]
277
  },
278
  {
279
   "cell_type": "code",
280
   "execution_count": null,
281
   "id": "fee46f1d",
282
   "metadata": {},
283
   "outputs": [],
284
   "source": [
285
    "RA = RetrievalAugmentation(config=RAC)"
286
   ]
287
  },
288
  {
289
   "cell_type": "code",
290
   "execution_count": null,
291
   "id": "afe05daf",
292
   "metadata": {},
293
   "outputs": [],
294
   "source": [
295
    "with open('demo/sample.txt', 'r') as file:\n",
296
    "    text = file.read()\n",
297
    "    \n",
298
    "RA.add_documents(text)"
299
   ]
300
  },
301
  {
302
   "cell_type": "code",
303
   "execution_count": null,
304
   "id": "7eee5847",
305
   "metadata": {},
306
   "outputs": [],
307
   "source": [
308
    "question = \"How did Cinderella reach her happy ending?\"\n",
309
    "\n",
310
    "answer = RA.answer_question(question=question)\n",
311
    "\n",
312
    "print(\"Answer: \", answer)"
313
   ]
314
  }
315
 ],
316
 "metadata": {
317
  "kernelspec": {
318
   "display_name": "RAPTOR_env",
319
   "language": "python",
320
   "name": "raptor_env"
321
  },
322
  "language_info": {
323
   "codemirror_mode": {
324
    "name": "ipython",
325
    "version": 3
326
   },
327
   "file_extension": ".py",
328
   "mimetype": "text/x-python",
329
   "name": "python",
330
   "nbconvert_exporter": "python",
331
   "pygments_lexer": "ipython3",
332
   "version": "3.8.16"
333
  }
334
 },
335
 "nbformat": 4,
336
 "nbformat_minor": 5
337
}
338

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.