CSS-LM
264 строки · 14.0 Кб
1from dataclasses import dataclass
2from typing import Any, Callable, Dict, List, NewType, Tuple, Union
3
4import torch
5from torch.nn.utils.rnn import pad_sequence
6
7from ..tokenization_utils import PreTrainedTokenizer
8from ..tokenization_utils_base import BatchEncoding
9
10
11InputDataClass = NewType("InputDataClass", Any)
12
13"""
14A DataCollator is a function that takes a list of samples from a Dataset
15and collate them into a batch, as a dictionary of Tensors.
16"""
17DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
18
19
20def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
21"""
22Very simple data collator that:
23- simply collates batches of dict-like objects
24- Performs special handling for potential keys named:
25- ``label``: handles a single value (int or float) per object
26- ``label_ids``: handles a list of values per object
27- does not do any additional preprocessing
28
29i.e., Property names of the input object will be used as corresponding inputs to the model.
30See glue and ner for example of how it's useful.
31"""
32
33# In this function we'll make the assumption that all `features` in the batch
34# have the same attributes.
35# So we will look at the first element as a proxy for what attributes exist
36# on the whole batch.
37if not isinstance(features[0], (dict, BatchEncoding)):
38features = [vars(f) for f in features]
39
40first = features[0]
41batch = {}
42
43# Special handling for labels.
44# Ensure that tensor is created with the correct type
45# (it should be automatically the case, but let's make sure of it.)
46if "label" in first and first["label"] is not None:
47label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
48dtype = torch.long if isinstance(label, int) else torch.float
49batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
50elif "label_ids" in first and first["label_ids"] is not None:
51if isinstance(first["label_ids"], torch.Tensor):
52batch["labels"] = torch.stack([f["label_ids"] for f in features])
53else:
54dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
55batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
56
57# Handling of all other possible keys.
58# Again, we will use the first element to figure out which key/values are not None for this model.
59for k, v in first.items():
60if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
61if isinstance(v, torch.Tensor):
62batch[k] = torch.stack([f[k] for f in features])
63else:
64batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
65
66return batch
67
68
69@dataclass
70class DataCollatorForLanguageModeling:
71"""
72Data collator used for language modeling.
73- collates batches of tensors, honoring their tokenizer's pad_token
74- preprocesses batches for masked language modeling
75"""
76
77tokenizer: PreTrainedTokenizer
78mlm: bool = True
79mlm_probability: float = 0.15
80
81def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
82if isinstance(examples[0], (dict, BatchEncoding)):
83examples = [e["input_ids"] for e in examples]
84batch = self._tensorize_batch(examples)
85if self.mlm:
86inputs, labels = self.mask_tokens(batch)
87return {"input_ids": inputs, "labels": labels}
88else:
89labels = batch.clone().detach()
90labels[labels == self.tokenizer.pad_token_id] = -100
91return {"input_ids": batch, "labels": labels}
92
93def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
94length_of_first = examples[0].size(0)
95are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
96if are_tensors_same_length:
97return torch.stack(examples, dim=0)
98else:
99if self.tokenizer._pad_token is None:
100raise ValueError(
101"You are attempting to pad samples but the tokenizer you are using"
102f" ({self.tokenizer.__class__.__name__}) does not have one."
103)
104return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
105
106def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
107"""
108Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
109"""
110
111if self.tokenizer.mask_token is None:
112raise ValueError(
113"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
114)
115
116labels = inputs.clone()
117# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
118probability_matrix = torch.full(labels.shape, self.mlm_probability)
119special_tokens_mask = [
120self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
121]
122probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
123if self.tokenizer._pad_token is not None:
124padding_mask = labels.eq(self.tokenizer.pad_token_id)
125probability_matrix.masked_fill_(padding_mask, value=0.0)
126masked_indices = torch.bernoulli(probability_matrix).bool()
127labels[~masked_indices] = -100 # We only compute loss on masked tokens
128
129# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
130indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
131inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
132
133# 10% of the time, we replace masked input tokens with random word
134indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
135random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
136inputs[indices_random] = random_words[indices_random]
137
138# The rest of the time (10% of the time) we keep the masked input tokens unchanged
139return inputs, labels
140
141
142@dataclass
143class DataCollatorForPermutationLanguageModeling:
144"""
145Data collator used for permutation language modeling.
146- collates batches of tensors, honoring their tokenizer's pad_token
147- preprocesses batches for permutation language modeling with procedures specific to XLNet
148"""
149
150tokenizer: PreTrainedTokenizer
151plm_probability: float = 1 / 6
152max_span_length: int = 5 # maximum length of a span of masked tokens
153
154def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
155if isinstance(examples[0], (dict, BatchEncoding)):
156examples = [e["input_ids"] for e in examples]
157batch = self._tensorize_batch(examples)
158inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch)
159return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
160
161def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
162length_of_first = examples[0].size(0)
163are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
164if are_tensors_same_length:
165return torch.stack(examples, dim=0)
166else:
167if self.tokenizer._pad_token is None:
168raise ValueError(
169"You are attempting to pad samples but the tokenizer you are using"
170f" ({self.tokenizer.__class__.__name__}) does not have one."
171)
172return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
173
174def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
175"""
176The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1770. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far).
1781. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked)
1792. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked
1803. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length``
1814. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1.
182"""
183
184if self.tokenizer.mask_token is None:
185raise ValueError(
186"This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer."
187)
188
189if inputs.size(1) % 2 != 0:
190raise ValueError(
191"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details."
192)
193
194labels = inputs.clone()
195# Creating the mask and target_mapping tensors
196masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
197target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
198
199for i in range(labels.size(0)):
200# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
201cur_len = 0
202max_len = labels.size(1)
203
204while cur_len < max_len:
205# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
206span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
207# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
208context_length = int(span_length / self.plm_probability)
209# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
210start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
211masked_indices[i, start_index : start_index + span_length] = 1
212# Set `cur_len = cur_len + context_length`
213cur_len += context_length
214
215# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
216# the i-th predict corresponds to the i-th token.
217target_mapping[i] = torch.eye(labels.size(1))
218
219special_tokens_mask = torch.tensor(
220[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
221dtype=torch.bool,
222)
223masked_indices.masked_fill_(special_tokens_mask, value=0.0)
224if self.tokenizer._pad_token is not None:
225padding_mask = labels.eq(self.tokenizer.pad_token_id)
226masked_indices.masked_fill_(padding_mask, value=0.0)
227
228# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
229non_func_mask = ~(padding_mask & special_tokens_mask)
230
231inputs[masked_indices] = self.tokenizer.mask_token_id
232labels[~masked_indices] = -100 # We only compute loss on masked tokens
233
234perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
235
236for i in range(labels.size(0)):
237# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
238# determine which tokens a given token can attend to (encoded in `perm_mask`).
239# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
240# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
241# we assume that reused length is half of sequence length and permutation length is equal to reused length.
242# This requires that the sequence length be even.
243
244# Create a linear factorisation order
245perm_index = torch.arange(labels.size(1))
246# Split this into two halves, assuming that half the sequence is reused each time
247perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
248# Permute the two halves such that they do not cross over
249perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
250# Flatten this out into the desired permuted factorisation order
251perm_index = torch.flatten(perm_index.transpose(0, 1))
252# Set the permutation indices of non-masked (non-functional) tokens to the
253# smallest index (-1) so that:
254# (1) They can be seen by all other positions
255# (2) They cannot see masked positions, so there won't be information leak
256perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
257# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
258# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
259# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
260perm_mask[i] = (
261perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
262) & masked_indices[i]
263
264return inputs, perm_mask, target_mapping, labels
265