CSS-LM
645 строк · 28.9 Кб
1# coding=utf-8
2# Copyright 2018 Salesforce and HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" TF 2.0 CTRL model."""
17
18
19import logging20
21import numpy as np22import tensorflow as tf23
24from .configuration_ctrl import CTRLConfig25from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable26from .modeling_tf_utils import (27TFCausalLanguageModelingLoss,28TFPreTrainedModel,29TFSharedEmbeddings,30cast_bool_to_primitive,31keras_serializable,32shape_list,33)
34from .tokenization_utils import BatchEncoding35
36
37logger = logging.getLogger(__name__)38
39_TOKENIZER_FOR_DOC = "CtrlTokenizer"40
41TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [42"ctrl"43# See all CTRL models at https://huggingface.co/models?filter=ctrl44]
45
46
47def angle_defn(pos, i, d_model_size):48angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model_size))49return pos * angle_rates50
51
52def positional_encoding(position, d_model_size):53# create the sinusoidal pattern for the positional encoding54angle_rads = angle_defn(np.arange(position)[:, np.newaxis], np.arange(d_model_size)[np.newaxis, :], d_model_size)55
56sines = np.sin(angle_rads[:, 0::2])57cosines = np.cos(angle_rads[:, 1::2])58
59# pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1)[np.newaxis, ...], dtype=tf.float32)60pos_encoding = tf.cast(np.concatenate([sines, cosines], axis=-1), dtype=tf.float32)61return pos_encoding62
63
64def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):65# calculate attention66matmul_qk = tf.matmul(q, k, transpose_b=True)67
68dk = tf.cast(shape_list(k)[-1], tf.float32)69scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)70
71if mask is not None:72scaled_attention_logits += mask * -1e473
74if attention_mask is not None:75# Apply the attention mask76scaled_attention_logits = scaled_attention_logits + attention_mask77
78attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)79
80# Mask heads if we want to81if head_mask is not None:82attention_weights = attention_weights * head_mask83
84output = tf.matmul(attention_weights, v)85
86return output, attention_weights87
88
89class TFMultiHeadAttention(tf.keras.layers.Layer):90def __init__(self, d_model_size, num_heads, **kwargs):91super().__init__(**kwargs)92self.num_heads = num_heads93self.d_model_size = d_model_size94
95self.depth = int(d_model_size / self.num_heads)96
97self.Wq = tf.keras.layers.Dense(d_model_size, name="Wq")98self.Wk = tf.keras.layers.Dense(d_model_size, name="Wk")99self.Wv = tf.keras.layers.Dense(d_model_size, name="Wv")100
101self.dense = tf.keras.layers.Dense(d_model_size, name="dense")102
103def split_into_heads(self, x, batch_size):104x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))105return tf.transpose(x, perm=[0, 2, 1, 3])106
107def call(self, inputs, training=False):108v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs109batch_size = shape_list(q)[0]110
111q = self.Wq(q)112k = self.Wk(k)113v = self.Wv(v)114
115q = self.split_into_heads(q, batch_size)116k = self.split_into_heads(k, batch_size)117v = self.split_into_heads(v, batch_size)118
119if layer_past is not None:120past_key, past_value = tf.unstack(layer_past, axis=0)121k = tf.concat((past_key, k), axis=-2)122v = tf.concat((past_value, v), axis=-2)123
124# to cope with keras serialization125use_cache = cast_bool_to_primitive(use_cache, True)126
127if use_cache is True:128present = tf.stack((k, v), axis=0)129else:130present = (None,)131
132output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)133scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])134attn = output[1]135original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))136output = self.dense(original_size_attention)137
138outputs = (output, present)139if cast_bool_to_primitive(output_attentions) is True:140outputs = outputs + (attn,)141return outputs142
143
144class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):145def __init__(self, d_model_size, dff, **kwargs):146super().__init__(**kwargs)147
148self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")149self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")150
151def call(self, inputs, trainable=False):152dense_0_output = self.dense_0(inputs)153dense_2_output = self.dense_2(dense_0_output)154
155return dense_2_output156
157
158class TFEncoderLayer(tf.keras.layers.Layer):159def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):160super().__init__(**kwargs)161
162self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")163self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")164
165self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")166self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")167
168self.dropout1 = tf.keras.layers.Dropout(rate)169self.dropout2 = tf.keras.layers.Dropout(rate)170
171def call(self, inputs, training=False):172x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs173normed = self.layernorm1(x)174attn_outputs = self.multi_head_attention(175[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],176training=training,177)178attn_output = attn_outputs[0]179attn_output = self.dropout1(attn_output, training=training)180out1 = x + attn_output181
182out2 = self.layernorm2(out1)183ffn_output = self.ffn(out2)184ffn_output = self.dropout2(ffn_output, training=training)185out2 = out1 + ffn_output186
187outputs = (out2,) + attn_outputs[1:]188return outputs189
190
191@keras_serializable
192class TFCTRLMainLayer(tf.keras.layers.Layer):193config_class = CTRLConfig194
195def __init__(self, config, **kwargs):196super().__init__(**kwargs)197self.output_hidden_states = config.output_hidden_states198self.output_attentions = config.output_attentions199self.use_cache = config.use_cache200
201self.d_model_size = config.n_embd202self.num_layers = config.n_layer203
204self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)205
206self.w = TFSharedEmbeddings(207config.vocab_size, config.n_embd, initializer_range=config.initializer_range, name="w"208)209
210self.dropout = tf.keras.layers.Dropout(config.embd_pdrop)211self.h = [212TFEncoderLayer(213config.n_embd,214config.n_head,215config.dff,216config.resid_pdrop,217config.layer_norm_epsilon,218name="h_._{}".format(i),219)220for i in range(config.n_layer)221]222self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="layernorm")223
224def get_input_embeddings(self):225return self.w226
227def set_input_embeddings(self, value):228self.w.weight = value229self.w.vocab_size = value.shape[0]230
231def _resize_token_embeddings(self, new_num_tokens):232raise NotImplementedError233
234def _prune_heads(self, heads_to_prune):235""" Prunes heads of the model.236heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
237"""
238raise NotImplementedError239
240def call(241self,242inputs,243past=None,244attention_mask=None,245token_type_ids=None,246position_ids=None,247head_mask=None,248inputs_embeds=None,249use_cache=None,250output_attentions=None,251output_hidden_states=None,252training=False,253):254
255if isinstance(inputs, (tuple, list)):256input_ids = inputs[0]257past = inputs[1] if len(inputs) > 1 else past258attention_mask = inputs[2] if len(inputs) > 2 else attention_mask259token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids260position_ids = inputs[4] if len(inputs) > 4 else position_ids261head_mask = inputs[5] if len(inputs) > 5 else head_mask262inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds263use_cache = inputs[7] if len(inputs) > 7 else use_cache264output_attentions = inputs[8] if len(inputs) > 8 else output_attentions265output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states266assert len(inputs) <= 10, "Too many inputs."267elif isinstance(inputs, (dict, BatchEncoding)):268input_ids = inputs.get("input_ids")269past = inputs.get("past", past)270attention_mask = inputs.get("attention_mask", attention_mask)271token_type_ids = inputs.get("token_type_ids", token_type_ids)272position_ids = inputs.get("position_ids", position_ids)273head_mask = inputs.get("head_mask", head_mask)274inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)275use_cache = inputs.get("use_cache", use_cache)276output_attentions = inputs.get("output_attentions", output_attentions)277output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)278assert len(inputs) <= 10, "Too many inputs."279else:280input_ids = inputs281
282output_attentions = output_attentions if output_attentions is not None else self.output_attentions283output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states284use_cache = use_cache if use_cache is not None else self.use_cache285
286# If using past key value states, only the last tokens287# should be given as an input288if past is not None:289if input_ids is not None:290input_ids = input_ids[:, -1:]291if inputs_embeds is not None:292inputs_embeds = inputs_embeds[:, -1:]293if token_type_ids is not None:294token_type_ids = token_type_ids[:, -1:]295
296if input_ids is not None and inputs_embeds is not None:297raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")298elif input_ids is not None:299input_shape = shape_list(input_ids)300input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])301elif inputs_embeds is not None:302input_shape = shape_list(inputs_embeds)[:-1]303else:304raise ValueError("You have to specify either input_ids or inputs_embeds")305
306if past is None:307past_length = 0308past = [None] * len(self.h)309else:310past_length = shape_list(past[0][0])[-2]311if position_ids is None:312position_ids = tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32)[tf.newaxis, :]313position_ids = tf.tile(position_ids, [input_shape[0], 1])314
315# Attention mask.316if attention_mask is not None:317# We create a 3D attention mask from a 2D tensor mask.318# Sizes are [batch_size, 1, 1, to_seq_length]319# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]320# this attention mask is more simple than the triangular masking of causal attention321# used in OpenAI GPT, we just need to prepare the broadcast dimension here.322attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]323
324# Since attention_mask is 1.0 for positions we want to attend and 0.0 for325# masked positions, this operation will create a tensor which is 0.0 for326# positions we want to attend and -10000.0 for masked positions.327# Since we are adding it to the raw scores before the softmax, this is328# effectively the same as removing these entirely.329
330attention_mask = tf.cast(attention_mask, tf.float32)331attention_mask = (1.0 - attention_mask) * -10000.0332else:333attention_mask = None334
335# Prepare head mask if needed336# 1.0 in head_mask indicate we keep the head337# attention_probs has shape bsz x n_heads x N x N338# head_mask has shape n_layer x batch x n_heads x N x N339if head_mask is not None:340raise NotImplementedError341else:342head_mask = [None] * self.num_layers343
344if token_type_ids is not None:345token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]])346token_type_embeds = self.w(token_type_ids, mode="embedding")347token_type_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))348else:349token_type_embeds = 0350position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]])351
352if inputs_embeds is None:353inputs_embeds = self.w(input_ids, mode="embedding")354seq_len = input_shape[-1]355mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)356
357inputs_embeds *= tf.math.sqrt(tf.cast(self.d_model_size, tf.float32))358
359pos_embeds = tf.gather(self.pos_encoding, position_ids)360
361hidden_states = inputs_embeds + pos_embeds + token_type_embeds362
363hidden_states = self.dropout(hidden_states, training=training)364
365output_shape = input_shape + [shape_list(hidden_states)[-1]]366presents = ()367all_hidden_states = ()368all_attentions = []369for i, (h, layer_past) in enumerate(zip(self.h, past)):370if cast_bool_to_primitive(output_hidden_states) is True:371all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)372outputs = h(373[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],374training=training,375)376hidden_states, present = outputs[:2]377
378if use_cache is True:379presents = presents + (present,)380
381if cast_bool_to_primitive(output_attentions) is True:382all_attentions.append(outputs[2])383
384hidden_states = self.layernorm(hidden_states)385hidden_states = tf.reshape(hidden_states, output_shape)386if cast_bool_to_primitive(output_hidden_states) is True:387all_hidden_states = all_hidden_states + (hidden_states,)388
389outputs = (hidden_states,)390if use_cache is True:391outputs = outputs + (presents,)392if cast_bool_to_primitive(output_hidden_states) is True:393outputs = outputs + (all_hidden_states,)394if cast_bool_to_primitive(output_attentions) is True:395# let the number of heads free (-1) so we can extract attention even after head pruning396attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]397all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)398outputs = outputs + (all_attentions,)399return outputs400
401
402class TFCTRLPreTrainedModel(TFPreTrainedModel):403""" An abstract class to handle weights initialization and404a simple interface for downloading and loading pretrained models.
405"""
406
407config_class = CTRLConfig408base_model_prefix = "transformer"409
410
411CTRL_START_DOCSTRING = r"""412
413.. note::
414TF 2.0 models accepts two formats as inputs:
415
416- having all inputs as keyword arguments (like PyTorch models), or
417- having all inputs as a list, tuple or dict in the first positional arguments.
418
419This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
420all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
421
422If you choose this second option, there are three possibilities you can use to gather all the input Tensors
423in the first positional argument :
424
425- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
426- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
427:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
428- a dictionary with one or several input Tensors associated to the input names given in the docstring:
429:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
430
431Parameters:
432config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
433Initializing with a config file does not load the weights associated with the model, only the configuration.
434Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
435"""
436
437CTRL_INPUTS_DOCSTRING = r"""438Args:
439input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, input_ids_length)`):
440:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states).
441
442Indices of input sequence tokens in the vocabulary.
443
444If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids (see `past`).
445
446Indices can be obtained using :class:`transformers.CTRLTokenizer`.
447See :func:`transformers.PreTrainedTokenizer.encode` and
448:func:`transformers.PreTrainedTokenizer.__call__` for details.
449
450`What are input IDs? <../glossary.html#input-ids>`__
451past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
452Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
453(see `past` output below). Can be used to speed up sequential decoding.
454The token ids which have their past given to this model
455should not be passed as input ids as they have already been computed.
456attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
457Mask to avoid performing attention on padding token indices.
458Mask values selected in ``[0, 1]``:
459``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
460
461`What are attention masks? <../glossary.html#attention-mask>`__
462token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
463Segment token indices to indicate first and second portions of the inputs.
464Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
465corresponds to a `sentence B` token
466
467`What are token type IDs? <../glossary.html#token-type-ids>`_
468position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
469Indices of positions of each input sequence tokens in the position embeddings.
470Selected in the range ``[0, config.max_position_embeddings - 1]``.
471
472`What are position IDs? <../glossary.html#position-ids>`_
473head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
474Mask to nullify selected heads of the self-attention modules.
475Mask values selected in ``[0, 1]``:
476:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
477inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
478Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
479This is useful if you want more control over how to convert `input_ids` indices into associated vectors
480than the model's internal embedding lookup matrix.
481use_cache (:obj:`bool`):
482If `use_cache` is True, `past` key value states are returned and
483can be used to speed up decoding (see `past`). Defaults to `True`.
484training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
485Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
486(if set to :obj:`False`) for evaluation.
487output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
488If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
489"""
490
491
492@add_start_docstrings(493"The bare CTRL Model transformer outputting raw hidden-states without any specific head on top.",494CTRL_START_DOCSTRING,495)
496class TFCTRLModel(TFCTRLPreTrainedModel):497def __init__(self, config, *inputs, **kwargs):498super().__init__(config, *inputs, **kwargs)499self.transformer = TFCTRLMainLayer(config, name="transformer")500
501@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)502@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")503def call(self, inputs, **kwargs):504r"""505Return:
506:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
507last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
508Sequence of hidden-states at the last layer of the model.
509past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
510Contains pre-computed hidden-states (key and values in the attention blocks).
511Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
512should not be passed as input ids as they have already been computed.
513hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
514tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
515of shape :obj:`(batch_size, sequence_length, hidden_size)`.
516
517Hidden-states of the model at the output of each layer plus the initial embedding outputs.
518attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
519tuple of :obj:`tf.Tensor` (one for each layer) of shape
520:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
521
522Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
523heads.
524"""
525outputs = self.transformer(inputs, **kwargs)526return outputs527
528
529class TFCTRLLMHead(tf.keras.layers.Layer):530def __init__(self, config, input_embeddings, **kwargs):531super().__init__(**kwargs)532self.vocab_size = config.vocab_size533
534# The output weights are the same as the input embeddings, but there is535# an output-only bias for each token.536self.input_embeddings = input_embeddings537
538def build(self, input_shape):539self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")540super().build(input_shape)541
542def call(self, hidden_states):543hidden_states = self.input_embeddings(hidden_states, mode="linear")544hidden_states = hidden_states + self.bias545return hidden_states546
547
548@add_start_docstrings(549"""The CTRL Model transformer with a language modeling head on top550(linear layer with weights tied to the input embeddings). """,551CTRL_START_DOCSTRING,552)
553class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):554def __init__(self, config, *inputs, **kwargs):555super().__init__(config, *inputs, **kwargs)556self.transformer = TFCTRLMainLayer(config, name="transformer")557
558self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")559
560def get_output_embeddings(self):561return self.lm_head.input_embeddings562
563def prepare_inputs_for_generation(self, inputs, past, **kwargs):564# only last token for inputs_ids if past is defined in kwargs565if past:566inputs = tf.expand_dims(inputs[:, -1], -1)567
568return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}569
570@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)571@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")572def call(573self,574inputs,575past=None,576attention_mask=None,577token_type_ids=None,578position_ids=None,579head_mask=None,580inputs_embeds=None,581use_cache=None,582output_attentions=None,583output_hidden_states=None,584labels=None,585training=False,586):587r"""588labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
589Labels for computing the cross entropy classification loss.
590Indices should be in ``[0, ..., config.vocab_size - 1]``.
591
592Return:
593:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
594prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
595Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
596past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
597Contains pre-computed hidden-states (key and values in the attention blocks).
598Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
599should not be passed as input ids as they have already been computed.
600hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
601tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
602of shape :obj:`(batch_size, sequence_length, hidden_size)`.
603
604Hidden-states of the model at the output of each layer plus the initial embedding outputs.
605attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
606tuple of :obj:`tf.Tensor` (one for each layer) of shape
607:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
608
609Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
610heads.
611"""
612if isinstance(inputs, (tuple, list)):613labels = inputs[10] if len(inputs) > 10 else labels614if len(inputs) > 10:615inputs = inputs[:10]616elif isinstance(inputs, (dict, BatchEncoding)):617labels = inputs.pop("labels", labels)618
619transformer_outputs = self.transformer(620inputs,621past=past,622attention_mask=attention_mask,623token_type_ids=token_type_ids,624position_ids=position_ids,625head_mask=head_mask,626inputs_embeds=inputs_embeds,627use_cache=use_cache,628output_attentions=output_attentions,629output_hidden_states=output_hidden_states,630training=training,631)632
633hidden_states = transformer_outputs[0]634
635logits = self.lm_head(hidden_states)636
637outputs = (logits,) + transformer_outputs[1:]638if labels is not None:639# shift labels to the left and cut last logit token640logits = logits[:, :-1]641labels = labels[:, 1:]642loss = self.compute_loss(labels, logits)643outputs = (loss,) + outputs644
645return outputs # lm_logits, presents, (all hidden_states), (attentions)646