transformers
587 строк · 24.5 Кб
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2020 The HuggingFace Team All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""
17Fine-tuning the library models for permutation language modeling.
18"""
19# You can also adapt this script on your own permutation language modeling task. Pointers for this are left as comments.
20
21import logging22import math23import os24import sys25import warnings26from dataclasses import dataclass, field27from itertools import chain28from typing import Optional29
30import datasets31from datasets import load_dataset32
33import transformers34from transformers import (35AutoConfig,36AutoTokenizer,37DataCollatorForPermutationLanguageModeling,38HfArgumentParser,39Trainer,40TrainingArguments,41XLNetConfig,42XLNetLMHeadModel,43set_seed,44)
45from transformers.trainer_utils import get_last_checkpoint46from transformers.utils import check_min_version, send_example_telemetry47from transformers.utils.versions import require_version48
49
50# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51check_min_version("4.39.0.dev0")52
53require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")54
55logger = logging.getLogger(__name__)56
57
58@dataclass
59class ModelArguments:60"""61Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
62"""
63
64model_name_or_path: Optional[str] = field(65default=None,66metadata={67"help": (68"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."69)70},71)72config_name: Optional[str] = field(73default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}74)75config_overrides: Optional[str] = field(76default=None,77metadata={78"help": (79"Override some existing default config settings when a model is trained from scratch. Example: "80"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"81)82},83)84tokenizer_name: Optional[str] = field(85default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}86)87cache_dir: Optional[str] = field(88default=None,89metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},90)91use_fast_tokenizer: bool = field(92default=True,93metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},94)95model_revision: str = field(96default="main",97metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},98)99token: str = field(100default=None,101metadata={102"help": (103"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "104"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."105)106},107)108use_auth_token: bool = field(109default=None,110metadata={111"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."112},113)114low_cpu_mem_usage: bool = field(115default=False,116metadata={117"help": (118"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "119"set True will benefit LLM loading time and RAM consumption."120)121},122)123
124def __post_init__(self):125if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):126raise ValueError(127"--config_overrides can't be used in combination with --config_name or --model_name_or_path"128)129
130
131@dataclass
132class DataTrainingArguments:133"""134Arguments pertaining to what data we are going to input our model for training and eval.
135"""
136
137dataset_name: Optional[str] = field(138default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}139)140dataset_config_name: Optional[str] = field(141default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}142)143train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})144validation_file: Optional[str] = field(145default=None,146metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},147)148overwrite_cache: bool = field(149default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}150)151validation_split_percentage: Optional[int] = field(152default=5,153metadata={154"help": "The percentage of the train set used as validation set in case there's no validation split"155},156)157max_seq_length: int = field(158default=512,159metadata={160"help": (161"The maximum total input sequence length after tokenization. Sequences longer "162"than this will be truncated."163)164},165)166preprocessing_num_workers: Optional[int] = field(167default=None,168metadata={"help": "The number of processes to use for the preprocessing."},169)170plm_probability: float = field(171default=1 / 6,172metadata={173"help": (174"Ratio of length of a span of masked tokens to surrounding context length for "175"permutation language modeling."176)177},178)179max_span_length: int = field(180default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}181)182line_by_line: bool = field(183default=False,184metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},185)186pad_to_max_length: bool = field(187default=False,188metadata={189"help": (190"Whether to pad all samples to `max_seq_length`. "191"If False, will pad the samples dynamically when batching to the maximum length in the batch."192)193},194)195max_train_samples: Optional[int] = field(196default=None,197metadata={198"help": (199"For debugging purposes or quicker training, truncate the number of training examples to this "200"value if set."201)202},203)204max_eval_samples: Optional[int] = field(205default=None,206metadata={207"help": (208"For debugging purposes or quicker training, truncate the number of evaluation examples to this "209"value if set."210)211},212)213
214def __post_init__(self):215if self.dataset_name is None and self.train_file is None and self.validation_file is None:216raise ValueError("Need either a dataset name or a training/validation file.")217else:218if self.train_file is not None:219extension = self.train_file.split(".")[-1]220assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."221if self.validation_file is not None:222extension = self.validation_file.split(".")[-1]223assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."224
225
226def main():227# See all possible arguments in src/transformers/training_args.py228# or by passing the --help flag to this script.229# We now keep distinct sets of args, for a cleaner separation of concerns.230
231parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))232if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):233# If we pass only one argument to the script and it's the path to a json file,234# let's parse it to get our arguments.235model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))236else:237model_args, data_args, training_args = parser.parse_args_into_dataclasses()238
239if model_args.use_auth_token is not None:240warnings.warn(241"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",242FutureWarning,243)244if model_args.token is not None:245raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")246model_args.token = model_args.use_auth_token247
248# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The249# information sent is the one passed as arguments along with your Python/PyTorch versions.250send_example_telemetry("run_plm", model_args, data_args)251
252# Setup logging253logging.basicConfig(254format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",255datefmt="%m/%d/%Y %H:%M:%S",256handlers=[logging.StreamHandler(sys.stdout)],257)258
259if training_args.should_log:260# The default of training_args.log_level is passive, so we set log level at info here to have that default.261transformers.utils.logging.set_verbosity_info()262
263log_level = training_args.get_process_log_level()264logger.setLevel(log_level)265datasets.utils.logging.set_verbosity(log_level)266transformers.utils.logging.set_verbosity(log_level)267transformers.utils.logging.enable_default_handler()268transformers.utils.logging.enable_explicit_format()269
270# Log on each process the small summary:271logger.warning(272f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "273+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"274)275logger.info(f"Training/evaluation parameters {training_args}")276
277# Detecting last checkpoint.278last_checkpoint = None279if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:280last_checkpoint = get_last_checkpoint(training_args.output_dir)281if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:282raise ValueError(283f"Output directory ({training_args.output_dir}) already exists and is not empty. "284"Use --overwrite_output_dir to overcome."285)286elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:287logger.info(288f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "289"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."290)291
292# Set seed before initializing model.293set_seed(training_args.seed)294
295# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)296# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/297# (the dataset will be downloaded automatically from the datasets Hub).298#299# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called300# 'text' is found. You can easily tweak this behavior (see below).301#302# In distributed training, the load_dataset function guarantee that only one local process can concurrently303# download the dataset.304if data_args.dataset_name is not None:305# Downloading and loading a dataset from the hub.306raw_datasets = load_dataset(307data_args.dataset_name,308data_args.dataset_config_name,309cache_dir=model_args.cache_dir,310token=model_args.token,311)312if "validation" not in raw_datasets.keys():313raw_datasets["validation"] = load_dataset(314data_args.dataset_name,315data_args.dataset_config_name,316split=f"train[:{data_args.validation_split_percentage}%]",317cache_dir=model_args.cache_dir,318token=model_args.token,319)320raw_datasets["train"] = load_dataset(321data_args.dataset_name,322data_args.dataset_config_name,323split=f"train[{data_args.validation_split_percentage}%:]",324cache_dir=model_args.cache_dir,325token=model_args.token,326)327else:328data_files = {}329if data_args.train_file is not None:330data_files["train"] = data_args.train_file331extension = data_args.train_file.split(".")[-1]332if data_args.validation_file is not None:333data_files["validation"] = data_args.validation_file334extension = data_args.validation_file.split(".")[-1]335if extension == "txt":336extension = "text"337raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)338# If no validation data is there, validation_split_percentage will be used to divide the dataset.339if "validation" not in raw_datasets.keys():340raw_datasets["validation"] = load_dataset(341extension,342data_files=data_files,343split=f"train[:{data_args.validation_split_percentage}%]",344cache_dir=model_args.cache_dir,345token=model_args.token,346)347raw_datasets["train"] = load_dataset(348extension,349data_files=data_files,350split=f"train[{data_args.validation_split_percentage}%:]",351cache_dir=model_args.cache_dir,352token=model_args.token,353)354
355# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at356# https://huggingface.co/docs/datasets/loading_datasets.357
358# Load pretrained model and tokenizer359#360# Distributed training:361# The .from_pretrained methods guarantee that only one local process can concurrently362# download model & vocab.363config_kwargs = {364"cache_dir": model_args.cache_dir,365"revision": model_args.model_revision,366"token": model_args.token,367}368if model_args.config_name:369config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)370elif model_args.model_name_or_path:371config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)372else:373config = XLNetConfig()374logger.warning("You are instantiating a new config instance from scratch.")375if model_args.config_overrides is not None:376logger.info(f"Overriding config: {model_args.config_overrides}")377config.update_from_string(model_args.config_overrides)378logger.info(f"New config: {config}")379
380tokenizer_kwargs = {381"cache_dir": model_args.cache_dir,382"use_fast": model_args.use_fast_tokenizer,383"revision": model_args.model_revision,384"token": model_args.token,385}386if model_args.tokenizer_name:387tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)388elif model_args.model_name_or_path:389tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)390else:391raise ValueError(392"You are instantiating a new tokenizer from scratch. This is not supported by this script. "393"You can do it from another script, save it, and load it from here, using --tokenizer_name."394)395
396if model_args.model_name_or_path:397model = XLNetLMHeadModel.from_pretrained(398model_args.model_name_or_path,399from_tf=bool(".ckpt" in model_args.model_name_or_path),400config=config,401cache_dir=model_args.cache_dir,402revision=model_args.model_revision,403token=model_args.token,404low_cpu_mem_usage=model_args.low_cpu_mem_usage,405)406else:407logger.info("Training new model from scratch")408model = XLNetLMHeadModel(config)409
410# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch411# on a small vocab and want a smaller embedding size, remove this test.412embedding_size = model.get_input_embeddings().weight.shape[0]413if len(tokenizer) > embedding_size:414model.resize_token_embeddings(len(tokenizer))415
416# Preprocessing the datasets.417# First we tokenize all the texts.418if training_args.do_train:419column_names = raw_datasets["train"].column_names420else:421column_names = raw_datasets["validation"].column_names422text_column_name = "text" if "text" in column_names else column_names[0]423
424if data_args.max_seq_length > tokenizer.model_max_length:425logger.warning(426f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "427f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."428)429max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)430
431if data_args.line_by_line:432# When using line_by_line, we just tokenize each nonempty line.433padding = "max_length" if data_args.pad_to_max_length else False434
435def tokenize_function(examples):436# Remove empty lines437examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]438return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)439
440with training_args.main_process_first(desc="dataset map tokenization"):441tokenized_datasets = raw_datasets.map(442tokenize_function,443batched=True,444num_proc=data_args.preprocessing_num_workers,445remove_columns=[text_column_name],446load_from_cache_file=not data_args.overwrite_cache,447desc="Running tokenizer on dataset line_by_line",448)449else:450# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.451def tokenize_function(examples):452return tokenizer(examples[text_column_name])453
454with training_args.main_process_first(desc="dataset map tokenization"):455tokenized_datasets = raw_datasets.map(456tokenize_function,457batched=True,458num_proc=data_args.preprocessing_num_workers,459remove_columns=column_names,460load_from_cache_file=not data_args.overwrite_cache,461desc="Running tokenizer on every text in dataset",462)463
464# Main data processing function that will concatenate all texts from our dataset and generate chunks of465# max_seq_length.466def group_texts(examples):467# Concatenate all texts.468concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}469total_length = len(concatenated_examples[list(examples.keys())[0]])470# We drop the small remainder, and if the total_length < max_seq_length we exclude this batch and return an empty dict.471# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.472total_length = (total_length // max_seq_length) * max_seq_length473# Split by chunks of max_len.474result = {475k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]476for k, t in concatenated_examples.items()477}478return result479
480# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a481# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value482# might be slower to preprocess.483#484# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:485# https://huggingface.co/docs/datasets/process#map486
487with training_args.main_process_first(desc="grouping texts together"):488tokenized_datasets = tokenized_datasets.map(489group_texts,490batched=True,491num_proc=data_args.preprocessing_num_workers,492load_from_cache_file=not data_args.overwrite_cache,493desc=f"Grouping texts in chunks of {max_seq_length}",494)495
496if training_args.do_train:497if "train" not in tokenized_datasets:498raise ValueError("--do_train requires a train dataset")499train_dataset = tokenized_datasets["train"]500if data_args.max_train_samples is not None:501max_train_samples = min(len(train_dataset), data_args.max_train_samples)502train_dataset = train_dataset.select(range(max_train_samples))503
504if training_args.do_eval:505if "validation" not in tokenized_datasets:506raise ValueError("--do_eval requires a validation dataset")507eval_dataset = tokenized_datasets["validation"]508if data_args.max_eval_samples is not None:509max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)510eval_dataset = eval_dataset.select(range(max_eval_samples))511
512# Data collator513data_collator = DataCollatorForPermutationLanguageModeling(514tokenizer=tokenizer,515plm_probability=data_args.plm_probability,516max_span_length=data_args.max_span_length,517)518
519# Initialize our Trainer520trainer = Trainer(521model=model,522args=training_args,523train_dataset=train_dataset if training_args.do_train else None,524eval_dataset=eval_dataset if training_args.do_eval else None,525tokenizer=tokenizer,526data_collator=data_collator,527)528
529# Training530if training_args.do_train:531checkpoint = None532if training_args.resume_from_checkpoint is not None:533checkpoint = training_args.resume_from_checkpoint534elif last_checkpoint is not None:535checkpoint = last_checkpoint536train_result = trainer.train(resume_from_checkpoint=checkpoint)537trainer.save_model() # Saves the tokenizer too for easy upload538metrics = train_result.metrics539
540max_train_samples = (541data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)542)543metrics["train_samples"] = min(max_train_samples, len(train_dataset))544
545trainer.log_metrics("train", metrics)546trainer.save_metrics("train", metrics)547trainer.save_state()548
549# Evaluation550if training_args.do_eval:551logger.info("*** Evaluate ***")552
553metrics = trainer.evaluate()554
555max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)556metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))557try:558perplexity = math.exp(metrics["eval_loss"])559except OverflowError:560perplexity = float("inf")561metrics["perplexity"] = perplexity562
563trainer.log_metrics("eval", metrics)564trainer.save_metrics("eval", metrics)565
566kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "language-modeling"}567if data_args.dataset_name is not None:568kwargs["dataset_tags"] = data_args.dataset_name569if data_args.dataset_config_name is not None:570kwargs["dataset_args"] = data_args.dataset_config_name571kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"572else:573kwargs["dataset"] = data_args.dataset_name574
575if training_args.push_to_hub:576trainer.push_to_hub(**kwargs)577else:578trainer.create_model_card(**kwargs)579
580
581def _mp_fn(index):582# For xla_spawn (TPUs)583main()584
585
586if __name__ == "__main__":587main()588