peft

Форк
0
/
Prompt_Tuning.ipynb 
692 строки · 31.0 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "id": "9ff5004e",
7
   "metadata": {},
8
   "outputs": [
9
    {
10
     "name": "stdout",
11
     "output_type": "stream",
12
     "text": [
13
      "\n",
14
      "===================================BUG REPORT===================================\n",
15
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
16
      "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
17
      "================================================================================\n",
18
      "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
19
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
20
      "CUDA SETUP: Detected CUDA version 117\n",
21
      "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
22
     ]
23
    }
24
   ],
25
   "source": [
26
    "import argparse\n",
27
    "import os\n",
28
    "\n",
29
    "import torch\n",
30
    "from torch.optim import AdamW\n",
31
    "from torch.utils.data import DataLoader\n",
32
    "from peft import (\n",
33
    "    get_peft_config,\n",
34
    "    get_peft_model,\n",
35
    "    get_peft_model_state_dict,\n",
36
    "    set_peft_model_state_dict,\n",
37
    "    PeftType,\n",
38
    "    PrefixTuningConfig,\n",
39
    "    PromptEncoderConfig,\n",
40
    "    PromptTuningConfig,\n",
41
    ")\n",
42
    "\n",
43
    "import evaluate\n",
44
    "from datasets import load_dataset\n",
45
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
46
    "from tqdm import tqdm"
47
   ]
48
  },
49
  {
50
   "cell_type": "code",
51
   "execution_count": 2,
52
   "id": "e32c4a9e",
53
   "metadata": {},
54
   "outputs": [],
55
   "source": [
56
    "batch_size = 32\n",
57
    "model_name_or_path = \"roberta-large\"\n",
58
    "task = \"mrpc\"\n",
59
    "peft_type = PeftType.PROMPT_TUNING\n",
60
    "device = \"cuda\"\n",
61
    "num_epochs = 20"
62
   ]
63
  },
64
  {
65
   "cell_type": "code",
66
   "execution_count": 3,
67
   "id": "622fe9c8",
68
   "metadata": {},
69
   "outputs": [],
70
   "source": [
71
    "peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n",
72
    "lr = 1e-3"
73
   ]
74
  },
75
  {
76
   "cell_type": "code",
77
   "execution_count": 4,
78
   "id": "74e9efe0",
79
   "metadata": {},
80
   "outputs": [
81
    {
82
     "name": "stderr",
83
     "output_type": "stream",
84
     "text": [
85
      "Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
86
     ]
87
    },
88
    {
89
     "data": {
90
      "application/vnd.jupyter.widget-view+json": {
91
       "model_id": "76198cec552441818ff107910275e5be",
92
       "version_major": 2,
93
       "version_minor": 0
94
      },
95
      "text/plain": [
96
       "  0%|          | 0/3 [00:00<?, ?it/s]"
97
      ]
98
     },
99
     "metadata": {},
100
     "output_type": "display_data"
101
    },
102
    {
103
     "name": "stderr",
104
     "output_type": "stream",
105
     "text": [
106
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
107
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
108
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
109
     ]
110
    }
111
   ],
112
   "source": [
113
    "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
114
    "    padding_side = \"left\"\n",
115
    "else:\n",
116
    "    padding_side = \"right\"\n",
117
    "\n",
118
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
119
    "if getattr(tokenizer, \"pad_token_id\") is None:\n",
120
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
121
    "\n",
122
    "datasets = load_dataset(\"glue\", task)\n",
123
    "metric = evaluate.load(\"glue\", task)\n",
124
    "\n",
125
    "\n",
126
    "def tokenize_function(examples):\n",
127
    "    # max_length=None => use the model max length (it's actually the default)\n",
128
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
129
    "    return outputs\n",
130
    "\n",
131
    "\n",
132
    "tokenized_datasets = datasets.map(\n",
133
    "    tokenize_function,\n",
134
    "    batched=True,\n",
135
    "    remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
136
    ")\n",
137
    "\n",
138
    "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
139
    "# transformers library\n",
140
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
141
    "\n",
142
    "\n",
143
    "def collate_fn(examples):\n",
144
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
145
    "\n",
146
    "\n",
147
    "# Instantiate dataloaders.\n",
148
    "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
149
    "eval_dataloader = DataLoader(\n",
150
    "    tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
151
    ")"
152
   ]
153
  },
154
  {
155
   "cell_type": "code",
156
   "execution_count": null,
157
   "id": "a3c15af0",
158
   "metadata": {},
159
   "outputs": [],
160
   "source": [
161
    "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
162
    "model = get_peft_model(model, peft_config)\n",
163
    "model.print_trainable_parameters()\n",
164
    "model"
165
   ]
166
  },
167
  {
168
   "cell_type": "code",
169
   "execution_count": 6,
170
   "id": "6d3c5edb",
171
   "metadata": {},
172
   "outputs": [],
173
   "source": [
174
    "optimizer = AdamW(params=model.parameters(), lr=lr)\n",
175
    "\n",
176
    "# Instantiate scheduler\n",
177
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
178
    "    optimizer=optimizer,\n",
179
    "    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
180
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
181
    ")"
182
   ]
183
  },
184
  {
185
   "cell_type": "code",
186
   "execution_count": 7,
187
   "id": "4d279225",
188
   "metadata": {},
189
   "outputs": [
190
    {
191
     "name": "stderr",
192
     "output_type": "stream",
193
     "text": [
194
      "  0%|                                                                                                  | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
195
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:09<00:00,  1.13s/it]\n",
196
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00,  1.62it/s]\n"
197
     ]
198
    },
199
    {
200
     "name": "stdout",
201
     "output_type": "stream",
202
     "text": [
203
      "epoch 0: {'accuracy': 0.678921568627451, 'f1': 0.7956318252730109}\n"
204
     ]
205
    },
206
    {
207
     "name": "stderr",
208
     "output_type": "stream",
209
     "text": [
210
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:50<00:00,  1.04it/s]\n",
211
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.22it/s]\n"
212
     ]
213
    },
214
    {
215
     "name": "stdout",
216
     "output_type": "stream",
217
     "text": [
218
      "epoch 1: {'accuracy': 0.696078431372549, 'f1': 0.8171091445427728}\n"
219
     ]
220
    },
221
    {
222
     "name": "stderr",
223
     "output_type": "stream",
224
     "text": [
225
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00,  1.19it/s]\n",
226
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  2.00it/s]\n"
227
     ]
228
    },
229
    {
230
     "name": "stdout",
231
     "output_type": "stream",
232
     "text": [
233
      "epoch 2: {'accuracy': 0.6985294117647058, 'f1': 0.8161434977578476}\n"
234
     ]
235
    },
236
    {
237
     "name": "stderr",
238
     "output_type": "stream",
239
     "text": [
240
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:37<00:00,  1.18it/s]\n",
241
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  2.09it/s]\n"
242
     ]
243
    },
244
    {
245
     "name": "stdout",
246
     "output_type": "stream",
247
     "text": [
248
      "epoch 3: {'accuracy': 0.7058823529411765, 'f1': 0.7979797979797979}\n"
249
     ]
250
    },
251
    {
252
     "name": "stderr",
253
     "output_type": "stream",
254
     "text": [
255
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:03<00:00,  1.07s/it]\n",
256
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.71it/s]\n"
257
     ]
258
    },
259
    {
260
     "name": "stdout",
261
     "output_type": "stream",
262
     "text": [
263
      "epoch 4: {'accuracy': 0.696078431372549, 'f1': 0.8132530120481929}\n"
264
     ]
265
    },
266
    {
267
     "name": "stderr",
268
     "output_type": "stream",
269
     "text": [
270
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:53<00:00,  1.01it/s]\n",
271
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.19it/s]\n"
272
     ]
273
    },
274
    {
275
     "name": "stdout",
276
     "output_type": "stream",
277
     "text": [
278
      "epoch 5: {'accuracy': 0.7107843137254902, 'f1': 0.8121019108280254}\n"
279
     ]
280
    },
281
    {
282
     "name": "stderr",
283
     "output_type": "stream",
284
     "text": [
285
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00,  1.20it/s]\n",
286
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.20it/s]\n"
287
     ]
288
    },
289
    {
290
     "name": "stdout",
291
     "output_type": "stream",
292
     "text": [
293
      "epoch 6: {'accuracy': 0.6911764705882353, 'f1': 0.7692307692307693}\n"
294
     ]
295
    },
296
    {
297
     "name": "stderr",
298
     "output_type": "stream",
299
     "text": [
300
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00,  1.20it/s]\n",
301
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.18it/s]\n"
302
     ]
303
    },
304
    {
305
     "name": "stdout",
306
     "output_type": "stream",
307
     "text": [
308
      "epoch 7: {'accuracy': 0.7156862745098039, 'f1': 0.8209876543209876}\n"
309
     ]
310
    },
311
    {
312
     "name": "stderr",
313
     "output_type": "stream",
314
     "text": [
315
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00,  1.20it/s]\n",
316
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.22it/s]\n"
317
     ]
318
    },
319
    {
320
     "name": "stdout",
321
     "output_type": "stream",
322
     "text": [
323
      "epoch 8: {'accuracy': 0.7205882352941176, 'f1': 0.8240740740740742}\n"
324
     ]
325
    },
326
    {
327
     "name": "stderr",
328
     "output_type": "stream",
329
     "text": [
330
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00,  1.19it/s]\n",
331
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.21it/s]\n"
332
     ]
333
    },
334
    {
335
     "name": "stdout",
336
     "output_type": "stream",
337
     "text": [
338
      "epoch 9: {'accuracy': 0.7205882352941176, 'f1': 0.8229813664596273}\n"
339
     ]
340
    },
341
    {
342
     "name": "stderr",
343
     "output_type": "stream",
344
     "text": [
345
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00,  1.20it/s]\n",
346
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.35it/s]\n"
347
     ]
348
    },
349
    {
350
     "name": "stdout",
351
     "output_type": "stream",
352
     "text": [
353
      "epoch 10: {'accuracy': 0.7156862745098039, 'f1': 0.8164556962025317}\n"
354
     ]
355
    },
356
    {
357
     "name": "stderr",
358
     "output_type": "stream",
359
     "text": [
360
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00,  1.20it/s]\n",
361
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.22it/s]\n"
362
     ]
363
    },
364
    {
365
     "name": "stdout",
366
     "output_type": "stream",
367
     "text": [
368
      "epoch 11: {'accuracy': 0.7058823529411765, 'f1': 0.8113207547169811}\n"
369
     ]
370
    },
371
    {
372
     "name": "stderr",
373
     "output_type": "stream",
374
     "text": [
375
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00,  1.24it/s]\n",
376
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.48it/s]\n"
377
     ]
378
    },
379
    {
380
     "name": "stdout",
381
     "output_type": "stream",
382
     "text": [
383
      "epoch 12: {'accuracy': 0.7009803921568627, 'f1': 0.7946127946127945}\n"
384
     ]
385
    },
386
    {
387
     "name": "stderr",
388
     "output_type": "stream",
389
     "text": [
390
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00,  1.24it/s]\n",
391
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.38it/s]\n"
392
     ]
393
    },
394
    {
395
     "name": "stdout",
396
     "output_type": "stream",
397
     "text": [
398
      "epoch 13: {'accuracy': 0.7230392156862745, 'f1': 0.8186195826645265}\n"
399
     ]
400
    },
401
    {
402
     "name": "stderr",
403
     "output_type": "stream",
404
     "text": [
405
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:29<00:00,  1.29it/s]\n",
406
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.31it/s]\n"
407
     ]
408
    },
409
    {
410
     "name": "stdout",
411
     "output_type": "stream",
412
     "text": [
413
      "epoch 14: {'accuracy': 0.7058823529411765, 'f1': 0.8130841121495327}\n"
414
     ]
415
    },
416
    {
417
     "name": "stderr",
418
     "output_type": "stream",
419
     "text": [
420
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00,  1.27it/s]\n",
421
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.39it/s]\n"
422
     ]
423
    },
424
    {
425
     "name": "stdout",
426
     "output_type": "stream",
427
     "text": [
428
      "epoch 15: {'accuracy': 0.7181372549019608, 'f1': 0.8194662480376768}\n"
429
     ]
430
    },
431
    {
432
     "name": "stderr",
433
     "output_type": "stream",
434
     "text": [
435
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00,  1.29it/s]\n",
436
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.35it/s]\n"
437
     ]
438
    },
439
    {
440
     "name": "stdout",
441
     "output_type": "stream",
442
     "text": [
443
      "epoch 16: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n"
444
     ]
445
    },
446
    {
447
     "name": "stderr",
448
     "output_type": "stream",
449
     "text": [
450
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00,  1.27it/s]\n",
451
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.30it/s]\n"
452
     ]
453
    },
454
    {
455
     "name": "stdout",
456
     "output_type": "stream",
457
     "text": [
458
      "epoch 17: {'accuracy': 0.7205882352941176, 'f1': 0.820754716981132}\n"
459
     ]
460
    },
461
    {
462
     "name": "stderr",
463
     "output_type": "stream",
464
     "text": [
465
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00,  1.27it/s]\n",
466
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.36it/s]\n"
467
     ]
468
    },
469
    {
470
     "name": "stdout",
471
     "output_type": "stream",
472
     "text": [
473
      "epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.821656050955414}\n"
474
     ]
475
    },
476
    {
477
     "name": "stderr",
478
     "output_type": "stream",
479
     "text": [
480
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00,  1.29it/s]\n",
481
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.43it/s]"
482
     ]
483
    },
484
    {
485
     "name": "stdout",
486
     "output_type": "stream",
487
     "text": [
488
      "epoch 19: {'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
489
     ]
490
    },
491
    {
492
     "name": "stderr",
493
     "output_type": "stream",
494
     "text": [
495
      "\n"
496
     ]
497
    }
498
   ],
499
   "source": [
500
    "model.to(device)\n",
501
    "for epoch in range(num_epochs):\n",
502
    "    model.train()\n",
503
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
504
    "        batch.to(device)\n",
505
    "        outputs = model(**batch)\n",
506
    "        loss = outputs.loss\n",
507
    "        loss.backward()\n",
508
    "        optimizer.step()\n",
509
    "        lr_scheduler.step()\n",
510
    "        optimizer.zero_grad()\n",
511
    "\n",
512
    "    model.eval()\n",
513
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
514
    "        batch.to(device)\n",
515
    "        with torch.no_grad():\n",
516
    "            outputs = model(**batch)\n",
517
    "        predictions = outputs.logits.argmax(dim=-1)\n",
518
    "        predictions, references = predictions, batch[\"labels\"]\n",
519
    "        metric.add_batch(\n",
520
    "            predictions=predictions,\n",
521
    "            references=references,\n",
522
    "        )\n",
523
    "\n",
524
    "    eval_metric = metric.compute()\n",
525
    "    print(f\"epoch {epoch}:\", eval_metric)"
526
   ]
527
  },
528
  {
529
   "cell_type": "markdown",
530
   "id": "e1ff3f44",
531
   "metadata": {},
532
   "source": [
533
    "## Share adapters on the 🤗 Hub"
534
   ]
535
  },
536
  {
537
   "cell_type": "code",
538
   "execution_count": 8,
539
   "id": "0bf79cb5",
540
   "metadata": {},
541
   "outputs": [
542
    {
543
     "data": {
544
      "text/plain": [
545
       "CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prompt-tuning/commit/893a909d8499aa8778d58c781d43c3a8d9360de8', commit_message='Upload model', commit_description='', oid='893a909d8499aa8778d58c781d43c3a8d9360de8', pr_url=None, pr_revision=None, pr_num=None)"
546
      ]
547
     },
548
     "execution_count": 8,
549
     "metadata": {},
550
     "output_type": "execute_result"
551
    }
552
   ],
553
   "source": [
554
    "model.push_to_hub(\"smangrul/roberta-large-peft-prompt-tuning\", use_auth_token=True)"
555
   ]
556
  },
557
  {
558
   "cell_type": "markdown",
559
   "id": "73870ad7",
560
   "metadata": {},
561
   "source": [
562
    "## Load adapters from the Hub\n",
563
    "\n",
564
    "You can also directly load adapters from the Hub using the commands below:"
565
   ]
566
  },
567
  {
568
   "cell_type": "code",
569
   "execution_count": 9,
570
   "id": "0654a552",
571
   "metadata": {},
572
   "outputs": [
573
    {
574
     "data": {
575
      "application/vnd.jupyter.widget-view+json": {
576
       "model_id": "24581bb98582444ca6114b9fa267847f",
577
       "version_major": 2,
578
       "version_minor": 0
579
      },
580
      "text/plain": [
581
       "Downloading:   0%|          | 0.00/368 [00:00<?, ?B/s]"
582
      ]
583
     },
584
     "metadata": {},
585
     "output_type": "display_data"
586
    },
587
    {
588
     "name": "stderr",
589
     "output_type": "stream",
590
     "text": [
591
      "Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
592
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
593
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
594
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
595
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
596
     ]
597
    },
598
    {
599
     "data": {
600
      "application/vnd.jupyter.widget-view+json": {
601
       "model_id": "f1584da4d1c54cc3873a515182674980",
602
       "version_major": 2,
603
       "version_minor": 0
604
      },
605
      "text/plain": [
606
       "Downloading:   0%|          | 0.00/4.25M [00:00<?, ?B/s]"
607
      ]
608
     },
609
     "metadata": {},
610
     "output_type": "display_data"
611
    },
612
    {
613
     "name": "stderr",
614
     "output_type": "stream",
615
     "text": [
616
      "  0%|                                                                                                   | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
617
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.58it/s]"
618
     ]
619
    },
620
    {
621
     "name": "stdout",
622
     "output_type": "stream",
623
     "text": [
624
      "{'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
625
     ]
626
    },
627
    {
628
     "name": "stderr",
629
     "output_type": "stream",
630
     "text": [
631
      "\n"
632
     ]
633
    }
634
   ],
635
   "source": [
636
    "import torch\n",
637
    "from peft import PeftModel, PeftConfig\n",
638
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
639
    "\n",
640
    "peft_model_id = \"smangrul/roberta-large-peft-prompt-tuning\"\n",
641
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
642
    "inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
643
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
644
    "\n",
645
    "# Load the Lora model\n",
646
    "inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
647
    "\n",
648
    "inference_model.to(device)\n",
649
    "inference_model.eval()\n",
650
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
651
    "    batch.to(device)\n",
652
    "    with torch.no_grad():\n",
653
    "        outputs = inference_model(**batch)\n",
654
    "    predictions = outputs.logits.argmax(dim=-1)\n",
655
    "    predictions, references = predictions, batch[\"labels\"]\n",
656
    "    metric.add_batch(\n",
657
    "        predictions=predictions,\n",
658
    "        references=references,\n",
659
    "    )\n",
660
    "\n",
661
    "eval_metric = metric.compute()\n",
662
    "print(eval_metric)"
663
   ]
664
  }
665
 ],
666
 "metadata": {
667
  "kernelspec": {
668
   "display_name": "Python 3 (ipykernel)",
669
   "language": "python",
670
   "name": "python3"
671
  },
672
  "language_info": {
673
   "codemirror_mode": {
674
    "name": "ipython",
675
    "version": 3
676
   },
677
   "file_extension": ".py",
678
   "mimetype": "text/x-python",
679
   "name": "python",
680
   "nbconvert_exporter": "python",
681
   "pygments_lexer": "ipython3",
682
   "version": "3.10.4"
683
  },
684
  "vscode": {
685
   "interpreter": {
686
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
687
   }
688
  }
689
 },
690
 "nbformat": 4,
691
 "nbformat_minor": 5
692
}
693

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.