peft
692 строки · 31.0 Кб
1{
2"cells": [
3{
4"cell_type": "code",
5"execution_count": 1,
6"id": "9ff5004e",
7"metadata": {},
8"outputs": [
9{
10"name": "stdout",
11"output_type": "stream",
12"text": [
13"\n",
14"===================================BUG REPORT===================================\n",
15"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
16"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
17"================================================================================\n",
18"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
19"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
20"CUDA SETUP: Detected CUDA version 117\n",
21"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
22]
23}
24],
25"source": [
26"import argparse\n",
27"import os\n",
28"\n",
29"import torch\n",
30"from torch.optim import AdamW\n",
31"from torch.utils.data import DataLoader\n",
32"from peft import (\n",
33" get_peft_config,\n",
34" get_peft_model,\n",
35" get_peft_model_state_dict,\n",
36" set_peft_model_state_dict,\n",
37" PeftType,\n",
38" PrefixTuningConfig,\n",
39" PromptEncoderConfig,\n",
40" PromptTuningConfig,\n",
41")\n",
42"\n",
43"import evaluate\n",
44"from datasets import load_dataset\n",
45"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
46"from tqdm import tqdm"
47]
48},
49{
50"cell_type": "code",
51"execution_count": 2,
52"id": "e32c4a9e",
53"metadata": {},
54"outputs": [],
55"source": [
56"batch_size = 32\n",
57"model_name_or_path = \"roberta-large\"\n",
58"task = \"mrpc\"\n",
59"peft_type = PeftType.PROMPT_TUNING\n",
60"device = \"cuda\"\n",
61"num_epochs = 20"
62]
63},
64{
65"cell_type": "code",
66"execution_count": 3,
67"id": "622fe9c8",
68"metadata": {},
69"outputs": [],
70"source": [
71"peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n",
72"lr = 1e-3"
73]
74},
75{
76"cell_type": "code",
77"execution_count": 4,
78"id": "74e9efe0",
79"metadata": {},
80"outputs": [
81{
82"name": "stderr",
83"output_type": "stream",
84"text": [
85"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
86]
87},
88{
89"data": {
90"application/vnd.jupyter.widget-view+json": {
91"model_id": "76198cec552441818ff107910275e5be",
92"version_major": 2,
93"version_minor": 0
94},
95"text/plain": [
96" 0%| | 0/3 [00:00<?, ?it/s]"
97]
98},
99"metadata": {},
100"output_type": "display_data"
101},
102{
103"name": "stderr",
104"output_type": "stream",
105"text": [
106"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
107"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
108"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
109]
110}
111],
112"source": [
113"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
114" padding_side = \"left\"\n",
115"else:\n",
116" padding_side = \"right\"\n",
117"\n",
118"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
119"if getattr(tokenizer, \"pad_token_id\") is None:\n",
120" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
121"\n",
122"datasets = load_dataset(\"glue\", task)\n",
123"metric = evaluate.load(\"glue\", task)\n",
124"\n",
125"\n",
126"def tokenize_function(examples):\n",
127" # max_length=None => use the model max length (it's actually the default)\n",
128" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
129" return outputs\n",
130"\n",
131"\n",
132"tokenized_datasets = datasets.map(\n",
133" tokenize_function,\n",
134" batched=True,\n",
135" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
136")\n",
137"\n",
138"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
139"# transformers library\n",
140"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
141"\n",
142"\n",
143"def collate_fn(examples):\n",
144" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
145"\n",
146"\n",
147"# Instantiate dataloaders.\n",
148"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
149"eval_dataloader = DataLoader(\n",
150" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
151")"
152]
153},
154{
155"cell_type": "code",
156"execution_count": null,
157"id": "a3c15af0",
158"metadata": {},
159"outputs": [],
160"source": [
161"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
162"model = get_peft_model(model, peft_config)\n",
163"model.print_trainable_parameters()\n",
164"model"
165]
166},
167{
168"cell_type": "code",
169"execution_count": 6,
170"id": "6d3c5edb",
171"metadata": {},
172"outputs": [],
173"source": [
174"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
175"\n",
176"# Instantiate scheduler\n",
177"lr_scheduler = get_linear_schedule_with_warmup(\n",
178" optimizer=optimizer,\n",
179" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
180" num_training_steps=(len(train_dataloader) * num_epochs),\n",
181")"
182]
183},
184{
185"cell_type": "code",
186"execution_count": 7,
187"id": "4d279225",
188"metadata": {},
189"outputs": [
190{
191"name": "stderr",
192"output_type": "stream",
193"text": [
194" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
195"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:09<00:00, 1.13s/it]\n",
196"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00, 1.62it/s]\n"
197]
198},
199{
200"name": "stdout",
201"output_type": "stream",
202"text": [
203"epoch 0: {'accuracy': 0.678921568627451, 'f1': 0.7956318252730109}\n"
204]
205},
206{
207"name": "stderr",
208"output_type": "stream",
209"text": [
210"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:50<00:00, 1.04it/s]\n",
211"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
212]
213},
214{
215"name": "stdout",
216"output_type": "stream",
217"text": [
218"epoch 1: {'accuracy': 0.696078431372549, 'f1': 0.8171091445427728}\n"
219]
220},
221{
222"name": "stderr",
223"output_type": "stream",
224"text": [
225"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.19it/s]\n",
226"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00, 2.00it/s]\n"
227]
228},
229{
230"name": "stdout",
231"output_type": "stream",
232"text": [
233"epoch 2: {'accuracy': 0.6985294117647058, 'f1': 0.8161434977578476}\n"
234]
235},
236{
237"name": "stderr",
238"output_type": "stream",
239"text": [
240"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:37<00:00, 1.18it/s]\n",
241"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00, 2.09it/s]\n"
242]
243},
244{
245"name": "stdout",
246"output_type": "stream",
247"text": [
248"epoch 3: {'accuracy': 0.7058823529411765, 'f1': 0.7979797979797979}\n"
249]
250},
251{
252"name": "stderr",
253"output_type": "stream",
254"text": [
255"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:03<00:00, 1.07s/it]\n",
256"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00, 1.71it/s]\n"
257]
258},
259{
260"name": "stdout",
261"output_type": "stream",
262"text": [
263"epoch 4: {'accuracy': 0.696078431372549, 'f1': 0.8132530120481929}\n"
264]
265},
266{
267"name": "stderr",
268"output_type": "stream",
269"text": [
270"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:53<00:00, 1.01it/s]\n",
271"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.19it/s]\n"
272]
273},
274{
275"name": "stdout",
276"output_type": "stream",
277"text": [
278"epoch 5: {'accuracy': 0.7107843137254902, 'f1': 0.8121019108280254}\n"
279]
280},
281{
282"name": "stderr",
283"output_type": "stream",
284"text": [
285"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
286"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.20it/s]\n"
287]
288},
289{
290"name": "stdout",
291"output_type": "stream",
292"text": [
293"epoch 6: {'accuracy': 0.6911764705882353, 'f1': 0.7692307692307693}\n"
294]
295},
296{
297"name": "stderr",
298"output_type": "stream",
299"text": [
300"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.20it/s]\n",
301"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.18it/s]\n"
302]
303},
304{
305"name": "stdout",
306"output_type": "stream",
307"text": [
308"epoch 7: {'accuracy': 0.7156862745098039, 'f1': 0.8209876543209876}\n"
309]
310},
311{
312"name": "stderr",
313"output_type": "stream",
314"text": [
315"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
316"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
317]
318},
319{
320"name": "stdout",
321"output_type": "stream",
322"text": [
323"epoch 8: {'accuracy': 0.7205882352941176, 'f1': 0.8240740740740742}\n"
324]
325},
326{
327"name": "stderr",
328"output_type": "stream",
329"text": [
330"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.19it/s]\n",
331"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.21it/s]\n"
332]
333},
334{
335"name": "stdout",
336"output_type": "stream",
337"text": [
338"epoch 9: {'accuracy': 0.7205882352941176, 'f1': 0.8229813664596273}\n"
339]
340},
341{
342"name": "stderr",
343"output_type": "stream",
344"text": [
345"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.20it/s]\n",
346"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.35it/s]\n"
347]
348},
349{
350"name": "stdout",
351"output_type": "stream",
352"text": [
353"epoch 10: {'accuracy': 0.7156862745098039, 'f1': 0.8164556962025317}\n"
354]
355},
356{
357"name": "stderr",
358"output_type": "stream",
359"text": [
360"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
361"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
362]
363},
364{
365"name": "stdout",
366"output_type": "stream",
367"text": [
368"epoch 11: {'accuracy': 0.7058823529411765, 'f1': 0.8113207547169811}\n"
369]
370},
371{
372"name": "stderr",
373"output_type": "stream",
374"text": [
375"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00, 1.24it/s]\n",
376"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.48it/s]\n"
377]
378},
379{
380"name": "stdout",
381"output_type": "stream",
382"text": [
383"epoch 12: {'accuracy': 0.7009803921568627, 'f1': 0.7946127946127945}\n"
384]
385},
386{
387"name": "stderr",
388"output_type": "stream",
389"text": [
390"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00, 1.24it/s]\n",
391"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.38it/s]\n"
392]
393},
394{
395"name": "stdout",
396"output_type": "stream",
397"text": [
398"epoch 13: {'accuracy': 0.7230392156862745, 'f1': 0.8186195826645265}\n"
399]
400},
401{
402"name": "stderr",
403"output_type": "stream",
404"text": [
405"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:29<00:00, 1.29it/s]\n",
406"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.31it/s]\n"
407]
408},
409{
410"name": "stdout",
411"output_type": "stream",
412"text": [
413"epoch 14: {'accuracy': 0.7058823529411765, 'f1': 0.8130841121495327}\n"
414]
415},
416{
417"name": "stderr",
418"output_type": "stream",
419"text": [
420"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
421"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.39it/s]\n"
422]
423},
424{
425"name": "stdout",
426"output_type": "stream",
427"text": [
428"epoch 15: {'accuracy': 0.7181372549019608, 'f1': 0.8194662480376768}\n"
429]
430},
431{
432"name": "stderr",
433"output_type": "stream",
434"text": [
435"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00, 1.29it/s]\n",
436"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.35it/s]\n"
437]
438},
439{
440"name": "stdout",
441"output_type": "stream",
442"text": [
443"epoch 16: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n"
444]
445},
446{
447"name": "stderr",
448"output_type": "stream",
449"text": [
450"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
451"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.30it/s]\n"
452]
453},
454{
455"name": "stdout",
456"output_type": "stream",
457"text": [
458"epoch 17: {'accuracy': 0.7205882352941176, 'f1': 0.820754716981132}\n"
459]
460},
461{
462"name": "stderr",
463"output_type": "stream",
464"text": [
465"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
466"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.36it/s]\n"
467]
468},
469{
470"name": "stdout",
471"output_type": "stream",
472"text": [
473"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.821656050955414}\n"
474]
475},
476{
477"name": "stderr",
478"output_type": "stream",
479"text": [
480"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00, 1.29it/s]\n",
481"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.43it/s]"
482]
483},
484{
485"name": "stdout",
486"output_type": "stream",
487"text": [
488"epoch 19: {'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
489]
490},
491{
492"name": "stderr",
493"output_type": "stream",
494"text": [
495"\n"
496]
497}
498],
499"source": [
500"model.to(device)\n",
501"for epoch in range(num_epochs):\n",
502" model.train()\n",
503" for step, batch in enumerate(tqdm(train_dataloader)):\n",
504" batch.to(device)\n",
505" outputs = model(**batch)\n",
506" loss = outputs.loss\n",
507" loss.backward()\n",
508" optimizer.step()\n",
509" lr_scheduler.step()\n",
510" optimizer.zero_grad()\n",
511"\n",
512" model.eval()\n",
513" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
514" batch.to(device)\n",
515" with torch.no_grad():\n",
516" outputs = model(**batch)\n",
517" predictions = outputs.logits.argmax(dim=-1)\n",
518" predictions, references = predictions, batch[\"labels\"]\n",
519" metric.add_batch(\n",
520" predictions=predictions,\n",
521" references=references,\n",
522" )\n",
523"\n",
524" eval_metric = metric.compute()\n",
525" print(f\"epoch {epoch}:\", eval_metric)"
526]
527},
528{
529"cell_type": "markdown",
530"id": "e1ff3f44",
531"metadata": {},
532"source": [
533"## Share adapters on the 🤗 Hub"
534]
535},
536{
537"cell_type": "code",
538"execution_count": 8,
539"id": "0bf79cb5",
540"metadata": {},
541"outputs": [
542{
543"data": {
544"text/plain": [
545"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prompt-tuning/commit/893a909d8499aa8778d58c781d43c3a8d9360de8', commit_message='Upload model', commit_description='', oid='893a909d8499aa8778d58c781d43c3a8d9360de8', pr_url=None, pr_revision=None, pr_num=None)"
546]
547},
548"execution_count": 8,
549"metadata": {},
550"output_type": "execute_result"
551}
552],
553"source": [
554"model.push_to_hub(\"smangrul/roberta-large-peft-prompt-tuning\", use_auth_token=True)"
555]
556},
557{
558"cell_type": "markdown",
559"id": "73870ad7",
560"metadata": {},
561"source": [
562"## Load adapters from the Hub\n",
563"\n",
564"You can also directly load adapters from the Hub using the commands below:"
565]
566},
567{
568"cell_type": "code",
569"execution_count": 9,
570"id": "0654a552",
571"metadata": {},
572"outputs": [
573{
574"data": {
575"application/vnd.jupyter.widget-view+json": {
576"model_id": "24581bb98582444ca6114b9fa267847f",
577"version_major": 2,
578"version_minor": 0
579},
580"text/plain": [
581"Downloading: 0%| | 0.00/368 [00:00<?, ?B/s]"
582]
583},
584"metadata": {},
585"output_type": "display_data"
586},
587{
588"name": "stderr",
589"output_type": "stream",
590"text": [
591"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
592"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
593"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
594"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
595"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
596]
597},
598{
599"data": {
600"application/vnd.jupyter.widget-view+json": {
601"model_id": "f1584da4d1c54cc3873a515182674980",
602"version_major": 2,
603"version_minor": 0
604},
605"text/plain": [
606"Downloading: 0%| | 0.00/4.25M [00:00<?, ?B/s]"
607]
608},
609"metadata": {},
610"output_type": "display_data"
611},
612{
613"name": "stderr",
614"output_type": "stream",
615"text": [
616" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
617"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.58it/s]"
618]
619},
620{
621"name": "stdout",
622"output_type": "stream",
623"text": [
624"{'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
625]
626},
627{
628"name": "stderr",
629"output_type": "stream",
630"text": [
631"\n"
632]
633}
634],
635"source": [
636"import torch\n",
637"from peft import PeftModel, PeftConfig\n",
638"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
639"\n",
640"peft_model_id = \"smangrul/roberta-large-peft-prompt-tuning\"\n",
641"config = PeftConfig.from_pretrained(peft_model_id)\n",
642"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
643"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
644"\n",
645"# Load the Lora model\n",
646"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
647"\n",
648"inference_model.to(device)\n",
649"inference_model.eval()\n",
650"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
651" batch.to(device)\n",
652" with torch.no_grad():\n",
653" outputs = inference_model(**batch)\n",
654" predictions = outputs.logits.argmax(dim=-1)\n",
655" predictions, references = predictions, batch[\"labels\"]\n",
656" metric.add_batch(\n",
657" predictions=predictions,\n",
658" references=references,\n",
659" )\n",
660"\n",
661"eval_metric = metric.compute()\n",
662"print(eval_metric)"
663]
664}
665],
666"metadata": {
667"kernelspec": {
668"display_name": "Python 3 (ipykernel)",
669"language": "python",
670"name": "python3"
671},
672"language_info": {
673"codemirror_mode": {
674"name": "ipython",
675"version": 3
676},
677"file_extension": ".py",
678"mimetype": "text/x-python",
679"name": "python",
680"nbconvert_exporter": "python",
681"pygments_lexer": "ipython3",
682"version": "3.10.4"
683},
684"vscode": {
685"interpreter": {
686"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
687}
688}
689},
690"nbformat": 4,
691"nbformat_minor": 5
692}
693