transformers
737 строк · 33.1 Кб
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2020 The HuggingFace Team All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""
17Fine-tuning XLNet for question answering with beam search using a slightly adapted version of the 🤗 Trainer.
18"""
19# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
20
21import logging
22import os
23import sys
24import warnings
25from dataclasses import dataclass, field
26from typing import Optional
27
28import datasets
29import evaluate
30from datasets import load_dataset
31from trainer_qa import QuestionAnsweringTrainer
32from utils_qa import postprocess_qa_predictions_with_beam_search
33
34import transformers
35from transformers import (
36DataCollatorWithPadding,
37EvalPrediction,
38HfArgumentParser,
39TrainingArguments,
40XLNetConfig,
41XLNetForQuestionAnswering,
42XLNetTokenizerFast,
43default_data_collator,
44set_seed,
45)
46from transformers.trainer_utils import get_last_checkpoint
47from transformers.utils import check_min_version, send_example_telemetry
48from transformers.utils.versions import require_version
49
50
51# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
52check_min_version("4.39.0.dev0")
53
54require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
55
56logger = logging.getLogger(__name__)
57
58
59@dataclass
60class ModelArguments:
61"""
62Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
63"""
64
65model_name_or_path: str = field(
66metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
67)
68config_name: Optional[str] = field(
69default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
70)
71tokenizer_name: Optional[str] = field(
72default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
73)
74cache_dir: Optional[str] = field(
75default=None,
76metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
77)
78model_revision: str = field(
79default="main",
80metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
81)
82token: str = field(
83default=None,
84metadata={
85"help": (
86"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
87"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
88)
89},
90)
91use_auth_token: bool = field(
92default=None,
93metadata={
94"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
95},
96)
97
98
99@dataclass
100class DataTrainingArguments:
101"""
102Arguments pertaining to what data we are going to input our model for training and eval.
103"""
104
105dataset_name: Optional[str] = field(
106default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
107)
108dataset_config_name: Optional[str] = field(
109default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
110)
111train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
112validation_file: Optional[str] = field(
113default=None,
114metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
115)
116test_file: Optional[str] = field(
117default=None,
118metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
119)
120overwrite_cache: bool = field(
121default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
122)
123preprocessing_num_workers: Optional[int] = field(
124default=None,
125metadata={"help": "The number of processes to use for the preprocessing."},
126)
127max_seq_length: int = field(
128default=384,
129metadata={
130"help": (
131"The maximum total input sequence length after tokenization. Sequences longer "
132"than this will be truncated, sequences shorter will be padded."
133)
134},
135)
136pad_to_max_length: bool = field(
137default=True,
138metadata={
139"help": (
140"Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
141" batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
142)
143},
144)
145max_train_samples: Optional[int] = field(
146default=None,
147metadata={
148"help": (
149"For debugging purposes or quicker training, truncate the number of training examples to this "
150"value if set."
151)
152},
153)
154max_eval_samples: Optional[int] = field(
155default=None,
156metadata={
157"help": (
158"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
159"value if set."
160)
161},
162)
163max_predict_samples: Optional[int] = field(
164default=None,
165metadata={
166"help": (
167"For debugging purposes or quicker training, truncate the number of prediction examples to this "
168"value if set."
169)
170},
171)
172version_2_with_negative: bool = field(
173default=False, metadata={"help": "If true, some of the examples do not have an answer."}
174)
175null_score_diff_threshold: float = field(
176default=0.0,
177metadata={
178"help": (
179"The threshold used to select the null answer: if the best answer has a score that is less than "
180"the score of the null answer minus this threshold, the null answer is selected for this example. "
181"Only useful when `version_2_with_negative=True`."
182)
183},
184)
185doc_stride: int = field(
186default=128,
187metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
188)
189n_best_size: int = field(
190default=20,
191metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
192)
193max_answer_length: int = field(
194default=30,
195metadata={
196"help": (
197"The maximum length of an answer that can be generated. This is needed because the start "
198"and end predictions are not conditioned on one another."
199)
200},
201)
202
203def __post_init__(self):
204if (
205self.dataset_name is None
206and self.train_file is None
207and self.validation_file is None
208and self.test_file is None
209):
210raise ValueError("Need either a dataset name or a training/validation/test file.")
211else:
212if self.train_file is not None:
213extension = self.train_file.split(".")[-1]
214assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
215if self.validation_file is not None:
216extension = self.validation_file.split(".")[-1]
217assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
218if self.test_file is not None:
219extension = self.test_file.split(".")[-1]
220assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
221
222
223def main():
224# See all possible arguments in src/transformers/training_args.py
225# or by passing the --help flag to this script.
226# We now keep distinct sets of args, for a cleaner separation of concerns.
227
228parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
229if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
230# If we pass only one argument to the script and it's the path to a json file,
231# let's parse it to get our arguments.
232model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
233else:
234model_args, data_args, training_args = parser.parse_args_into_dataclasses()
235
236if model_args.use_auth_token is not None:
237warnings.warn(
238"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
239FutureWarning,
240)
241if model_args.token is not None:
242raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
243model_args.token = model_args.use_auth_token
244
245# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
246# information sent is the one passed as arguments along with your Python/PyTorch versions.
247send_example_telemetry("run_qa_beam_search", model_args, data_args)
248
249# Setup logging
250logging.basicConfig(
251format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
252datefmt="%m/%d/%Y %H:%M:%S",
253handlers=[logging.StreamHandler(sys.stdout)],
254)
255
256if training_args.should_log:
257# The default of training_args.log_level is passive, so we set log level at info here to have that default.
258transformers.utils.logging.set_verbosity_info()
259
260log_level = training_args.get_process_log_level()
261logger.setLevel(log_level)
262datasets.utils.logging.set_verbosity(log_level)
263transformers.utils.logging.set_verbosity(log_level)
264transformers.utils.logging.enable_default_handler()
265transformers.utils.logging.enable_explicit_format()
266
267# Log on each process the small summary:
268logger.warning(
269f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
270+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
271)
272logger.info(f"Training/evaluation parameters {training_args}")
273
274# Detecting last checkpoint.
275last_checkpoint = None
276if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
277last_checkpoint = get_last_checkpoint(training_args.output_dir)
278if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
279raise ValueError(
280f"Output directory ({training_args.output_dir}) already exists and is not empty. "
281"Use --overwrite_output_dir to overcome."
282)
283elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
284logger.info(
285f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
286"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
287)
288
289# Set seed before initializing model.
290set_seed(training_args.seed)
291
292# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
293# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
294# (the dataset will be downloaded automatically from the datasets Hub).
295#
296# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
297# 'text' is found. You can easily tweak this behavior (see below).
298#
299# In distributed training, the load_dataset function guarantee that only one local process can concurrently
300# download the dataset.
301if data_args.dataset_name is not None:
302# Downloading and loading a dataset from the hub.
303raw_datasets = load_dataset(
304data_args.dataset_name,
305data_args.dataset_config_name,
306cache_dir=model_args.cache_dir,
307token=model_args.token,
308)
309else:
310data_files = {}
311if data_args.train_file is not None:
312data_files["train"] = data_args.train_file
313extension = data_args.train_file.split(".")[-1]
314if data_args.validation_file is not None:
315data_files["validation"] = data_args.validation_file
316extension = data_args.validation_file.split(".")[-1]
317if data_args.test_file is not None:
318data_files["test"] = data_args.test_file
319extension = data_args.test_file.split(".")[-1]
320raw_datasets = load_dataset(
321extension,
322data_files=data_files,
323field="data",
324cache_dir=model_args.cache_dir,
325token=model_args.token,
326)
327# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
328# https://huggingface.co/docs/datasets/loading_datasets.
329
330# Load pretrained model and tokenizer
331#
332# Distributed training:
333# The .from_pretrained methods guarantee that only one local process can concurrently
334# download model & vocab.
335config = XLNetConfig.from_pretrained(
336model_args.config_name if model_args.config_name else model_args.model_name_or_path,
337cache_dir=model_args.cache_dir,
338revision=model_args.model_revision,
339token=model_args.token,
340)
341tokenizer = XLNetTokenizerFast.from_pretrained(
342model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
343cache_dir=model_args.cache_dir,
344revision=model_args.model_revision,
345token=model_args.token,
346)
347model = XLNetForQuestionAnswering.from_pretrained(
348model_args.model_name_or_path,
349from_tf=bool(".ckpt" in model_args.model_name_or_path),
350config=config,
351cache_dir=model_args.cache_dir,
352revision=model_args.model_revision,
353token=model_args.token,
354)
355
356# Preprocessing the datasets.
357# Preprocessing is slightly different for training and evaluation.
358if training_args.do_train:
359column_names = raw_datasets["train"].column_names
360elif training_args.do_eval:
361column_names = raw_datasets["validation"].column_names
362else:
363column_names = raw_datasets["test"].column_names
364question_column_name = "question" if "question" in column_names else column_names[0]
365context_column_name = "context" if "context" in column_names else column_names[1]
366answer_column_name = "answers" if "answers" in column_names else column_names[2]
367
368# Padding side determines if we do (question|context) or (context|question).
369pad_on_right = tokenizer.padding_side == "right"
370
371if data_args.max_seq_length > tokenizer.model_max_length:
372logger.warning(
373f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
374f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
375)
376max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
377
378# Training preprocessing
379def prepare_train_features(examples):
380# Some of the questions have lots of whitespace on the left, which is not useful and will make the
381# truncation of the context fail (the tokenized question will take a lots of space). So we remove that
382# left whitespace
383examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
384
385# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
386# in one example possible giving several features when a context is long, each of those features having a
387# context that overlaps a bit the context of the previous feature.
388tokenized_examples = tokenizer(
389examples[question_column_name if pad_on_right else context_column_name],
390examples[context_column_name if pad_on_right else question_column_name],
391truncation="only_second" if pad_on_right else "only_first",
392max_length=max_seq_length,
393stride=data_args.doc_stride,
394return_overflowing_tokens=True,
395return_offsets_mapping=True,
396return_special_tokens_mask=True,
397return_token_type_ids=True,
398padding="max_length",
399)
400
401# Since one example might give us several features if it has a long context, we need a map from a feature to
402# its corresponding example. This key gives us just that.
403sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
404# The offset mappings will give us a map from token to character position in the original context. This will
405# help us compute the start_positions and end_positions.
406offset_mapping = tokenized_examples.pop("offset_mapping")
407# The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
408special_tokens = tokenized_examples.pop("special_tokens_mask")
409
410# Let's label those examples!
411tokenized_examples["start_positions"] = []
412tokenized_examples["end_positions"] = []
413tokenized_examples["is_impossible"] = []
414tokenized_examples["cls_index"] = []
415tokenized_examples["p_mask"] = []
416
417for i, offsets in enumerate(offset_mapping):
418# We will label impossible answers with the index of the CLS token.
419input_ids = tokenized_examples["input_ids"][i]
420cls_index = input_ids.index(tokenizer.cls_token_id)
421tokenized_examples["cls_index"].append(cls_index)
422
423# Grab the sequence corresponding to that example (to know what is the context and what is the question).
424sequence_ids = tokenized_examples["token_type_ids"][i]
425for k, s in enumerate(special_tokens[i]):
426if s:
427sequence_ids[k] = 3
428context_idx = 1 if pad_on_right else 0
429
430# Build the p_mask: non special tokens and context gets 0.0, the others get 1.0.
431# The cls token gets 1.0 too (for predictions of empty answers).
432tokenized_examples["p_mask"].append(
433[
4340.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
435for k, s in enumerate(sequence_ids)
436]
437)
438
439# One example can give several spans, this is the index of the example containing this span of text.
440sample_index = sample_mapping[i]
441answers = examples[answer_column_name][sample_index]
442# If no answers are given, set the cls_index as answer.
443if len(answers["answer_start"]) == 0:
444tokenized_examples["start_positions"].append(cls_index)
445tokenized_examples["end_positions"].append(cls_index)
446tokenized_examples["is_impossible"].append(1.0)
447else:
448# Start/end character index of the answer in the text.
449start_char = answers["answer_start"][0]
450end_char = start_char + len(answers["text"][0])
451
452# Start token index of the current span in the text.
453token_start_index = 0
454while sequence_ids[token_start_index] != context_idx:
455token_start_index += 1
456
457# End token index of the current span in the text.
458token_end_index = len(input_ids) - 1
459while sequence_ids[token_end_index] != context_idx:
460token_end_index -= 1
461# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
462if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
463tokenized_examples["start_positions"].append(cls_index)
464tokenized_examples["end_positions"].append(cls_index)
465tokenized_examples["is_impossible"].append(1.0)
466else:
467# Otherwise move the token_start_index and token_end_index to the two ends of the answer.
468# Note: we could go after the last offset if the answer is the last word (edge case).
469while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
470token_start_index += 1
471tokenized_examples["start_positions"].append(token_start_index - 1)
472while offsets[token_end_index][1] >= end_char:
473token_end_index -= 1
474tokenized_examples["end_positions"].append(token_end_index + 1)
475tokenized_examples["is_impossible"].append(0.0)
476
477return tokenized_examples
478
479if training_args.do_train:
480if "train" not in raw_datasets:
481raise ValueError("--do_train requires a train dataset")
482train_dataset = raw_datasets["train"]
483if data_args.max_train_samples is not None:
484# Select samples from Dataset, This will help to decrease processing time
485max_train_samples = min(len(train_dataset), data_args.max_train_samples)
486train_dataset = train_dataset.select(range(max_train_samples))
487# Create Training Features
488with training_args.main_process_first(desc="train dataset map pre-processing"):
489train_dataset = train_dataset.map(
490prepare_train_features,
491batched=True,
492num_proc=data_args.preprocessing_num_workers,
493remove_columns=column_names,
494load_from_cache_file=not data_args.overwrite_cache,
495desc="Running tokenizer on train dataset",
496)
497if data_args.max_train_samples is not None:
498# Select samples from dataset again since Feature Creation might increase number of features
499max_train_samples = min(len(train_dataset), data_args.max_train_samples)
500train_dataset = train_dataset.select(range(max_train_samples))
501
502# Validation preprocessing
503def prepare_validation_features(examples):
504# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
505# in one example possible giving several features when a context is long, each of those features having a
506# context that overlaps a bit the context of the previous feature.
507tokenized_examples = tokenizer(
508examples[question_column_name if pad_on_right else context_column_name],
509examples[context_column_name if pad_on_right else question_column_name],
510truncation="only_second" if pad_on_right else "only_first",
511max_length=max_seq_length,
512stride=data_args.doc_stride,
513return_overflowing_tokens=True,
514return_offsets_mapping=True,
515return_special_tokens_mask=True,
516return_token_type_ids=True,
517padding="max_length",
518)
519
520# Since one example might give us several features if it has a long context, we need a map from a feature to
521# its corresponding example. This key gives us just that.
522sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
523
524# The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
525special_tokens = tokenized_examples.pop("special_tokens_mask")
526
527# For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
528# corresponding example_id and we will store the offset mappings.
529tokenized_examples["example_id"] = []
530
531# We still provide the index of the CLS token and the p_mask to the model, but not the is_impossible label.
532tokenized_examples["cls_index"] = []
533tokenized_examples["p_mask"] = []
534
535for i, input_ids in enumerate(tokenized_examples["input_ids"]):
536# Find the CLS token in the input ids.
537cls_index = input_ids.index(tokenizer.cls_token_id)
538tokenized_examples["cls_index"].append(cls_index)
539
540# Grab the sequence corresponding to that example (to know what is the context and what is the question).
541sequence_ids = tokenized_examples["token_type_ids"][i]
542for k, s in enumerate(special_tokens[i]):
543if s:
544sequence_ids[k] = 3
545context_idx = 1 if pad_on_right else 0
546
547# Build the p_mask: non special tokens and context gets 0.0, the others 1.0.
548tokenized_examples["p_mask"].append(
549[
5500.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
551for k, s in enumerate(sequence_ids)
552]
553)
554
555# One example can give several spans, this is the index of the example containing this span of text.
556sample_index = sample_mapping[i]
557tokenized_examples["example_id"].append(examples["id"][sample_index])
558
559# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
560# position is part of the context or not.
561tokenized_examples["offset_mapping"][i] = [
562(o if sequence_ids[k] == context_idx else None)
563for k, o in enumerate(tokenized_examples["offset_mapping"][i])
564]
565
566return tokenized_examples
567
568if training_args.do_eval:
569if "validation" not in raw_datasets:
570raise ValueError("--do_eval requires a validation dataset")
571eval_examples = raw_datasets["validation"]
572if data_args.max_eval_samples is not None:
573# Selecting Eval Samples from Dataset
574max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
575eval_examples = eval_examples.select(range(max_eval_samples))
576# Create Features from Eval Dataset
577with training_args.main_process_first(desc="validation dataset map pre-processing"):
578eval_dataset = eval_examples.map(
579prepare_validation_features,
580batched=True,
581num_proc=data_args.preprocessing_num_workers,
582remove_columns=column_names,
583load_from_cache_file=not data_args.overwrite_cache,
584desc="Running tokenizer on validation dataset",
585)
586if data_args.max_eval_samples is not None:
587# Selecting Samples from Dataset again since Feature Creation might increase samples size
588max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
589eval_dataset = eval_dataset.select(range(max_eval_samples))
590
591if training_args.do_predict:
592if "test" not in raw_datasets:
593raise ValueError("--do_predict requires a test dataset")
594predict_examples = raw_datasets["test"]
595if data_args.max_predict_samples is not None:
596# We will select sample from whole data
597predict_examples = predict_examples.select(range(data_args.max_predict_samples))
598# Test Feature Creation
599with training_args.main_process_first(desc="prediction dataset map pre-processing"):
600predict_dataset = predict_examples.map(
601prepare_validation_features,
602batched=True,
603num_proc=data_args.preprocessing_num_workers,
604remove_columns=column_names,
605load_from_cache_file=not data_args.overwrite_cache,
606desc="Running tokenizer on prediction dataset",
607)
608if data_args.max_predict_samples is not None:
609# During Feature creation dataset samples might increase, we will select required samples again
610max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
611predict_dataset = predict_dataset.select(range(max_predict_samples))
612
613# Data collator
614# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
615# collator.
616data_collator = (
617default_data_collator
618if data_args.pad_to_max_length
619else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
620)
621
622# Post-processing:
623def post_processing_function(examples, features, predictions, stage="eval"):
624# Post-processing: we match the start logits and end logits to answers in the original context.
625predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
626examples=examples,
627features=features,
628predictions=predictions,
629version_2_with_negative=data_args.version_2_with_negative,
630n_best_size=data_args.n_best_size,
631max_answer_length=data_args.max_answer_length,
632start_n_top=model.config.start_n_top,
633end_n_top=model.config.end_n_top,
634output_dir=training_args.output_dir,
635log_level=log_level,
636prefix=stage,
637)
638# Format the result to the format the metric expects.
639if data_args.version_2_with_negative:
640formatted_predictions = [
641{"id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k]}
642for k, v in predictions.items()
643]
644else:
645formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
646
647references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
648return EvalPrediction(predictions=formatted_predictions, label_ids=references)
649
650metric = evaluate.load(
651"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
652)
653
654def compute_metrics(p: EvalPrediction):
655return metric.compute(predictions=p.predictions, references=p.label_ids)
656
657# Initialize our Trainer
658trainer = QuestionAnsweringTrainer(
659model=model,
660args=training_args,
661train_dataset=train_dataset if training_args.do_train else None,
662eval_dataset=eval_dataset if training_args.do_eval else None,
663eval_examples=eval_examples if training_args.do_eval else None,
664tokenizer=tokenizer,
665data_collator=data_collator,
666post_process_function=post_processing_function,
667compute_metrics=compute_metrics,
668)
669
670# Training
671if training_args.do_train:
672checkpoint = None
673if training_args.resume_from_checkpoint is not None:
674checkpoint = training_args.resume_from_checkpoint
675elif last_checkpoint is not None:
676checkpoint = last_checkpoint
677train_result = trainer.train(resume_from_checkpoint=checkpoint)
678trainer.save_model() # Saves the tokenizer too for easy upload
679
680metrics = train_result.metrics
681
682max_train_samples = (
683data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
684)
685metrics["train_samples"] = min(max_train_samples, len(train_dataset))
686
687trainer.log_metrics("train", metrics)
688trainer.save_metrics("train", metrics)
689trainer.save_state()
690
691# Evaluation
692if training_args.do_eval:
693logger.info("*** Evaluate ***")
694metrics = trainer.evaluate()
695
696max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
697metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
698
699trainer.log_metrics("eval", metrics)
700trainer.save_metrics("eval", metrics)
701
702# Prediction
703if training_args.do_predict:
704logger.info("*** Predict ***")
705results = trainer.predict(predict_dataset, predict_examples)
706metrics = results.metrics
707
708max_predict_samples = (
709data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
710)
711metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
712
713trainer.log_metrics("predict", metrics)
714trainer.save_metrics("predict", metrics)
715
716kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
717if data_args.dataset_name is not None:
718kwargs["dataset_tags"] = data_args.dataset_name
719if data_args.dataset_config_name is not None:
720kwargs["dataset_args"] = data_args.dataset_config_name
721kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
722else:
723kwargs["dataset"] = data_args.dataset_name
724
725if training_args.push_to_hub:
726trainer.push_to_hub(**kwargs)
727else:
728trainer.create_model_card(**kwargs)
729
730
731def _mp_fn(index):
732# For xla_spawn (TPUs)
733main()
734
735
736if __name__ == "__main__":
737main()
738