CSS-LM
385 строк · 18.3 Кб
1# coding=utf-8
2# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" TF 2.0 Flaubert model.
16"""
17
18import logging19import random20
21import tensorflow as tf22
23from .configuration_flaubert import FlaubertConfig24from .file_utils import add_start_docstrings25from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list26from .modeling_tf_xlm import (27TFXLMForMultipleChoice,28TFXLMForQuestionAnsweringSimple,29TFXLMForSequenceClassification,30TFXLMForTokenClassification,31TFXLMMainLayer,32TFXLMModel,33TFXLMPredLayer,34TFXLMWithLMHeadModel,35get_masks,36)
37from .tokenization_utils import BatchEncoding38
39
40logger = logging.getLogger(__name__)41
42TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [43# See all Flaubert models at https://huggingface.co/models?filter=flaubert44]
45
46FLAUBERT_START_DOCSTRING = r"""47
48This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
49Use it as a regular TF 2.0 Keras Model and
50refer to the TF 2.0 documentation for all matter related to general usage and behavior.
51
52Parameters:
53config (:class:`~transformers.FlaubertConfig`): Model configuration class with all the parameters of the model.
54Initializing with a config file does not load the weights associated with the model, only the configuration.
55Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
56"""
57
58FLAUBERT_INPUTS_DOCSTRING = r"""59Args:
60input_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`):
61Indices of input sequence tokens in the vocabulary.
62Indices can be obtained using :class:`transformers.BertTokenizer`.
63See :func:`transformers.PreTrainedTokenizer.encode` and
64:func:`transformers.PreTrainedTokenizer.__call__` for details.
65`What are input IDs? <../glossary.html#input-ids>`__
66attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
67Mask to avoid performing attention on padding token indices.
68Mask values selected in ``[0, 1]``:
69``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
70`What are attention masks? <../glossary.html#attention-mask>`__
71langs (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
72A parallel sequence of tokens to be used to indicate the language of each token in the input.
73Indices are languages ids which can be obtained from the language names by using two conversion mappings
74provided in the configuration of the model (only provided for multilingual models).
75More precisely, the `language name -> language id` mapping is in `model.config.lang2id` (dict str -> int) and
76the `language id -> language name` mapping is `model.config.id2lang` (dict int -> str).
77See usage examples detailed in the `multilingual documentation <https://huggingface.co/transformers/multilingual.html>`__.
78token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
79Segment token indices to indicate first and second portions of the inputs.
80Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
81corresponds to a `sentence B` token
82`What are token type IDs? <../glossary.html#token-type-ids>`_
83position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
84Indices of positions of each input sequence tokens in the position embeddings.
85Selected in the range ``[0, config.max_position_embeddings - 1]``.
86`What are position IDs? <../glossary.html#position-ids>`_
87lengths (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
88Length of each sentence that can be used to avoid performing attention on padding token indices.
89You can also use `attention_mask` for the same result (see above), kept here for compatbility.
90Indices selected in ``[0, ..., input_ids.size(-1)]``:
91cache (:obj:`Dict[str, tf.Tensor]`, `optional`, defaults to :obj:`None`):
92dictionary with ``tf.Tensor`` that contains pre-computed
93hidden-states (key and values in the attention blocks) as computed by the model
94(see `cache` output below). Can be used to speed up sequential decoding.
95The dictionary object will be modified in-place during the forward pass to add newly computed hidden-states.
96head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
97Mask to nullify selected heads of the self-attention modules.
98Mask values selected in ``[0, 1]``:
99:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
100inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
101Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
102This is useful if you want more control over how to convert `input_ids` indices into associated vectors
103than the model's internal embedding lookup matrix.
104output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
105If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
106"""
107
108
109@add_start_docstrings(110"The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.",111FLAUBERT_START_DOCSTRING,112)
113class TFFlaubertModel(TFXLMModel):114config_class = FlaubertConfig115
116def __init__(self, config, *inputs, **kwargs):117super().__init__(config, *inputs, **kwargs)118self.transformer = TFFlaubertMainLayer(config, name="transformer")119
120
121@keras_serializable
122class TFFlaubertMainLayer(TFXLMMainLayer):123def __init__(self, config, *inputs, **kwargs):124super().__init__(config, *inputs, **kwargs)125self.layerdrop = getattr(config, "layerdrop", 0.0)126self.pre_norm = getattr(config, "pre_norm", False)127self.output_attentions = config.output_attentions128self.output_hidden_states = config.output_hidden_states129
130def call(131self,132inputs,133attention_mask=None,134langs=None,135token_type_ids=None,136position_ids=None,137lengths=None,138cache=None,139head_mask=None,140inputs_embeds=None,141output_attentions=None,142output_hidden_states=None,143training=False,144):145# removed: src_enc=None, src_len=None146if isinstance(inputs, (tuple, list)):147input_ids = inputs[0]148attention_mask = inputs[1] if len(inputs) > 1 else attention_mask149langs = inputs[2] if len(inputs) > 2 else langs150token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids151position_ids = inputs[4] if len(inputs) > 4 else position_ids152lengths = inputs[5] if len(inputs) > 5 else lengths153cache = inputs[6] if len(inputs) > 6 else cache154head_mask = inputs[7] if len(inputs) > 7 else head_mask155inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds156output_attentions = inputs[9] if len(inputs) > 9 else output_attentions157output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states158assert len(inputs) <= 11, "Too many inputs."159elif isinstance(inputs, (dict, BatchEncoding)):160input_ids = inputs.get("input_ids")161attention_mask = inputs.get("attention_mask", attention_mask)162langs = inputs.get("langs", langs)163token_type_ids = inputs.get("token_type_ids", token_type_ids)164position_ids = inputs.get("position_ids", position_ids)165lengths = inputs.get("lengths", lengths)166cache = inputs.get("cache", cache)167head_mask = inputs.get("head_mask", head_mask)168inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)169output_attentions = inputs.get("output_attentions", output_attentions)170output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)171assert len(inputs) <= 11, "Too many inputs."172else:173input_ids = inputs174
175output_attentions = output_attentions if output_attentions is not None else self.output_attentions176output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states177
178if input_ids is not None and inputs_embeds is not None:179raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")180elif input_ids is not None:181bs, slen = shape_list(input_ids)182elif inputs_embeds is not None:183bs, slen = shape_list(inputs_embeds)[:2]184else:185raise ValueError("You have to specify either input_ids or inputs_embeds")186
187if lengths is None:188if input_ids is not None:189lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.pad_index), dtype=tf.int32), axis=1)190else:191lengths = tf.convert_to_tensor([slen] * bs, tf.int32)192# mask = input_ids != self.pad_index193
194# check inputs195# assert shape_list(lengths)[0] == bs196tf.debugging.assert_equal(shape_list(lengths)[0], bs)197# assert lengths.max().item() <= slen198# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0199# assert (src_enc is None) == (src_len is None)200# if src_enc is not None:201# assert self.is_decoder202# assert src_enc.size(0) == bs203
204# generate masks205mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)206# if self.is_decoder and src_enc is not None:207# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]208
209# position_ids210if position_ids is None:211position_ids = tf.expand_dims(tf.range(slen), axis=0)212else:213# assert shape_list(position_ids) == [bs, slen] # (slen, bs)214tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])215# position_ids = position_ids.transpose(0, 1)216
217# langs218if langs is not None:219# assert shape_list(langs) == [bs, slen] # (slen, bs)220tf.debugging.assert_equal(shape_list(langs), [bs, slen])221# langs = langs.transpose(0, 1)222
223# Prepare head mask if needed224# 1.0 in head_mask indicate we keep the head225# attention_probs has shape bsz x n_heads x N x N226# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]227# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x qlen x klen]228if head_mask is not None:229raise NotImplementedError230else:231head_mask = [None] * self.n_layers232
233# do not recompute cached elements234if cache is not None and input_ids is not None:235_slen = slen - cache["slen"]236input_ids = input_ids[:, -_slen:]237position_ids = position_ids[:, -_slen:]238if langs is not None:239langs = langs[:, -_slen:]240mask = mask[:, -_slen:]241attn_mask = attn_mask[:, -_slen:]242
243# embeddings244if inputs_embeds is None:245inputs_embeds = self.embeddings(input_ids)246
247tensor = inputs_embeds + self.position_embeddings(position_ids)248if langs is not None and self.use_lang_emb:249tensor = tensor + self.lang_embeddings(langs)250if token_type_ids is not None:251tensor = tensor + self.embeddings(token_type_ids)252tensor = self.layer_norm_emb(tensor)253tensor = self.dropout(tensor, training=training)254tensor = tensor * mask[..., tf.newaxis]255
256# transformer layers257hidden_states = ()258attentions = ()259for i in range(self.n_layers):260# LayerDrop261dropout_probability = random.uniform(0, 1)262if training and (dropout_probability < self.layerdrop):263continue264
265if output_hidden_states:266hidden_states = hidden_states + (tensor,)267
268# self attention269if not self.pre_norm:270attn_outputs = self.attentions[i](271[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training272)273attn = attn_outputs[0]274if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:275attentions = attentions + (attn_outputs[1],)276attn = self.dropout(attn, training=training)277tensor = tensor + attn278tensor = self.layer_norm1[i](tensor)279else:280tensor_normalized = self.layer_norm1[i](tensor)281attn_outputs = self.attentions[i](282[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training283)284attn = attn_outputs[0]285if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:286attentions = attentions + (attn_outputs[1],)287attn = self.dropout(attn, training=training)288tensor = tensor + attn289
290# encoder attention (for decoder only)291# if self.is_decoder and src_enc is not None:292# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)293# attn = F.dropout(attn, p=self.dropout, training=self.training)294# tensor = tensor + attn295# tensor = self.layer_norm15[i](tensor)296
297# FFN298if not self.pre_norm:299tensor = tensor + self.ffns[i](tensor)300tensor = self.layer_norm2[i](tensor)301else:302tensor_normalized = self.layer_norm2[i](tensor)303tensor = tensor + self.ffns[i](tensor_normalized)304
305tensor = tensor * mask[..., tf.newaxis]306
307# Add last hidden state308if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:309hidden_states = hidden_states + (tensor,)310
311# update cache length312if cache is not None:313cache["slen"] += tensor.size(1)314
315# move back sequence length to dimension 0316# tensor = tensor.transpose(0, 1)317
318outputs = (tensor,)319if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:320outputs = outputs + (hidden_states,)321if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:322outputs = outputs + (attentions,)323return outputs # outputs, (hidden_states), (attentions)324
325
326@add_start_docstrings(327"""The Flaubert Model transformer with a language modeling head on top328(linear layer with weights tied to the input embeddings). """,329FLAUBERT_START_DOCSTRING,330)
331class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):332config_class = FlaubertConfig333
334def __init__(self, config, *inputs, **kwargs):335super().__init__(config, *inputs, **kwargs)336self.transformer = TFFlaubertMainLayer(config, name="transformer")337self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj")338
339
340@add_start_docstrings(341"""Flaubert Model with a sequence classification/regression head on top (a linear layer on top of342the pooled output) e.g. for GLUE tasks. """,343FLAUBERT_START_DOCSTRING,344)
345class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):346config_class = FlaubertConfig347
348def __init__(self, config, *inputs, **kwargs):349super().__init__(config, *inputs, **kwargs)350self.transformer = TFFlaubertMainLayer(config, name="transformer")351
352
353@add_start_docstrings(354"""Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of355the hidden-states output to compute `span start logits` and `span end logits`). """,356FLAUBERT_START_DOCSTRING,357)
358class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):359config_class = FlaubertConfig360
361def __init__(self, config, *inputs, **kwargs):362super().__init__(config, *inputs, **kwargs)363self.transformer = TFFlaubertMainLayer(config, name="transformer")364
365
366@add_start_docstrings(367"""Flaubert Model with a token classification head on top (a linear layer on top of368the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,369FLAUBERT_START_DOCSTRING,370)
371class TFFlaubertForTokenClassification(TFXLMForTokenClassification):372def __init__(self, config, *inputs, **kwargs):373super().__init__(config, *inputs, **kwargs)374self.transformer = TFFlaubertMainLayer(config, name="transformer")375
376
377@add_start_docstrings(378"""Flaubert Model with a multiple choice classification head on top (a linear layer on top of379the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,380FLAUBERT_START_DOCSTRING,381)
382class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):383def __init__(self, config, *inputs, **kwargs):384super().__init__(config, *inputs, **kwargs)385self.transformer = TFFlaubertMainLayer(config, name="transformer")386