CSS-LM

Форк
0
192 строки · 7.4 Кб
1
import logging
2
import os
3
import time
4
from dataclasses import dataclass, field
5
from enum import Enum
6
from typing import Dict, List, Optional, Union
7

8
import torch
9
from filelock import FileLock
10
from torch.utils.data.dataset import Dataset
11

12
from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
13
from ...tokenization_utils import PreTrainedTokenizer
14
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
15

16

17
logger = logging.getLogger(__name__)
18

19
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
20
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
21

22

23
@dataclass
24
class SquadDataTrainingArguments:
25
    """
26
    Arguments pertaining to what data we are going to input our model for training and eval.
27
    """
28

29
    model_type: str = field(
30
        default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
31
    )
32
    data_dir: str = field(
33
        default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
34
    )
35
    max_seq_length: int = field(
36
        default=128,
37
        metadata={
38
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
39
            "than this will be truncated, sequences shorter will be padded."
40
        },
41
    )
42
    doc_stride: int = field(
43
        default=128,
44
        metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
45
    )
46
    max_query_length: int = field(
47
        default=64,
48
        metadata={
49
            "help": "The maximum number of tokens for the question. Questions longer than this will "
50
            "be truncated to this length."
51
        },
52
    )
53
    max_answer_length: int = field(
54
        default=30,
55
        metadata={
56
            "help": "The maximum length of an answer that can be generated. This is needed because the start "
57
            "and end predictions are not conditioned on one another."
58
        },
59
    )
60
    overwrite_cache: bool = field(
61
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
62
    )
63
    version_2_with_negative: bool = field(
64
        default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
65
    )
66
    null_score_diff_threshold: float = field(
67
        default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
68
    )
69
    n_best_size: int = field(
70
        default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
71
    )
72
    lang_id: int = field(
73
        default=0,
74
        metadata={
75
            "help": "language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
76
        },
77
    )
78
    threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
79

80

81
class Split(Enum):
82
    train = "train"
83
    dev = "dev"
84

85

86
class SquadDataset(Dataset):
87
    """
88
    This will be superseded by a framework-agnostic approach
89
    soon.
90
    """
91

92
    args: SquadDataTrainingArguments
93
    features: List[SquadFeatures]
94
    mode: Split
95
    is_language_sensitive: bool
96

97
    def __init__(
98
        self,
99
        args: SquadDataTrainingArguments,
100
        tokenizer: PreTrainedTokenizer,
101
        limit_length: Optional[int] = None,
102
        mode: Union[str, Split] = Split.train,
103
        is_language_sensitive: Optional[bool] = False,
104
        cache_dir: Optional[str] = None,
105
    ):
106
        self.args = args
107
        self.is_language_sensitive = is_language_sensitive
108
        self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
109
        if isinstance(mode, str):
110
            try:
111
                mode = Split[mode]
112
            except KeyError:
113
                raise KeyError("mode is not a valid split name")
114
        self.mode = mode
115
        # Load data features from cache or dataset file
116
        version_tag = "v2" if args.version_2_with_negative else "v1"
117
        cached_features_file = os.path.join(
118
            cache_dir if cache_dir is not None else args.data_dir,
119
            "cached_{}_{}_{}_{}".format(
120
                mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), version_tag,
121
            ),
122
        )
123

124
        # Make sure only the first process in distributed training processes the dataset,
125
        # and the others will use the cache.
126
        lock_path = cached_features_file + ".lock"
127
        with FileLock(lock_path):
128
            if os.path.exists(cached_features_file) and not args.overwrite_cache:
129
                start = time.time()
130
                self.features = torch.load(cached_features_file)
131
                logger.info(
132
                    f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
133
                )
134
            else:
135
                if mode == Split.dev:
136
                    examples = self.processor.get_dev_examples(args.data_dir)
137
                else:
138
                    examples = self.processor.get_train_examples(args.data_dir)
139

140
                self.features = squad_convert_examples_to_features(
141
                    examples=examples,
142
                    tokenizer=tokenizer,
143
                    max_seq_length=args.max_seq_length,
144
                    doc_stride=args.doc_stride,
145
                    max_query_length=args.max_query_length,
146
                    is_training=mode == Split.train,
147
                    threads=args.threads,
148
                )
149

150
                start = time.time()
151
                torch.save(self.features, cached_features_file)
152
                # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
153
                logger.info(
154
                    "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
155
                )
156

157
    def __len__(self):
158
        return len(self.features)
159

160
    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
161
        # Convert to Tensors and build dataset
162
        feature = self.features[i]
163

164
        input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
165
        attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
166
        token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
167
        cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
168
        p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
169
        is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
170

171
        inputs = {
172
            "input_ids": input_ids,
173
            "attention_mask": attention_mask,
174
            "token_type_ids": token_type_ids,
175
        }
176

177
        if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
178
            del inputs["token_type_ids"]
179

180
        if self.args.model_type in ["xlnet", "xlm"]:
181
            inputs.update({"cls_index": cls_index, "p_mask": p_mask})
182
            if self.args.version_2_with_negative:
183
                inputs.update({"is_impossible": is_impossible})
184
            if self.is_language_sensitive:
185
                inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
186

187
        if self.mode == Split.train:
188
            start_positions = torch.tensor(feature.start_position, dtype=torch.long)
189
            end_positions = torch.tensor(feature.end_position, dtype=torch.long)
190
            inputs.update({"start_positions": start_positions, "end_positions": end_positions})
191

192
        return inputs
193

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.