transformers

Форк
0
737 строк · 33.1 Кб
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2020 The HuggingFace Team All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""
17
Fine-tuning XLNet for question answering with beam search using a slightly adapted version of the 🤗 Trainer.
18
"""
19
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
20

21
import logging
22
import os
23
import sys
24
import warnings
25
from dataclasses import dataclass, field
26
from typing import Optional
27

28
import datasets
29
import evaluate
30
from datasets import load_dataset
31
from trainer_qa import QuestionAnsweringTrainer
32
from utils_qa import postprocess_qa_predictions_with_beam_search
33

34
import transformers
35
from transformers import (
36
    DataCollatorWithPadding,
37
    EvalPrediction,
38
    HfArgumentParser,
39
    TrainingArguments,
40
    XLNetConfig,
41
    XLNetForQuestionAnswering,
42
    XLNetTokenizerFast,
43
    default_data_collator,
44
    set_seed,
45
)
46
from transformers.trainer_utils import get_last_checkpoint
47
from transformers.utils import check_min_version, send_example_telemetry
48
from transformers.utils.versions import require_version
49

50

51
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
52
check_min_version("4.39.0.dev0")
53

54
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
55

56
logger = logging.getLogger(__name__)
57

58

59
@dataclass
60
class ModelArguments:
61
    """
62
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
63
    """
64

65
    model_name_or_path: str = field(
66
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
67
    )
68
    config_name: Optional[str] = field(
69
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
70
    )
71
    tokenizer_name: Optional[str] = field(
72
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
73
    )
74
    cache_dir: Optional[str] = field(
75
        default=None,
76
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
77
    )
78
    model_revision: str = field(
79
        default="main",
80
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
81
    )
82
    token: str = field(
83
        default=None,
84
        metadata={
85
            "help": (
86
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
87
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
88
            )
89
        },
90
    )
91
    use_auth_token: bool = field(
92
        default=None,
93
        metadata={
94
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
95
        },
96
    )
97

98

99
@dataclass
100
class DataTrainingArguments:
101
    """
102
    Arguments pertaining to what data we are going to input our model for training and eval.
103
    """
104

105
    dataset_name: Optional[str] = field(
106
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
107
    )
108
    dataset_config_name: Optional[str] = field(
109
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
110
    )
111
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
112
    validation_file: Optional[str] = field(
113
        default=None,
114
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
115
    )
116
    test_file: Optional[str] = field(
117
        default=None,
118
        metadata={"help": "An optional input test data file to test the perplexity on (a text file)."},
119
    )
120
    overwrite_cache: bool = field(
121
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
122
    )
123
    preprocessing_num_workers: Optional[int] = field(
124
        default=None,
125
        metadata={"help": "The number of processes to use for the preprocessing."},
126
    )
127
    max_seq_length: int = field(
128
        default=384,
129
        metadata={
130
            "help": (
131
                "The maximum total input sequence length after tokenization. Sequences longer "
132
                "than this will be truncated, sequences shorter will be padded."
133
            )
134
        },
135
    )
136
    pad_to_max_length: bool = field(
137
        default=True,
138
        metadata={
139
            "help": (
140
                "Whether to pad all samples to `max_seq_length`. If False, will pad the samples dynamically when"
141
                " batching to the maximum length in the batch (which can be faster on GPU but will be slower on TPU)."
142
            )
143
        },
144
    )
145
    max_train_samples: Optional[int] = field(
146
        default=None,
147
        metadata={
148
            "help": (
149
                "For debugging purposes or quicker training, truncate the number of training examples to this "
150
                "value if set."
151
            )
152
        },
153
    )
154
    max_eval_samples: Optional[int] = field(
155
        default=None,
156
        metadata={
157
            "help": (
158
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
159
                "value if set."
160
            )
161
        },
162
    )
163
    max_predict_samples: Optional[int] = field(
164
        default=None,
165
        metadata={
166
            "help": (
167
                "For debugging purposes or quicker training, truncate the number of prediction examples to this "
168
                "value if set."
169
            )
170
        },
171
    )
172
    version_2_with_negative: bool = field(
173
        default=False, metadata={"help": "If true, some of the examples do not have an answer."}
174
    )
175
    null_score_diff_threshold: float = field(
176
        default=0.0,
177
        metadata={
178
            "help": (
179
                "The threshold used to select the null answer: if the best answer has a score that is less than "
180
                "the score of the null answer minus this threshold, the null answer is selected for this example. "
181
                "Only useful when `version_2_with_negative=True`."
182
            )
183
        },
184
    )
185
    doc_stride: int = field(
186
        default=128,
187
        metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
188
    )
189
    n_best_size: int = field(
190
        default=20,
191
        metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
192
    )
193
    max_answer_length: int = field(
194
        default=30,
195
        metadata={
196
            "help": (
197
                "The maximum length of an answer that can be generated. This is needed because the start "
198
                "and end predictions are not conditioned on one another."
199
            )
200
        },
201
    )
202

203
    def __post_init__(self):
204
        if (
205
            self.dataset_name is None
206
            and self.train_file is None
207
            and self.validation_file is None
208
            and self.test_file is None
209
        ):
210
            raise ValueError("Need either a dataset name or a training/validation/test file.")
211
        else:
212
            if self.train_file is not None:
213
                extension = self.train_file.split(".")[-1]
214
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
215
            if self.validation_file is not None:
216
                extension = self.validation_file.split(".")[-1]
217
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
218
            if self.test_file is not None:
219
                extension = self.test_file.split(".")[-1]
220
                assert extension in ["csv", "json"], "`test_file` should be a csv or a json file."
221

222

223
def main():
224
    # See all possible arguments in src/transformers/training_args.py
225
    # or by passing the --help flag to this script.
226
    # We now keep distinct sets of args, for a cleaner separation of concerns.
227

228
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
229
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
230
        # If we pass only one argument to the script and it's the path to a json file,
231
        # let's parse it to get our arguments.
232
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
233
    else:
234
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
235

236
    if model_args.use_auth_token is not None:
237
        warnings.warn(
238
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
239
            FutureWarning,
240
        )
241
        if model_args.token is not None:
242
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
243
        model_args.token = model_args.use_auth_token
244

245
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
246
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
247
    send_example_telemetry("run_qa_beam_search", model_args, data_args)
248

249
    # Setup logging
250
    logging.basicConfig(
251
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
252
        datefmt="%m/%d/%Y %H:%M:%S",
253
        handlers=[logging.StreamHandler(sys.stdout)],
254
    )
255

256
    if training_args.should_log:
257
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
258
        transformers.utils.logging.set_verbosity_info()
259

260
    log_level = training_args.get_process_log_level()
261
    logger.setLevel(log_level)
262
    datasets.utils.logging.set_verbosity(log_level)
263
    transformers.utils.logging.set_verbosity(log_level)
264
    transformers.utils.logging.enable_default_handler()
265
    transformers.utils.logging.enable_explicit_format()
266

267
    # Log on each process the small summary:
268
    logger.warning(
269
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
270
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
271
    )
272
    logger.info(f"Training/evaluation parameters {training_args}")
273

274
    # Detecting last checkpoint.
275
    last_checkpoint = None
276
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
277
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
278
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
279
            raise ValueError(
280
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
281
                "Use --overwrite_output_dir to overcome."
282
            )
283
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
284
            logger.info(
285
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
286
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
287
            )
288

289
    # Set seed before initializing model.
290
    set_seed(training_args.seed)
291

292
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
293
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
294
    # (the dataset will be downloaded automatically from the datasets Hub).
295
    #
296
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
297
    # 'text' is found. You can easily tweak this behavior (see below).
298
    #
299
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
300
    # download the dataset.
301
    if data_args.dataset_name is not None:
302
        # Downloading and loading a dataset from the hub.
303
        raw_datasets = load_dataset(
304
            data_args.dataset_name,
305
            data_args.dataset_config_name,
306
            cache_dir=model_args.cache_dir,
307
            token=model_args.token,
308
        )
309
    else:
310
        data_files = {}
311
        if data_args.train_file is not None:
312
            data_files["train"] = data_args.train_file
313
            extension = data_args.train_file.split(".")[-1]
314
        if data_args.validation_file is not None:
315
            data_files["validation"] = data_args.validation_file
316
            extension = data_args.validation_file.split(".")[-1]
317
        if data_args.test_file is not None:
318
            data_files["test"] = data_args.test_file
319
            extension = data_args.test_file.split(".")[-1]
320
        raw_datasets = load_dataset(
321
            extension,
322
            data_files=data_files,
323
            field="data",
324
            cache_dir=model_args.cache_dir,
325
            token=model_args.token,
326
        )
327
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
328
    # https://huggingface.co/docs/datasets/loading_datasets.
329

330
    # Load pretrained model and tokenizer
331
    #
332
    # Distributed training:
333
    # The .from_pretrained methods guarantee that only one local process can concurrently
334
    # download model & vocab.
335
    config = XLNetConfig.from_pretrained(
336
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
337
        cache_dir=model_args.cache_dir,
338
        revision=model_args.model_revision,
339
        token=model_args.token,
340
    )
341
    tokenizer = XLNetTokenizerFast.from_pretrained(
342
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
343
        cache_dir=model_args.cache_dir,
344
        revision=model_args.model_revision,
345
        token=model_args.token,
346
    )
347
    model = XLNetForQuestionAnswering.from_pretrained(
348
        model_args.model_name_or_path,
349
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
350
        config=config,
351
        cache_dir=model_args.cache_dir,
352
        revision=model_args.model_revision,
353
        token=model_args.token,
354
    )
355

356
    # Preprocessing the datasets.
357
    # Preprocessing is slightly different for training and evaluation.
358
    if training_args.do_train:
359
        column_names = raw_datasets["train"].column_names
360
    elif training_args.do_eval:
361
        column_names = raw_datasets["validation"].column_names
362
    else:
363
        column_names = raw_datasets["test"].column_names
364
    question_column_name = "question" if "question" in column_names else column_names[0]
365
    context_column_name = "context" if "context" in column_names else column_names[1]
366
    answer_column_name = "answers" if "answers" in column_names else column_names[2]
367

368
    # Padding side determines if we do (question|context) or (context|question).
369
    pad_on_right = tokenizer.padding_side == "right"
370

371
    if data_args.max_seq_length > tokenizer.model_max_length:
372
        logger.warning(
373
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
374
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
375
        )
376
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
377

378
    # Training preprocessing
379
    def prepare_train_features(examples):
380
        # Some of the questions have lots of whitespace on the left, which is not useful and will make the
381
        # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
382
        # left whitespace
383
        examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]]
384

385
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
386
        # in one example possible giving several features when a context is long, each of those features having a
387
        # context that overlaps a bit the context of the previous feature.
388
        tokenized_examples = tokenizer(
389
            examples[question_column_name if pad_on_right else context_column_name],
390
            examples[context_column_name if pad_on_right else question_column_name],
391
            truncation="only_second" if pad_on_right else "only_first",
392
            max_length=max_seq_length,
393
            stride=data_args.doc_stride,
394
            return_overflowing_tokens=True,
395
            return_offsets_mapping=True,
396
            return_special_tokens_mask=True,
397
            return_token_type_ids=True,
398
            padding="max_length",
399
        )
400

401
        # Since one example might give us several features if it has a long context, we need a map from a feature to
402
        # its corresponding example. This key gives us just that.
403
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
404
        # The offset mappings will give us a map from token to character position in the original context. This will
405
        # help us compute the start_positions and end_positions.
406
        offset_mapping = tokenized_examples.pop("offset_mapping")
407
        # The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
408
        special_tokens = tokenized_examples.pop("special_tokens_mask")
409

410
        # Let's label those examples!
411
        tokenized_examples["start_positions"] = []
412
        tokenized_examples["end_positions"] = []
413
        tokenized_examples["is_impossible"] = []
414
        tokenized_examples["cls_index"] = []
415
        tokenized_examples["p_mask"] = []
416

417
        for i, offsets in enumerate(offset_mapping):
418
            # We will label impossible answers with the index of the CLS token.
419
            input_ids = tokenized_examples["input_ids"][i]
420
            cls_index = input_ids.index(tokenizer.cls_token_id)
421
            tokenized_examples["cls_index"].append(cls_index)
422

423
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
424
            sequence_ids = tokenized_examples["token_type_ids"][i]
425
            for k, s in enumerate(special_tokens[i]):
426
                if s:
427
                    sequence_ids[k] = 3
428
            context_idx = 1 if pad_on_right else 0
429

430
            # Build the p_mask: non special tokens and context gets 0.0, the others get 1.0.
431
            # The cls token gets 1.0 too (for predictions of empty answers).
432
            tokenized_examples["p_mask"].append(
433
                [
434
                    0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
435
                    for k, s in enumerate(sequence_ids)
436
                ]
437
            )
438

439
            # One example can give several spans, this is the index of the example containing this span of text.
440
            sample_index = sample_mapping[i]
441
            answers = examples[answer_column_name][sample_index]
442
            # If no answers are given, set the cls_index as answer.
443
            if len(answers["answer_start"]) == 0:
444
                tokenized_examples["start_positions"].append(cls_index)
445
                tokenized_examples["end_positions"].append(cls_index)
446
                tokenized_examples["is_impossible"].append(1.0)
447
            else:
448
                # Start/end character index of the answer in the text.
449
                start_char = answers["answer_start"][0]
450
                end_char = start_char + len(answers["text"][0])
451

452
                # Start token index of the current span in the text.
453
                token_start_index = 0
454
                while sequence_ids[token_start_index] != context_idx:
455
                    token_start_index += 1
456

457
                # End token index of the current span in the text.
458
                token_end_index = len(input_ids) - 1
459
                while sequence_ids[token_end_index] != context_idx:
460
                    token_end_index -= 1
461
                # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
462
                if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
463
                    tokenized_examples["start_positions"].append(cls_index)
464
                    tokenized_examples["end_positions"].append(cls_index)
465
                    tokenized_examples["is_impossible"].append(1.0)
466
                else:
467
                    # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
468
                    # Note: we could go after the last offset if the answer is the last word (edge case).
469
                    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
470
                        token_start_index += 1
471
                    tokenized_examples["start_positions"].append(token_start_index - 1)
472
                    while offsets[token_end_index][1] >= end_char:
473
                        token_end_index -= 1
474
                    tokenized_examples["end_positions"].append(token_end_index + 1)
475
                    tokenized_examples["is_impossible"].append(0.0)
476

477
        return tokenized_examples
478

479
    if training_args.do_train:
480
        if "train" not in raw_datasets:
481
            raise ValueError("--do_train requires a train dataset")
482
        train_dataset = raw_datasets["train"]
483
        if data_args.max_train_samples is not None:
484
            # Select samples from Dataset, This will help to decrease processing time
485
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
486
            train_dataset = train_dataset.select(range(max_train_samples))
487
        # Create Training Features
488
        with training_args.main_process_first(desc="train dataset map pre-processing"):
489
            train_dataset = train_dataset.map(
490
                prepare_train_features,
491
                batched=True,
492
                num_proc=data_args.preprocessing_num_workers,
493
                remove_columns=column_names,
494
                load_from_cache_file=not data_args.overwrite_cache,
495
                desc="Running tokenizer on train dataset",
496
            )
497
        if data_args.max_train_samples is not None:
498
            # Select samples from dataset again since Feature Creation might increase number of features
499
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
500
            train_dataset = train_dataset.select(range(max_train_samples))
501

502
    # Validation preprocessing
503
    def prepare_validation_features(examples):
504
        # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
505
        # in one example possible giving several features when a context is long, each of those features having a
506
        # context that overlaps a bit the context of the previous feature.
507
        tokenized_examples = tokenizer(
508
            examples[question_column_name if pad_on_right else context_column_name],
509
            examples[context_column_name if pad_on_right else question_column_name],
510
            truncation="only_second" if pad_on_right else "only_first",
511
            max_length=max_seq_length,
512
            stride=data_args.doc_stride,
513
            return_overflowing_tokens=True,
514
            return_offsets_mapping=True,
515
            return_special_tokens_mask=True,
516
            return_token_type_ids=True,
517
            padding="max_length",
518
        )
519

520
        # Since one example might give us several features if it has a long context, we need a map from a feature to
521
        # its corresponding example. This key gives us just that.
522
        sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
523

524
        # The special tokens will help us build the p_mask (which indicates the tokens that can't be in answers).
525
        special_tokens = tokenized_examples.pop("special_tokens_mask")
526

527
        # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
528
        # corresponding example_id and we will store the offset mappings.
529
        tokenized_examples["example_id"] = []
530

531
        # We still provide the index of the CLS token and the p_mask to the model, but not the is_impossible label.
532
        tokenized_examples["cls_index"] = []
533
        tokenized_examples["p_mask"] = []
534

535
        for i, input_ids in enumerate(tokenized_examples["input_ids"]):
536
            # Find the CLS token in the input ids.
537
            cls_index = input_ids.index(tokenizer.cls_token_id)
538
            tokenized_examples["cls_index"].append(cls_index)
539

540
            # Grab the sequence corresponding to that example (to know what is the context and what is the question).
541
            sequence_ids = tokenized_examples["token_type_ids"][i]
542
            for k, s in enumerate(special_tokens[i]):
543
                if s:
544
                    sequence_ids[k] = 3
545
            context_idx = 1 if pad_on_right else 0
546

547
            # Build the p_mask: non special tokens and context gets 0.0, the others 1.0.
548
            tokenized_examples["p_mask"].append(
549
                [
550
                    0.0 if (not special_tokens[i][k] and s == context_idx) or k == cls_index else 1.0
551
                    for k, s in enumerate(sequence_ids)
552
                ]
553
            )
554

555
            # One example can give several spans, this is the index of the example containing this span of text.
556
            sample_index = sample_mapping[i]
557
            tokenized_examples["example_id"].append(examples["id"][sample_index])
558

559
            # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
560
            # position is part of the context or not.
561
            tokenized_examples["offset_mapping"][i] = [
562
                (o if sequence_ids[k] == context_idx else None)
563
                for k, o in enumerate(tokenized_examples["offset_mapping"][i])
564
            ]
565

566
        return tokenized_examples
567

568
    if training_args.do_eval:
569
        if "validation" not in raw_datasets:
570
            raise ValueError("--do_eval requires a validation dataset")
571
        eval_examples = raw_datasets["validation"]
572
        if data_args.max_eval_samples is not None:
573
            # Selecting Eval Samples from Dataset
574
            max_eval_samples = min(len(eval_examples), data_args.max_eval_samples)
575
            eval_examples = eval_examples.select(range(max_eval_samples))
576
        # Create Features from Eval Dataset
577
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
578
            eval_dataset = eval_examples.map(
579
                prepare_validation_features,
580
                batched=True,
581
                num_proc=data_args.preprocessing_num_workers,
582
                remove_columns=column_names,
583
                load_from_cache_file=not data_args.overwrite_cache,
584
                desc="Running tokenizer on validation dataset",
585
            )
586
        if data_args.max_eval_samples is not None:
587
            # Selecting Samples from Dataset again since Feature Creation might increase samples size
588
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
589
            eval_dataset = eval_dataset.select(range(max_eval_samples))
590

591
    if training_args.do_predict:
592
        if "test" not in raw_datasets:
593
            raise ValueError("--do_predict requires a test dataset")
594
        predict_examples = raw_datasets["test"]
595
        if data_args.max_predict_samples is not None:
596
            # We will select sample from whole data
597
            predict_examples = predict_examples.select(range(data_args.max_predict_samples))
598
        # Test Feature Creation
599
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
600
            predict_dataset = predict_examples.map(
601
                prepare_validation_features,
602
                batched=True,
603
                num_proc=data_args.preprocessing_num_workers,
604
                remove_columns=column_names,
605
                load_from_cache_file=not data_args.overwrite_cache,
606
                desc="Running tokenizer on prediction dataset",
607
            )
608
        if data_args.max_predict_samples is not None:
609
            # During Feature creation dataset samples might increase, we will select required samples again
610
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
611
            predict_dataset = predict_dataset.select(range(max_predict_samples))
612

613
    # Data collator
614
    # We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
615
    # collator.
616
    data_collator = (
617
        default_data_collator
618
        if data_args.pad_to_max_length
619
        else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
620
    )
621

622
    # Post-processing:
623
    def post_processing_function(examples, features, predictions, stage="eval"):
624
        # Post-processing: we match the start logits and end logits to answers in the original context.
625
        predictions, scores_diff_json = postprocess_qa_predictions_with_beam_search(
626
            examples=examples,
627
            features=features,
628
            predictions=predictions,
629
            version_2_with_negative=data_args.version_2_with_negative,
630
            n_best_size=data_args.n_best_size,
631
            max_answer_length=data_args.max_answer_length,
632
            start_n_top=model.config.start_n_top,
633
            end_n_top=model.config.end_n_top,
634
            output_dir=training_args.output_dir,
635
            log_level=log_level,
636
            prefix=stage,
637
        )
638
        # Format the result to the format the metric expects.
639
        if data_args.version_2_with_negative:
640
            formatted_predictions = [
641
                {"id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k]}
642
                for k, v in predictions.items()
643
            ]
644
        else:
645
            formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
646

647
        references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
648
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)
649

650
    metric = evaluate.load(
651
        "squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
652
    )
653

654
    def compute_metrics(p: EvalPrediction):
655
        return metric.compute(predictions=p.predictions, references=p.label_ids)
656

657
    # Initialize our Trainer
658
    trainer = QuestionAnsweringTrainer(
659
        model=model,
660
        args=training_args,
661
        train_dataset=train_dataset if training_args.do_train else None,
662
        eval_dataset=eval_dataset if training_args.do_eval else None,
663
        eval_examples=eval_examples if training_args.do_eval else None,
664
        tokenizer=tokenizer,
665
        data_collator=data_collator,
666
        post_process_function=post_processing_function,
667
        compute_metrics=compute_metrics,
668
    )
669

670
    # Training
671
    if training_args.do_train:
672
        checkpoint = None
673
        if training_args.resume_from_checkpoint is not None:
674
            checkpoint = training_args.resume_from_checkpoint
675
        elif last_checkpoint is not None:
676
            checkpoint = last_checkpoint
677
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
678
        trainer.save_model()  # Saves the tokenizer too for easy upload
679

680
        metrics = train_result.metrics
681

682
        max_train_samples = (
683
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
684
        )
685
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
686

687
        trainer.log_metrics("train", metrics)
688
        trainer.save_metrics("train", metrics)
689
        trainer.save_state()
690

691
    # Evaluation
692
    if training_args.do_eval:
693
        logger.info("*** Evaluate ***")
694
        metrics = trainer.evaluate()
695

696
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
697
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
698

699
        trainer.log_metrics("eval", metrics)
700
        trainer.save_metrics("eval", metrics)
701

702
    # Prediction
703
    if training_args.do_predict:
704
        logger.info("*** Predict ***")
705
        results = trainer.predict(predict_dataset, predict_examples)
706
        metrics = results.metrics
707

708
        max_predict_samples = (
709
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
710
        )
711
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
712

713
        trainer.log_metrics("predict", metrics)
714
        trainer.save_metrics("predict", metrics)
715

716
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "question-answering"}
717
    if data_args.dataset_name is not None:
718
        kwargs["dataset_tags"] = data_args.dataset_name
719
        if data_args.dataset_config_name is not None:
720
            kwargs["dataset_args"] = data_args.dataset_config_name
721
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
722
        else:
723
            kwargs["dataset"] = data_args.dataset_name
724

725
    if training_args.push_to_hub:
726
        trainer.push_to_hub(**kwargs)
727
    else:
728
        trainer.create_model_card(**kwargs)
729

730

731
def _mp_fn(index):
732
    # For xla_spawn (TPUs)
733
    main()
734

735

736
if __name__ == "__main__":
737
    main()
738

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.