CSS-LM

Форк
0
/
modeling_tf_t5.py 
1320 строк · 62.1 Кб
1
# coding=utf-8
2
# Copyright 2018 T5 Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" TF 2.0 T5 model. """
17

18

19
import copy
20
import itertools
21
import logging
22
import math
23
import warnings
24

25
import tensorflow as tf
26

27
from .configuration_t5 import T5Config
28
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
29
from .modeling_tf_utils import (
30
    TFCausalLanguageModelingLoss,
31
    TFPreTrainedModel,
32
    TFSharedEmbeddings,
33
    cast_bool_to_primitive,
34
    keras_serializable,
35
    shape_list,
36
)
37
from .tokenization_utils import BatchEncoding
38

39

40
logger = logging.getLogger(__name__)
41

42
_TOKENIZER_FOR_DOC = "T5Tokenizer"
43

44
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
    "t5-small",
46
    "t5-base",
47
    "t5-large",
48
    "t5-3b",
49
    "t5-11b",
50
    # See all T5 models at https://huggingface.co/models?filter=t5
51
]
52

53
####################################################
54
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
55
# - tf.keras.layers.Layer for the layers and
56
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
57
####################################################
58

59

60
class TFT5LayerNorm(tf.keras.layers.Layer):
61
    def __init__(self, epsilon=1e-6, **kwargs):
62
        """ Construct a layernorm module in the T5 style
63
            No bias and no substraction of mean.
64
        """
65
        super().__init__(**kwargs)
66
        self.variance_epsilon = epsilon
67

68
    def build(self, input_shape):
69
        """Build shared word embedding layer """
70
        self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
71
        super().build(input_shape)
72

73
    def call(self, x):
74
        variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
75
        x = x * tf.math.rsqrt(variance + self.variance_epsilon)
76
        return self.weight * x
77

78

79
class TFT5DenseReluDense(tf.keras.layers.Layer):
80
    def __init__(self, config, **kwargs):
81
        super().__init__(**kwargs)
82
        self.wi = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi")
83
        self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
84
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
85
        self.act = tf.keras.activations.relu
86

87
    def call(self, hidden_states, training=False):
88
        h = self.wi(hidden_states)
89
        h = self.act(h)
90
        h = self.dropout(h, training=training)
91
        h = self.wo(h)
92
        return h
93

94

95
class TFT5LayerFF(tf.keras.layers.Layer):
96
    def __init__(self, config, **kwargs):
97
        super().__init__(**kwargs)
98
        self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
99
        self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
100
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
101

102
    def call(self, hidden_states, training=False):
103
        norm_x = self.layer_norm(hidden_states)
104
        y = self.DenseReluDense(norm_x, training=training)
105
        layer_output = hidden_states + self.dropout(y, training=training)
106
        return layer_output
107

108

109
class TFT5Attention(tf.keras.layers.Layer):
110
    NEW_ID = itertools.count()
111

112
    def __init__(self, config, has_relative_attention_bias=False, **kwargs):
113
        super().__init__(**kwargs)
114
        self.layer_id = next(TFT5Attention.NEW_ID)
115
        self.is_decoder = config.is_decoder
116
        self.use_cache = config.use_cache
117
        self.has_relative_attention_bias = has_relative_attention_bias
118

119
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
120
        self.d_model = config.d_model
121
        self.d_kv = config.d_kv
122
        self.n_heads = config.num_heads
123
        self.inner_dim = self.n_heads * self.d_kv
124

125
        # Mesh TensorFlow initialization to avoid scaling before softmax
126
        self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
127
        self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="k")
128
        self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="v")
129
        self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name="o")
130
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
131

132
        if self.has_relative_attention_bias:
133
            self.relative_attention_bias = tf.keras.layers.Embedding(
134
                self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias",
135
            )
136
        self.pruned_heads = set()
137

138
    def prune_heads(self, heads):
139
        raise NotImplementedError
140

141
    @staticmethod
142
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
143
        """
144
        Adapted from Mesh Tensorflow:
145
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
146

147
        Translate relative position to a bucket number for relative attention.
148
        The relative position is defined as memory_position - query_position, i.e.
149
        the distance in tokens from the attending position to the attended-to
150
        position.  If bidirectional=False, then positive relative positions are
151
        invalid.
152
        We use smaller buckets for small absolute relative_position and larger buckets
153
        for larger absolute relative_positions.  All relative positions >=max_distance
154
        map to the same bucket.  All relative positions <=-max_distance map to the
155
        same bucket.  This should allow for more graceful generalization to longer
156
        sequences than the model has been trained on.
157
        Args:
158
            relative_position: an int32 Tensor
159
            bidirectional: a boolean - whether the attention is bidirectional
160
            num_buckets: an integer
161
            max_distance: an integer
162
        Returns:
163
            a Tensor with the same shape as relative_position, containing int32
164
            values in the range [0, num_buckets)
165
        """
166
        ret = 0
167
        n = -relative_position
168
        if bidirectional:
169
            num_buckets //= 2
170
            ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
171
            n = tf.math.abs(n)
172
        else:
173
            n = tf.math.maximum(n, 0)
174
        # now n is in the range [0, inf)
175
        max_exact = num_buckets // 2
176
        is_small = tf.math.less(n, max_exact)
177
        val_if_large = max_exact + tf.dtypes.cast(
178
            tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
179
            / math.log(max_distance / max_exact)
180
            * (num_buckets - max_exact),
181
            tf.int32,
182
        )
183
        val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
184
        ret += tf.where(is_small, n, val_if_large)
185
        return ret
186

187
    def compute_bias(self, qlen, klen):
188
        """ Compute binned relative position bias """
189
        context_position = tf.range(qlen)[:, None]
190
        memory_position = tf.range(klen)[None, :]
191
        relative_position = memory_position - context_position  # shape (qlen, klen)
192
        rp_bucket = self._relative_position_bucket(
193
            relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets,
194
        )
195
        values = self.relative_attention_bias(rp_bucket)  # shape (qlen, klen, num_heads)
196
        values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0)  # shape (1, num_heads, qlen, klen)
197
        return values
198

199
    def call(
200
        self,
201
        input,
202
        mask=None,
203
        kv=None,
204
        position_bias=None,
205
        cache=None,
206
        past_key_value_state=None,
207
        head_mask=None,
208
        query_length=None,
209
        use_cache=False,
210
        training=False,
211
        output_attentions=False,
212
    ):
213
        """
214
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
215
        """
216
        # Input is (bs, qlen, dim)
217
        # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
218
        # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
219
        bs, qlen, dim = shape_list(input)
220

221
        if past_key_value_state is not None:
222
            assert self.is_decoder is True, "Encoder cannot cache past key value states"
223
            assert (
224
                len(past_key_value_state) == 2
225
            ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
226
                len(past_key_value_state)
227
            )
228
            real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length
229
        else:
230
            real_qlen = qlen
231

232
        if kv is None:
233
            klen = real_qlen
234
        else:
235
            klen = shape_list(kv)[1]
236

237
        def shape(x):
238
            """  projection """
239
            return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3))
240

241
        def unshape(x):
242
            """  compute context """
243
            return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
244

245
        q = shape(self.q(input))  # (bs, n_heads, qlen, dim_per_head)
246

247
        if kv is None:
248
            k = shape(self.k(input))  # (bs, n_heads, qlen, dim_per_head)
249
            v = shape(self.v(input))  # (bs, n_heads, qlen, dim_per_head)
250
        elif past_key_value_state is None:
251
            k = v = kv
252
            k = shape(self.k(k))  # (bs, n_heads, qlen, dim_per_head)
253
            v = shape(self.v(v))  # (bs, n_heads, qlen, dim_per_head)
254

255
        if past_key_value_state is not None:
256
            if kv is None:
257
                k_, v_ = past_key_value_state
258
                k = tf.concat([k_, k], axis=2)  # (bs, n_heads, klen, dim_per_head)
259
                v = tf.concat([v_, v], axis=2)  # (bs, n_heads, klen, dim_per_head)
260
            else:
261
                k, v = past_key_value_state
262

263
        # to cope with keras serialization
264
        if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
265
            present_key_value_state = ((k, v),)
266
        else:
267
            present_key_value_state = (None,)
268

269
        scores = tf.einsum("bnqd,bnkd->bnqk", q, k)  # (bs, n_heads, qlen, klen)
270

271
        if position_bias is None:
272
            if not self.has_relative_attention_bias:
273
                raise ValueError("No position_bias provided and no weights to compute position_bias")
274
            position_bias = self.compute_bias(real_qlen, klen)
275

276
            # if key and values are already calculated
277
            # we want only the last query position bias
278
            if past_key_value_state is not None:
279
                position_bias = position_bias[:, :, -1:, :]
280

281
            if mask is not None:
282
                position_bias = position_bias + mask  # (bs, n_heads, qlen, klen)
283

284
        scores += position_bias
285
        weights = tf.nn.softmax(scores, axis=-1)  # (bs, n_heads, qlen, klen)
286
        weights = self.dropout(weights, training=training)  # (bs, n_heads, qlen, klen)
287

288
        # Mask heads if we want to
289
        if head_mask is not None:
290
            weights = weights * head_mask
291

292
        context = tf.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)
293
        context = unshape(context)  # (bs, qlen, dim)
294

295
        context = self.o(context)
296

297
        outputs = (context,) + present_key_value_state
298

299
        if cast_bool_to_primitive(output_attentions, True) is True:
300
            outputs = outputs + (weights,)
301
        if self.has_relative_attention_bias:
302
            outputs = outputs + (position_bias,)
303
        return outputs
304

305

306
class TFT5LayerSelfAttention(tf.keras.layers.Layer):
307
    def __init__(self, config, has_relative_attention_bias=False, **kwargs):
308
        super().__init__(**kwargs)
309
        self.SelfAttention = TFT5Attention(
310
            config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention",
311
        )
312
        self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
313
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
314

315
    def call(
316
        self,
317
        hidden_states,
318
        attention_mask=None,
319
        position_bias=None,
320
        head_mask=None,
321
        past_key_value_state=None,
322
        use_cache=False,
323
        output_attentions=False,
324
        training=False,
325
    ):
326
        norm_x = self.layer_norm(hidden_states)
327
        attention_output = self.SelfAttention(
328
            norm_x,
329
            mask=attention_mask,
330
            position_bias=position_bias,
331
            head_mask=head_mask,
332
            past_key_value_state=past_key_value_state,
333
            use_cache=use_cache,
334
            output_attentions=output_attentions,
335
            training=training,
336
        )
337
        y = attention_output[0]
338
        layer_output = hidden_states + self.dropout(y, training=training)
339
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
340
        return outputs
341

342

343
class TFT5LayerCrossAttention(tf.keras.layers.Layer):
344
    def __init__(self, config, has_relative_attention_bias=False, **kwargs):
345
        super().__init__(**kwargs)
346
        self.EncDecAttention = TFT5Attention(
347
            config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention",
348
        )
349
        self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
350
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
351

352
    def call(
353
        self,
354
        hidden_states,
355
        kv,
356
        attention_mask=None,
357
        position_bias=None,
358
        head_mask=None,
359
        past_key_value_state=None,
360
        query_length=None,
361
        use_cache=False,
362
        output_attentions=False,
363
        training=False,
364
    ):
365
        norm_x = self.layer_norm(hidden_states)
366
        attention_output = self.EncDecAttention(
367
            norm_x,
368
            mask=attention_mask,
369
            kv=kv,
370
            position_bias=position_bias,
371
            head_mask=head_mask,
372
            past_key_value_state=past_key_value_state,
373
            query_length=query_length,
374
            use_cache=use_cache,
375
            output_attentions=output_attentions,
376
            training=training,
377
        )
378
        y = attention_output[0]
379
        layer_output = hidden_states + self.dropout(y, training=training)
380
        outputs = (layer_output,) + attention_output[1:]  # add attentions if we output them
381
        return outputs
382

383

384
class TFT5Block(tf.keras.layers.Layer):
385
    def __init__(self, config, has_relative_attention_bias=False, **kwargs):
386
        super().__init__(**kwargs)
387
        self.is_decoder = config.is_decoder
388
        self.layer = []
389
        self.layer.append(
390
            TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",)
391
        )
392
        if self.is_decoder:
393
            self.layer.append(
394
                TFT5LayerCrossAttention(
395
                    config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
396
                )
397
            )
398

399
        self.layer.append(TFT5LayerFF(config, name="layer_._{}".format(len(self.layer))))
400

401
    def call(
402
        self,
403
        hidden_states,
404
        attention_mask=None,
405
        position_bias=None,
406
        encoder_hidden_states=None,
407
        encoder_attention_mask=None,
408
        encoder_decoder_position_bias=None,
409
        head_mask=None,
410
        past_key_value_state=None,
411
        use_cache=False,
412
        output_attentions=False,
413
        training=False,
414
    ):
415

416
        if past_key_value_state is not None:
417
            assert self.is_decoder, "Only decoder can use `past_key_value_states`"
418
            expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
419

420
            error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
421
                expected_num_past_key_value_states,
422
                "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
423
                len(past_key_value_state),
424
            )
425
            assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
426

427
            self_attn_past_key_value_state = past_key_value_state[:2]
428
            cross_attn_past_key_value_state = past_key_value_state[2:]
429
        else:
430
            self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
431

432
        self_attention_outputs = self.layer[0](
433
            hidden_states,
434
            attention_mask=attention_mask,
435
            position_bias=position_bias,
436
            head_mask=head_mask,
437
            past_key_value_state=self_attn_past_key_value_state,
438
            use_cache=use_cache,
439
            output_attentions=output_attentions,
440
            training=training,
441
        )
442
        hidden_states, present_key_value_state = self_attention_outputs[:2]
443
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights
444

445
        if self.is_decoder and encoder_hidden_states is not None:
446
            # the actual query length is unknown for cross attention
447
            # if using past key value states. Need to inject it here
448
            if present_key_value_state is not None:
449
                query_length = shape_list(present_key_value_state[0])[2]
450
            else:
451
                query_length = None
452

453
            cross_attention_outputs = self.layer[1](
454
                hidden_states,
455
                kv=encoder_hidden_states,
456
                attention_mask=encoder_attention_mask,
457
                position_bias=encoder_decoder_position_bias,
458
                head_mask=head_mask,
459
                past_key_value_state=cross_attn_past_key_value_state,
460
                query_length=query_length,
461
                use_cache=use_cache,
462
                output_attentions=output_attentions,
463
                training=training,
464
            )
465
            hidden_states = cross_attention_outputs[0]
466
            # Combine self attn and cross attn key value states
467
            if present_key_value_state is not None:
468
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]
469

470
            # Keep cross-attention outputs and relative position weights
471
            attention_outputs = attention_outputs + cross_attention_outputs[2:]
472

473
        # Apply Feed Forward layer
474
        hidden_states = self.layer[-1](hidden_states, training=training)
475
        outputs = (hidden_states,)
476

477
        # Add attentions if we output them
478
        outputs = outputs + (present_key_value_state,) + attention_outputs
479
        return outputs  # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
480

481

482
class _NoLayerEmbedTokens:
483
    """
484
     this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer'
485
     class to avoid problem with weight restoring. Also it makes sure that the layer is
486
     called from the correct scope to avoid problem with saving/storing the correct weights
487
    """
488

489
    def __init__(self, layer, abs_scope_name=None):
490
        self._layer = layer
491
        self._abs_scope_name = abs_scope_name
492

493
    def call(self, inputs, mode="embedding"):
494
        if self._abs_scope_name is None:
495
            return self._layer.call(inputs, mode)
496

497
        # if an abs scope name is given to the embedding variable, call variable from absolute scope
498
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
499
            with tf.name_scope(abs_scope_name.original_name_scope):
500
                return self._layer.call(inputs, mode)
501

502
    def __call__(self, inputs, mode="embedding"):
503
        if self._abs_scope_name is None:
504
            return self._layer(inputs, mode)
505

506
        # if an abs scope name is given to the embedding variable, call variable from absolute scope
507
        with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
508
            with tf.name_scope(abs_scope_name.original_name_scope):
509
                return self._layer(inputs, mode)
510

511

512
####################################################
513
# The full model without a specific pretrained or finetuning head is
514
# provided as a tf.keras.layers.Layer usually called "TFT5MainLayer"
515
####################################################
516
@keras_serializable
517
class TFT5MainLayer(tf.keras.layers.Layer):
518
    config_class = T5Config
519

520
    def __init__(self, config, embed_tokens=None, **kwargs):
521
        super().__init__(**kwargs)
522
        self.output_hidden_states = config.output_hidden_states
523
        self.output_attentions = config.output_attentions
524
        self.use_cache = config.use_cache
525

526
        self.embed_tokens = embed_tokens
527
        self.is_decoder = config.is_decoder
528

529
        self.config = config
530
        self.num_hidden_layers = config.num_layers
531

532
        self.block = [
533
            TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),)
534
            for i in range(config.num_layers)
535
        ]
536
        self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
537
        self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
538

539
    def get_input_embeddings(self):
540
        return self.embed_tokens
541

542
    def get_output_embeddings(self):
543
        return self.embed_tokens
544

545
    def set_embed_tokens(self, embed_tokens):
546
        self.embed_tokens = embed_tokens
547

548
    def _resize_token_embeddings(self, new_num_tokens):
549
        raise NotImplementedError  # Not implemented yet in the library fr TF 2.0 models
550

551
    def _prune_heads(self, heads_to_prune):
552
        raise NotImplementedError  # Not implemented yet in the library fr TF 2.0 models
553

554
    def call(
555
        self,
556
        inputs,
557
        attention_mask=None,
558
        encoder_hidden_states=None,
559
        encoder_attention_mask=None,
560
        inputs_embeds=None,
561
        head_mask=None,
562
        past_key_value_states=None,
563
        use_cache=None,
564
        output_attentions=None,
565
        output_hidden_states=None,
566
        training=False,
567
    ):
568
        if isinstance(inputs, (tuple, list)):
569
            input_ids = inputs[0]
570
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
571
            encoder_hidden_states = inputs[2] if len(inputs) > 2 else encoder_hidden_states
572
            encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
573
            inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
574
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
575
            past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
576
            use_cache = inputs[7] if len(inputs) > 7 else use_cache
577
            output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
578
            output_hidden_states = inputs[9] if len(inputs) > 8 else output_hidden_states
579
            assert len(inputs) <= 10, "Too many inputs."
580
        elif isinstance(inputs, (dict, BatchEncoding)):
581
            input_ids = inputs.get("input_ids")
582
            attention_mask = inputs.get("attention_mask", attention_mask)
583
            encoder_hidden_states = inputs.get("encoder_hidden_states", encoder_hidden_states)
584
            encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
585
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
586
            head_mask = inputs.get("head_mask", head_mask)
587
            past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
588
            use_cache = inputs.get("use_cache", use_cache)
589
            output_attentions = inputs.get("output_attentions", output_attentions)
590
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
591
            assert len(inputs) <= 10, "Too many inputs."
592
        else:
593
            input_ids = inputs
594

595
        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
596
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
597
        use_cache = use_cache if use_cache is not None else self.use_cache
598

599
        if input_ids is not None and inputs_embeds is not None:
600
            raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
601
        elif input_ids is not None:
602
            input_shape = shape_list(input_ids)
603
            input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
604
        elif inputs_embeds is not None:
605
            input_shape = shape_list(inputs_embeds)[:-1]
606
        else:
607
            raise ValueError("You have to specify either inputs or inputs_embeds")
608

609
        if inputs_embeds is None:
610
            assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
611
            inputs_embeds = self.embed_tokens(input_ids)
612

613
        batch_size, seq_length = input_shape
614

615
        if past_key_value_states is not None:
616
            assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
617
                input_shape, (batch_size, 1)
618
            )
619
            # required mask seq length can be calculated via length of past
620
            # key value states and seq_length = 1 for the last token
621
            mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
622
        else:
623
            mask_seq_length = seq_length
624

625
        if attention_mask is None:
626
            attention_mask = tf.fill((batch_size, mask_seq_length), 1)
627
        if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
628
            encoder_seq_length = shape_list(encoder_hidden_states)[1]
629
            encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
630

631
        # initialize past_key_value_states with `None` if past does not exist
632
        if past_key_value_states is None:
633
            past_key_value_states = [None] * len(self.block)
634

635
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
636
        # ourselves in which case we just need to make it broadcastable to all heads.
637
        attention_mask = tf.cast(attention_mask, dtype=tf.float32)
638
        num_dims_attention_mask = len(shape_list(attention_mask))
639
        if num_dims_attention_mask == 3:
640
            extended_attention_mask = attention_mask[:, None, :, :]
641
        elif num_dims_attention_mask == 2:
642
            # Provided a padding mask of dimensions [batch_size, mask_seq_length]
643
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
644
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
645
            if self.is_decoder:
646
                seq_ids = tf.range(mask_seq_length)
647
                causal_mask = tf.less_equal(
648
                    tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None],
649
                )
650
                causal_mask = tf.cast(causal_mask, dtype=tf.float32)
651
                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
652
                if past_key_value_states[0] is not None:
653
                    extended_attention_mask = extended_attention_mask[:, :, -1:, :]
654
            else:
655
                extended_attention_mask = attention_mask[:, None, None, :]
656

657
        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
658
        # masked positions, this operation will create a tensor which is 0.0 for
659
        # positions we want to attend and -10000.0 for masked positions.
660
        # Since we are adding it to the raw scores before the softmax, this is
661
        # effectively the same as removing these entirely.
662

663
        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
664
        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
665
        # extended_attention_mask = tf.math.equal(extended_attention_mask,
666
        #                                         tf.transpose(extended_attention_mask, perm=(-1, -2)))
667

668
        extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
669

670
        if self.is_decoder and encoder_attention_mask is not None:
671
            # If a 2D ou 3D attention mask is provided for the cross-attention
672
            # we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
673
            # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
674
            encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32)
675
            num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
676
            if num_dims_encoder_attention_mask == 3:
677
                encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
678
            if num_dims_encoder_attention_mask == 2:
679
                encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
680

681
            # T5 has a mask that can compare sequence ids, we can simulate this here with this transposistion
682
            # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
683
            # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
684
            #                                         tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
685

686
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
687
        else:
688
            encoder_extended_attention_mask = None
689

690
        assert head_mask is None, "Head mask not supported"
691
        head_mask = [None] * self.num_hidden_layers
692

693
        present_key_value_states = ()
694
        all_hidden_states = ()
695
        all_attentions = ()
696
        position_bias = None
697
        encoder_decoder_position_bias = None
698

699
        hidden_states = self.dropout(inputs_embeds, training=training)
700

701
        for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
702
            if cast_bool_to_primitive(output_hidden_states) is True:
703
                all_hidden_states = all_hidden_states + (hidden_states,)
704

705
            layer_outputs = layer_module(
706
                hidden_states,
707
                attention_mask=extended_attention_mask,
708
                position_bias=position_bias,
709
                encoder_hidden_states=encoder_hidden_states,
710
                encoder_attention_mask=encoder_extended_attention_mask,
711
                encoder_decoder_position_bias=encoder_decoder_position_bias,
712
                head_mask=head_mask[i],
713
                past_key_value_state=past_key_value_state,
714
                use_cache=use_cache,
715
                output_attentions=output_attentions,
716
                training=training,
717
            )
718
            # layer_outputs is a tuple with:
719
            # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
720
            hidden_states, present_key_value_state = layer_outputs[:2]
721
            if i == 0:
722
                # We share the position biases between the layers - the first layer store them
723
                # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
724
                position_bias = layer_outputs[3 if output_attentions else 2]
725
                if self.is_decoder and encoder_hidden_states is not None:
726
                    encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
727
            # append next layer key value states
728
            present_key_value_states = present_key_value_states + (present_key_value_state,)
729

730
            if cast_bool_to_primitive(output_attentions) is True:
731
                all_attentions = all_attentions + (layer_outputs[2],)
732

733
        hidden_states = self.final_layer_norm(hidden_states)
734
        hidden_states = self.dropout(hidden_states, training=training)
735

736
        # Add last layer
737
        if cast_bool_to_primitive(output_hidden_states) is True:
738
            all_hidden_states = all_hidden_states + (hidden_states,)
739

740
        outputs = (hidden_states,)
741
        # need to check if is decoder here as well for special cases when using keras compile
742
        if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
743
            outputs = outputs + (present_key_value_states,)
744
        if cast_bool_to_primitive(output_hidden_states) is True:
745
            outputs = outputs + (all_hidden_states,)
746
        if cast_bool_to_primitive(output_attentions) is True:
747
            outputs = outputs + (all_attentions,)
748
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)
749

750

751
####################################################
752
# TFT5PreTrainedModel is a sub-class of tf.keras.Model
753
# which take care of loading and saving pretrained weights
754
# and various common utilities.
755
# Here you just need to specify a few (self-explanatory)
756
# pointers for your model.
757
####################################################
758
class TFT5PreTrainedModel(TFPreTrainedModel):
759
    """ An abstract class to handle weights initialization and
760
        a simple interface for downloading and loading pretrained models.
761
    """
762

763
    config_class = T5Config
764
    base_model_prefix = "transformer"
765

766
    @property
767
    def dummy_inputs(self):
768
        inputs = tf.constant(DUMMY_INPUTS)
769
        input_mask = tf.constant(DUMMY_MASK)
770
        dummy_inputs = {
771
            "input_ids": inputs,
772
            "decoder_input_ids": inputs,
773
            "decoder_attention_mask": input_mask,
774
        }
775
        return dummy_inputs
776

777
    def _shift_right(self, input_ids):
778
        decoder_start_token_id = self.config.decoder_start_token_id
779
        pad_token_id = self.config.pad_token_id
780

781
        assert (
782
            decoder_start_token_id is not None
783
        ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
784

785
        # shift inputs to the right
786
        shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32)
787
        shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
788
        start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
789
        shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
790

791
        assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
792
        # replace possible -100 values in labels by `pad_token_id`
793
        shifted_input_ids = tf.where(
794
            shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
795
        )
796

797
        assert tf.math.reduce_any(
798
            shifted_input_ids >= 0
799
        ).numpy(), "Verify that `labels` has only positive values and -100"
800

801
        return shifted_input_ids
802

803

804
T5_START_DOCSTRING = r"""
805
    The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
806
    <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
807
    Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
808
    It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting.
809

810
    This model is a `tf.keras.Model <https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model>`__
811
    sub-class. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
812
    general usage and behavior.
813

814
    Note on the model inputs:
815
        TF 2.0 models accepts two formats as inputs:
816

817
            - having all inputs as keyword arguments (like PyTorch models), or
818
            - having all inputs as a list, tuple or dict in the first positional arguments.
819

820
        This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
821

822
        If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
823

824
        - a single Tensor with inputs only and nothing else: `model(inputs_ids)`
825
        - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
826
            `model([inputs, attention_mask])` or `model([inputs, attention_mask, token_type_ids])`
827
        - a dictionary with one or several input Tensors associaed to the input names given in the docstring:
828
            `model({'inputs': inputs, 'token_type_ids': token_type_ids})`
829

830
    Parameters:
831
        config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
832
            Initializing with a config file does not load the weights associated with the model, only the configuration.
833
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
834
"""
835

836
T5_INPUTS_DOCSTRING = r"""
837
    Args:
838
        inputs are usually used as a `dict` (see T5 description above for more information) containing all the following.
839

840
        inputs (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
841
            Indices of input sequence tokens in the vocabulary.
842
            T5 is a model with relative position embeddings so you should be able to pad the inputs on
843
            the right or the left.
844
            Indices can be obtained using :class:`transformers.T5Tokenizer`.
845
            To know more on how to prepare :obj:`inputs` for pre-training take a look at
846
            `T5 Training <./t5.html#training>`__.
847
            See :func:`transformers.PreTrainedTokenizer.encode` and
848
            :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
849
        decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
850
            Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
851
            If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
852
        attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
853
            Mask to avoid performing attention on padding token indices.
854
            Mask values selected in ``[0, 1]``:
855
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
856
        encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`, defaults to :obj:`None`):
857
            Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
858
            `last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
859
            Used in the cross-attention of the decoder.
860
        decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
861
            Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
862
        decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
863
            Contains pre-computed key and value hidden-states of the attention blocks.
864
            Can be used to speed up decoding.
865
            If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
866
            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
867
        use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
868
            If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
869
        inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
870
            Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
871
            This is useful if you want more control over how to convert `inputs` indices into associated vectors
872
            than the model's internal embedding lookup matrix.
873
        decoder_inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
874
            Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
875
            This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
876
            than the model's internal embedding lookup matrix.
877
            To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
878
            `T5 Training <./t5.html#training>`__.
879
        head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
880
            Mask to nullify selected heads of the self-attention modules.
881
            Mask values selected in ``[0, 1]``:
882
            ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
883
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
884
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
885
"""
886

887

888
@add_start_docstrings(
889
    "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
890
    T5_START_DOCSTRING,
891
)
892
class TFT5Model(TFT5PreTrainedModel):
893
    def __init__(self, config, *inputs, **kwargs):
894
        super().__init__(config, *inputs, **kwargs)
895
        self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
896

897
        # retrieve correct absolute scope for embed token wrapper
898
        with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
899
            pass
900

901
        embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
902

903
        encoder_config = copy.deepcopy(config)
904
        encoder_config.use_cache = False
905
        self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
906

907
        decoder_config = copy.deepcopy(config)
908
        decoder_config.is_decoder = True
909
        self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
910

911
    def get_input_embeddings(self):
912
        return self.shared
913

914
    def get_output_embeddings(self):
915
        return self.shared
916

917
    def set_input_embeddings(self, new_embeddings):
918
        self.shared.weight = new_embeddings
919
        self.shared.vocab_size = self.shared.weight.shape[0]
920
        # retrieve correct absolute scope for embed token wrapper
921
        with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
922
            pass
923
        embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
924
        self.encoder.set_embed_tokens(embed_tokens)
925
        self.decoder.set_embed_tokens(embed_tokens)
926

927
    def get_encoder(self):
928
        return self.encoder
929

930
    def get_decoder(self):
931
        return self.decoder
932

933
    @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
934
    def call(
935
        self,
936
        inputs,
937
        attention_mask=None,
938
        encoder_outputs=None,
939
        inputs_embeds=None,
940
        head_mask=None,
941
        decoder_past_key_value_states=None,
942
        decoder_input_ids=None,
943
        decoder_attention_mask=None,
944
        decoder_inputs_embeds=None,
945
        use_cache=None,
946
        output_attentions=None,
947
        output_hidden_states=None,
948
        training=False,
949
    ):
950
        r"""
951
    Returns:
952
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
953
        last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
954
            Sequence of hidden-states at the output of the last layer of the model.
955
            If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
956
        decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
957
            Contains pre-computed key and value hidden-states of the attention blocks.
958
            Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
959
            Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
960
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
961
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
962
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
963

964
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
965
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
966
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
967
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
968

969
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
970
            heads.
971

972
    Examples::
973

974
        >>> from transformers import T5Tokenizer, TFT5Model
975

976
        >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
977
        >>> model = TFT5Model.from_pretrained('t5-small')
978
        >>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf")  # Batch size 1
979
        >>> outputs = model(inputs, decoder_input_ids=inputs)
980
        >>> last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple
981

982
        """
983
        if isinstance(inputs, (tuple, list)):
984
            input_ids = inputs[0]
985
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
986
            encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
987
            inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
988
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
989
            decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
990
            decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
991
            decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
992
            decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
993
            use_cache = inputs[9] if len(inputs) > 9 else use_cache
994
            output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
995
            output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
996
            assert len(inputs) <= 12, "Too many inputs."
997
        elif isinstance(inputs, (dict, BatchEncoding)):
998
            if "inputs" in inputs:
999
                warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
1000
                input_ids = inputs.get("inputs")
1001
            input_ids = inputs.get("input_ids")
1002
            attention_mask = inputs.get("attention_mask", attention_mask)
1003
            encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
1004
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1005
            head_mask = inputs.get("head_mask", head_mask)
1006
            decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
1007
            decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
1008
            decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
1009
            decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
1010
            use_cache = inputs.get("use_cache", use_cache)
1011
            output_attentions = inputs.get("output_attentions", output_attentions)
1012
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1013
            assert len(inputs) <= 12, "Too many inputs."
1014
        else:
1015
            input_ids = inputs
1016

1017
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1018

1019
        # Encode if needed (training, first prediction pass)
1020
        if encoder_outputs is None:
1021
            encoder_outputs = self.encoder(
1022
                [
1023
                    input_ids,
1024
                    attention_mask,
1025
                    None,
1026
                    None,
1027
                    inputs_embeds,
1028
                    head_mask,
1029
                    None,
1030
                    False,
1031
                    output_attentions,
1032
                    output_hidden_states,
1033
                ],
1034
                training=training,
1035
            )
1036

1037
        hidden_states = encoder_outputs[0]
1038

1039
        # If decoding with past key value states, only the last tokens
1040
        # should be given as an input
1041
        if decoder_past_key_value_states is not None:
1042
            if decoder_input_ids is not None:
1043
                decoder_input_ids = decoder_input_ids[:, -1:]
1044
            if decoder_inputs_embeds is not None:
1045
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
1046

1047
        # Decode
1048
        decoder_outputs = self.decoder(
1049
            [
1050
                decoder_input_ids,
1051
                decoder_attention_mask,
1052
                hidden_states,
1053
                attention_mask,
1054
                decoder_inputs_embeds,
1055
                head_mask,
1056
                decoder_past_key_value_states,
1057
                use_cache,
1058
                output_attentions,
1059
                output_hidden_states,
1060
            ],
1061
            training=training,
1062
        )
1063

1064
        if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
1065
            past = ((encoder_outputs, decoder_outputs[1]),)
1066
            decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
1067

1068
        return decoder_outputs + encoder_outputs
1069

1070

1071
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
1072
class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModelingLoss):
1073
    def __init__(self, config, *inputs, **kwargs):
1074
        super().__init__(config, *inputs, **kwargs)
1075
        self.model_dim = config.d_model
1076

1077
        self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
1078

1079
        # retrieve correct absolute scope for embed token wrapper
1080
        with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
1081
            pass
1082

1083
        embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
1084

1085
        encoder_config = copy.deepcopy(config)
1086
        encoder_config.use_cache = False
1087
        self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder")
1088

1089
        decoder_config = copy.deepcopy(config)
1090
        decoder_config.is_decoder = True
1091
        self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
1092

1093
    def get_input_embeddings(self):
1094
        return self.shared
1095

1096
    def get_output_embeddings(self):
1097
        return self.shared
1098

1099
    def set_input_embeddings(self, new_embeddings):
1100
        self.shared.weight = new_embeddings
1101
        # retrieve correct absolute scope for embed token wrapper
1102
        with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
1103
            pass
1104
        embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
1105
        self.encoder.set_embed_tokens(embed_tokens)
1106
        self.decoder.set_embed_tokens(embed_tokens)
1107

1108
    def get_encoder(self):
1109
        return self.encoder
1110

1111
    def get_decoder(self):
1112
        return self.decoder
1113

1114
    @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
1115
    def call(
1116
        self,
1117
        inputs,
1118
        attention_mask=None,
1119
        encoder_outputs=None,
1120
        inputs_embeds=None,
1121
        head_mask=None,
1122
        decoder_past_key_value_states=None,
1123
        decoder_input_ids=None,
1124
        decoder_attention_mask=None,
1125
        decoder_inputs_embeds=None,
1126
        use_cache=None,
1127
        output_attentions=None,
1128
        output_hidden_states=None,
1129
        labels=None,
1130
        training=False,
1131
    ):
1132
        r"""
1133
        labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1134
            Labels for computing the cross entropy classification loss.
1135
            Indices should be in ``[0, ..., config.vocab_size - 1]``.
1136

1137
    Returns:
1138
        :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
1139
        prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1140
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1141
        decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
1142
            Contains pre-computed key and value hidden-states of the attention blocks.
1143
            Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
1144
            Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
1145
        hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1146
            tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1147
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1148

1149
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1150
        attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1151
            tuple of :obj:`tf.Tensor` (one for each layer) of shape
1152
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1153

1154
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1155
            heads.
1156

1157
    Examples::
1158

1159
        >>> from transformers import T5Tokenizer, TFT5ForConditionalGeneration
1160

1161
        >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1162
        >>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
1163
        >>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf")  # Batch size 1
1164
        >>> outputs = model(inputs, decoder_input_ids=inputs)
1165
        >>> prediction_scores = outputs[0]
1166

1167
        >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1168
        >>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
1169
        >>> inputs = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="tf")  # Batch size 1
1170
        >>> result = model.generate(inputs)
1171

1172
        """
1173
        if isinstance(inputs, (tuple, list)):
1174
            input_ids = inputs[0]
1175
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
1176
            encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
1177
            inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
1178
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
1179
            decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
1180
            decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
1181
            decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
1182
            decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
1183
            use_cache = inputs[9] if len(inputs) > 9 else use_cache
1184
            output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
1185
            output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
1186
            labels = inputs[12] if len(inputs) > 12 else labels
1187
            assert len(inputs) <= 13, "Too many inputs."
1188
        elif isinstance(inputs, (dict, BatchEncoding)):
1189
            if "inputs" in inputs:
1190
                warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
1191
                input_ids = inputs.get("inputs")
1192
            input_ids = inputs.get("input_ids")
1193
            attention_mask = inputs.get("attention_mask", attention_mask)
1194
            encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
1195
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1196
            head_mask = inputs.get("head_mask", head_mask)
1197
            decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
1198
            decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
1199
            decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
1200
            decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
1201
            use_cache = inputs.get("use_cache", use_cache)
1202
            output_attentions = inputs.get("output_attentions", output_attentions)
1203
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1204
            labels = inputs.get("labels", labels)
1205
            assert len(inputs) <= 13, "Too many inputs."
1206
        else:
1207
            input_ids = inputs
1208

1209
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1210

1211
        # Encode if needed (training, first prediction pass)
1212
        if encoder_outputs is None:
1213
            # Convert encoder inputs in embeddings if needed
1214
            encoder_outputs = self.encoder(
1215
                [
1216
                    input_ids,
1217
                    attention_mask,
1218
                    None,
1219
                    None,
1220
                    inputs_embeds,
1221
                    head_mask,
1222
                    None,
1223
                    False,
1224
                    output_attentions,
1225
                    output_hidden_states,
1226
                ],
1227
                training=training,
1228
            )
1229

1230
        hidden_states = encoder_outputs[0]
1231

1232
        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1233
            # get decoder inputs from shifting lm labels to the right
1234
            decoder_input_ids = self._shift_right(labels)
1235

1236
        # If decoding with past key value states, only the last tokens
1237
        # should be given as an input
1238
        if decoder_past_key_value_states is not None:
1239
            if decoder_input_ids is not None:
1240
                decoder_input_ids = decoder_input_ids[:, -1:]
1241
            if decoder_inputs_embeds is not None:
1242
                decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
1243

1244
        # Decode
1245
        decoder_outputs = self.decoder(
1246
            [
1247
                decoder_input_ids,
1248
                decoder_attention_mask,
1249
                hidden_states,
1250
                attention_mask,
1251
                decoder_inputs_embeds,
1252
                head_mask,
1253
                decoder_past_key_value_states,
1254
                use_cache,
1255
                output_attentions,
1256
                output_hidden_states,
1257
            ],
1258
            training=training,
1259
        )
1260

1261
        # insert decoder past at right place
1262
        # to speed up decoding
1263
        if cast_bool_to_primitive(use_cache, self.config.use_cache) is True:
1264
            past = ((encoder_outputs, decoder_outputs[1]),)
1265
            decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
1266

1267
        sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
1268
        embed_tokens = self.get_output_embeddings()
1269
        logits = embed_tokens(sequence_output, mode="linear")
1270
        decoder_outputs = (logits,) + decoder_outputs[1:]
1271

1272
        if labels is not None:
1273
            loss = self.compute_loss(labels, logits)
1274
            decoder_outputs = (loss,) + decoder_outputs
1275

1276
        return decoder_outputs + encoder_outputs
1277

1278
    def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):
1279
        assert past is not None, "past has to be defined for encoder_outputs"
1280

1281
        # first step
1282
        if len(past) < 2:
1283
            encoder_outputs, decoder_past_key_value_states = past, None
1284
        else:
1285
            encoder_outputs, decoder_past_key_value_states = past[0], past[1]
1286

1287
        return {
1288
            "inputs": None,  # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
1289
            "decoder_input_ids": inputs,  # inputs are the decoder_input_ids
1290
            "decoder_past_key_value_states": decoder_past_key_value_states,
1291
            "encoder_outputs": encoder_outputs,
1292
            "attention_mask": attention_mask,
1293
            "use_cache": use_cache,
1294
        }
1295

1296
    def _reorder_cache(self, past, beam_idx):
1297
        # if decoder past is not included in output
1298
        # speedy decoding is disabled and no need to reorder
1299

1300
        if len(past) < 2:
1301
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1302
            return past
1303

1304
        decoder_past = past[1]
1305
        past = (past[0],)
1306
        reordered_decoder_past = ()
1307

1308
        for layer_past_states in decoder_past:
1309
            # get the correct batch idx from layer past batch dim
1310
            # batch dim of `past` is at 2nd position
1311
            reordered_layer_past_states = ()
1312
            for layer_past_state in layer_past_states:
1313
                # need to set correct `past` for each of the four key / value states
1314
                reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),)
1315

1316
            assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0])
1317
            assert len(reordered_layer_past_states) == len(layer_past_states)
1318

1319
            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1320
        return past + (reordered_decoder_past,)
1321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.