transformers

Форк
0
587 строк · 24.5 Кб
1
#!/usr/bin/env python
2
# coding=utf-8
3
# Copyright 2020 The HuggingFace Team All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""
17
Fine-tuning the library models for permutation language modeling.
18
"""
19
# You can also adapt this script on your own permutation language modeling task. Pointers for this are left as comments.
20

21
import logging
22
import math
23
import os
24
import sys
25
import warnings
26
from dataclasses import dataclass, field
27
from itertools import chain
28
from typing import Optional
29

30
import datasets
31
from datasets import load_dataset
32

33
import transformers
34
from transformers import (
35
    AutoConfig,
36
    AutoTokenizer,
37
    DataCollatorForPermutationLanguageModeling,
38
    HfArgumentParser,
39
    Trainer,
40
    TrainingArguments,
41
    XLNetConfig,
42
    XLNetLMHeadModel,
43
    set_seed,
44
)
45
from transformers.trainer_utils import get_last_checkpoint
46
from transformers.utils import check_min_version, send_example_telemetry
47
from transformers.utils.versions import require_version
48

49

50
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
check_min_version("4.39.0.dev0")
52

53
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
54

55
logger = logging.getLogger(__name__)
56

57

58
@dataclass
59
class ModelArguments:
60
    """
61
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
62
    """
63

64
    model_name_or_path: Optional[str] = field(
65
        default=None,
66
        metadata={
67
            "help": (
68
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
69
            )
70
        },
71
    )
72
    config_name: Optional[str] = field(
73
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
74
    )
75
    config_overrides: Optional[str] = field(
76
        default=None,
77
        metadata={
78
            "help": (
79
                "Override some existing default config settings when a model is trained from scratch. Example: "
80
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
81
            )
82
        },
83
    )
84
    tokenizer_name: Optional[str] = field(
85
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
86
    )
87
    cache_dir: Optional[str] = field(
88
        default=None,
89
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
90
    )
91
    use_fast_tokenizer: bool = field(
92
        default=True,
93
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
    )
95
    model_revision: str = field(
96
        default="main",
97
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
98
    )
99
    token: str = field(
100
        default=None,
101
        metadata={
102
            "help": (
103
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
104
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
105
            )
106
        },
107
    )
108
    use_auth_token: bool = field(
109
        default=None,
110
        metadata={
111
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
112
        },
113
    )
114
    low_cpu_mem_usage: bool = field(
115
        default=False,
116
        metadata={
117
            "help": (
118
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
119
                "set True will benefit LLM loading time and RAM consumption."
120
            )
121
        },
122
    )
123

124
    def __post_init__(self):
125
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
126
            raise ValueError(
127
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
128
            )
129

130

131
@dataclass
132
class DataTrainingArguments:
133
    """
134
    Arguments pertaining to what data we are going to input our model for training and eval.
135
    """
136

137
    dataset_name: Optional[str] = field(
138
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
139
    )
140
    dataset_config_name: Optional[str] = field(
141
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
142
    )
143
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
144
    validation_file: Optional[str] = field(
145
        default=None,
146
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
147
    )
148
    overwrite_cache: bool = field(
149
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
150
    )
151
    validation_split_percentage: Optional[int] = field(
152
        default=5,
153
        metadata={
154
            "help": "The percentage of the train set used as validation set in case there's no validation split"
155
        },
156
    )
157
    max_seq_length: int = field(
158
        default=512,
159
        metadata={
160
            "help": (
161
                "The maximum total input sequence length after tokenization. Sequences longer "
162
                "than this will be truncated."
163
            )
164
        },
165
    )
166
    preprocessing_num_workers: Optional[int] = field(
167
        default=None,
168
        metadata={"help": "The number of processes to use for the preprocessing."},
169
    )
170
    plm_probability: float = field(
171
        default=1 / 6,
172
        metadata={
173
            "help": (
174
                "Ratio of length of a span of masked tokens to surrounding context length for "
175
                "permutation language modeling."
176
            )
177
        },
178
    )
179
    max_span_length: int = field(
180
        default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
181
    )
182
    line_by_line: bool = field(
183
        default=False,
184
        metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
185
    )
186
    pad_to_max_length: bool = field(
187
        default=False,
188
        metadata={
189
            "help": (
190
                "Whether to pad all samples to `max_seq_length`. "
191
                "If False, will pad the samples dynamically when batching to the maximum length in the batch."
192
            )
193
        },
194
    )
195
    max_train_samples: Optional[int] = field(
196
        default=None,
197
        metadata={
198
            "help": (
199
                "For debugging purposes or quicker training, truncate the number of training examples to this "
200
                "value if set."
201
            )
202
        },
203
    )
204
    max_eval_samples: Optional[int] = field(
205
        default=None,
206
        metadata={
207
            "help": (
208
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
209
                "value if set."
210
            )
211
        },
212
    )
213

214
    def __post_init__(self):
215
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
216
            raise ValueError("Need either a dataset name or a training/validation file.")
217
        else:
218
            if self.train_file is not None:
219
                extension = self.train_file.split(".")[-1]
220
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
221
            if self.validation_file is not None:
222
                extension = self.validation_file.split(".")[-1]
223
                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
224

225

226
def main():
227
    # See all possible arguments in src/transformers/training_args.py
228
    # or by passing the --help flag to this script.
229
    # We now keep distinct sets of args, for a cleaner separation of concerns.
230

231
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
232
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
233
        # If we pass only one argument to the script and it's the path to a json file,
234
        # let's parse it to get our arguments.
235
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
236
    else:
237
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
238

239
    if model_args.use_auth_token is not None:
240
        warnings.warn(
241
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
242
            FutureWarning,
243
        )
244
        if model_args.token is not None:
245
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
246
        model_args.token = model_args.use_auth_token
247

248
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
249
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
250
    send_example_telemetry("run_plm", model_args, data_args)
251

252
    # Setup logging
253
    logging.basicConfig(
254
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
255
        datefmt="%m/%d/%Y %H:%M:%S",
256
        handlers=[logging.StreamHandler(sys.stdout)],
257
    )
258

259
    if training_args.should_log:
260
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
261
        transformers.utils.logging.set_verbosity_info()
262

263
    log_level = training_args.get_process_log_level()
264
    logger.setLevel(log_level)
265
    datasets.utils.logging.set_verbosity(log_level)
266
    transformers.utils.logging.set_verbosity(log_level)
267
    transformers.utils.logging.enable_default_handler()
268
    transformers.utils.logging.enable_explicit_format()
269

270
    # Log on each process the small summary:
271
    logger.warning(
272
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
273
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
274
    )
275
    logger.info(f"Training/evaluation parameters {training_args}")
276

277
    # Detecting last checkpoint.
278
    last_checkpoint = None
279
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
280
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
281
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
282
            raise ValueError(
283
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
284
                "Use --overwrite_output_dir to overcome."
285
            )
286
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
287
            logger.info(
288
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
289
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
290
            )
291

292
    # Set seed before initializing model.
293
    set_seed(training_args.seed)
294

295
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
296
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
297
    # (the dataset will be downloaded automatically from the datasets Hub).
298
    #
299
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
300
    # 'text' is found. You can easily tweak this behavior (see below).
301
    #
302
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
303
    # download the dataset.
304
    if data_args.dataset_name is not None:
305
        # Downloading and loading a dataset from the hub.
306
        raw_datasets = load_dataset(
307
            data_args.dataset_name,
308
            data_args.dataset_config_name,
309
            cache_dir=model_args.cache_dir,
310
            token=model_args.token,
311
        )
312
        if "validation" not in raw_datasets.keys():
313
            raw_datasets["validation"] = load_dataset(
314
                data_args.dataset_name,
315
                data_args.dataset_config_name,
316
                split=f"train[:{data_args.validation_split_percentage}%]",
317
                cache_dir=model_args.cache_dir,
318
                token=model_args.token,
319
            )
320
            raw_datasets["train"] = load_dataset(
321
                data_args.dataset_name,
322
                data_args.dataset_config_name,
323
                split=f"train[{data_args.validation_split_percentage}%:]",
324
                cache_dir=model_args.cache_dir,
325
                token=model_args.token,
326
            )
327
    else:
328
        data_files = {}
329
        if data_args.train_file is not None:
330
            data_files["train"] = data_args.train_file
331
            extension = data_args.train_file.split(".")[-1]
332
        if data_args.validation_file is not None:
333
            data_files["validation"] = data_args.validation_file
334
            extension = data_args.validation_file.split(".")[-1]
335
        if extension == "txt":
336
            extension = "text"
337
        raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
338
        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
339
        if "validation" not in raw_datasets.keys():
340
            raw_datasets["validation"] = load_dataset(
341
                extension,
342
                data_files=data_files,
343
                split=f"train[:{data_args.validation_split_percentage}%]",
344
                cache_dir=model_args.cache_dir,
345
                token=model_args.token,
346
            )
347
            raw_datasets["train"] = load_dataset(
348
                extension,
349
                data_files=data_files,
350
                split=f"train[{data_args.validation_split_percentage}%:]",
351
                cache_dir=model_args.cache_dir,
352
                token=model_args.token,
353
            )
354

355
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
356
    # https://huggingface.co/docs/datasets/loading_datasets.
357

358
    # Load pretrained model and tokenizer
359
    #
360
    # Distributed training:
361
    # The .from_pretrained methods guarantee that only one local process can concurrently
362
    # download model & vocab.
363
    config_kwargs = {
364
        "cache_dir": model_args.cache_dir,
365
        "revision": model_args.model_revision,
366
        "token": model_args.token,
367
    }
368
    if model_args.config_name:
369
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
370
    elif model_args.model_name_or_path:
371
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
372
    else:
373
        config = XLNetConfig()
374
        logger.warning("You are instantiating a new config instance from scratch.")
375
        if model_args.config_overrides is not None:
376
            logger.info(f"Overriding config: {model_args.config_overrides}")
377
            config.update_from_string(model_args.config_overrides)
378
            logger.info(f"New config: {config}")
379

380
    tokenizer_kwargs = {
381
        "cache_dir": model_args.cache_dir,
382
        "use_fast": model_args.use_fast_tokenizer,
383
        "revision": model_args.model_revision,
384
        "token": model_args.token,
385
    }
386
    if model_args.tokenizer_name:
387
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
388
    elif model_args.model_name_or_path:
389
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
390
    else:
391
        raise ValueError(
392
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
393
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
394
        )
395

396
    if model_args.model_name_or_path:
397
        model = XLNetLMHeadModel.from_pretrained(
398
            model_args.model_name_or_path,
399
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
400
            config=config,
401
            cache_dir=model_args.cache_dir,
402
            revision=model_args.model_revision,
403
            token=model_args.token,
404
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
405
        )
406
    else:
407
        logger.info("Training new model from scratch")
408
        model = XLNetLMHeadModel(config)
409

410
    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
411
    # on a small vocab and want a smaller embedding size, remove this test.
412
    embedding_size = model.get_input_embeddings().weight.shape[0]
413
    if len(tokenizer) > embedding_size:
414
        model.resize_token_embeddings(len(tokenizer))
415

416
    # Preprocessing the datasets.
417
    # First we tokenize all the texts.
418
    if training_args.do_train:
419
        column_names = raw_datasets["train"].column_names
420
    else:
421
        column_names = raw_datasets["validation"].column_names
422
    text_column_name = "text" if "text" in column_names else column_names[0]
423

424
    if data_args.max_seq_length > tokenizer.model_max_length:
425
        logger.warning(
426
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
427
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
428
        )
429
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
430

431
    if data_args.line_by_line:
432
        # When using line_by_line, we just tokenize each nonempty line.
433
        padding = "max_length" if data_args.pad_to_max_length else False
434

435
        def tokenize_function(examples):
436
            # Remove empty lines
437
            examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
438
            return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
439

440
        with training_args.main_process_first(desc="dataset map tokenization"):
441
            tokenized_datasets = raw_datasets.map(
442
                tokenize_function,
443
                batched=True,
444
                num_proc=data_args.preprocessing_num_workers,
445
                remove_columns=[text_column_name],
446
                load_from_cache_file=not data_args.overwrite_cache,
447
                desc="Running tokenizer on dataset line_by_line",
448
            )
449
    else:
450
        # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
451
        def tokenize_function(examples):
452
            return tokenizer(examples[text_column_name])
453

454
        with training_args.main_process_first(desc="dataset map tokenization"):
455
            tokenized_datasets = raw_datasets.map(
456
                tokenize_function,
457
                batched=True,
458
                num_proc=data_args.preprocessing_num_workers,
459
                remove_columns=column_names,
460
                load_from_cache_file=not data_args.overwrite_cache,
461
                desc="Running tokenizer on every text in dataset",
462
            )
463

464
        # Main data processing function that will concatenate all texts from our dataset and generate chunks of
465
        # max_seq_length.
466
        def group_texts(examples):
467
            # Concatenate all texts.
468
            concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
469
            total_length = len(concatenated_examples[list(examples.keys())[0]])
470
            # We drop the small remainder, and if the total_length < max_seq_length  we exclude this batch and return an empty dict.
471
            # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
472
            total_length = (total_length // max_seq_length) * max_seq_length
473
            # Split by chunks of max_len.
474
            result = {
475
                k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
476
                for k, t in concatenated_examples.items()
477
            }
478
            return result
479

480
        # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
481
        # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
482
        # might be slower to preprocess.
483
        #
484
        # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
485
        # https://huggingface.co/docs/datasets/process#map
486

487
        with training_args.main_process_first(desc="grouping texts together"):
488
            tokenized_datasets = tokenized_datasets.map(
489
                group_texts,
490
                batched=True,
491
                num_proc=data_args.preprocessing_num_workers,
492
                load_from_cache_file=not data_args.overwrite_cache,
493
                desc=f"Grouping texts in chunks of {max_seq_length}",
494
            )
495

496
    if training_args.do_train:
497
        if "train" not in tokenized_datasets:
498
            raise ValueError("--do_train requires a train dataset")
499
        train_dataset = tokenized_datasets["train"]
500
        if data_args.max_train_samples is not None:
501
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
502
            train_dataset = train_dataset.select(range(max_train_samples))
503

504
    if training_args.do_eval:
505
        if "validation" not in tokenized_datasets:
506
            raise ValueError("--do_eval requires a validation dataset")
507
        eval_dataset = tokenized_datasets["validation"]
508
        if data_args.max_eval_samples is not None:
509
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
510
            eval_dataset = eval_dataset.select(range(max_eval_samples))
511

512
    # Data collator
513
    data_collator = DataCollatorForPermutationLanguageModeling(
514
        tokenizer=tokenizer,
515
        plm_probability=data_args.plm_probability,
516
        max_span_length=data_args.max_span_length,
517
    )
518

519
    # Initialize our Trainer
520
    trainer = Trainer(
521
        model=model,
522
        args=training_args,
523
        train_dataset=train_dataset if training_args.do_train else None,
524
        eval_dataset=eval_dataset if training_args.do_eval else None,
525
        tokenizer=tokenizer,
526
        data_collator=data_collator,
527
    )
528

529
    # Training
530
    if training_args.do_train:
531
        checkpoint = None
532
        if training_args.resume_from_checkpoint is not None:
533
            checkpoint = training_args.resume_from_checkpoint
534
        elif last_checkpoint is not None:
535
            checkpoint = last_checkpoint
536
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
537
        trainer.save_model()  # Saves the tokenizer too for easy upload
538
        metrics = train_result.metrics
539

540
        max_train_samples = (
541
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
542
        )
543
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
544

545
        trainer.log_metrics("train", metrics)
546
        trainer.save_metrics("train", metrics)
547
        trainer.save_state()
548

549
    # Evaluation
550
    if training_args.do_eval:
551
        logger.info("*** Evaluate ***")
552

553
        metrics = trainer.evaluate()
554

555
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
556
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
557
        try:
558
            perplexity = math.exp(metrics["eval_loss"])
559
        except OverflowError:
560
            perplexity = float("inf")
561
        metrics["perplexity"] = perplexity
562

563
        trainer.log_metrics("eval", metrics)
564
        trainer.save_metrics("eval", metrics)
565

566
    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "language-modeling"}
567
    if data_args.dataset_name is not None:
568
        kwargs["dataset_tags"] = data_args.dataset_name
569
        if data_args.dataset_config_name is not None:
570
            kwargs["dataset_args"] = data_args.dataset_config_name
571
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
572
        else:
573
            kwargs["dataset"] = data_args.dataset_name
574

575
    if training_args.push_to_hub:
576
        trainer.push_to_hub(**kwargs)
577
    else:
578
        trainer.create_model_card(**kwargs)
579

580

581
def _mp_fn(index):
582
    # For xla_spawn (TPUs)
583
    main()
584

585

586
if __name__ == "__main__":
587
    main()
588

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.