GenerativeAIExamples

Форк
0
/
08_Option(2)_llama_index_with_HF_local_LLM.ipynb 
391 строка · 13.7 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "cbd5a364",
6
   "metadata": {},
7
   "source": [
8
    "# HF Checkpoints with LlamaIndex and LangChain\n",
9
    "\n",
10
    "This notebook demonstrates how to plug in a local llm from [HuggingFace Hub Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and [all-MiniLM-L6-v2 embedding from Huggingface](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), bind these to into [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/) with these customizations.\n",
11
    "\n",
12
    "The custom plug-ins shown in this notebook can be replaced, for example, you can swap out the [HuggingFace Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) with [HuggingFace checkpoint from Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1).\n",
13
    "\n",
14
    "\n",
15
    "<div class=\"alert alert-block alert-info\">\n",
16
    "    \n",
17
    "⚠️ The notebook before this one, `08_Option(1)_llama_index_with_NVIDIA_AI_endpoint.ipynb`, contains the same exercise as this notebook but uses NVIDIA AI Catelog's models via API calls instead of loading the models' checkpoints pulled from huggingface model hub, and then load from host to devices (i.e GPUs).\n",
18
    "\n",
19
    "Noted that, since we will load the checkpoints, it will be significantly slower to go through this entire notebook. \n",
20
    "\n",
21
    "If you do decide to go through this notebook, please kindly check the **Prerequisite** section below.\n",
22
    "\n",
23
    "There are continous development and retrieval techniques supported in LlamaIndex and this notebook just shows how to quickly replace components such as llm and embedding per user's choice, read more [documentation on llama-index](https://docs.llamaindex.ai/en/stable/) for the latest nformation. \n",
24
    "\n",
25
    "</div>\n",
26
    "\n",
27
    "### Prerequisite \n",
28
    "In order to successfully run this notebook, you will need the following -\n",
29
    "\n",
30
    "1. Already being approved of using the checkpoints via applying for [meta-llama](https://huggingface.co/meta-llama)\n",
31
    "2. At least 2 NVIDIA GPUs, each with at least 32G mem, preferably using Ampere architecture\n",
32
    "3. docker and [nvidia-docker](https://github.com/NVIDIA/nvidia-container-toolkit) installed \n",
33
    "4. Registered [NVIDIA NGC](https://www.nvidia.com/en-us/gpu-cloud/) and can pull and run NGC pytorch containers\n",
34
    "5. install necesary python dependencies : \n",
35
    "Note: if you are using the [Dockerfile.gpu_notebook](https://raw.githubusercontent.com/NVIDIA/GenerativeAIExamples/main/notebooks/Dockerfile.gpu_notebook), it should already prepare the environment for you. Otherwise please refer to the Dockerfile for environment building.\n",
36
    "\n",
37
    "In this notebook, we will cover the following custom plug-in components -\n",
38
    "\n",
39
    "    - LLM locally load from [HuggingFace Hub Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and warp this into llama-index \n",
40
    "    \n",
41
    "    - A [HuggingFace embedding all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) \n",
42
    "    \n"
43
   ]
44
  },
45
  {
46
   "cell_type": "markdown",
47
   "id": "3786d093",
48
   "metadata": {},
49
   "source": [
50
    "### Step 1 - Load [HuggingFace Hub Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) \n",
51
    "\n",
52
    "\n",
53
    "Note: Scroll down and make sure you supply the **hf_token in code block below, replace [FILL_IN] with your huggingface token** \n",
54
    ", for how to generate the token from huggingface, please following instruction from [this link](https://huggingface.co/docs/transformers.js/guides/private)"
55
   ]
56
  },
57
  {
58
   "cell_type": "code",
59
   "execution_count": null,
60
   "id": "9e080a53",
61
   "metadata": {},
62
   "outputs": [],
63
   "source": [
64
    "## uncomment the below if you have not yet install the python dependencies\n",
65
    "#!pip install accelerate transformers==4.33.1 --upgrade"
66
   ]
67
  },
68
  {
69
   "cell_type": "code",
70
   "execution_count": null,
71
   "id": "8fcd1582",
72
   "metadata": {},
73
   "outputs": [],
74
   "source": [
75
    "import logging\n",
76
    "import sys\n",
77
    "\n",
78
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
79
    "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
80
    "import os\n",
81
    "from IPython.display import Markdown, display\n",
82
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer\n",
83
    "import torch\n",
84
    "\n",
85
    "def load_hf_model(model_name_or_path, device, num_gpus,hf_auth_token, debug=False):\n",
86
    "    \"\"\"Load an HF locally saved checkpoint.\"\"\"\n",
87
    "    if device == \"cpu\":\n",
88
    "        kwargs = {}\n",
89
    "    elif device == \"cuda\":\n",
90
    "        kwargs = {\"torch_dtype\": torch.float16}\n",
91
    "        if num_gpus == \"auto\":\n",
92
    "            kwargs[\"device_map\"] = \"auto\"\n",
93
    "        else:\n",
94
    "            num_gpus = int(num_gpus)\n",
95
    "            if num_gpus != 1:\n",
96
    "                kwargs.update(\n",
97
    "                    {\n",
98
    "                        \"device_map\": \"auto\",\n",
99
    "                        \"max_memory\": {i: \"13GiB\" for i in range(num_gpus)},\n",
100
    "                    }\n",
101
    "                )\n",
102
    "    elif device == \"mps\":\n",
103
    "        kwargs = {\"torch_dtype\": torch.float16}\n",
104
    "        # Avoid bugs in mps backend by not using in-place operations.\n",
105
    "        print(\"mps not supported\")\n",
106
    "    else:\n",
107
    "        raise ValueError(f\"Invalid device: {device}\")\n",
108
    "\n",
109
    "    if hf_auth_token is None:\n",
110
    "        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)\n",
111
    "        model = AutoModelForCausalLM.from_pretrained(\n",
112
    "            model_name_or_path, low_cpu_mem_usage=True, **kwargs\n",
113
    "        )\n",
114
    "    else:\n",
115
    "        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=hf_auth_token, use_fast=False)\n",
116
    "        model = AutoModelForCausalLM.from_pretrained(\n",
117
    "            model_name_or_path, low_cpu_mem_usage=True,use_auth_token=hf_auth_token, **kwargs\n",
118
    "        )\n",
119
    "\n",
120
    "    if device == \"cuda\" and num_gpus == 1:\n",
121
    "        model.to(device)\n",
122
    "\n",
123
    "    if debug:\n",
124
    "        print(model)\n",
125
    "\n",
126
    "    return model, tokenizer\n",
127
    "\n",
128
    "\n",
129
    "\n",
130
    "# Define variable to hold llama2 weights naming\n",
131
    "model_name_or_path = \"meta-llama/Llama-2-13b-chat-hf\"\n",
132
    "# Set auth token variable from hugging face\n",
133
    "# Create tokenizer\n",
134
    "hf_token= \"[FILL_IN]\"\n",
135
    "device = \"cuda\"\n",
136
    "num_gpus = 2\n",
137
    "\n",
138
    "model, tokenizer = load_hf_model(model_name_or_path, device, num_gpus,hf_auth_token=hf_token, debug=False)\n",
139
    "# Setup a prompt\n",
140
    "prompt = \"### User:What is the fastest car in  \\\n",
141
    "          the world and how much does it cost? \\\n",
142
    "          ### Assistant:\"\n",
143
    "# Pass the prompt to the tokenizer\n",
144
    "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
145
    "# Setup the text streamer\n",
146
    "streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)"
147
   ]
148
  },
149
  {
150
   "cell_type": "markdown",
151
   "id": "10945486",
152
   "metadata": {},
153
   "source": [
154
    "run a test and see the model generating output response"
155
   ]
156
  },
157
  {
158
   "cell_type": "code",
159
   "execution_count": null,
160
   "id": "3fd772cf",
161
   "metadata": {},
162
   "outputs": [],
163
   "source": [
164
    "output = model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=100)\n",
165
    "# Covert the output tokens back to text\n",
166
    "output_text = tokenizer.decode(output[0], skip_special_tokens=True)\n",
167
    "output_text"
168
   ]
169
  },
170
  {
171
   "cell_type": "markdown",
172
   "id": "8735d640",
173
   "metadata": {},
174
   "source": [
175
    "### Step 2 - Construct prompt template"
176
   ]
177
  },
178
  {
179
   "cell_type": "code",
180
   "execution_count": null,
181
   "id": "e53c0e45",
182
   "metadata": {},
183
   "outputs": [],
184
   "source": [
185
    "# Import the prompt wrapper...but for llama index\n",
186
    "from llama_index.prompts.prompts import SimpleInputPrompt\n",
187
    "# Create a system prompt\n",
188
    "system_prompt = \"\"\"<<SYS>>\n",
189
    "You are a helpful, respectful and honest assistant. Always answer as\n",
190
    "helpfully as possible, while being safe. Your answers should not include\n",
191
    "any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.\n",
192
    "Please ensure that your responses are socially unbiased and positive in nature.\n",
193
    "\n",
194
    "If a question does not make any sense, or is not factually coherent, explain\n",
195
    "why instead of answering something not correct. If you don't know the answer\n",
196
    "to a question, please don't share false information.\n",
197
    "\n",
198
    "Your goal is to provide answers relating to the financial performance of\n",
199
    "the company.<</SYS>>[INST]\n",
200
    "\"\"\"\n",
201
    "# Throw together the query wrapper\n",
202
    "query_wrapper_prompt = SimpleInputPrompt(\"{query_str} [/INST]\")\n",
203
    "## do a test query\n",
204
    "query_str='What can you help me with?'\n",
205
    "query_wrapper_prompt.format(query_str=query_str)\n"
206
   ]
207
  },
208
  {
209
   "cell_type": "markdown",
210
   "id": "58a675f3",
211
   "metadata": {},
212
   "source": [
213
    "### Step 3 - Load the chosen huggingface Embedding"
214
   ]
215
  },
216
  {
217
   "cell_type": "code",
218
   "execution_count": null,
219
   "id": "f14a0a8b",
220
   "metadata": {},
221
   "outputs": [],
222
   "source": [
223
    "# Create and dl embeddings instance wrapping huggingface embedding into langchain embedding\n",
224
    "# Bring in embeddings wrapper\n",
225
    "from llama_index.embeddings import LangchainEmbedding\n",
226
    "# Bring in HF embeddings - need these to represent document chunks\n",
227
    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
228
    "embeddings=LangchainEmbedding(\n",
229
    "    HuggingFaceEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n",
230
    ")\n"
231
   ]
232
  },
233
  {
234
   "cell_type": "markdown",
235
   "id": "2df926bc",
236
   "metadata": {},
237
   "source": [
238
    "### Step 4 - Prepare the locally loaded huggingface llm into into llamaindex"
239
   ]
240
  },
241
  {
242
   "cell_type": "code",
243
   "execution_count": null,
244
   "id": "2e541fb9",
245
   "metadata": {},
246
   "outputs": [],
247
   "source": [
248
    "# Import the llama index HF Wrapper\n",
249
    "from llama_index.llms import HuggingFaceLLM\n",
250
    "# Create a HF LLM using the llama index wrapper\n",
251
    "llm = HuggingFaceLLM(context_window=4096,\n",
252
    "                    max_new_tokens=256,\n",
253
    "                    system_prompt=system_prompt,\n",
254
    "                    query_wrapper_prompt=query_wrapper_prompt,\n",
255
    "                    model=model,\n",
256
    "                    tokenizer=tokenizer)\n"
257
   ]
258
  },
259
  {
260
   "cell_type": "markdown",
261
   "id": "dd3275e7",
262
   "metadata": {},
263
   "source": [
264
    "### Step 5 - Wrap the custom embedding and the locally loaded huggingface llm into llama-index's ServiceContext"
265
   ]
266
  },
267
  {
268
   "cell_type": "code",
269
   "execution_count": null,
270
   "id": "1fa1aeeb",
271
   "metadata": {},
272
   "outputs": [],
273
   "source": [
274
    "# Bring in stuff to change service context\n",
275
    "from llama_index import set_global_service_context\n",
276
    "from llama_index import ServiceContext"
277
   ]
278
  },
279
  {
280
   "cell_type": "code",
281
   "execution_count": null,
282
   "id": "6284cadd",
283
   "metadata": {},
284
   "outputs": [],
285
   "source": [
286
    "# Create new service context instance\n",
287
    "service_context = ServiceContext.from_defaults(\n",
288
    "    chunk_size=1024,\n",
289
    "    llm=llm,\n",
290
    "    embed_model=embeddings\n",
291
    ")\n",
292
    "# And set the service context\n",
293
    "set_global_service_context(service_context)\n"
294
   ]
295
  },
296
  {
297
   "cell_type": "markdown",
298
   "id": "88937615",
299
   "metadata": {},
300
   "source": [
301
    "### Step 6a - Load the text data using llama-index's SimpleDirectoryReader and we will be using the built-in [VectorStoreIndex](https://docs.llamaindex.ai/en/latest/community/integrations/vector_stores.html)"
302
   ]
303
  },
304
  {
305
   "cell_type": "code",
306
   "execution_count": null,
307
   "id": "63d060f6",
308
   "metadata": {},
309
   "outputs": [],
310
   "source": [
311
    "#create query engine with cross encoder reranker\n",
312
    "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n",
313
    "import torch\n",
314
    "\n",
315
    "documents = SimpleDirectoryReader(\"./toy_data\").load_data()\n",
316
    "index = VectorStoreIndex.from_documents(documents, service_context=service_context)\n"
317
   ]
318
  },
319
  {
320
   "cell_type": "markdown",
321
   "id": "fe5db8f7",
322
   "metadata": {},
323
   "source": [
324
    "### Step 6b - This will serve as the query engine for us to ask questions"
325
   ]
326
  },
327
  {
328
   "cell_type": "code",
329
   "execution_count": null,
330
   "id": "3fc661a5",
331
   "metadata": {},
332
   "outputs": [],
333
   "source": [
334
    "# Setup index query engine using LLM\n",
335
    "query_engine = index.as_query_engine()\n"
336
   ]
337
  },
338
  {
339
   "cell_type": "code",
340
   "execution_count": null,
341
   "id": "0d6fe0b1",
342
   "metadata": {},
343
   "outputs": [],
344
   "source": [
345
    "# Test out a query in natural\n",
346
    "response = query_engine.query(\"Tell me about Sweden's population?\")"
347
   ]
348
  },
349
  {
350
   "cell_type": "code",
351
   "execution_count": null,
352
   "id": "4e402baa",
353
   "metadata": {},
354
   "outputs": [],
355
   "source": [
356
    "response.metadata"
357
   ]
358
  },
359
  {
360
   "cell_type": "code",
361
   "execution_count": null,
362
   "id": "23748347",
363
   "metadata": {},
364
   "outputs": [],
365
   "source": [
366
    "response.response"
367
   ]
368
  }
369
 ],
370
 "metadata": {
371
  "kernelspec": {
372
   "display_name": "Python 3 (ipykernel)",
373
   "language": "python",
374
   "name": "python3"
375
  },
376
  "language_info": {
377
   "codemirror_mode": {
378
    "name": "ipython",
379
    "version": 3
380
   },
381
   "file_extension": ".py",
382
   "mimetype": "text/x-python",
383
   "name": "python",
384
   "nbconvert_exporter": "python",
385
   "pygments_lexer": "ipython3",
386
   "version": "3.10.6"
387
  }
388
 },
389
 "nbformat": 4,
390
 "nbformat_minor": 5
391
}
392

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.