CSS-LM

Форк
0
/
modeling_tf_flaubert.py 
385 строк · 18.3 Кб
1
# coding=utf-8
2
# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" TF 2.0 Flaubert model.
16
"""
17

18
import logging
19
import random
20

21
import tensorflow as tf
22

23
from .configuration_flaubert import FlaubertConfig
24
from .file_utils import add_start_docstrings
25
from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list
26
from .modeling_tf_xlm import (
27
    TFXLMForMultipleChoice,
28
    TFXLMForQuestionAnsweringSimple,
29
    TFXLMForSequenceClassification,
30
    TFXLMForTokenClassification,
31
    TFXLMMainLayer,
32
    TFXLMModel,
33
    TFXLMPredLayer,
34
    TFXLMWithLMHeadModel,
35
    get_masks,
36
)
37
from .tokenization_utils import BatchEncoding
38

39

40
logger = logging.getLogger(__name__)
41

42
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
    # See all Flaubert models at https://huggingface.co/models?filter=flaubert
44
]
45

46
FLAUBERT_START_DOCSTRING = r"""
47

48
    This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
49
    Use it as a regular TF 2.0 Keras Model and
50
    refer to the TF 2.0 documentation for all matter related to general usage and behavior.
51

52
    Parameters:
53
        config (:class:`~transformers.FlaubertConfig`): Model configuration class with all the parameters of the model.
54
            Initializing with a config file does not load the weights associated with the model, only the configuration.
55
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
56
"""
57

58
FLAUBERT_INPUTS_DOCSTRING = r"""
59
    Args:
60
        input_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`):
61
            Indices of input sequence tokens in the vocabulary.
62
            Indices can be obtained using :class:`transformers.BertTokenizer`.
63
            See :func:`transformers.PreTrainedTokenizer.encode` and
64
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
65
            `What are input IDs? <../glossary.html#input-ids>`__
66
        attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
67
            Mask to avoid performing attention on padding token indices.
68
            Mask values selected in ``[0, 1]``:
69
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
70
            `What are attention masks? <../glossary.html#attention-mask>`__
71
        langs (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
72
            A parallel sequence of tokens to be used to indicate the language of each token in the input.
73
            Indices are languages ids which can be obtained from the language names by using two conversion mappings
74
            provided in the configuration of the model (only provided for multilingual models).
75
            More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
76
            the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
77
            See usage examples detailed in the `multilingual documentation <https://huggingface.co/transformers/multilingual.html>`__.
78
        token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
79
            Segment token indices to indicate first and second portions of the inputs.
80
            Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
81
            corresponds to a `sentence B` token
82
            `What are token type IDs? <../glossary.html#token-type-ids>`_
83
        position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
84
            Indices of positions of each input sequence tokens in the position embeddings.
85
            Selected in the range ``[0, config.max_position_embeddings - 1]``.
86
            `What are position IDs? <../glossary.html#position-ids>`_
87
        lengths (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
88
            Length of each sentence that can be used to avoid performing attention on padding token indices.
89
            You can also use `attention_mask` for the same result (see above), kept here for compatbility.
90
            Indices selected in ``[0, ..., input_ids.size(-1)]``:
91
        cache (:obj:`Dict[str, tf.Tensor]`, `optional`, defaults to :obj:`None`):
92
            dictionary with ``tf.Tensor`` that contains pre-computed
93
            hidden-states (key and values in the attention blocks) as computed by the model
94
            (see `cache` output below). Can be used to speed up sequential decoding.
95
            The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
96
        head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
97
            Mask to nullify selected heads of the self-attention modules.
98
            Mask values selected in ``[0, 1]``:
99
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
100
        inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
101
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
102
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
103
            than the model's internal embedding lookup matrix.
104
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
105
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
106
"""
107

108

109
@add_start_docstrings(
110
    "The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",
111
    FLAUBERT_START_DOCSTRING,
112
)
113
class TFFlaubertModel(TFXLMModel):
114
    config_class = FlaubertConfig
115

116
    def __init__(self, config, *inputs, **kwargs):
117
        super().__init__(config, *inputs, **kwargs)
118
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
119

120

121
@keras_serializable
122
class TFFlaubertMainLayer(TFXLMMainLayer):
123
    def __init__(self, config, *inputs, **kwargs):
124
        super().__init__(config, *inputs, **kwargs)
125
        self.layerdrop = getattr(config, "layerdrop", 0.0)
126
        self.pre_norm = getattr(config, "pre_norm", False)
127
        self.output_attentions = config.output_attentions
128
        self.output_hidden_states = config.output_hidden_states
129

130
    def call(
131
        self,
132
        inputs,
133
        attention_mask=None,
134
        langs=None,
135
        token_type_ids=None,
136
        position_ids=None,
137
        lengths=None,
138
        cache=None,
139
        head_mask=None,
140
        inputs_embeds=None,
141
        output_attentions=None,
142
        output_hidden_states=None,
143
        training=False,
144
    ):
145
        # removed: src_enc=None, src_len=None
146
        if isinstance(inputs, (tuple, list)):
147
            input_ids = inputs[0]
148
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
149
            langs = inputs[2] if len(inputs) > 2 else langs
150
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
151
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
152
            lengths = inputs[5] if len(inputs) > 5 else lengths
153
            cache = inputs[6] if len(inputs) > 6 else cache
154
            head_mask = inputs[7] if len(inputs) > 7 else head_mask
155
            inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
156
            output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
157
            output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
158
            assert len(inputs) <= 11, "Too many inputs."
159
        elif isinstance(inputs, (dict, BatchEncoding)):
160
            input_ids = inputs.get("input_ids")
161
            attention_mask = inputs.get("attention_mask", attention_mask)
162
            langs = inputs.get("langs", langs)
163
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
164
            position_ids = inputs.get("position_ids", position_ids)
165
            lengths = inputs.get("lengths", lengths)
166
            cache = inputs.get("cache", cache)
167
            head_mask = inputs.get("head_mask", head_mask)
168
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
169
            output_attentions = inputs.get("output_attentions", output_attentions)
170
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
171
            assert len(inputs) <= 11, "Too many inputs."
172
        else:
173
            input_ids = inputs
174

175
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
176
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
177

178
        if input_ids is not None and inputs_embeds is not None:
179
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
180
        elif input_ids is not None:
181
            bs, slen = shape_list(input_ids)
182
        elif inputs_embeds is not None:
183
            bs, slen = shape_list(inputs_embeds)[:2]
184
        else:
185
            raise ValueError("You have to specify either input_ids or inputs_embeds")
186

187
        if lengths is None:
188
            if input_ids is not None:
189
                lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)
190
            else:
191
                lengths = tf.convert_to_tensor([slen] * bs, tf.int32)
192
        # mask = input_ids != self.pad_index
193

194
        # check inputs
195
        # assert shape_list(lengths)[0] == bs
196
        tf.debugging.assert_equal(shape_list(lengths)[0], bs)
197
        # assert lengths.max().item() <= slen
198
        # input_ids = input_ids.transpose(0, 1)  # batch size as dimension 0
199
        # assert (src_enc is None) == (src_len is None)
200
        # if src_enc is not None:
201
        #     assert self.is_decoder
202
        #     assert src_enc.size(0) == bs
203

204
        # generate masks
205
        mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
206
        # if self.is_decoder and src_enc is not None:
207
        #     src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
208

209
        # position_ids
210
        if position_ids is None:
211
            position_ids = tf.expand_dims(tf.range(slen), axis=0)
212
        else:
213
            # assert shape_list(position_ids) == [bs, slen]  # (slen, bs)
214
            tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
215
            # position_ids = position_ids.transpose(0, 1)
216

217
        # langs
218
        if langs is not None:
219
            # assert shape_list(langs) == [bs, slen]  # (slen, bs)
220
            tf.debugging.assert_equal(shape_list(langs), [bs, slen])
221
            # langs = langs.transpose(0, 1)
222

223
        # Prepare head mask if needed
224
        # 1.0 in head_mask indicate we keep the head
225
        # attention_probs has shape bsz x n_heads x N x N
226
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
227
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]
228
        if head_mask is not None:
229
            raise NotImplementedError
230
        else:
231
            head_mask = [None] * self.n_layers
232

233
        # do not recompute cached elements
234
        if cache is not None and input_ids is not None:
235
            _slen = slen - cache["slen"]
236
            input_ids = input_ids[:, -_slen:]
237
            position_ids = position_ids[:, -_slen:]
238
            if langs is not None:
239
                langs = langs[:, -_slen:]
240
            mask = mask[:, -_slen:]
241
            attn_mask = attn_mask[:, -_slen:]
242

243
        # embeddings
244
        if inputs_embeds is None:
245
            inputs_embeds = self.embeddings(input_ids)
246

247
        tensor = inputs_embeds + self.position_embeddings(position_ids)
248
        if langs is not None and self.use_lang_emb:
249
            tensor = tensor + self.lang_embeddings(langs)
250
        if token_type_ids is not None:
251
            tensor = tensor + self.embeddings(token_type_ids)
252
        tensor = self.layer_norm_emb(tensor)
253
        tensor = self.dropout(tensor, training=training)
254
        tensor = tensor * mask[..., tf.newaxis]
255

256
        # transformer layers
257
        hidden_states = ()
258
        attentions = ()
259
        for i in range(self.n_layers):
260
            # LayerDrop
261
            dropout_probability = random.uniform(0, 1)
262
            if training and (dropout_probability < self.layerdrop):
263
                continue
264

265
            if output_hidden_states:
266
                hidden_states = hidden_states + (tensor,)
267

268
            # self attention
269
            if not self.pre_norm:
270
                attn_outputs = self.attentions[i](
271
                    [tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
272
                )
273
                attn = attn_outputs[0]
274
                if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
275
                    attentions = attentions + (attn_outputs[1],)
276
                attn = self.dropout(attn, training=training)
277
                tensor = tensor + attn
278
                tensor = self.layer_norm1[i](tensor)
279
            else:
280
                tensor_normalized = self.layer_norm1[i](tensor)
281
                attn_outputs = self.attentions[i](
282
                    [tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
283
                )
284
                attn = attn_outputs[0]
285
                if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
286
                    attentions = attentions + (attn_outputs[1],)
287
                attn = self.dropout(attn, training=training)
288
                tensor = tensor + attn
289

290
            # encoder attention (for decoder only)
291
            # if self.is_decoder and src_enc is not None:
292
            #     attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
293
            #     attn = F.dropout(attn, p=self.dropout, training=self.training)
294
            #     tensor = tensor + attn
295
            #     tensor = self.layer_norm15[i](tensor)
296

297
            # FFN
298
            if not self.pre_norm:
299
                tensor = tensor + self.ffns[i](tensor)
300
                tensor = self.layer_norm2[i](tensor)
301
            else:
302
                tensor_normalized = self.layer_norm2[i](tensor)
303
                tensor = tensor + self.ffns[i](tensor_normalized)
304

305
            tensor = tensor * mask[..., tf.newaxis]
306

307
        # Add last hidden state
308
        if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
309
            hidden_states = hidden_states + (tensor,)
310

311
        # update cache length
312
        if cache is not None:
313
            cache["slen"] += tensor.size(1)
314

315
        # move back sequence length to dimension 0
316
        # tensor = tensor.transpose(0, 1)
317

318
        outputs = (tensor,)
319
        if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
320
            outputs = outputs + (hidden_states,)
321
        if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
322
            outputs = outputs + (attentions,)
323
        return outputs  # outputs, (hidden_states), (attentions)
324

325

326
@add_start_docstrings(
327
    """The Flaubert Model transformer with a language modeling head on top
328
    (linear layer with weights tied to the input embeddings). """,
329
    FLAUBERT_START_DOCSTRING,
330
)
331
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
332
    config_class = FlaubertConfig
333

334
    def __init__(self, config, *inputs, **kwargs):
335
        super().__init__(config, *inputs, **kwargs)
336
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
337
        self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")
338

339

340
@add_start_docstrings(
341
    """Flaubert Model with a sequence classification/regression head on top (a linear layer on top of
342
    the pooled output) e.g. for GLUE tasks. """,
343
    FLAUBERT_START_DOCSTRING,
344
)
345
class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
346
    config_class = FlaubertConfig
347

348
    def __init__(self, config, *inputs, **kwargs):
349
        super().__init__(config, *inputs, **kwargs)
350
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
351

352

353
@add_start_docstrings(
354
    """Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
355
    the hidden-states output to compute `span start logits` and `span end logits`). """,
356
    FLAUBERT_START_DOCSTRING,
357
)
358
class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
359
    config_class = FlaubertConfig
360

361
    def __init__(self, config, *inputs, **kwargs):
362
        super().__init__(config, *inputs, **kwargs)
363
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
364

365

366
@add_start_docstrings(
367
    """Flaubert Model with a token classification head on top (a linear layer on top of
368
    the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
369
    FLAUBERT_START_DOCSTRING,
370
)
371
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
372
    def __init__(self, config, *inputs, **kwargs):
373
        super().__init__(config, *inputs, **kwargs)
374
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
375

376

377
@add_start_docstrings(
378
    """Flaubert Model with a multiple choice classification head on top (a linear layer on top of
379
    the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
380
    FLAUBERT_START_DOCSTRING,
381
)
382
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
383
    def __init__(self, config, *inputs, **kwargs):
384
        super().__init__(config, *inputs, **kwargs)
385
        self.transformer = TFFlaubertMainLayer(config, name="transformer")
386

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.