CSS-LM
2245 строк · 90.8 Кб
1# coding=utf-8
2# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""PyTorch BERT model. """
17
18
19import logging
20import math
21import os
22import warnings
23from dataclasses import dataclass
24from typing import Optional, Tuple
25
26import torch
27import torch.utils.checkpoint
28from torch import nn
29from torch.nn import CrossEntropyLoss, MSELoss
30
31from .activations import gelu, gelu_new, swish
32from .configuration_bert import BertConfig
33from .file_utils import (
34ModelOutput,
35add_code_sample_docstrings,
36add_start_docstrings,
37add_start_docstrings_to_callable,
38replace_return_docstrings,
39)
40from .modeling_outputs import (
41BaseModelOutput,
42BaseModelOutputWithPooling,
43CausalLMOutput,
44MaskedLMOutput,
45MultipleChoiceModelOutput,
46NextSentencePredictorOutput,
47QuestionAnsweringModelOutput,
48SequenceClassifierOutput,
49TokenClassifierOutput,
50)
51from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
52
53
54logger = logging.getLogger(__name__)
55
56_CONFIG_FOR_DOC = "BertConfig"
57_TOKENIZER_FOR_DOC = "BertTokenizer"
58
59BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
60"bert-base-uncased",
61"bert-large-uncased",
62"bert-base-cased",
63"bert-large-cased",
64"bert-base-multilingual-uncased",
65"bert-base-multilingual-cased",
66"bert-base-chinese",
67"bert-base-german-cased",
68"bert-large-uncased-whole-word-masking",
69"bert-large-cased-whole-word-masking",
70"bert-large-uncased-whole-word-masking-finetuned-squad",
71"bert-large-cased-whole-word-masking-finetuned-squad",
72"bert-base-cased-finetuned-mrpc",
73"bert-base-german-dbmdz-cased",
74"bert-base-german-dbmdz-uncased",
75"cl-tohoku/bert-base-japanese",
76"cl-tohoku/bert-base-japanese-whole-word-masking",
77"cl-tohoku/bert-base-japanese-char",
78"cl-tohoku/bert-base-japanese-char-whole-word-masking",
79"TurkuNLP/bert-base-finnish-cased-v1",
80"TurkuNLP/bert-base-finnish-uncased-v1",
81"wietsedv/bert-base-dutch-cased",
82# See all BERT models at https://huggingface.co/models?filter=bert
83]
84
85
86def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
87""" Load tf checkpoints in a pytorch model.
88"""
89try:
90import re
91import numpy as np
92import tensorflow as tf
93except ImportError:
94logger.error(
95"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
96"https://www.tensorflow.org/install/ for installation instructions."
97)
98raise
99tf_path = os.path.abspath(tf_checkpoint_path)
100logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
101# Load weights from TF model
102init_vars = tf.train.list_variables(tf_path)
103names = []
104arrays = []
105for name, shape in init_vars:
106logger.info("Loading TF weight {} with shape {}".format(name, shape))
107array = tf.train.load_variable(tf_path, name)
108names.append(name)
109arrays.append(array)
110
111for name, array in zip(names, arrays):
112name = name.split("/")
113# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
114# which are not required for using pretrained model
115if any(
116n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
117for n in name
118):
119logger.info("Skipping {}".format("/".join(name)))
120continue
121pointer = model
122for m_name in name:
123if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
124scope_names = re.split(r"_(\d+)", m_name)
125else:
126scope_names = [m_name]
127if scope_names[0] == "kernel" or scope_names[0] == "gamma":
128pointer = getattr(pointer, "weight")
129elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
130pointer = getattr(pointer, "bias")
131elif scope_names[0] == "output_weights":
132pointer = getattr(pointer, "weight")
133elif scope_names[0] == "squad":
134pointer = getattr(pointer, "classifier")
135else:
136try:
137pointer = getattr(pointer, scope_names[0])
138except AttributeError:
139logger.info("Skipping {}".format("/".join(name)))
140continue
141if len(scope_names) >= 2:
142num = int(scope_names[1])
143pointer = pointer[num]
144if m_name[-11:] == "_embeddings":
145pointer = getattr(pointer, "weight")
146elif m_name == "kernel":
147array = np.transpose(array)
148try:
149assert pointer.shape == array.shape
150except AssertionError as e:
151e.args += (pointer.shape, array.shape)
152raise
153logger.info("Initialize PyTorch weight {}".format(name))
154pointer.data = torch.from_numpy(array)
155return model
156
157
158def mish(x):
159return x * torch.tanh(nn.functional.softplus(x))
160
161
162ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
163
164
165BertLayerNorm = torch.nn.LayerNorm
166
167
168class BertEmbeddings(nn.Module):
169"""Construct the embeddings from word, position and token_type embeddings.
170"""
171
172def __init__(self, config):
173super().__init__()
174self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
175self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
176self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
177
178# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
179# any TensorFlow checkpoint file
180self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
181self.dropout = nn.Dropout(config.hidden_dropout_prob)
182
183# position_ids (1, len position emb) is contiguous in memory and exported when serialized
184self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
185
186def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
187if input_ids is not None:
188input_shape = input_ids.size()
189else:
190input_shape = inputs_embeds.size()[:-1]
191
192seq_length = input_shape[1]
193
194if position_ids is None:
195position_ids = self.position_ids[:, :seq_length]
196
197if token_type_ids is None:
198token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
199
200if inputs_embeds is None:
201inputs_embeds = self.word_embeddings(input_ids)
202position_embeddings = self.position_embeddings(position_ids)
203token_type_embeddings = self.token_type_embeddings(token_type_ids)
204
205
206embeddings = inputs_embeds + position_embeddings + token_type_embeddings
207embeddings = self.LayerNorm(embeddings)
208embeddings = self.dropout(embeddings)
209return embeddings
210
211
212class BertSelfAttention(nn.Module):
213def __init__(self, config):
214super().__init__()
215if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
216raise ValueError(
217"The hidden size (%d) is not a multiple of the number of attention "
218"heads (%d)" % (config.hidden_size, config.num_attention_heads)
219)
220
221self.num_attention_heads = config.num_attention_heads
222self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
223self.all_head_size = self.num_attention_heads * self.attention_head_size
224
225self.query = nn.Linear(config.hidden_size, self.all_head_size)
226self.key = nn.Linear(config.hidden_size, self.all_head_size)
227self.value = nn.Linear(config.hidden_size, self.all_head_size)
228
229self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
230
231def transpose_for_scores(self, x):
232new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
233x = x.view(*new_x_shape)
234return x.permute(0, 2, 1, 3)
235
236def forward(
237self,
238hidden_states,
239attention_mask=None,
240head_mask=None,
241encoder_hidden_states=None,
242encoder_attention_mask=None,
243output_attentions=False,
244):
245mixed_query_layer = self.query(hidden_states)
246
247# If this is instantiated as a cross-attention module, the keys
248# and values come from an encoder; the attention mask needs to be
249# such that the encoder's padding tokens are not attended to.
250if encoder_hidden_states is not None:
251mixed_key_layer = self.key(encoder_hidden_states)
252mixed_value_layer = self.value(encoder_hidden_states)
253attention_mask = encoder_attention_mask
254else:
255mixed_key_layer = self.key(hidden_states)
256mixed_value_layer = self.value(hidden_states)
257
258query_layer = self.transpose_for_scores(mixed_query_layer)
259key_layer = self.transpose_for_scores(mixed_key_layer)
260value_layer = self.transpose_for_scores(mixed_value_layer)
261
262# Take the dot product between "query" and "key" to get the raw attention scores.
263attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
264attention_scores = attention_scores / math.sqrt(self.attention_head_size)
265if attention_mask is not None:
266# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
267attention_scores = attention_scores + attention_mask
268
269# Normalize the attention scores to probabilities.
270attention_probs = nn.Softmax(dim=-1)(attention_scores)
271
272# This is actually dropping out entire tokens to attend to, which might
273# seem a bit unusual, but is taken from the original Transformer paper.
274attention_probs = self.dropout(attention_probs)
275
276# Mask heads if we want to
277if head_mask is not None:
278attention_probs = attention_probs * head_mask
279
280context_layer = torch.matmul(attention_probs, value_layer)
281
282context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
283new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
284context_layer = context_layer.view(*new_context_layer_shape)
285
286outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
287return outputs
288
289
290class BertSelfOutput(nn.Module):
291def __init__(self, config):
292super().__init__()
293self.dense = nn.Linear(config.hidden_size, config.hidden_size)
294self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
295self.dropout = nn.Dropout(config.hidden_dropout_prob)
296
297def forward(self, hidden_states, input_tensor):
298hidden_states = self.dense(hidden_states)
299hidden_states = self.dropout(hidden_states)
300hidden_states = self.LayerNorm(hidden_states + input_tensor)
301return hidden_states
302
303
304class BertAttention(nn.Module):
305def __init__(self, config):
306super().__init__()
307self.self = BertSelfAttention(config)
308self.output = BertSelfOutput(config)
309self.pruned_heads = set()
310
311def prune_heads(self, heads):
312if len(heads) == 0:
313return
314heads, index = find_pruneable_heads_and_indices(
315heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
316)
317
318# Prune linear layers
319self.self.query = prune_linear_layer(self.self.query, index)
320self.self.key = prune_linear_layer(self.self.key, index)
321self.self.value = prune_linear_layer(self.self.value, index)
322self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
323
324# Update hyper params and store pruned heads
325self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
326self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
327self.pruned_heads = self.pruned_heads.union(heads)
328
329def forward(
330self,
331hidden_states,
332attention_mask=None,
333head_mask=None,
334encoder_hidden_states=None,
335encoder_attention_mask=None,
336output_attentions=False,
337):
338self_outputs = self.self(
339hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
340)
341attention_output = self.output(self_outputs[0], hidden_states)
342outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
343return outputs
344
345
346class BertIntermediate(nn.Module):
347def __init__(self, config):
348super().__init__()
349self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
350if isinstance(config.hidden_act, str):
351self.intermediate_act_fn = ACT2FN[config.hidden_act]
352else:
353self.intermediate_act_fn = config.hidden_act
354
355def forward(self, hidden_states):
356hidden_states = self.dense(hidden_states)
357hidden_states = self.intermediate_act_fn(hidden_states)
358return hidden_states
359
360
361class BertOutput(nn.Module):
362def __init__(self, config):
363super().__init__()
364self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
365self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
366self.dropout = nn.Dropout(config.hidden_dropout_prob)
367
368def forward(self, hidden_states, input_tensor):
369hidden_states = self.dense(hidden_states)
370hidden_states = self.dropout(hidden_states)
371hidden_states = self.LayerNorm(hidden_states + input_tensor)
372return hidden_states
373
374
375class BertLayer(nn.Module):
376def __init__(self, config):
377super().__init__()
378self.attention = BertAttention(config)
379self.is_decoder = config.is_decoder
380if self.is_decoder:
381self.crossattention = BertAttention(config)
382self.intermediate = BertIntermediate(config)
383self.output = BertOutput(config)
384
385def forward(
386self,
387hidden_states,
388attention_mask=None,
389head_mask=None,
390encoder_hidden_states=None,
391encoder_attention_mask=None,
392output_attentions=False,
393):
394self_attention_outputs = self.attention(
395hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
396)
397attention_output = self_attention_outputs[0]
398outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
399
400if self.is_decoder and encoder_hidden_states is not None:
401cross_attention_outputs = self.crossattention(
402attention_output,
403attention_mask,
404head_mask,
405encoder_hidden_states,
406encoder_attention_mask,
407output_attentions,
408)
409attention_output = cross_attention_outputs[0]
410outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
411
412intermediate_output = self.intermediate(attention_output)
413layer_output = self.output(intermediate_output, attention_output)
414outputs = (layer_output,) + outputs
415return outputs
416
417
418class BertEncoder(nn.Module):
419def __init__(self, config):
420super().__init__()
421self.config = config
422self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
423
424def forward(
425self,
426hidden_states,
427attention_mask=None,
428head_mask=None,
429encoder_hidden_states=None,
430encoder_attention_mask=None,
431output_attentions=False,
432output_hidden_states=False,
433return_dict=False,
434):
435all_hidden_states = () if output_hidden_states else None
436all_attentions = () if output_attentions else None
437for i, layer_module in enumerate(self.layer):
438if output_hidden_states:
439all_hidden_states = all_hidden_states + (hidden_states,)
440
441if getattr(self.config, "gradient_checkpointing", False):
442
443def create_custom_forward(module):
444def custom_forward(*inputs):
445return module(*inputs, output_attentions)
446
447return custom_forward
448
449layer_outputs = torch.utils.checkpoint.checkpoint(
450create_custom_forward(layer_module),
451hidden_states,
452attention_mask,
453head_mask[i],
454encoder_hidden_states,
455encoder_attention_mask,
456)
457else:
458layer_outputs = layer_module(
459hidden_states,
460attention_mask,
461head_mask[i],
462encoder_hidden_states,
463encoder_attention_mask,
464output_attentions,
465)
466hidden_states = layer_outputs[0]
467if output_attentions:
468all_attentions = all_attentions + (layer_outputs[1],)
469
470if output_hidden_states:
471all_hidden_states = all_hidden_states + (hidden_states,)
472
473if not return_dict:
474return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
475return BaseModelOutput(
476last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
477)
478
479
480class BertPooler(nn.Module):
481def __init__(self, config):
482super().__init__()
483self.dense = nn.Linear(config.hidden_size, config.hidden_size)
484self.activation = nn.Tanh()
485
486def forward(self, hidden_states):
487# We "pool" the model by simply taking the hidden state corresponding
488# to the first token.
489first_token_tensor = hidden_states[:, 0]
490pooled_output = self.dense(first_token_tensor)
491pooled_output = self.activation(pooled_output)
492return pooled_output
493
494
495class BertPredictionHeadTransform(nn.Module):
496def __init__(self, config):
497super().__init__()
498self.dense = nn.Linear(config.hidden_size, config.hidden_size)
499if isinstance(config.hidden_act, str):
500self.transform_act_fn = ACT2FN[config.hidden_act]
501else:
502self.transform_act_fn = config.hidden_act
503self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
504
505def forward(self, hidden_states):
506hidden_states = self.dense(hidden_states)
507hidden_states = self.transform_act_fn(hidden_states)
508hidden_states = self.LayerNorm(hidden_states)
509return hidden_states
510
511
512class BertLMPredictionHead(nn.Module):
513def __init__(self, config):
514super().__init__()
515self.transform = BertPredictionHeadTransform(config)
516
517# The output weights are the same as the input embeddings, but there is
518# an output-only bias for each token.
519self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
520
521self.bias = nn.Parameter(torch.zeros(config.vocab_size))
522
523# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
524self.decoder.bias = self.bias
525
526def forward(self, hidden_states):
527hidden_states = self.transform(hidden_states)
528hidden_states = self.decoder(hidden_states)
529return hidden_states
530
531
532class BertOnlyMLMHead(nn.Module):
533def __init__(self, config):
534super().__init__()
535self.predictions = BertLMPredictionHead(config)
536
537def forward(self, sequence_output):
538prediction_scores = self.predictions(sequence_output)
539return prediction_scores
540
541
542class BertOnlyNSPHead(nn.Module):
543def __init__(self, config):
544super().__init__()
545self.seq_relationship = nn.Linear(config.hidden_size, 2)
546
547def forward(self, pooled_output):
548seq_relationship_score = self.seq_relationship(pooled_output)
549return seq_relationship_score
550
551
552class BertPreTrainingHeads(nn.Module):
553def __init__(self, config):
554super().__init__()
555self.predictions = BertLMPredictionHead(config)
556self.seq_relationship = nn.Linear(config.hidden_size, 2)
557
558def forward(self, sequence_output, pooled_output):
559prediction_scores = self.predictions(sequence_output)
560seq_relationship_score = self.seq_relationship(pooled_output)
561return prediction_scores, seq_relationship_score
562
563
564class BertPreTrainedModel(PreTrainedModel):
565""" An abstract class to handle weights initialization and
566a simple interface for downloading and loading pretrained models.
567"""
568
569config_class = BertConfig
570load_tf_weights = load_tf_weights_in_bert
571base_model_prefix = "bert"
572authorized_missing_keys = [r"position_ids"]
573
574def _init_weights(self, module):
575""" Initialize the weights """
576if isinstance(module, (nn.Linear, nn.Embedding)):
577# Slightly different from the TF version which uses truncated_normal for initialization
578# cf https://github.com/pytorch/pytorch/pull/5617
579module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
580elif isinstance(module, BertLayerNorm):
581module.bias.data.zero_()
582module.weight.data.fill_(1.0)
583if isinstance(module, nn.Linear) and module.bias is not None:
584module.bias.data.zero_()
585
586
587@dataclass
588class BertForPretrainingOutput(ModelOutput):
589"""
590Output type of :class:`~transformers.BertForPretrainingModel`.
591
592Args:
593loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
594Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
595prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
596Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
597seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
598Prediction scores of the next sequence prediction (classification) head (scores of True/False
599continuation before SoftMax).
600hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
601Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
602of shape :obj:`(batch_size, sequence_length, hidden_size)`.
603
604Hidden-states of the model at the output of each layer plus the initial embedding outputs.
605attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
606Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
607:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
608
609Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
610heads.
611"""
612
613loss: Optional[torch.FloatTensor] = None
614prediction_logits: torch.FloatTensor = None
615seq_relationship_logits: torch.FloatTensor = None
616hidden_states: Optional[Tuple[torch.FloatTensor]] = None
617attentions: Optional[Tuple[torch.FloatTensor]] = None
618
619
620BERT_START_DOCSTRING = r"""
621This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
622Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
623usage and behavior.
624
625Parameters:
626config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
627Initializing with a config file does not load the weights associated with the model, only the configuration.
628Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
629"""
630
631BERT_INPUTS_DOCSTRING = r"""
632Args:
633input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
634Indices of input sequence tokens in the vocabulary.
635
636Indices can be obtained using :class:`transformers.BertTokenizer`.
637See :func:`transformers.PreTrainedTokenizer.encode` and
638:func:`transformers.PreTrainedTokenizer.__call__` for details.
639
640`What are input IDs? <../glossary.html#input-ids>`__
641attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
642Mask to avoid performing attention on padding token indices.
643Mask values selected in ``[0, 1]``:
644``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
645
646`What are attention masks? <../glossary.html#attention-mask>`__
647token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
648Segment token indices to indicate first and second portions of the inputs.
649Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
650corresponds to a `sentence B` token
651
652`What are token type IDs? <../glossary.html#token-type-ids>`_
653position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
654Indices of positions of each input sequence tokens in the position embeddings.
655Selected in the range ``[0, config.max_position_embeddings - 1]``.
656
657`What are position IDs? <../glossary.html#position-ids>`_
658head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
659Mask to nullify selected heads of the self-attention modules.
660Mask values selected in ``[0, 1]``:
661:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
662inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
663Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
664This is useful if you want more control over how to convert `input_ids` indices into associated vectors
665than the model's internal embedding lookup matrix.
666encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
667Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
668if the model is configured as a decoder.
669encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
670Mask to avoid performing attention on the padding token indices of the encoder input. This mask
671is used in the cross-attention if the model is configured as a decoder.
672Mask values selected in ``[0, 1]``:
673``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
674output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
675If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
676output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
677If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
678return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
679If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
680plain tuple.
681"""
682
683
684@add_start_docstrings(
685"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
686BERT_START_DOCSTRING,
687)
688class BertModel(BertPreTrainedModel):
689"""
690
691The model can behave as an encoder (with only self-attention) as well
692as a decoder, in which case a layer of cross-attention is added between
693the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
694Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
695
696To behave as an decoder the model needs to be initialized with the
697:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
698:obj:`encoder_hidden_states` is expected as an input to the forward pass.
699
700.. _`Attention is all you need`:
701https://arxiv.org/abs/1706.03762
702
703"""
704
705def __init__(self, config):
706super().__init__(config)
707self.config = config
708
709self.embeddings = BertEmbeddings(config)
710self.encoder = BertEncoder(config)
711self.pooler = BertPooler(config)
712
713self.init_weights()
714
715def get_input_embeddings(self):
716return self.embeddings.word_embeddings
717
718def set_input_embeddings(self, value):
719self.embeddings.word_embeddings = value
720
721def _prune_heads(self, heads_to_prune):
722""" Prunes heads of the model.
723heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
724See base class PreTrainedModel
725"""
726for layer, heads in heads_to_prune.items():
727self.encoder.layer[layer].attention.prune_heads(heads)
728
729@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
730@add_code_sample_docstrings(
731tokenizer_class=_TOKENIZER_FOR_DOC,
732checkpoint="bert-base-uncased",
733output_type=BaseModelOutputWithPooling,
734config_class=_CONFIG_FOR_DOC,
735)
736def forward(
737self,
738input_ids=None,
739attention_mask=None,
740token_type_ids=None,
741position_ids=None,
742head_mask=None,
743inputs_embeds=None,
744encoder_hidden_states=None,
745encoder_attention_mask=None,
746output_attentions=None,
747output_hidden_states=None,
748return_dict=None,
749):
750output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
751output_hidden_states = (
752output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
753)
754return_dict = return_dict if return_dict is not None else self.config.use_return_dict
755
756if input_ids is not None and inputs_embeds is not None:
757raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
758elif input_ids is not None:
759input_shape = input_ids.size()
760elif inputs_embeds is not None:
761input_shape = inputs_embeds.size()[:-1]
762else:
763raise ValueError("You have to specify either input_ids or inputs_embeds")
764
765device = input_ids.device if input_ids is not None else inputs_embeds.device
766
767if attention_mask is None:
768attention_mask = torch.ones(input_shape, device=device)
769if token_type_ids is None:
770token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
771
772# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
773# ourselves in which case we just need to make it broadcastable to all heads.
774extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
775
776# If a 2D ou 3D attention mask is provided for the cross-attention
777# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
778if self.config.is_decoder and encoder_hidden_states is not None:
779encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
780encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
781if encoder_attention_mask is None:
782encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
783encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
784else:
785encoder_extended_attention_mask = None
786
787# Prepare head mask if needed
788# 1.0 in head_mask indicate we keep the head
789# attention_probs has shape bsz x n_heads x N x N
790# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
791# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
792head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
793
794embedding_output = self.embeddings(
795input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
796)
797encoder_outputs = self.encoder(
798embedding_output,
799attention_mask=extended_attention_mask,
800head_mask=head_mask,
801encoder_hidden_states=encoder_hidden_states,
802encoder_attention_mask=encoder_extended_attention_mask,
803output_attentions=output_attentions,
804output_hidden_states=output_hidden_states,
805return_dict=return_dict,
806)
807sequence_output = encoder_outputs[0]
808pooled_output = self.pooler(sequence_output)
809
810if not return_dict:
811return (sequence_output, pooled_output) + encoder_outputs[1:]
812
813return BaseModelOutputWithPooling(
814last_hidden_state=sequence_output,
815pooler_output=pooled_output,
816hidden_states=encoder_outputs.hidden_states,
817attentions=encoder_outputs.attentions,
818)
819
820
821@add_start_docstrings(
822"""Domain-Task Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
823a `next sentence prediction (classification)` head. """,
824BERT_START_DOCSTRING,
825)
826class BertForPreTrainingDomainTask(BertPreTrainedModel):
827def __init__(self, config):
828super().__init__(config)
829
830self.bert = BertModel(config)
831self.cls = BertPreTrainingHeads(config)
832###
833self.in_domain_layer = torch.nn.Linear(768,768,bias=False)
834torch.nn.init.xavier_uniform_(self.in_domain_layer.weight)
835self.out_domain_layer = torch.nn.Linear(768,768,bias=False)
836torch.nn.init.xavier_uniform_(self.out_domain_layer.weight)
837self.act = nn.ReLU()
838self.layer_out = nn.Linear(768, 2) #num_class
839###
840self.init_weights()
841
842def get_output_embeddings(self):
843return self.cls.predictions.decoder
844
845@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
846@replace_return_docstrings(output_type=BertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
847def forward(
848self,
849input_ids=None,
850attention_mask=None,
851token_type_ids=None,
852position_ids=None,
853head_mask=None,
854inputs_embeds=None,
855labels=None,
856next_sentence_label=None,
857output_attentions=None,
858output_hidden_states=None,
859return_dict=None,
860tail_idxs=None,
861#in_domain_rep_batch=None,
862in_domain_rep=None,
863out_domain_rep=None,
864func=None,
865**kwargs
866):
867
868r"""
869labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
870Labels for computing the masked language modeling loss.
871Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
872Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
873in ``[0, ..., config.vocab_size]``
874
875next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
876Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
877Indices should be in ``[0, 1]``.
878``0`` indicates sequence B is a continuation of sequence A,
879``1`` indicates sequence B is a random sequence.
880
881kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
882Used to hide legacy arguments that have been deprecated.
883
884Returns:
885
886Examples::
887
888>>> from transformers import BertTokenizer, BertForPreTraining
889>>> import torch
890
891>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
892>>> model = BertForPreTraining.from_pretrained('bert-base-uncased', return_dict=True)
893
894>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
895>>> outputs = model(**inputs)
896
897>>> prediction_logits = outptus.prediction_logits
898>>> seq_relationship_logits = outputs.seq_relationship_logits
899"""
900
901if "masked_lm_labels" in kwargs:
902warnings.warn(
903"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
904FutureWarning,
905)
906labels = kwargs.pop("masked_lm_labels")
907assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
908return_dict = return_dict if return_dict is not None else self.config.use_return_dict
909
910outputs = self.bert(
911input_ids,
912attention_mask=attention_mask,
913token_type_ids=token_type_ids,
914position_ids=position_ids,
915head_mask=head_mask,
916inputs_embeds=inputs_embeds,
917output_attentions=output_attentions,
918output_hidden_states=output_hidden_states,
919return_dict=return_dict,
920)
921
922if func == "in_domain_rep":
923#return outputs.hidden_states[0]
924'''
925in_domain_rep=list()
926for id, idx in enumerate(tail_idxs):
927in_domain_rep.append(outputs.hidden_states[0][id,idx,:])
928in_domain_rep = torch.stack(in_domain_rep)
929in_domain_rep = self.in_domain_layer(in_domain_rep)
930'''
931in_domain_rep = self.in_domain_layer(outputs.hidden_states[0][:,0,:])
932return in_domain_rep
933
934elif func == "domain_class":
935#in_domain_rep = in_domain_rep_batch.squeeze(0)
936#print(in_domain_rep.shape)
937#exit()
938loss_fct = CrossEntropyLoss()
939'''
940out_domain_rep = list()
941for id, idx in enumerate(tail_idxs):
942out_domain_rep.append(outputs.hidden_states[0][id,idx,:])
943out_domain_rep = torch.stack(out_domain_rep)
944out_domain_rep = self.out_domain_layer(out_domain_rep)
945'''
946out_domain_rep = self.out_domain_layer(outputs.hidden_states[0][:,0,:])
947pos_rep = self.layer_out((self.act(in_domain_rep)))
948pos_target = torch.tensor([1]*pos_rep.shape[0]).to("cuda")
949neg_rep = self.layer_out((self.act(out_domain_rep)))
950neg_target = torch.tensor([0]*neg_rep.shape[0]).to("cuda")
951rep = torch.cat([pos_rep, neg_rep], 0)
952target = torch.cat([pos_target, neg_target], 0)
953domain_loss = loss_fct(rep, target)
954return domain_loss
955else:
956pass
957
958
959sequence_output, pooled_output = outputs[:2]
960prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
961
962
963total_loss = None
964if labels is not None and next_sentence_label is not None:
965loss_fct = CrossEntropyLoss()
966masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
967next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
968total_loss = masked_lm_loss + next_sentence_loss
969elif labels is not None:
970#exit()
971#loss_fct = CrossEntropyLoss()
972loss_fct = CrossEntropyLoss(ignore_index=-1)
973'''
974print(prediction_scores)
975print(prediction_scores.shape)
976print("====")
977print(labels)
978print(labels.shape)
979exit()
980'''
981masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
982total_loss = masked_lm_loss
983
984
985
986if not return_dict:
987output = (prediction_scores, seq_relationship_score) + outputs[2:]
988return ((total_loss,) + output) if total_loss is not None else output
989
990
991return BertForPretrainingOutput(
992loss=total_loss,
993prediction_logits=prediction_scores,
994seq_relationship_logits=seq_relationship_score,
995hidden_states=outputs.hidden_states,
996attentions=outputs.attentions,
997)
998
999
1000@add_start_docstrings(
1001"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
1002a `next sentence prediction (classification)` head. """,
1003BERT_START_DOCSTRING,
1004)
1005class BertForPreTraining(BertPreTrainedModel):
1006def __init__(self, config):
1007super().__init__(config)
1008
1009self.bert = BertModel(config)
1010self.cls = BertPreTrainingHeads(config)
1011
1012self.init_weights()
1013
1014def get_output_embeddings(self):
1015return self.cls.predictions.decoder
1016
1017@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1018@replace_return_docstrings(output_type=BertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
1019def forward(
1020self,
1021input_ids=None,
1022attention_mask=None,
1023token_type_ids=None,
1024position_ids=None,
1025head_mask=None,
1026inputs_embeds=None,
1027labels=None,
1028next_sentence_label=None,
1029output_attentions=None,
1030output_hidden_states=None,
1031return_dict=None,
1032**kwargs
1033):
1034
1035r"""
1036labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
1037Labels for computing the masked language modeling loss.
1038Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1039Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1040in ``[0, ..., config.vocab_size]``
1041
1042next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
1043Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
1044Indices should be in ``[0, 1]``.
1045``0`` indicates sequence B is a continuation of sequence A,
1046``1`` indicates sequence B is a random sequence.
1047
1048kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1049Used to hide legacy arguments that have been deprecated.
1050
1051Returns:
1052
1053Examples::
1054
1055>>> from transformers import BertTokenizer, BertForPreTraining
1056>>> import torch
1057
1058>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1059>>> model = BertForPreTraining.from_pretrained('bert-base-uncased', return_dict=True)
1060
1061>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1062>>> outputs = model(**inputs)
1063
1064>>> prediction_logits = outptus.prediction_logits
1065>>> seq_relationship_logits = outputs.seq_relationship_logits
1066"""
1067
1068if "masked_lm_labels" in kwargs:
1069warnings.warn(
1070"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1071FutureWarning,
1072)
1073labels = kwargs.pop("masked_lm_labels")
1074assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1075return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1076
1077outputs = self.bert(
1078input_ids,
1079attention_mask=attention_mask,
1080token_type_ids=token_type_ids,
1081position_ids=position_ids,
1082head_mask=head_mask,
1083inputs_embeds=inputs_embeds,
1084output_attentions=output_attentions,
1085output_hidden_states=output_hidden_states,
1086return_dict=return_dict,
1087)
1088
1089sequence_output, pooled_output = outputs[:2]
1090prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1091
1092
1093total_loss = None
1094if labels is not None and next_sentence_label is not None:
1095loss_fct = CrossEntropyLoss()
1096masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1097next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1098total_loss = masked_lm_loss + next_sentence_loss
1099elif labels is not None:
1100#exit()
1101#loss_fct = CrossEntropyLoss()
1102loss_fct = CrossEntropyLoss(ignore_index=-1)
1103'''
1104print(prediction_scores)
1105print(prediction_scores.shape)
1106print("====")
1107print(labels)
1108print(labels.shape)
1109exit()
1110'''
1111masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1112total_loss = masked_lm_loss
1113
1114
1115
1116if not return_dict:
1117output = (prediction_scores, seq_relationship_score) + outputs[2:]
1118return ((total_loss,) + output) if total_loss is not None else output
1119
1120
1121return BertForPretrainingOutput(
1122loss=total_loss,
1123prediction_logits=prediction_scores,
1124seq_relationship_logits=seq_relationship_score,
1125hidden_states=outputs.hidden_states,
1126attentions=outputs.attentions,
1127)
1128
1129
1130
1131
1132@add_start_docstrings(
1133"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
1134)
1135class BertLMHeadModel(BertPreTrainedModel):
1136def __init__(self, config):
1137super().__init__(config)
1138assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
1139
1140self.bert = BertModel(config)
1141self.cls = BertOnlyMLMHead(config)
1142
1143self.init_weights()
1144
1145def get_output_embeddings(self):
1146return self.cls.predictions.decoder
1147
1148@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1149@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
1150def forward(
1151self,
1152input_ids=None,
1153attention_mask=None,
1154token_type_ids=None,
1155position_ids=None,
1156head_mask=None,
1157inputs_embeds=None,
1158labels=None,
1159encoder_hidden_states=None,
1160encoder_attention_mask=None,
1161output_attentions=None,
1162output_hidden_states=None,
1163return_dict=None,
1164**kwargs
1165):
1166r"""
1167labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1168Labels for computing the left-to-right language modeling loss (next word prediction).
1169Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1170Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1171in ``[0, ..., config.vocab_size]``
1172kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1173Used to hide legacy arguments that have been deprecated.
1174
1175Returns:
1176
1177Example::
1178
1179>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1180>>> import torch
1181
1182>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1183>>> config = BertConfig.from_pretrained("bert-base-cased")
1184>>> config.is_decoder = True
1185>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config, return_dict=True)
1186
1187>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1188>>> outputs = model(**inputs)
1189
1190>>> prediction_logits = outputs.logits
1191"""
1192return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1193
1194outputs = self.bert(
1195input_ids,
1196attention_mask=attention_mask,
1197token_type_ids=token_type_ids,
1198position_ids=position_ids,
1199head_mask=head_mask,
1200inputs_embeds=inputs_embeds,
1201encoder_hidden_states=encoder_hidden_states,
1202encoder_attention_mask=encoder_attention_mask,
1203output_attentions=output_attentions,
1204output_hidden_states=output_hidden_states,
1205return_dict=return_dict,
1206)
1207
1208sequence_output = outputs[0]
1209prediction_scores = self.cls(sequence_output)
1210
1211lm_loss = None
1212if labels is not None:
1213# we are doing next-token prediction; shift prediction scores and input ids by one
1214shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1215labels = labels[:, 1:].contiguous()
1216loss_fct = CrossEntropyLoss()
1217lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1218
1219if not return_dict:
1220output = (prediction_scores,) + outputs[2:]
1221return ((lm_loss,) + output) if lm_loss is not None else output
1222
1223return CausalLMOutput(
1224loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
1225)
1226
1227def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1228input_shape = input_ids.shape
1229
1230# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1231if attention_mask is None:
1232attention_mask = input_ids.new_ones(input_shape)
1233
1234return {"input_ids": input_ids, "attention_mask": attention_mask}
1235
1236
1237@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1238class BertForMaskedLM(BertPreTrainedModel):
1239def __init__(self, config):
1240super().__init__(config)
1241assert (
1242not config.is_decoder
1243), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1244
1245self.bert = BertModel(config)
1246self.cls = BertOnlyMLMHead(config)
1247
1248self.init_weights()
1249
1250def get_output_embeddings(self):
1251return self.cls.predictions.decoder
1252
1253@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1254@add_code_sample_docstrings(
1255tokenizer_class=_TOKENIZER_FOR_DOC,
1256checkpoint="bert-base-uncased",
1257output_type=MaskedLMOutput,
1258config_class=_CONFIG_FOR_DOC,
1259)
1260def forward(
1261self,
1262input_ids=None,
1263attention_mask=None,
1264token_type_ids=None,
1265position_ids=None,
1266head_mask=None,
1267inputs_embeds=None,
1268labels=None,
1269encoder_hidden_states=None,
1270encoder_attention_mask=None,
1271output_attentions=None,
1272output_hidden_states=None,
1273return_dict=None,
1274**kwargs
1275):
1276
1277"""
1278`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1279with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
1280`extract_features.py`, `run_classifier.py` and `run_squad.py`)
1281`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1282types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1283a `sentence B` token (see BERT paper for more details).
1284`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1285selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1286input sequence length in the current batch. It's the mask that we typically use for attention when
1287a batch has varying length sentences.
1288`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
1289with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
1290is only computed for the labels set in [0, ..., vocab_size]
1291`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
1292with indices selected in [0, 1].
12930 => next sentence is the continuation, 1 => next sentence is a random sentence.
1294"""
1295
1296
1297r"""
1298labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1299Labels for computing the masked language modeling loss.
1300Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1301Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1302in ``[0, ..., config.vocab_size]``
1303kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1304Used to hide legacy arguments that have been deprecated.
1305"""
1306if "masked_lm_labels" in kwargs:
1307warnings.warn(
1308"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1309FutureWarning,
1310)
1311labels = kwargs.pop("masked_lm_labels")
1312assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
1313assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1314
1315return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1316
1317outputs = self.bert(
1318input_ids,
1319attention_mask=attention_mask,
1320token_type_ids=token_type_ids,
1321position_ids=position_ids,
1322head_mask=head_mask,
1323inputs_embeds=inputs_embeds,
1324encoder_hidden_states=encoder_hidden_states,
1325encoder_attention_mask=encoder_attention_mask,
1326output_attentions=output_attentions,
1327output_hidden_states=output_hidden_states,
1328return_dict=return_dict,
1329)
1330
1331sequence_output = outputs[0]
1332prediction_scores = self.cls(sequence_output)
1333
1334masked_lm_loss = None
1335if labels is not None:
1336loss_fct = CrossEntropyLoss() # -100 index = padding token
1337masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1338
1339if not return_dict:
1340output = (prediction_scores,) + outputs[2:]
1341return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1342
1343return MaskedLMOutput(
1344loss=masked_lm_loss,
1345logits=prediction_scores,
1346hidden_states=outputs.hidden_states,
1347attentions=outputs.attentions,
1348)
1349
1350def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1351input_shape = input_ids.shape
1352effective_batch_size = input_shape[0]
1353
1354# add a dummy token
1355assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1356attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1357dummy_token = torch.full(
1358(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1359)
1360input_ids = torch.cat([input_ids, dummy_token], dim=1)
1361
1362return {"input_ids": input_ids, "attention_mask": attention_mask}
1363
1364
1365
1366class BertClassificationHead(nn.Module):
1367"""Head for sentence-level classification tasks."""
1368
1369def __init__(self, config):
1370super().__init__()
1371self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1372self.dropout = nn.Dropout(config.hidden_dropout_prob)
1373self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1374
1375def forward(self, features, **kwargs):
1376x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1377#x = features[input_ids==2] # take </s> token (equiv. to the last token)
1378x = self.dropout(x)
1379x = self.dense(x)
1380x = torch.tanh(x)
1381x = self.dropout(x)
1382x = self.out_proj(x)
1383return x
1384
1385
1386
1387class BertClassificationTail(nn.Module):
1388"""Head for sentence-level classification tasks."""
1389
1390def __init__(self, config):
1391super().__init__()
1392self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1393self.dropout = nn.Dropout(config.hidden_dropout_prob)
1394self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1395
1396def forward(self, features, input_ids, **kwargs):
1397#x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1398#x = features[input_ids==2] # take </s> token (equiv. to the last token)
1399x = features[input_ids==102] # take </s> token (equiv. to the last token)
1400x = self.dropout(x)
1401x = self.dense(x)
1402x = torch.tanh(x)
1403x = self.dropout(x)
1404x = self.out_proj(x)
1405return x
1406
1407
1408class BertClassificationHeadandTail(nn.Module):
1409"""Head for sentence-level classification tasks."""
1410
1411def __init__(self, config):
1412super().__init__()
1413self.dense = nn.Linear(config.hidden_size*2, config.hidden_size*2)
1414self.dropout = nn.Dropout(config.hidden_dropout_prob)
1415self.out_proj = nn.Linear(config.hidden_size*2, config.num_labels)
1416self.num_labels = config.num_labels
1417
1418def forward(self, features, input_ids, **kwargs):
1419head = features[:, 0, :] # take <s> token (equiv. to [CLS])
1420#tail = features[input_ids==2] # take </s> token (equiv. to the last token)
1421tail = features[input_ids==102] # take </s> token (equiv. to the last token)
1422x = torch.cat((head, tail),-1) # [, 768*2]
1423x = self.dropout(x)
1424x = self.dense(x)
1425x = torch.tanh(x)
1426x = self.dropout(x)
1427x = self.out_proj(x)
1428return x
1429
1430
1431
1432@add_start_docstrings("""BERT Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1433class BertForMaskedLMDomainTask(BertPreTrainedModel):
1434#config_class = RobertaConfig
1435#base_model_prefix = "roberta"
1436
1437def __init__(self, config):
1438super().__init__(config)
1439
1440self.bert = BertModel(config)
1441self.cls = BertOnlyMLMHead(config)
1442self.classifier_Task = BertClassificationHead(config)
1443self.classifier_Domain = BertClassificationTail(config)
1444self.classifier_DomainandTask = BertClassificationHeadandTail(config)
1445self.num_labels = config.num_labels
1446
1447self.init_weights()
1448self.LeakyReLU = torch.nn.LeakyReLU()
1449self.domain_binary_classifier = nn.Linear(768*2,2,bias=True) #num_class
1450self.task_binary_classifier = nn.Linear(768*2,2,bias=True) #num_class
1451self.domain_task_binary_classifier = nn.Linear(768*4,2,bias=True) #num_class
1452
1453def get_output_embeddings(self):
1454return self.cls.predictions.decoder
1455
1456@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1457@add_code_sample_docstrings(
1458tokenizer_class=_TOKENIZER_FOR_DOC,
1459checkpoint="bert-base-uncased",
1460output_type=MaskedLMOutput,
1461config_class=_CONFIG_FOR_DOC,
1462)
1463
1464def forward(
1465self,
1466input_ids=None,
1467input_ids_org=None,
1468attention_mask=None,
1469token_type_ids=None,
1470position_ids=None,
1471head_mask=None,
1472inputs_embeds=None,
1473labels=None,
1474output_attentions=None,
1475output_hidden_states=None,
1476return_dict=None,
1477func=None,
1478tail_idxs=None,
1479in_domain_rep=None,
1480out_domain_rep=None,
1481sentence_label=None,
1482lm_label=None,
1483batch_size=None,
1484all_in_task_rep_comb=None,
1485all_sentence_binary_label=None,
1486from_query=False,
1487task_loss_org=None,
1488task_loss_cotrain=None,
1489domain_id=None,
1490use_detach=False,
1491**kwargs
1492):
1493r"""
1494labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1495Labels for computing the masked language modeling loss.
1496Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1497Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1498in ``[0, ..., config.vocab_size]``
1499kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1500Used to hide legacy arguments that have been deprecated.
1501"""
1502if "masked_lm_labels" in kwargs:
1503warnings.warn(
1504"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1505FutureWarning,
1506)
1507labels = kwargs.pop("masked_lm_labels")
1508assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1509return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1510
1511
1512if func == "in_domain_task_rep":
1513#######
1514outputs = self.bert(
1515input_ids=input_ids_org,
1516attention_mask=attention_mask,
1517token_type_ids=token_type_ids,
1518position_ids=position_ids,
1519head_mask=head_mask,
1520inputs_embeds=inputs_embeds,
1521output_attentions=output_attentions,
1522output_hidden_states=output_hidden_states,
1523return_dict=return_dict,
1524)
1525#######
1526#x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1527#rep = outputs.last_hidden_state[:, 0, :]
1528#rep = outputs.last_hidden_state[:, 0, :]
1529rep_head = outputs.last_hidden_state[:, 0, :]
1530#rep_tail = outputs.last_hidden_state[input_ids_org==2]
1531rep_tail = outputs.last_hidden_state[input_ids_org==102]
1532#print("rep:",rep_tail.shape[0])
1533#print("input:",input_ids_org.shape[0])
1534if rep_tail.shape[0] != input_ids_org.shape[0]:
1535#print(input_ids_org)
1536print("!!!!!!!!!!!!!!!!!!!!!!")
1537print("Error: have no 102 id")
1538print("!!!!!!!!!!!!!!!!!!!!!!")
1539rep_tail = outputs.last_hidden_state[input_ids_org==tail_idxs]
1540
1541
1542
1543#detach
1544#rep = rep.detach()
1545'''
1546in_domain_rep = self.domain_layer(rep)
1547in_task_rep = self.task_layer(rep)
1548return in_domain_rep, in_task_rep
1549'''
1550return rep_tail, rep_head
1551
1552
1553elif func == "return_task_binary_classifier":
1554return self.task_binary_classifier.weight.data, self.task_binary_classifier.bias.data
1555
1556elif func == "return_domain_binary_classifier":
1557return self.domain_binary_classifier.weight.data, self.domain_binary_classifier.bias.data
1558
1559elif func == "return_domain_task_binary_classifier":
1560return self.domain_task_binary_classifier.weight.data, self.domain_task_binary_classifier.bias.data
1561
1562#if func == "task_binary_classifier":
1563
1564elif func == "domain_binary_classifier":
1565#in:1 , out:0
1566#Need to fix
1567#######
1568
1569loss_fct = CrossEntropyLoss()
1570domain_rep = torch.cat([in_domain_rep, out_domain_rep], 0)
1571if use_detach==True:
1572domain_rep = domain_rep.detach()
1573logit = self.domain_binary_classifier(domain_rep)
1574pos_target = torch.tensor([1]*in_domain_rep.shape[0]).to("cuda")
1575unknow_target = domain_id.to("cuda")
1576target = torch.cat([pos_target, unknow_target], 0)
1577domain_loss = loss_fct(logit, target)
1578
1579
1580#return domain_loss, logit, out_domain_rep_head, out_domain_rep_tail
1581return domain_loss, logit
1582
1583
1584elif func == "task_binary_classifier":
1585#Didn't include query rep: so it need to add in_domain_rep here
1586loss_fct = CrossEntropyLoss()
1587#detach
1588#all_in_task_rep_comb = all_in_task_rep_comb.detach()
1589if use_detach==True:
1590all_in_task_rep_comb = all_in_task_rep_comb.detach()
1591logit = self.task_binary_classifier(all_in_task_rep_comb)
1592#logit = self.LeakyReLU(logit)
1593all_sentence_binary_label = all_sentence_binary_label.reshape(all_sentence_binary_label.shape[0]*all_sentence_binary_label.shape[1])
1594logit = logit.reshape(logit.shape[0]*logit.shape[1],logit.shape[2])
1595task_binary_loss = loss_fct(logit.view(-1,2), all_sentence_binary_label.view(-1))
1596return task_binary_loss, logit
1597
1598
1599elif func == "domain_task_binary_classifier":
1600#Didn't include query rep: so it need to add in_domain_rep here
1601loss_fct = CrossEntropyLoss()
1602#detach
1603#all_in_task_rep_comb = all_in_task_rep_comb.detach()
1604logit = self.domain_task_binary_classifier(all_in_task_rep_comb)
1605#logit = self.LeakyReLU(logit)
1606all_sentence_binary_label = all_sentence_binary_label.reshape(all_sentence_binary_label.shape[0]*all_sentence_binary_label.shape[1])
1607logit = logit.reshape(logit.shape[0]*logit.shape[1],logit.shape[2])
1608task_binary_loss = loss_fct(logit.view(-1,2), all_sentence_binary_label.view(-1))
1609return task_binary_loss, logit
1610
1611
1612elif func == "task_class":
1613#######
1614outputs = self.bert(
1615input_ids=input_ids_org,
1616attention_mask=attention_mask,
1617token_type_ids=token_type_ids,
1618position_ids=position_ids,
1619head_mask=head_mask,
1620inputs_embeds=inputs_embeds,
1621output_attentions=output_attentions,
1622output_hidden_states=output_hidden_states,
1623return_dict=return_dict,
1624)
1625#######
1626#Already including query rep
1627loss_fct = CrossEntropyLoss()
1628###
1629class_logit = self.classifier_DomainandTask(outputs.last_hidden_state, input_ids_org)
1630task_loss = loss_fct(class_logit.view(-1, self.num_labels), sentence_label.view(-1))
1631
1632if from_query==True:
1633query_rep_head = outputs.last_hidden_state[:,0,:]
1634#query_rep_tail = outputs.last_hidden_state[input_ids_org==2]
1635query_rep_tail = outputs.last_hidden_state[input_ids_org==102]
1636
1637if query_rep_tail.shape[0] != input_ids_org.shape[0]:
1638#print(input_ids_org)
1639print("!!!!!!!!!!!!!!!!!!!!!!")
1640print("Error: have no 102 id")
1641print("!!!!!!!!!!!!!!!!!!!!!!")
1642query_rep_tail = outputs.last_hidden_state[input_ids_org==tail_idxs]
1643return task_loss, class_logit, query_rep_head, query_rep_tail
1644else:
1645return task_loss, class_logit
1646
1647
1648elif func == "task_class_domain":
1649#######
1650outputs = self.bert(
1651input_ids=input_ids_org,
1652attention_mask=attention_mask,
1653token_type_ids=token_type_ids,
1654position_ids=position_ids,
1655head_mask=head_mask,
1656inputs_embeds=inputs_embeds,
1657output_attentions=output_attentions,
1658output_hidden_states=output_hidden_states,
1659return_dict=return_dict,
1660)
1661#######
1662#Already including query rep
1663loss_fct = CrossEntropyLoss()
1664###
1665class_logit = self.classifier_Domain(outputs.last_hidden_state, input_ids_org)
1666task_loss = loss_fct(class_logit.view(-1, self.num_labels), sentence_label.view(-1))
1667
1668if from_query==True:
1669query_rep_head = outputs.last_hidden_state[:,0,:]
1670#query_rep_tail = outputs.last_hidden_state[input_ids_org==2]
1671query_rep_tail = outputs.last_hidden_state[input_ids_org==102]
1672if query_rep_tail.shape[0] != input_ids_org.shape[0]:
1673#print(input_ids_org)
1674print("!!!!!!!!!!!!!!!!!!!!!!")
1675print("Error: have no 102 id")
1676print("!!!!!!!!!!!!!!!!!!!!!!")
1677query_rep_tail = outputs.last_hidden_state[input_ids_org==tail_idxs]
1678return task_loss, class_logit, query_rep_head, query_rep_tail
1679else:
1680return task_loss, class_logit
1681
1682
1683elif func == "task_class_nodomain":
1684#######
1685outputs = self.bert(
1686input_ids=input_ids_org,
1687attention_mask=attention_mask,
1688token_type_ids=token_type_ids,
1689position_ids=position_ids,
1690head_mask=head_mask,
1691inputs_embeds=inputs_embeds,
1692output_attentions=output_attentions,
1693output_hidden_states=output_hidden_states,
1694return_dict=return_dict,
1695)
1696#######
1697#Already including query rep
1698loss_fct = CrossEntropyLoss()
1699###
1700class_logit = self.classifier_Task(outputs.last_hidden_state)
1701task_loss = loss_fct(class_logit.view(-1, self.num_labels), sentence_label.view(-1))
1702
1703if from_query==True:
1704query_rep_head = outputs.last_hidden_state[:,0,:]
1705#query_rep_tail = outputs.last_hidden_state[input_ids_org==2]
1706query_rep_tail = outputs.last_hidden_state[input_ids_org==102]
1707if query_rep_tail.shape[0] != input_ids_org.shape[0]:
1708#print(input_ids_org)
1709print("!!!!!!!!!!!!!!!!!!!!!!")
1710print("Error: have no 102 id")
1711print("!!!!!!!!!!!!!!!!!!!!!!")
1712query_rep_tail = outputs.last_hidden_state[input_ids_org==tail_idxs]
1713return task_loss, class_logit, query_rep_head, query_rep_tail
1714else:
1715return task_loss, class_logit
1716
1717
1718elif func == "mlm":
1719outputs_mlm = self.bert(
1720input_ids=input_ids,
1721attention_mask=attention_mask,
1722token_type_ids=token_type_ids,
1723position_ids=position_ids,
1724head_mask=head_mask,
1725inputs_embeds=inputs_embeds,
1726output_attentions=output_attentions,
1727output_hidden_states=output_hidden_states,
1728return_dict=return_dict,
1729)
1730
1731loss_fct = CrossEntropyLoss()
1732sequence_output = outputs_mlm.last_hidden_state
1733#sequence_output = outputs_mlm[0]
1734#prediction_scores = self.lm_head(sequence_output)
1735prediction_scores = self.cls(sequence_output)
1736loss_fct = CrossEntropyLoss(ignore_index=-1)
1737masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_label.view(-1))
1738return masked_lm_loss
1739
1740
1741elif func == "task_class and mlm":
1742#######
1743outputs = self.bert(
1744input_ids=input_ids_org,
1745attention_mask=attention_mask,
1746token_type_ids=token_type_ids,
1747position_ids=position_ids,
1748head_mask=head_mask,
1749inputs_embeds=inputs_embeds,
1750output_attentions=output_attentions,
1751output_hidden_states=output_hidden_states,
1752return_dict=return_dict,
1753)
1754#######
1755#######
1756outputs_mlm = self.bert(
1757input_ids=input_ids,
1758attention_mask=attention_mask,
1759token_type_ids=token_type_ids,
1760position_ids=position_ids,
1761head_mask=head_mask,
1762inputs_embeds=inputs_embeds,
1763output_attentions=output_attentions,
1764output_hidden_states=output_hidden_states,
1765return_dict=return_dict,
1766)
1767#######
1768#Already including query rep
1769#task loss
1770loss_fct = CrossEntropyLoss()
1771###
1772'''
1773#rep = outputs.last_hidden_state[input_ids==2]
1774rep = outputs.last_hidden_state[:, 0, :]
1775#rep = rep.detach()
1776task_rep = self.task_layer(rep)
1777class_logit = self.layer_out_taskClass((self.act(task_rep)))
1778'''
1779class_logit = self.classifier(outputs.last_hidden_state)
1780###
1781task_loss = loss_fct(class_logit.view(-1, 8), sentence_label.view(-1))
1782
1783#mlm loss
1784sequence_output = outputs_mlm.last_hidden_state
1785#prediction_scores = self.lm_head(sequence_output)
1786prediction_scores = self.cls(sequence_output)
1787loss_fct = CrossEntropyLoss(ignore_index=-1)
1788masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_label.view(-1))
1789return task_loss, masked_lm_loss
1790
1791elif func == "gen_rep":
1792outputs = self.bert(
1793input_ids=input_ids_org,
1794attention_mask=attention_mask,
1795token_type_ids=token_type_ids,
1796position_ids=position_ids,
1797head_mask=head_mask,
1798inputs_embeds=inputs_embeds,
1799output_attentions=output_attentions,
1800output_hidden_states=output_hidden_states,
1801return_dict=return_dict,
1802)
1803return outputs
1804
1805
1806
1807
1808
1809
1810
1811@add_start_docstrings(
1812"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1813)
1814class BertForNextSentencePrediction(BertPreTrainedModel):
1815def __init__(self, config):
1816super().__init__(config)
1817
1818self.bert = BertModel(config)
1819self.cls = BertOnlyNSPHead(config)
1820
1821self.init_weights()
1822
1823@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1824@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1825def forward(
1826self,
1827input_ids=None,
1828attention_mask=None,
1829token_type_ids=None,
1830position_ids=None,
1831head_mask=None,
1832inputs_embeds=None,
1833next_sentence_label=None,
1834output_attentions=None,
1835output_hidden_states=None,
1836return_dict=None,
1837):
1838r"""
1839next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1840Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
1841Indices should be in ``[0, 1]``.
1842``0`` indicates sequence B is a continuation of sequence A,
1843``1`` indicates sequence B is a random sequence.
1844
1845Returns:
1846
1847Example::
1848
1849>>> from transformers import BertTokenizer, BertForNextSentencePrediction
1850>>> import torch
1851
1852>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1853>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased', return_dict=True)
1854
1855>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1856>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1857>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1858
1859>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1860>>> logits = outputs.logits
1861>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1862"""
1863return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1864
1865outputs = self.bert(
1866input_ids,
1867attention_mask=attention_mask,
1868token_type_ids=token_type_ids,
1869position_ids=position_ids,
1870head_mask=head_mask,
1871inputs_embeds=inputs_embeds,
1872output_attentions=output_attentions,
1873output_hidden_states=output_hidden_states,
1874return_dict=return_dict,
1875)
1876
1877pooled_output = outputs[1]
1878
1879seq_relationship_scores = self.cls(pooled_output)
1880
1881next_sentence_loss = None
1882if next_sentence_label is not None:
1883loss_fct = CrossEntropyLoss()
1884next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
1885
1886if not return_dict:
1887output = (seq_relationship_scores,) + outputs[2:]
1888return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1889
1890return NextSentencePredictorOutput(
1891loss=next_sentence_loss,
1892logits=seq_relationship_scores,
1893hidden_states=outputs.hidden_states,
1894attentions=outputs.attentions,
1895)
1896
1897
1898@add_start_docstrings(
1899"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1900the pooled output) e.g. for GLUE tasks. """,
1901BERT_START_DOCSTRING,
1902)
1903class BertForSequenceClassification(BertPreTrainedModel):
1904def __init__(self, config):
1905super().__init__(config)
1906self.num_labels = config.num_labels
1907
1908self.bert = BertModel(config)
1909self.dropout = nn.Dropout(config.hidden_dropout_prob)
1910self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1911
1912self.init_weights()
1913
1914
1915@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1916@add_code_sample_docstrings(
1917tokenizer_class=_TOKENIZER_FOR_DOC,
1918checkpoint="bert-base-uncased",
1919output_type=SequenceClassifierOutput,
1920config_class=_CONFIG_FOR_DOC,
1921)
1922def forward(
1923self,
1924input_ids=None,
1925attention_mask=None,
1926token_type_ids=None, #segment_ids
1927position_ids=None, #
1928head_mask=None,
1929inputs_embeds=None,
1930labels=None,
1931output_attentions=None,
1932output_hidden_states=None,
1933return_dict=None,
1934):
1935r"""
1936labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1937Labels for computing the sequence classification/regression loss.
1938Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1939If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1940If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1941"""
1942return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1943
1944outputs = self.bert(
1945input_ids,
1946attention_mask=attention_mask,
1947token_type_ids=token_type_ids,
1948position_ids=position_ids,
1949head_mask=head_mask,
1950inputs_embeds=inputs_embeds,
1951output_attentions=output_attentions,
1952output_hidden_states=output_hidden_states,
1953return_dict=return_dict,
1954)
1955
1956pooled_output = outputs[1]
1957
1958#pooled_output = pooled_output.detach()
1959pooled_output = self.dropout(pooled_output)
1960logits = self.classifier(pooled_output)
1961
1962loss = None
1963if labels is not None:
1964if self.num_labels == 1:
1965# We are doing regression
1966loss_fct = MSELoss()
1967loss = loss_fct(logits.view(-1), labels.view(-1))
1968else:
1969loss_fct = CrossEntropyLoss()
1970loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1971
1972
1973if not return_dict:
1974output = (logits,) + outputs[2:]
1975return ((loss,) + output) if loss is not None else output
1976
1977return SequenceClassifierOutput(
1978loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
1979)
1980
1981
1982@add_start_docstrings(
1983"""Bert Model with a multiple choice classification head on top (a linear layer on top of
1984the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1985BERT_START_DOCSTRING,
1986)
1987class BertForMultipleChoice(BertPreTrainedModel):
1988def __init__(self, config):
1989super().__init__(config)
1990
1991self.bert = BertModel(config)
1992self.dropout = nn.Dropout(config.hidden_dropout_prob)
1993self.classifier = nn.Linear(config.hidden_size, 1)
1994
1995self.init_weights()
1996
1997@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1998@add_code_sample_docstrings(
1999tokenizer_class=_TOKENIZER_FOR_DOC,
2000checkpoint="bert-base-uncased",
2001output_type=MultipleChoiceModelOutput,
2002config_class=_CONFIG_FOR_DOC,
2003)
2004def forward(
2005self,
2006input_ids=None,
2007attention_mask=None,
2008token_type_ids=None,
2009position_ids=None,
2010head_mask=None,
2011inputs_embeds=None,
2012labels=None,
2013output_attentions=None,
2014output_hidden_states=None,
2015return_dict=None,
2016):
2017r"""
2018labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
2019Labels for computing the multiple choice classification loss.
2020Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
2021of the input tensors. (see `input_ids` above)
2022"""
2023return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2024num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
2025
2026input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
2027attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
2028token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
2029position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
2030inputs_embeds = (
2031inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
2032if inputs_embeds is not None
2033else None
2034)
2035
2036outputs = self.bert(
2037input_ids,
2038attention_mask=attention_mask,
2039token_type_ids=token_type_ids,
2040position_ids=position_ids,
2041head_mask=head_mask,
2042inputs_embeds=inputs_embeds,
2043output_attentions=output_attentions,
2044output_hidden_states=output_hidden_states,
2045return_dict=return_dict,
2046)
2047
2048pooled_output = outputs[1]
2049
2050pooled_output = self.dropout(pooled_output)
2051logits = self.classifier(pooled_output)
2052reshaped_logits = logits.view(-1, num_choices)
2053
2054loss = None
2055if labels is not None:
2056loss_fct = CrossEntropyLoss()
2057loss = loss_fct(reshaped_logits, labels)
2058
2059if not return_dict:
2060output = (reshaped_logits,) + outputs[2:]
2061return ((loss,) + output) if loss is not None else output
2062
2063return MultipleChoiceModelOutput(
2064loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
2065)
2066
2067
2068@add_start_docstrings(
2069"""Bert Model with a token classification head on top (a linear layer on top of
2070the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
2071BERT_START_DOCSTRING,
2072)
2073class BertForTokenClassification(BertPreTrainedModel):
2074def __init__(self, config):
2075super().__init__(config)
2076self.num_labels = config.num_labels
2077
2078self.bert = BertModel(config)
2079self.dropout = nn.Dropout(config.hidden_dropout_prob)
2080self.classifier = nn.Linear(config.hidden_size, config.num_labels)
2081
2082self.init_weights()
2083
2084@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
2085@add_code_sample_docstrings(
2086tokenizer_class=_TOKENIZER_FOR_DOC,
2087checkpoint="bert-base-uncased",
2088output_type=TokenClassifierOutput,
2089config_class=_CONFIG_FOR_DOC,
2090)
2091def forward(
2092self,
2093input_ids=None,
2094attention_mask=None,
2095token_type_ids=None,
2096position_ids=None,
2097head_mask=None,
2098inputs_embeds=None,
2099labels=None,
2100output_attentions=None,
2101output_hidden_states=None,
2102return_dict=None,
2103):
2104r"""
2105labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
2106Labels for computing the token classification loss.
2107Indices should be in ``[0, ..., config.num_labels - 1]``.
2108"""
2109return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2110
2111outputs = self.bert(
2112input_ids,
2113attention_mask=attention_mask,
2114token_type_ids=token_type_ids,
2115position_ids=position_ids,
2116head_mask=head_mask,
2117inputs_embeds=inputs_embeds,
2118output_attentions=output_attentions,
2119output_hidden_states=output_hidden_states,
2120return_dict=return_dict,
2121)
2122
2123sequence_output = outputs[0]
2124
2125sequence_output = self.dropout(sequence_output)
2126logits = self.classifier(sequence_output)
2127
2128loss = None
2129if labels is not None:
2130loss_fct = CrossEntropyLoss()
2131# Only keep active parts of the loss
2132if attention_mask is not None:
2133active_loss = attention_mask.view(-1) == 1
2134active_logits = logits.view(-1, self.num_labels)
2135active_labels = torch.where(
2136active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
2137)
2138loss = loss_fct(active_logits, active_labels)
2139else:
2140loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
2141
2142if not return_dict:
2143output = (logits,) + outputs[2:]
2144return ((loss,) + output) if loss is not None else output
2145
2146return TokenClassifierOutput(
2147loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
2148)
2149
2150
2151@add_start_docstrings(
2152"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
2153layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
2154BERT_START_DOCSTRING,
2155)
2156class BertForQuestionAnswering(BertPreTrainedModel):
2157def __init__(self, config):
2158super().__init__(config)
2159self.num_labels = config.num_labels
2160
2161self.bert = BertModel(config)
2162self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2163
2164self.init_weights()
2165
2166@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
2167@add_code_sample_docstrings(
2168tokenizer_class=_TOKENIZER_FOR_DOC,
2169checkpoint="bert-base-uncased",
2170output_type=QuestionAnsweringModelOutput,
2171config_class=_CONFIG_FOR_DOC,
2172)
2173def forward(
2174self,
2175input_ids=None,
2176attention_mask=None,
2177token_type_ids=None,
2178position_ids=None,
2179head_mask=None,
2180inputs_embeds=None,
2181start_positions=None,
2182end_positions=None,
2183output_attentions=None,
2184output_hidden_states=None,
2185return_dict=None,
2186):
2187r"""
2188start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
2189Labels for position (index) of the start of the labelled span for computing the token classification loss.
2190Positions are clamped to the length of the sequence (`sequence_length`).
2191Position outside of the sequence are not taken into account for computing the loss.
2192end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
2193Labels for position (index) of the end of the labelled span for computing the token classification loss.
2194Positions are clamped to the length of the sequence (`sequence_length`).
2195Position outside of the sequence are not taken into account for computing the loss.
2196"""
2197return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2198
2199outputs = self.bert(
2200input_ids,
2201attention_mask=attention_mask,
2202token_type_ids=token_type_ids,
2203position_ids=position_ids,
2204head_mask=head_mask,
2205inputs_embeds=inputs_embeds,
2206output_attentions=output_attentions,
2207output_hidden_states=output_hidden_states,
2208return_dict=return_dict,
2209)
2210
2211sequence_output = outputs[0]
2212
2213logits = self.qa_outputs(sequence_output)
2214start_logits, end_logits = logits.split(1, dim=-1)
2215start_logits = start_logits.squeeze(-1)
2216end_logits = end_logits.squeeze(-1)
2217
2218total_loss = None
2219if start_positions is not None and end_positions is not None:
2220# If we are on multi-GPU, split add a dimension
2221if len(start_positions.size()) > 1:
2222start_positions = start_positions.squeeze(-1)
2223if len(end_positions.size()) > 1:
2224end_positions = end_positions.squeeze(-1)
2225# sometimes the start/end positions are outside our model inputs, we ignore these terms
2226ignored_index = start_logits.size(1)
2227start_positions.clamp_(0, ignored_index)
2228end_positions.clamp_(0, ignored_index)
2229
2230loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2231start_loss = loss_fct(start_logits, start_positions)
2232end_loss = loss_fct(end_logits, end_positions)
2233total_loss = (start_loss + end_loss) / 2
2234
2235if not return_dict:
2236output = (start_logits, end_logits) + outputs[2:]
2237return ((total_loss,) + output) if total_loss is not None else output
2238
2239return QuestionAnsweringModelOutput(
2240loss=total_loss,
2241start_logits=start_logits,
2242end_logits=end_logits,
2243hidden_states=outputs.hidden_states,
2244attentions=outputs.attentions,
2245)
2246