CSS-LM

Форк
0
/
modeling_distilbert.py 
924 строки · 37.5 Кб
1
# coding=utf-8
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" PyTorch DistilBERT model
16
    adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
17
    and in part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
18
"""
19

20

21
import copy
22
import logging
23
import math
24
import warnings
25

26
import numpy as np
27
import torch
28
import torch.nn as nn
29
from torch.nn import CrossEntropyLoss
30

31
from .activations import gelu
32
from .configuration_distilbert import DistilBertConfig
33
from .file_utils import (
34
    add_code_sample_docstrings,
35
    add_start_docstrings,
36
    add_start_docstrings_to_callable,
37
    replace_return_docstrings,
38
)
39
from .modeling_outputs import (
40
    BaseModelOutput,
41
    MaskedLMOutput,
42
    MultipleChoiceModelOutput,
43
    QuestionAnsweringModelOutput,
44
    SequenceClassifierOutput,
45
    TokenClassifierOutput,
46
)
47
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
48

49

50
logger = logging.getLogger(__name__)
51

52
_CONFIG_FOR_DOC = "DistilBertConfig"
53
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
54

55
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
56
    "distilbert-base-uncased",
57
    "distilbert-base-uncased-distilled-squad",
58
    "distilbert-base-cased",
59
    "distilbert-base-cased-distilled-squad",
60
    "distilbert-base-german-cased",
61
    "distilbert-base-multilingual-cased",
62
    "distilbert-base-uncased-finetuned-sst-2-english",
63
    # See all DistilBERT models at https://huggingface.co/models?filter=distilbert
64
]
65

66

67
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
68

69

70
def create_sinusoidal_embeddings(n_pos, dim, out):
71
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
72
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
73
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
74
    out.detach_()
75
    out.requires_grad = False
76

77

78
class Embeddings(nn.Module):
79
    def __init__(self, config):
80
        super().__init__()
81
        self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
82
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
83
        if config.sinusoidal_pos_embds:
84
            create_sinusoidal_embeddings(
85
                n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
86
            )
87

88
        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
89
        self.dropout = nn.Dropout(config.dropout)
90

91
    def forward(self, input_ids):
92
        """
93
        Parameters
94
        ----------
95
        input_ids: torch.tensor(bs, max_seq_length)
96
            The token ids to embed.
97

98
        Outputs
99
        -------
100
        embeddings: torch.tensor(bs, max_seq_length, dim)
101
            The embedded tokens (plus position embeddings, no token_type embeddings)
102
        """
103
        seq_length = input_ids.size(1)
104
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  # (max_seq_length)
105
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
106

107
        word_embeddings = self.word_embeddings(input_ids)  # (bs, max_seq_length, dim)
108
        position_embeddings = self.position_embeddings(position_ids)  # (bs, max_seq_length, dim)
109

110
        embeddings = word_embeddings + position_embeddings  # (bs, max_seq_length, dim)
111
        embeddings = self.LayerNorm(embeddings)  # (bs, max_seq_length, dim)
112
        embeddings = self.dropout(embeddings)  # (bs, max_seq_length, dim)
113
        return embeddings
114

115

116
class MultiHeadSelfAttention(nn.Module):
117
    def __init__(self, config):
118
        super().__init__()
119

120
        self.n_heads = config.n_heads
121
        self.dim = config.dim
122
        self.dropout = nn.Dropout(p=config.attention_dropout)
123

124
        assert self.dim % self.n_heads == 0
125

126
        self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
127
        self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
128
        self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
129
        self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
130

131
        self.pruned_heads = set()
132

133
    def prune_heads(self, heads):
134
        attention_head_size = self.dim // self.n_heads
135
        if len(heads) == 0:
136
            return
137
        heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
138
        # Prune linear layers
139
        self.q_lin = prune_linear_layer(self.q_lin, index)
140
        self.k_lin = prune_linear_layer(self.k_lin, index)
141
        self.v_lin = prune_linear_layer(self.v_lin, index)
142
        self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
143
        # Update hyper params
144
        self.n_heads = self.n_heads - len(heads)
145
        self.dim = attention_head_size * self.n_heads
146
        self.pruned_heads = self.pruned_heads.union(heads)
147

148
    def forward(self, query, key, value, mask, head_mask=None, output_attentions=False):
149
        """
150
        Parameters
151
        ----------
152
        query: torch.tensor(bs, seq_length, dim)
153
        key: torch.tensor(bs, seq_length, dim)
154
        value: torch.tensor(bs, seq_length, dim)
155
        mask: torch.tensor(bs, seq_length)
156

157
        Outputs
158
        -------
159
        weights: torch.tensor(bs, n_heads, seq_length, seq_length)
160
            Attention weights
161
        context: torch.tensor(bs, seq_length, dim)
162
            Contextualized layer. Optional: only if `output_attentions=True`
163
        """
164
        bs, q_length, dim = query.size()
165
        k_length = key.size(1)
166
        # assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
167
        # assert key.size() == value.size()
168

169
        dim_per_head = self.dim // self.n_heads
170

171
        mask_reshp = (bs, 1, 1, k_length)
172

173
        def shape(x):
174
            """ separate heads """
175
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
176

177
        def unshape(x):
178
            """ group heads """
179
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
180

181
        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
182
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
183
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)
184

185
        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
186
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
187
        mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
188
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, q_length, k_length)
189

190
        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
191
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)
192

193
        # Mask heads if we want to
194
        if head_mask is not None:
195
            weights = weights * head_mask
196

197
        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
198
        context = unshape(context)  # (bs, q_length, dim)
199
        context = self.out_lin(context)  # (bs, q_length, dim)
200

201
        if output_attentions:
202
            return (context, weights)
203
        else:
204
            return (context,)
205

206

207
class FFN(nn.Module):
208
    def __init__(self, config):
209
        super().__init__()
210
        self.dropout = nn.Dropout(p=config.dropout)
211
        self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
212
        self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
213
        assert config.activation in ["relu", "gelu"], "activation ({}) must be in ['relu', 'gelu']".format(
214
            config.activation
215
        )
216
        self.activation = gelu if config.activation == "gelu" else nn.ReLU()
217

218
    def forward(self, input):
219
        x = self.lin1(input)
220
        x = self.activation(x)
221
        x = self.lin2(x)
222
        x = self.dropout(x)
223
        return x
224

225

226
class TransformerBlock(nn.Module):
227
    def __init__(self, config):
228
        super().__init__()
229

230
        assert config.dim % config.n_heads == 0
231

232
        self.attention = MultiHeadSelfAttention(config)
233
        self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
234

235
        self.ffn = FFN(config)
236
        self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
237

238
    def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
239
        """
240
        Parameters
241
        ----------
242
        x: torch.tensor(bs, seq_length, dim)
243
        attn_mask: torch.tensor(bs, seq_length)
244

245
        Outputs
246
        -------
247
        sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length)
248
            The attention weights
249
        ffn_output: torch.tensor(bs, seq_length, dim)
250
            The output of the transformer block contextualization.
251
        """
252
        # Self-Attention
253
        sa_output = self.attention(
254
            query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions,
255
        )
256
        if output_attentions:
257
            sa_output, sa_weights = sa_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
258
        else:  # To handle these `output_attention` or `output_hidden_states` cases returning tuples
259
            assert type(sa_output) == tuple
260
            sa_output = sa_output[0]
261
        sa_output = self.sa_layer_norm(sa_output + x)  # (bs, seq_length, dim)
262

263
        # Feed Forward Network
264
        ffn_output = self.ffn(sa_output)  # (bs, seq_length, dim)
265
        ffn_output = self.output_layer_norm(ffn_output + sa_output)  # (bs, seq_length, dim)
266

267
        output = (ffn_output,)
268
        if output_attentions:
269
            output = (sa_weights,) + output
270
        return output
271

272

273
class Transformer(nn.Module):
274
    def __init__(self, config):
275
        super().__init__()
276
        self.n_layers = config.n_layers
277

278
        layer = TransformerBlock(config)
279
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
280

281
    def forward(
282
        self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=None
283
    ):
284
        """
285
        Parameters
286
        ----------
287
        x: torch.tensor(bs, seq_length, dim)
288
            Input sequence embedded.
289
        attn_mask: torch.tensor(bs, seq_length)
290
            Attention mask on the sequence.
291

292
        Outputs
293
        -------
294
        hidden_state: torch.tensor(bs, seq_length, dim)
295
            Sequence of hiddens states in the last (top) layer
296
        all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
297
            Tuple of length n_layers with the hidden states from each layer.
298
            Optional: only if output_hidden_states=True
299
        all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
300
            Tuple of length n_layers with the attention weights from each layer
301
            Optional: only if output_attentions=True
302
        """
303
        all_hidden_states = () if output_hidden_states else None
304
        all_attentions = () if output_attentions else None
305

306
        hidden_state = x
307
        for i, layer_module in enumerate(self.layer):
308
            if output_hidden_states:
309
                all_hidden_states = all_hidden_states + (hidden_state,)
310

311
            layer_outputs = layer_module(
312
                x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
313
            )
314
            hidden_state = layer_outputs[-1]
315

316
            if output_attentions:
317
                assert len(layer_outputs) == 2
318
                attentions = layer_outputs[0]
319
                all_attentions = all_attentions + (attentions,)
320
            else:
321
                assert len(layer_outputs) == 1
322

323
        # Add last layer
324
        if output_hidden_states:
325
            all_hidden_states = all_hidden_states + (hidden_state,)
326

327
        if not return_dict:
328
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
329
        return BaseModelOutput(
330
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
331
        )
332

333

334
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
335
class DistilBertPreTrainedModel(PreTrainedModel):
336
    """ An abstract class to handle weights initialization and
337
        a simple interface for downloading and loading pretrained models.
338
    """
339

340
    config_class = DistilBertConfig
341
    load_tf_weights = None
342
    base_model_prefix = "distilbert"
343

344
    def _init_weights(self, module):
345
        """ Initialize the weights.
346
        """
347
        if isinstance(module, nn.Embedding):
348
            if module.weight.requires_grad:
349
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
350
        if isinstance(module, nn.Linear):
351
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
352
        elif isinstance(module, nn.LayerNorm):
353
            module.bias.data.zero_()
354
            module.weight.data.fill_(1.0)
355
        if isinstance(module, nn.Linear) and module.bias is not None:
356
            module.bias.data.zero_()
357

358

359
DISTILBERT_START_DOCSTRING = r"""
360

361
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
362
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
363
    usage and behavior.
364

365
    Parameters:
366
        config (:class:`~transformers.DistilBertConfig`): Model configuration class with all the parameters of the model.
367
            Initializing with a config file does not load the weights associated with the model, only the configuration.
368
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
369
"""
370

371
DISTILBERT_INPUTS_DOCSTRING = r"""
372
    Args:
373
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
374
            Indices of input sequence tokens in the vocabulary.
375

376
            Indices can be obtained using :class:`transformers.DistilBertTokenizer`.
377
            See :func:`transformers.PreTrainedTokenizer.encode` and
378
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
379

380
            `What are input IDs? <../glossary.html#input-ids>`__
381
        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
382
            Mask to avoid performing attention on padding token indices.
383
            Mask values selected in ``[0, 1]``:
384
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
385

386
            `What are attention masks? <../glossary.html#attention-mask>`__
387
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
388
            Mask to nullify selected heads of the self-attention modules.
389
            Mask values selected in ``[0, 1]``:
390
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
391
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
392
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
393
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
394
            than the model's internal embedding lookup matrix.
395
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
396
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
397
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
398
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
399
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
400
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
401
            plain tuple.
402
"""
403

404

405
@add_start_docstrings(
406
    "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
407
    DISTILBERT_START_DOCSTRING,
408
)
409
class DistilBertModel(DistilBertPreTrainedModel):
410
    def __init__(self, config):
411
        super().__init__(config)
412

413
        self.embeddings = Embeddings(config)  # Embeddings
414
        self.transformer = Transformer(config)  # Encoder
415

416
        self.init_weights()
417

418
    def get_input_embeddings(self):
419
        return self.embeddings.word_embeddings
420

421
    def set_input_embeddings(self, new_embeddings):
422
        self.embeddings.word_embeddings = new_embeddings
423

424
    def _prune_heads(self, heads_to_prune):
425
        """ Prunes heads of the model.
426
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
427
            See base class PreTrainedModel
428
        """
429
        for layer, heads in heads_to_prune.items():
430
            self.transformer.layer[layer].attention.prune_heads(heads)
431

432
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
433
    @add_code_sample_docstrings(
434
        tokenizer_class=_TOKENIZER_FOR_DOC,
435
        checkpoint="distilbert-base-uncased",
436
        output_type=BaseModelOutput,
437
        config_class=_CONFIG_FOR_DOC,
438
    )
439
    @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
440
    def forward(
441
        self,
442
        input_ids=None,
443
        attention_mask=None,
444
        head_mask=None,
445
        inputs_embeds=None,
446
        output_attentions=None,
447
        output_hidden_states=None,
448
        return_dict=None,
449
    ):
450
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
451
        output_hidden_states = (
452
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
453
        )
454
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
455

456
        if input_ids is not None and inputs_embeds is not None:
457
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
458
        elif input_ids is not None:
459
            input_shape = input_ids.size()
460
        elif inputs_embeds is not None:
461
            input_shape = inputs_embeds.size()[:-1]
462
        else:
463
            raise ValueError("You have to specify either input_ids or inputs_embeds")
464

465
        device = input_ids.device if input_ids is not None else inputs_embeds.device
466

467
        if attention_mask is None:
468
            attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
469

470
        # Prepare head mask if needed
471
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
472

473
        if inputs_embeds is None:
474
            inputs_embeds = self.embeddings(input_ids)  # (bs, seq_length, dim)
475
        return self.transformer(
476
            x=inputs_embeds,
477
            attn_mask=attention_mask,
478
            head_mask=head_mask,
479
            output_attentions=output_attentions,
480
            output_hidden_states=output_hidden_states,
481
            return_dict=return_dict,
482
        )
483

484

485
@add_start_docstrings(
486
    """DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING,
487
)
488
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
489
    def __init__(self, config):
490
        super().__init__(config)
491

492
        self.distilbert = DistilBertModel(config)
493
        self.vocab_transform = nn.Linear(config.dim, config.dim)
494
        self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
495
        self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
496

497
        self.init_weights()
498

499
        self.mlm_loss_fct = nn.CrossEntropyLoss()
500

501
    def get_output_embeddings(self):
502
        return self.vocab_projector
503

504
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
505
    @add_code_sample_docstrings(
506
        tokenizer_class=_TOKENIZER_FOR_DOC,
507
        checkpoint="distilbert-base-uncased",
508
        output_type=MaskedLMOutput,
509
        config_class=_CONFIG_FOR_DOC,
510
    )
511
    def forward(
512
        self,
513
        input_ids=None,
514
        attention_mask=None,
515
        head_mask=None,
516
        inputs_embeds=None,
517
        labels=None,
518
        output_attentions=None,
519
        output_hidden_states=None,
520
        return_dict=None,
521
        **kwargs
522
    ):
523
        r"""
524
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
525
            Labels for computing the masked language modeling loss.
526
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
527
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
528
            in ``[0, ..., config.vocab_size]``
529
        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
530
            Used to hide legacy arguments that have been deprecated.
531
        """
532
        if "masked_lm_labels" in kwargs:
533
            warnings.warn(
534
                "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
535
                FutureWarning,
536
            )
537
            labels = kwargs.pop("masked_lm_labels")
538
        assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
539
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
540

541
        dlbrt_output = self.distilbert(
542
            input_ids=input_ids,
543
            attention_mask=attention_mask,
544
            head_mask=head_mask,
545
            inputs_embeds=inputs_embeds,
546
            output_attentions=output_attentions,
547
            output_hidden_states=output_hidden_states,
548
            return_dict=return_dict,
549
        )
550
        hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
551
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
552
        prediction_logits = gelu(prediction_logits)  # (bs, seq_length, dim)
553
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
554
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)
555

556
        mlm_loss = None
557
        if labels is not None:
558
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
559

560
        if not return_dict:
561
            output = (prediction_logits,) + dlbrt_output[1:]
562
            return ((mlm_loss,) + output) if mlm_loss is not None else output
563

564
        return MaskedLMOutput(
565
            loss=mlm_loss,
566
            logits=prediction_logits,
567
            hidden_states=dlbrt_output.hidden_states,
568
            attentions=dlbrt_output.attentions,
569
        )
570

571

572
@add_start_docstrings(
573
    """DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of
574
    the pooled output) e.g. for GLUE tasks. """,
575
    DISTILBERT_START_DOCSTRING,
576
)
577
class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
578
    def __init__(self, config):
579
        super().__init__(config)
580
        self.num_labels = config.num_labels
581

582
        self.distilbert = DistilBertModel(config)
583
        self.pre_classifier = nn.Linear(config.dim, config.dim)
584
        self.classifier = nn.Linear(config.dim, config.num_labels)
585
        self.dropout = nn.Dropout(config.seq_classif_dropout)
586

587
        self.init_weights()
588

589
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
590
    @add_code_sample_docstrings(
591
        tokenizer_class=_TOKENIZER_FOR_DOC,
592
        checkpoint="distilbert-base-uncased",
593
        output_type=SequenceClassifierOutput,
594
        config_class=_CONFIG_FOR_DOC,
595
    )
596
    def forward(
597
        self,
598
        input_ids=None,
599
        attention_mask=None,
600
        head_mask=None,
601
        inputs_embeds=None,
602
        labels=None,
603
        output_attentions=None,
604
        output_hidden_states=None,
605
        return_dict=None,
606
    ):
607
        r"""
608
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
609
            Labels for computing the sequence classification/regression loss.
610
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
611
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
612
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
613
        """
614
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
615

616
        distilbert_output = self.distilbert(
617
            input_ids=input_ids,
618
            attention_mask=attention_mask,
619
            head_mask=head_mask,
620
            inputs_embeds=inputs_embeds,
621
            output_attentions=output_attentions,
622
            output_hidden_states=output_hidden_states,
623
            return_dict=return_dict,
624
        )
625
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
626
        pooled_output = hidden_state[:, 0]  # (bs, dim)
627
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
628
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
629
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
630
        logits = self.classifier(pooled_output)  # (bs, dim)
631

632
        loss = None
633
        if labels is not None:
634
            if self.num_labels == 1:
635
                loss_fct = nn.MSELoss()
636
                loss = loss_fct(logits.view(-1), labels.view(-1))
637
            else:
638
                loss_fct = nn.CrossEntropyLoss()
639
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
640

641
        if not return_dict:
642
            output = (logits,) + distilbert_output[1:]
643
            return ((loss,) + output) if loss is not None else output
644

645
        return SequenceClassifierOutput(
646
            loss=loss,
647
            logits=logits,
648
            hidden_states=distilbert_output.hidden_states,
649
            attentions=distilbert_output.attentions,
650
        )
651

652

653
@add_start_docstrings(
654
    """DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
655
    the hidden-states output to compute `span start logits` and `span end logits`). """,
656
    DISTILBERT_START_DOCSTRING,
657
)
658
class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
659
    def __init__(self, config):
660
        super().__init__(config)
661

662
        self.distilbert = DistilBertModel(config)
663
        self.qa_outputs = nn.Linear(config.dim, config.num_labels)
664
        assert config.num_labels == 2
665
        self.dropout = nn.Dropout(config.qa_dropout)
666

667
        self.init_weights()
668

669
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
670
    @add_code_sample_docstrings(
671
        tokenizer_class=_TOKENIZER_FOR_DOC,
672
        checkpoint="distilbert-base-uncased",
673
        output_type=QuestionAnsweringModelOutput,
674
        config_class=_CONFIG_FOR_DOC,
675
    )
676
    def forward(
677
        self,
678
        input_ids=None,
679
        attention_mask=None,
680
        head_mask=None,
681
        inputs_embeds=None,
682
        start_positions=None,
683
        end_positions=None,
684
        output_attentions=None,
685
        output_hidden_states=None,
686
        return_dict=None,
687
    ):
688
        r"""
689
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
690
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
691
            Positions are clamped to the length of the sequence (`sequence_length`).
692
            Position outside of the sequence are not taken into account for computing the loss.
693
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
694
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
695
            Positions are clamped to the length of the sequence (`sequence_length`).
696
            Position outside of the sequence are not taken into account for computing the loss.
697
        """
698
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
699

700
        distilbert_output = self.distilbert(
701
            input_ids=input_ids,
702
            attention_mask=attention_mask,
703
            head_mask=head_mask,
704
            inputs_embeds=inputs_embeds,
705
            output_attentions=output_attentions,
706
            output_hidden_states=output_hidden_states,
707
            return_dict=return_dict,
708
        )
709
        hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)
710

711
        hidden_states = self.dropout(hidden_states)  # (bs, max_query_len, dim)
712
        logits = self.qa_outputs(hidden_states)  # (bs, max_query_len, 2)
713
        start_logits, end_logits = logits.split(1, dim=-1)
714
        start_logits = start_logits.squeeze(-1)  # (bs, max_query_len)
715
        end_logits = end_logits.squeeze(-1)  # (bs, max_query_len)
716

717
        total_loss = None
718
        if start_positions is not None and end_positions is not None:
719
            # If we are on multi-GPU, split add a dimension
720
            if len(start_positions.size()) > 1:
721
                start_positions = start_positions.squeeze(-1)
722
            if len(end_positions.size()) > 1:
723
                end_positions = end_positions.squeeze(-1)
724
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
725
            ignored_index = start_logits.size(1)
726
            start_positions.clamp_(0, ignored_index)
727
            end_positions.clamp_(0, ignored_index)
728

729
            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
730
            start_loss = loss_fct(start_logits, start_positions)
731
            end_loss = loss_fct(end_logits, end_positions)
732
            total_loss = (start_loss + end_loss) / 2
733

734
        if not return_dict:
735
            output = (start_logits, end_logits) + distilbert_output[1:]
736
            return ((total_loss,) + output) if total_loss is not None else output
737

738
        return QuestionAnsweringModelOutput(
739
            loss=total_loss,
740
            start_logits=start_logits,
741
            end_logits=end_logits,
742
            hidden_states=distilbert_output.hidden_states,
743
            attentions=distilbert_output.attentions,
744
        )
745

746

747
@add_start_docstrings(
748
    """DistilBert Model with a token classification head on top (a linear layer on top of
749
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
750
    DISTILBERT_START_DOCSTRING,
751
)
752
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
753
    def __init__(self, config):
754
        super().__init__(config)
755
        self.num_labels = config.num_labels
756

757
        self.distilbert = DistilBertModel(config)
758
        self.dropout = nn.Dropout(config.dropout)
759
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
760

761
        self.init_weights()
762

763
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
764
    @add_code_sample_docstrings(
765
        tokenizer_class=_TOKENIZER_FOR_DOC,
766
        checkpoint="distilbert-base-uncased",
767
        output_type=TokenClassifierOutput,
768
        config_class=_CONFIG_FOR_DOC,
769
    )
770
    def forward(
771
        self,
772
        input_ids=None,
773
        attention_mask=None,
774
        head_mask=None,
775
        inputs_embeds=None,
776
        labels=None,
777
        output_attentions=None,
778
        output_hidden_states=None,
779
        return_dict=None,
780
    ):
781
        r"""
782
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
783
            Labels for computing the token classification loss.
784
            Indices should be in ``[0, ..., config.num_labels - 1]``.
785
        """
786
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
787

788
        outputs = self.distilbert(
789
            input_ids,
790
            attention_mask=attention_mask,
791
            head_mask=head_mask,
792
            inputs_embeds=inputs_embeds,
793
            output_attentions=output_attentions,
794
            output_hidden_states=output_hidden_states,
795
            return_dict=return_dict,
796
        )
797

798
        sequence_output = outputs[0]
799

800
        sequence_output = self.dropout(sequence_output)
801
        logits = self.classifier(sequence_output)
802

803
        loss = None
804
        if labels is not None:
805
            loss_fct = CrossEntropyLoss()
806
            # Only keep active parts of the loss
807
            if attention_mask is not None:
808
                active_loss = attention_mask.view(-1) == 1
809
                active_logits = logits.view(-1, self.num_labels)
810
                active_labels = torch.where(
811
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
812
                )
813
                loss = loss_fct(active_logits, active_labels)
814
            else:
815
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
816

817
        if not return_dict:
818
            output = (logits,) + outputs[1:]
819
            return ((loss,) + output) if loss is not None else output
820

821
        return TokenClassifierOutput(
822
            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
823
        )
824

825

826
@add_start_docstrings(
827
    """DistilBert Model with a multiple choice classification head on top (a linear layer on top of
828
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
829
    DISTILBERT_START_DOCSTRING,
830
)
831
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
832
    def __init__(self, config):
833
        super().__init__(config)
834

835
        self.distilbert = DistilBertModel(config)
836
        self.pre_classifier = nn.Linear(config.dim, config.dim)
837
        self.classifier = nn.Linear(config.dim, 1)
838
        self.dropout = nn.Dropout(config.seq_classif_dropout)
839

840
        self.init_weights()
841

842
    @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
843
    @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
844
    def forward(
845
        self,
846
        input_ids=None,
847
        attention_mask=None,
848
        head_mask=None,
849
        inputs_embeds=None,
850
        labels=None,
851
        output_attentions=None,
852
        output_hidden_states=None,
853
        return_dict=None,
854
    ):
855
        r"""
856
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
857
            Labels for computing the multiple choice classification loss.
858
            Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
859
            of the input tensors. (see `input_ids` above)
860

861
    Returns:
862

863
    Examples::
864

865
        >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
866
        >>> import torch
867

868
        >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
869
        >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased', return_dict=True)
870

871
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
872
        >>> choice0 = "It is eaten with a fork and a knife."
873
        >>> choice1 = "It is eaten while held in the hand."
874
        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1
875

876
        >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)
877
        >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
878

879
        >>> # the linear classifier still needs to be trained
880
        >>> loss = outputs.loss
881
        >>> logits = outputs.logits
882
        """
883
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
884
        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
885

886
        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
887
        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
888
        inputs_embeds = (
889
            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
890
            if inputs_embeds is not None
891
            else None
892
        )
893

894
        outputs = self.distilbert(
895
            input_ids,
896
            attention_mask=attention_mask,
897
            head_mask=head_mask,
898
            inputs_embeds=inputs_embeds,
899
            output_attentions=output_attentions,
900
            output_hidden_states=output_hidden_states,
901
            return_dict=return_dict,
902
        )
903

904
        hidden_state = outputs[0]  # (bs * num_choices, seq_len, dim)
905
        pooled_output = hidden_state[:, 0]  # (bs * num_choices, dim)
906
        pooled_output = self.pre_classifier(pooled_output)  # (bs * num_choices, dim)
907
        pooled_output = nn.ReLU()(pooled_output)  # (bs * num_choices, dim)
908
        pooled_output = self.dropout(pooled_output)  # (bs * num_choices, dim)
909
        logits = self.classifier(pooled_output)  # (bs * num_choices, 1)
910

911
        reshaped_logits = logits.view(-1, num_choices)  # (bs, num_choices)
912

913
        loss = None
914
        if labels is not None:
915
            loss_fct = CrossEntropyLoss()
916
            loss = loss_fct(reshaped_logits, labels)
917

918
        if not return_dict:
919
            output = (reshaped_logits,) + outputs[1:]
920
            return ((loss,) + output) if loss is not None else output
921

922
        return MultipleChoiceModelOutput(
923
            loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
924
        )
925

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.