CSS-LM
185 строк · 8.5 Кб
1# coding=utf-8
2# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16RetriBERT model
17"""
18
19
20import logging21import math22
23import torch24import torch.nn as nn25import torch.utils.checkpoint as checkpoint26
27from .configuration_retribert import RetriBertConfig28from .file_utils import add_start_docstrings29from .modeling_bert import BertLayerNorm, BertModel30from .modeling_utils import PreTrainedModel31
32
33logger = logging.getLogger(__name__)34
35RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [36"yjernite/retribert-base-uncased",37# See all RetriBert models at https://huggingface.co/models?filter=retribert38]
39
40
41# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
42class RetriBertPreTrainedModel(PreTrainedModel):43""" An abstract class to handle weights initialization and44a simple interface for downloading and loading pretrained models.
45"""
46
47config_class = RetriBertConfig48load_tf_weights = None49base_model_prefix = "retribert"50
51def _init_weights(self, module):52""" Initialize the weights """53if isinstance(module, (nn.Linear, nn.Embedding)):54module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)55elif isinstance(module, BertLayerNorm):56module.bias.data.zero_()57module.weight.data.fill_(1.0)58if isinstance(module, nn.Linear) and module.bias is not None:59module.bias.data.zero_()60
61
62RETRIBERT_START_DOCSTRING = r"""63
64This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
65Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
66usage and behavior.
67
68Parameters:
69config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model.
70Initializing with a config file does not load the weights associated with the model, only the configuration.
71Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
72"""
73
74
75@add_start_docstrings(76"""Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,77)
78class RetriBertModel(RetriBertPreTrainedModel):79def __init__(self, config):80super().__init__(config)81self.projection_dim = config.projection_dim82
83self.bert_query = BertModel(config)84self.bert_doc = None if config.share_encoders else BertModel(config)85self.dropout = nn.Dropout(config.hidden_dropout_prob)86self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)87self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)88
89self.ce_loss = nn.CrossEntropyLoss(reduction="mean")90
91self.init_weights()92
93def embed_sentences_checkpointed(94self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,95):96# reproduces BERT forward pass with checkpointing97if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:98return sent_encoder(input_ids, attention_mask=attention_mask)[1]99else:100# prepare implicit variables101device = input_ids.device102input_shape = input_ids.size()103token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)104head_mask = [None] * sent_encoder.config.num_hidden_layers105extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(106attention_mask, input_shape, device107)108
109# define function for cehckpointing110def partial_encode(*inputs):111encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)112sequence_output = encoder_outputs[0]113pooled_output = sent_encoder.pooler(sequence_output)114return pooled_output115
116# run embedding layer on everything at once117embedding_output = sent_encoder.embeddings(118input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None119)120# run encoding and pooling on one mini-batch at a time121pooled_output_list = []122for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):123b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]124b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]125pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)126pooled_output_list.append(pooled_output)127return torch.cat(pooled_output_list, dim=0)128
129def embed_questions(130self, input_ids, attention_mask=None, checkpoint_batch_size=-1,131):132q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)133return self.project_query(q_reps)134
135def embed_answers(136self, input_ids, attention_mask=None, checkpoint_batch_size=-1,137):138a_reps = self.embed_sentences_checkpointed(139input_ids,140attention_mask,141self.bert_query if self.bert_doc is None else self.bert_doc,142checkpoint_batch_size,143)144return self.project_doc(a_reps)145
146def forward(147self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1148):149r"""150Args:
151input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
152Indices of input sequence tokens in the vocabulary for the queries in a batch.
153
154Indices can be obtained using :class:`transformers.RetriBertTokenizer`.
155See :func:`transformers.PreTrainedTokenizer.encode` and
156:func:`transformers.PreTrainedTokenizer.__call__` for details.
157
158`What are input IDs? <../glossary.html#input-ids>`__
159attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
160Mask to avoid performing attention on queries padding token indices.
161Mask values selected in ``[0, 1]``:
162``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
163
164`What are attention masks? <../glossary.html#attention-mask>`__
165input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
166Indices of input sequence tokens in the vocabulary for the documents in a batch.
167attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
168Mask to avoid performing attention on documents padding token indices.
169
170checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`):
171If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time
172on the GPU. All query representations are still compared to all document representations in the batch.
173
174Return:
175:obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document
176and each cocument to its corresponding query in the batch
177"""
178device = input_ids_query.device179q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)180a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)181compare_scores = torch.mm(q_reps, a_reps.t())182loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))183loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))184loss = (loss_qa + loss_aq) / 2185return loss186