CSS-LM
771 строка · 35.5 Кб
1import logging
2
3import tensorflow as tf
4
5from .configuration_electra import ElectraConfig
6from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
7from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
8from .modeling_tf_utils import (
9TFMaskedLanguageModelingLoss,
10TFQuestionAnsweringLoss,
11TFTokenClassificationLoss,
12get_initializer,
13keras_serializable,
14shape_list,
15)
16from .tokenization_utils import BatchEncoding
17
18
19logger = logging.getLogger(__name__)
20
21_TOKENIZER_FOR_DOC = "ElectraTokenizer"
22
23TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
24"google/electra-small-generator",
25"google/electra-base-generator",
26"google/electra-large-generator",
27"google/electra-small-discriminator",
28"google/electra-base-discriminator",
29"google/electra-large-discriminator",
30# See all ELECTRA models at https://huggingface.co/models?filter=electra
31]
32
33
34class TFElectraEmbeddings(tf.keras.layers.Layer):
35"""Construct the embeddings from word, position and token_type embeddings.
36"""
37
38def __init__(self, config, **kwargs):
39super().__init__(**kwargs)
40self.vocab_size = config.vocab_size
41self.embedding_size = config.embedding_size
42self.initializer_range = config.initializer_range
43
44self.position_embeddings = tf.keras.layers.Embedding(
45config.max_position_embeddings,
46config.embedding_size,
47embeddings_initializer=get_initializer(self.initializer_range),
48name="position_embeddings",
49)
50self.token_type_embeddings = tf.keras.layers.Embedding(
51config.type_vocab_size,
52config.embedding_size,
53embeddings_initializer=get_initializer(self.initializer_range),
54name="token_type_embeddings",
55)
56
57# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
58# any TensorFlow checkpoint file
59self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
60self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
61
62def build(self, input_shape):
63"""Build shared word embedding layer """
64with tf.name_scope("word_embeddings"):
65# Create and initialize weights. The random normal initializer was chosen
66# arbitrarily, and works well.
67self.word_embeddings = self.add_weight(
68"weight",
69shape=[self.vocab_size, self.embedding_size],
70initializer=get_initializer(self.initializer_range),
71)
72super().build(input_shape)
73
74def call(self, inputs, mode="embedding", training=False):
75"""Get token embeddings of inputs.
76Args:
77inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
78mode: string, a valid value is one of "embedding" and "linear".
79Returns:
80outputs: (1) If mode == "embedding", output embedding tensor, float32 with
81shape [batch_size, length, embedding_size]; (2) mode == "linear", output
82linear tensor, float32 with shape [batch_size, length, vocab_size].
83Raises:
84ValueError: if mode is not valid.
85
86Shared weights logic adapted from
87https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
88"""
89if mode == "embedding":
90return self._embedding(inputs, training=training)
91elif mode == "linear":
92return self._linear(inputs)
93else:
94raise ValueError("mode {} is not valid.".format(mode))
95
96def _embedding(self, inputs, training=False):
97"""Applies embedding based on inputs tensor."""
98input_ids, position_ids, token_type_ids, inputs_embeds = inputs
99
100if input_ids is not None:
101input_shape = shape_list(input_ids)
102else:
103input_shape = shape_list(inputs_embeds)[:-1]
104
105seq_length = input_shape[1]
106if position_ids is None:
107position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
108if token_type_ids is None:
109token_type_ids = tf.fill(input_shape, 0)
110
111if inputs_embeds is None:
112inputs_embeds = tf.gather(self.word_embeddings, input_ids)
113position_embeddings = self.position_embeddings(position_ids)
114token_type_embeddings = self.token_type_embeddings(token_type_ids)
115
116embeddings = inputs_embeds + position_embeddings + token_type_embeddings
117embeddings = self.LayerNorm(embeddings)
118embeddings = self.dropout(embeddings, training=training)
119return embeddings
120
121def _linear(self, inputs):
122"""Computes logits by running inputs through a linear layer.
123Args:
124inputs: A float32 tensor with shape [batch_size, length, hidden_size]
125Returns:
126float32 tensor with shape [batch_size, length, vocab_size].
127"""
128batch_size = shape_list(inputs)[0]
129length = shape_list(inputs)[1]
130
131x = tf.reshape(inputs, [-1, self.embedding_size])
132logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
133
134return tf.reshape(logits, [batch_size, length, self.vocab_size])
135
136
137class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer):
138def __init__(self, config, **kwargs):
139super().__init__(**kwargs)
140
141self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
142self.dense_prediction = tf.keras.layers.Dense(1, name="dense_prediction")
143self.config = config
144
145def call(self, discriminator_hidden_states, training=False):
146hidden_states = self.dense(discriminator_hidden_states)
147hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
148logits = tf.squeeze(self.dense_prediction(hidden_states))
149
150return logits
151
152
153class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
154def __init__(self, config, **kwargs):
155super().__init__(**kwargs)
156
157self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
158self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")
159
160def call(self, generator_hidden_states, training=False):
161hidden_states = self.dense(generator_hidden_states)
162hidden_states = ACT2FN["gelu"](hidden_states)
163hidden_states = self.LayerNorm(hidden_states)
164
165return hidden_states
166
167
168class TFElectraPreTrainedModel(TFBertPreTrainedModel):
169
170config_class = ElectraConfig
171base_model_prefix = "electra"
172
173def get_extended_attention_mask(self, attention_mask, input_shape):
174if attention_mask is None:
175attention_mask = tf.fill(input_shape, 1)
176
177# We create a 3D attention mask from a 2D tensor mask.
178# Sizes are [batch_size, 1, 1, to_seq_length]
179# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
180# this attention mask is more simple than the triangular masking of causal attention
181# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
182extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
183
184# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
185# masked positions, this operation will create a tensor which is 0.0 for
186# positions we want to attend and -10000.0 for masked positions.
187# Since we are adding it to the raw scores before the softmax, this is
188# effectively the same as removing these entirely.
189
190extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
191extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
192
193return extended_attention_mask
194
195def get_head_mask(self, head_mask):
196if head_mask is not None:
197raise NotImplementedError
198else:
199head_mask = [None] * self.config.num_hidden_layers
200
201return head_mask
202
203
204@keras_serializable
205class TFElectraMainLayer(TFElectraPreTrainedModel):
206
207config_class = ElectraConfig
208
209def __init__(self, config, **kwargs):
210super().__init__(config, **kwargs)
211self.embeddings = TFElectraEmbeddings(config, name="embeddings")
212
213if config.embedding_size != config.hidden_size:
214self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")
215self.encoder = TFBertEncoder(config, name="encoder")
216self.config = config
217
218def get_input_embeddings(self):
219return self.embeddings
220
221def set_input_embeddings(self, value):
222self.embeddings.word_embeddings = value
223self.embeddings.vocab_size = value.shape[0]
224
225def _resize_token_embeddings(self, new_num_tokens):
226raise NotImplementedError
227
228def _prune_heads(self, heads_to_prune):
229""" Prunes heads of the model.
230heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
231See base class PreTrainedModel
232"""
233raise NotImplementedError
234
235def call(
236self,
237inputs,
238attention_mask=None,
239token_type_ids=None,
240position_ids=None,
241head_mask=None,
242inputs_embeds=None,
243output_attentions=None,
244output_hidden_states=None,
245training=False,
246):
247if isinstance(inputs, (tuple, list)):
248input_ids = inputs[0]
249attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
250token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
251position_ids = inputs[3] if len(inputs) > 3 else position_ids
252head_mask = inputs[4] if len(inputs) > 4 else head_mask
253inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
254output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
255output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
256assert len(inputs) <= 8, "Too many inputs."
257elif isinstance(inputs, (dict, BatchEncoding)):
258input_ids = inputs.get("input_ids")
259attention_mask = inputs.get("attention_mask", attention_mask)
260token_type_ids = inputs.get("token_type_ids", token_type_ids)
261position_ids = inputs.get("position_ids", position_ids)
262head_mask = inputs.get("head_mask", head_mask)
263inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
264output_attentions = inputs.get("output_attentions", output_attentions)
265output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
266assert len(inputs) <= 8, "Too many inputs."
267else:
268input_ids = inputs
269
270output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
271output_hidden_states = (
272output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273)
274
275if input_ids is not None and inputs_embeds is not None:
276raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
277elif input_ids is not None:
278input_shape = shape_list(input_ids)
279elif inputs_embeds is not None:
280input_shape = shape_list(inputs_embeds)[:-1]
281else:
282raise ValueError("You have to specify either input_ids or inputs_embeds")
283
284if attention_mask is None:
285attention_mask = tf.fill(input_shape, 1)
286if token_type_ids is None:
287token_type_ids = tf.fill(input_shape, 0)
288
289extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
290head_mask = self.get_head_mask(head_mask)
291
292hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
293
294if hasattr(self, "embeddings_project"):
295hidden_states = self.embeddings_project(hidden_states, training=training)
296
297hidden_states = self.encoder(
298[hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
299training=training,
300)
301
302return hidden_states
303
304
305ELECTRA_START_DOCSTRING = r"""
306This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
307Use it as a regular TF 2.0 Keras Model and
308refer to the TF 2.0 documentation for all matter related to general usage and behavior.
309
310.. note::
311
312TF 2.0 models accepts two formats as inputs:
313
314- having all inputs as keyword arguments (like PyTorch models), or
315- having all inputs as a list, tuple or dict in the first positional arguments.
316
317This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
318all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
319
320If you choose this second option, there are three possibilities you can use to gather all the input Tensors
321in the first positional argument :
322
323- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
324- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
325:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
326- a dictionary with one or several input Tensors associated to the input names given in the docstring:
327:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
328
329Parameters:
330config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model.
331Initializing with a config file does not load the weights associated with the model, only the configuration.
332Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
333"""
334
335ELECTRA_INPUTS_DOCSTRING = r"""
336Args:
337input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
338Indices of input sequence tokens in the vocabulary.
339
340Indices can be obtained using :class:`transformers.ElectraTokenizer`.
341See :func:`transformers.PreTrainedTokenizer.encode` and
342:func:`transformers.PreTrainedTokenizer.__call__` for details.
343
344`What are input IDs? <../glossary.html#input-ids>`__
345attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
346Mask to avoid performing attention on padding token indices.
347Mask values selected in ``[0, 1]``:
348``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
349
350`What are attention masks? <../glossary.html#attention-mask>`__
351position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
352Indices of positions of each input sequence tokens in the position embeddings.
353Selected in the range ``[0, config.max_position_embeddings - 1]``.
354
355`What are position IDs? <../glossary.html#position-ids>`__
356head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
357Mask to nullify selected heads of the self-attention modules.
358Mask values selected in ``[0, 1]``:
359:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
360inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
361Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
362This is useful if you want more control over how to convert `input_ids` indices into associated vectors
363than the model's internal embedding lookup matrix.
364training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
365Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
366(if set to :obj:`False`) for evaluation.
367
368output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
369If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
370"""
371
372
373@add_start_docstrings(
374"The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
375"the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
376"hidden size and embedding size are different."
377""
378"Both the generator and discriminator checkpoints may be loaded into this model.",
379ELECTRA_START_DOCSTRING,
380)
381class TFElectraModel(TFElectraPreTrainedModel):
382def __init__(self, config, *inputs, **kwargs):
383super().__init__(config, *inputs, **kwargs)
384self.electra = TFElectraMainLayer(config, name="electra")
385
386@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
387@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
388def call(self, inputs, **kwargs):
389r"""
390Returns:
391:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
392last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
393Sequence of hidden-states at the output of the last layer of the model.
394hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
395tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
396of shape :obj:`(batch_size, sequence_length, hidden_size)`.
397
398Hidden-states of the model at the output of each layer plus the initial embedding outputs.
399attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
400tuple of :obj:`tf.Tensor` (one for each layer) of shape
401:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
402
403Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
404heads.
405"""
406outputs = self.electra(inputs, **kwargs)
407return outputs
408
409
410@add_start_docstrings(
411"""Electra model with a binary classification head on top as used during pre-training for identifying generated
412tokens.
413
414Even though both the discriminator and generator may be loaded into this model, the discriminator is
415the only model of the two to have the correct classification head to be used for this model.""",
416ELECTRA_START_DOCSTRING,
417)
418class TFElectraForPreTraining(TFElectraPreTrainedModel):
419def __init__(self, config, **kwargs):
420super().__init__(config, **kwargs)
421
422self.electra = TFElectraMainLayer(config, name="electra")
423self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
424
425@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
426def call(
427self,
428input_ids=None,
429attention_mask=None,
430token_type_ids=None,
431position_ids=None,
432head_mask=None,
433inputs_embeds=None,
434output_attentions=None,
435output_hidden_states=None,
436training=False,
437):
438r"""
439Returns:
440:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
441scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
442Prediction scores of the head (scores for each token before SoftMax).
443hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
444tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
445of shape :obj:`(batch_size, sequence_length, hidden_size)`.
446
447Hidden-states of the model at the output of each layer plus the initial embedding outputs.
448attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
449tuple of :obj:`tf.Tensor` (one for each layer) of shape
450:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
451
452Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
453heads.
454
455Examples::
456
457import tensorflow as tf
458from transformers import ElectraTokenizer, TFElectraForPreTraining
459
460tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
461model = TFElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
462input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
463outputs = model(input_ids)
464scores = outputs[0]
465"""
466
467discriminator_hidden_states = self.electra(
468input_ids,
469attention_mask,
470token_type_ids,
471position_ids,
472head_mask,
473inputs_embeds,
474output_attentions,
475output_hidden_states,
476training=training,
477)
478discriminator_sequence_output = discriminator_hidden_states[0]
479logits = self.discriminator_predictions(discriminator_sequence_output)
480output = (logits,)
481output += discriminator_hidden_states[1:]
482
483return output # (loss), scores, (hidden_states), (attentions)
484
485
486class TFElectraMaskedLMHead(tf.keras.layers.Layer):
487def __init__(self, config, input_embeddings, **kwargs):
488super().__init__(**kwargs)
489self.vocab_size = config.vocab_size
490self.input_embeddings = input_embeddings
491
492def build(self, input_shape):
493self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
494super().build(input_shape)
495
496def call(self, hidden_states, training=False):
497hidden_states = self.input_embeddings(hidden_states, mode="linear")
498hidden_states = hidden_states + self.bias
499return hidden_states
500
501
502@add_start_docstrings(
503"""Electra model with a language modeling head on top.
504
505Even though both the discriminator and generator may be loaded into this model, the generator is
506the only model of the two to have been trained for the masked language modeling task.""",
507ELECTRA_START_DOCSTRING,
508)
509class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
510def __init__(self, config, **kwargs):
511super().__init__(config, **kwargs)
512
513self.vocab_size = config.vocab_size
514self.electra = TFElectraMainLayer(config, name="electra")
515self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
516if isinstance(config.hidden_act, str):
517self.activation = ACT2FN[config.hidden_act]
518else:
519self.activation = config.hidden_act
520self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
521
522def get_output_embeddings(self):
523return self.generator_lm_head
524
525@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
526@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-generator")
527def call(
528self,
529input_ids=None,
530attention_mask=None,
531token_type_ids=None,
532position_ids=None,
533head_mask=None,
534inputs_embeds=None,
535output_attentions=None,
536output_hidden_states=None,
537labels=None,
538training=False,
539):
540r"""
541labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
542Labels for computing the masked language modeling loss.
543Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
544Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
545in ``[0, ..., config.vocab_size]``
546
547Returns:
548:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
549prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
550Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
551hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
552tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
553of shape :obj:`(batch_size, sequence_length, hidden_size)`.
554
555Hidden-states of the model at the output of each layer plus the initial embedding outputs.
556attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
557tuple of :obj:`tf.Tensor` (one for each layer) of shape
558:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
559
560Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
561heads.
562"""
563if isinstance(input_ids, (tuple, list)):
564labels = input_ids[8] if len(input_ids) > 8 else labels
565if len(input_ids) > 8:
566input_ids = input_ids[:8]
567elif isinstance(input_ids, (dict, BatchEncoding)):
568labels = input_ids.pop("labels", labels)
569
570generator_hidden_states = self.electra(
571input_ids,
572attention_mask,
573token_type_ids,
574position_ids,
575head_mask,
576inputs_embeds,
577output_attentions=output_attentions,
578output_hidden_states=output_hidden_states,
579training=training,
580)
581generator_sequence_output = generator_hidden_states[0]
582prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
583prediction_scores = self.generator_lm_head(prediction_scores, training=training)
584output = (prediction_scores,)
585output += generator_hidden_states[1:]
586
587if labels is not None:
588loss = self.compute_loss(labels, prediction_scores)
589output = (loss,) + output
590
591return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
592
593
594@add_start_docstrings(
595"""Electra model with a token classification head on top.
596
597Both the discriminator and generator may be loaded into this model.""",
598ELECTRA_START_DOCSTRING,
599)
600class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
601def __init__(self, config, **kwargs):
602super().__init__(config, **kwargs)
603
604self.electra = TFElectraMainLayer(config, name="electra")
605self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
606self.classifier = tf.keras.layers.Dense(
607config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
608)
609
610@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
611@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
612def call(
613self,
614inputs=None,
615attention_mask=None,
616token_type_ids=None,
617position_ids=None,
618head_mask=None,
619inputs_embeds=None,
620output_attentions=None,
621output_hidden_states=None,
622labels=None,
623training=False,
624):
625r"""
626labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
627Labels for computing the token classification loss.
628Indices should be in ``[0, ..., config.num_labels - 1]``.
629
630Returns:
631:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
632scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
633Classification scores (before SoftMax).
634hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
635tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
636of shape :obj:`(batch_size, sequence_length, hidden_size)`.
637
638Hidden-states of the model at the output of each layer plus the initial embedding outputs.
639attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
640tuple of :obj:`tf.Tensor` (one for each layer) of shape
641:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
642
643Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
644heads.
645"""
646if isinstance(inputs, (tuple, list)):
647labels = inputs[8] if len(inputs) > 8 else labels
648if len(inputs) > 8:
649inputs = inputs[:8]
650elif isinstance(inputs, (dict, BatchEncoding)):
651labels = inputs.pop("labels", labels)
652
653discriminator_hidden_states = self.electra(
654inputs,
655attention_mask,
656token_type_ids,
657position_ids,
658head_mask,
659inputs_embeds,
660output_attentions,
661output_hidden_states,
662training=training,
663)
664discriminator_sequence_output = discriminator_hidden_states[0]
665discriminator_sequence_output = self.dropout(discriminator_sequence_output)
666logits = self.classifier(discriminator_sequence_output)
667
668outputs = (logits,) + discriminator_hidden_states[1:]
669
670if labels is not None:
671loss = self.compute_loss(labels, logits)
672outputs = (loss,) + outputs
673
674return outputs # (loss), scores, (hidden_states), (attentions)
675
676
677@add_start_docstrings(
678"""Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
679the hidden-states output to compute `span start logits` and `span end logits`). """,
680ELECTRA_START_DOCSTRING,
681)
682class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
683def __init__(self, config, *inputs, **kwargs):
684super().__init__(config, *inputs, **kwargs)
685self.num_labels = config.num_labels
686
687self.electra = TFElectraMainLayer(config, name="electra")
688self.qa_outputs = tf.keras.layers.Dense(
689config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
690)
691
692@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
693@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
694def call(
695self,
696inputs=None,
697attention_mask=None,
698token_type_ids=None,
699position_ids=None,
700head_mask=None,
701inputs_embeds=None,
702output_attentions=None,
703output_hidden_states=None,
704start_positions=None,
705end_positions=None,
706training=False,
707):
708r"""
709start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
710Labels for position (index) of the start of the labelled span for computing the token classification loss.
711Positions are clamped to the length of the sequence (`sequence_length`).
712Position outside of the sequence are not taken into account for computing the loss.
713end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
714Labels for position (index) of the end of the labelled span for computing the token classification loss.
715Positions are clamped to the length of the sequence (`sequence_length`).
716Position outside of the sequence are not taken into account for computing the loss.
717
718Return:
719:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
720start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
721Span-start scores (before SoftMax).
722end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
723Span-end scores (before SoftMax).
724hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
725tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
726of shape :obj:`(batch_size, sequence_length, hidden_size)`.
727
728Hidden-states of the model at the output of each layer plus the initial embedding outputs.
729attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
730tuple of :obj:`tf.Tensor` (one for each layer) of shape
731:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
732
733Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
734heads.
735"""
736if isinstance(inputs, (tuple, list)):
737start_positions = inputs[8] if len(inputs) > 8 else start_positions
738end_positions = inputs[9] if len(inputs) > 9 else end_positions
739if len(inputs) > 8:
740inputs = inputs[:8]
741elif isinstance(inputs, (dict, BatchEncoding)):
742start_positions = inputs.pop("start_positions", start_positions)
743end_positions = inputs.pop("end_positions", start_positions)
744
745discriminator_hidden_states = self.electra(
746inputs,
747attention_mask,
748token_type_ids,
749position_ids,
750head_mask,
751inputs_embeds,
752output_attentions,
753output_hidden_states,
754training=training,
755)
756discriminator_sequence_output = discriminator_hidden_states[0]
757
758logits = self.qa_outputs(discriminator_sequence_output)
759start_logits, end_logits = tf.split(logits, 2, axis=-1)
760start_logits = tf.squeeze(start_logits, axis=-1)
761end_logits = tf.squeeze(end_logits, axis=-1)
762
763outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]
764
765if start_positions is not None and end_positions is not None:
766labels = {"start_position": start_positions}
767labels["end_position"] = end_positions
768loss = self.compute_loss(labels, outputs[:2])
769outputs = (loss,) + outputs
770
771return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
772