CSS-LM
1320 строк · 62.1 Кб
1# coding=utf-8
2# Copyright 2018 T5 Authors and The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" TF 2.0 T5 model. """
17
18
19import copy
20import itertools
21import logging
22import math
23import warnings
24
25import tensorflow as tf
26
27from .configuration_t5 import T5Config
28from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
29from .modeling_tf_utils import (
30TFCausalLanguageModelingLoss,
31TFPreTrainedModel,
32TFSharedEmbeddings,
33cast_bool_to_primitive,
34keras_serializable,
35shape_list,
36)
37from .tokenization_utils import BatchEncoding
38
39
40logger = logging.getLogger(__name__)
41
42_TOKENIZER_FOR_DOC = "T5Tokenizer"
43
44TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
45"t5-small",
46"t5-base",
47"t5-large",
48"t5-3b",
49"t5-11b",
50# See all T5 models at https://huggingface.co/models?filter=t5
51]
52
53####################################################
54# TF 2.0 Models are constructed using Keras imperative API by sub-classing
55# - tf.keras.layers.Layer for the layers and
56# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
57####################################################
58
59
60class TFT5LayerNorm(tf.keras.layers.Layer):
61def __init__(self, epsilon=1e-6, **kwargs):
62""" Construct a layernorm module in the T5 style
63No bias and no substraction of mean.
64"""
65super().__init__(**kwargs)
66self.variance_epsilon = epsilon
67
68def build(self, input_shape):
69"""Build shared word embedding layer """
70self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
71super().build(input_shape)
72
73def call(self, x):
74variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
75x = x * tf.math.rsqrt(variance + self.variance_epsilon)
76return self.weight * x
77
78
79class TFT5DenseReluDense(tf.keras.layers.Layer):
80def __init__(self, config, **kwargs):
81super().__init__(**kwargs)
82self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
83self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
84self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
85self.act = tf.keras.activations.relu
86
87def call(self, hidden_states, training=False):
88h = self.wi(hidden_states)
89h = self.act(h)
90h = self.dropout(h, training=training)
91h = self.wo(h)
92return h
93
94
95class TFT5LayerFF(tf.keras.layers.Layer):
96def __init__(self, config, **kwargs):
97super().__init__(**kwargs)
98self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
99self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
100self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
101
102def call(self, hidden_states, training=False):
103norm_x = self.layer_norm(hidden_states)
104y = self.DenseReluDense(norm_x, training=training)
105layer_output = hidden_states + self.dropout(y, training=training)
106return layer_output
107
108
109class TFT5Attention(tf.keras.layers.Layer):
110NEW_ID = itertools.count()
111
112def __init__(self, config, has_relative_attention_bias=False, **kwargs):
113super().__init__(**kwargs)
114self.layer_id = next(TFT5Attention.NEW_ID)
115self.is_decoder = config.is_decoder
116self.use_cache = config.use_cache
117self.has_relative_attention_bias = has_relative_attention_bias
118
119self.relative_attention_num_buckets = config.relative_attention_num_buckets
120self.d_model = config.d_model
121self.d_kv = config.d_kv
122self.n_heads = config.num_heads
123self.inner_dim = self.n_heads * self.d_kv
124
125# Mesh TensorFlow initialization to avoid scaling before softmax
126self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
127self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
128self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
129self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
130self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
131
132if self.has_relative_attention_bias:
133self.relative_attention_bias = tf.keras.layers.Embedding(
134self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias",
135)
136self.pruned_heads = set()
137
138def prune_heads(self, heads):
139raise NotImplementedError
140
141@staticmethod
142def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
143"""
144Adapted from Mesh Tensorflow:
145https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
146
147Translate relative position to a bucket number for relative attention.
148The relative position is defined as memory_position - query_position, i.e.
149the distance in tokens from the attending position to the attended-to
150position. If bidirectional=False, then positive relative positions are
151invalid.
152We use smaller buckets for small absolute relative_position and larger buckets
153for larger absolute relative_positions. All relative positions >=max_distance
154map to the same bucket. All relative positions <=-max_distance map to the
155same bucket. This should allow for more graceful generalization to longer
156sequences than the model has been trained on.
157Args:
158relative_position: an int32 Tensor
159bidirectional: a boolean - whether the attention is bidirectional
160num_buckets: an integer
161max_distance: an integer
162Returns:
163a Tensor with the same shape as relative_position, containing int32
164values in the range [0, num_buckets)
165"""
166ret = 0
167n = -relative_position
168if bidirectional:
169num_buckets //= 2
170ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
171n = tf.math.abs(n)
172else:
173n = tf.math.maximum(n, 0)
174# now n is in the range [0, inf)
175max_exact = num_buckets // 2
176is_small = tf.math.less(n, max_exact)
177val_if_large = max_exact + tf.dtypes.cast(
178tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
179/ math.log(max_distance / max_exact)
180* (num_buckets - max_exact),
181tf.int32,
182)
183val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
184ret += tf.where(is_small, n, val_if_large)
185return ret
186
187def compute_bias(self, qlen, klen):
188""" Compute binned relative position bias """
189context_position = tf.range(qlen)[:, None]
190memory_position = tf.range(klen)[None, :]
191relative_position = memory_position - context_position # shape (qlen, klen)
192rp_bucket = self._relative_position_bucket(
193relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets,
194)
195values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
196values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
197return values
198
199def call(
200self,
201input,
202mask=None,
203kv=None,
204position_bias=None,
205cache=None,
206past_key_value_state=None,
207head_mask=None,
208query_length=None,
209use_cache=False,
210training=False,
211output_attentions=False,
212):
213"""
214Self-attention (if kv is None) or attention over source sentence (provided by kv).
215"""
216# Input is (bs, qlen, dim)
217# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
218# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
219bs, qlen, dim = shape_list(input)
220
221if past_key_value_state is not None:
222assert self.is_decoder is True, "Encoder cannot cache past key value states"
223assert (
224len(past_key_value_state) == 2
225), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
226len(past_key_value_state)
227)
228real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length
229else:
230real_qlen = qlen
231
232if kv is None:
233klen = real_qlen
234else:
235klen = shape_list(kv)[1]
236
237def shape(x):
238""" projection """
239return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3))
240
241def unshape(x):
242""" compute context """
243return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
244
245q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
246
247if kv is None:
248k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
249v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
250elif past_key_value_state is None:
251k = v = kv
252k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
253v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
254
255if past_key_value_state is not None:
256if kv is None:
257k_, v_ = past_key_value_state
258k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
259v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
260else:
261k, v = past_key_value_state
262
263# to cope with keras serialization
264if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
265present_key_value_state = ((k, v),)
266else:
267present_key_value_state = (None,)
268
269scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
270
271if position_bias is None:
272if not self.has_relative_attention_bias:
273raise ValueError("No position_bias provided and no weights to compute position_bias")
274position_bias = self.compute_bias(real_qlen, klen)
275
276# if key and values are already calculated
277# we want only the last query position bias
278if past_key_value_state is not None:
279position_bias = position_bias[:, :, -1:, :]
280
281if mask is not None:
282position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
283
284scores += position_bias
285weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
286weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
287
288# Mask heads if we want to
289if head_mask is not None:
290weights = weights * head_mask
291
292context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
293context = unshape(context) # (bs, qlen, dim)
294
295context = self.o(context)
296
297outputs = (context,) + present_key_value_state
298
299if cast_bool_to_primitive(output_attentions, True) is True:
300outputs = outputs + (weights,)
301if self.has_relative_attention_bias:
302outputs = outputs + (position_bias,)
303return outputs
304
305
306class TFT5LayerSelfAttention(tf.keras.layers.Layer):
307def __init__(self, config, has_relative_attention_bias=False, **kwargs):
308super().__init__(**kwargs)
309self.SelfAttention = TFT5Attention(
310config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention",
311)
312self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
313self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
314
315def call(
316self,
317hidden_states,
318attention_mask=None,
319position_bias=None,
320head_mask=None,
321past_key_value_state=None,
322use_cache=False,
323output_attentions=False,
324training=False,
325):
326norm_x = self.layer_norm(hidden_states)
327attention_output = self.SelfAttention(
328norm_x,
329mask=attention_mask,
330position_bias=position_bias,
331head_mask=head_mask,
332past_key_value_state=past_key_value_state,
333use_cache=use_cache,
334output_attentions=output_attentions,
335training=training,
336)
337y = attention_output[0]
338layer_output = hidden_states + self.dropout(y, training=training)
339outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
340return outputs
341
342
343class TFT5LayerCrossAttention(tf.keras.layers.Layer):
344def __init__(self, config, has_relative_attention_bias=False, **kwargs):
345super().__init__(**kwargs)
346self.EncDecAttention = TFT5Attention(
347config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention",
348)
349self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
350self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
351
352def call(
353self,
354hidden_states,
355kv,
356attention_mask=None,
357position_bias=None,
358head_mask=None,
359past_key_value_state=None,
360query_length=None,
361use_cache=False,
362output_attentions=False,
363training=False,
364):
365norm_x = self.layer_norm(hidden_states)
366attention_output = self.EncDecAttention(
367norm_x,
368mask=attention_mask,
369kv=kv,
370position_bias=position_bias,
371head_mask=head_mask,
372past_key_value_state=past_key_value_state,
373query_length=query_length,
374use_cache=use_cache,
375output_attentions=output_attentions,
376training=training,
377)
378y = attention_output[0]
379layer_output = hidden_states + self.dropout(y, training=training)
380outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
381return outputs
382
383
384class TFT5Block(tf.keras.layers.Layer):
385def __init__(self, config, has_relative_attention_bias=False, **kwargs):
386super().__init__(**kwargs)
387self.is_decoder = config.is_decoder
388self.layer = []
389self.layer.append(
390TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",)
391)
392if self.is_decoder:
393self.layer.append(
394TFT5LayerCrossAttention(
395config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
396)
397)
398
399self.layer.append(TFT5LayerFF(config, name="layer_._{}".format(len(self.layer))))
400
401def call(
402self,
403hidden_states,
404attention_mask=None,
405position_bias=None,
406encoder_hidden_states=None,
407encoder_attention_mask=None,
408encoder_decoder_position_bias=None,
409head_mask=None,
410past_key_value_state=None,
411use_cache=False,
412output_attentions=False,
413training=False,
414):
415
416if past_key_value_state is not None:
417assert self.is_decoder, "Only decoder can use `past_key_value_states`"
418expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
419
420error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
421expected_num_past_key_value_states,
422"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
423len(past_key_value_state),
424)
425assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
426
427self_attn_past_key_value_state = past_key_value_state[:2]
428cross_attn_past_key_value_state = past_key_value_state[2:]
429else:
430self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
431
432self_attention_outputs = self.layer[0](
433hidden_states,
434attention_mask=attention_mask,
435position_bias=position_bias,
436head_mask=head_mask,
437past_key_value_state=self_attn_past_key_value_state,
438use_cache=use_cache,
439output_attentions=output_attentions,
440training=training,
441)
442hidden_states, present_key_value_state = self_attention_outputs[:2]
443attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
444
445if self.is_decoder and encoder_hidden_states is not None:
446# the actual query length is unknown for cross attention
447# if using past key value states. Need to inject it here
448if present_key_value_state is not None:
449query_length = shape_list(present_key_value_state[0])[2]
450else:
451query_length = None
452
453cross_attention_outputs = self.layer[1](
454hidden_states,
455kv=encoder_hidden_states,
456attention_mask=encoder_attention_mask,
457position_bias=encoder_decoder_position_bias,
458head_mask=head_mask,
459past_key_value_state=cross_attn_past_key_value_state,
460query_length=query_length,
461use_cache=use_cache,
462output_attentions=output_attentions,
463training=training,
464)
465hidden_states = cross_attention_outputs[0]
466# Combine self attn and cross attn key value states
467if present_key_value_state is not None:
468present_key_value_state = present_key_value_state + cross_attention_outputs[1]
469
470# Keep cross-attention outputs and relative position weights
471attention_outputs = attention_outputs + cross_attention_outputs[2:]
472
473# Apply Feed Forward layer
474hidden_states = self.layer[-1](hidden_states, training=training)
475outputs = (hidden_states,)
476
477# Add attentions if we output them
478outputs = outputs + (present_key_value_state,) + attention_outputs
479return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
480
481
482class _NoLayerEmbedTokens:
483"""
484this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
485class to avoid problem with weight restoring. Also it makes sure that the layer is
486called from the correct scope to avoid problem with saving/storing the correct weights
487"""
488
489def __init__(self, layer, abs_scope_name=None):
490self._layer = layer
491self._abs_scope_name = abs_scope_name
492
493def call(self, inputs, mode="embedding"):
494if self._abs_scope_name is None:
495return self._layer.call(inputs, mode)
496
497# if an abs scope name is given to the embedding variable, call variable from absolute scope
498with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
499with tf.name_scope(abs_scope_name.original_name_scope):
500return self._layer.call(inputs, mode)
501
502def __call__(self, inputs, mode="embedding"):
503if self._abs_scope_name is None:
504return self._layer(inputs, mode)
505
506# if an abs scope name is given to the embedding variable, call variable from absolute scope
507with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
508with tf.name_scope(abs_scope_name.original_name_scope):
509return self._layer(inputs, mode)
510
511
512####################################################
513# The full model without a specific pretrained or finetuning head is
514# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
515####################################################
516@keras_serializable
517class TFT5MainLayer(tf.keras.layers.Layer):
518config_class = T5Config
519
520def __init__(self, config, embed_tokens=None, **kwargs):
521super().__init__(**kwargs)
522self.output_hidden_states = config.output_hidden_states
523self.output_attentions = config.output_attentions
524self.use_cache = config.use_cache
525
526self.embed_tokens = embed_tokens
527self.is_decoder = config.is_decoder
528
529self.config = config
530self.num_hidden_layers = config.num_layers
531
532self.block = [
533TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),)
534for i in range(config.num_layers)
535]
536self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
537self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
538
539def get_input_embeddings(self):
540return self.embed_tokens
541
542def get_output_embeddings(self):
543return self.embed_tokens
544
545def set_embed_tokens(self, embed_tokens):
546self.embed_tokens = embed_tokens
547
548def _resize_token_embeddings(self, new_num_tokens):
549raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
550
551def _prune_heads(self, heads_to_prune):
552raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
553
554def call(
555self,
556inputs,
557attention_mask=None,
558encoder_hidden_states=None,
559encoder_attention_mask=None,
560inputs_embeds=None,
561head_mask=None,
562past_key_value_states=None,
563use_cache=None,
564output_attentions=None,
565output_hidden_states=None,
566training=False,
567):
568if isinstance(inputs, (tuple, list)):
569input_ids = inputs[0]
570attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
571encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
572encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
573inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
574head_mask = inputs[5] if len(inputs) > 5 else head_mask
575past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
576use_cache = inputs[7] if len(inputs) > 7 else use_cache
577output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
578output_hidden_states = inputs[9] if len(inputs) > 8 else output_hidden_states
579assert len(inputs) <= 10, "Too many inputs."
580elif isinstance(inputs, (dict, BatchEncoding)):
581input_ids = inputs.get("input_ids")
582attention_mask = inputs.get("attention_mask", attention_mask)
583encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
584encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
585inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
586head_mask = inputs.get("head_mask", head_mask)
587past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
588use_cache = inputs.get("use_cache", use_cache)
589output_attentions = inputs.get("output_attentions", output_attentions)
590output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
591assert len(inputs) <= 10, "Too many inputs."
592else:
593input_ids = inputs
594
595output_attentions = output_attentions if output_attentions is not None else self.output_attentions
596output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
597use_cache = use_cache if use_cache is not None else self.use_cache
598
599if input_ids is not None and inputs_embeds is not None:
600raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
601elif input_ids is not None:
602input_shape = shape_list(input_ids)
603input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
604elif inputs_embeds is not None:
605input_shape = shape_list(inputs_embeds)[:-1]
606else:
607raise ValueError("You have to specify either inputs or inputs_embeds")
608
609if inputs_embeds is None:
610assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
611inputs_embeds = self.embed_tokens(input_ids)
612
613batch_size, seq_length = input_shape
614
615if past_key_value_states is not None:
616assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
617input_shape, (batch_size, 1)
618)
619# required mask seq length can be calculated via length of past
620# key value states and seq_length = 1 for the last token
621mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
622else:
623mask_seq_length = seq_length
624
625if attention_mask is None:
626attention_mask = tf.fill((batch_size, mask_seq_length), 1)
627if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
628encoder_seq_length = shape_list(encoder_hidden_states)[1]
629encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
630
631# initialize past_key_value_states with `None` if past does not exist
632if past_key_value_states is None:
633past_key_value_states = [None] * len(self.block)
634
635# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
636# ourselves in which case we just need to make it broadcastable to all heads.
637attention_mask = tf.cast(attention_mask, dtype=tf.float32)
638num_dims_attention_mask = len(shape_list(attention_mask))
639if num_dims_attention_mask == 3:
640extended_attention_mask = attention_mask[:, None, :, :]
641elif num_dims_attention_mask == 2:
642# Provided a padding mask of dimensions [batch_size, mask_seq_length]
643# - if the model is a decoder, apply a causal mask in addition to the padding mask
644# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
645if self.is_decoder:
646seq_ids = tf.range(mask_seq_length)
647causal_mask = tf.less_equal(
648tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None],
649)
650causal_mask = tf.cast(causal_mask, dtype=tf.float32)
651extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652if past_key_value_states[0] is not None:
653extended_attention_mask = extended_attention_mask[:, :, -1:, :]
654else:
655extended_attention_mask = attention_mask[:, None, None, :]
656
657# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
658# masked positions, this operation will create a tensor which is 0.0 for
659# positions we want to attend and -10000.0 for masked positions.
660# Since we are adding it to the raw scores before the softmax, this is
661# effectively the same as removing these entirely.
662
663# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
664# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
665# extended_attention_mask = tf.math.equal(extended_attention_mask,
666# tf.transpose(extended_attention_mask, perm=(-1, -2)))
667
668extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
669
670if self.is_decoder and encoder_attention_mask is not None:
671# If a 2D ou 3D attention mask is provided for the cross-attention
672# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
673# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
674encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32)
675num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
676if num_dims_encoder_attention_mask == 3:
677encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
678if num_dims_encoder_attention_mask == 2:
679encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
680
681# T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
682# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
683# encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
684# tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
685
686encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
687else:
688encoder_extended_attention_mask = None
689
690assert head_mask is None, "Head mask not supported"
691head_mask = [None] * self.num_hidden_layers
692
693present_key_value_states = ()
694all_hidden_states = ()
695all_attentions = ()
696position_bias = None
697encoder_decoder_position_bias = None
698
699hidden_states = self.dropout(inputs_embeds, training=training)
700
701for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
702if cast_bool_to_primitive(output_hidden_states) is True:
703all_hidden_states = all_hidden_states + (hidden_states,)
704
705layer_outputs = layer_module(
706hidden_states,
707attention_mask=extended_attention_mask,
708position_bias=position_bias,
709encoder_hidden_states=encoder_hidden_states,
710encoder_attention_mask=encoder_extended_attention_mask,
711encoder_decoder_position_bias=encoder_decoder_position_bias,
712head_mask=head_mask[i],
713past_key_value_state=past_key_value_state,
714use_cache=use_cache,
715output_attentions=output_attentions,
716training=training,
717)
718# layer_outputs is a tuple with:
719# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
720hidden_states, present_key_value_state = layer_outputs[:2]
721if i == 0:
722# We share the position biases between the layers - the first layer store them
723# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
724position_bias = layer_outputs[3 if output_attentions else 2]
725if self.is_decoder and encoder_hidden_states is not None:
726encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
727# append next layer key value states
728present_key_value_states = present_key_value_states + (present_key_value_state,)
729
730if cast_bool_to_primitive(output_attentions) is True:
731all_attentions = all_attentions + (layer_outputs[2],)
732
733hidden_states = self.final_layer_norm(hidden_states)
734hidden_states = self.dropout(hidden_states, training=training)
735
736# Add last layer
737if cast_bool_to_primitive(output_hidden_states) is True:
738all_hidden_states = all_hidden_states + (hidden_states,)
739
740outputs = (hidden_states,)
741# need to check if is decoder here as well for special cases when using keras compile
742if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
743outputs = outputs + (present_key_value_states,)
744if cast_bool_to_primitive(output_hidden_states) is True:
745outputs = outputs + (all_hidden_states,)
746if cast_bool_to_primitive(output_attentions) is True:
747outputs = outputs + (all_attentions,)
748return outputs # last-layer hidden state, (all hidden states), (all attentions)
749
750
751####################################################
752# TFT5PreTrainedModel is a sub-class of tf.keras.Model
753# which take care of loading and saving pretrained weights
754# and various common utilities.
755# Here you just need to specify a few (self-explanatory)
756# pointers for your model.
757####################################################
758class TFT5PreTrainedModel(TFPreTrainedModel):
759""" An abstract class to handle weights initialization and
760a simple interface for downloading and loading pretrained models.
761"""
762
763config_class = T5Config
764base_model_prefix = "transformer"
765
766@property
767def dummy_inputs(self):
768inputs = tf.constant(DUMMY_INPUTS)
769input_mask = tf.constant(DUMMY_MASK)
770dummy_inputs = {
771"input_ids": inputs,
772"decoder_input_ids": inputs,
773"decoder_attention_mask": input_mask,
774}
775return dummy_inputs
776
777def _shift_right(self, input_ids):
778decoder_start_token_id = self.config.decoder_start_token_id
779pad_token_id = self.config.pad_token_id
780
781assert (
782decoder_start_token_id is not None
783), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
784
785# shift inputs to the right
786shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32)
787shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
788start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
789shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
790
791assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
792# replace possible -100 values in labels by `pad_token_id`
793shifted_input_ids = tf.where(
794shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
795)
796
797assert tf.math.reduce_any(
798shifted_input_ids >= 0
799).numpy(), "Verify that `labels` has only positive values and -100"
800
801return shifted_input_ids
802
803
804T5_START_DOCSTRING = r"""
805The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
806<https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
807Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
808It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting.
809
810This model is a `tf.keras.Model <https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model>`__
811sub-class. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
812general usage and behavior.
813
814Note on the model inputs:
815TF 2.0 models accepts two formats as inputs:
816
817- having all inputs as keyword arguments (like PyTorch models), or
818- having all inputs as a list, tuple or dict in the first positional arguments.
819
820This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
821
822If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
823
824- a single Tensor with inputs only and nothing else: `model(inputs_ids)`
825- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
826`model([inputs, attention_mask])` or `model([inputs, attention_mask, token_type_ids])`
827- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
828`model({'inputs': inputs, 'token_type_ids': token_type_ids})`
829
830Parameters:
831config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
832Initializing with a config file does not load the weights associated with the model, only the configuration.
833Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
834"""
835
836T5_INPUTS_DOCSTRING = r"""
837Args:
838inputs are usually used as a `dict` (see T5 description above for more information) containing all the following.
839
840inputs (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
841Indices of input sequence tokens in the vocabulary.
842T5 is a model with relative position embeddings so you should be able to pad the inputs on
843the right or the left.
844Indices can be obtained using :class:`transformers.T5Tokenizer`.
845To know more on how to prepare :obj:`inputs` for pre-training take a look at
846`T5 Training <./t5.html#training>`__.
847See :func:`transformers.PreTrainedTokenizer.encode` and
848:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
849decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
850Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
851If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
852attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
853Mask to avoid performing attention on padding token indices.
854Mask values selected in ``[0, 1]``:
855``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
856encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`):
857Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
858`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
859Used in the cross-attention of the decoder.
860decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
861Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
862decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
863Contains pre-computed key and value hidden-states of the attention blocks.
864Can be used to speed up decoding.
865If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
866(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
868If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
869inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
870Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
871This is useful if you want more control over how to convert `inputs` indices into associated vectors
872than the model's internal embedding lookup matrix.
873decoder_inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
874Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
875This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
876than the model's internal embedding lookup matrix.
877To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
878`T5 Training <./t5.html#training>`__.
879head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
880Mask to nullify selected heads of the self-attention modules.
881Mask values selected in ``[0, 1]``:
882``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
883output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
884If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
885"""
886
887
888@add_start_docstrings(
889"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
890T5_START_DOCSTRING,
891)
892class TFT5Model(TFT5PreTrainedModel):
893def __init__(self, config, *inputs, **kwargs):
894super().__init__(config, *inputs, **kwargs)
895self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
896
897# retrieve correct absolute scope for embed token wrapper
898with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
899pass
900
901embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
902
903encoder_config = copy.deepcopy(config)
904encoder_config.use_cache = False
905self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
906
907decoder_config = copy.deepcopy(config)
908decoder_config.is_decoder = True
909self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
910
911def get_input_embeddings(self):
912return self.shared
913
914def get_output_embeddings(self):
915return self.shared
916
917def set_input_embeddings(self, new_embeddings):
918self.shared.weight = new_embeddings
919self.shared.vocab_size = self.shared.weight.shape[0]
920# retrieve correct absolute scope for embed token wrapper
921with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
922pass
923embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
924self.encoder.set_embed_tokens(embed_tokens)
925self.decoder.set_embed_tokens(embed_tokens)
926
927def get_encoder(self):
928return self.encoder
929
930def get_decoder(self):
931return self.decoder
932
933@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
934def call(
935self,
936inputs,
937attention_mask=None,
938encoder_outputs=None,
939inputs_embeds=None,
940head_mask=None,
941decoder_past_key_value_states=None,
942decoder_input_ids=None,
943decoder_attention_mask=None,
944decoder_inputs_embeds=None,
945use_cache=None,
946output_attentions=None,
947output_hidden_states=None,
948training=False,
949):
950r"""
951Returns:
952:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
953last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
954Sequence of hidden-states at the output of the last layer of the model.
955If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
956decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
957Contains pre-computed key and value hidden-states of the attention blocks.
958Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
959Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
960hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
961tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
962of shape :obj:`(batch_size, sequence_length, hidden_size)`.
963
964Hidden-states of the model at the output of each layer plus the initial embedding outputs.
965attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
966tuple of :obj:`tf.Tensor` (one for each layer) of shape
967:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
968
969Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
970heads.
971
972Examples::
973
974>>> from transformers import T5Tokenizer, TFT5Model
975
976>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
977>>> model = TFT5Model.from_pretrained('t5-small')
978>>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
979>>> outputs = model(inputs, decoder_input_ids=inputs)
980>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
981
982"""
983if isinstance(inputs, (tuple, list)):
984input_ids = inputs[0]
985attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
986encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
987inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
988head_mask = inputs[4] if len(inputs) > 4 else head_mask
989decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
990decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
991decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
992decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
993use_cache = inputs[9] if len(inputs) > 9 else use_cache
994output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
995output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
996assert len(inputs) <= 12, "Too many inputs."
997elif isinstance(inputs, (dict, BatchEncoding)):
998if "inputs" in inputs:
999warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
1000input_ids = inputs.get("inputs")
1001input_ids = inputs.get("input_ids")
1002attention_mask = inputs.get("attention_mask", attention_mask)
1003encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
1004inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1005head_mask = inputs.get("head_mask", head_mask)
1006decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
1007decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
1008decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
1009decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
1010use_cache = inputs.get("use_cache", use_cache)
1011output_attentions = inputs.get("output_attentions", output_attentions)
1012output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1013assert len(inputs) <= 12, "Too many inputs."
1014else:
1015input_ids = inputs
1016
1017use_cache = use_cache if use_cache is not None else self.config.use_cache
1018
1019# Encode if needed (training, first prediction pass)
1020if encoder_outputs is None:
1021encoder_outputs = self.encoder(
1022[
1023input_ids,
1024attention_mask,
1025None,
1026None,
1027inputs_embeds,
1028head_mask,
1029None,
1030False,
1031output_attentions,
1032output_hidden_states,
1033],
1034training=training,
1035)
1036
1037hidden_states = encoder_outputs[0]
1038
1039# If decoding with past key value states, only the last tokens
1040# should be given as an input
1041if decoder_past_key_value_states is not None:
1042if decoder_input_ids is not None:
1043decoder_input_ids = decoder_input_ids[:, -1:]
1044if decoder_inputs_embeds is not None:
1045decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
1046
1047# Decode
1048decoder_outputs = self.decoder(
1049[
1050decoder_input_ids,
1051decoder_attention_mask,
1052hidden_states,
1053attention_mask,
1054decoder_inputs_embeds,
1055head_mask,
1056decoder_past_key_value_states,
1057use_cache,
1058output_attentions,
1059output_hidden_states,
1060],
1061training=training,
1062)
1063
1064if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
1065past = ((encoder_outputs, decoder_outputs[1]),)
1066decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
1067
1068return decoder_outputs + encoder_outputs
1069
1070
1071@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1072class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss):
1073def __init__(self, config, *inputs, **kwargs):
1074super().__init__(config, *inputs, **kwargs)
1075self.model_dim = config.d_model
1076
1077self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
1078
1079# retrieve correct absolute scope for embed token wrapper
1080with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
1081pass
1082
1083embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
1084
1085encoder_config = copy.deepcopy(config)
1086encoder_config.use_cache = False
1087self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
1088
1089decoder_config = copy.deepcopy(config)
1090decoder_config.is_decoder = True
1091self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
1092
1093def get_input_embeddings(self):
1094return self.shared
1095
1096def get_output_embeddings(self):
1097return self.shared
1098
1099def set_input_embeddings(self, new_embeddings):
1100self.shared.weight = new_embeddings
1101# retrieve correct absolute scope for embed token wrapper
1102with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
1103pass
1104embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
1105self.encoder.set_embed_tokens(embed_tokens)
1106self.decoder.set_embed_tokens(embed_tokens)
1107
1108def get_encoder(self):
1109return self.encoder
1110
1111def get_decoder(self):
1112return self.decoder
1113
1114@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
1115def call(
1116self,
1117inputs,
1118attention_mask=None,
1119encoder_outputs=None,
1120inputs_embeds=None,
1121head_mask=None,
1122decoder_past_key_value_states=None,
1123decoder_input_ids=None,
1124decoder_attention_mask=None,
1125decoder_inputs_embeds=None,
1126use_cache=None,
1127output_attentions=None,
1128output_hidden_states=None,
1129labels=None,
1130training=False,
1131):
1132r"""
1133labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1134Labels for computing the cross entropy classification loss.
1135Indices should be in ``[0, ..., config.vocab_size - 1]``.
1136
1137Returns:
1138:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
1139prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1140Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1141decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
1142Contains pre-computed key and value hidden-states of the attention blocks.
1143Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
1144Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
1145hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1146tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1147of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1148
1149Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1150attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1151tuple of :obj:`tf.Tensor` (one for each layer) of shape
1152:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1153
1154Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1155heads.
1156
1157Examples::
1158
1159>>> from transformers import T5Tokenizer, TFT5ForConditionalGeneration
1160
1161>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1162>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
1163>>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
1164>>> outputs = model(inputs, decoder_input_ids=inputs)
1165>>> prediction_scores = outputs[0]
1166
1167>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1168>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
1169>>> inputs = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="tf") # Batch size 1
1170>>> result = model.generate(inputs)
1171
1172"""
1173if isinstance(inputs, (tuple, list)):
1174input_ids = inputs[0]
1175attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
1176encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
1177inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
1178head_mask = inputs[4] if len(inputs) > 4 else head_mask
1179decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
1180decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
1181decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
1182decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
1183use_cache = inputs[9] if len(inputs) > 9 else use_cache
1184output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
1185output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
1186labels = inputs[12] if len(inputs) > 12 else labels
1187assert len(inputs) <= 13, "Too many inputs."
1188elif isinstance(inputs, (dict, BatchEncoding)):
1189if "inputs" in inputs:
1190warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
1191input_ids = inputs.get("inputs")
1192input_ids = inputs.get("input_ids")
1193attention_mask = inputs.get("attention_mask", attention_mask)
1194encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
1195inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1196head_mask = inputs.get("head_mask", head_mask)
1197decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
1198decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
1199decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
1200decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
1201use_cache = inputs.get("use_cache", use_cache)
1202output_attentions = inputs.get("output_attentions", output_attentions)
1203output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1204labels = inputs.get("labels", labels)
1205assert len(inputs) <= 13, "Too many inputs."
1206else:
1207input_ids = inputs
1208
1209use_cache = use_cache if use_cache is not None else self.config.use_cache
1210
1211# Encode if needed (training, first prediction pass)
1212if encoder_outputs is None:
1213# Convert encoder inputs in embeddings if needed
1214encoder_outputs = self.encoder(
1215[
1216input_ids,
1217attention_mask,
1218None,
1219None,
1220inputs_embeds,
1221head_mask,
1222None,
1223False,
1224output_attentions,
1225output_hidden_states,
1226],
1227training=training,
1228)
1229
1230hidden_states = encoder_outputs[0]
1231
1232if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1233# get decoder inputs from shifting lm labels to the right
1234decoder_input_ids = self._shift_right(labels)
1235
1236# If decoding with past key value states, only the last tokens
1237# should be given as an input
1238if decoder_past_key_value_states is not None:
1239if decoder_input_ids is not None:
1240decoder_input_ids = decoder_input_ids[:, -1:]
1241if decoder_inputs_embeds is not None:
1242decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
1243
1244# Decode
1245decoder_outputs = self.decoder(
1246[
1247decoder_input_ids,
1248decoder_attention_mask,
1249hidden_states,
1250attention_mask,
1251decoder_inputs_embeds,
1252head_mask,
1253decoder_past_key_value_states,
1254use_cache,
1255output_attentions,
1256output_hidden_states,
1257],
1258training=training,
1259)
1260
1261# insert decoder past at right place
1262# to speed up decoding
1263if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
1264past = ((encoder_outputs, decoder_outputs[1]),)
1265decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
1266
1267sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
1268embed_tokens = self.get_output_embeddings()
1269logits = embed_tokens(sequence_output, mode="linear")
1270decoder_outputs = (logits,) + decoder_outputs[1:]
1271
1272if labels is not None:
1273loss = self.compute_loss(labels, logits)
1274decoder_outputs = (loss,) + decoder_outputs
1275
1276return decoder_outputs + encoder_outputs
1277
1278def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
1279assert past is not None, "past has to be defined for encoder_outputs"
1280
1281# first step
1282if len(past) < 2:
1283encoder_outputs, decoder_past_key_value_states = past, None
1284else:
1285encoder_outputs, decoder_past_key_value_states = past[0], past[1]
1286
1287return {
1288"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
1289"decoder_input_ids": inputs, # inputs are the decoder_input_ids
1290"decoder_past_key_value_states": decoder_past_key_value_states,
1291"encoder_outputs": encoder_outputs,
1292"attention_mask": attention_mask,
1293"use_cache": use_cache,
1294}
1295
1296def _reorder_cache(self, past, beam_idx):
1297# if decoder past is not included in output
1298# speedy decoding is disabled and no need to reorder
1299
1300if len(past) < 2:
1301logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1302return past
1303
1304decoder_past = past[1]
1305past = (past[0],)
1306reordered_decoder_past = ()
1307
1308for layer_past_states in decoder_past:
1309# get the correct batch idx from layer past batch dim
1310# batch dim of `past` is at 2nd position
1311reordered_layer_past_states = ()
1312for layer_past_state in layer_past_states:
1313# need to set correct `past` for each of the four key / value states
1314reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),)
1315
1316assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0])
1317assert len(reordered_layer_past_states) == len(layer_past_states)
1318
1319reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1320return past + (reordered_decoder_past,)
1321