peft

Форк
0
685 строк · 30.8 Кб
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": 1,
6
   "id": "a825ba6b",
7
   "metadata": {},
8
   "outputs": [
9
    {
10
     "name": "stdout",
11
     "output_type": "stream",
12
     "text": [
13
      "\n",
14
      "===================================BUG REPORT===================================\n",
15
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
16
      "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
17
      "================================================================================\n",
18
      "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
19
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
20
      "CUDA SETUP: Detected CUDA version 117\n",
21
      "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
22
     ]
23
    }
24
   ],
25
   "source": [
26
    "import argparse\n",
27
    "import os\n",
28
    "\n",
29
    "import torch\n",
30
    "from torch.optim import AdamW\n",
31
    "from torch.utils.data import DataLoader\n",
32
    "from peft import (\n",
33
    "    get_peft_config,\n",
34
    "    get_peft_model,\n",
35
    "    get_peft_model_state_dict,\n",
36
    "    set_peft_model_state_dict,\n",
37
    "    PeftType,\n",
38
    "    PrefixTuningConfig,\n",
39
    "    PromptEncoderConfig,\n",
40
    ")\n",
41
    "\n",
42
    "import evaluate\n",
43
    "from datasets import load_dataset\n",
44
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
45
    "from tqdm import tqdm"
46
   ]
47
  },
48
  {
49
   "cell_type": "code",
50
   "execution_count": 2,
51
   "id": "2bd7cbb2",
52
   "metadata": {},
53
   "outputs": [],
54
   "source": [
55
    "batch_size = 32\n",
56
    "model_name_or_path = \"roberta-large\"\n",
57
    "task = \"mrpc\"\n",
58
    "peft_type = PeftType.P_TUNING\n",
59
    "device = \"cuda\"\n",
60
    "num_epochs = 20"
61
   ]
62
  },
63
  {
64
   "cell_type": "code",
65
   "execution_count": 3,
66
   "id": "33d9b62e",
67
   "metadata": {},
68
   "outputs": [],
69
   "source": [
70
    "peft_config = PromptEncoderConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20, encoder_hidden_size=128)\n",
71
    "lr = 1e-3"
72
   ]
73
  },
74
  {
75
   "cell_type": "code",
76
   "execution_count": 4,
77
   "id": "152b6177",
78
   "metadata": {},
79
   "outputs": [
80
    {
81
     "name": "stderr",
82
     "output_type": "stream",
83
     "text": [
84
      "Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
85
     ]
86
    },
87
    {
88
     "data": {
89
      "application/vnd.jupyter.widget-view+json": {
90
       "model_id": "a451b90675e0451489cc6426465afa32",
91
       "version_major": 2,
92
       "version_minor": 0
93
      },
94
      "text/plain": [
95
       "  0%|          | 0/3 [00:00<?, ?it/s]"
96
      ]
97
     },
98
     "metadata": {},
99
     "output_type": "display_data"
100
    },
101
    {
102
     "name": "stderr",
103
     "output_type": "stream",
104
     "text": [
105
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
106
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
107
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
108
     ]
109
    }
110
   ],
111
   "source": [
112
    "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
113
    "    padding_side = \"left\"\n",
114
    "else:\n",
115
    "    padding_side = \"right\"\n",
116
    "\n",
117
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
118
    "if getattr(tokenizer, \"pad_token_id\") is None:\n",
119
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
120
    "\n",
121
    "datasets = load_dataset(\"glue\", task)\n",
122
    "metric = evaluate.load(\"glue\", task)\n",
123
    "\n",
124
    "\n",
125
    "def tokenize_function(examples):\n",
126
    "    # max_length=None => use the model max length (it's actually the default)\n",
127
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
128
    "    return outputs\n",
129
    "\n",
130
    "\n",
131
    "tokenized_datasets = datasets.map(\n",
132
    "    tokenize_function,\n",
133
    "    batched=True,\n",
134
    "    remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
135
    ")\n",
136
    "\n",
137
    "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
138
    "# transformers library\n",
139
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
140
    "\n",
141
    "\n",
142
    "def collate_fn(examples):\n",
143
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
144
    "\n",
145
    "\n",
146
    "# Instantiate dataloaders.\n",
147
    "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
148
    "eval_dataloader = DataLoader(\n",
149
    "    tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
150
    ")"
151
   ]
152
  },
153
  {
154
   "cell_type": "code",
155
   "execution_count": null,
156
   "id": "f6bc8144",
157
   "metadata": {},
158
   "outputs": [],
159
   "source": [
160
    "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
161
    "model = get_peft_model(model, peft_config)\n",
162
    "model.print_trainable_parameters()\n",
163
    "model"
164
   ]
165
  },
166
  {
167
   "cell_type": "code",
168
   "execution_count": 6,
169
   "id": "af41c571",
170
   "metadata": {},
171
   "outputs": [],
172
   "source": [
173
    "optimizer = AdamW(params=model.parameters(), lr=lr)\n",
174
    "\n",
175
    "# Instantiate scheduler\n",
176
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
177
    "    optimizer=optimizer,\n",
178
    "    num_warmup_steps=0,  # 0.06*(len(train_dataloader) * num_epochs),\n",
179
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
180
    ")"
181
   ]
182
  },
183
  {
184
   "cell_type": "code",
185
   "execution_count": 7,
186
   "id": "90993c93",
187
   "metadata": {},
188
   "outputs": [
189
    {
190
     "name": "stderr",
191
     "output_type": "stream",
192
     "text": [
193
      "  0%|                                                                                                  | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
194
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.54it/s]\n",
195
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.91it/s]\n"
196
     ]
197
    },
198
    {
199
     "name": "stdout",
200
     "output_type": "stream",
201
     "text": [
202
      "epoch 0: {'accuracy': 0.6985294117647058, 'f1': 0.8172362555720655}\n"
203
     ]
204
    },
205
    {
206
     "name": "stderr",
207
     "output_type": "stream",
208
     "text": [
209
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.61it/s]\n",
210
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.87it/s]\n"
211
     ]
212
    },
213
    {
214
     "name": "stdout",
215
     "output_type": "stream",
216
     "text": [
217
      "epoch 1: {'accuracy': 0.6936274509803921, 'f1': 0.806201550387597}\n"
218
     ]
219
    },
220
    {
221
     "name": "stderr",
222
     "output_type": "stream",
223
     "text": [
224
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.61it/s]\n",
225
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.88it/s]\n"
226
     ]
227
    },
228
    {
229
     "name": "stdout",
230
     "output_type": "stream",
231
     "text": [
232
      "epoch 2: {'accuracy': 0.7132352941176471, 'f1': 0.8224582701062216}\n"
233
     ]
234
    },
235
    {
236
     "name": "stderr",
237
     "output_type": "stream",
238
     "text": [
239
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.61it/s]\n",
240
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.87it/s]\n"
241
     ]
242
    },
243
    {
244
     "name": "stdout",
245
     "output_type": "stream",
246
     "text": [
247
      "epoch 3: {'accuracy': 0.7083333333333334, 'f1': 0.8199697428139183}\n"
248
     ]
249
    },
250
    {
251
     "name": "stderr",
252
     "output_type": "stream",
253
     "text": [
254
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.61it/s]\n",
255
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.90it/s]\n"
256
     ]
257
    },
258
    {
259
     "name": "stdout",
260
     "output_type": "stream",
261
     "text": [
262
      "epoch 4: {'accuracy': 0.7205882352941176, 'f1': 0.8246153846153846}\n"
263
     ]
264
    },
265
    {
266
     "name": "stderr",
267
     "output_type": "stream",
268
     "text": [
269
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.62it/s]\n",
270
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.90it/s]\n"
271
     ]
272
    },
273
    {
274
     "name": "stdout",
275
     "output_type": "stream",
276
     "text": [
277
      "epoch 5: {'accuracy': 0.7009803921568627, 'f1': 0.8200589970501474}\n"
278
     ]
279
    },
280
    {
281
     "name": "stderr",
282
     "output_type": "stream",
283
     "text": [
284
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.59it/s]\n",
285
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.89it/s]\n"
286
     ]
287
    },
288
    {
289
     "name": "stdout",
290
     "output_type": "stream",
291
     "text": [
292
      "epoch 6: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
293
     ]
294
    },
295
    {
296
     "name": "stderr",
297
     "output_type": "stream",
298
     "text": [
299
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.60it/s]\n",
300
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.86it/s]\n"
301
     ]
302
    },
303
    {
304
     "name": "stdout",
305
     "output_type": "stream",
306
     "text": [
307
      "epoch 7: {'accuracy': 0.7230392156862745, 'f1': 0.8269525267993874}\n"
308
     ]
309
    },
310
    {
311
     "name": "stderr",
312
     "output_type": "stream",
313
     "text": [
314
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:34<00:00,  3.34it/s]\n",
315
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.88it/s]\n"
316
     ]
317
    },
318
    {
319
     "name": "stdout",
320
     "output_type": "stream",
321
     "text": [
322
      "epoch 8: {'accuracy': 0.7254901960784313, 'f1': 0.8297872340425533}\n"
323
     ]
324
    },
325
    {
326
     "name": "stderr",
327
     "output_type": "stream",
328
     "text": [
329
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.60it/s]\n",
330
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.77it/s]\n"
331
     ]
332
    },
333
    {
334
     "name": "stdout",
335
     "output_type": "stream",
336
     "text": [
337
      "epoch 9: {'accuracy': 0.7230392156862745, 'f1': 0.828006088280061}\n"
338
     ]
339
    },
340
    {
341
     "name": "stderr",
342
     "output_type": "stream",
343
     "text": [
344
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.58it/s]\n",
345
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.88it/s]\n"
346
     ]
347
    },
348
    {
349
     "name": "stdout",
350
     "output_type": "stream",
351
     "text": [
352
      "epoch 10: {'accuracy': 0.7181372549019608, 'f1': 0.8183254344391785}\n"
353
     ]
354
    },
355
    {
356
     "name": "stderr",
357
     "output_type": "stream",
358
     "text": [
359
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.60it/s]\n",
360
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.87it/s]\n"
361
     ]
362
    },
363
    {
364
     "name": "stdout",
365
     "output_type": "stream",
366
     "text": [
367
      "epoch 11: {'accuracy': 0.7132352941176471, 'f1': 0.803361344537815}\n"
368
     ]
369
    },
370
    {
371
     "name": "stderr",
372
     "output_type": "stream",
373
     "text": [
374
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.59it/s]\n",
375
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.85it/s]\n"
376
     ]
377
    },
378
    {
379
     "name": "stdout",
380
     "output_type": "stream",
381
     "text": [
382
      "epoch 12: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
383
     ]
384
    },
385
    {
386
     "name": "stderr",
387
     "output_type": "stream",
388
     "text": [
389
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.59it/s]\n",
390
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.85it/s]\n"
391
     ]
392
    },
393
    {
394
     "name": "stdout",
395
     "output_type": "stream",
396
     "text": [
397
      "epoch 13: {'accuracy': 0.7181372549019608, 'f1': 0.8254931714719272}\n"
398
     ]
399
    },
400
    {
401
     "name": "stderr",
402
     "output_type": "stream",
403
     "text": [
404
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.59it/s]\n",
405
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.87it/s]\n"
406
     ]
407
    },
408
    {
409
     "name": "stdout",
410
     "output_type": "stream",
411
     "text": [
412
      "epoch 14: {'accuracy': 0.7156862745098039, 'f1': 0.8253012048192772}\n"
413
     ]
414
    },
415
    {
416
     "name": "stderr",
417
     "output_type": "stream",
418
     "text": [
419
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.59it/s]\n",
420
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.84it/s]\n"
421
     ]
422
    },
423
    {
424
     "name": "stdout",
425
     "output_type": "stream",
426
     "text": [
427
      "epoch 15: {'accuracy': 0.7230392156862745, 'f1': 0.8242612752721618}\n"
428
     ]
429
    },
430
    {
431
     "name": "stderr",
432
     "output_type": "stream",
433
     "text": [
434
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.49it/s]\n",
435
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  5.84it/s]\n"
436
     ]
437
    },
438
    {
439
     "name": "stdout",
440
     "output_type": "stream",
441
     "text": [
442
      "epoch 16: {'accuracy': 0.7181372549019608, 'f1': 0.8200312989045383}\n"
443
     ]
444
    },
445
    {
446
     "name": "stderr",
447
     "output_type": "stream",
448
     "text": [
449
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:32<00:00,  3.49it/s]\n",
450
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.84it/s]\n"
451
     ]
452
    },
453
    {
454
     "name": "stdout",
455
     "output_type": "stream",
456
     "text": [
457
      "epoch 17: {'accuracy': 0.7107843137254902, 'f1': 0.8217522658610272}\n"
458
     ]
459
    },
460
    {
461
     "name": "stderr",
462
     "output_type": "stream",
463
     "text": [
464
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.60it/s]\n",
465
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.88it/s]\n"
466
     ]
467
    },
468
    {
469
     "name": "stdout",
470
     "output_type": "stream",
471
     "text": [
472
      "epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
473
     ]
474
    },
475
    {
476
     "name": "stderr",
477
     "output_type": "stream",
478
     "text": [
479
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:31<00:00,  3.61it/s]\n",
480
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.89it/s]"
481
     ]
482
    },
483
    {
484
     "name": "stdout",
485
     "output_type": "stream",
486
     "text": [
487
      "epoch 19: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
488
     ]
489
    },
490
    {
491
     "name": "stderr",
492
     "output_type": "stream",
493
     "text": [
494
      "\n"
495
     ]
496
    }
497
   ],
498
   "source": [
499
    "model.to(device)\n",
500
    "for epoch in range(num_epochs):\n",
501
    "    model.train()\n",
502
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
503
    "        batch.to(device)\n",
504
    "        outputs = model(**batch)\n",
505
    "        loss = outputs.loss\n",
506
    "        loss.backward()\n",
507
    "        optimizer.step()\n",
508
    "        lr_scheduler.step()\n",
509
    "        optimizer.zero_grad()\n",
510
    "\n",
511
    "    model.eval()\n",
512
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
513
    "        batch.to(device)\n",
514
    "        with torch.no_grad():\n",
515
    "            outputs = model(**batch)\n",
516
    "        predictions = outputs.logits.argmax(dim=-1)\n",
517
    "        predictions, references = predictions, batch[\"labels\"]\n",
518
    "        metric.add_batch(\n",
519
    "            predictions=predictions,\n",
520
    "            references=references,\n",
521
    "        )\n",
522
    "\n",
523
    "    eval_metric = metric.compute()\n",
524
    "    print(f\"epoch {epoch}:\", eval_metric)"
525
   ]
526
  },
527
  {
528
   "cell_type": "markdown",
529
   "id": "a43bd9fb",
530
   "metadata": {},
531
   "source": [
532
    "## Share adapters on the 🤗 Hub"
533
   ]
534
  },
535
  {
536
   "cell_type": "code",
537
   "execution_count": 8,
538
   "id": "871b75aa",
539
   "metadata": {},
540
   "outputs": [
541
    {
542
     "data": {
543
      "text/plain": [
544
       "CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-p-tuning/commit/fa7abe613f498c76df5e16c85d9c19c3019587a7', commit_message='Upload model', commit_description='', oid='fa7abe613f498c76df5e16c85d9c19c3019587a7', pr_url=None, pr_revision=None, pr_num=None)"
545
      ]
546
     },
547
     "execution_count": 8,
548
     "metadata": {},
549
     "output_type": "execute_result"
550
    }
551
   ],
552
   "source": [
553
    "model.push_to_hub(\"smangrul/roberta-large-peft-p-tuning\", use_auth_token=True)"
554
   ]
555
  },
556
  {
557
   "cell_type": "markdown",
558
   "id": "1c6a9036",
559
   "metadata": {},
560
   "source": [
561
    "## Load adapters from the Hub\n",
562
    "\n",
563
    "You can also directly load adapters from the Hub using the commands below:"
564
   ]
565
  },
566
  {
567
   "cell_type": "code",
568
   "execution_count": 9,
569
   "id": "91b0b8f5",
570
   "metadata": {},
571
   "outputs": [
572
    {
573
     "name": "stderr",
574
     "output_type": "stream",
575
     "text": [
576
      "Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']\n",
577
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
578
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
579
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
580
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
581
     ]
582
    },
583
    {
584
     "data": {
585
      "application/vnd.jupyter.widget-view+json": {
586
       "model_id": "e650799d58ec4bd1b21b6bc28ddf2069",
587
       "version_major": 2,
588
       "version_minor": 0
589
      },
590
      "text/plain": [
591
       "Downloading:   0%|          | 0.00/4.29M [00:00<?, ?B/s]"
592
      ]
593
     },
594
     "metadata": {},
595
     "output_type": "display_data"
596
    },
597
    {
598
     "name": "stderr",
599
     "output_type": "stream",
600
     "text": [
601
      "  0%|                                                                                                   | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
602
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  7.18it/s]"
603
     ]
604
    },
605
    {
606
     "name": "stdout",
607
     "output_type": "stream",
608
     "text": [
609
      "{'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
610
     ]
611
    },
612
    {
613
     "name": "stderr",
614
     "output_type": "stream",
615
     "text": [
616
      "\n"
617
     ]
618
    }
619
   ],
620
   "source": [
621
    "import torch\n",
622
    "from peft import PeftModel, PeftConfig\n",
623
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
624
    "\n",
625
    "peft_model_id = \"smangrul/roberta-large-peft-p-tuning\"\n",
626
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
627
    "inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
628
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
629
    "\n",
630
    "# Load the Lora model\n",
631
    "inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
632
    "\n",
633
    "inference_model.to(device)\n",
634
    "inference_model.eval()\n",
635
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
636
    "    batch.to(device)\n",
637
    "    with torch.no_grad():\n",
638
    "        outputs = inference_model(**batch)\n",
639
    "    predictions = outputs.logits.argmax(dim=-1)\n",
640
    "    predictions, references = predictions, batch[\"labels\"]\n",
641
    "    metric.add_batch(\n",
642
    "        predictions=predictions,\n",
643
    "        references=references,\n",
644
    "    )\n",
645
    "\n",
646
    "eval_metric = metric.compute()\n",
647
    "print(eval_metric)"
648
   ]
649
  },
650
  {
651
   "cell_type": "code",
652
   "execution_count": null,
653
   "id": "1a8d69d1",
654
   "metadata": {},
655
   "outputs": [],
656
   "source": []
657
  }
658
 ],
659
 "metadata": {
660
  "kernelspec": {
661
   "display_name": "Python 3 (ipykernel)",
662
   "language": "python",
663
   "name": "python3"
664
  },
665
  "language_info": {
666
   "codemirror_mode": {
667
    "name": "ipython",
668
    "version": 3
669
   },
670
   "file_extension": ".py",
671
   "mimetype": "text/x-python",
672
   "name": "python",
673
   "nbconvert_exporter": "python",
674
   "pygments_lexer": "ipython3",
675
   "version": "3.10.5 (v3.10.5:f377153967, Jun  6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
676
  },
677
  "vscode": {
678
   "interpreter": {
679
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
680
   }
681
  }
682
 },
683
 "nbformat": 4,
684
 "nbformat_minor": 5
685
}
686

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.