CSS-LM
146 строк · 5.2 Кб
1import logging2import os3import time4from dataclasses import dataclass, field5from enum import Enum6from typing import List, Optional, Union7
8import torch9from filelock import FileLock10from torch.utils.data.dataset import Dataset11
12from ...tokenization_bart import BartTokenizer, BartTokenizerFast13from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast14from ...tokenization_utils import PreTrainedTokenizer15from ...tokenization_xlm_roberta import XLMRobertaTokenizer16from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors17from ..processors.utils import InputFeatures18
19
20logger = logging.getLogger(__name__)21
22
23@dataclass
24class GlueDataTrainingArguments:25"""26Arguments pertaining to what data we are going to input our model for training and eval.
27
28Using `HfArgumentParser` we can turn this class
29into argparse arguments to be able to specify them on
30the command line.
31"""
32
33task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})34data_dir: str = field(35metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}36)37max_seq_length: int = field(38default=128,39metadata={40"help": "The maximum total input sequence length after tokenization. Sequences longer "41"than this will be truncated, sequences shorter will be padded."42},43)44overwrite_cache: bool = field(45default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}46)47
48def __post_init__(self):49self.task_name = self.task_name.lower()50
51
52class Split(Enum):53train = "train"54dev = "dev"55test = "test"56
57
58class GlueDataset(Dataset):59"""60This will be superseded by a framework-agnostic approach
61soon.
62"""
63
64args: GlueDataTrainingArguments65output_mode: str66features: List[InputFeatures]67
68def __init__(69self,70args: GlueDataTrainingArguments,71tokenizer: PreTrainedTokenizer,72limit_length: Optional[int] = None,73mode: Union[str, Split] = Split.train,74cache_dir: Optional[str] = None,75):76self.args = args77self.processor = glue_processors[args.task_name]()78self.output_mode = glue_output_modes[args.task_name]79if isinstance(mode, str):80try:81mode = Split[mode]82except KeyError:83raise KeyError("mode is not a valid split name")84# Load data features from cache or dataset file85cached_features_file = os.path.join(86cache_dir if cache_dir is not None else args.data_dir,87"cached_{}_{}_{}_{}".format(88mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name,89),90)91label_list = self.processor.get_labels()92if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__ in (93RobertaTokenizer,94RobertaTokenizerFast,95XLMRobertaTokenizer,96BartTokenizer,97BartTokenizerFast,98):99# HACK(label indices are swapped in RoBERTa pretrained model)100label_list[1], label_list[2] = label_list[2], label_list[1]101self.label_list = label_list102
103# Make sure only the first process in distributed training processes the dataset,104# and the others will use the cache.105lock_path = cached_features_file + ".lock"106with FileLock(lock_path):107
108if os.path.exists(cached_features_file) and not args.overwrite_cache:109start = time.time()110self.features = torch.load(cached_features_file)111logger.info(112f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start113)114else:115logger.info(f"Creating features from dataset file at {args.data_dir}")116
117if mode == Split.dev:118examples = self.processor.get_dev_examples(args.data_dir)119elif mode == Split.test:120examples = self.processor.get_test_examples(args.data_dir)121else:122examples = self.processor.get_train_examples(args.data_dir)123if limit_length is not None:124examples = examples[:limit_length]125self.features = glue_convert_examples_to_features(126examples,127tokenizer,128max_length=args.max_seq_length,129label_list=label_list,130output_mode=self.output_mode,131)132start = time.time()133torch.save(self.features, cached_features_file)134# ^ This seems to take a lot of time so I want to investigate why and how we can improve.135logger.info(136"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start137)138
139def __len__(self):140return len(self.features)141
142def __getitem__(self, i) -> InputFeatures:143return self.features[i]144
145def get_labels(self):146return self.label_list147