transformers

Форк
0
/
test_modeling_tf_t5.py 
1040 строк · 55.0 Кб
1
# coding=utf-8
2
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
from __future__ import annotations
17

18
import unittest
19

20
from transformers import T5Config, is_tf_available
21
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
22
from transformers.utils import cached_property
23

24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_tf_available():
30
    import tensorflow as tf
31

32
    from transformers import ByT5Tokenizer, T5Tokenizer, TFT5EncoderModel, TFT5ForConditionalGeneration, TFT5Model
33

34

35
class TFT5ModelTester:
36
    def __init__(
37
        self,
38
        parent,
39
    ):
40
        self.parent = parent
41
        self.batch_size = 13
42
        self.seq_length = 7
43
        self.is_training = True
44
        self.use_input_mask = True
45
        self.use_labels = True
46
        self.vocab_size = 99
47
        self.n_positions = 14
48
        self.hidden_size = 32
49
        self.num_hidden_layers = 2
50
        self.num_attention_heads = 4
51
        self.d_ff = 37
52
        self.relative_attention_num_buckets = 8
53
        self.dropout_rate = 0.1
54
        self.initializer_factor = 0.002
55
        self.eos_token_id = 1
56
        self.pad_token_id = 0
57
        self.scope = None
58

59
    def prepare_config_and_inputs(self):
60
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
61

62
        input_mask = None
63
        if self.use_input_mask:
64
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
65

66
        token_labels = None
67
        if self.use_labels:
68
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
69

70
        config = T5Config(
71
            vocab_size=self.vocab_size,
72
            n_positions=self.n_positions,
73
            d_model=self.hidden_size,
74
            d_ff=self.d_ff,
75
            d_kv=self.hidden_size // self.num_attention_heads,
76
            num_layers=self.num_hidden_layers,
77
            num_heads=self.num_attention_heads,
78
            relative_attention_num_buckets=self.relative_attention_num_buckets,
79
            dropout_rate=self.dropout_rate,
80
            initializer_factor=self.initializer_factor,
81
            eos_token_id=self.eos_token_id,
82
            bos_token_id=self.pad_token_id,
83
            pad_token_id=self.pad_token_id,
84
            decoder_start_token_id=self.pad_token_id,
85
        )
86

87
        return (config, input_ids, input_mask, token_labels)
88

89
    def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
90
        model = TFT5Model(config=config)
91
        inputs = {
92
            "input_ids": input_ids,
93
            "decoder_input_ids": input_ids,
94
            "decoder_attention_mask": input_mask,
95
        }
96
        result = model(inputs)
97

98
        result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
99
        decoder_output = result.last_hidden_state
100
        decoder_past = result.past_key_values
101
        encoder_output = result.encoder_last_hidden_state
102
        self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
103
        self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
104
        # There should be `num_layers` key value embeddings stored in decoder_past[1]
105
        self.parent.assertEqual(len(decoder_past), config.num_layers)
106
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
107
        self.parent.assertEqual(len(decoder_past[0]), 4)
108

109
    def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
110
        model = TFT5ForConditionalGeneration(config=config)
111
        inputs_dict = {
112
            "input_ids": input_ids,
113
            "decoder_input_ids": input_ids,
114
            "decoder_attention_mask": input_mask,
115
        }
116

117
        result = model(inputs_dict)
118

119
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
120

121
    def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
122
        model = TFT5Model(config=config).get_decoder()
123

124
        input_ids = input_ids[:1, :]
125
        self.batch_size = 1
126

127
        # first forward pass
128
        outputs = model(input_ids, use_cache=True)
129

130
        outputs_use_cache_conf = model(input_ids)
131
        outputs_no_past = model(input_ids, use_cache=False)
132

133
        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
134
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
135

136
        # create hypothetical next token and extent to next_input_ids
137
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
138

139
        # append to next input_ids and
140
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
141

142
        output_from_no_past = model(next_input_ids)[0]
143
        output_from_past = model(next_tokens, past_key_values=outputs.past_key_values)[0]
144

145
        # select random slice
146
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
147
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
148
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]
149

150
        # test that outputs are equal for slice
151
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
152

153
    def create_and_check_t5_decoder_model_attention_mask_past(
154
        self, config, input_ids, decoder_input_ids, attention_mask
155
    ):
156
        model = TFT5Model(config=config).get_decoder()
157

158
        # create attention mask
159
        half_seq_length = self.seq_length // 2
160
        attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
161
        attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
162
        attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
163

164
        # first forward pass
165
        outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
166

167
        # create hypothetical next token and extent to next_input_ids
168
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
169

170
        # change a random masked slice from input_ids
171
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
172
        random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
173
        vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
174
        condition = tf.transpose(
175
            tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
176
        )
177
        input_ids = tf.where(condition, random_other_next_tokens, input_ids)
178

179
        # append to next input_ids and attn_mask
180
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
181
        attn_mask = tf.concat(
182
            [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
183
            axis=1,
184
        )
185

186
        # get two different outputs
187
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
188
        output_from_past = model(next_tokens, past_key_values=outputs.past_key_values, attention_mask=attn_mask)[0]
189

190
        # select random slice
191
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
192
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
193
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]
194

195
        # test that outputs are equal for slice
196
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
197

198
    def create_and_check_t5_decoder_model_past_large_inputs(
199
        self, config, input_ids, decoder_input_ids, attention_mask
200
    ):
201
        model = TFT5Model(config=config).get_decoder()
202

203
        input_ids = input_ids[:1, :]
204
        attention_mask = attention_mask[:1, :]
205
        self.batch_size = 1
206

207
        # first forward pass
208
        outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
209

210
        # create hypothetical next token and extent to next_input_ids
211
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
212
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
213

214
        # append to next input_ids and
215
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
216
        next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
217

218
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
219
        output_from_past = model(
220
            next_tokens, attention_mask=next_attention_mask, past_key_values=outputs.past_key_values
221
        )[0]
222

223
        self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
224

225
        # select random slice
226
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
227
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
228
        output_from_past_slice = output_from_past[:, :, random_slice_idx]
229

230
        # test that outputs are equal for slice
231
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
232

233
    def prepare_config_and_inputs_for_common(self):
234
        config_and_inputs = self.prepare_config_and_inputs()
235
        (config, input_ids, input_mask, token_labels) = config_and_inputs
236
        inputs_dict = {
237
            "input_ids": input_ids,
238
            "decoder_input_ids": input_ids,
239
            "decoder_attention_mask": input_mask,
240
        }
241
        return config, inputs_dict
242

243

244
@require_tf
245
class TFT5ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
246
    is_encoder_decoder = True
247
    all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
248
    all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
249
    pipeline_model_mapping = (
250
        {
251
            "conversational": TFT5ForConditionalGeneration,
252
            "feature-extraction": TFT5Model,
253
            "summarization": TFT5ForConditionalGeneration,
254
            "text2text-generation": TFT5ForConditionalGeneration,
255
            "translation": TFT5ForConditionalGeneration,
256
        }
257
        if is_tf_available()
258
        else {}
259
    )
260
    test_onnx = False
261

262
    def setUp(self):
263
        self.model_tester = TFT5ModelTester(self)
264
        self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
265

266
    def test_config(self):
267
        self.config_tester.run_common_tests()
268

269
    def test_t5_model(self):
270
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
271
        self.model_tester.create_and_check_t5_model(*config_and_inputs)
272

273
    def test_t5_model_v1_1(self):
274
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
275
        config = config_and_inputs[0]
276
        config.tie_word_embeddings = False
277
        config.feed_forward_proj = "gated-gelu"
278
        self.model_tester.create_and_check_t5_model(config, *config_and_inputs[1:])
279

280
    def test_with_lm_head(self):
281
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
282
        self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
283

284
    def test_t5_decoder_model_past(self):
285
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
286
        self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
287

288
    def test_t5_decoder_model_past_with_attn_mask(self):
289
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
290
        self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
291

292
    def test_t5_decoder_model_past_large_inputs(self):
293
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
294

295
        # `create_and_check_t5_decoder_model_past_large_inputs` has special inputs:
296
        #     (config, input_ids, decoder_input_ids, attention_mask)
297
        # and we have to prepare it correctly here.
298
        config, input_ids, input_mask, token_labels = config_and_inputs
299
        config_and_inputs = (config, input_ids, None, input_mask)
300

301
        self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs)
302

303
    @slow
304
    def test_model_from_pretrained(self):
305
        model = TFT5Model.from_pretrained("google-t5/t5-small")
306
        self.assertIsNotNone(model)
307

308
    def test_generate_with_headmasking(self):
309
        # TODO: Fix head-masking according to PyTorch T5 model
310
        pass
311

312
    # This test is run in `TFT5EncoderOnlyModelTest`, where the main layer has the same inputs as the model
313
    @unittest.skip(reason="The inputs of the Main Layer are different.")
314
    def test_keras_save_load(self):
315
        pass
316

317
    @unittest.skip("Does not support conversations.")
318
    def test_pipeline_conversational(self):
319
        pass
320

321

322
class TFT5EncoderOnlyModelTester:
323
    def __init__(
324
        self,
325
        parent,
326
        vocab_size=99,
327
        batch_size=13,
328
        encoder_seq_length=7,
329
        # For common tests
330
        use_attention_mask=True,
331
        hidden_size=32,
332
        num_hidden_layers=2,
333
        num_attention_heads=4,
334
        d_ff=37,
335
        relative_attention_num_buckets=8,
336
        is_training=False,
337
        dropout_rate=0.1,
338
        initializer_factor=0.002,
339
        is_encoder_decoder=False,
340
        eos_token_id=1,
341
        pad_token_id=0,
342
        scope=None,
343
    ):
344
        self.parent = parent
345
        self.batch_size = batch_size
346
        self.encoder_seq_length = encoder_seq_length
347
        # For common tests
348
        self.seq_length = self.encoder_seq_length
349
        self.use_attention_mask = use_attention_mask
350
        self.vocab_size = vocab_size
351
        self.hidden_size = hidden_size
352
        self.num_hidden_layers = num_hidden_layers
353
        self.num_attention_heads = num_attention_heads
354
        self.d_ff = d_ff
355
        self.relative_attention_num_buckets = relative_attention_num_buckets
356
        self.dropout_rate = dropout_rate
357
        self.initializer_factor = initializer_factor
358
        self.eos_token_id = eos_token_id
359
        self.pad_token_id = pad_token_id
360
        self.is_encoder_decoder = is_encoder_decoder
361
        self.scope = None
362
        self.is_training = is_training
363

364
    def prepare_config_and_inputs(self):
365
        input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
366

367
        attention_mask = None
368
        if self.use_attention_mask:
369
            attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
370

371
        config = T5Config(
372
            vocab_size=self.vocab_size,
373
            d_model=self.hidden_size,
374
            d_ff=self.d_ff,
375
            d_kv=self.hidden_size // self.num_attention_heads,
376
            num_layers=self.num_hidden_layers,
377
            num_heads=self.num_attention_heads,
378
            relative_attention_num_buckets=self.relative_attention_num_buckets,
379
            dropout_rate=self.dropout_rate,
380
            initializer_factor=self.initializer_factor,
381
            eos_token_id=self.eos_token_id,
382
            bos_token_id=self.pad_token_id,
383
            pad_token_id=self.pad_token_id,
384
            is_encoder_decoder=self.is_encoder_decoder,
385
        )
386

387
        return (
388
            config,
389
            input_ids,
390
            attention_mask,
391
        )
392

393
    def create_and_check_model(
394
        self,
395
        config,
396
        input_ids,
397
        attention_mask,
398
    ):
399
        model = TFT5EncoderModel(config=config)
400
        result = model(
401
            input_ids=input_ids,
402
            attention_mask=attention_mask,
403
        )
404
        result = model(input_ids=input_ids)
405
        encoder_output = result.last_hidden_state
406

407
        self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.encoder_seq_length, self.hidden_size))
408

409
    def prepare_config_and_inputs_for_common(self):
410
        config_and_inputs = self.prepare_config_and_inputs()
411
        (
412
            config,
413
            input_ids,
414
            attention_mask,
415
        ) = config_and_inputs
416

417
        inputs_dict = {
418
            "input_ids": input_ids,
419
            "attention_mask": attention_mask,
420
        }
421
        return config, inputs_dict
422

423

424
class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
425
    is_encoder_decoder = False
426
    all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
427
    test_onnx = False
428

429
    def setUp(self):
430
        self.model_tester = TFT5EncoderOnlyModelTester(self)
431
        self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
432

433
    def test_config(self):
434
        self.config_tester.run_common_tests()
435

436
    def test_model(self):
437
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
438
        self.model_tester.create_and_check_model(*config_and_inputs)
439

440
    # is not able to be part of a pipeline
441
    def test_train_pipeline_custom_model(self):
442
        pass
443

444

445
@require_tf
446
@require_sentencepiece
447
@require_tokenizers
448
class TFT5GenerationIntegrationTests(unittest.TestCase):
449
    @slow
450
    def test_greedy_xla_generate_simple(self):
451
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
452
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
453

454
        # two examples with different lengths to confirm that attention masks are operational in XLA
455
        sentences = [
456
            "Translate English to German: Today is a beautiful day.",
457
            "Translate English to German: I have four cats, three dogs, two birds, and a horse.",
458
        ]
459
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
460

461
        xla_generate = tf.function(model.generate, jit_compile=True)
462

463
        output_ids = model.generate(input_ids)
464
        output_ids_xla = xla_generate(input_ids)
465

466
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
467
        output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
468

469
        expected_output_string = [
470
            "Heute ist ein schöner Tag.",
471
            "Ich habe vier Katzen, drei Hunde, zwei Vögel und ein Pferd.",
472
        ]
473

474
        self.assertListEqual(expected_output_string, output_strings)
475
        self.assertListEqual(expected_output_string, output_strings_xla)
476

477
    @slow
478
    def test_greedy_generate(self):
479
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
480
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
481

482
        sentences = ["Yesterday, my name was", "Today is a beautiful day and"]
483
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
484

485
        generation_kwargs = {
486
            "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
487
            "no_repeat_ngram_size": 3,
488
            "do_sample": False,
489
            "repetition_penalty": 2.2,
490
        }
491

492
        output_ids = model.generate(input_ids, **generation_kwargs)
493

494
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
495

496
        expected_output_string = ["Yesterday, my name was", "Heute ist ein schöne Tag und"]
497

498
        self.assertListEqual(expected_output_string, output_strings)
499

500
    @slow
501
    def test_sample_xla_generate_simple(self):
502
        # NOTE: due to the small numerical differences that are natural when we compile to XLA, sampling the same
503
        # output out of the same seed is far from guaranteed. We can, however, confirm that the results are sensible
504
        # and that we can seed both versions.
505

506
        # forces the generation to happen on CPU, to avoid GPU-related quirks
507
        with tf.device(":/CPU:0"):
508
            model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
509
            tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
510

511
            sentence = "Translate English to German: I have two bananas"
512
            input_ids = tokenizer(sentence, return_tensors="tf", padding=True).input_ids
513
            expected_output_string = ["Ich habe zwei Bananen"]
514
            expected_output_string_xla = ["Ich habe 2 Bananen"]
515

516
            # seed set -> deterministic sampling sequence -> deterministic generation
517
            output_ids = model.generate(input_ids, do_sample=True, seed=[42, 0])
518
            output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
519
            self.assertListEqual(expected_output_string, output_strings)
520

521
            xla_generate = tf.function(model.generate, jit_compile=True)
522
            # seed set -> deterministic sampling sequence -> deterministic generation
523
            output_ids_xla = xla_generate(input_ids, do_sample=True, seed=[42, 0])
524
            output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
525
            self.assertListEqual(expected_output_string_xla, output_strings_xla)
526

527
    @slow
528
    def test_sample_generate(self):
529
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
530
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
531

532
        sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
533
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
534

535
        generation_kwargs = {
536
            "do_sample": True,
537
            "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
538
            "no_repeat_ngram_size": 3,
539
            "repetition_penalty": 2.2,
540
            "temperature": 0.8,
541
            "top_k": 500,
542
            "top_p": 0.9,
543
            "seed": [20, 0],  # seed set -> deterministic sampling sequence -> deterministic generation
544
        }
545

546
        # forces the generation to happen on CPU, to avoid GPU-related quirks
547
        with tf.device(":/CPU:0"):
548
            output_ids = model.generate(input_ids, **generation_kwargs)
549

550
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
551

552
        expected_output_string = ["- I really love my way of this.", "die Transformatoren sind wirklich erstaunlich"]
553

554
        self.assertListEqual(expected_output_string, output_strings)
555

556
    # TODO (ydshieh): undo skip once a fix is done on TF side.
557
    @unittest.skip("Skip for now as TF 2.13 breaks it on GPU")
558
    @slow
559
    def test_beam_search_xla_generate_simple(self):
560
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
561
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
562

563
        # tests XLA with task specific arguments
564
        task_specific_config = getattr(model.config, "task_specific_params", {})
565
        translation_config = task_specific_config.get("translation_en_to_fr", {})
566
        model.config.update(translation_config)
567

568
        # two examples with different lengths to confirm that attention masks are operational in XLA
569
        sentences = [
570
            model.config.prefix + "Today is a beautiful day.",
571
            model.config.prefix + "I have four cats, three dogs, two birds, and a horse.",
572
        ]
573
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
574

575
        xla_generate = tf.function(model.generate, jit_compile=True)
576

577
        output_ids = model.generate(input_ids, num_beams=2)
578
        output_ids_xla = xla_generate(input_ids, num_beams=2)
579

580
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
581
        output_strings_xla = tokenizer.batch_decode(output_ids_xla, skip_special_tokens=True)
582

583
        expected_output_string = [
584
            "Aujourd'hui est une belle journée.",
585
            "J'ai quatre chats, trois chiens, deux oiseaux et un cheval.",
586
        ]
587

588
        self.assertListEqual(expected_output_string, output_strings)
589
        self.assertListEqual(expected_output_string, output_strings_xla)
590

591
    @slow
592
    def test_beam_search_generate(self):
593
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
594
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
595

596
        sentences = ["I really love my", "Translate English to German: the transformers are truly amazing"]
597
        input_ids = tokenizer(sentences, return_tensors="tf", padding=True).input_ids
598

599
        generation_kwargs = {
600
            "bad_words_ids": [tokenizer("my").input_ids, tokenizer("ein schöner").input_ids],
601
            "no_repeat_ngram_size": 3,
602
            "do_sample": False,
603
            "repetition_penalty": 2.2,
604
            "num_beams": 4,
605
        }
606

607
        output_ids = model.generate(input_ids, **generation_kwargs)
608

609
        output_strings = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
610

611
        expected_output_string = ["Ich liebe es so sehr!", "die Transformatoren sind wirklich erstaunlich"]
612
        self.assertListEqual(expected_output_string, output_strings)
613

614
    @unittest.skip("Does not support conversations.")
615
    def test_pipeline_conversational(self):
616
        pass
617

618

619
@require_tf
620
@require_sentencepiece
621
@require_tokenizers
622
class TFT5ModelIntegrationTests(unittest.TestCase):
623
    @cached_property
624
    def model(self):
625
        return TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-base")
626

627
    @slow
628
    def test_small_integration_test(self):
629
        """
630
        For comparision run:
631
        >>> import t5  # pip install t5==0.7.1
632
        >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
633

634
        >>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
635
        >>> path_to_mtf_small_spm_model_path = '<fill_in>'
636
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
637
        >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
638
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
639
        """
640

641
        model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
642
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
643

644
        input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
645
        labels = tokenizer("Hi I am", return_tensors="tf").input_ids
646

647
        loss = model(input_ids, labels=labels).loss
648
        mtf_score = -tf.math.reduce_mean(loss).numpy()
649

650
        EXPECTED_SCORE = -4.771147
651
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
652

653
    @slow
654
    def test_small_v1_1_integration_test(self):
655
        """
656
        For comparision run:
657
        >>> import t5  # pip install t5==0.7.1
658
        >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
659

660
        >>> path_to_mtf_small_t5_v1.1_checkpoint = '<fill_in>'
661
        >>> path_to_mtf_small_spm_model_path = '<fill_in>'
662
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1.1_checkpoint, batch_size=1, tpu=None)
663
        >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
664
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
665
        """
666

667
        model = TFT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small")
668
        tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
669

670
        input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
671
        labels = tokenizer("Hi I am", return_tensors="tf").input_ids
672

673
        loss = model(input_ids, labels=labels).loss
674
        mtf_score = -tf.math.reduce_mean(loss).numpy()
675

676
        EXPECTED_SCORE = -14.757326
677
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
678

679
    @slow
680
    def test_small_byt5_integration_test(self):
681
        """
682
        For comparision run:
683
        >>> import t5  # pip install t5==0.9.1
684

685
        >>> path_to_byt5_small_checkpoint = '<fill_in>'
686
        >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None)
687
        >>> vocab = t5.data.ByteVocabulary()
688
        >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
689
        """
690

691
        model = TFT5ForConditionalGeneration.from_pretrained("google/byt5-small")
692
        tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
693

694
        input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
695
        labels = tokenizer("Hi I am", return_tensors="tf").input_ids
696

697
        loss = model(input_ids, labels=labels).loss
698
        mtf_score = -tf.math.reduce_mean(loss).numpy()
699

700
        EXPECTED_SCORE = -7.592465
701
        self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
702

703
    @slow
704
    def test_summarization(self):
705
        model = self.model
706
        tok = T5Tokenizer.from_pretrained("google-t5/t5-base")
707

708
        FRANCE_ARTICLE = (  # @noqa
709
            "Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings"
710
            " Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane."
711
            ' Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation."'
712
            ' He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s'
713
            " comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video"
714
            " showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French"
715
            " Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a"
716
            " phone at the wreckage site. The two publications described the supposed video, but did not post it on"
717
            " their websites. The publications said that they watched the video, which was found by a source close to"
718
            " the investigation. \"One can hear cries of 'My God' in several languages,\" Paris Match reported."
719
            ' "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the'
720
            " cockpit door with a heavy object.  Towards the end, after a heavy shake, stronger than the others, the"
721
            ' screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt,'
722
            " editor-in-chief of Bild online. An official with France's accident investigation agency, the BEA, said"
723
            " the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman"
724
            " in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the"
725
            ' reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said,'
726
            ' but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be'
727
            " sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by"
728
            " specialized technicians working hand-in-hand with investigators. But none of the cell phones found so"
729
            " far have been sent to the institute, Menichini said. Asked whether staff involved in the search could"
730
            ' have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin'
731
            ' Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match'
732
            ' are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered'
733
            ' cell phones from the crash site after Bild and Paris Match published their reports. "That is something'
734
            " we did not know before. ... Overall we can say many things of the investigation weren't revealed by the"
735
            ' investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline'
736
            " Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the"
737
            " controls of Germanwings Flight 9525, which he's accused of deliberately crashing last week in the"
738
            ' French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of'
739
            ' severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school'
740
            " discovered in an internal investigation, Lufthansa said, included medical documents he submitted in"
741
            " connection with resuming his flight training. The announcement indicates that Lufthansa, the parent"
742
            " company of Germanwings, knew of Lubitz's battle with depression, allowed him to continue training and"
743
            " ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100%"
744
            ' fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was'
745
            " sharing the information and documents -- including training and medical records -- with public"
746
            " prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the"
747
            " past week to recover human remains and plane debris scattered across a steep mountainside. He saw the"
748
            " crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash"
749
            " site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late"
750
            " Tuesday that no visible human remains were left at the site but recovery teams would keep searching."
751
            " French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all"
752
            " the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested."
753
            " In the meantime, the recovery of the victims' personal belongings will start Wednesday, Menichini said."
754
            " Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew"
755
            " on board. Check out the latest from our correspondents . The details about Lubitz's correspondence with"
756
            " the flight school during his training were among several developments as investigators continued to"
757
            " delve into what caused the crash and Lubitz's possible motive for downing the jet. A Lufthansa"
758
            " spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his"
759
            ' examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in'
760
            " Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at"
761
            " some point before his aviation career and underwent psychotherapy before he got his pilot's license."
762
            " Kumpa emphasized there's no evidence suggesting Lubitz was suicidal or acting aggressively before the"
763
            " crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to"
764
            " lose his pilot's license, a European government official briefed on the investigation told CNN on"
765
            ' Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being'
766
            " considered. Another source, a law enforcement official briefed on the investigation, also told CNN that"
767
            " authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would"
768
            " not be allowed to fly because of his medical problems. Lubitz's girlfriend told investigators he had"
769
            " seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded"
770
            " he had psychological issues, the European government official said. But no matter what details emerge"
771
            " about his previous mental health struggles, there's more to the story, said Brian Russell, a forensic"
772
            ' psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact'
773
            " that maybe they weren't going to keep doing their job and they're upset about that and so they're"
774
            ' suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to'
775
            " also take that rage and turn it outward on 149 other people who had nothing to do with the person's"
776
            ' problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight'
777
            " 9525? CNN's Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura"
778
            " Smith-Spark wrote from London. CNN's Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine"
779
            " Amiel and Anna-Maja Rappard contributed to this report."
780
        )
781

782
        SHORTER_ARTICLE = (
783
            "(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
784
            " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
785
            " formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based."
786
            " The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its"
787
            ' jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East'
788
            ' Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the'
789
            " situation in Palestinian territories, paving the way for possible war crimes investigations against"
790
            " Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and"
791
            " the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the"
792
            " body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a"
793
            ' move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the'
794
            ' world is also a step closer to ending a long era of impunity and injustice," he said, according to an'
795
            ' ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge'
796
            " Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the"
797
            ' Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine'
798
            " acquires all the rights as well as responsibilities that come with being a State Party to the Statute."
799
            ' These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights'
800
            ' Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should'
801
            " immediately end their pressure, and countries that support universal acceptance of the court's treaty"
802
            ' should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the'
803
            " group. \"What's objectionable is the attempts to undermine international justice, not Palestine's"
804
            ' decision to join a treaty to which over 100 countries around the world are members." In January, when'
805
            " the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an"
806
            ' outrage, saying the court was overstepping its boundaries. The United States also said it "strongly"'
807
            " disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a"
808
            ' state and therefore we do not believe that it is eligible to join the ICC," the State Department said in'
809
            ' a statement. It urged the warring sides to resolve their differences through direct negotiations. "We'
810
            ' will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,"'
811
            " it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the"
812
            ' territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the'
813
            " court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou"
814
            ' Bensouda said her office would "conduct its analysis in full independence and impartiality." The war'
815
            " between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry"
816
            " will include alleged war crimes committed since June. The International Criminal Court was set up in"
817
            " 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder"
818
            " and Faith Karimi contributed to this report."
819
        )
820

821
        IRAN_ARTICLE = (
822
            "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran"
823
            " in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively"
824
            " block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger."
825
            " Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli"
826
            " Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a"
827
            " letter to the Iranian leadership warning them away from a deal. The debate that has already begun since"
828
            " the announcement of the new framework will likely result in more heat than light. It will not be helped"
829
            " by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: ."
830
            " The most misleading assertion, despite universal rejection by experts, is that the negotiations'"
831
            " objective at the outset was the total elimination of any nuclear program in Iran. That is the position"
832
            " of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it"
833
            " had been, there would have been no Iranian team at the negotiating table. Rather, the objective has"
834
            " always been to structure an agreement or series of agreements so that Iran could not covertly develop a"
835
            " nuclear arsenal before the United States and its allies could respond. The new framework has exceeded"
836
            " expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by"
837
            " two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another"
838
            " dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite"
839
            " sharp accusations by some in the United States and its allies, Iran denies having such a program, and"
840
            " U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's"
841
            " continued cooperation with International Atomic Energy Agency inspections is further evidence on this"
842
            " point, and we'll know even more about Iran's program in the coming months and years because of the deal."
843
            " In fact, the inspections provisions that are part of this agreement are designed to protect against any"
844
            " covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that"
845
            " the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter"
846
            " warning that a deal might be killed by Congress or a future president). This of course is not the case."
847
            " The talks were between Iran and the five permanent members of the U.N. Security Council (United States,"
848
            " United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has"
849
            " played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement"
850
            " reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran"
851
            " and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement"
852
            " contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the"
853
            " case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased"
854
            " or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes"
855
            " Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear"
856
            " sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going"
857
            " forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such"
858
            " a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the"
859
            ' agreement should be a formal treaty requiring the Senate to "advise and consent." But the issue is not'
860
            " suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New"
861
            " START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement"
862
            " with Iran will not be so balanced.  The restrictions and obligations in the final framework agreement"
863
            " will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove"
864
            " most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally"
865
            " some insist that any agreement must address Iranian missile programs, human rights violations or support"
866
            " for Hamas or Hezbollah.  As important as these issues are, and they must indeed be addressed, they are"
867
            " unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran.  To include them in"
868
            " the negotiations would be a poison pill. This agreement should be judged on its merits and on how it"
869
            " affects the security of our negotiating partners and allies, including Israel. Those judgments should be"
870
            " fact-based, not based on questionable assertions or dubious assumptions."
871
        )
872

873
        ARTICLE_SUBWAY = (
874
            "New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A"
875
            " year later, she got married again in Westchester County, but to a different man and without divorcing"
876
            " her first husband.  Only 18 days after that marriage, she got hitched yet again. Then, Barrientos"
877
            ' declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married'
878
            " once more, this time in the Bronx. In an application for a marriage license, she stated it was her"
879
            ' "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false'
880
            ' instrument for filing in the first degree," referring to her false statements on the 2010 marriage'
881
            " license application, according to court documents. Prosecutors said the marriages were part of an"
882
            " immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to"
883
            " her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was"
884
            " arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New"
885
            " York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total,"
886
            " Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.  All"
887
            " occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be"
888
            " married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors"
889
            " said the immigration scam involved some of her husbands, who filed for permanent residence status"
890
            " shortly after the marriages.  Any divorces happened only after such filings were approved. It was"
891
            " unclear whether any of the men will be prosecuted. The case was referred to the Bronx District"
892
            " Attorney's Office by Immigration and Customs Enforcement and the Department of Homeland Security's"
893
            ' Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt,'
894
            " Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his"
895
            " native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces"
896
            " up to four years in prison.  Her next court appearance is scheduled for May 18."
897
        )
898

899
        expected_summaries = [
900
            'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
901
            " cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
902
            " magazine says .",
903
            "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
904
            " preliminary examination into the situation in the occupied Palestinian territory . as members of the"
905
            " court, Palestinians may be subject to counter-charges as well .",
906
            "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
907
            " the debate that has already begun since the announcement of the new framework will likely result in more"
908
            " heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
909
            " implement a rigorous inspection regime .",
910
            "prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
911
            ' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
912
            " times, with nine of her marriages occurring between 1999 and 2002 .",
913
        ]
914

915
        task_specific_config = getattr(model.config, "task_specific_params", {})
916
        summarization_config = task_specific_config.get("summarization", {})
917
        model.config.update(summarization_config)
918

919
        dct = tok(
920
            [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
921
            max_length=512,
922
            padding="max_length",
923
            truncation=True,
924
            return_tensors="tf",
925
        )
926
        self.assertEqual(512, dct["input_ids"].shape[1])
927

928
        hypotheses_batch = model.generate(
929
            input_ids=dct["input_ids"],
930
            attention_mask=dct["attention_mask"],
931
            num_beams=4,
932
            length_penalty=2.0,
933
            max_length=142,
934
            min_length=56,
935
            no_repeat_ngram_size=3,
936
            do_sample=False,
937
            early_stopping=True,
938
        )
939

940
        decoded = [
941
            tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
942
        ]
943

944
        self.assertListEqual(
945
            expected_summaries,
946
            decoded,
947
        )
948

949
    @slow
950
    def test_translation_en_to_de(self):
951
        tok = T5Tokenizer.from_pretrained("google-t5/t5-base")
952
        model = self.model
953

954
        task_specific_config = getattr(model.config, "task_specific_params", {})
955
        translation_config = task_specific_config.get("translation_en_to_de", {})
956
        self.model.config.update(translation_config)
957

958
        original_input = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
959
        expected_translation = (
960
            '"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
961
        )
962

963
        input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf")
964

965
        output = model.generate(
966
            input_ids=input_ids,
967
            num_beams=4,
968
            length_penalty=2.0,
969
            max_length=50,
970
            no_repeat_ngram_size=3,
971
            do_sample=False,
972
            early_stopping=True,
973
        )
974
        translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
975

976
        self.assertEqual(translation, expected_translation)
977

978
    @slow
979
    def test_translation_en_to_fr(self):
980
        model = self.model
981
        tok = T5Tokenizer.from_pretrained("google-t5/t5-base")
982

983
        task_specific_config = getattr(model.config, "task_specific_params", {})
984
        translation_config = task_specific_config.get("translation_en_to_fr", {})
985
        model.config.update(translation_config)
986

987
        en_text = (
988
            ' This image section from an infrared recording by the Spitzer telescope shows a "family portrait" of'
989
            " countless generations of stars: the oldest stars are seen as blue dots. "
990
        )
991

992
        new_truncated_translation = (
993
            "Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
994
            "un "
995
            "« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées "
996
            "sous forme "
997
            "de points bleus."
998
        )
999

1000
        input_ids = tok(model.config.prefix + en_text, return_tensors="tf").input_ids
1001

1002
        output = model.generate(
1003
            input_ids=input_ids,
1004
            num_beams=4,
1005
            length_penalty=2.0,
1006
            max_length=100,
1007
            no_repeat_ngram_size=3,
1008
            do_sample=False,
1009
            early_stopping=True,
1010
        )
1011
        translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
1012

1013
        self.assertEqual(translation, new_truncated_translation)
1014

1015
    @slow
1016
    def test_translation_en_to_ro(self):
1017
        model = self.model
1018
        tok = T5Tokenizer.from_pretrained("google-t5/t5-base")
1019

1020
        task_specific_config = getattr(model.config, "task_specific_params", {})
1021
        translation_config = task_specific_config.get("translation_en_to_ro", {})
1022
        model.config.update(translation_config)
1023

1024
        original_input = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
1025
        expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
1026

1027
        input_ids = tok.encode(model.config.prefix + original_input, return_tensors="tf")
1028

1029
        output = model.generate(
1030
            input_ids=input_ids,
1031
            num_beams=4,
1032
            length_penalty=2.0,
1033
            max_length=50,
1034
            no_repeat_ngram_size=3,
1035
            do_sample=False,
1036
            early_stopping=True,
1037
        )
1038
        translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
1039

1040
        self.assertEqual(translation, expected_translation)
1041

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.