transformers

Форк
0
/
test_modeling_tf_whisper.py 
1119 строк · 46.7 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the TensorFlow Whisper model. """
16

17
from __future__ import annotations
18

19
import inspect
20
import tempfile
21
import traceback
22
import unittest
23

24
import numpy as np
25

26
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
27
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
28
from transformers.utils import cached_property
29
from transformers.utils.import_utils import is_datasets_available
30

31
from ...test_configuration_common import ConfigTester
32
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
33
from ...test_pipeline_mixin import PipelineTesterMixin
34

35

36
if is_datasets_available():
37
    import datasets
38
    from datasets import load_dataset
39

40

41
if is_tf_available():
42
    import tensorflow as tf
43

44
    from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
45
    from transformers.models.whisper.modeling_tf_whisper import (
46
        TFWhisperDecoder,
47
        TFWhisperEncoder,
48
        sinusoidal_embedding_init,
49
    )
50

51

52
def prepare_whisper_inputs_dict(
53
    config,
54
    input_features,
55
    decoder_input_ids,
56
    attention_mask=None,
57
    decoder_attention_mask=None,
58
    head_mask=None,
59
    decoder_head_mask=None,
60
    cross_attn_head_mask=None,
61
):
62
    if decoder_attention_mask is None:
63
        decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)
64
    if head_mask is None:
65
        head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
66
    if decoder_head_mask is None:
67
        decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
68
    if cross_attn_head_mask is None:
69
        cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
70
    return {
71
        "input_features": input_features,
72
        "decoder_input_ids": decoder_input_ids,
73
        "decoder_attention_mask": decoder_attention_mask,
74
        "head_mask": head_mask,
75
        "decoder_head_mask": decoder_head_mask,
76
        "cross_attn_head_mask": cross_attn_head_mask,
77
    }
78

79

80
@require_tf
81
class TFWhisperModelTester:
82
    def __init__(
83
        self,
84
        parent,
85
        batch_size=13,
86
        seq_length=60,
87
        is_training=True,
88
        use_labels=False,
89
        vocab_size=200,
90
        hidden_size=16,
91
        num_hidden_layers=2,
92
        num_attention_heads=4,
93
        input_channels=1,
94
        hidden_act="gelu",
95
        hidden_dropout_prob=0.1,
96
        attention_probs_dropout_prob=0.1,
97
        max_position_embeddings=20,
98
        max_source_positions=30,
99
        max_target_positions=60,
100
        bos_token_id=98,
101
        eos_token_id=98,
102
        pad_token_id=0,
103
        num_mel_bins=80,
104
        decoder_start_token_id=85,
105
        num_conv_layers=1,
106
        suppress_tokens=None,
107
        begin_suppress_tokens=None,
108
    ):
109
        self.parent = parent
110
        self.batch_size = batch_size
111
        self.seq_length = seq_length
112
        self.is_training = is_training
113
        self.use_labels = use_labels
114
        self.vocab_size = vocab_size
115
        self.hidden_size = hidden_size
116
        self.num_hidden_layers = num_hidden_layers
117
        self.num_attention_heads = num_attention_heads
118
        self.input_channels = input_channels
119
        self.hidden_act = hidden_act
120
        self.hidden_dropout_prob = hidden_dropout_prob
121
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
122
        self.num_mel_bins = num_mel_bins
123
        self.max_position_embeddings = max_position_embeddings
124
        self.max_source_positions = max_source_positions
125
        self.max_target_positions = max_target_positions
126
        self.eos_token_id = eos_token_id
127
        self.pad_token_id = pad_token_id
128
        self.bos_token_id = bos_token_id
129
        self.decoder_start_token_id = decoder_start_token_id
130
        self.num_conv_layers = num_conv_layers
131
        self.suppress_tokens = suppress_tokens
132
        self.begin_suppress_tokens = begin_suppress_tokens
133

134
    def prepare_config_and_inputs(self):
135
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
136

137
        decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
138

139
        config = self.get_config()
140
        inputs_dict = prepare_whisper_inputs_dict(
141
            config,
142
            attention_mask=None,
143
            input_features=input_features,
144
            decoder_input_ids=decoder_input_ids,
145
        )
146
        return config, inputs_dict
147

148
    def get_config(self):
149
        return WhisperConfig(
150
            vocab_size=self.vocab_size,
151
            d_model=self.hidden_size,
152
            encoder_layers=self.num_hidden_layers,
153
            decoder_layers=self.num_hidden_layers,
154
            encoder_attention_heads=self.num_attention_heads,
155
            decoder_attention_heads=self.num_attention_heads,
156
            input_channels=self.input_channels,
157
            dropout=self.hidden_dropout_prob,
158
            attention_dropout=self.attention_probs_dropout_prob,
159
            max_position_embeddings=self.max_position_embeddings,
160
            max_source_positions=self.max_source_positions,
161
            max_target_positions=self.max_target_positions,
162
            eos_token_id=self.eos_token_id,
163
            bos_token_id=self.bos_token_id,
164
            pad_token_id=self.pad_token_id,
165
            decoder_ffn_dim=self.hidden_size,
166
            encoder_ffn_dim=self.hidden_size,
167
            decoder_start_token_id=self.decoder_start_token_id,
168
            suppress_tokens=self.suppress_tokens,
169
            begin_suppress_tokens=self.begin_suppress_tokens,
170
        )
171

172
    def prepare_config_and_inputs_for_common(self):
173
        config, inputs_dict = self.prepare_config_and_inputs()
174
        return config, inputs_dict
175

176
    def get_subsampled_output_lengths(self, input_lengths):
177
        """
178
        Computes the output length of the convolutional layers
179
        """
180

181
        for i in range(self.num_conv_layers):
182
            input_lengths = (input_lengths - 1) // 2 + 1
183

184
        return input_lengths
185

186
    def create_and_check_model_forward(self, config, inputs_dict):
187
        model = TFWhisperModel(config=config)
188

189
        input_features = inputs_dict["input_features"]
190
        decoder_input_ids = inputs_dict["decoder_input_ids"]
191

192
        # first forward pass
193
        last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
194

195
        self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
196

197
    def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
198
        model = TFWhisperModel(config=config).get_decoder()
199
        # take a slice so we're shorter than the seqeuence length and can append later
200
        input_ids = inputs_dict["decoder_input_ids"][:, :-10]
201
        attention_mask = inputs_dict["decoder_attention_mask"][:, :-10]
202

203
        # first forward pass
204
        outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
205

206
        output, past_key_values = outputs.to_tuple()
207

208
        # create hypothetical multiple next token and extent to next_input_ids
209
        next_token = ids_tensor((self.batch_size, 3), config.vocab_size)
210
        next_tokens = tf.where(next_token <= 2, 2, next_token)
211
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
212

213
        # append to next input_ids and
214
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
215
        next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
216

217
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
218
        output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
219
            "last_hidden_state"
220
        ]
221

222
        # select random slice
223
        random_slice_idx = np.random.randint(0, output_from_past.shape[-1])
224
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
225
        output_from_past_slice = output_from_past[:, :, random_slice_idx]
226

227
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
228

229
        # test that outputs are equal for slice
230
        self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
231

232
    def check_encoder_decoder_model_standalone(self, config, inputs_dict):
233
        model = TFWhisperModel(config=config)
234
        outputs = model(**inputs_dict)
235

236
        encoder_last_hidden_state = outputs.encoder_last_hidden_state
237
        last_hidden_state = outputs.last_hidden_state
238

239
        with tempfile.TemporaryDirectory() as tmpdirname:
240
            encoder = model.get_encoder()
241
            encoder.save_pretrained(tmpdirname)
242
            encoder = TFWhisperEncoder.from_pretrained(tmpdirname)
243

244
        encoder_last_hidden_state_2 = encoder(inputs_dict["input_features"])[0]
245

246
        self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max() < 1e-3)
247

248
        with tempfile.TemporaryDirectory() as tmpdirname:
249
            decoder = model.get_decoder()
250
            decoder.save_pretrained(tmpdirname)
251
            decoder = TFWhisperDecoder.from_pretrained(tmpdirname)
252

253
        last_hidden_state_2 = decoder(
254
            input_ids=inputs_dict["decoder_input_ids"],
255
            attention_mask=inputs_dict["decoder_attention_mask"],
256
            encoder_hidden_states=encoder_last_hidden_state,
257
        )[0]
258

259
        self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max() < 1e-3)
260

261

262
@require_tf
263
class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
264
    all_model_classes = (TFWhisperModel, TFWhisperForConditionalGeneration) if is_tf_available() else ()
265
    all_generative_model_classes = (TFWhisperForConditionalGeneration,) if is_tf_available() else ()
266
    pipeline_model_mapping = {"feature-extraction": TFWhisperModel} if is_tf_available() else {}
267
    is_encoder_decoder = True
268
    fx_compatible = False
269
    test_pruning = False
270
    test_missing_keys = False
271
    test_onnx = False
272

273
    input_name = "input_features"
274

275
    # TODO (ydshieh): undo skip once a fix is done on TF side.
276
    @unittest.skip("Skip for now as TF 2.13 breaks it on GPU")
277
    def test_xla_generate_slow(self):
278
        super().test_xla_generate_slow()
279

280
    def setUp(self):
281
        self.model_tester = TFWhisperModelTester(self)
282
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
283
        self.maxDiff = 3000
284

285
    def test_config(self):
286
        self.config_tester.run_common_tests()
287

288
    def test_save_load_strict(self):
289
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
290
        for model_class in self.all_model_classes:
291
            model = model_class(config)
292

293
            model.build_in_name_scope()
294

295
            with tempfile.TemporaryDirectory() as tmpdirname:
296
                model.save_pretrained(tmpdirname, saved_model=False)
297
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
298
            self.assertEqual(info["missing_keys"], [])
299

300
    def test_model_forward(self):
301
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
302
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
303

304
    def test_requires_grad_encoder_embed_positions(self):
305
        config = self.model_tester.get_config()
306
        for model_class in self.all_model_classes:
307
            model = model_class(config)
308
            encoder = model.get_encoder()
309
            self.assertFalse(encoder.embed_positions.trainable)
310

311
    def test_encoder_sinusoidal_embed_positions(self):
312
        config = self.model_tester.get_config()
313
        for model_class in self.all_model_classes:
314
            model = model_class(config)
315
            model.build_in_name_scope()
316

317
            embeds = model.get_encoder().embed_positions.get_weights()[0]
318
            sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
319
            self.assertTrue(np.allclose(embeds, sinusoids))
320

321
    def test_decoder_model_past_with_large_inputs(self):
322
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
323
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
324

325
    def _get_input_ids_and_config(self):
326
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
327
        input_ids = inputs_dict[self.input_name]
328

329
        # cut to half length & take max batch_size 3
330
        max_batch_size = 3
331
        input_ids = input_ids[:max_batch_size, :, :]
332

333
        # generate max 3 tokens
334
        max_length = 4
335
        if config.eos_token_id is not None and config.pad_token_id is None:
336
            # hack to allow generate for models such as GPT2 as is done in `generate()`
337
            config.pad_token_id = config.eos_token_id
338

339
        return config, input_ids, None, max_length
340

341
    # not implemented currently
342
    def test_inputs_embeds(self):
343
        pass
344

345
    @unittest.skip("Training is not yet supported")
346
    def test_training(self):
347
        pass
348

349
    def test_generate_with_head_masking(self):
350
        pass
351

352
    @unittest.skip("fp16 is not yet supported for TF models")
353
    def test_generate_fp16(self):
354
        config, input_dict = self.model_tester.prepare_config_and_inputs()
355
        config.max_target_positions = 400
356
        input_features = input_dict["input_features"]
357
        model = TFWhisperForConditionalGeneration(config)
358
        model.generate(input_features)
359
        model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
360

361
    def test_forward_signature(self):
362
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
363

364
        for model_class in self.all_model_classes:
365
            model = model_class(config)
366
            signature = inspect.signature(model.call)
367
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
368
            arg_names = [*signature.parameters.keys()]
369

370
            expected_arg_names = [
371
                "input_features",
372
                "decoder_input_ids",
373
                "decoder_attention_mask",
374
            ]
375
            expected_arg_names.extend(
376
                ["decoder_position_ids", "head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
377
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
378
                else ["encoder_outputs"]
379
            )
380
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
381

382
    def test_hidden_states_output(self):
383
        def check_hidden_states_output(inputs_dict, config, model_class):
384
            model = model_class(config)
385
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
386

387
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
388

389
            expected_num_layers = getattr(
390
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
391
            )
392
            self.assertEqual(len(hidden_states), expected_num_layers)
393

394
            if hasattr(self.model_tester, "encoder_seq_length"):
395
                seq_length = self.model_tester.encoder_seq_length
396
            else:
397
                seq_length = self.model_tester.seq_length
398

399
            subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
400

401
            self.assertListEqual(
402
                list(hidden_states[0].shape[-2:]),
403
                [subsampled_seq_length, self.model_tester.hidden_size],
404
            )
405

406
            if config.is_encoder_decoder:
407
                hidden_states = outputs.decoder_hidden_states
408

409
                self.assertIsInstance(hidden_states, (list, tuple))
410
                self.assertEqual(len(hidden_states), expected_num_layers)
411

412
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
413

414
                self.assertListEqual(
415
                    list(hidden_states[0].shape[-2:]),
416
                    [decoder_seq_length, self.model_tester.hidden_size],
417
                )
418

419
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
420

421
        for model_class in self.all_model_classes:
422
            inputs_dict["output_hidden_states"] = True
423
            check_hidden_states_output(inputs_dict, config, model_class)
424

425
            # check that output_hidden_states also work using config
426
            del inputs_dict["output_hidden_states"]
427
            config.output_hidden_states = True
428

429
            check_hidden_states_output(inputs_dict, config, model_class)
430

431
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
432
        # We override with a slightly higher tol value, as test recently became flaky
433
        super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
434

435
    def test_attention_outputs(self):
436
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
437
        config.return_dict = True
438

439
        seq_len = getattr(self.model_tester, "seq_length", None)
440
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
441
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
442
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
443
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", encoder_key_length)
444

445
        for model_class in self.all_model_classes:
446
            inputs_dict["output_attentions"] = True
447
            inputs_dict["output_hidden_states"] = False
448
            config.return_dict = True
449
            model = model_class(config)
450

451
            subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
452
            subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
453

454
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
455
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
456
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
457

458
            # check that output_attentions also work using config
459
            del inputs_dict["output_attentions"]
460
            config.output_attentions = True
461
            model = model_class(config)
462

463
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
464
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
465
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
466

467
            self.assertListEqual(
468
                list(attentions[0].shape[-3:]),
469
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
470
            )
471
            out_len = len(outputs)
472

473
            correct_outlen = 5
474

475
            # loss is at first position
476
            if "labels" in inputs_dict:
477
                correct_outlen += 1  # loss is added to beginning
478
            if "past_key_values" in outputs:
479
                correct_outlen += 1  # past_key_values have been returned
480

481
            self.assertEqual(out_len, correct_outlen)
482

483
            # decoder attentions
484
            decoder_attentions = outputs.decoder_attentions
485
            self.assertIsInstance(decoder_attentions, (list, tuple))
486
            self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
487
            self.assertListEqual(
488
                list(decoder_attentions[0].shape[-3:]),
489
                [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
490
            )
491

492
            # cross attentions
493
            cross_attentions = outputs.cross_attentions
494
            self.assertIsInstance(cross_attentions, (list, tuple))
495
            self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
496
            self.assertListEqual(
497
                list(cross_attentions[0].shape[-3:]),
498
                [
499
                    self.model_tester.num_attention_heads,
500
                    decoder_seq_length,
501
                    subsampled_encoder_key_length,
502
                ],
503
            )
504

505
            # Check attention is always last and order is fine
506
            inputs_dict["output_attentions"] = True
507
            inputs_dict["output_hidden_states"] = True
508
            model = model_class(config)
509
            outputs = model(**self._prepare_for_class(inputs_dict, model_class))
510

511
            added_hidden_states = 2
512
            self.assertEqual(out_len + added_hidden_states, len(outputs))
513

514
            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
515

516
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
517
            self.assertListEqual(
518
                list(self_attentions[0].shape[-3:]),
519
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
520
            )
521

522
    def test_generate_without_input_ids(self):
523
        pass
524

525
    @staticmethod
526
    def _get_encoder_outputs(
527
        model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
528
    ):
529
        encoder = model.get_encoder()
530
        encoder_outputs = encoder(
531
            input_ids,
532
            output_attentions=output_attentions,
533
            output_hidden_states=output_hidden_states,
534
        )
535
        encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
536
            num_interleave, dim=0
537
        )
538
        input_ids = input_ids[:, :, 0]
539
        input_ids = tf.zeros_like(input_ids[:, :1], dtype=tf.int64) + tf.convert_to_tensor(
540
            [model._get_decoder_start_token_id()]
541
        )
542
        attention_mask = None
543
        return encoder_outputs, input_ids, attention_mask
544

545
    def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
546
        batch_size, mel, seq_length = input_ids.shape
547
        subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
548
        num_sequences_in_output = batch_size * num_return_sequences
549
        gen_len = (
550
            output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
551
        )
552

553
        # scores
554
        self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
555

556
        # Attentions
557
        # encoder
558
        self._check_encoder_attention_for_generate(
559
            output.encoder_attentions, batch_size, config, subsampled_seq_length
560
        )
561
        # decoder
562
        self._check_attentions_for_generate(
563
            num_sequences_in_output,
564
            output.decoder_attentions,
565
            min_length=1,
566
            max_length=output.sequences.shape[-1],
567
            config=config,
568
            use_cache=use_cache,
569
        )
570

571
        # Hidden States
572
        # encoder
573
        self._check_encoder_hidden_states_for_generate(
574
            output.encoder_hidden_states, batch_size, config, subsampled_seq_length
575
        )
576

577
        # decoder
578
        self._check_hidden_states_for_generate(
579
            num_sequences_in_output,
580
            output.decoder_hidden_states,
581
            min_length=1,
582
            max_length=output.sequences.shape[-1],
583
            config=config,
584
            use_cache=use_cache,
585
        )
586

587
    # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
588
    # `input_features`
589
    def test_lm_head_model_random_no_beam_search_generate(self):
590
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
591
        input_features = inputs_dict.get("input_features", None)
592

593
        # iterate over all generative models
594
        for model_class in self.all_generative_model_classes:
595
            model = model_class(config)
596

597
            if config.bos_token_id is None:
598
                # if bos token id is not defined model needs input_features
599
                with self.assertRaises(AssertionError):
600
                    model.generate(do_sample=True, max_length=5)
601
                # num_return_sequences = 1
602
                self._check_generated_ids(model.generate(input_features, do_sample=True))
603

604
            with self.assertRaises(ValueError):
605
                # generating multiple sequences when no beam search generation
606
                # is not allowed as it would always generate the same sequences
607
                model.generate(input_features, do_sample=False, num_return_sequences=2)
608

609
            # num_return_sequences > 1, sample
610
            self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
611

612
            # check bad words tokens language generation
613
            # create list of 1-seq bad token and list of 2-seq of bad tokens
614
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
615
            output_tokens = model.generate(
616
                input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
617
            )
618
            # only count generated tokens
619
            generated_ids = output_tokens[:, input_features.shape[-1] :]
620
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
621

622
    # overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
623
    # `input_features`
624
    def test_lm_head_model_random_beam_search_generate(self):
625
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
626
        input_features = inputs_dict.get("input_features", None)
627

628
        for model_class in self.all_generative_model_classes:
629
            model = model_class(config)
630

631
            if config.bos_token_id is None:
632
                # if bos token id is not defined model needs input_ids, num_return_sequences = 1
633
                self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
634

635
            with self.assertRaises(ValueError):
636
                # generating more sequences than having beams leads is not possible
637
                model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
638

639
            # num_return_sequences > 1, sample
640
            self._check_generated_ids(
641
                model.generate(
642
                    input_features,
643
                    do_sample=True,
644
                    num_beams=2,
645
                    num_return_sequences=2,
646
                )
647
            )
648
            # num_return_sequences > 1, greedy
649
            self._check_generated_ids(
650
                model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
651
            )
652

653
            # check bad words tokens language generation
654
            # create list of 1-seq bad token and list of 2-seq of bad tokens
655
            bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
656
            output_tokens = model.generate(
657
                input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
658
            )
659
            # only count generated tokens
660
            generated_ids = output_tokens[:, input_features.shape[-1] :]
661
            self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
662

663
    def test_generate_with_prompt_ids_and_task_and_language(self):
664
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
665
        model = TFWhisperForConditionalGeneration(config)
666
        input_features = input_dict["input_features"]
667
        prompt_ids = np.arange(5)
668
        language = "<|de|>"
669
        task = "translate"
670
        lang_id = 6
671
        task_id = 7
672
        model.generation_config.__setattr__("lang_to_id", {language: lang_id})
673
        model.generation_config.__setattr__("task_to_id", {task: task_id})
674

675
        output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
676

677
        expected_output_start = [
678
            *prompt_ids.tolist(),
679
            model.generation_config.decoder_start_token_id,
680
            lang_id,
681
            task_id,
682
        ]
683
        for row in output.numpy().tolist():
684
            self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
685

686
    def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
687
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
688
        model = TFWhisperForConditionalGeneration(config)
689
        input_features = input_dict["input_features"]
690
        prompt_ids = np.asarray(range(5))
691
        forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
692

693
        output = model.generate(
694
            input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
695
        )
696

697
        expected_output_start = [
698
            *prompt_ids.tolist(),
699
            model.generation_config.decoder_start_token_id,
700
            *[token for _rank, token in forced_decoder_ids],
701
        ]
702
        for row in output.numpy().tolist():
703
            self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
704

705

706
def _load_datasamples(num_samples):
707
    ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
708
    # automatic decoding with librispeech
709
    speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
710

711
    return [x["array"] for x in speech_samples]
712

713

714
def _test_large_logits_librispeech(in_queue, out_queue, timeout):
715
    error = None
716
    try:
717
        _ = in_queue.get(timeout=timeout)
718

719
        set_seed(0)
720

721
        model = TFWhisperModel.from_pretrained("openai/whisper-large")
722

723
        input_speech = _load_datasamples(1)
724

725
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
726
        processed_inputs = processor(
727
            audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="tf"
728
        )
729
        input_features = processed_inputs.input_features
730
        decoder_input_ids = processed_inputs.labels
731

732
        logits = model(
733
            input_features,
734
            decoder_input_ids=decoder_input_ids,
735
            output_hidden_states=False,
736
            output_attentions=False,
737
            use_cache=False,
738
        )
739

740
        logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
741

742
        # fmt: off
743
        EXPECTED_LOGITS = tf.convert_to_tensor(
744
            [
745
                2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
746
                1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
747
                1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
748
                1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
749
            ]
750
        )
751
        # fmt: on
752

753
        unittest.TestCase().assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
754
    except Exception:
755
        error = f"{traceback.format_exc()}"
756

757
    results = {"error": error}
758
    out_queue.put(results, timeout=timeout)
759
    out_queue.join()
760

761

762
def _test_large_generation(in_queue, out_queue, timeout):
763
    error = None
764
    try:
765
        _ = in_queue.get(timeout=timeout)
766

767
        set_seed(0)
768
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
769
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
770

771
        input_speech = _load_datasamples(1)
772
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
773

774
        generated_ids = model.generate(
775
            input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
776
        )
777
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
778

779
        EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
780
        unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
781
    except Exception:
782
        error = f"{traceback.format_exc()}"
783

784
    results = {"error": error}
785
    out_queue.put(results, timeout=timeout)
786
    out_queue.join()
787

788

789
def _test_large_generation_multilingual(in_queue, out_queue, timeout):
790
    error = None
791
    try:
792
        _ = in_queue.get(timeout=timeout)
793

794
        set_seed(0)
795
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
796
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
797

798
        ds = load_dataset("common_voice", "ja", split="test", streaming=True)
799
        ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
800
        input_speech = next(iter(ds))["audio"]["array"]
801
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
802

803
        generated_ids = model.generate(
804
            input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
805
        )
806
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
807

808
        EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
809
        unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
810

811
        generated_ids = model.generate(
812
            input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
813
        )
814
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
815

816
        EXPECTED_TRANSCRIPT = " Kimura-san called me."
817
        unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
818

819
        generated_ids = model.generate(
820
            input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
821
        )
822
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
823

824
        EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
825
        unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
826
    except Exception:
827
        error = f"{traceback.format_exc()}"
828

829
    results = {"error": error}
830
    out_queue.put(results, timeout=timeout)
831
    out_queue.join()
832

833

834
def _test_large_batched_generation(in_queue, out_queue, timeout):
835
    error = None
836
    try:
837
        _ = in_queue.get(timeout=timeout)
838

839
        set_seed(0)
840
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
841
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
842

843
        input_speech = _load_datasamples(4)
844
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
845
        generated_ids_1 = model.generate(input_features[0:2], max_length=20)
846
        generated_ids_2 = model.generate(input_features[2:4], max_length=20)
847
        generated_ids = np.concatenate([generated_ids_1, generated_ids_2])
848

849
        # fmt: off
850
        EXPECTED_IDS = [
851
            [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
852
            [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
853
            [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
854
            [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
855
        ]
856
        # fmt: on
857

858
        unittest.TestCase().assertEqual(generated_ids.tolist(), EXPECTED_IDS)
859

860
        # fmt: off
861
        EXPECTED_TRANSCRIPT = [
862
            " Mr. Quilter is the apostle of the middle classes and we are glad to",
863
            " Nor is Mr. Quilter's manner less interesting than his matter.",
864
            " He tells us that at this festive season of the year, with Christmas and roast beef",
865
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
866
        ]
867
        # fmt: on
868

869
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
870
        unittest.TestCase().assertListEqual(transcript, EXPECTED_TRANSCRIPT)
871
    except Exception:
872
        error = f"{traceback.format_exc()}"
873

874
    results = {"error": error}
875
    out_queue.put(results, timeout=timeout)
876
    out_queue.join()
877

878

879
@require_tf
880
@require_tokenizers
881
class TFWhisperModelIntegrationTests(unittest.TestCase):
882
    @cached_property
883
    def default_processor(self):
884
        return WhisperProcessor.from_pretrained("openai/whisper-base")
885

886
    def _load_datasamples(self, num_samples):
887
        return _load_datasamples(num_samples)
888

889
    @slow
890
    def test_tiny_logits_librispeech(self):
891
        set_seed(0)
892
        model = TFWhisperModel.from_pretrained("openai/whisper-tiny")
893
        input_speech = self._load_datasamples(1)
894
        feature_extractor = WhisperFeatureExtractor()
895
        input_features = feature_extractor(input_speech, return_tensors="tf").input_features
896

897
        logits = model(
898
            input_features,
899
            decoder_input_ids=tf.convert_to_tensor([[50258, 50259, 50359]]),
900
            output_hidden_states=False,
901
            output_attentions=False,
902
            return_dict=False,
903
            use_cache=False,
904
        )
905

906
        # fmt: off
907
        EXPECTED_LOGITS = tf.convert_to_tensor(
908
            [
909
                2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
910
                0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
911
                4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
912
                0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
913
            ]
914
        )
915
        # fmt: on
916
        self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
917

918
        # fmt: off
919
        EXPECTED_GENERATION = tf.convert_to_tensor(
920
            [
921
                -1.4651, -2.6944, 2.7821, 2.3793, 4.0738, 0.0188, -3.3203, 1.9836,
922
                0.0520, 0.7095, 1.1063, 0.2952, -3.6786, -0.5249, 0.3105, 4.7691,
923
                1.1562, 1.3046, 0.5810, -0.3624, 1.7006, 1.3424, 0.9817, 2.1958,
924
                1.8775, -5.7046, -0.7679, 4.0113, 2.6848, 2.8609
925
            ]
926
        )
927
        # fmt: on
928

929
        head_logits = logits[0] @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
930
        self.assertTrue(np.allclose(head_logits[0, 0, :30], EXPECTED_GENERATION, atol=1e-4))
931

932
    @slow
933
    def test_small_en_logits_librispeech(self):
934
        set_seed(0)
935
        model = TFWhisperModel.from_pretrained("openai/whisper-small.en")
936

937
        input_speech = self._load_datasamples(1)
938

939
        feaure_extractor = WhisperFeatureExtractor()
940
        input_features = feaure_extractor(input_speech, return_tensors="tf").input_features
941

942
        logits = model(
943
            input_features,
944
            decoder_input_ids=tf.convert_to_tensor([[model.config.decoder_start_token_id]]),
945
            output_hidden_states=False,
946
            output_attentions=False,
947
            use_cache=False,
948
        )
949

950
        logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
951

952
        # fmt: off
953
        EXPECTED_LOGITS = tf.convert_to_tensor(
954
            [
955
                -3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
956
                -8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
957
                -6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
958
                -10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
959
                -11.1146, -8.1918
960
            ]
961
        )
962
        # fmt: on
963
        self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
964

965
    @slow
966
    def test_large_logits_librispeech(self):
967
        run_test_in_subprocess(test_case=self, target_func=_test_large_logits_librispeech, inputs=None)
968

969
    @slow
970
    def test_tiny_en_generation(self):
971
        set_seed(0)
972
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
973
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
974
        model.config.decoder_start_token_id = 50257
975

976
        input_speech = self._load_datasamples(1)
977
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
978

979
        generated_ids = model.generate(input_features, num_beams=5, max_length=20)
980
        transcript = processor.tokenizer.batch_decode(generated_ids)[0]
981

982
        EXPECTED_TRANSCRIPT = (
983
            "<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle"
984
            " classes, and we are glad to"
985
        )
986
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
987

988
    @slow
989
    def test_tiny_generation(self):
990
        set_seed(0)
991
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
992
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
993

994
        input_speech = self._load_datasamples(1)
995
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
996

997
        generated_ids = model.generate(input_features, num_beams=5, max_length=20)
998
        transcript = processor.tokenizer.decode(generated_ids[0])
999

1000
        EXPECTED_TRANSCRIPT = (
1001
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
1002
            " classes and we are glad"
1003
        )
1004
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1005

1006
    @slow
1007
    def test_tiny_xla_generation(self):
1008
        set_seed(0)
1009
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1010
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1011

1012
        input_speech = self._load_datasamples(1)
1013
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
1014

1015
        xla_generate = tf.function(model.generate, jit_compile=True)
1016

1017
        generated_ids = model.generate(input_features, num_beams=5, max_length=20)
1018
        generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)
1019

1020
        transcript = processor.tokenizer.decode(generated_ids[0])
1021
        transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
1022

1023
        EXPECTED_TRANSCRIPT = (
1024
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
1025
            " classes and we are glad"
1026
        )
1027
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1028
        self.assertEqual(transcript_xla, EXPECTED_TRANSCRIPT)
1029

1030
    @slow
1031
    def test_large_generation(self):
1032
        run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
1033

1034
    @slow
1035
    def test_large_generation_multilingual(self):
1036
        run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
1037

1038
    @slow
1039
    def test_large_batched_generation(self):
1040
        run_test_in_subprocess(test_case=self, target_func=_test_large_batched_generation, inputs=None)
1041

1042
    @slow
1043
    def test_tiny_en_batched_generation(self):
1044
        set_seed(0)
1045
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1046
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1047

1048
        input_speech = self._load_datasamples(4)
1049
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
1050
        generated_ids = model.generate(input_features, max_length=20)
1051

1052
        # fmt: off
1053
        EXPECTED_LOGITS = tf.convert_to_tensor(
1054
            [
1055
                [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
1056
                [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
1057
                [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
1058
                [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
1059
            ]
1060

1061
        )
1062
        # fmt: on
1063

1064
        self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
1065

1066
        # fmt: off
1067
        EXPECTED_TRANSCRIPT = [
1068
            " Mr. Quilter is the apostle of the middle classes, and we are glad to",
1069
            " Nor is Mr. Quilter's manner less interesting than his matter.",
1070
            " He tells us that at this festive season of the year, with Christmas and roast beef looming",
1071
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
1072
        ]
1073
        # fmt: on
1074

1075
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
1076
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
1077

1078
    @slow
1079
    def test_tiny_en_batched_xla_generation(self):
1080
        set_seed(0)
1081
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1082
        model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1083

1084
        input_speech = self._load_datasamples(4)
1085
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
1086

1087
        xla_generate = tf.function(model.generate, jit_compile=True)
1088

1089
        generated_ids = model.generate(input_features, max_length=20)
1090
        generated_ids_xla = xla_generate(input_features, max_length=20)
1091

1092
        # fmt: off
1093
        EXPECTED_LOGITS = tf.convert_to_tensor(
1094
            [
1095
                [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
1096
                [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
1097
                [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
1098
                [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
1099
            ]
1100

1101
        )
1102
        # fmt: on
1103

1104
        self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
1105
        self.assertTrue(np.allclose(generated_ids_xla, EXPECTED_LOGITS))
1106

1107
        # fmt: off
1108
        EXPECTED_TRANSCRIPT = [
1109
            " Mr. Quilter is the apostle of the middle classes, and we are glad to",
1110
            " Nor is Mr. Quilter's manner less interesting than his matter.",
1111
            " He tells us that at this festive season of the year, with Christmas and roast beef looming",
1112
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
1113
        ]
1114
        # fmt: on
1115

1116
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
1117
        transcript_xla = processor.batch_decode(generated_ids_xla, skip_special_tokens=True)
1118
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
1119
        self.assertListEqual(transcript_xla, EXPECTED_TRANSCRIPT)
1120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.