GenerativeAIExamples

Форк
0
526 строк · 21.8 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "markdown",
5
   "id": "3d4e2931",
6
   "metadata": {},
7
   "source": [
8
    "## CodeGemma Parameter Efficient Fine-Tuning with LoRA using NeMo Framework\n",
9
    "\n",
10
    "CodeGemma is a groundbreaking new open model in the Gemini family of models from Google. CodeGemma is just as powerful as previous models but compact enough to run locally on NVIDIA RTX GPUs. CodeGemma is available in 2 sizes: 2B and 7B parameters. With NVIDIA NeMo, you can customize CodeGemma to fit your usecase and deploy an optimized model on your NVIDIA GPU.\n",
11
    "\n",
12
    "In this tutorial, we'll go over a specific kind of customization -- Low-rank adapter tuning to follow a specific output format (also known as LoRA). To learn how to perform full parameter supervised fine-tuning for instruction following (also known as SFT), see the [SFT notebook on Gemma Base Model](https://github.com/NVIDIA/GenerativeAIExamples/blob/main/models/Gemma/sft.ipynb). For LoRA, we'll perform all operations within the notebook on a single GPU. The compute resources needed for training depend on which CodeGemma model you use. For the 7 billion parameter variant, you'll need a GPU with 80GB of memory. For the 2 billion parameter model, 40GB will do.\n",
13
    "\n",
14
    "We'll also learn how to export your custom model to TensorRT-LLM, an open-source library that accelerates and optimizes inference performance of the latest LLMs on the NVIDIA AI platform."
15
   ]
16
  },
17
  {
18
   "cell_type": "markdown",
19
   "id": "6a0c3e84",
20
   "metadata": {},
21
   "source": [
22
    "## Introduction\n",
23
    "\n",
24
    "[LoRA tuning](https://arxiv.org/abs/2106.09685) is a parameter efficient method for fine-tuning models, where we freeze the base model parameters and update an auxiliary \"adapter\" with many fewer weights. At inference time, the adapter weights are combined with the base model weights to produce a new model, customized for a particular use case or dataset. Because this adapter is so much smaller than the base model, it can be trained with far fewer resources than it would take to fine-tune the entire model. In this notebook, we'll show you how to LoRA-tune small models like the CodeGemma models on a single A100 GPU.\n",
25
    "\n",
26
    "For this example, we're going to tune our CodeGemma model on the [Alpaca Python Code Instructions Dataset](https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca) and tuning our model to enhance its instruction following ability for generating Python code."
27
   ]
28
  },
29
  {
30
   "cell_type": "markdown",
31
   "id": "8df62050",
32
   "metadata": {},
33
   "source": [
34
    "## Download the Pretrained CodeGemma Model\n",
35
    "\n",
36
    "For all of our customization and deployment processes, we'll need to start off with a pre-trained version of CodeGemma in the `.nemo` format. You can download the base model in `.nemo` format from the NVIDIA GPU Cloud, or convert checkpoints from another framework into a `.nemo` file. You can choose to use the 2B parameter or 7B parameter CodeGemma models for this notebook -- the 2B model will be faster to customize, but the 7B model will be more capable.\n",
37
    "\n",
38
    "You can download either model from the NVIDIA NGC Catalog, using the NGC CLI. The instructions to install and configure the NGC CLI can be found [here](https://ngc.nvidia.com/setup/installers/cli).\n",
39
    "\n",
40
    "To download the model, execute one of the following commands, based on which model you want to use:\n",
41
    "\n",
42
    "ngc registry model download-version \"nvidia/nemo/codegemma_2b_base:1.0\"\n",
43
    "\n",
44
    "or\n",
45
    "\n",
46
    "ngc registry model download-version \"nvidia/nemo/codegemma_7b_base:1.0\""
47
   ]
48
  },
49
  {
50
   "cell_type": "markdown",
51
   "id": "21facc49",
52
   "metadata": {},
53
   "source": [
54
    "## Getting NeMo Framework\n",
55
    "\n",
56
    "NVIDIA NeMo Framework is a generative AI framework built for researchers and PyTorch developers working on large language models (LLMs), multimodal models (MM), automatic speech recognition (ASR), and text-to-speech synthesis (TTS). The primary objective of NeMo is to provide a scalable framework for researchers and developers from industry and academia to more easily implement and design new generative AI models by being able to leverage existing code and pretrained models.\n",
57
    "\n",
58
    "If you haven't already, you can pull a container that includes the version of NeMo Framework and all dependencies needed for this notebook with the following:\n",
59
    "\n",
60
    "docker pull nvcr.io/nvidia/nemo:24.03.codegemma\n",
61
    "\n",
62
    "The best way to run this notebook is from within the container. You can do that by launching the container with the following command\n",
63
    "\n",
64
    "docker run -it --rm --gpus all --ipc=host --network host -v $(pwd):/workspace nvcr.io/nvidia/nemo:24.03.codegemma\n",
65
    "\n",
66
    "Then, from within the container, start the jupyter server with\n",
67
    "\n",
68
    "jupyter lab --no-browser --port=5000 --allow-root --ip 0.0.0.0"
69
   ]
70
  },
71
  {
72
   "cell_type": "markdown",
73
   "id": "02f23c5e",
74
   "metadata": {},
75
   "source": [
76
    "## Dataset Preparation\n",
77
    "\n",
78
    "Let's download Alpaca Python Code Instructions dataset from Hugging Face:"
79
   ]
80
  },
81
  {
82
   "cell_type": "code",
83
   "execution_count": null,
84
   "id": "444aa398",
85
   "metadata": {},
86
   "outputs": [],
87
   "source": [
88
    "!git lfs install\n",
89
    "!git clone https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca"
90
   ]
91
  },
92
  {
93
   "cell_type": "markdown",
94
   "id": "99c914d4",
95
   "metadata": {},
96
   "source": [
97
    "Finally, the following code snippets convert the dataset into the JSONL format that NeMo defaults for PEFT. Meanwhile, we will reformat the data into list of (prompt, completion) pairs that our model can appropriately handle. Please refer to the printout for the original code instruction data format."
98
   ]
99
  },
100
  {
101
   "cell_type": "code",
102
   "execution_count": null,
103
   "id": "bcc7f461",
104
   "metadata": {},
105
   "outputs": [],
106
   "source": [
107
    "import pandas as pd\n",
108
    "import glob\n",
109
    "from random import seed, shuffle\n",
110
    "from huggingface_hub import login\n",
111
    "\n",
112
    "login(token='your_huggingface_access_token')\n",
113
    "parquet_file_path = glob.glob('./python_code_instructions_18k_alpaca/data/*.parquet')\n",
114
    "parquet_file_list = ''.join(parquet_file_path)\n",
115
    "df = pd.read_parquet(parquet_file_list)\n",
116
    "instruct2code_list = df.to_dict('records')\n",
117
    "\n",
118
    "seed(2)\n",
119
    "val_percent = 5\n",
120
    "test_percent = 5\n",
121
    "instruct2code_list = instruct2code_list[:len(instruct2code_list)] \n",
122
    "num_train = int(len(instruct2code_list) * (100 - val_percent - test_percent) / 100)\n",
123
    "num_val = int(len(instruct2code_list)*(val_percent)/100)\n",
124
    "shuffle(instruct2code_list)\n",
125
    "\n",
126
    "instruct2code_list_train = instruct2code_list[:num_train]\n",
127
    "instruct2code_list_val = instruct2code_list[num_train : num_train + num_val]\n",
128
    "instruct2code_list_test = instruct2code_list[num_train + num_val:]\n",
129
    "print(f\"=== Input prompt example from the training split:\\n{instruct2code_list_train[5]['prompt']}\\n\") \n",
130
    "print(f\"=== Output completion example from the validation split:\\n{instruct2code_list_val[5]['output']}\")"
131
   ]
132
  },
133
  {
134
   "cell_type": "code",
135
   "execution_count": null,
136
   "id": "d30c4c0c",
137
   "metadata": {},
138
   "outputs": [],
139
   "source": [
140
    "\n",
141
    "import json\n",
142
    "def write_jsonl(fname, json_objs):\n",
143
    "    with open(fname, 'wt') as f:\n",
144
    "        for o in json_objs:\n",
145
    "            f.write(json.dumps(o)+\"\\n\")\n",
146
    "def form_instruction(pair):\n",
147
    "    outpout_loc = pair.find('### Output')\n",
148
    "    return(pair[:outpout_loc])\n",
149
    "def convert_to_jsonl(instruct2code_list, output_path):\n",
150
    "    json_objs = []\n",
151
    "    for pair in instruct2code_list:\n",
152
    "        prompt = form_instruction(pair['prompt'])\n",
153
    "        completion = pair['output']\n",
154
    "        json_objs.append({\"input\": prompt, \"output\": completion})\n",
155
    "    write_jsonl(output_path, json_objs)\n",
156
    "    return json_objs\n",
157
    "\n",
158
    "print(len(instruct2code_list_train))\n",
159
    "train_json_objs = convert_to_jsonl(instruct2code_list_train, \"alpaca_python_train.jsonl\")\n",
160
    "val_json_objs= convert_to_jsonl(instruct2code_list_val, \"alpaca_python_val.jsonl\")\n",
161
    "test_json_objs = convert_to_jsonl(instruct2code_list_test, \"alpaca_python_test.jsonl\")"
162
   ]
163
  },
164
  {
165
   "cell_type": "markdown",
166
   "id": "5284142f",
167
   "metadata": {},
168
   "source": [
169
    "Here's an example of what the data looks like after reformatting:"
170
   ]
171
  },
172
  {
173
   "cell_type": "code",
174
   "execution_count": null,
175
   "id": "1c892887",
176
   "metadata": {},
177
   "outputs": [],
178
   "source": [
179
    "train_json_objs[0]"
180
   ]
181
  },
182
  {
183
   "cell_type": "markdown",
184
   "id": "18b28a24",
185
   "metadata": {},
186
   "source": [
187
    "## LoRA Configuration and Training\n",
188
    "\n",
189
    "NeMo Framework provides support for configuration and training. To proceed with the training, you'll find a script at `/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py`. The script uses config parameters to control many of its operations. An example config file allows you to quickly see what options you can change and carry out different experiments. We can start by downloading the example config file, `megatron_gpt_peft_tuning_config.yaml` from github. The file is referenced to configure the parameters for the running PEFT training jobs in NeMo with LoRA technique for language model tuning. \n",
190
    "\n"
191
   ]
192
  },
193
  {
194
   "cell_type": "code",
195
   "execution_count": null,
196
   "id": "b40c60b1",
197
   "metadata": {},
198
   "outputs": [],
199
   "source": [
200
    "!wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml"
201
   ]
202
  },
203
  {
204
   "cell_type": "markdown",
205
   "id": "54e5428e",
206
   "metadata": {},
207
   "source": [
208
    "To see all of the different configuration options available, you can take a look at the file we downloaded. For this example, we're going to update a couple of settings to point to our newly-prepared datasets and to make sure the LoRA tuning runs on our A100. Feel free to experiment with these different options -- you can swap in your own datasets and change the training settings depending on what GPU you're using.\n",
209
    "\n",
210
    "For data our data configuration, we'll point to the `jsonl` files we wrote out earlier. `concat_sampling_probabilities` determines what percentage of the finetuning data you would like to come from each file -- in our example we only have 1 training file so we choose [1.0]"
211
   ]
212
  },
213
  {
214
   "cell_type": "markdown",
215
   "id": "89e40bd9",
216
   "metadata": {},
217
   "source": [
218
    "For our model settings, we don't have much to change since we're reading in a pretrained model and can inherit the values that were already set. We need to point to our existing `.nemo` file, specify that we want to use LoRA as our scheme for finetuning, and choose our parallelism and batch size values. The values below should be appropriate for a single A100 GPU.\n",
219
    "\n",
220
    "Make sure to change the `restore_from_path` setting with the path to the `.nemo` checkpoint!"
221
   ]
222
  },
223
  {
224
   "cell_type": "markdown",
225
   "id": "409717ce",
226
   "metadata": {},
227
   "source": [
228
    "Finally, we set some options for the `Trainer`. We'll be training on 1 GPU on a single node, at bfloat16 precision. For this example we'll train for 2000 steps, with a validation check every after every 200 iterations."
229
   ]
230
  },
231
  {
232
   "cell_type": "markdown",
233
   "id": "c91a7095",
234
   "metadata": {},
235
   "source": [
236
    "After setting the `Trainer` object configurations to handle our training loop, we set configurations for an experiment manager to handle checkpointing and logging. We can load our model from disk into memory. "
237
   ]
238
  },
239
  {
240
   "cell_type": "markdown",
241
   "id": "58478c4f",
242
   "metadata": {},
243
   "source": [
244
    "Now, let's see how to add the LoRA Adapter to our model and train it. We can specify that we want to use LoRA by using the `model.peft.peft_scheme` configuration to `lora`, which stores the types of applicable adapter and the hyperparameters required to initialize the adapter module.\n",
245
    "\n",
246
    "We're now ready to start training! As the training loop runs, you'll see the validation loss drop significantly -- even with this short demonstration."
247
   ]
248
  },
249
  {
250
   "cell_type": "code",
251
   "execution_count": null,
252
   "id": "3513c286",
253
   "metadata": {},
254
   "outputs": [],
255
   "source": [
256
    "%%bash\n",
257
    "\n",
258
    "PEFT_SCHEME='lora'\n",
259
    "MODEL_SIZE=7b\n",
260
    "MBS=1\n",
261
    "TP=1\n",
262
    "PP=1\n",
263
    "NUM_DEVICES=1\n",
264
    "GBS=8\n",
265
    "SEQ_LEN=4096\n",
266
    "\n",
267
    "EXTRA_ARGS=\"\n",
268
    "        +model.fp8=False \\\n",
269
    "        +model.fp8_e4m3=False \\\n",
270
    "        +model.fp8_hybrid=True \\\n",
271
    "        +model.fp8_margin=0 \\\n",
272
    "        +model.fp8_interval=1 \\\n",
273
    "        +model.fp8_amax_history_len=128 \\\n",
274
    "        +model.fp8_amax_compute_algo=max \"\n",
275
    "\n",
276
    "TRAIN_DS=[alpaca_python_train.jsonl]\n",
277
    "VALID_DS=[alpaca_python_val.jsonl]\n",
278
    "GBS=128\n",
279
    "PACKED=False\n",
280
    "MODEL=codegemma-7b_fromhf.nemo\n",
281
    "EXP_DIR=nemo_experiments\n",
282
    "    \n",
283
    "torchrun --nproc_per_node=1 \\\n",
284
    "/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \\\n",
285
    "    trainer.devices=${NUM_DEVICES} \\\n",
286
    "        trainer.num_nodes=1 \\\n",
287
    "        trainer.val_check_interval=200 \\\n",
288
    "        trainer.max_steps=2000 \\\n",
289
    "        +trainer.num_sanity_val_steps=0 \\\n",
290
    "        +trainer.limit_val_batches=3 \\\n",
291
    "        model.megatron_amp_O2=True \\\n",
292
    "        exp_manager.resume_if_exists=False \\\n",
293
    "        exp_manager.exp_dir=\"${EXP_DIR}\" \\\n",
294
    "        exp_manager.checkpoint_callback_params.save_top_k=0 \\\n",
295
    "        model.tensor_model_parallel_size=${TP} \\\n",
296
    "        model.pipeline_model_parallel_size=${PP} \\\n",
297
    "        model.micro_batch_size=${MBS} \\\n",
298
    "        model.global_batch_size=${GBS} \\\n",
299
    "        model.restore_from_path=${MODEL} \\\n",
300
    "        model.data.train_ds.num_workers=0 \\\n",
301
    "        model.data.validation_ds.num_workers=0 \\\n",
302
    "        +model.data.train_ds.packed_sequence=${PACKED} \\\n",
303
    "        ++model.sequence_parallel=False \\\n",
304
    "        +model.log_token_counts=True \\\n",
305
    "        model.data.train_ds.file_names=${TRAIN_DS} \\\n",
306
    "        model.data.train_ds.concat_sampling_probabilities=[1.0] \\\n",
307
    "        model.data.validation_ds.file_names=${VALID_DS} \\\n",
308
    "        model.peft.peft_scheme=${PEFT_SCHEME} \\\n",
309
    "        model.peft.lora_tuning.target_modules=[attention_qkv] \\\n",
310
    "        model.data.train_ds.max_seq_length=${SEQ_LEN} \\\n",
311
    "        model.data.validation_ds.max_seq_length=${SEQ_LEN} \\\n",
312
    "        +model.apply_rope_fusion=True \\\n",
313
    "        ${EXTRA_ARGS} \\\n",
314
    "        trainer.precision=bf16 \\\n",
315
    "        model.answer_only_loss=True"
316
   ]
317
  },
318
  {
319
   "cell_type": "markdown",
320
   "id": "275984c7",
321
   "metadata": {},
322
   "source": [
323
    "Once training is completed you should see a saved '.nemo' file in the nemo_experiments folder. This checkpoint will only contain the trained adapter weights, and not the frozen base model weights."
324
   ]
325
  },
326
  {
327
   "cell_type": "markdown",
328
   "id": "21dabcf2",
329
   "metadata": {},
330
   "source": [
331
    "Next, we'll need to merge the weights of the base model and the weights of the adapter. If you're using the `NeMo Framework` container, you'll find a script for this at `/opt/NeMo/scripts/nlp_language_modeling/merge_lora_weights/merge.py`. Otherwise, you can download the standalone script from GitHub at https://raw.githubusercontent.com/NVIDIA/NeMo/main/scripts/nlp_language_modeling/merge_lora_weights/merge.py\n",
332
    "\n",
333
    "To merge weights using the merge script, you'll need the path to the base model and trained adapter, as well as a path to save the merged model to."
334
   ]
335
  },
336
  {
337
   "cell_type": "code",
338
   "execution_count": null,
339
   "id": "794cf054",
340
   "metadata": {},
341
   "outputs": [],
342
   "source": [
343
    "%%bash\n",
344
    "python /opt/NeMo/scripts/nlp_language_modeling/merge_lora_weights/merge.py \\\n",
345
    "    trainer.accelerator=gpu \\\n",
346
    "    tensor_model_parallel_size=1 \\\n",
347
    "    pipeline_model_parallel_size=1 \\\n",
348
    "    gpt_model_file=codegemma-7b_fromhf.nemo \\\n",
349
    "    lora_model_path=megatron_gpt_peft_lora_tuning.nemo \\\n",
350
    "    merged_model_path=gemma_lora_alpaca_python_merged.nemo"
351
   ]
352
  },
353
  {
354
   "cell_type": "markdown",
355
   "id": "09bbc702",
356
   "metadata": {},
357
   "source": [
358
    "With our merged model weights, we can run evaluation on test dataset using `megatron_gpt_peft_eval.py`. We set the Set the appropriate model checkpoint path, test file path, batch sizes, number of tokens etc. and run evaluation on the test file."
359
   ]
360
  },
361
  {
362
   "cell_type": "code",
363
   "execution_count": null,
364
   "id": "fb943ee7",
365
   "metadata": {},
366
   "outputs": [],
367
   "source": [
368
    "%%bash\n",
369
    "\n",
370
    "python /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \\\n",
371
    "    model.restore_from_path=gemma_lora_alpaca_python_merged.nemo \\\n",
372
    "    trainer.devices=1 \\\n",
373
    "    model.global_batch_size=8 \\\n",
374
    "    model.data.test_ds.file_names=[\"alpaca_python_test.jsonl\"] \\\n",
375
    "    model.data.test_ds.names=[\"alpaca_python_test_set\"] \\\n",
376
    "    model.data.test_ds.global_batch_size=8 \\\n",
377
    "    model.data.test_ds.micro_batch_size=1 \\\n",
378
    "    model.data.test_ds.tokens_to_generate=20 \\\n",
379
    "    model.tensor_model_parallel_size=1 \\\n",
380
    "    model.pipeline_model_parallel_size=1 \\\n",
381
    "    inference.greedy=True \\\n",
382
    "    model.data.test_ds.output_file_path_prefix=/results \\\n",
383
    "    model.data.test_ds.write_predictions_to_file=True"
384
   ]
385
  },
386
  {
387
   "cell_type": "markdown",
388
   "id": "579048ac",
389
   "metadata": {},
390
   "source": [
391
    "Check the output from the result file:"
392
   ]
393
  },
394
  {
395
   "cell_type": "code",
396
   "execution_count": null,
397
   "id": "55756a1c",
398
   "metadata": {},
399
   "outputs": [],
400
   "source": [
401
    "!tail -n 4 /results_test_alpaca_python_test_set_inputs_preds_labels.jsonl"
402
   ]
403
  },
404
  {
405
   "cell_type": "markdown",
406
   "id": "389f2115",
407
   "metadata": {},
408
   "source": [
409
    "Note, This is only a sample output (based of a toy LoRA example) and your output may vary. The performance can be further improved by fine tuning the model for more steps."
410
   ]
411
  },
412
  {
413
   "cell_type": "markdown",
414
   "id": "9fa0eb3a",
415
   "metadata": {},
416
   "source": [
417
    "Finally, let's continue on to the \"Exporting to TensorRT-LLM\" section, to learn how to export our new model for optimized inference using TensorRT-LLM! "
418
   ]
419
  },
420
  {
421
   "cell_type": "markdown",
422
   "id": "86e36e7b",
423
   "metadata": {},
424
   "source": [
425
    "## Exporting to TensorRT-LLM\n",
426
    "\n",
427
    "TensorRT-LLM is an open-source library for optimizing inference performance to acheive state-of-the-art speed on NVDIA GPUs. The NeMo framework offers an easy way to compile .nemo models into optimized TensorRT-LLM engines which you can run locally embedded in another application, or serve to other applications using a server like Triton Inference Server.\n",
428
    "\n",
429
    "To start with, lets create a folder where our exported model will land"
430
   ]
431
  },
432
  {
433
   "cell_type": "code",
434
   "execution_count": null,
435
   "id": "3b32a604",
436
   "metadata": {},
437
   "outputs": [],
438
   "source": [
439
    "!mkdir codegemma_trt_llm"
440
   ]
441
  },
442
  {
443
   "cell_type": "markdown",
444
   "id": "a1f8ea73",
445
   "metadata": {},
446
   "source": [
447
    "With our merged model weights, we just need to create an instance of the TensorRTLLM class and call the TensorRTLLM.export() function -- pointing the nemo_checkpoint_path argument to the newly merged model from above.\n",
448
    "\n",
449
    "This creates a couple of files in the folder we created -- an engine file that holds the weights and the compiled execution graph of the model, a tokenizer.model file which holds the tokenizer information, and config.json which holds some metadata about the model (along with model.cache, which caches some operations and makes it faster to re-compile the model in the future.)"
450
   ]
451
  },
452
  {
453
   "cell_type": "code",
454
   "execution_count": null,
455
   "id": "f2eca7c1",
456
   "metadata": {},
457
   "outputs": [],
458
   "source": [
459
    "from nemo.export import TensorRTLLM\n",
460
    "trt_llm_exporter = TensorRTLLM(model_dir=\"gemma_alpaca_python_merged_trt_llm\")\n",
461
    "trt_llm_exporter.export(nemo_checkpoint_path=\"gemma_lora_alpaca_python_merged.nemo\", model_type=\"gemma\", n_gpus=1)"
462
   ]
463
  },
464
  {
465
   "cell_type": "markdown",
466
   "id": "2de9dde1",
467
   "metadata": {},
468
   "source": [
469
    "With the model exported into TensorRTLLM, we can perform very fast inference:"
470
   ]
471
  },
472
  {
473
   "cell_type": "code",
474
   "execution_count": null,
475
   "id": "072e9474",
476
   "metadata": {},
477
   "outputs": [],
478
   "source": [
479
    "trt_llm_exporter.forward([\"Implement Fibonacci sequence in Python\"])"
480
   ]
481
  },
482
  {
483
   "cell_type": "markdown",
484
   "id": "9b2f968a",
485
   "metadata": {},
486
   "source": [
487
    "There's also a convenient function to deploy a the model as a service, backed by Triton Inference Server:"
488
   ]
489
  },
490
  {
491
   "cell_type": "code",
492
   "execution_count": null,
493
   "id": "70fb870a",
494
   "metadata": {},
495
   "outputs": [],
496
   "source": [
497
    "from nemo.deploy import DeployPyTriton\n",
498
    "\n",
499
    "nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name=\"gemma\")\n",
500
    "nm.deploy()\n",
501
    "nm.serve()"
502
   ]
503
  }
504
 ],
505
 "metadata": {
506
  "kernelspec": {
507
   "display_name": "Python 3 (ipykernel)",
508
   "language": "python",
509
   "name": "python3"
510
  },
511
  "language_info": {
512
   "codemirror_mode": {
513
    "name": "ipython",
514
    "version": 3
515
   },
516
   "file_extension": ".py",
517
   "mimetype": "text/x-python",
518
   "name": "python",
519
   "nbconvert_exporter": "python",
520
   "pygments_lexer": "ipython3",
521
   "version": "3.10.12"
522
  }
523
 },
524
 "nbformat": 4,
525
 "nbformat_minor": 5
526
}
527

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.