google-research
903 строки · 32.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16#!/usr/bin/env python
17# coding=utf-8
18"""Fine-tuning the library models for sequence classification."""
19
20import argparse21import dataclasses22import json23import logging24import math25import os26import random27import shutil28from typing import Any, Dict, List, Optional, Tuple29
30import accelerate31import datasets32from datasets import load_dataset33from datasets import load_metric34import numpy as np35import pandas as pd36import torch37from torch.utils.data import DataLoader38from tqdm.auto import tqdm39from transformers import AdamW40from transformers import AutoConfig41from transformers import AutoModelForSequenceClassification42from transformers import AutoTokenizer43from transformers import DataCollatorWithPadding44from transformers import default_data_collator45from transformers import get_scheduler46from transformers import set_seed47from transformers.configuration_utils import PretrainedConfig48from transformers.file_utils import ExplicitEnum49from transformers.modeling_utils import PreTrainedModel50from transformers.tokenization_utils_base import PreTrainedTokenizerBase51from transformers.trainer_utils import IntervalStrategy52
53logger = logging.getLogger(__name__)54
55
56class Split(ExplicitEnum):57TRAIN = 'train'58EVAL = 'eval'59TEST = 'test'60INFER = 'infer'61
62
63@dataclasses.dataclass64class FTModelArguments:65"""Arguments pertaining to which config/tokenizer/model we are going to fine-tune from."""66model_name_or_path: str = dataclasses.field(67metadata={68'help':69'Path to pretrained model or model identifier from huggingface.co/models.'70})71use_fast_tokenizer: Optional[bool] = dataclasses.field(72default=True,73metadata={74'help':75'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'76},77)78cache_dir: Optional[str] = dataclasses.field(79default=None,80metadata={81'help':82'Where do you want to store the pretrained models downloaded from huggingface.co.'83},84)85
86
87@dataclasses.dataclass88class FTDataArguments:89"""Arguments pertaining to what data we are going to input our model for training and evaluation."""90train_file: str = dataclasses.field(91default=None,92metadata={'help': 'A csv or a json file containing the training data.'})93eval_file: Optional[str] = dataclasses.field(94default=None,95metadata={'help': 'A csv or a json file containing the validation data.'})96test_file: Optional[str] = dataclasses.field(97default=None,98metadata={'help': 'A csv or a json file containing the test data.'})99infer_file: Optional[str] = dataclasses.field(100default=None,101metadata={102'help': 'A csv or a json file containing the data to predict on.'103})104task_name: Optional[str] = dataclasses.field(105default=None,106metadata={'help': 'The name of the task to train on.'},107)108label_list: Optional[List[str]] = dataclasses.field(109default=None, metadata={'help': 'The list of labels for the task.'})110
111max_length: Optional[int] = dataclasses.field(112default=128,113metadata={114'help':115'The maximum total input sequence length after tokenization. Sequences longer '116'than this will be truncated, sequences shorter will be padded.'117},118)119pad_to_max_length: Optional[bool] = dataclasses.field(120default=False,121metadata={122'help':123'Whether to pad all samples to `max_seq_length`. '124'If False, will pad the samples dynamically when batching to the maximum length in the batch.'125},126)127
128
129@dataclasses.dataclass130class FTTrainingArguments():131"""Training arguments pertaining to the training loop itself."""132
133output_dir: str = dataclasses.field(134metadata={135'help':136'The output directory where the model predictions and checkpoints will be written.'137})138do_train: Optional[bool] = dataclasses.field(139default=False,140metadata={'help': 'Whether to run training or not.'},141)142do_eval: Optional[bool] = dataclasses.field(143default=False,144metadata={145'help': 'Whether to run evaluation on the validation set or not.'146},147)148do_predict: Optional[bool] = dataclasses.field(149default=False,150metadata={151'help': 'Whether to run inference on the inference set or not.'152},153)154seed: Optional[int] = dataclasses.field(155default=42,156metadata={157'help': 'Random seed that will be set at the beginning of training.'158},159)160per_device_train_batch_size: Optional[int] = dataclasses.field(161default=8,162metadata={'help': 'The batch size per GPU/TPU core/CPU for training.'},163)164per_device_eval_batch_size: Optional[int] = dataclasses.field(165default=8,166metadata={'help': 'The batch size per GPU/TPU core/CPU for evaluation.'},167)168weight_decay: Optional[float] = dataclasses.field(169default=0.0,170metadata={171'help':172'The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`] optimizer.'173},174)175learning_rate: Optional[float] = dataclasses.field(176default=5e-5,177metadata={'help': 'The initial learning rate for [`AdamW`] optimizer.'},178)179gradient_accumulation_steps: Optional[int] = dataclasses.field(180default=1,181metadata={182'help':183'Number of updates steps to accumulate the gradients for, before performing a backward/update pass.'184},185)186max_steps: Optional[int] = dataclasses.field(187default=-1,188metadata={189'help':190'If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.'191},192)193lr_scheduler_type: Optional[str] = dataclasses.field(194default='linear', metadata={'help': 'The scheduler type to use.'})195warmup_steps: Optional[int] = dataclasses.field(196default=1,197metadata={198'help':199'Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.'200},201)202evaluation_strategy: Optional[str] = dataclasses.field(203default='no',204metadata={205'help':206'The evaluation strategy to adopt during training. Possible values are: ["no", "step", "epoch]'207})208eval_steps: Optional[int] = dataclasses.field(209default=1,210metadata={211'help':212'Number of update steps between two evaluations if `evaluation_strategy="steps"`.'213},214)215eval_metric: Optional[str] = dataclasses.field(216default='accuracy',217metadata={'help': 'The evaluation metric used for the task.'})218keep_checkpoint_max: Optional[int] = dataclasses.field(219default=1,220metadata={'help': 'The maximum number of best checkpoint files to keep.'},221)222early_stopping_patience: Optional[int] = dataclasses.field(223default=10,224metadata={225'help':226'Number of evaluation calls with no improvement after which training will be stopped.'227},228)229early_stopping_threshold: Optional[float] = dataclasses.field(230default=0.0,231metadata={232'help':233'How much the specified evaluation metric must improve to satisfy early stopping conditions.'234},235)236
237
238def train(args,239accelerator,240model,241tokenizer,242train_dataloader,243optimizer,244lr_scheduler,245eval_dataloader = None):246"""Train a model on the given training data."""247
248total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps249
250logger.info('***** Running training *****')251logger.info(' Num examples = %d', args.num_examples[Split.TRAIN.value])252logger.info(' Instantaneous batch size per device = %d',253args.per_device_train_batch_size)254logger.info(255' Total train batch size (w. parallel, distributed & accumulation) = %d',256total_batch_size)257logger.info(' Gradient Accumulation steps = %d',258args.gradient_accumulation_steps)259logger.info(' Total optimization steps = %d', args.max_steps)260
261# Only show the progress bar once on each machine.262progress_bar = tqdm(263range(args.max_steps), disable=not accelerator.is_local_main_process)264
265checkpoints = None266eval_results = None267best_checkpoint = None268best_eval_result = None269early_stopping_patience_counter = 0270should_training_stop = False271epoch = 0272completed_steps = 0273train_loss = 0.0274model.zero_grad()275
276for _ in range(args.num_train_epochs):277epoch += 1278model.train()279for step, batch in enumerate(train_dataloader):280outputs = model(**batch)281loss = outputs.loss282loss = loss / args.gradient_accumulation_steps283accelerator.backward(loss)284train_loss += loss.item()285
286if step % args.gradient_accumulation_steps == 0 or step == len(287train_dataloader) - 1:288optimizer.step()289lr_scheduler.step()290optimizer.zero_grad()291progress_bar.update(1)292completed_steps += 1293
294# Evaluate during training295if eval_dataloader is not None and args.evaluation_strategy == IntervalStrategy.STEPS.value and args.eval_steps > 0 and completed_steps % args.eval_steps == 0:296accelerator.wait_for_everyone()297new_checkpoint = f'checkpoint-{IntervalStrategy.STEPS.value}-{completed_steps}'298new_eval_result = evaluate(args, accelerator, eval_dataloader, 'eval',299model, new_checkpoint)[args.eval_metric]300logger.info('Evaluation result at step %d: %s = %f', completed_steps,301args.eval_metric, new_eval_result)302if checkpoints is None:303checkpoints = np.array([new_checkpoint])304eval_results = np.array([new_eval_result])305best_checkpoint = new_checkpoint306best_eval_result = new_eval_result307else:308if new_eval_result - best_eval_result > args.early_stopping_threshold:309best_checkpoint = new_checkpoint310best_eval_result = new_eval_result311early_stopping_patience_counter = 0312else:313if new_eval_result == best_eval_result:314best_checkpoint = new_checkpoint315best_eval_result = new_eval_result316early_stopping_patience_counter += 1317
318if early_stopping_patience_counter >= args.early_stopping_patience:319should_training_stop = True320
321checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)322eval_results = np.append(eval_results, [new_eval_result], axis=0)323sorted_ids = np.argsort(eval_results)324eval_results = eval_results[sorted_ids]325checkpoints = checkpoints[sorted_ids]326
327if len(checkpoints) > args.keep_checkpoint_max:328# Delete the current worst checkpoint329checkpoint_to_remove, *checkpoints = checkpoints330eval_results = eval_results[1:]331if checkpoint_to_remove != new_checkpoint:332if accelerator.is_main_process:333shutil.rmtree(334os.path.join(args.output_dir, checkpoint_to_remove),335ignore_errors=True)336accelerator.wait_for_everyone()337
338if new_checkpoint in checkpoints:339# Save model checkpoint340checkpoint_output_dir = os.path.join(args.output_dir,341new_checkpoint)342if accelerator.is_main_process:343if not os.path.exists(checkpoint_output_dir):344os.makedirs(checkpoint_output_dir)345accelerator.wait_for_everyone()346unwrapped_model = accelerator.unwrap_model(model)347unwrapped_model.save_pretrained(348checkpoint_output_dir, save_function=accelerator.save)349if accelerator.is_main_process:350tokenizer.save_pretrained(checkpoint_output_dir)351logger.info('Saving model checkpoint to %s',352checkpoint_output_dir)353
354if completed_steps >= args.max_steps:355break356
357if should_training_stop:358break359
360# Evaluate during training361if eval_dataloader is not None and args.evaluation_strategy == IntervalStrategy.EPOCH.value:362accelerator.wait_for_everyone()363new_checkpoint = f'checkpoint-{IntervalStrategy.EPOCH.value}-{epoch}'364new_eval_result = evaluate(args, accelerator, eval_dataloader, 'eval',365model, new_checkpoint)[args.eval_metric]366logger.info('Evaluation result at epoch %d: %s = %f', epoch,367args.eval_metric, new_eval_result)368
369if checkpoints is None:370checkpoints = np.array([new_checkpoint])371eval_results = np.array([new_eval_result])372best_checkpoint = new_checkpoint373best_eval_result = new_eval_result374else:375if new_eval_result - best_eval_result > args.early_stopping_threshold:376best_checkpoint = new_checkpoint377best_eval_result = new_eval_result378early_stopping_patience_counter = 0379else:380if new_eval_result == best_eval_result:381best_checkpoint = new_checkpoint382best_eval_result = new_eval_result383early_stopping_patience_counter += 1384
385if early_stopping_patience_counter >= args.early_stopping_patience:386should_training_stop = True387
388checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)389eval_results = np.append(eval_results, [new_eval_result], axis=0)390sorted_ids = np.argsort(eval_results)391eval_results = eval_results[sorted_ids]392checkpoints = checkpoints[sorted_ids]393
394if len(checkpoints) > args.keep_checkpoint_max:395# Delete the current worst checkpoint396checkpoint_to_remove, *checkpoints = checkpoints397eval_results = eval_results[1:]398if checkpoint_to_remove != new_checkpoint:399if accelerator.is_main_process:400shutil.rmtree(401os.path.join(args.output_dir, checkpoint_to_remove),402ignore_errors=True)403accelerator.wait_for_everyone()404
405if new_checkpoint in checkpoints:406# Save model checkpoint407checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)408if accelerator.is_main_process:409if not os.path.exists(checkpoint_output_dir):410os.makedirs(checkpoint_output_dir)411accelerator.wait_for_everyone()412unwrapped_model = accelerator.unwrap_model(model)413unwrapped_model.save_pretrained(414checkpoint_output_dir, save_function=accelerator.save)415if accelerator.is_main_process:416tokenizer.save_pretrained(checkpoint_output_dir)417logger.info('Saving model checkpoint to %s', checkpoint_output_dir)418
419if completed_steps >= args.max_steps:420break421
422if should_training_stop:423break424
425if best_checkpoint is not None:426# Save the best checkpoint427logger.info('Best checkpoint: %s', best_checkpoint)428logger.info('Best evaluation result: %s = %f', args.eval_metric,429best_eval_result)430best_checkpoint_output_dir = os.path.join(args.output_dir, best_checkpoint)431if accelerator.is_main_process:432shutil.move(best_checkpoint_output_dir,433os.path.join(args.output_dir, 'best-checkpoint'))434shutil.rmtree(best_checkpoint_output_dir, ignore_errors=True)435accelerator.wait_for_everyone()436
437else:438# Assume that the last checkpoint is the best checkpoint and save it439checkpoint_output_dir = os.path.join(args.output_dir, 'best-checkpoint')440if not os.path.exists(checkpoint_output_dir):441os.makedirs(checkpoint_output_dir)442
443accelerator.wait_for_everyone()444unwrapped_model = accelerator.unwrap_model(model)445unwrapped_model.save_pretrained(446checkpoint_output_dir, save_function=accelerator.save)447if accelerator.is_main_process:448tokenizer.save_pretrained(checkpoint_output_dir)449logger.info('Saving model checkpoint to %s', checkpoint_output_dir)450return completed_steps, train_loss / completed_steps451
452
453def evaluate(args,454accelerator,455dataloader,456eval_set,457model,458checkpoint,459has_labels = True,460write_to_file = True):461"""Evaluate a model checkpoint on the given evaluation data."""462
463num_examples = args.num_examples[eval_set]464eval_metric = None465completed_steps = 0466eval_loss = 0.0467all_predictions = None468all_references = None469all_probabilities = None470
471if has_labels:472# Get the metric function473eval_metric = load_metric(args.eval_metric)474
475eval_results = {}476model.eval()477for _, batch in enumerate(dataloader):478with torch.no_grad():479outputs = model(**batch)480
481eval_loss += outputs.loss.item()482logits = outputs.logits483predictions = logits.argmax(484dim=-1) if not args.is_regression else logits.squeeze()485predictions = accelerator.gather(predictions)486
487if all_predictions is None:488all_predictions = predictions.detach().cpu().numpy()489else:490all_predictions = np.append(491all_predictions, predictions.detach().cpu().numpy(), axis=0)492
493if not args.is_regression:494probabilities = logits.softmax(dim=-1).max(dim=-1).values495probabilities = accelerator.gather(probabilities)496if all_probabilities is None:497all_probabilities = probabilities.detach().cpu().numpy()498else:499all_probabilities = np.append(500all_probabilities, probabilities.detach().cpu().numpy(), axis=0)501
502if has_labels:503references = batch['labels']504references = accelerator.gather(references)505if all_references is None:506all_references = references.detach().cpu().numpy()507else:508all_references = np.append(509all_references, references.detach().cpu().numpy(), axis=0)510
511eval_metric.add_batch(512predictions=predictions,513references=references,514)515completed_steps += 1516
517if has_labels:518eval_results.update(eval_metric.compute())519eval_results['completed_steps'] = completed_steps520eval_results['avg_eval_loss'] = eval_loss / completed_steps521
522if write_to_file:523accelerator.wait_for_everyone()524if accelerator.is_main_process:525results_file = os.path.join(args.output_dir,526f'{eval_set}_results_{checkpoint}.json')527with open(results_file, 'w') as f:528json.dump(eval_results, f, indent=4, sort_keys=True)529
530if write_to_file:531accelerator.wait_for_everyone()532if accelerator.is_main_process:533output_file = os.path.join(args.output_dir,534f'{eval_set}_output_{checkpoint}.csv')535if not args.is_regression:536assert len(all_predictions) == len(all_probabilities)537df = pd.DataFrame(538list(zip(all_predictions, all_probabilities)),539columns=['prediction', 'probability'])540else:541df = pd.DataFrame(all_predictions, columns=['prediction'])542df = df.head(num_examples)543df.to_csv(output_file, header=True, index=False)544return eval_results545
546
547def load_from_pretrained(548args, pretrained_model_name_or_path549):550"""Load the pretrained model and tokenizer."""551
552# In distributed training, the .from_pretrained methods guarantee that only553# one local process can concurrently perform this procedure.554
555config = AutoConfig.from_pretrained(556pretrained_model_name_or_path,557num_labels=args.num_labels if hasattr(args, 'num_labels') else None,558finetuning_task=args.task_name.lower(),559cache_dir=args.cache_dir,560)561tokenizer = AutoTokenizer.from_pretrained(562pretrained_model_name_or_path,563use_fast=args.use_fast_tokenizer,564cache_dir=args.cache_dir)565model = AutoModelForSequenceClassification.from_pretrained(566pretrained_model_name_or_path,567from_tf=bool('.ckpt' in args.model_name_or_path),568config=config,569ignore_mismatched_sizes=True,570cache_dir=args.cache_dir,571)572return config, tokenizer, model573
574
575def finetune(accelerator,576model_name_or_path, train_file, output_dir,577**kwargs):578"""Fine-tuning a pre-trained model on a downstream task.579
580Args:
581accelerator: An instance of an accelerator for distributed training (on
582multi-GPU, TPU) or mixed precision training.
583model_name_or_path: Path to pretrained model or model identifier from
584huggingface.co/models.
585train_file: A csv or a json file containing the training data.
586output_dir: The output directory where the model predictions and checkpoints
587will be written.
588**kwargs: Dictionary of key/value pairs with which to update the
589configuration object after loading. The values in kwargs of any keys which
590are configuration attributes will be used to override the loaded values.
591"""
592# Make one log on every process with the configuration for debugging.593logging.basicConfig(594format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',595datefmt='%m/%d/%Y %H:%M:%S',596level=logging.INFO,597)598logger.info(accelerator.state)599
600# Setup logging, we only want one process per machine to log things on the601# screen. accelerator.is_local_main_process is only True for one process per602# machine.603logger.setLevel(604logging.INFO if accelerator.is_local_main_process else logging.ERROR)605
606model_args = FTModelArguments(model_name_or_path=model_name_or_path)607data_args = FTDataArguments(train_file=train_file)608training_args = FTTrainingArguments(output_dir=output_dir)609args = argparse.Namespace()610
611for arg_class in (model_args, data_args, training_args):612for key, value in vars(arg_class).items():613setattr(args, key, value)614
615for key, value in kwargs.items():616if hasattr(args, key):617setattr(args, key, value)618
619# Sanity checks620data_files = {}621args.data_file_extension = None622
623# You need to provide the training data as we always run training624args.do_train = True625assert args.train_file is not None626data_files[Split.TRAIN.value] = args.train_file627
628if args.do_eval or args.evaluation_strategy != IntervalStrategy.NO.value:629assert args.eval_file is not None630data_files[Split.EVAL.value] = args.eval_file631
632if args.do_eval and args.test_file is not None:633data_files[Split.TEST.value] = args.test_file634
635if args.do_predict:636assert args.infer_file is not None637data_files[Split.INFER.value] = args.infer_file638
639for key in data_files:640extension = data_files[key].split('.')[-1]641assert extension in ['csv', 'json'642], f'`{key}_file` should be a csv or a json file.'643if args.data_file_extension is None:644args.data_file_extension = extension645else:646assert (extension == args.data_file_extension647), f'`{key}_file` should be a {args.data_file_extension} file`.'648
649assert (650args.eval_metric in datasets.list_metrics()651), f'{args.eval_metric} not in the list of supported metrics {datasets.list_metrics()}.'652
653# Handle the output directory creation654if accelerator.is_main_process:655if args.output_dir is not None:656os.makedirs(args.output_dir, exist_ok=True)657accelerator.wait_for_everyone()658
659# If passed along, set the training seed now.660if args.seed is not None:661set_seed(args.seed)662
663# You need to provide your CSV/JSON data files.664#665# For CSV/JSON files, this script will use as labels the column called 'label'666# and as pair of sentences the sentences in columns called 'sentence1' and667# 'sentence2' if these columns exist or the first two columns not named668# 'label' if at least two columns are provided.669#670# If the CSVs/JSONs contain only one non-label column, the script does single671# sentence classification on this single column.672#673# In distributed training, the load_dataset function guarantees that only one674# local process can download the dataset.675
676# Loading the dataset from local csv or json files.677raw_datasets = load_dataset(args.data_file_extension, data_files=data_files)678
679# Labels680is_regression = raw_datasets[Split.TRAIN.value].features['label'].dtype in [681'float32', 'float64'682]683args.is_regression = is_regression684
685if args.is_regression:686label_list = None687num_labels = 1688else:689label_list = args.label_list690assert label_list is not None691label_list.sort() # Let's sort it for determinism692num_labels = len(label_list)693args.num_labels = num_labels694
695# Load pre-trained model696config, tokenizer, model = load_from_pretrained(args, args.model_name_or_path)697
698# Preprocessing the datasets699non_label_column_names = [700name for name in raw_datasets[Split.TRAIN.value].column_names701if name != 'label'702]703if 'sentence1' in non_label_column_names and 'sentence2' in non_label_column_names:704sentence1_key, sentence2_key = 'sentence1', 'sentence2'705else:706if len(non_label_column_names) >= 2:707sentence1_key, sentence2_key = non_label_column_names[:2]708else:709sentence1_key, sentence2_key = non_label_column_names[0], None710
711label_to_id = {v: i for i, v in enumerate(label_list)}712config.label2id = label_to_id713config.id2label = {id: label for label, id in config.label2id.items()}714padding = 'max_length' if args.pad_to_max_length else False715
716def preprocess_function(examples):717# Tokenize the texts718texts = ((examples[sentence1_key],) if sentence2_key is None else719(examples[sentence1_key], examples[sentence2_key]))720result = tokenizer(721*texts, padding=padding, max_length=args.max_length, truncation=True)722
723if 'label' in examples:724if label_to_id is not None:725# Map labels to IDs (not necessary for GLUE tasks)726result['labels'] = [label_to_id[l] for l in examples['label']]727else:728# In all cases, rename the column to labels because the model will729# expect that.730result['labels'] = examples['label']731return result732
733with accelerator.main_process_first():734processed_datasets = raw_datasets.map(735preprocess_function,736batched=True,737remove_columns=raw_datasets[Split.TRAIN.value].column_names,738desc='Running tokenizer on dataset',739)740
741num_examples = {}742splits = [s.value for s in Split]743for split in splits:744if split in processed_datasets:745num_examples[split] = len(processed_datasets[split])746args.num_examples = num_examples747
748train_dataset = processed_datasets[Split.TRAIN.value]749eval_dataset = processed_datasets[750Split.EVAL.value] if Split.EVAL.value in processed_datasets else None751test_dataset = processed_datasets[752Split.TEST.value] if Split.TEST.value in processed_datasets else None753infer_dataset = processed_datasets[754Split.INFER.value] if Split.INFER.value in processed_datasets else None755
756# Log a few random samples from the training set:757for index in random.sample(range(len(train_dataset)), 3):758logger.info('Sample %d of the training set: %s.', index,759train_dataset[index])760
761# DataLoaders creation:762if args.pad_to_max_length:763# If padding was already done ot max length, we use the default data764# collator that will just convert everything to tensors.765data_collator = default_data_collator766else:767# Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by768# padding to the maximum length of the samples passed). When using mixed769# precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple of770# 8s, which will enable the use of Tensor Cores on NVIDIA hardware with771# compute capability >= 7.5 (Volta).772data_collator = DataCollatorWithPadding(773tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))774
775train_dataloader = DataLoader(776train_dataset,777batch_size=args.per_device_train_batch_size,778shuffle=True,779collate_fn=data_collator,780)781eval_dataloader, test_dataloader, infer_dataloader = None, None, None782
783if eval_dataset is not None:784eval_dataloader = DataLoader(785eval_dataset,786batch_size=args.per_device_eval_batch_size,787collate_fn=data_collator)788
789if test_dataset is not None:790test_dataloader = DataLoader(791test_dataset,792batch_size=args.per_device_eval_batch_size,793collate_fn=data_collator)794
795if infer_dataset is not None:796infer_dataloader = DataLoader(797infer_dataset,798batch_size=args.per_device_eval_batch_size,799collate_fn=data_collator)800
801# Optimizer802# Split weights in two groups, one with weight decay and the other not.803no_decay = ['bias', 'LayerNorm.weight']804optimizer_grouped_parameters = [805{806'params': [807p for n, p in model.named_parameters()808if not any(nd in n for nd in no_decay)809],810'weight_decay': args.weight_decay,811},812{813'params': [814p for n, p in model.named_parameters()815if any(nd in n for nd in no_decay)816],817'weight_decay': 0.0,818},819]820optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)821
822# Prepare everything with our `accelerator`.823model, optimizer, train_dataloader, eval_dataloader, test_dataloader, infer_dataloader = accelerator.prepare(824model, optimizer, train_dataloader, eval_dataloader, test_dataloader,825infer_dataloader)826
827# Note -> the training dataloader needs to be prepared before we grab its828# length below (cause its length will be shorter in multiprocess)829
830# Scheduler and math around the number of training steps.831num_update_steps_per_epoch = math.ceil(832len(train_dataloader) / args.gradient_accumulation_steps)833if args.max_steps == -1:834args.max_steps = args.num_train_epochs * num_update_steps_per_epoch835else:836args.num_train_epochs = math.ceil(args.max_steps /837num_update_steps_per_epoch)838
839lr_scheduler = get_scheduler(840name=args.lr_scheduler_type,841optimizer=optimizer,842num_warmup_steps=args.warmup_steps,843num_training_steps=args.max_steps,844)845
846# Train847completed_steps, avg_train_loss = train(args, accelerator, model, tokenizer,848train_dataloader, optimizer,849lr_scheduler, eval_dataloader)850accelerator.wait_for_everyone()851logger.info(852'Training job completed: completed_steps = %d, avg_train_loss = %f',853completed_steps, avg_train_loss)854
855args.model_name_or_path = os.path.join(args.output_dir, 'best-checkpoint')856logger.info('Loading the best checkpoint: %s', args.model_name_or_path)857config, tokenizer, model = load_from_pretrained(args, args.model_name_or_path)858model = accelerator.prepare(model)859
860if args.do_eval:861# Evaluate862if eval_dataloader is not None:863logger.info(864'***** Running evaluation on the eval data using the best checkpoint *****'865)866eval_results = evaluate(args, accelerator, eval_dataloader,867Split.EVAL.value, model, 'best-checkpoint')868avg_eval_loss = eval_results['avg_eval_loss']869eval_metric = eval_results[args.eval_metric]870logger.info('Evaluation job completed: avg_eval_loss = %f', avg_eval_loss)871logger.info('Evaluation result for the best checkpoint: %s = %f',872args.eval_metric, eval_metric)873
874if test_dataloader is not None:875logger.info(876'***** Running evaluation on the test data using the best checkpoint *****'877)878eval_results = evaluate(args, accelerator, test_dataloader,879Split.TEST.value, model, 'best-checkpoint')880avg_eval_loss = eval_results['avg_eval_loss']881eval_metric = eval_results[args.eval_metric]882logger.info('Test job completed: avg_test_loss = %f', avg_eval_loss)883logger.info('Test result for the best checkpoint: %s = %f',884args.eval_metric, eval_metric)885
886if args.do_predict:887# Predict888if infer_dataloader is not None:889logger.info('***** Running inference using the best checkpoint *****')890evaluate(891args,892accelerator,893infer_dataloader,894Split.INFER.value,895model,896'best-checkpoint',897has_labels=False)898logger.info('Inference job completed.')899
900# Release all references to the internal objects stored and call the garbage901# collector. You should call this method between two trainings with different902# models/optimizers.903accelerator.free_memory()904