raptor
/
demo.ipynb
337 строк · 9.9 Кб
1{
2"cells": [
3{
4"cell_type": "markdown",
5"id": "912cd8c6-d405-4dfe-8897-46108e6a6af7",
6"metadata": {},
7"source": [
8"# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval"
9]
10},
11{
12"cell_type": "code",
13"execution_count": null,
14"id": "631b09a3",
15"metadata": {},
16"outputs": [],
17"source": [
18"# NOTE: An OpenAI API key must be set here for application initialization, even if not in use.\n",
19"# If you're not utilizing OpenAI models, assign a placeholder string (e.g., \"not_used\").\n",
20"import os\n",
21"os.environ[\"OPENAI_API_KEY\"] = \"your-openai-key\""
22]
23},
24{
25"cell_type": "code",
26"execution_count": null,
27"id": "e2d7d995-7beb-40b5-9a44-afd350b7d221",
28"metadata": {},
29"outputs": [],
30"source": [
31"# Cinderella story defined in sample.txt\n",
32"with open('demo/sample.txt', 'r') as file:\n",
33" text = file.read()\n",
34"\n",
35"print(text[:100])"
36]
37},
38{
39"cell_type": "markdown",
40"id": "c7d51ebd-5597-4fdd-8c37-32636395081b",
41"metadata": {},
42"source": [
43"1) **Building**: RAPTOR recursively embeds, clusters, and summarizes chunks of text to construct a tree with varying levels of summarization from the bottom up. You can create a tree from the text in 'sample.txt' using `RA.add_documents(text)`.\n",
44"\n",
45"2) **Querying**: At inference time, the RAPTOR model retrieves information from this tree, integrating data across lengthy documents at different abstraction levels. You can perform queries on the tree with `RA.answer_question`."
46]
47},
48{
49"cell_type": "markdown",
50"id": "f4f58830-9004-48a4-b50e-61a855511d24",
51"metadata": {},
52"source": [
53"### Building the tree"
54]
55},
56{
57"cell_type": "code",
58"execution_count": null,
59"id": "3753fcf9-0a8e-4ab3-bf3a-6be38ef6cd1e",
60"metadata": {},
61"outputs": [],
62"source": [
63"from raptor import RetrievalAugmentation "
64]
65},
66{
67"cell_type": "code",
68"execution_count": null,
69"id": "7e843edf",
70"metadata": {},
71"outputs": [],
72"source": [
73"RA = RetrievalAugmentation()\n",
74"\n",
75"# construct the tree\n",
76"RA.add_documents(text)"
77]
78},
79{
80"cell_type": "markdown",
81"id": "f219d60a-1f0b-4cee-89eb-2ae026f13e63",
82"metadata": {},
83"source": [
84"### Querying from the tree\n",
85"\n",
86"```python\n",
87"question = # any question\n",
88"RA.answer_question(question)\n",
89"```"
90]
91},
92{
93"cell_type": "code",
94"execution_count": null,
95"id": "1b4037c5-ad5a-424b-80e4-a67b8e00773b",
96"metadata": {},
97"outputs": [],
98"source": [
99"question = \"How did Cinderella reach her happy ending ?\"\n",
100"\n",
101"answer = RA.answer_question(question=question)\n",
102"\n",
103"print(\"Answer: \", answer)"
104]
105},
106{
107"cell_type": "code",
108"execution_count": null,
109"id": "f5be7e57",
110"metadata": {},
111"outputs": [],
112"source": [
113"# Save the tree by calling RA.save(\"path/to/save\")\n",
114"SAVE_PATH = \"demo/cinderella\"\n",
115"RA.save(SAVE_PATH)"
116]
117},
118{
119"cell_type": "code",
120"execution_count": null,
121"id": "2e845de9",
122"metadata": {},
123"outputs": [],
124"source": [
125"# load back the tree by passing it into RetrievalAugmentation\n",
126"\n",
127"RA = RetrievalAugmentation(tree=SAVE_PATH)\n",
128"\n",
129"answer = RA.answer_question(question=question)\n",
130"print(\"Answer: \", answer)"
131]
132},
133{
134"cell_type": "markdown",
135"id": "277ab6ea-1c79-4ed1-97de-1c2e39d6db2e",
136"metadata": {},
137"source": [
138"## Using other Open Source Models for Summarization/QA/Embeddings\n",
139"\n",
140"If you want to use other models such as Llama or Mistral, you can very easily define your own models and use them with RAPTOR. "
141]
142},
143{
144"cell_type": "code",
145"execution_count": null,
146"id": "f86cbe7e",
147"metadata": {},
148"outputs": [],
149"source": [
150"import torch\n",
151"from raptor import BaseSummarizationModel, BaseQAModel, BaseEmbeddingModel, RetrievalAugmentationConfig\n",
152"from transformers import AutoTokenizer, pipeline"
153]
154},
155{
156"cell_type": "code",
157"execution_count": null,
158"id": "fe5cef43",
159"metadata": {},
160"outputs": [],
161"source": [
162"# if you want to use the Gemma, you will need to authenticate with HuggingFace, Skip this step, if you have the model already downloaded\n",
163"from huggingface_hub import login\n",
164"login()"
165]
166},
167{
168"cell_type": "code",
169"execution_count": null,
170"id": "245b91a5",
171"metadata": {},
172"outputs": [],
173"source": [
174"from transformers import AutoTokenizer, pipeline\n",
175"import torch\n",
176"\n",
177"# You can define your own Summarization model by extending the base Summarization Class. \n",
178"class GEMMASummarizationModel(BaseSummarizationModel):\n",
179" def __init__(self, model_name=\"google/gemma-2b-it\"):\n",
180" # Initialize the tokenizer and the pipeline for the GEMMA model\n",
181" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
182" self.summarization_pipeline = pipeline(\n",
183" \"text-generation\",\n",
184" model=model_name,\n",
185" model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
186" device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), # Use \"cpu\" if CUDA is not available\n",
187" )\n",
188"\n",
189" def summarize(self, context, max_tokens=150):\n",
190" # Format the prompt for summarization\n",
191" messages=[\n",
192" {\"role\": \"user\", \"content\": f\"Write a summary of the following, including as many key details as possible: {context}:\"}\n",
193" ]\n",
194" \n",
195" prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
196" \n",
197" # Generate the summary using the pipeline\n",
198" outputs = self.summarization_pipeline(\n",
199" prompt,\n",
200" max_new_tokens=max_tokens,\n",
201" do_sample=True,\n",
202" temperature=0.7,\n",
203" top_k=50,\n",
204" top_p=0.95\n",
205" )\n",
206" \n",
207" # Extracting and returning the generated summary\n",
208" summary = outputs[0][\"generated_text\"].strip()\n",
209" return summary\n"
210]
211},
212{
213"cell_type": "code",
214"execution_count": null,
215"id": "a171496d",
216"metadata": {},
217"outputs": [],
218"source": [
219"class GEMMAQAModel(BaseQAModel):\n",
220" def __init__(self, model_name= \"google/gemma-2b-it\"):\n",
221" # Initialize the tokenizer and the pipeline for the model\n",
222" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
223" self.qa_pipeline = pipeline(\n",
224" \"text-generation\",\n",
225" model=model_name,\n",
226" model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
227" device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
228" )\n",
229"\n",
230" def answer_question(self, context, question):\n",
231" # Apply the chat template for the context and question\n",
232" messages=[\n",
233" {\"role\": \"user\", \"content\": f\"Given Context: {context} Give the best full answer amongst the option to question {question}\"}\n",
234" ]\n",
235" prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
236" \n",
237" # Generate the answer using the pipeline\n",
238" outputs = self.qa_pipeline(\n",
239" prompt,\n",
240" max_new_tokens=256,\n",
241" do_sample=True,\n",
242" temperature=0.7,\n",
243" top_k=50,\n",
244" top_p=0.95\n",
245" )\n",
246" \n",
247" # Extracting and returning the generated answer\n",
248" answer = outputs[0][\"generated_text\"][len(prompt):]\n",
249" return answer"
250]
251},
252{
253"cell_type": "code",
254"execution_count": null,
255"id": "878f7c7b",
256"metadata": {},
257"outputs": [],
258"source": [
259"from sentence_transformers import SentenceTransformer\n",
260"class SBertEmbeddingModel(BaseEmbeddingModel):\n",
261" def __init__(self, model_name=\"sentence-transformers/multi-qa-mpnet-base-cos-v1\"):\n",
262" self.model = SentenceTransformer(model_name)\n",
263"\n",
264" def create_embedding(self, text):\n",
265" return self.model.encode(text)\n"
266]
267},
268{
269"cell_type": "code",
270"execution_count": null,
271"id": "255791ce",
272"metadata": {},
273"outputs": [],
274"source": [
275"RAC = RetrievalAugmentationConfig(summarization_model=GEMMASummarizationModel(), qa_model=GEMMAQAModel(), embedding_model=SBertEmbeddingModel())"
276]
277},
278{
279"cell_type": "code",
280"execution_count": null,
281"id": "fee46f1d",
282"metadata": {},
283"outputs": [],
284"source": [
285"RA = RetrievalAugmentation(config=RAC)"
286]
287},
288{
289"cell_type": "code",
290"execution_count": null,
291"id": "afe05daf",
292"metadata": {},
293"outputs": [],
294"source": [
295"with open('demo/sample.txt', 'r') as file:\n",
296" text = file.read()\n",
297" \n",
298"RA.add_documents(text)"
299]
300},
301{
302"cell_type": "code",
303"execution_count": null,
304"id": "7eee5847",
305"metadata": {},
306"outputs": [],
307"source": [
308"question = \"How did Cinderella reach her happy ending?\"\n",
309"\n",
310"answer = RA.answer_question(question=question)\n",
311"\n",
312"print(\"Answer: \", answer)"
313]
314}
315],
316"metadata": {
317"kernelspec": {
318"display_name": "RAPTOR_env",
319"language": "python",
320"name": "raptor_env"
321},
322"language_info": {
323"codemirror_mode": {
324"name": "ipython",
325"version": 3
326},
327"file_extension": ".py",
328"mimetype": "text/x-python",
329"name": "python",
330"nbconvert_exporter": "python",
331"pygments_lexer": "ipython3",
332"version": "3.8.16"
333}
334},
335"nbformat": 4,
336"nbformat_minor": 5
337}
338