google-research

Форк
0
/
finetuning.py 
903 строки · 32.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
#!/usr/bin/env python
17
# coding=utf-8
18
"""Fine-tuning the library models for sequence classification."""
19

20
import argparse
21
import dataclasses
22
import json
23
import logging
24
import math
25
import os
26
import random
27
import shutil
28
from typing import Any, Dict, List, Optional, Tuple
29

30
import accelerate
31
import datasets
32
from datasets import load_dataset
33
from datasets import load_metric
34
import numpy as np
35
import pandas as pd
36
import torch
37
from torch.utils.data import DataLoader
38
from tqdm.auto import tqdm
39
from transformers import AdamW
40
from transformers import AutoConfig
41
from transformers import AutoModelForSequenceClassification
42
from transformers import AutoTokenizer
43
from transformers import DataCollatorWithPadding
44
from transformers import default_data_collator
45
from transformers import get_scheduler
46
from transformers import set_seed
47
from transformers.configuration_utils import PretrainedConfig
48
from transformers.file_utils import ExplicitEnum
49
from transformers.modeling_utils import PreTrainedModel
50
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
51
from transformers.trainer_utils import IntervalStrategy
52

53
logger = logging.getLogger(__name__)
54

55

56
class Split(ExplicitEnum):
57
  TRAIN = 'train'
58
  EVAL = 'eval'
59
  TEST = 'test'
60
  INFER = 'infer'
61

62

63
@dataclasses.dataclass
64
class FTModelArguments:
65
  """Arguments pertaining to which config/tokenizer/model we are going to fine-tune from."""
66
  model_name_or_path: str = dataclasses.field(
67
      metadata={
68
          'help':
69
              'Path to pretrained model or model identifier from huggingface.co/models.'
70
      })
71
  use_fast_tokenizer: Optional[bool] = dataclasses.field(
72
      default=True,
73
      metadata={
74
          'help':
75
              'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
76
      },
77
  )
78
  cache_dir: Optional[str] = dataclasses.field(
79
      default=None,
80
      metadata={
81
          'help':
82
              'Where do you want to store the pretrained models downloaded from huggingface.co.'
83
      },
84
  )
85

86

87
@dataclasses.dataclass
88
class FTDataArguments:
89
  """Arguments pertaining to what data we are going to input our model for training and evaluation."""
90
  train_file: str = dataclasses.field(
91
      default=None,
92
      metadata={'help': 'A csv or a json file containing the training data.'})
93
  eval_file: Optional[str] = dataclasses.field(
94
      default=None,
95
      metadata={'help': 'A csv or a json file containing the validation data.'})
96
  test_file: Optional[str] = dataclasses.field(
97
      default=None,
98
      metadata={'help': 'A csv or a json file containing the test data.'})
99
  infer_file: Optional[str] = dataclasses.field(
100
      default=None,
101
      metadata={
102
          'help': 'A csv or a json file containing the data to predict on.'
103
      })
104
  task_name: Optional[str] = dataclasses.field(
105
      default=None,
106
      metadata={'help': 'The name of the task to train on.'},
107
  )
108
  label_list: Optional[List[str]] = dataclasses.field(
109
      default=None, metadata={'help': 'The list of labels for the task.'})
110

111
  max_length: Optional[int] = dataclasses.field(
112
      default=128,
113
      metadata={
114
          'help':
115
              'The maximum total input sequence length after tokenization. Sequences longer '
116
              'than this will be truncated, sequences shorter will be padded.'
117
      },
118
  )
119
  pad_to_max_length: Optional[bool] = dataclasses.field(
120
      default=False,
121
      metadata={
122
          'help':
123
              'Whether to pad all samples to `max_seq_length`. '
124
              'If False, will pad the samples dynamically when batching to the maximum length in the batch.'
125
      },
126
  )
127

128

129
@dataclasses.dataclass
130
class FTTrainingArguments():
131
  """Training arguments pertaining to the training loop itself."""
132

133
  output_dir: str = dataclasses.field(
134
      metadata={
135
          'help':
136
              'The output directory where the model predictions and checkpoints will be written.'
137
      })
138
  do_train: Optional[bool] = dataclasses.field(
139
      default=False,
140
      metadata={'help': 'Whether to run training or not.'},
141
  )
142
  do_eval: Optional[bool] = dataclasses.field(
143
      default=False,
144
      metadata={
145
          'help': 'Whether to run evaluation on the validation set or not.'
146
      },
147
  )
148
  do_predict: Optional[bool] = dataclasses.field(
149
      default=False,
150
      metadata={
151
          'help': 'Whether to run inference on the inference set or not.'
152
      },
153
  )
154
  seed: Optional[int] = dataclasses.field(
155
      default=42,
156
      metadata={
157
          'help': 'Random seed that will be set at the beginning of training.'
158
      },
159
  )
160
  per_device_train_batch_size: Optional[int] = dataclasses.field(
161
      default=8,
162
      metadata={'help': 'The batch size per GPU/TPU core/CPU for training.'},
163
  )
164
  per_device_eval_batch_size: Optional[int] = dataclasses.field(
165
      default=8,
166
      metadata={'help': 'The batch size per GPU/TPU core/CPU for evaluation.'},
167
  )
168
  weight_decay: Optional[float] = dataclasses.field(
169
      default=0.0,
170
      metadata={
171
          'help':
172
              'The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] optimizer.'
173
      },
174
  )
175
  learning_rate: Optional[float] = dataclasses.field(
176
      default=5e-5,
177
      metadata={'help': 'The initial learning rate for [`AdamW`] optimizer.'},
178
  )
179
  gradient_accumulation_steps: Optional[int] = dataclasses.field(
180
      default=1,
181
      metadata={
182
          'help':
183
              'Number of updates steps to accumulate the gradients for, before performing a backward/update pass.'
184
      },
185
  )
186
  max_steps: Optional[int] = dataclasses.field(
187
      default=-1,
188
      metadata={
189
          'help':
190
              'If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.'
191
      },
192
  )
193
  lr_scheduler_type: Optional[str] = dataclasses.field(
194
      default='linear', metadata={'help': 'The scheduler type to use.'})
195
  warmup_steps: Optional[int] = dataclasses.field(
196
      default=1,
197
      metadata={
198
          'help':
199
              'Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.'
200
      },
201
  )
202
  evaluation_strategy: Optional[str] = dataclasses.field(
203
      default='no',
204
      metadata={
205
          'help':
206
              'The evaluation strategy to adopt during training. Possible values are: ["no", "step", "epoch]'
207
      })
208
  eval_steps: Optional[int] = dataclasses.field(
209
      default=1,
210
      metadata={
211
          'help':
212
              'Number of update steps between two evaluations if `evaluation_strategy="steps"`.'
213
      },
214
  )
215
  eval_metric: Optional[str] = dataclasses.field(
216
      default='accuracy',
217
      metadata={'help': 'The evaluation metric used for the task.'})
218
  keep_checkpoint_max: Optional[int] = dataclasses.field(
219
      default=1,
220
      metadata={'help': 'The maximum number of best checkpoint files to keep.'},
221
  )
222
  early_stopping_patience: Optional[int] = dataclasses.field(
223
      default=10,
224
      metadata={
225
          'help':
226
              'Number of evaluation calls with no improvement after which training will be stopped.'
227
      },
228
  )
229
  early_stopping_threshold: Optional[float] = dataclasses.field(
230
      default=0.0,
231
      metadata={
232
          'help':
233
              'How much the specified evaluation metric must improve to satisfy early stopping conditions.'
234
      },
235
  )
236

237

238
def train(args,
239
          accelerator,
240
          model,
241
          tokenizer,
242
          train_dataloader,
243
          optimizer,
244
          lr_scheduler,
245
          eval_dataloader = None):
246
  """Train a model on the given training data."""
247

248
  total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
249

250
  logger.info('***** Running training *****')
251
  logger.info('  Num examples = %d', args.num_examples[Split.TRAIN.value])
252
  logger.info('  Instantaneous batch size per device = %d',
253
              args.per_device_train_batch_size)
254
  logger.info(
255
      '  Total train batch size (w. parallel, distributed & accumulation) = %d',
256
      total_batch_size)
257
  logger.info('  Gradient Accumulation steps = %d',
258
              args.gradient_accumulation_steps)
259
  logger.info('  Total optimization steps = %d', args.max_steps)
260

261
  # Only show the progress bar once on each machine.
262
  progress_bar = tqdm(
263
      range(args.max_steps), disable=not accelerator.is_local_main_process)
264

265
  checkpoints = None
266
  eval_results = None
267
  best_checkpoint = None
268
  best_eval_result = None
269
  early_stopping_patience_counter = 0
270
  should_training_stop = False
271
  epoch = 0
272
  completed_steps = 0
273
  train_loss = 0.0
274
  model.zero_grad()
275

276
  for _ in range(args.num_train_epochs):
277
    epoch += 1
278
    model.train()
279
    for step, batch in enumerate(train_dataloader):
280
      outputs = model(**batch)
281
      loss = outputs.loss
282
      loss = loss / args.gradient_accumulation_steps
283
      accelerator.backward(loss)
284
      train_loss += loss.item()
285

286
      if step % args.gradient_accumulation_steps == 0 or step == len(
287
          train_dataloader) - 1:
288
        optimizer.step()
289
        lr_scheduler.step()
290
        optimizer.zero_grad()
291
        progress_bar.update(1)
292
        completed_steps += 1
293

294
        # Evaluate during training
295
        if eval_dataloader is not None and args.evaluation_strategy == IntervalStrategy.STEPS.value and args.eval_steps > 0 and completed_steps % args.eval_steps == 0:
296
          accelerator.wait_for_everyone()
297
          new_checkpoint = f'checkpoint-{IntervalStrategy.STEPS.value}-{completed_steps}'
298
          new_eval_result = evaluate(args, accelerator, eval_dataloader, 'eval',
299
                                     model, new_checkpoint)[args.eval_metric]
300
          logger.info('Evaluation result at step %d: %s = %f', completed_steps,
301
                      args.eval_metric, new_eval_result)
302
          if checkpoints is None:
303
            checkpoints = np.array([new_checkpoint])
304
            eval_results = np.array([new_eval_result])
305
            best_checkpoint = new_checkpoint
306
            best_eval_result = new_eval_result
307
          else:
308
            if new_eval_result - best_eval_result > args.early_stopping_threshold:
309
              best_checkpoint = new_checkpoint
310
              best_eval_result = new_eval_result
311
              early_stopping_patience_counter = 0
312
            else:
313
              if new_eval_result == best_eval_result:
314
                best_checkpoint = new_checkpoint
315
                best_eval_result = new_eval_result
316
              early_stopping_patience_counter += 1
317

318
            if early_stopping_patience_counter >= args.early_stopping_patience:
319
              should_training_stop = True
320

321
            checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
322
            eval_results = np.append(eval_results, [new_eval_result], axis=0)
323
            sorted_ids = np.argsort(eval_results)
324
            eval_results = eval_results[sorted_ids]
325
            checkpoints = checkpoints[sorted_ids]
326

327
          if len(checkpoints) > args.keep_checkpoint_max:
328
            # Delete the current worst checkpoint
329
            checkpoint_to_remove, *checkpoints = checkpoints
330
            eval_results = eval_results[1:]
331
            if checkpoint_to_remove != new_checkpoint:
332
              if accelerator.is_main_process:
333
                shutil.rmtree(
334
                    os.path.join(args.output_dir, checkpoint_to_remove),
335
                    ignore_errors=True)
336
              accelerator.wait_for_everyone()
337

338
          if new_checkpoint in checkpoints:
339
            # Save model checkpoint
340
            checkpoint_output_dir = os.path.join(args.output_dir,
341
                                                 new_checkpoint)
342
            if accelerator.is_main_process:
343
              if not os.path.exists(checkpoint_output_dir):
344
                os.makedirs(checkpoint_output_dir)
345
            accelerator.wait_for_everyone()
346
            unwrapped_model = accelerator.unwrap_model(model)
347
            unwrapped_model.save_pretrained(
348
                checkpoint_output_dir, save_function=accelerator.save)
349
            if accelerator.is_main_process:
350
              tokenizer.save_pretrained(checkpoint_output_dir)
351
              logger.info('Saving model checkpoint to %s',
352
                          checkpoint_output_dir)
353

354
      if completed_steps >= args.max_steps:
355
        break
356

357
      if should_training_stop:
358
        break
359

360
    # Evaluate during training
361
    if eval_dataloader is not None and args.evaluation_strategy == IntervalStrategy.EPOCH.value:
362
      accelerator.wait_for_everyone()
363
      new_checkpoint = f'checkpoint-{IntervalStrategy.EPOCH.value}-{epoch}'
364
      new_eval_result = evaluate(args, accelerator, eval_dataloader, 'eval',
365
                                 model, new_checkpoint)[args.eval_metric]
366
      logger.info('Evaluation result at epoch %d: %s = %f', epoch,
367
                  args.eval_metric, new_eval_result)
368

369
      if checkpoints is None:
370
        checkpoints = np.array([new_checkpoint])
371
        eval_results = np.array([new_eval_result])
372
        best_checkpoint = new_checkpoint
373
        best_eval_result = new_eval_result
374
      else:
375
        if new_eval_result - best_eval_result > args.early_stopping_threshold:
376
          best_checkpoint = new_checkpoint
377
          best_eval_result = new_eval_result
378
          early_stopping_patience_counter = 0
379
        else:
380
          if new_eval_result == best_eval_result:
381
            best_checkpoint = new_checkpoint
382
            best_eval_result = new_eval_result
383
          early_stopping_patience_counter += 1
384

385
        if early_stopping_patience_counter >= args.early_stopping_patience:
386
          should_training_stop = True
387

388
        checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
389
        eval_results = np.append(eval_results, [new_eval_result], axis=0)
390
        sorted_ids = np.argsort(eval_results)
391
        eval_results = eval_results[sorted_ids]
392
        checkpoints = checkpoints[sorted_ids]
393

394
      if len(checkpoints) > args.keep_checkpoint_max:
395
        # Delete the current worst checkpoint
396
        checkpoint_to_remove, *checkpoints = checkpoints
397
        eval_results = eval_results[1:]
398
        if checkpoint_to_remove != new_checkpoint:
399
          if accelerator.is_main_process:
400
            shutil.rmtree(
401
                os.path.join(args.output_dir, checkpoint_to_remove),
402
                ignore_errors=True)
403
          accelerator.wait_for_everyone()
404

405
      if new_checkpoint in checkpoints:
406
        # Save model checkpoint
407
        checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)
408
        if accelerator.is_main_process:
409
          if not os.path.exists(checkpoint_output_dir):
410
            os.makedirs(checkpoint_output_dir)
411
        accelerator.wait_for_everyone()
412
        unwrapped_model = accelerator.unwrap_model(model)
413
        unwrapped_model.save_pretrained(
414
            checkpoint_output_dir, save_function=accelerator.save)
415
        if accelerator.is_main_process:
416
          tokenizer.save_pretrained(checkpoint_output_dir)
417
          logger.info('Saving model checkpoint to %s', checkpoint_output_dir)
418

419
    if completed_steps >= args.max_steps:
420
      break
421

422
    if should_training_stop:
423
      break
424

425
  if best_checkpoint is not None:
426
    # Save the best checkpoint
427
    logger.info('Best checkpoint: %s', best_checkpoint)
428
    logger.info('Best evaluation result: %s = %f', args.eval_metric,
429
                best_eval_result)
430
    best_checkpoint_output_dir = os.path.join(args.output_dir, best_checkpoint)
431
    if accelerator.is_main_process:
432
      shutil.move(best_checkpoint_output_dir,
433
                  os.path.join(args.output_dir, 'best-checkpoint'))
434
      shutil.rmtree(best_checkpoint_output_dir, ignore_errors=True)
435
    accelerator.wait_for_everyone()
436

437
  else:
438
    # Assume that the last checkpoint is the best checkpoint and save it
439
    checkpoint_output_dir = os.path.join(args.output_dir, 'best-checkpoint')
440
    if not os.path.exists(checkpoint_output_dir):
441
      os.makedirs(checkpoint_output_dir)
442

443
    accelerator.wait_for_everyone()
444
    unwrapped_model = accelerator.unwrap_model(model)
445
    unwrapped_model.save_pretrained(
446
        checkpoint_output_dir, save_function=accelerator.save)
447
    if accelerator.is_main_process:
448
      tokenizer.save_pretrained(checkpoint_output_dir)
449
      logger.info('Saving model checkpoint to %s', checkpoint_output_dir)
450
  return completed_steps, train_loss / completed_steps
451

452

453
def evaluate(args,
454
             accelerator,
455
             dataloader,
456
             eval_set,
457
             model,
458
             checkpoint,
459
             has_labels = True,
460
             write_to_file = True):
461
  """Evaluate a model checkpoint on the given evaluation data."""
462

463
  num_examples = args.num_examples[eval_set]
464
  eval_metric = None
465
  completed_steps = 0
466
  eval_loss = 0.0
467
  all_predictions = None
468
  all_references = None
469
  all_probabilities = None
470

471
  if has_labels:
472
    # Get the metric function
473
    eval_metric = load_metric(args.eval_metric)
474

475
  eval_results = {}
476
  model.eval()
477
  for _, batch in enumerate(dataloader):
478
    with torch.no_grad():
479
      outputs = model(**batch)
480

481
    eval_loss += outputs.loss.item()
482
    logits = outputs.logits
483
    predictions = logits.argmax(
484
        dim=-1) if not args.is_regression else logits.squeeze()
485
    predictions = accelerator.gather(predictions)
486

487
    if all_predictions is None:
488
      all_predictions = predictions.detach().cpu().numpy()
489
    else:
490
      all_predictions = np.append(
491
          all_predictions, predictions.detach().cpu().numpy(), axis=0)
492

493
    if not args.is_regression:
494
      probabilities = logits.softmax(dim=-1).max(dim=-1).values
495
      probabilities = accelerator.gather(probabilities)
496
      if all_probabilities is None:
497
        all_probabilities = probabilities.detach().cpu().numpy()
498
      else:
499
        all_probabilities = np.append(
500
            all_probabilities, probabilities.detach().cpu().numpy(), axis=0)
501

502
    if has_labels:
503
      references = batch['labels']
504
      references = accelerator.gather(references)
505
      if all_references is None:
506
        all_references = references.detach().cpu().numpy()
507
      else:
508
        all_references = np.append(
509
            all_references, references.detach().cpu().numpy(), axis=0)
510

511
      eval_metric.add_batch(
512
          predictions=predictions,
513
          references=references,
514
      )
515
    completed_steps += 1
516

517
  if has_labels:
518
    eval_results.update(eval_metric.compute())
519
    eval_results['completed_steps'] = completed_steps
520
    eval_results['avg_eval_loss'] = eval_loss / completed_steps
521

522
    if write_to_file:
523
      accelerator.wait_for_everyone()
524
      if accelerator.is_main_process:
525
        results_file = os.path.join(args.output_dir,
526
                                    f'{eval_set}_results_{checkpoint}.json')
527
        with open(results_file, 'w') as f:
528
          json.dump(eval_results, f, indent=4, sort_keys=True)
529

530
  if write_to_file:
531
    accelerator.wait_for_everyone()
532
    if accelerator.is_main_process:
533
      output_file = os.path.join(args.output_dir,
534
                                 f'{eval_set}_output_{checkpoint}.csv')
535
      if not args.is_regression:
536
        assert len(all_predictions) == len(all_probabilities)
537
        df = pd.DataFrame(
538
            list(zip(all_predictions, all_probabilities)),
539
            columns=['prediction', 'probability'])
540
      else:
541
        df = pd.DataFrame(all_predictions, columns=['prediction'])
542
      df = df.head(num_examples)
543
      df.to_csv(output_file, header=True, index=False)
544
  return eval_results
545

546

547
def load_from_pretrained(
548
    args, pretrained_model_name_or_path
549
):
550
  """Load the pretrained model and tokenizer."""
551

552
  # In distributed training, the .from_pretrained methods guarantee that only
553
  # one local process can concurrently perform this procedure.
554

555
  config = AutoConfig.from_pretrained(
556
      pretrained_model_name_or_path,
557
      num_labels=args.num_labels if hasattr(args, 'num_labels') else None,
558
      finetuning_task=args.task_name.lower(),
559
      cache_dir=args.cache_dir,
560
  )
561
  tokenizer = AutoTokenizer.from_pretrained(
562
      pretrained_model_name_or_path,
563
      use_fast=args.use_fast_tokenizer,
564
      cache_dir=args.cache_dir)
565
  model = AutoModelForSequenceClassification.from_pretrained(
566
      pretrained_model_name_or_path,
567
      from_tf=bool('.ckpt' in args.model_name_or_path),
568
      config=config,
569
      ignore_mismatched_sizes=True,
570
      cache_dir=args.cache_dir,
571
  )
572
  return config, tokenizer, model
573

574

575
def finetune(accelerator,
576
             model_name_or_path, train_file, output_dir,
577
             **kwargs):
578
  """Fine-tuning a pre-trained model on a downstream task.
579

580
  Args:
581
    accelerator: An instance of an accelerator for distributed training (on
582
      multi-GPU, TPU) or mixed precision training.
583
    model_name_or_path: Path to pretrained model or model identifier from
584
      huggingface.co/models.
585
    train_file: A csv or a json file containing the training data.
586
    output_dir: The output directory where the model predictions and checkpoints
587
      will be written.
588
    **kwargs: Dictionary of key/value pairs with which to update the
589
      configuration object after loading. The values in kwargs of any keys which
590
      are configuration attributes will be used to override the loaded values.
591
  """
592
  # Make one log on every process with the configuration for debugging.
593
  logging.basicConfig(
594
      format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
595
      datefmt='%m/%d/%Y %H:%M:%S',
596
      level=logging.INFO,
597
  )
598
  logger.info(accelerator.state)
599

600
  # Setup logging, we only want one process per machine to log things on the
601
  # screen. accelerator.is_local_main_process is only True for one process per
602
  # machine.
603
  logger.setLevel(
604
      logging.INFO if accelerator.is_local_main_process else logging.ERROR)
605

606
  model_args = FTModelArguments(model_name_or_path=model_name_or_path)
607
  data_args = FTDataArguments(train_file=train_file)
608
  training_args = FTTrainingArguments(output_dir=output_dir)
609
  args = argparse.Namespace()
610

611
  for arg_class in (model_args, data_args, training_args):
612
    for key, value in vars(arg_class).items():
613
      setattr(args, key, value)
614

615
  for key, value in kwargs.items():
616
    if hasattr(args, key):
617
      setattr(args, key, value)
618

619
  # Sanity checks
620
  data_files = {}
621
  args.data_file_extension = None
622

623
  # You need to provide the training data as we always run training
624
  args.do_train = True
625
  assert args.train_file is not None
626
  data_files[Split.TRAIN.value] = args.train_file
627

628
  if args.do_eval or args.evaluation_strategy != IntervalStrategy.NO.value:
629
    assert args.eval_file is not None
630
    data_files[Split.EVAL.value] = args.eval_file
631

632
  if args.do_eval and args.test_file is not None:
633
    data_files[Split.TEST.value] = args.test_file
634

635
  if args.do_predict:
636
    assert args.infer_file is not None
637
    data_files[Split.INFER.value] = args.infer_file
638

639
  for key in data_files:
640
    extension = data_files[key].split('.')[-1]
641
    assert extension in ['csv', 'json'
642
                        ], f'`{key}_file` should be a csv or a json file.'
643
    if args.data_file_extension is None:
644
      args.data_file_extension = extension
645
    else:
646
      assert (extension == args.data_file_extension
647
             ), f'`{key}_file` should be a {args.data_file_extension} file`.'
648

649
  assert (
650
      args.eval_metric in datasets.list_metrics()
651
  ), f'{args.eval_metric} not in the list of supported metrics {datasets.list_metrics()}.'
652

653
  # Handle the output directory creation
654
  if accelerator.is_main_process:
655
    if args.output_dir is not None:
656
      os.makedirs(args.output_dir, exist_ok=True)
657
  accelerator.wait_for_everyone()
658

659
  # If passed along, set the training seed now.
660
  if args.seed is not None:
661
    set_seed(args.seed)
662

663
  # You need to provide your CSV/JSON data files.
664
  #
665
  # For CSV/JSON files, this script will use as labels the column called 'label'
666
  # and as pair of sentences the sentences in columns called 'sentence1' and
667
  # 'sentence2' if these columns exist or the first two columns not named
668
  # 'label' if at least two columns are provided.
669
  #
670
  # If the CSVs/JSONs contain only one non-label column, the script does single
671
  # sentence classification on this single column.
672
  #
673
  # In distributed training, the load_dataset function guarantees that only one
674
  # local process can download the dataset.
675

676
  # Loading the dataset from local csv or json files.
677
  raw_datasets = load_dataset(args.data_file_extension, data_files=data_files)
678

679
  # Labels
680
  is_regression = raw_datasets[Split.TRAIN.value].features['label'].dtype in [
681
      'float32', 'float64'
682
  ]
683
  args.is_regression = is_regression
684

685
  if args.is_regression:
686
    label_list = None
687
    num_labels = 1
688
  else:
689
    label_list = args.label_list
690
    assert label_list is not None
691
    label_list.sort()  # Let's sort it for determinism
692
    num_labels = len(label_list)
693
  args.num_labels = num_labels
694

695
  # Load pre-trained model
696
  config, tokenizer, model = load_from_pretrained(args, args.model_name_or_path)
697

698
  # Preprocessing the datasets
699
  non_label_column_names = [
700
      name for name in raw_datasets[Split.TRAIN.value].column_names
701
      if name != 'label'
702
  ]
703
  if 'sentence1' in non_label_column_names and 'sentence2' in non_label_column_names:
704
    sentence1_key, sentence2_key = 'sentence1', 'sentence2'
705
  else:
706
    if len(non_label_column_names) >= 2:
707
      sentence1_key, sentence2_key = non_label_column_names[:2]
708
    else:
709
      sentence1_key, sentence2_key = non_label_column_names[0], None
710

711
  label_to_id = {v: i for i, v in enumerate(label_list)}
712
  config.label2id = label_to_id
713
  config.id2label = {id: label for label, id in config.label2id.items()}
714
  padding = 'max_length' if args.pad_to_max_length else False
715

716
  def preprocess_function(examples):
717
    # Tokenize the texts
718
    texts = ((examples[sentence1_key],) if sentence2_key is None else
719
             (examples[sentence1_key], examples[sentence2_key]))
720
    result = tokenizer(
721
        *texts, padding=padding, max_length=args.max_length, truncation=True)
722

723
    if 'label' in examples:
724
      if label_to_id is not None:
725
        # Map labels to IDs (not necessary for GLUE tasks)
726
        result['labels'] = [label_to_id[l] for l in examples['label']]
727
      else:
728
        # In all cases, rename the column to labels because the model will
729
        # expect that.
730
        result['labels'] = examples['label']
731
    return result
732

733
  with accelerator.main_process_first():
734
    processed_datasets = raw_datasets.map(
735
        preprocess_function,
736
        batched=True,
737
        remove_columns=raw_datasets[Split.TRAIN.value].column_names,
738
        desc='Running tokenizer on dataset',
739
    )
740

741
  num_examples = {}
742
  splits = [s.value for s in Split]
743
  for split in splits:
744
    if split in processed_datasets:
745
      num_examples[split] = len(processed_datasets[split])
746
  args.num_examples = num_examples
747

748
  train_dataset = processed_datasets[Split.TRAIN.value]
749
  eval_dataset = processed_datasets[
750
      Split.EVAL.value] if Split.EVAL.value in processed_datasets else None
751
  test_dataset = processed_datasets[
752
      Split.TEST.value] if Split.TEST.value in processed_datasets else None
753
  infer_dataset = processed_datasets[
754
      Split.INFER.value] if Split.INFER.value in processed_datasets else None
755

756
  # Log a few random samples from the training set:
757
  for index in random.sample(range(len(train_dataset)), 3):
758
    logger.info('Sample %d of the training set: %s.', index,
759
                train_dataset[index])
760

761
  # DataLoaders creation:
762
  if args.pad_to_max_length:
763
    # If padding was already done ot max length, we use the default data
764
    # collator that will just convert everything to tensors.
765
    data_collator = default_data_collator
766
  else:
767
    # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by
768
    # padding to the maximum length of the samples passed). When using mixed
769
    # precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple of
770
    # 8s, which will enable the use of Tensor Cores on NVIDIA hardware with
771
    # compute capability >= 7.5 (Volta).
772
    data_collator = DataCollatorWithPadding(
773
        tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))
774

775
  train_dataloader = DataLoader(
776
      train_dataset,
777
      batch_size=args.per_device_train_batch_size,
778
      shuffle=True,
779
      collate_fn=data_collator,
780
  )
781
  eval_dataloader, test_dataloader, infer_dataloader = None, None, None
782

783
  if eval_dataset is not None:
784
    eval_dataloader = DataLoader(
785
        eval_dataset,
786
        batch_size=args.per_device_eval_batch_size,
787
        collate_fn=data_collator)
788

789
  if test_dataset is not None:
790
    test_dataloader = DataLoader(
791
        test_dataset,
792
        batch_size=args.per_device_eval_batch_size,
793
        collate_fn=data_collator)
794

795
  if infer_dataset is not None:
796
    infer_dataloader = DataLoader(
797
        infer_dataset,
798
        batch_size=args.per_device_eval_batch_size,
799
        collate_fn=data_collator)
800

801
  # Optimizer
802
  # Split weights in two groups, one with weight decay and the other not.
803
  no_decay = ['bias', 'LayerNorm.weight']
804
  optimizer_grouped_parameters = [
805
      {
806
          'params': [
807
              p for n, p in model.named_parameters()
808
              if not any(nd in n for nd in no_decay)
809
          ],
810
          'weight_decay': args.weight_decay,
811
      },
812
      {
813
          'params': [
814
              p for n, p in model.named_parameters()
815
              if any(nd in n for nd in no_decay)
816
          ],
817
          'weight_decay': 0.0,
818
      },
819
  ]
820
  optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
821

822
  # Prepare everything with our `accelerator`.
823
  model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infer_dataloader = accelerator.prepare(
824
      model, optimizer, train_dataloader, eval_dataloader, test_dataloader,
825
      infer_dataloader)
826

827
  # Note -> the training dataloader needs to be prepared before we grab its
828
  # length below (cause its length will be shorter in multiprocess)
829

830
  # Scheduler and math around the number of training steps.
831
  num_update_steps_per_epoch = math.ceil(
832
      len(train_dataloader) / args.gradient_accumulation_steps)
833
  if args.max_steps == -1:
834
    args.max_steps = args.num_train_epochs * num_update_steps_per_epoch
835
  else:
836
    args.num_train_epochs = math.ceil(args.max_steps /
837
                                      num_update_steps_per_epoch)
838

839
  lr_scheduler = get_scheduler(
840
      name=args.lr_scheduler_type,
841
      optimizer=optimizer,
842
      num_warmup_steps=args.warmup_steps,
843
      num_training_steps=args.max_steps,
844
  )
845

846
  # Train
847
  completed_steps, avg_train_loss = train(args, accelerator, model, tokenizer,
848
                                          train_dataloader, optimizer,
849
                                          lr_scheduler, eval_dataloader)
850
  accelerator.wait_for_everyone()
851
  logger.info(
852
      'Training job completed: completed_steps = %d, avg_train_loss = %f',
853
      completed_steps, avg_train_loss)
854

855
  args.model_name_or_path = os.path.join(args.output_dir, 'best-checkpoint')
856
  logger.info('Loading the best checkpoint: %s', args.model_name_or_path)
857
  config, tokenizer, model = load_from_pretrained(args, args.model_name_or_path)
858
  model = accelerator.prepare(model)
859

860
  if args.do_eval:
861
    # Evaluate
862
    if eval_dataloader is not None:
863
      logger.info(
864
          '***** Running evaluation on the eval data using the best checkpoint *****'
865
      )
866
      eval_results = evaluate(args, accelerator, eval_dataloader,
867
                              Split.EVAL.value, model, 'best-checkpoint')
868
      avg_eval_loss = eval_results['avg_eval_loss']
869
      eval_metric = eval_results[args.eval_metric]
870
      logger.info('Evaluation job completed: avg_eval_loss = %f', avg_eval_loss)
871
      logger.info('Evaluation result for the best checkpoint: %s = %f',
872
                  args.eval_metric, eval_metric)
873

874
    if test_dataloader is not None:
875
      logger.info(
876
          '***** Running evaluation on the test data using the best checkpoint *****'
877
      )
878
      eval_results = evaluate(args, accelerator, test_dataloader,
879
                              Split.TEST.value, model, 'best-checkpoint')
880
      avg_eval_loss = eval_results['avg_eval_loss']
881
      eval_metric = eval_results[args.eval_metric]
882
      logger.info('Test job completed: avg_test_loss = %f', avg_eval_loss)
883
      logger.info('Test result for the best checkpoint: %s = %f',
884
                  args.eval_metric, eval_metric)
885

886
  if args.do_predict:
887
    # Predict
888
    if infer_dataloader is not None:
889
      logger.info('***** Running inference using the best checkpoint *****')
890
      evaluate(
891
          args,
892
          accelerator,
893
          infer_dataloader,
894
          Split.INFER.value,
895
          model,
896
          'best-checkpoint',
897
          has_labels=False)
898
      logger.info('Inference job completed.')
899

900
  # Release all references to the internal objects stored and call the garbage
901
  # collector. You should call this method between two trainings with different
902
  # models/optimizers.
903
  accelerator.free_memory()
904

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.