CSS-LM

Форк
0
/
modeling_retribert.py 
185 строк · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
RetriBERT model
17
"""
18

19

20
import logging
21
import math
22

23
import torch
24
import torch.nn as nn
25
import torch.utils.checkpoint as checkpoint
26

27
from .configuration_retribert import RetriBertConfig
28
from .file_utils import add_start_docstrings
29
from .modeling_bert import BertLayerNorm, BertModel
30
from .modeling_utils import PreTrainedModel
31

32

33
logger = logging.getLogger(__name__)
34

35
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
36
    "yjernite/retribert-base-uncased",
37
    # See all RetriBert models at https://huggingface.co/models?filter=retribert
38
]
39

40

41
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
42
class RetriBertPreTrainedModel(PreTrainedModel):
43
    """ An abstract class to handle weights initialization and
44
        a simple interface for downloading and loading pretrained models.
45
    """
46

47
    config_class = RetriBertConfig
48
    load_tf_weights = None
49
    base_model_prefix = "retribert"
50

51
    def _init_weights(self, module):
52
        """ Initialize the weights """
53
        if isinstance(module, (nn.Linear, nn.Embedding)):
54
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
55
        elif isinstance(module, BertLayerNorm):
56
            module.bias.data.zero_()
57
            module.weight.data.fill_(1.0)
58
        if isinstance(module, nn.Linear) and module.bias is not None:
59
            module.bias.data.zero_()
60

61

62
RETRIBERT_START_DOCSTRING = r"""
63

64
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
65
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
66
    usage and behavior.
67

68
    Parameters:
69
        config (:class:`~transformers.RetriBertConfig`): Model configuration class with all the parameters of the model.
70
            Initializing with a config file does not load the weights associated with the model, only the configuration.
71
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
72
"""
73

74

75
@add_start_docstrings(
76
    """Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,
77
)
78
class RetriBertModel(RetriBertPreTrainedModel):
79
    def __init__(self, config):
80
        super().__init__(config)
81
        self.projection_dim = config.projection_dim
82

83
        self.bert_query = BertModel(config)
84
        self.bert_doc = None if config.share_encoders else BertModel(config)
85
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
86
        self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
87
        self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
88

89
        self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
90

91
        self.init_weights()
92

93
    def embed_sentences_checkpointed(
94
        self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,
95
    ):
96
        # reproduces BERT forward pass with checkpointing
97
        if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
98
            return sent_encoder(input_ids, attention_mask=attention_mask)[1]
99
        else:
100
            # prepare implicit variables
101
            device = input_ids.device
102
            input_shape = input_ids.size()
103
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
104
            head_mask = [None] * sent_encoder.config.num_hidden_layers
105
            extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
106
                attention_mask, input_shape, device
107
            )
108

109
            # define function for cehckpointing
110
            def partial_encode(*inputs):
111
                encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
112
                sequence_output = encoder_outputs[0]
113
                pooled_output = sent_encoder.pooler(sequence_output)
114
                return pooled_output
115

116
            # run embedding layer on everything at once
117
            embedding_output = sent_encoder.embeddings(
118
                input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
119
            )
120
            # run encoding and pooling on one mini-batch at a time
121
            pooled_output_list = []
122
            for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
123
                b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
124
                b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
125
                pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
126
                pooled_output_list.append(pooled_output)
127
            return torch.cat(pooled_output_list, dim=0)
128

129
    def embed_questions(
130
        self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
131
    ):
132
        q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)
133
        return self.project_query(q_reps)
134

135
    def embed_answers(
136
        self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
137
    ):
138
        a_reps = self.embed_sentences_checkpointed(
139
            input_ids,
140
            attention_mask,
141
            self.bert_query if self.bert_doc is None else self.bert_doc,
142
            checkpoint_batch_size,
143
        )
144
        return self.project_doc(a_reps)
145

146
    def forward(
147
        self, input_ids_query, attention_mask_query, input_ids_doc, attention_mask_doc, checkpoint_batch_size=-1
148
    ):
149
        r"""
150
    Args:
151
        input_ids_query (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
152
            Indices of input sequence tokens in the vocabulary for the queries in a batch.
153

154
            Indices can be obtained using :class:`transformers.RetriBertTokenizer`.
155
            See :func:`transformers.PreTrainedTokenizer.encode` and
156
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
157

158
            `What are input IDs? <../glossary.html#input-ids>`__
159
        attention_mask_query (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
160
            Mask to avoid performing attention on queries padding token indices.
161
            Mask values selected in ``[0, 1]``:
162
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
163

164
            `What are attention masks? <../glossary.html#attention-mask>`__
165
        input_ids_doc (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
166
            Indices of input sequence tokens in the vocabulary for the documents in a batch.
167
        attention_mask_doc (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
168
            Mask to avoid performing attention on documents padding token indices.
169

170
        checkpoint_batch_size (:obj:`int`, `optional`, defaults to `:obj:`-1`):
171
            If greater than 0, uses gradient checkpointing to only compute sequence representation on checkpoint_batch_size examples at a time
172
            on the GPU. All query representations are still compared to all document representations in the batch.
173

174
    Return:
175
        :obj:`torch.FloatTensor` the bi-directional cross-entropy loss obtained while trying to match each query to its corresponding document
176
        and each cocument to its corresponding query in the batch
177
        """
178
        device = input_ids_query.device
179
        q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)
180
        a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)
181
        compare_scores = torch.mm(q_reps, a_reps.t())
182
        loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
183
        loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
184
        loss = (loss_qa + loss_aq) / 2
185
        return loss
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.