transformers

Форк
0
/
test_modeling_whisper.py 
3236 строк · 204.7 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch Whisper model. """
16

17
import copy
18
import inspect
19
import os
20
import random
21
import re
22
import tempfile
23
import time
24
import unittest
25

26
import numpy as np
27
import pytest
28
from huggingface_hub import hf_hub_download
29

30
import transformers
31
from transformers import WhisperConfig
32
from transformers.testing_utils import (
33
    is_pt_flax_cross_test,
34
    require_flash_attn,
35
    require_torch,
36
    require_torch_fp16,
37
    require_torch_gpu,
38
    require_torchaudio,
39
    slow,
40
    torch_device,
41
)
42
from transformers.utils import cached_property, is_flax_available, is_torch_available, is_torchaudio_available
43
from transformers.utils.import_utils import is_datasets_available
44

45
from ...generation.test_utils import GenerationTesterMixin
46
from ...test_configuration_common import ConfigTester
47
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
48
from ...test_pipeline_mixin import PipelineTesterMixin
49

50

51
if is_datasets_available():
52
    import datasets
53
    from datasets import Audio, load_dataset
54

55
if is_torch_available():
56
    import torch
57

58
    from transformers import (
59
        WhisperFeatureExtractor,
60
        WhisperForAudioClassification,
61
        WhisperForCausalLM,
62
        WhisperForConditionalGeneration,
63
        WhisperModel,
64
        WhisperProcessor,
65
        set_seed,
66
    )
67
    from transformers.generation.logits_process import LogitsProcessor
68
    from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
69

70
    class DummyTimestampLogitProcessor(LogitsProcessor):
71
        """This processor fakes the correct timestamps tokens pattern [TOK_1] [TOK_2] ... [TOK_N] [TIME_STAMP_TOK_1] [TIME_STAMP_TOK_2] [TOK_N+1] ..."""
72

73
        def __init__(
74
            self, timestamp_begin, vocab_size, batch_size, max_length, min_space=3, seed=0, is_length_ascending=True
75
        ):
76
            self.timestamp_begin = timestamp_begin
77
            self.vocab_size = vocab_size
78

79
            self.min_space_between_timestamps = min_space
80
            self.timestamp_tokens = torch.arange(self.timestamp_begin, self.vocab_size)
81
            self.timestamp_tokens.to(torch_device)
82
            self.is_length_ascending = is_length_ascending
83

84
            self.no_time_stamp_counter = batch_size * [0]
85
            self.prev_highest_timestamp = batch_size * [0]
86
            self.batch_size = batch_size
87
            self.max_length = max_length
88
            self.count = 0
89
            self.begin_index = 0
90

91
            self.let_pass = [[] for _ in range(batch_size)]
92
            for k in range(batch_size):
93
                random.seed(seed + k)
94
                for _ in range(10000):
95
                    self.let_pass[k].append(random.randint(1, 10) <= 3)
96

97
        def set_begin_index(self, begin_index: int):
98
            self.begin_index = begin_index
99

100
        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
101
            # we don't want to randomely sample timestamp tokens
102
            if input_ids.shape[-1] != self.begin_index:
103
                scores[:, self.timestamp_begin :] = -float("inf")
104

105
            self.no_time_stamp_counter = [x + 1 for x in self.no_time_stamp_counter]
106
            for k in range(input_ids.shape[0]):
107
                # make sure to use correct index if a batch was removed
108
                if self.is_length_ascending and input_ids.shape[0] < self.batch_size:
109
                    prev_k = k + self.batch_size - input_ids.shape[0]
110
                else:
111
                    prev_k = k
112

113
                if input_ids[k, -1] == self.timestamp_begin:
114
                    self.no_time_stamp_counter[prev_k] = 0
115

116
                can_produce = self.no_time_stamp_counter[prev_k] > self.min_space_between_timestamps
117
                must_produce = (
118
                    input_ids[k][2:].le(self.timestamp_begin).all() and input_ids.shape[-1] == self.max_length - 1
119
                )
120
                # produce timestamp with 30%
121
                if (can_produce and self.let_pass[prev_k][self.count]) or must_produce:
122
                    self.no_time_stamp_counter[prev_k] = 0
123
                    self.prev_highest_timestamp[prev_k] = max(input_ids[k].max() + 1, self.timestamp_tokens[0].item())
124

125
                    # force a timestamp
126
                    scores[k, :] = -float("inf")
127
                    scores[k, self.prev_highest_timestamp[prev_k]] = 10.0
128

129
                if (
130
                    input_ids.shape[-1] > 3
131
                    and input_ids[k, -1].item() in self.timestamp_tokens
132
                    and input_ids[k, -2].item() not in self.timestamp_tokens
133
                ):
134
                    # force the same as before
135
                    scores[k, :] = -float("inf")
136
                    scores[k, input_ids[k, -1].item()] = 10.0
137

138
            self.count += 1
139

140
            if torch.isinf(scores).all():
141
                raise ValueError("Dummy logit processor is incorrectly set up. Scores should not be all inf.")
142

143
            return scores
144

145

146
if is_torchaudio_available():
147
    import torchaudio
148

149

150
if is_flax_available():
151
    import jax.numpy as jnp
152

153
    from transformers.modeling_flax_pytorch_utils import (
154
        convert_pytorch_state_dict_to_flax,
155
        load_flax_weights_in_pytorch_model,
156
    )
157

158

159
def prepare_whisper_inputs_dict(
160
    config,
161
    input_features,
162
    decoder_input_ids,
163
    attention_mask=None,
164
    decoder_attention_mask=None,
165
    head_mask=None,
166
    decoder_head_mask=None,
167
    cross_attn_head_mask=None,
168
):
169
    if decoder_attention_mask is None:
170
        decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
171
    if head_mask is None:
172
        head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
173
    if decoder_head_mask is None:
174
        decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
175
    if cross_attn_head_mask is None:
176
        cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
177
    return {
178
        # "input_ids": input_features,
179
        "input_features": input_features,
180
        "decoder_input_ids": decoder_input_ids,
181
        "decoder_attention_mask": decoder_attention_mask,
182
        "head_mask": head_mask,
183
        "decoder_head_mask": decoder_head_mask,
184
        "cross_attn_head_mask": cross_attn_head_mask,
185
    }
186

187

188
@require_torch
189
class WhisperModelTester:
190
    def __init__(
191
        self,
192
        parent,
193
        batch_size=2,
194
        seq_length=60,
195
        is_training=True,
196
        use_labels=False,
197
        vocab_size=200,
198
        hidden_size=16,
199
        num_hidden_layers=2,
200
        num_attention_heads=4,
201
        input_channels=1,
202
        hidden_act="gelu",
203
        hidden_dropout_prob=0.1,
204
        attention_probs_dropout_prob=0.1,
205
        max_position_embeddings=20,
206
        max_source_positions=30,
207
        max_target_positions=40,
208
        bos_token_id=98,
209
        eos_token_id=98,
210
        pad_token_id=0,
211
        num_mel_bins=80,
212
        decoder_start_token_id=85,
213
        num_conv_layers=1,
214
        suppress_tokens=None,
215
        begin_suppress_tokens=None,
216
    ):
217
        self.parent = parent
218
        self.batch_size = batch_size
219
        self.seq_length = seq_length
220
        self.is_training = is_training
221
        self.use_labels = use_labels
222
        self.vocab_size = vocab_size
223
        self.hidden_size = hidden_size
224
        self.num_hidden_layers = num_hidden_layers
225
        self.num_attention_heads = num_attention_heads
226
        self.input_channels = input_channels
227
        self.hidden_act = hidden_act
228
        self.hidden_dropout_prob = hidden_dropout_prob
229
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
230
        self.num_mel_bins = num_mel_bins
231
        self.max_position_embeddings = max_position_embeddings
232
        self.max_source_positions = max_source_positions
233
        self.max_target_positions = max_target_positions
234
        self.eos_token_id = eos_token_id
235
        self.pad_token_id = pad_token_id
236
        self.bos_token_id = bos_token_id
237
        self.decoder_start_token_id = decoder_start_token_id
238
        self.num_conv_layers = num_conv_layers
239
        self.suppress_tokens = suppress_tokens
240
        self.begin_suppress_tokens = begin_suppress_tokens
241

242
    def prepare_config_and_inputs(self):
243
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
244

245
        decoder_input_ids = torch.tensor(self.batch_size * [[self.decoder_start_token_id]], device=torch_device)
246

247
        config = self.get_config()
248
        inputs_dict = prepare_whisper_inputs_dict(
249
            config,
250
            attention_mask=None,
251
            input_features=input_features,
252
            decoder_input_ids=decoder_input_ids,
253
        )
254
        return config, inputs_dict
255

256
    def get_config(self):
257
        return WhisperConfig(
258
            vocab_size=self.vocab_size,
259
            d_model=self.hidden_size,
260
            encoder_layers=self.num_hidden_layers,
261
            decoder_layers=self.num_hidden_layers,
262
            encoder_attention_heads=self.num_attention_heads,
263
            decoder_attention_heads=self.num_attention_heads,
264
            input_channels=self.input_channels,
265
            dropout=self.hidden_dropout_prob,
266
            attention_dropout=self.attention_probs_dropout_prob,
267
            max_position_embeddings=self.max_position_embeddings,
268
            max_source_positions=self.max_source_positions,
269
            max_target_positions=self.max_target_positions,
270
            eos_token_id=self.eos_token_id,
271
            bos_token_id=self.bos_token_id,
272
            pad_token_id=self.pad_token_id,
273
            decoder_ffn_dim=self.hidden_size,
274
            encoder_ffn_dim=self.hidden_size,
275
            decoder_start_token_id=self.decoder_start_token_id,
276
            suppress_tokens=self.suppress_tokens,
277
            begin_suppress_tokens=self.begin_suppress_tokens,
278
        )
279

280
    def prepare_config_and_inputs_for_common(self):
281
        config, inputs_dict = self.prepare_config_and_inputs()
282
        return config, inputs_dict
283

284
    def get_subsampled_output_lengths(self, input_lengths):
285
        """
286
        Computes the output length of the convolutional layers
287
        """
288

289
        for i in range(self.num_conv_layers):
290
            input_lengths = (input_lengths - 1) // 2 + 1
291

292
        return input_lengths
293

294
    def create_and_check_model_forward(self, config, inputs_dict, freeze_encoder=False):
295
        model = WhisperModel(config=config).to(torch_device).eval()
296

297
        if freeze_encoder:
298
            model.freeze_encoder()
299

300
        input_features = inputs_dict["input_features"]
301
        decoder_input_ids = inputs_dict["decoder_input_ids"]
302

303
        # first forward pass
304
        last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
305

306
        self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
307

308
    def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
309
        model = WhisperModel(config=config).get_decoder().to(torch_device).eval()
310
        input_ids = inputs_dict["decoder_input_ids"]
311
        attention_mask = inputs_dict["decoder_attention_mask"]
312

313
        # first forward pass
314
        outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
315

316
        output, past_key_values = outputs.to_tuple()
317

318
        # create hypothetical multiple next token and extent to next_input_ids
319
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2)
320
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
321

322
        # append to next input_ids and
323
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
324
        next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
325

326
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
327
        output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
328
            "last_hidden_state"
329
        ]
330

331
        # select random slice
332
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
333
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
334
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
335

336
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
337

338
        # test that outputs are equal for slice
339
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
340

341
    def check_encoder_decoder_model_standalone(self, config, inputs_dict):
342
        model = WhisperModel(config=config).to(torch_device).eval()
343
        outputs = model(**inputs_dict)
344

345
        encoder_last_hidden_state = outputs.encoder_last_hidden_state
346
        last_hidden_state = outputs.last_hidden_state
347

348
        with tempfile.TemporaryDirectory() as tmpdirname:
349
            encoder = model.get_encoder()
350
            encoder.save_pretrained(tmpdirname)
351
            encoder = WhisperEncoder.from_pretrained(tmpdirname).to(torch_device)
352

353
        encoder_last_hidden_state_2 = encoder(inputs_dict["input_features"])[0]
354

355
        self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
356

357
        with tempfile.TemporaryDirectory() as tmpdirname:
358
            decoder = model.get_decoder()
359
            decoder.save_pretrained(tmpdirname)
360
            decoder = WhisperDecoder.from_pretrained(tmpdirname).to(torch_device)
361

362
        last_hidden_state_2 = decoder(
363
            input_ids=inputs_dict["decoder_input_ids"],
364
            attention_mask=inputs_dict["decoder_attention_mask"],
365
            encoder_hidden_states=encoder_last_hidden_state,
366
        )[0]
367

368
        self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
369

370

371
@require_torch
372
class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
373
    all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
374
    all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
375
    pipeline_model_mapping = (
376
        {
377
            "audio-classification": WhisperForAudioClassification,
378
            "automatic-speech-recognition": WhisperForConditionalGeneration,
379
            "feature-extraction": WhisperModel,
380
            "text-generation": WhisperForCausalLM,
381
        }
382
        if is_torch_available()
383
        else {}
384
    )
385
    is_encoder_decoder = True
386
    fx_compatible = False
387
    test_pruning = False
388
    test_missing_keys = False
389
    # Needs higher percentages after model tester's vocab_size is changed to 200 (PR #21222)
390
    # `0.5` is for `test_disk_offload` (which also works for `test_model_parallelism`)
391
    model_split_percents = [0.5, 0.8, 0.9]
392

393
    input_name = "input_features"
394

395
    # TODO: Fix the failed tests
396
    def is_pipeline_test_to_skip(
397
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
398
    ):
399
        if pipeline_test_casse_name in [
400
            "AutomaticSpeechRecognitionPipelineTests",
401
            "AudioClassificationPipelineTests",
402
        ]:
403
            # RuntimeError: The size of tensor a (1500) must match the size of tensor b (30) at non-singleton
404
            # dimension 1
405
            return True
406

407
        return False
408

409
    def setUp(self):
410
        self.model_tester = WhisperModelTester(self)
411
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
412
        self.maxDiff = 3000
413

414
    def test_config(self):
415
        self.config_tester.run_common_tests()
416

417
    def test_save_load_strict(self):
418
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
419
        for model_class in self.all_model_classes:
420
            model = model_class(config)
421

422
            with tempfile.TemporaryDirectory() as tmpdirname:
423
                model.save_pretrained(tmpdirname)
424
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
425
            self.assertEqual(info["missing_keys"], [])
426

427
    def test_model_forward(self):
428
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
429
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
430

431
    def test_model_forward_with_frozen_encoder(self):
432
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
433
        self.model_tester.create_and_check_model_forward(*config_and_inputs, freeze_encoder=True)
434

435
    def test_requires_grad_with_frozen_encoder(self):
436
        config = self.model_tester.get_config()
437
        for model_class in self.all_model_classes:
438
            model = model_class(config)
439
            model.freeze_encoder()
440

441
            try:
442
                encoder_grads = [param.requires_grad for param in model.encoder.parameters()]
443
                decoder_grads = [param.requires_grad for param in model.decoder.parameters()]
444
            except AttributeError:
445
                encoder_grads = [param.requires_grad for param in model.model.encoder.parameters()]
446
                decoder_grads = [param.requires_grad for param in model.model.decoder.parameters()]
447

448
            self.assertFalse(all(encoder_grads))
449
            self.assertTrue(all(decoder_grads))
450

451
    def test_requires_grad_encoder_embed_positions(self):
452
        config = self.model_tester.get_config()
453
        for model_class in self.all_model_classes:
454
            model = model_class(config)
455
            encoder = model.get_encoder()
456
            self.assertFalse(encoder.embed_positions.weight.requires_grad)
457

458
    def test_encoder_sinusoidal_embed_positions(self):
459
        config = self.model_tester.get_config()
460
        for model_class in self.all_model_classes:
461
            model = model_class(config)
462
            embeds = model.get_encoder().embed_positions.weight
463
            self.assertTrue(torch.allclose(embeds, sinusoids(*embeds.shape)))
464

465
    def test_decoder_model_past_with_large_inputs(self):
466
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
467
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
468

469
    def test_encoder_decoder_model_standalone(self):
470
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
471
        self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
472

473
    def _get_input_ids_and_config(self, batch_size=3):
474
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
475
        input_ids = inputs_dict[self.input_name]
476

477
        # cut to half length & take max batch_size=batch_size
478
        input_ids = input_ids[:batch_size, :, :]
479

480
        # generate max 3 tokens
481
        max_length = 4
482
        if config.eos_token_id is not None and config.pad_token_id is None:
483
            # hack to allow generate for models such as GPT2 as is done in `generate()`
484
            config.pad_token_id = config.eos_token_id
485

486
        return config, input_ids, None, max_length
487

488
    def test_inputs_embeds(self):
489
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
490

491
        for model_class in self.all_model_classes:
492
            model = model_class(config)
493
            model.to(torch_device)
494
            model.eval()
495

496
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
497

498
            decoder_input_ids = inputs.pop("decoder_input_ids", None)
499
            inputs.pop("decoder_attention_mask", None)
500

501
            wte = model.get_input_embeddings()
502
            inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
503

504
            with torch.no_grad():
505
                model(**inputs)[0]
506

507
    # training is not supported yet
508
    def test_training(self):
509
        pass
510

511
    def test_training_gradient_checkpointing(self):
512
        pass
513

514
    @unittest.skip(
515
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
516
    )
517
    def test_training_gradient_checkpointing_use_reentrant(self):
518
        pass
519

520
    @unittest.skip(
521
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
522
    )
523
    def test_training_gradient_checkpointing_use_reentrant_false(self):
524
        pass
525

526
    def test_generate_with_head_masking(self):
527
        pass
528

529
    @require_torch_fp16
530
    def test_generate_fp16(self):
531
        config, input_dict = self.model_tester.prepare_config_and_inputs()
532
        config.max_target_positions = 400
533
        input_features = input_dict["input_features"]
534
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
535
        input_features = input_features.half()
536
        model.half()
537
        model.generate(input_features)
538
        model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
539

540
    def test_generate_language(self):
541
        config, input_dict = self.model_tester.prepare_config_and_inputs()
542
        input_features = input_dict["input_features"]
543
        model = WhisperForConditionalGeneration(config).to(torch_device)
544
        # Hack to keep the test fast and not require downloading a model with a generation_config
545
        model.generation_config.__setattr__("lang_to_id", {"<|en|>": 1})
546
        model.generation_config.__setattr__("task_to_id", {"transcribe": 2})
547

548
        # test language code
549
        model.generate(input_features, language="en")
550
        # test tokenizer code
551
        model.generate(input_features, language="<|en|>")
552
        # test language name
553
        model.generate(input_features, language="English")
554

555
    def test_forward_signature(self):
556
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
557

558
        for model_class in self.all_model_classes:
559
            model = model_class(config)
560
            signature = inspect.signature(model.forward)
561
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
562
            arg_names = [*signature.parameters.keys()]
563

564
            expected_arg_names = [
565
                "input_features",
566
                "attention_mask",
567
                "decoder_input_ids",
568
                "decoder_attention_mask",
569
            ]
570
            expected_arg_names.extend(
571
                ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
572
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
573
                else ["encoder_outputs"]
574
            )
575
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
576

577
    def test_hidden_states_output(self):
578
        def check_hidden_states_output(inputs_dict, config, model_class):
579
            model = model_class(config)
580
            model.to(torch_device)
581
            model.eval()
582

583
            with torch.no_grad():
584
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
585

586
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
587

588
            expected_num_layers = getattr(
589
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
590
            )
591
            self.assertEqual(len(hidden_states), expected_num_layers)
592

593
            if hasattr(self.model_tester, "encoder_seq_length"):
594
                seq_length = self.model_tester.encoder_seq_length
595
            else:
596
                seq_length = self.model_tester.seq_length
597

598
            subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
599

600
            self.assertListEqual(
601
                list(hidden_states[0].shape[-2:]),
602
                [subsampled_seq_length, self.model_tester.hidden_size],
603
            )
604

605
            if config.is_encoder_decoder:
606
                hidden_states = outputs.decoder_hidden_states
607

608
                self.assertIsInstance(hidden_states, (list, tuple))
609
                self.assertEqual(len(hidden_states), expected_num_layers)
610

611
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
612

613
                self.assertListEqual(
614
                    list(hidden_states[0].shape[-2:]),
615
                    [decoder_seq_length, self.model_tester.hidden_size],
616
                )
617

618
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
619

620
        for model_class in self.all_model_classes:
621
            inputs_dict["output_hidden_states"] = True
622
            check_hidden_states_output(inputs_dict, config, model_class)
623

624
            # check that output_hidden_states also work using config
625
            del inputs_dict["output_hidden_states"]
626
            config.output_hidden_states = True
627

628
            check_hidden_states_output(inputs_dict, config, model_class)
629

630
    def test_attention_outputs(self):
631
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
632
        config.return_dict = True
633

634
        seq_len = getattr(self.model_tester, "seq_length", None)
635
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
636
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
637
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", 1)
638
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
639

640
        for model_class in self.all_model_classes:
641
            inputs_dict["output_attentions"] = True
642
            inputs_dict["output_hidden_states"] = False
643
            config.return_dict = True
644
            model = model_class(config)
645
            model.to(torch_device)
646
            model.eval()
647

648
            subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
649
            subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
650

651
            with torch.no_grad():
652
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
653
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
654
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
655

656
            # check that output_attentions also work using config
657
            del inputs_dict["output_attentions"]
658
            config.output_attentions = True
659
            model = model_class(config)
660
            model.to(torch_device)
661
            model.eval()
662
            with torch.no_grad():
663
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
664
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
665
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
666

667
            self.assertListEqual(
668
                list(attentions[0].shape[-3:]),
669
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
670
            )
671
            out_len = len(outputs)
672

673
            correct_outlen = 5
674

675
            # loss is at first position
676
            if "labels" in inputs_dict:
677
                correct_outlen += 1  # loss is added to beginning
678
            if "past_key_values" in outputs:
679
                correct_outlen += 1  # past_key_values have been returned
680

681
            self.assertEqual(out_len, correct_outlen)
682

683
            # decoder attentions
684
            decoder_attentions = outputs.decoder_attentions
685
            self.assertIsInstance(decoder_attentions, (list, tuple))
686
            self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
687
            self.assertListEqual(
688
                list(decoder_attentions[0].shape[-3:]),
689
                [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
690
            )
691

692
            # cross attentions
693
            cross_attentions = outputs.cross_attentions
694
            self.assertIsInstance(cross_attentions, (list, tuple))
695
            self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
696
            self.assertListEqual(
697
                list(cross_attentions[0].shape[-3:]),
698
                [
699
                    self.model_tester.num_attention_heads,
700
                    decoder_seq_length,
701
                    subsampled_encoder_key_length,
702
                ],
703
            )
704

705
            # Check attention is always last and order is fine
706
            inputs_dict["output_attentions"] = True
707
            inputs_dict["output_hidden_states"] = True
708
            model = model_class(config)
709
            model.to(torch_device)
710
            model.eval()
711
            with torch.no_grad():
712
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
713

714
            added_hidden_states = 2
715
            self.assertEqual(out_len + added_hidden_states, len(outputs))
716

717
            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
718

719
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
720
            self.assertListEqual(
721
                list(self_attentions[0].shape[-3:]),
722
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
723
            )
724

725
    def test_resize_tokens_embeddings(self):
726
        (
727
            original_config,
728
            inputs_dict,
729
        ) = self.model_tester.prepare_config_and_inputs_for_common()
730
        if not self.test_resize_embeddings:
731
            return
732

733
        for model_class in self.all_model_classes:
734
            config = copy.deepcopy(original_config)
735
            model = model_class(config)
736
            model.to(torch_device)
737

738
            if self.model_tester.is_training is False:
739
                model.eval()
740

741
            model_vocab_size = config.vocab_size
742
            # Retrieve the embeddings and clone theme
743
            model_embed = model.resize_token_embeddings(model_vocab_size)
744
            cloned_embeddings = model_embed.weight.clone()
745

746
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
747
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
748
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
749
            # Check that it actually resizes the embeddings matrix
750
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
751
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
752
            model(**self._prepare_for_class(inputs_dict, model_class))
753

754
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
755
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
756
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
757
            # Check that it actually resizes the embeddings matrix
758
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
759

760
            # make sure that decoder_input_ids are resized
761
            if "decoder_input_ids" in inputs_dict:
762
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
763
            model(**self._prepare_for_class(inputs_dict, model_class))
764

765
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
766
            models_equal = True
767
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
768
                if p1.data.ne(p2.data).sum() > 0:
769
                    models_equal = False
770

771
            self.assertTrue(models_equal)
772

773
    def test_resize_embeddings_untied(self):
774
        (
775
            original_config,
776
            inputs_dict,
777
        ) = self.model_tester.prepare_config_and_inputs_for_common()
778
        if not self.test_resize_embeddings:
779
            return
780

781
        original_config.tie_word_embeddings = False
782

783
        # if model cannot untied embeddings -> leave test
784
        if original_config.tie_word_embeddings:
785
            return
786

787
        for model_class in self.all_model_classes:
788
            config = copy.deepcopy(original_config)
789
            model = model_class(config).to(torch_device)
790

791
            # if no output embeddings -> leave test
792
            if model.get_output_embeddings() is None:
793
                continue
794

795
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
796
            model_vocab_size = config.vocab_size
797
            model.resize_token_embeddings(model_vocab_size + 10)
798
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
799
            output_embeds = model.get_output_embeddings()
800
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
801
            # Check bias if present
802
            if output_embeds.bias is not None:
803
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
804
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
805
            model(**self._prepare_for_class(inputs_dict, model_class))
806

807
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
808
            model.resize_token_embeddings(model_vocab_size - 15)
809
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
810
            # Check that it actually resizes the embeddings matrix
811
            output_embeds = model.get_output_embeddings()
812
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
813
            # Check bias if present
814
            if output_embeds.bias is not None:
815
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
816
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
817
            if "decoder_input_ids" in inputs_dict:
818
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
819
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
820
            model(**self._prepare_for_class(inputs_dict, model_class))
821

822
    def test_generate_without_input_ids(self):
823
        pass
824

825
    @staticmethod
826
    def _get_encoder_outputs(
827
        model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
828
    ):
829
        encoder = model.get_encoder()
830
        encoder_outputs = encoder(
831
            input_ids,
832
            output_attentions=output_attentions,
833
            output_hidden_states=output_hidden_states,
834
        )
835
        encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
836
            num_interleave, dim=0
837
        )
838
        input_ids = input_ids[:, :, 0]
839
        input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
840
            [model._get_decoder_start_token_id()], device=input_ids.device
841
        )
842
        attention_mask = None
843
        return encoder_outputs, input_ids, attention_mask
844

845
    def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
846
        batch_size, mel, seq_length = input_ids.shape
847
        subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
848
        num_sequences_in_output = batch_size * num_return_sequences
849
        gen_len = (
850
            output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
851
        )
852

853
        # scores
854
        self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
855

856
        # Attentions
857
        # encoder
858
        self._check_encoder_attention_for_generate(
859
            output.encoder_attentions, batch_size, config, subsampled_seq_length
860
        )
861
        # decoder
862
        self._check_attentions_for_generate(
863
            num_sequences_in_output,
864
            output.decoder_attentions,
865
            min_length=1,
866
            max_length=output.sequences.shape[-1],
867
            config=config,
868
            use_cache=use_cache,
869
        )
870

871
        # Hidden States
872
        # encoder
873
        self._check_encoder_hidden_states_for_generate(
874
            output.encoder_hidden_states, batch_size, config, subsampled_seq_length
875
        )
876

877
        # decoder
878
        self._check_hidden_states_for_generate(
879
            num_sequences_in_output,
880
            output.decoder_hidden_states,
881
            min_length=1,
882
            max_length=output.sequences.shape[-1],
883
            config=config,
884
            use_cache=use_cache,
885
        )
886

887
    @require_flash_attn
888
    @require_torch_gpu
889
    @pytest.mark.flash_attn_test
890
    @slow
891
    def test_flash_attn_2_inference(self):
892
        import torch
893

894
        for model_class in self.all_model_classes:
895
            if not model_class._supports_flash_attn_2:
896
                return
897

898
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
899
            model = model_class(config)
900

901
            with tempfile.TemporaryDirectory() as tmpdirname:
902
                model.save_pretrained(tmpdirname)
903
                model_fa = model_class.from_pretrained(
904
                    tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
905
                )
906
                model_fa.to(torch_device)
907

908
                model = model_class.from_pretrained(
909
                    tmpdirname,
910
                    torch_dtype=torch.bfloat16,
911
                )
912
                model.to(torch_device)
913

914
                dummy_input = inputs_dict[model.main_input_name][:1]
915
                if dummy_input.dtype in [torch.float32, torch.float16]:
916
                    dummy_input = dummy_input.to(torch.bfloat16)
917

918
                decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
919

920
                outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
921
                outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
922

923
                logits = outputs.decoder_hidden_states[-1]
924
                logits_fa = outputs_fa.decoder_hidden_states[-1]
925

926
                # whisper FA2 needs very high tolerance
927
                assert torch.allclose(logits_fa, logits, atol=4e-1)
928

929
                # check with inference + dropout
930
                model.train()
931
                _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
932

933
    @require_flash_attn
934
    @require_torch_gpu
935
    @pytest.mark.flash_attn_test
936
    @slow
937
    def test_flash_attn_2_inference_padding_right(self):
938
        import torch
939

940
        for model_class in self.all_model_classes:
941
            if not model_class._supports_flash_attn_2:
942
                return
943

944
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
945
            model = model_class(config)
946

947
            with tempfile.TemporaryDirectory() as tmpdirname:
948
                model.save_pretrained(tmpdirname)
949
                model_fa = model_class.from_pretrained(
950
                    tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
951
                )
952
                model_fa.to(torch_device)
953

954
                model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
955
                model.to(torch_device)
956

957
                dummy_input = inputs_dict[model.main_input_name][:1]
958
                dummy_input = dummy_input.to(torch.float16)
959

960
                decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long)
961
                decoder_attention_mask = torch.tensor(
962
                    [[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long
963
                )
964

965
                outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
966
                outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
967

968
                logits = outputs.decoder_hidden_states[-1]
969
                logits_fa = outputs_fa.decoder_hidden_states[-1]
970

971
                # whisper FA2 needs very high tolerance
972
                assert torch.allclose(logits_fa, logits, atol=4e-1)
973

974
                other_inputs = {
975
                    "decoder_input_ids": decoder_input_ids,
976
                    "decoder_attention_mask": decoder_attention_mask,
977
                    "output_hidden_states": True,
978
                }
979

980
                outputs = model(dummy_input, **other_inputs)
981
                outputs_fa = model_fa(dummy_input, **other_inputs)
982

983
                logits = outputs.decoder_hidden_states[-1]
984
                logits_fa = outputs_fa.decoder_hidden_states[-1]
985

986
                # whisper FA2 needs very high tolerance
987
                assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
988

989
    def _create_and_check_torchscript(self, config, inputs_dict):
990
        if not self.test_torchscript:
991
            return
992

993
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
994
        configs_no_init.torchscript = True
995
        configs_no_init._attn_implementation = "eager"
996
        for model_class in self.all_model_classes:
997
            model = model_class(config=configs_no_init)
998
            model.to(torch_device)
999
            model.eval()
1000
            inputs = self._prepare_for_class(inputs_dict, model_class)
1001

1002
            try:
1003
                model.config.use_cache = False  # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
1004
                input_features = inputs["input_features"]
1005
                decoder_input_ids = inputs["decoder_input_ids"]
1006
                decoder_attention_mask = inputs["decoder_attention_mask"]
1007
                # prepare `attention_mask` with shape (batch_size, sequence_length)
1008
                attention_mask = torch.ones(
1009
                    input_features.shape[0],
1010
                    input_features.shape[-1],
1011
                    device=input_features.device,
1012
                    dtype=input_features.dtype,
1013
                )
1014
                traced_model = torch.jit.trace(
1015
                    model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask)
1016
                )
1017

1018
            except RuntimeError:
1019
                self.fail("Couldn't trace module.")
1020

1021
            with tempfile.TemporaryDirectory() as tmp_dir_name:
1022
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
1023

1024
                try:
1025
                    torch.jit.save(traced_model, pt_file_name)
1026
                except Exception:
1027
                    self.fail("Couldn't save module.")
1028

1029
                try:
1030
                    loaded_model = torch.jit.load(pt_file_name)
1031
                except Exception:
1032
                    self.fail("Couldn't load module.")
1033

1034
            model.to(torch_device)
1035
            model.eval()
1036

1037
            loaded_model.to(torch_device)
1038
            loaded_model.eval()
1039

1040
            model_state_dict = model.state_dict()
1041
            loaded_model_state_dict = loaded_model.state_dict()
1042

1043
            non_persistent_buffers = {}
1044
            for key in loaded_model_state_dict.keys():
1045
                if key not in model_state_dict.keys():
1046
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
1047

1048
            loaded_model_state_dict = {
1049
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
1050
            }
1051

1052
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
1053

1054
            model_buffers = list(model.buffers())
1055
            for non_persistent_buffer in non_persistent_buffers.values():
1056
                found_buffer = False
1057
                for i, model_buffer in enumerate(model_buffers):
1058
                    if torch.equal(non_persistent_buffer, model_buffer):
1059
                        found_buffer = True
1060
                        break
1061

1062
                self.assertTrue(found_buffer)
1063
                model_buffers.pop(i)
1064

1065
            models_equal = True
1066
            for layer_name, p1 in model_state_dict.items():
1067
                p2 = loaded_model_state_dict[layer_name]
1068
                if p1.data.ne(p2.data).sum() > 0:
1069
                    models_equal = False
1070

1071
            self.assertTrue(models_equal)
1072

1073
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
1074
        # We override with a slightly higher tol value, as test recently became flaky
1075
        super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
1076

1077
    def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
1078
        # We override with a slightly higher tol value, as test recently became flaky
1079
        super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
1080

1081
    @is_pt_flax_cross_test
1082
    def test_equivalence_pt_to_flax(self):
1083
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1084
        init_shape = (1,) + inputs_dict["input_features"].shape[1:]
1085

1086
        for model_class in self.all_model_classes:
1087
            with self.subTest(model_class.__name__):
1088
                fx_model_class_name = "Flax" + model_class.__name__
1089

1090
                if not hasattr(transformers, fx_model_class_name):
1091
                    # no flax model exists for this class
1092
                    return
1093

1094
                # Output all for aggressive testing
1095
                config.output_hidden_states = True
1096
                config.output_attentions = self.has_attentions
1097

1098
                fx_model_class = getattr(transformers, fx_model_class_name)
1099

1100
                # load PyTorch class
1101
                pt_model = model_class(config).eval()
1102
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
1103
                # So we disable `use_cache` here for PyTorch model.
1104
                pt_model.config.use_cache = False
1105

1106
                # load Flax class
1107
                fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
1108

1109
                # make sure only flax inputs are forward that actually exist in function args
1110
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
1111

1112
                # prepare inputs
1113
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)
1114

1115
                # remove function args that don't exist in Flax
1116
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
1117

1118
                # send pytorch inputs to the correct device
1119
                pt_inputs = {
1120
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
1121
                }
1122

1123
                # convert inputs to Flax
1124
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
1125

1126
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
1127
                fx_model.params = fx_state
1128

1129
                # send pytorch model to the correct device
1130
                pt_model.to(torch_device)
1131

1132
                with torch.no_grad():
1133
                    pt_outputs = pt_model(**pt_inputs)
1134
                fx_outputs = fx_model(**fx_inputs)
1135

1136
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1137
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1138

1139
                self.assertEqual(fx_keys, pt_keys)
1140
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
1141

1142
                with tempfile.TemporaryDirectory() as tmpdirname:
1143
                    pt_model.save_pretrained(tmpdirname)
1144
                    fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
1145

1146
                fx_outputs_loaded = fx_model_loaded(**fx_inputs)
1147

1148
                fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
1149
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1150

1151
                self.assertEqual(fx_keys, pt_keys)
1152
                self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
1153

1154
    @is_pt_flax_cross_test
1155
    def test_equivalence_flax_to_pt(self):
1156
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1157
        init_shape = (1,) + inputs_dict["input_features"].shape[1:]
1158

1159
        for model_class in self.all_model_classes:
1160
            with self.subTest(model_class.__name__):
1161
                fx_model_class_name = "Flax" + model_class.__name__
1162

1163
                if not hasattr(transformers, fx_model_class_name):
1164
                    # no flax model exists for this class
1165
                    return
1166

1167
                # Output all for aggressive testing
1168
                config.output_hidden_states = True
1169
                config.output_attentions = self.has_attentions
1170

1171
                fx_model_class = getattr(transformers, fx_model_class_name)
1172

1173
                # load PyTorch class
1174
                pt_model = model_class(config).eval()
1175
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
1176
                # So we disable `use_cache` here for PyTorch model.
1177
                pt_model.config.use_cache = False
1178

1179
                # load Flax class
1180
                fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
1181

1182
                # make sure only flax inputs are forward that actually exist in function args
1183
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
1184

1185
                # prepare inputs
1186
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)
1187

1188
                # remove function args that don't exist in Flax
1189
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
1190

1191
                # send pytorch inputs to the correct device
1192
                pt_inputs = {
1193
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
1194
                }
1195

1196
                # convert inputs to Flax
1197
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
1198

1199
                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
1200

1201
                # make sure weights are tied in PyTorch
1202
                pt_model.tie_weights()
1203

1204
                # send pytorch model to the correct device
1205
                pt_model.to(torch_device)
1206

1207
                with torch.no_grad():
1208
                    pt_outputs = pt_model(**pt_inputs)
1209
                fx_outputs = fx_model(**fx_inputs)
1210

1211
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1212
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
1213

1214
                self.assertEqual(fx_keys, pt_keys)
1215
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
1216

1217
                with tempfile.TemporaryDirectory() as tmpdirname:
1218
                    fx_model.save_pretrained(tmpdirname)
1219
                    pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
1220

1221
                # send pytorch model to the correct device
1222
                pt_model_loaded.to(torch_device)
1223
                pt_model_loaded.eval()
1224

1225
                with torch.no_grad():
1226
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs)
1227

1228
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
1229
                pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
1230

1231
                self.assertEqual(fx_keys, pt_keys)
1232
                self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
1233

1234
    def test_mask_feature_prob(self):
1235
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1236
        config.mask_feature_prob = 0.2
1237
        config.mask_feature_length = 2
1238

1239
        for model_class in self.all_model_classes:
1240
            model = model_class(config)
1241
            model.to(torch_device)
1242
            model.train()
1243

1244
            # forward pass
1245
            encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
1246
            self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
1247

1248
    def test_mask_time_prob(self):
1249
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1250
        config.mask_time_prob = 0.2
1251
        config.mask_time_length = 2
1252

1253
        for model_class in self.all_model_classes:
1254
            model = model_class(config)
1255
            model.to(torch_device)
1256
            model.train()
1257

1258
            # forward pass
1259
            encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
1260
            self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
1261

1262
    def test_generate_with_prompt_ids_and_task_and_language(self):
1263
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1264
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1265
        input_features = input_dict["input_features"]
1266
        prompt_ids = torch.arange(5).to(torch_device)
1267
        language = "<|de|>"
1268
        task = "translate"
1269
        lang_id = 6
1270
        task_id = 7
1271
        model.generation_config.__setattr__("lang_to_id", {language: lang_id})
1272
        model.generation_config.__setattr__("task_to_id", {task: task_id})
1273

1274
        output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
1275

1276
        expected_output_start = [
1277
            *prompt_ids.tolist(),
1278
            model.generation_config.decoder_start_token_id,
1279
            lang_id,
1280
            task_id,
1281
        ]
1282
        for row in output.tolist():
1283
            self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
1284

1285
    def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
1286
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1287
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1288
        input_features = input_dict["input_features"]
1289
        prompt_ids = torch.arange(5).to(torch_device)
1290
        forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
1291

1292
        output = model.generate(
1293
            input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
1294
        )
1295

1296
        expected_output_start = [
1297
            *prompt_ids.tolist(),
1298
            model.generation_config.decoder_start_token_id,
1299
            *[token for _rank, token in forced_decoder_ids],
1300
        ]
1301
        for row in output.tolist():
1302
            self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
1303

1304
    def test_generate_with_prompt_ids_max_length(self):
1305
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1306
        config.max_target_positions = 7
1307

1308
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1309
        input_features = input_dict["input_features"]
1310
        decoder_input_ids = torch.arange(5).to(torch_device)
1311
        prompt_ids = decoder_input_ids[:4]
1312
        max_new_tokens = 8
1313

1314
        with self.assertRaisesRegex(
1315
            ValueError,
1316
            f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
1317
            f"is {max_new_tokens}. Thus, the combined length of "
1318
            f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
1319
            f"`max_target_positions` of the Whisper model: {config.max_target_positions}. "
1320
            "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
1321
            f"so that their combined length is less than {config.max_target_positions}.",
1322
        ):
1323
            model.generate(input_features, max_new_tokens=max_new_tokens, prompt_ids=prompt_ids)
1324

1325
        model.generate(input_features, max_new_tokens=1, prompt_ids=prompt_ids)
1326

1327
    def test_generate_longform_with_prompt_ids(self):
1328
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1329
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1330

1331
        prompt_ids = torch.arange(5).to(torch_device)
1332
        model.generation_config.no_timestamps_token_id = 11
1333
        model.generation_config.pad_token_id = 10
1334

1335
        # make sure prompt token ids [0-9] can't be generated
1336
        model.generation_config.suppress_tokens = list(range(10))
1337

1338
        input_features = input_dict["input_features"]
1339

1340
        language = "<|de|>"
1341
        lang_id = 6
1342

1343
        input_features = input_features.repeat(1, 1, 50)
1344
        attention_mask = torch.ones_like(input_features, dtype=torch.long)[:, 0]
1345

1346
        for prompt_type in ["first-segment", "all-segments"]:
1347
            for task_id, task in enumerate(["translate", "transcribe"]):
1348
                task_id = 7 + task_id
1349

1350
                model.generation_config.__setattr__("lang_to_id", {language: lang_id})
1351
                model.generation_config.__setattr__("task_to_id", {task: task_id})
1352

1353
                output = model.generate(
1354
                    input_features,
1355
                    attention_mask=attention_mask,
1356
                    prompt_condition_type=prompt_type,
1357
                    max_new_tokens=5,
1358
                    task=task,
1359
                    language=language,
1360
                    prompt_ids=prompt_ids,
1361
                    condition_on_prev_tokens=True,
1362
                )
1363
                for row in output.tolist():
1364
                    # make sure no token below 10 is in generated output => this means for long-form prompt ids should NOT be returned
1365
                    assert not any(i in row for i in model.generation_config.suppress_tokens)
1366

1367
    def _check_longform_generate_single_batch(self, condition_on_prev_tokens):
1368
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1369

1370
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1371
        input_features = input_dict["input_features"]
1372

1373
        # len = 250 with num_input_frames = 60
1374
        long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1)
1375

1376
        # force bsz=1
1377
        long_input_features = long_input_features[:1]
1378
        vocab_size = model.config.vocab_size
1379

1380
        batch_size = 1
1381
        num_timestamp_tokens = 20
1382
        max_length = 16
1383
        logits_processor = [
1384
            DummyTimestampLogitProcessor(
1385
                vocab_size - num_timestamp_tokens,
1386
                vocab_size,
1387
                batch_size=batch_size,
1388
                max_length=max_length,
1389
                min_space=4,
1390
            )
1391
        ]
1392

1393
        # each chunk should not be longer than 10
1394
        model.generation_config.max_length = max_length
1395

1396
        # if input features are long can't set return_timestamps to False
1397
        with self.assertRaises(ValueError):
1398
            _ = model.generate(long_input_features, logits_processor=logits_processor, return_timestamps=False)
1399

1400
        # if input features are long need to set generation config
1401
        with self.assertRaises(ValueError):
1402
            _ = model.generate(long_input_features, logits_processor=logits_processor)
1403

1404
        timestamp_begin = vocab_size - num_timestamp_tokens
1405
        model.generation_config.no_timestamps_token_id = timestamp_begin - 1
1406
        model.generation_config.eos_token_id = None
1407
        model.config.eos_token_id = None
1408
        model.generation_config._detect_timestamp_from_logprob = False
1409
        # make sure that we only have the same begin token
1410
        model.generation_config.max_initial_timestamp_index = 0
1411
        model.generation_config.prev_bos_token_id = timestamp_begin - 3
1412

1413
        gen_kwargs = {
1414
            "logits_processor": logits_processor,
1415
            "return_segments": True,
1416
            "condition_on_prev_tokens": condition_on_prev_tokens,
1417
        }
1418

1419
        if condition_on_prev_tokens:
1420
            gen_kwargs["no_speech_threshold"] = 0.6
1421
            gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
1422
            gen_kwargs["compression_ratio_threshold"] = 2.4
1423
            gen_kwargs["logprob_threshold"] = -1.0
1424

1425
        outputs = model.generate(long_input_features, **gen_kwargs)
1426

1427
        segments = outputs["segments"][0]
1428

1429
        for _, segment in enumerate(segments):
1430
            assert segment["start"] <= segment["end"], "start has to be smaller equal end"
1431
            assert any(
1432
                s > timestamp_begin for s in segment["tokens"][1:]
1433
            ), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}"
1434
            assert (
1435
                segment["tokens"].shape[-1] <= max_length
1436
            ), "make sure that no segment is larger than max generation length"
1437

1438
    def test_longform_generate_single_batch(self):
1439
        self._check_longform_generate_single_batch(condition_on_prev_tokens=False)
1440

1441
    def test_longform_generate_single_batch_cond_prev(self):
1442
        self._check_longform_generate_single_batch(condition_on_prev_tokens=True)
1443

1444
    def _check_longform_generate_multi_batch(self, condition_on_prev_tokens):
1445
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
1446

1447
        model = WhisperForConditionalGeneration(config).eval().to(torch_device)
1448
        input_features = input_dict["input_features"].to(torch_device)
1449

1450
        # len = 250 with num_input_frames = 60
1451
        long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1)
1452
        input_features_2 = long_input_features[1:]
1453
        attention_mask = torch.ones(
1454
            (2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device
1455
        )
1456
        attention_mask[0, 200:] = 0
1457

1458
        # force bsz=1
1459
        vocab_size = model.config.vocab_size
1460

1461
        batch_size = 1
1462
        num_timestamp_tokens = 20
1463
        max_new_tokens = 16
1464
        timestamp_begin = vocab_size - num_timestamp_tokens
1465
        model.generation_config.no_timestamps_token_id = timestamp_begin - 1
1466
        model.generation_config.eos_token_id = None
1467
        model.config.eos_token_id = None
1468
        model.generation_config._detect_timestamp_from_logprob = False
1469
        # make sure that we only have the same begin token
1470
        model.generation_config.max_initial_timestamp_index = 0
1471
        model.generation_config.max_new_tokens = max_new_tokens
1472
        model.generation_config.prev_bos_token_id = timestamp_begin - 3
1473

1474
        logits_processor = [
1475
            DummyTimestampLogitProcessor(
1476
                vocab_size - num_timestamp_tokens,
1477
                vocab_size,
1478
                batch_size=batch_size,
1479
                max_length=max_new_tokens,
1480
                min_space=4,
1481
                seed=1,
1482
            )
1483
        ]
1484
        outputs_2 = model.generate(
1485
            input_features_2,
1486
            max_new_tokens=max_new_tokens,
1487
            logits_processor=logits_processor,
1488
            condition_on_prev_tokens=condition_on_prev_tokens,
1489
            return_segments=True,
1490
        )
1491
        tokens_2 = outputs_2["sequences"][0]
1492
        segments_2 = outputs_2["segments"][0]
1493

1494
        batch_size = 2
1495
        logits_processor = [
1496
            DummyTimestampLogitProcessor(
1497
                vocab_size - num_timestamp_tokens,
1498
                vocab_size,
1499
                batch_size=batch_size,
1500
                max_length=max_new_tokens,
1501
                min_space=4,
1502
                seed=0,
1503
            )
1504
        ]
1505
        gen_kwargs = {
1506
            "logits_processor": logits_processor,
1507
            "return_segments": True,
1508
            "condition_on_prev_tokens": condition_on_prev_tokens,
1509
            "attention_mask": attention_mask,
1510
            "max_new_tokens": max_new_tokens,
1511
        }
1512

1513
        outputs = model.generate(long_input_features, **gen_kwargs)
1514
        tokens = outputs["sequences"][1]
1515
        segments = outputs["segments"][1]
1516

1517
        # make sure batched and non-batched is the same
1518
        assert tokens_2.tolist() == tokens[: tokens_2.shape[-1]].tolist()
1519

1520
        for seg1, seg2 in zip(segments_2, segments):
1521
            assert seg1["start"] == seg2["start"]
1522
            assert seg1["end"] == seg2["end"]
1523
            assert seg1["tokens"].tolist() == seg2["tokens"].tolist()
1524

1525
    def test_longform_generate_multi_batch(self):
1526
        self._check_longform_generate_multi_batch(condition_on_prev_tokens=False)
1527

1528
    def test_longform_generate_multi_batch_cond_prev(self):
1529
        self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
1530

1531

1532
@require_torch
1533
@require_torchaudio
1534
class WhisperModelIntegrationTests(unittest.TestCase):
1535
    @cached_property
1536
    def default_processor(self):
1537
        return WhisperProcessor.from_pretrained("openai/whisper-base")
1538

1539
    def _load_datasamples(self, num_samples):
1540
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1541
        # automatic decoding with librispeech
1542
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
1543

1544
        return [x["array"] for x in speech_samples]
1545

1546
    @slow
1547
    def test_tiny_logits_librispeech(self):
1548
        torch_device = "cpu"
1549
        set_seed(0)
1550
        model = WhisperModel.from_pretrained("openai/whisper-tiny")
1551
        model.to(torch_device)
1552
        input_speech = self._load_datasamples(1)
1553
        feature_extractor = WhisperFeatureExtractor()
1554
        input_features = feature_extractor(input_speech, return_tensors="pt").input_features
1555

1556
        with torch.no_grad():
1557
            logits = model(
1558
                input_features,
1559
                decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
1560
                output_hidden_states=False,
1561
                output_attentions=False,
1562
                return_dict=False,
1563
                use_cache=False,
1564
            )
1565

1566
        # fmt: off
1567
        EXPECTED_LOGITS = torch.tensor(
1568
            [
1569
                2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
1570
                0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
1571
                4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
1572
                0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
1573
            ]
1574
        )
1575
        # fmt: on
1576
        self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
1577

1578
        # fmt: off
1579
        EXPECTED_GENERATION = torch.tensor(
1580
            [
1581
                -1.4651, -2.6944, 2.7821, 2.3793, 4.0738, 0.0188, -3.3203, 1.9836,
1582
                0.0520, 0.7095, 1.1063, 0.2952, -3.6786, -0.5249, 0.3105, 4.7691,
1583
                1.1562, 1.3046, 0.5810, -0.3624, 1.7006, 1.3424, 0.9817, 2.1958,
1584
                1.8775, -5.7046, -0.7679, 4.0113, 2.6848, 2.8609
1585
            ]
1586
        )
1587
        # fmt: on
1588

1589
        head_logits = logits[0] @ model.decoder.embed_tokens.weight.T
1590
        self.assertTrue(torch.allclose(head_logits[0, 0, :30].cpu(), EXPECTED_GENERATION, atol=1e-4))
1591

1592
    @slow
1593
    def test_small_en_logits_librispeech(self):
1594
        set_seed(0)
1595
        torch_device = "cpu"
1596
        model = WhisperModel.from_pretrained("openai/whisper-small.en")
1597
        model.to(torch_device)
1598

1599
        input_speech = self._load_datasamples(1)
1600

1601
        feaure_extractor = WhisperFeatureExtractor()
1602
        input_features = feaure_extractor(input_speech, return_tensors="pt").input_features.to(torch_device)
1603

1604
        logits = model(
1605
            input_features,
1606
            decoder_input_ids=torch.tensor([[model.config.decoder_start_token_id]]),
1607
            output_hidden_states=False,
1608
            output_attentions=False,
1609
            use_cache=False,
1610
        )
1611

1612
        logits = logits.last_hidden_state @ model.decoder.embed_tokens.weight.T
1613

1614
        # fmt: off
1615
        EXPECTED_LOGITS = torch.tensor(
1616
            [
1617
                -3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
1618
                -8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
1619
                -6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
1620
                -10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
1621
                -11.1146, -8.1918
1622
            ]
1623
        )
1624
        # fmt: on
1625
        self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
1626

1627
    @slow
1628
    def test_large_logits_librispeech(self):
1629
        set_seed(0)
1630

1631
        torch_device = "cpu"
1632
        model = WhisperModel.from_pretrained("openai/whisper-large")
1633
        model.to(torch_device)
1634

1635
        input_speech = self._load_datasamples(1)
1636

1637
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
1638
        processed_inputs = processor(
1639
            audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="pt"
1640
        )
1641
        input_features = processed_inputs.input_features.to(torch_device)
1642
        decoder_input_ids = processed_inputs.labels.to(torch_device)
1643

1644
        logits = model(
1645
            input_features,
1646
            decoder_input_ids=decoder_input_ids,
1647
            output_hidden_states=False,
1648
            output_attentions=False,
1649
            use_cache=False,
1650
        )
1651

1652
        logits = logits.last_hidden_state @ model.decoder.embed_tokens.weight.T
1653

1654
        # fmt: off
1655
        EXPECTED_LOGITS = torch.tensor(
1656
            [
1657
                2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
1658
                1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
1659
                1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
1660
                1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
1661
            ]
1662
        )
1663
        # fmt: on
1664

1665
        self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
1666

1667
    @slow
1668
    def test_tiny_en_generation(self):
1669
        torch_device = "cpu"
1670
        set_seed(0)
1671
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1672
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1673
        model.to(torch_device)
1674
        model.config.decoder_start_token_id = 50257
1675

1676
        input_speech = self._load_datasamples(1)
1677
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1678
            torch_device
1679
        )
1680

1681
        generated_ids = model.generate(input_features, num_beams=5, max_length=20)
1682
        transcript = processor.tokenizer.batch_decode(generated_ids)[0]
1683

1684
        EXPECTED_TRANSCRIPT = (
1685
            "<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle"
1686
            " classes, and we are glad to"
1687
        )
1688
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1689

1690
    @slow
1691
    def test_tiny_generation(self):
1692
        torch_device = "cpu"
1693
        set_seed(0)
1694
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1695
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1696
        model.to(torch_device)
1697

1698
        input_speech = self._load_datasamples(1)
1699
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1700
            torch_device
1701
        )
1702

1703
        generated_ids = model.generate(input_features, num_beams=5, max_length=20)
1704
        transcript = processor.tokenizer.decode(generated_ids[0])
1705

1706
        EXPECTED_TRANSCRIPT = (
1707
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
1708
            " classes and we are glad"
1709
        )
1710
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1711

1712
    @slow
1713
    def test_large_generation(self):
1714
        torch_device = "cpu"
1715
        set_seed(0)
1716
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
1717
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
1718
        model.to(torch_device)
1719

1720
        input_speech = self._load_datasamples(1)
1721
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1722
            torch_device
1723
        )
1724

1725
        generated_ids = model.generate(
1726
            input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
1727
        )
1728
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1729

1730
        EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
1731
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1732

1733
    @slow
1734
    def test_large_generation_multilingual(self):
1735
        torch_device = "cpu"
1736
        set_seed(0)
1737
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
1738
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
1739
        model.to(torch_device)
1740

1741
        token = os.getenv("HF_HUB_READ_TOKEN", True)
1742
        ds = load_dataset("mozilla-foundation/common_voice_6_1", "ja", split="test", streaming=True, token=token)
1743
        ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
1744

1745
        input_speech = next(iter(ds))["audio"]["array"]
1746
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1747
            torch_device
1748
        )
1749

1750
        generated_ids = model.generate(
1751
            input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
1752
        )
1753
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1754

1755
        EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
1756
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1757

1758
        generated_ids = model.generate(
1759
            input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
1760
        )
1761
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1762

1763
        EXPECTED_TRANSCRIPT = " Kimura-san called me."
1764
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1765

1766
        generated_ids = model.generate(
1767
            input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
1768
        )
1769
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
1770

1771
        EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
1772
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1773

1774
    @slow
1775
    def test_large_batched_generation(self):
1776
        set_seed(0)
1777
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
1778
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
1779

1780
        input_speech = self._load_datasamples(4)
1781
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
1782
        generated_ids = model.generate(input_features, max_length=20, task="translate")
1783

1784
        # fmt: off
1785
        EXPECTED_LOGITS = torch.tensor(
1786
            [
1787
                [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
1788
                [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
1789
                [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
1790
                [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
1791
            ]
1792
        )
1793
        # fmt: on
1794

1795
        self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS))
1796

1797
        # fmt: off
1798
        EXPECTED_TRANSCRIPT = [
1799
            " Mr. Quilter is the apostle of the middle classes and we are glad",
1800
            " Nor is Mr. Quilter's manner less interesting than his matter.",
1801
            " He tells us that at this festive season of the year, with Christmas and roast",
1802
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
1803
        ]
1804
        # fmt: on
1805

1806
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
1807
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
1808

1809
    @slow
1810
    def test_tiny_en_batched_generation(self):
1811
        set_seed(0)
1812
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
1813
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
1814
        model.to(torch_device)
1815

1816
        input_speech = self._load_datasamples(4)
1817
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1818
            torch_device
1819
        )
1820
        generated_ids = model.generate(input_features, max_length=20).to("cpu")
1821

1822
        # fmt: off
1823
        EXPECTED_LOGITS = torch.tensor(
1824
            [
1825
                [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
1826
                [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
1827
                [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
1828
                [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
1829
            ]
1830

1831
        )
1832
        # fmt: on
1833

1834
        self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS))
1835

1836
        # fmt: off
1837
        EXPECTED_TRANSCRIPT = [
1838
            " Mr. Quilter is the apostle of the middle classes, and we are glad to",
1839
            " Nor is Mr. Quilter's manner less interesting than his matter.",
1840
            " He tells us that at this festive season of the year, with Christmas and roast beef looming",
1841
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
1842
        ]
1843
        # fmt: on
1844

1845
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
1846
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
1847

1848
    @slow
1849
    def test_tiny_timestamp_generation(self):
1850
        set_seed(0)
1851
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1852
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1853
        model.to(torch_device)
1854

1855
        input_speech = np.concatenate(self._load_datasamples(4))
1856
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1857
            torch_device
1858
        )
1859

1860
        generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")
1861

1862
        EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])  # fmt: skip
1863

1864
        self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
1865

1866
        EXPECTED_TRANSCRIPT = [
1867
            {
1868
                "text": (
1869
                    " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"
1870
                    " Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"
1871
                    " of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"
1872
                    " its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"
1873
                    " work is really Greek after all, and"
1874
                ),
1875
                "offsets": [
1876
                    {
1877
                        "text": (
1878
                            " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
1879
                        ),
1880
                        "timestamp": (0.0, 6.5600000000000005),
1881
                    },
1882
                    {
1883
                        "text": " Nor is Mr. Quilter's manner less interesting than his matter.",
1884
                        "timestamp": (6.5600000000000005, 11.24),
1885
                    },
1886
                    {
1887
                        "text": (
1888
                            " He tells us that at this festive season of the year, with Christmas and roast beef"
1889
                            " looming"
1890
                        ),
1891
                        "timestamp": (11.24, 16.88),
1892
                    },
1893
                    {
1894
                        "text": (
1895
                            " before us, similarly drawn from eating and its results occur most readily to the mind."
1896
                        ),
1897
                        "timestamp": (16.88, 23.76),
1898
                    },
1899
                    {
1900
                        "text": (
1901
                            " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"
1902
                        ),
1903
                        "timestamp": (23.76, 29.44),
1904
                    },
1905
                ],
1906
            }
1907
        ]
1908

1909
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
1910
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
1911

1912
    @slow
1913
    def test_tiny_token_timestamp_generation(self):
1914
        set_seed(0)
1915
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1916
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1917
        model.to(torch_device)
1918
        model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
1919

1920
        input_speech = self._load_datasamples(4)
1921
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1922
            torch_device
1923
        )
1924

1925
        generate_outputs = model.generate(
1926
            input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
1927
        )
1928

1929
        self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
1930

1931
        # fmt: off
1932
        EXPECTED_OUTPUT = torch.tensor([
1933
            [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
1934
            [ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
1935
            [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
1936
            [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
1937
        ])
1938
        # fmt: on
1939

1940
        self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
1941

1942
    @slow
1943
    def test_tiny_token_timestamp_batch_generation(self):
1944
        set_seed(0)
1945
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1946
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1947
        model.to(torch_device)
1948
        model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
1949

1950
        num_samples = 4
1951
        num_return_sequences = 2
1952

1953
        input_speech = self._load_datasamples(num_samples)
1954
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
1955
            torch_device
1956
        )
1957

1958
        generate_outputs = model.generate(
1959
            input_features,
1960
            max_length=448,
1961
            return_timestamps=True,
1962
            return_token_timestamps=True,
1963
            num_beams=3,
1964
            num_return_sequences=num_return_sequences,
1965
        )
1966

1967
        # task id and lang id prompts should not have timestamp tokens
1968
        self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
1969

1970
        self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
1971

1972
    @slow
1973
    def test_tiny_token_timestamp_generation_longform(self):
1974
        set_seed(0)
1975
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
1976
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
1977
        model.to(torch_device)
1978
        model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
1979

1980
        input_speech = self._load_datasamples(5)
1981
        long_input_speech = np.concatenate(input_speech, dtype=np.float32)
1982
        inputs = processor.feature_extractor(
1983
            raw_speech=long_input_speech,
1984
            return_tensors="pt",
1985
            truncation=False,  # False so the audio isn't truncated and whole audio is sent to the model
1986
            return_attention_mask=True,
1987
            padding=True,
1988
        )
1989

1990
        inputs = inputs.to(torch_device)
1991
        generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
1992

1993
        token_timestamps_shape = [
1994
            [segment["token_timestamps"].shape for segment in segment_list]
1995
            for segment_list in generate_outputs["segments"]
1996
        ]
1997
        tokens_shape = [
1998
            [segment["tokens"].shape for segment in segment_list] for segment_list in generate_outputs["segments"]
1999
        ]
2000
        self.assertListEqual(tokens_shape, token_timestamps_shape)
2001

2002
        # fmt: off
2003
        EXPECTED_OUTPUT = [
2004
            torch.tensor([0.0000, 0.4200, 0.8200, 0.9400, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0400, 2.3400, 2.5200, 2.6600, 3.2000, 3.4400, 3.5600, 3.6800, 3.8200, 4.1000, 4.3000, 4.5800, 4.9400, 5.4000, 6.3600]),
2005
            torch.tensor([ 6.5400,  6.5400,  6.7400,  6.9600,  7.2600,  7.3400,  7.5800,  7.5800, 7.6400,  7.8400,  8.1000,  8.5000,  9.0000,  9.4800,  9.7200, 10.2600, 11.1000]),
2006
            torch.tensor([11.2200, 11.2200, 11.4200, 11.6600, 12.0800, 12.4400, 12.5800, 12.8400, 13.1800, 13.6800, 14.0000, 14.2200, 14.6200, 14.9800, 15.2200, 15.6000, 15.9400, 16.2000, 16.5600, 16.8400, 16.9800]),
2007
            torch.tensor([16.9800, 16.9800, 17.3200, 18.1600, 18.6400, 18.8600, 19.2800, 19.5600, 19.8800, 20.1800, 20.3800, 20.7200, 21.1600, 21.5400, 21.9000, 22.2000, 22.4200, 22.8600, 23.7000]),
2008
            torch.tensor([23.7000, 23.7000, 23.9400, 24.1800, 24.3800, 24.8400, 25.2800, 25.6600, 25.9200, 26.2600, 26.4000, 26.5800, 26.7600, 27.1400, 27.3800, 28.0400, 28.3800, 28.8200, 29.3400, 29.5200]),
2009
            torch.tensor([29.4400, 29.4400, 29.7000, 30.0800, 30.3800, 30.5400, 30.8200, 31.0600, 31.6600, 31.9200, 32.3000, 32.4800, 32.6200, 33.6800]),
2010
            torch.tensor([33.8000, 33.8000, 33.9800, 33.9800, 34.1800, 34.4400, 34.6200, 35.0000, 35.2200, 35.3200, 35.5600, 35.9200, 36.3800, 36.6200, 36.6600, 36.9600, 37.3400, 37.9800, 38.5800, 38.7200, 38.9800, 39.4400, 39.5800, 39.8000, 40.1200, 40.2600]),
2011
            torch.tensor([40.5200, 40.5200, 40.6200, 41.1000, 41.5400, 41.9200, 42.1000, 42.3200, 42.3200, 43.0600, 44.6000]),
2012
            torch.tensor([44.7000, 44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400]),
2013
            torch.tensor([50.5400, 50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400]),
2014
            torch.tensor([52.9600, 52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1800, 58.5000, 58.6400, 58.8200]),
2015
            torch.tensor([58.6800, 58.6800, 59.1400, 59.5400, 59.9200, 60.1600, 60.3800, 60.8200, 61.6200, 62.2600, 75.2000]),
2016
        ]
2017
        # fmt: on
2018

2019
        for segment, exp_segment in zip(generate_outputs["segments"][0], EXPECTED_OUTPUT):
2020
            self.assertTrue(torch.allclose(segment["token_timestamps"], exp_segment))
2021

2022
    @slow
2023
    def test_tiny_specaugment_librispeech(self):
2024
        torch_device = "cpu"
2025
        set_seed(0)
2026
        # Apply SpecAugment
2027
        model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
2028
        # Set model to training mode to enable SpecAugment
2029
        model.train()
2030
        model.to(torch_device)
2031
        input_speech = self._load_datasamples(1)
2032
        feature_extractor = WhisperFeatureExtractor()
2033
        input_features = feature_extractor(input_speech, return_tensors="pt").input_features
2034

2035
        with torch.no_grad():
2036
            logits = model(
2037
                input_features,
2038
                decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
2039
                output_hidden_states=False,
2040
                output_attentions=False,
2041
                return_dict=False,
2042
                use_cache=False,
2043
            )
2044

2045
        # fmt: off
2046
        EXPECTED_LOGITS = torch.tensor(
2047
            [
2048
                0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
2049
                -0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
2050
                2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
2051
                1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
2052
            ]
2053
        )
2054
        # fmt: on
2055
        self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
2056

2057
    @slow
2058
    def test_generate_with_prompt_ids(self):
2059
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2060
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2061
        model.to(torch_device)
2062
        input_speech = self._load_datasamples(4)[-1:]
2063
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2064

2065
        output_without_prompt = model.generate(input_features)
2066
        prompt_ids = processor.get_prompt_ids("Leighton", return_tensors="pt").to(torch_device)
2067
        output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
2068

2069
        expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
2070
        expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
2071

2072
        output_without_prompt = processor.decode(output_without_prompt[0])
2073
        output_with_prompt = processor.decode(output_with_prompt[0])
2074

2075
        self.assertEqual(output_without_prompt, expected_without_prompt)
2076
        self.assertEqual(output_with_prompt, expected_with_prompt)
2077

2078
    @slow
2079
    def test_language_detection(self):
2080
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2081
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2082
        model.to(torch_device)
2083
        input_speech = self._load_datasamples(4)[-1:]
2084
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2085

2086
        lang_id = model.detect_language(input_features)[0].item()
2087

2088
        ids_to_lang = {v: k for k, v in model.generation_config.lang_to_id.items()}
2089

2090
        assert ids_to_lang[lang_id] == "<|en|>"
2091

2092
        audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
2093

2094
        raw_audio, sr = torchaudio.load(audio)
2095
        input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio).numpy()
2096

2097
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2098

2099
        lang_id = model.detect_language(input_features)[0].item()
2100

2101
        assert ids_to_lang[lang_id] == "<|hi|>"
2102

2103
    @slow
2104
    def test_default_multilingual_transcription_short_form(self):
2105
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2106
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2107
        model.to(torch_device)
2108

2109
        audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
2110

2111
        raw_audio, sr = torchaudio.load(audio)
2112
        input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio).numpy()
2113

2114
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2115

2116
        # model.generation_config.forced_decoder_ids defaults to [1, null] for lang_token
2117
        sequences = model.generate(input_features)
2118

2119
        transcription = processor.batch_decode(sequences, skip_special_tokens=False)[0]
2120

2121
        assert (
2122
            transcription
2123
            == "<|startoftranscript|><|hi|><|transcribe|><|notimestamps|> Mirchi mein ki tene vibinda prajatiya hai<|endoftext|>"
2124
        )
2125

2126
        # set forced_decoder_ids to English
2127
        model.generation_config.forced_decoder_ids[0][-1] = 50259
2128

2129
        sequences = model.generate(input_features)
2130
        transcription = processor.batch_decode(sequences, skip_special_tokens=False)[0]
2131

2132
        assert (
2133
            transcription
2134
            == "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> MIRCHI MET, which is the name of the Bible.<|endoftext|>"
2135
        )
2136

2137
    @slow
2138
    def test_default_multilingual_transcription_long_form(self):
2139
        processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
2140
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
2141
        model.to(torch_device)
2142

2143
        audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
2144

2145
        raw_audio, sr = torchaudio.load(audio)
2146
        input_speech = torchaudio.transforms.Resample(sr, 16_000)(raw_audio)
2147

2148
        input_speech = input_speech.repeat(1, 10).numpy()
2149
        input_features = processor(
2150
            input_speech, return_tensors="pt", padding="longest", truncation=False
2151
        ).input_features.to(torch_device)
2152

2153
        # model.generation_config.forced_decoder_ids defaults to [1, null] for lang_token
2154
        sequences = model.generate(input_features)
2155

2156
        transcription = processor.batch_decode(sequences)[0]
2157

2158
        assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"
2159

2160
        # set forced_decoder_ids to English
2161
        model.generation_config.forced_decoder_ids[0][-1] = 50259
2162

2163
        sequences = model.generate(input_features)
2164
        transcription = processor.batch_decode(sequences)[0]
2165

2166
        assert (
2167
            transcription
2168
            == " How many different species are there in the chilli? How many different species are there in the chili?"
2169
        )
2170

2171
    @slow
2172
    def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
2173
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2174
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2175
        model.to(torch_device)
2176
        input_speech = self._load_datasamples(1)
2177
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2178
        task = "translate"
2179
        language = "de"
2180
        expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
2181
        prompt = "test prompt"
2182
        prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device)
2183

2184
        output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
2185
        text = processor.decode(output[0])
2186

2187
        self.assertTrue(prompt in text)
2188
        self.assertTrue(all(token in text for token in expected_tokens))
2189

2190
    @slow
2191
    def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
2192
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2193
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2194
        model.to(torch_device)
2195
        input_speech = self._load_datasamples(1)
2196
        input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
2197
        prompt = "test prompt"
2198
        prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device)
2199

2200
        model.generation_config.forced_decoder_ids = None
2201
        model.config.forced_decoder_ids = None
2202

2203
        output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
2204
        text = processor.decode(output[0])
2205

2206
        self.assertTrue(prompt in text)
2207

2208
    @slow
2209
    @require_torch_gpu
2210
    def test_speculative_decoding_distil(self):
2211
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
2212
        model_id = "openai/whisper-large-v2"
2213
        model = WhisperForConditionalGeneration.from_pretrained(
2214
            model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
2215
        )
2216
        model.to(torch_device)
2217

2218
        processor = WhisperProcessor.from_pretrained(model_id)
2219

2220
        assistant_model_id = "distil-whisper/distil-large-v2"
2221
        assistant_model = WhisperForCausalLM.from_pretrained(
2222
            assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
2223
        )
2224
        assistant_model.to(torch_device)
2225

2226
        dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
2227
        sample = dataset[0]["audio"]
2228

2229
        input_features = (
2230
            processor(sample["array"], return_tensors="pt").input_features.to(torch_device).to(torch.float16)
2231
        )
2232

2233
        # warm up assisted decoding
2234
        _ = model.generate(input_features, assistant_model=assistant_model)
2235
        # warm up non-assisted decoding
2236
        _ = model.generate(input_features)
2237

2238
        # assisted decoding
2239
        start_time = time.time()
2240
        tokens = model.generate(input_features, assistant_model=assistant_model)
2241
        total_time_assist = time.time() - start_time
2242

2243
        transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
2244

2245
        # non-assisted decoding
2246
        start_time = time.time()
2247
        tokens = model.generate(input_features)
2248
        total_time_non_assist = time.time() - start_time
2249

2250
        transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
2251

2252
        assert transcription_ass == transcription_non_ass
2253
        assert transcription_ass == [
2254
            " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
2255
        ]
2256
        assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
2257

2258
    @slow
2259
    @require_torch_gpu
2260
    def test_speculative_decoding_non_distil(self):
2261
        torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
2262
        model_id = "openai/whisper-large-v2"
2263
        model = WhisperForConditionalGeneration.from_pretrained(
2264
            model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
2265
        )
2266
        model.to(torch_device)
2267

2268
        processor = WhisperProcessor.from_pretrained(model_id)
2269

2270
        assistant_model_id = "openai/whisper-tiny"
2271
        assistant_model = WhisperForConditionalGeneration.from_pretrained(
2272
            assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
2273
        )
2274
        assistant_model.to(torch_device)
2275

2276
        dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
2277
        sample = dataset[0]["audio"]
2278

2279
        input_features = (
2280
            processor(sample["array"], return_tensors="pt").input_features.to(torch_device).to(torch.float16)
2281
        )
2282

2283
        # warm up assisted decoding
2284
        _ = model.generate(input_features, assistant_model=assistant_model)
2285
        # warm up non-assisted decoding
2286
        _ = model.generate(input_features)
2287

2288
        # assisted decoding
2289
        start_time = time.time()
2290
        tokens = model.generate(input_features, assistant_model=assistant_model)
2291
        total_time_assist = time.time() - start_time
2292

2293
        transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
2294

2295
        # non-assisted decoding
2296
        start_time = time.time()
2297
        tokens = model.generate(input_features)
2298
        total_time_non_assist = time.time() - start_time
2299

2300
        transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
2301

2302
        assert transcription_ass == transcription_non_ass
2303
        assert transcription_ass == [
2304
            " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
2305
        ]
2306
        assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
2307

2308
    @slow
2309
    def test_whisper_longform_single_batch(self):
2310
        # fmt: off
2311
        EXPECTED_TEXT = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.']
2312
        # fmt: on
2313

2314
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2315
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2316
        model = model.to(torch_device)
2317

2318
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
2319
        one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
2320

2321
        input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
2322
            "input_features"
2323
        ]
2324
        input_features = input_features.to(device=torch_device)
2325

2326
        result = model.generate(input_features, return_timestamps=True)
2327
        decoded = processor.batch_decode(result, skip_special_tokens=True)
2328

2329
        assert decoded == EXPECTED_TEXT
2330

2331
        decoded_with_timestamps = processor.batch_decode(result, skip_special_tokens=True, decode_with_timestamps=True)
2332

2333
        no_timestamp_matches = re.split(r"<\|[\d\.]+\|>", decoded_with_timestamps[0])
2334

2335
        assert ["".join(no_timestamp_matches)] == EXPECTED_TEXT
2336

2337
        timestamp_matches = re.findall(r"<\|[\d\.]+\|>", decoded_with_timestamps[0])
2338

2339
        timestamp_floats = [float(t[2:-2]) for t in timestamp_matches]
2340

2341
        is_increasing = all(timestamp_floats[i] <= timestamp_floats[i + 1] for i in range(len(timestamp_floats) - 1))
2342

2343
        assert is_increasing
2344

2345
    @slow
2346
    def test_whisper_longform_prompt_ids(self):
2347
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2348
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2349
        model = model.to(torch_device)
2350

2351
        prompt = "Mr. Kilter, Ruggedo."  # let's force Mr. Quilter -> Mr. Kilter
2352
        prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt").to(torch_device)
2353

2354
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
2355
        one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
2356

2357
        first_text = ds["validation"][0]["text"].lower()
2358
        last_text = ds["validation"][-1]["text"].lower()
2359

2360
        input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
2361
            "input_features"
2362
        ]
2363
        input_features = input_features.to(device=torch_device)
2364

2365
        result = model.generate(
2366
            input_features,
2367
            prompt_ids=prompt_ids,
2368
            return_timestamps=True,
2369
            prompt_condition_type="first-segment",
2370
            condition_on_prev_tokens=True,
2371
        )
2372
        decoded_first_segment = processor.batch_decode(result, skip_special_tokens=True)
2373

2374
        result = model.generate(
2375
            input_features,
2376
            prompt_ids=prompt_ids,
2377
            return_timestamps=True,
2378
            prompt_condition_type="all-segments",
2379
            condition_on_prev_tokens=True,
2380
        )
2381
        decoded_all_segments = processor.batch_decode(result, skip_special_tokens=True)
2382

2383
        # show that first segment has quilter and last segment has ruggedo
2384
        assert "quilter" in first_text
2385
        assert "ruggedo" in last_text
2386

2387
        # condition on first segment correctly changes to kilter in first segment, but does not transcribe "ruggedo" correctly
2388
        assert "kilter" in decoded_first_segment[0][: len(first_text)].lower()
2389
        assert "ruggedo" not in decoded_first_segment[0][-len(last_text) :].lower()
2390

2391
        # condition on all-segment correctly changes to kilter in first segment and correctly transcribes "ruggedo"
2392
        assert "kilter" in decoded_all_segments[0][: len(first_text)].lower()
2393
        assert "ruggedo" in decoded_all_segments[0][-len(last_text) :].lower()
2394

2395
    @slow
2396
    def test_whisper_longform_single_batch_prev_cond(self):
2397
        # fmt: off
2398
        EXPECTED_TEXT = [""" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. When Mr. John Collier gives his sitter a cheerful slap in the back, before he says like a shampooer and a Turkish bath, next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. He tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in felicitous grace that many faces are feeling. Unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M. A. A man said to the universe, Sir, I exist. Sweat covered Breon's body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retroveilities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you're being a fool. But there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Your man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Breon's death was in some ways easier than defeat. Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that's rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it, just as we're good to be used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Regidos discarded Ruby crown, and holding in his hand to scepter which Regidos had so often thrown at his head."""]
2399
        # fmt: on
2400

2401
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2402
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2403
        model = model.to(torch_device)
2404

2405
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
2406
        one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
2407

2408
        input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
2409
            "input_features"
2410
        ]
2411
        input_features = input_features.to(device=torch_device)
2412

2413
        gen_kwargs = {
2414
            "return_timestamps": True,
2415
            "no_speech_threshold": 0.6,
2416
            "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
2417
            "compression_ratio_threshold": 1.35,
2418
            "condition_on_prev_tokens": True,
2419
            "logprob_threshold": -1.0,
2420
        }
2421

2422
        torch.manual_seed(0)
2423
        result = model.generate(input_features, **gen_kwargs)
2424
        decoded = processor.batch_decode(result, skip_special_tokens=True)
2425

2426
        assert decoded == EXPECTED_TEXT
2427

2428
    @slow
2429
    def test_whisper_longform_multi_batch(self):
2430
        # fmt: off
2431
        EXPECTED_TEXT_1 = [" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing a poster or near the fire, and the ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. a Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only germany war. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were, triggered his muscles into complete relaxation. Oily his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the mazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, pre-inscented and new to fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, return Calico. Where is my brother now? choir-dshaggy, in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh, no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe and knew any magic, or she'd have worked it before. I do not know, confess shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as Virgado used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgados discarded Ruby Crown, and holding in his hand to scepter, which Virgado had so often thrown at his head. head."]
2432
        EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker."]
2433
        EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which ruggedo had so often thrown at his head."]
2434
        EXPECTED_TEXT_4 = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.']
2435
        # fmt: on
2436

2437
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2438
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2439
        model = model.to(torch_device)
2440

2441
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
2442
        one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
2443
        audios = []
2444
        audios.append(one_audio[110000:])
2445
        audios.append(one_audio[:800000])
2446
        audios.append(one_audio[80000:])
2447
        audios.append(one_audio[:])
2448

2449
        decoded_single = []
2450
        for audio in audios:
2451
            inputs = processor(audio, return_tensors="pt", truncation=False)
2452
            inputs = inputs.to(device=torch_device)
2453

2454
            result = model.generate(**inputs, return_timestamps=True)
2455
            decoded_single.append(processor.batch_decode(result, skip_special_tokens=True))
2456

2457
        inputs = processor(
2458
            audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
2459
        )
2460
        inputs = inputs.to(device=torch_device)
2461

2462
        result = model.generate(**inputs, return_timestamps=True)
2463
        decoded_all = processor.batch_decode(result, skip_special_tokens=True)
2464

2465
        # make sure single & batch is exactly the same
2466
        assert decoded_all[0:1] == decoded_single[0]
2467
        assert decoded_all[1:2] == decoded_single[1]
2468
        assert decoded_all[2:3] == decoded_single[2]
2469
        assert decoded_all[3:4] == decoded_single[3]
2470

2471
        # exact match
2472
        assert decoded_all[0:1] == EXPECTED_TEXT_1
2473
        assert decoded_all[1:2] == EXPECTED_TEXT_2
2474
        assert decoded_all[2:3] == EXPECTED_TEXT_3
2475
        assert decoded_all[3:4] == EXPECTED_TEXT_4
2476

2477
    @slow
2478
    def test_whisper_longform_multi_batch_prev_cond(self):
2479
        # fmt: off
2480
        EXPECTED_TEXT_1 = [" Mr. Quilters manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. The Nils, pictures are sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilters writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are of two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does get good. Mr. Quilters has missed his chance, for he has failed even to make himself the tougher of painting. My hair equal to M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment he wore. The cut on his chest still dripping blood. The ache of his overstrain dyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance. And brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly. But that away, he'd be no fool. Out, the resoundance then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were inexplicably linked into one. This strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. In the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our role. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you run into escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedsey thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Ruggano used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Ruggano's discarded ruby crown. And holding in his hand the scepter which Ruggano had so often thrown at his head."]
2481
        EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennials, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker"]
2482
        EXPECTED_TEXT_3 = [" gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating in its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky ithaka. Lennils, pictures, are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostoror. Near the fire, any ornaments spread brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many faces are feeling, only unfortunately his own work never does get good. Mr. Quilter has missed his chance. For he has failed even to make himself the tougher of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat covered Brienne's body trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered his muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding out on the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away here being a fool. Out, there is silence then, and still wondering, Brienne was once more asleep. 10 seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were anextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne's softly spoke the odd hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He said it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Brienne to long ago to send him away, but he would not do so. I also offered to help you brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired Shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to bed see you thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gone and pounded on it, just as we're good or used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gone and then sat in the throne, wearing reggos, discarded ruby crown, and holding in his hand to scepter which reggos hand so often thrown at his head."]
2483
        EXPECTED_TEXT_4 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like a shampooer in a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does, get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter, M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators were trivialities not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance. And brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, he could be no fool. Out, there was silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from Irohog. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's for us to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they are asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there is nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. And exactly we've turned Calico, where is my brother now in Quaragejji, in the metal forest? Where is that? The metal forest is in the great donned cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedzeeth thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as we're good to have used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown. And holding in his hand to scepter which reggos had so often thrown at his head."]
2484
        # fmt: on
2485

2486
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2487
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2488
        model = model.to(torch_device)
2489

2490
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
2491
        one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)
2492
        audios = []
2493
        audios.append(one_audio[110000:])
2494
        audios.append(one_audio[:800000])
2495
        audios.append(one_audio[80000:])
2496
        audios.append(one_audio[:])
2497

2498
        gen_kwargs = {
2499
            "return_timestamps": True,
2500
            "no_speech_threshold": 0.6,
2501
            "temperature": 0.0,
2502
            "compression_ratio_threshold": 1.35,
2503
            "condition_on_prev_tokens": True,
2504
            "logprob_threshold": -1.0,
2505
        }
2506

2507
        decoded_single = []
2508
        for audio in audios:
2509
            inputs = processor(audio, return_tensors="pt", truncation=False)
2510
            inputs = inputs.to(device=torch_device)
2511

2512
            result = model.generate(**inputs, **gen_kwargs)
2513
            decoded_single.append(processor.batch_decode(result, skip_special_tokens=True))
2514

2515
        # exact match
2516
        assert decoded_single[0] == EXPECTED_TEXT_1
2517
        assert decoded_single[1] == EXPECTED_TEXT_2
2518
        assert decoded_single[2] == EXPECTED_TEXT_3
2519
        assert decoded_single[3] == EXPECTED_TEXT_4
2520

2521
    @slow
2522
    def test_whisper_longform_multi_batch_hard(self):
2523
        # fmt: off
2524
        EXPECTED_TEXT = [
2525
            " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile.",
2526
            " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!",
2527
            " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them.",
2528
            " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!",
2529
            " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!",
2530
            " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!",
2531
            " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!",
2532
            " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!"
2533
        ]
2534
        # fmt: on
2535

2536
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
2537
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
2538
        model = model.to(torch_device)
2539

2540
        ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
2541
        ds = ds.cast_column("audio", Audio(sampling_rate=16000))
2542

2543
        num_samples = 8
2544

2545
        audio = ds[:num_samples]["audio"]
2546
        audios = [x["array"] for x in audio]
2547

2548
        decoded_single = []
2549
        for audio in audios:
2550
            inputs = processor(audio, return_tensors="pt", truncation=False, sampling_rate=16_000)
2551
            inputs = inputs.to(device=torch_device)
2552

2553
            result = model.generate(**inputs, return_timestamps=True)
2554
            decoded_single += processor.batch_decode(result, skip_special_tokens=True)
2555

2556
        inputs = processor(
2557
            audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
2558
        )
2559
        inputs = inputs.to(device=torch_device)
2560

2561
        result = model.generate(**inputs, return_timestamps=True)
2562
        decoded_all = processor.batch_decode(result, skip_special_tokens=True)
2563

2564
        for i in range(num_samples):
2565
            assert decoded_all[i] == decoded_single[i]
2566
            assert decoded_all[i] == EXPECTED_TEXT[i]
2567

2568
    @slow
2569
    def test_whisper_longform_multi_batch_hard_prev_cond(self):
2570
        # fmt: off
2571
        EXPECTED_TEXT = [
2572
            " Folks, if you watch the show, you know I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories, developing the central headline pawns, definitely maneuvering an oh-so-topical night to F6, faming of classic Sicilian, named or variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a Fisher show's in lip-nitsky attack that culminates in the elegant lethal slow played all pass on checkmate that is my nightly monologue, but sometimes sometimes folks I sometimes I start a little wake-up side down in the monkey bars of a condemned playground on a super fun site, get all hepped up on goofballs, rummage that would discard a tag bag of defective toys, yank out a fistball of disembodied doll limbs, toss them on a stain kid's place mad from a defunked denies, set up a table inside a rusty cargo container down by the warf and challenge toothless drifters to the godless bughouse blitz of tournament that is my segment.",
2573
            " Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing on those topical anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush. To create the luxury sedan that is my nightly monologue, but sometimes I just sometimes focus. I lurched to consciousness in the back of an abandoned school bus and slapped myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen-moon render a gas tank out of an empty big gulp, filled with white claw and de-natured alcohol, then light a match, letter-ripping the dis-mented one-man soapbox derby of news that is my segment.",
2574
            " Ladies and gentlemen, you know, I spent a lot of time right over there, raising the finest hosting news cattle firmly, yet tenderly milking the latest headlines from their jokes, swollen teats, churning the daily stories into the decadent Provincil style triple cream-breed. It is my nightly monologue, but sometimes sometimes I stagger home hungry after being released by the police and root around in the neighbors trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rind I won from a rat and a pre-drawn street fight. Put it into discarded paint can to leave it to ferment next to a trash fire than a hunker down in hallucinate while eating the lusteria latent demon custard of news that is my segment.",
2575
            " Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle, and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ickel Greg Waferandi, who carefully died them in a pallet of bright, zesty shades, and adorn them in the finest most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, and finally attach a mallet hammered strap, perled hardware, and close-shet to create for you the one of a kind hope, kutur, earn-may is burkin bag that is my monologue, but sometimes, sometimes, sometimes. Sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Kony Island, where I'm hiding from the triads, I have some engine lubricants out of a safe way bag and staggered down the shore to tear the sail off a beach sooner than I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel Lovelyfokes, and use it to stitch the sail into a loose pouch like rock sack, and I stole a bag of a garbage truck to the junkyard, where I picked through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out Bindle of news that is my segment.",
2576
            " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hard-seller and a second pill, as I stole from a park damsel, and it's then arm wrestle a raccoon in the back alley vision quest of news that is my segment.",
2577
            " You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, beceived, melee, container ship that picked me up floating on the detainees. Then after I sunstroke in juice, realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe in a pool chain that accepting my new role as captain and declaring myself king of the wind arc seas. I grab a dirty muck bucket covered in barnacles and a dornet with the teeth of the vanquished to create the softening wet pirate crown of news that is my segment. I'm going to use the white paper to create the softened white paper to create the softened white paper to create the softened white pirate crown of news that is my segment. Meanwhile.",
2578
            " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks I wake up in the baggage hole of Greyhound bus, it's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants and as ovenmets to extract and serve the demented transients pound cake of news that is my segment.",
2579
            " Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Sloering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seasons to life before me. And the hideous collection of loose animal parts and corrupted men tissue that is my segment.",
2580
        ]
2581
        # fmt: on
2582

2583
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
2584
        model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
2585
        model = model.to(torch_device)
2586

2587
        ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
2588
        ds = ds.cast_column("audio", Audio(sampling_rate=16000))
2589

2590
        num_samples = 8
2591

2592
        audio = ds[:num_samples]["audio"]
2593
        audios = [x["array"] for x in audio]
2594

2595
        inputs = processor(
2596
            audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
2597
        )
2598
        inputs = inputs.to(device=torch_device)
2599

2600
        gen_kwargs = {
2601
            "return_timestamps": True,
2602
            "no_speech_threshold": 0.6,
2603
            "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
2604
            "compression_ratio_threshold": 1.35,
2605
            "condition_on_prev_tokens": True,
2606
            "logprob_threshold": -1.0,
2607
            "num_beams": 5,
2608
        }
2609

2610
        torch.manual_seed(0)
2611
        result = model.generate(**inputs, **gen_kwargs)
2612
        decoded_all = processor.batch_decode(result, skip_special_tokens=True)
2613

2614
        for i in range(num_samples):
2615
            assert decoded_all[i] == EXPECTED_TEXT[i]
2616

2617

2618
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
2619
    if head_mask is None:
2620
        head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
2621
    return {"input_features": input_features, "head_mask": head_mask}
2622

2623

2624
@require_torch
2625
class WhisperEncoderModelTester:
2626
    def __init__(
2627
        self,
2628
        parent,
2629
        batch_size=2,
2630
        seq_length=60,
2631
        is_training=True,
2632
        use_labels=True,
2633
        hidden_size=16,
2634
        num_hidden_layers=2,
2635
        num_attention_heads=4,
2636
        input_channels=1,
2637
        hidden_act="gelu",
2638
        hidden_dropout_prob=0.1,
2639
        attention_probs_dropout_prob=0.1,
2640
        max_position_embeddings=20,
2641
        max_source_positions=30,
2642
        num_mel_bins=80,
2643
        num_conv_layers=1,
2644
        suppress_tokens=None,
2645
        begin_suppress_tokens=None,
2646
        classifier_proj_size=4,
2647
        num_labels=2,
2648
        is_encoder_decoder=False,
2649
        is_decoder=False,
2650
    ):
2651
        self.parent = parent
2652
        self.batch_size = batch_size
2653
        self.seq_length = seq_length
2654
        self.is_training = is_training
2655
        self.use_labels = use_labels
2656
        self.hidden_size = hidden_size
2657
        self.num_hidden_layers = num_hidden_layers
2658
        self.num_attention_heads = num_attention_heads
2659
        self.input_channels = input_channels
2660
        self.hidden_act = hidden_act
2661
        self.hidden_dropout_prob = hidden_dropout_prob
2662
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
2663
        self.num_mel_bins = num_mel_bins
2664
        self.max_position_embeddings = max_position_embeddings
2665
        self.max_source_positions = max_source_positions
2666
        self.num_conv_layers = num_conv_layers
2667
        self.suppress_tokens = suppress_tokens
2668
        self.begin_suppress_tokens = begin_suppress_tokens
2669
        self.classifier_proj_size = classifier_proj_size
2670
        self.num_labels = num_labels
2671
        self.is_encoder_decoder = is_encoder_decoder
2672
        self.is_decoder = is_decoder
2673

2674
    def get_config(self):
2675
        return WhisperConfig(
2676
            d_model=self.hidden_size,
2677
            encoder_layers=self.num_hidden_layers,
2678
            decoder_layers=self.num_hidden_layers,
2679
            encoder_attention_heads=self.num_attention_heads,
2680
            decoder_attention_heads=self.num_attention_heads,
2681
            input_channels=self.input_channels,
2682
            dropout=self.hidden_dropout_prob,
2683
            attention_dropout=self.attention_probs_dropout_prob,
2684
            max_position_embeddings=self.max_position_embeddings,
2685
            max_source_positions=self.max_source_positions,
2686
            decoder_ffn_dim=self.hidden_size,
2687
            encoder_ffn_dim=self.hidden_size,
2688
            suppress_tokens=self.suppress_tokens,
2689
            begin_suppress_tokens=self.begin_suppress_tokens,
2690
            classifier_proj_size=self.classifier_proj_size,
2691
            num_labels=self.num_labels,
2692
            is_encoder_decoder=self.is_encoder_decoder,
2693
            is_decoder=self.is_decoder,
2694
        )
2695

2696
    def prepare_config_and_inputs(self):
2697
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
2698

2699
        config = self.get_config()
2700
        inputs_dict = prepare_whisper_encoder_inputs_dict(
2701
            config,
2702
            input_features=input_features,
2703
        )
2704
        return config, inputs_dict
2705

2706
    def prepare_config_and_inputs_for_common(self):
2707
        config, inputs_dict = self.prepare_config_and_inputs()
2708
        return config, inputs_dict
2709

2710
    def get_subsampled_output_lengths(self, input_lengths):
2711
        """
2712
        Computes the output length of the convolutional layers
2713
        """
2714

2715
        for i in range(self.num_conv_layers):
2716
            input_lengths = (input_lengths - 1) // 2 + 1
2717

2718
        return input_lengths
2719

2720
    @property
2721
    def encoder_seq_length(self):
2722
        return self.get_subsampled_output_lengths(self.seq_length)
2723

2724
    def create_and_check_model_forward(self, config, inputs_dict, use_weighted_layer_sum=False):
2725
        config.use_weighted_layer_sum = use_weighted_layer_sum
2726
        model = WhisperForAudioClassification(config=config)
2727
        model.to(torch_device).eval()
2728

2729
        input_features = inputs_dict["input_features"]
2730

2731
        with torch.no_grad():
2732
            last_hidden_state = model(input_features).logits
2733

2734
        self.parent.assertTrue(last_hidden_state.shape, (13, 2))
2735

2736

2737
@require_torch
2738
class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
2739
    all_model_classes = (WhisperForAudioClassification,) if is_torch_available() else ()
2740
    is_encoder_decoder = False
2741
    fx_compatible = False
2742
    test_pruning = False
2743
    test_missing_keys = False
2744

2745
    input_name = "input_features"
2746

2747
    def setUp(self):
2748
        self.model_tester = WhisperEncoderModelTester(self)
2749
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
2750
        self.maxDiff = 3000
2751

2752
    def test_config(self):
2753
        self.config_tester.run_common_tests()
2754

2755
    def test_forward_signature(self):
2756
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
2757

2758
        for model_class in self.all_model_classes:
2759
            model = model_class(config)
2760
            signature = inspect.signature(model.forward)
2761
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
2762
            arg_names = [*signature.parameters.keys()]
2763

2764
            expected_arg_names = ["input_features", "head_mask", "encoder_outputs"]
2765
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
2766

2767
    def test_forward_pass(self):
2768
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
2769
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
2770

2771
    def test_forward_pass_weighted_layer_sum(self):
2772
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
2773
        self.model_tester.create_and_check_model_forward(*config_and_inputs, use_weighted_layer_sum=True)
2774

2775
    @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
2776
    def test_cpu_offload(self):
2777
        pass
2778

2779
    @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
2780
    def test_disk_offload_bin(self):
2781
        pass
2782

2783
    @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
2784
    def test_disk_offload_safetensors(self):
2785
        pass
2786

2787
    @unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
2788
    def test_model_parallelism(self):
2789
        pass
2790

2791
    # input embeds is meaningless for an encoder-only acoustic model
2792
    def test_inputs_embeds(self):
2793
        pass
2794

2795
    # the equivalent test is passing the encoder outputs directly to the model
2796
    def test_encoder_outputs(self):
2797
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
2798

2799
        for model_class in self.all_model_classes:
2800
            model = model_class(config)
2801
            model.to(torch_device)
2802
            model.eval()
2803

2804
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
2805

2806
            with torch.no_grad():
2807
                outputs = model(**inputs)[0]
2808

2809
            encoder = model.encoder
2810

2811
            encoder_inputs = {"input_features": inputs["input_features"]}
2812
            del inputs["input_features"]
2813

2814
            if "head_mask" in inputs:
2815
                encoder_inputs["head_mask"] = inputs["head_mask"]
2816
            if "attention_mask" in inputs:
2817
                encoder_inputs["attention_mask"] = inputs["attention_mask"]
2818
            if "output_attentions" in inputs:
2819
                encoder_inputs["output_attentions"] = inputs["output_attentions"]
2820

2821
            with torch.no_grad():
2822
                inputs["encoder_outputs"] = encoder(**encoder_inputs)
2823
                outputs_embeds = model(**inputs)[0]
2824

2825
            self.assertTrue((outputs_embeds == outputs).all())
2826

2827
    # Needs to override as the encoder input embedding is a Conv1d
2828
    def test_model_common_attributes(self):
2829
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
2830

2831
        for model_class in self.all_model_classes:
2832
            model = model_class(config)
2833
            self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Conv1d))
2834
            model.set_input_embeddings(torch.nn.Conv1d(10, 10, 3))
2835
            x = model.get_output_embeddings()
2836
            self.assertTrue(x is None or isinstance(x, torch.nn.Conv1d))
2837

2838
    # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
2839
    def test_resize_tokens_embeddings(self):
2840
        pass
2841

2842
    @is_pt_flax_cross_test
2843
    def test_equivalence_pt_to_flax(self):
2844
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
2845
        init_shape = (1,) + inputs_dict["input_features"].shape[1:]
2846

2847
        for model_class in self.all_model_classes:
2848
            with self.subTest(model_class.__name__):
2849
                fx_model_class_name = "Flax" + model_class.__name__
2850

2851
                if not hasattr(transformers, fx_model_class_name):
2852
                    # no flax model exists for this class
2853
                    return
2854

2855
                # Output all for aggressive testing
2856
                config.output_hidden_states = True
2857
                config.output_attentions = self.has_attentions
2858

2859
                fx_model_class = getattr(transformers, fx_model_class_name)
2860

2861
                # load PyTorch class
2862
                pt_model = model_class(config).eval()
2863
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
2864
                # So we disable `use_cache` here for PyTorch model.
2865
                pt_model.config.use_cache = False
2866

2867
                # load Flax class
2868
                fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
2869

2870
                # make sure only flax inputs are forward that actually exist in function args
2871
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
2872

2873
                # prepare inputs
2874
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)
2875

2876
                # remove function args that don't exist in Flax
2877
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
2878

2879
                # send pytorch inputs to the correct device
2880
                pt_inputs = {
2881
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
2882
                }
2883

2884
                # convert inputs to Flax
2885
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
2886

2887
                fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
2888
                fx_model.params = fx_state
2889

2890
                # send pytorch model to the correct device
2891
                pt_model.to(torch_device)
2892

2893
                with torch.no_grad():
2894
                    pt_outputs = pt_model(**pt_inputs)
2895
                fx_outputs = fx_model(**fx_inputs)
2896

2897
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
2898
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
2899

2900
                self.assertEqual(fx_keys, pt_keys)
2901
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
2902

2903
                with tempfile.TemporaryDirectory() as tmpdirname:
2904
                    pt_model.save_pretrained(tmpdirname)
2905
                    fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True)
2906

2907
                fx_outputs_loaded = fx_model_loaded(**fx_inputs)
2908

2909
                fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
2910
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
2911

2912
                self.assertEqual(fx_keys, pt_keys)
2913
                self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class)
2914

2915
    @is_pt_flax_cross_test
2916
    def test_equivalence_flax_to_pt(self):
2917
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
2918
        init_shape = (1,) + inputs_dict["input_features"].shape[1:]
2919

2920
        for model_class in self.all_model_classes:
2921
            with self.subTest(model_class.__name__):
2922
                fx_model_class_name = "Flax" + model_class.__name__
2923

2924
                if not hasattr(transformers, fx_model_class_name):
2925
                    # no flax model exists for this class
2926
                    return
2927

2928
                # Output all for aggressive testing
2929
                config.output_hidden_states = True
2930
                config.output_attentions = self.has_attentions
2931

2932
                fx_model_class = getattr(transformers, fx_model_class_name)
2933

2934
                # load PyTorch class
2935
                pt_model = model_class(config).eval()
2936
                # Flax models don't use the `use_cache` option and cache is not returned as a default.
2937
                # So we disable `use_cache` here for PyTorch model.
2938
                pt_model.config.use_cache = False
2939

2940
                # load Flax class
2941
                fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32)
2942

2943
                # make sure only flax inputs are forward that actually exist in function args
2944
                fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
2945

2946
                # prepare inputs
2947
                pt_inputs = self._prepare_for_class(inputs_dict, model_class)
2948

2949
                # remove function args that don't exist in Flax
2950
                pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys}
2951

2952
                # send pytorch inputs to the correct device
2953
                pt_inputs = {
2954
                    k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items()
2955
                }
2956

2957
                # convert inputs to Flax
2958
                fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)}
2959

2960
                pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
2961

2962
                # make sure weights are tied in PyTorch
2963
                pt_model.tie_weights()
2964

2965
                # send pytorch model to the correct device
2966
                pt_model.to(torch_device)
2967

2968
                with torch.no_grad():
2969
                    pt_outputs = pt_model(**pt_inputs)
2970
                fx_outputs = fx_model(**fx_inputs)
2971

2972
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
2973
                pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
2974

2975
                self.assertEqual(fx_keys, pt_keys)
2976
                self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class)
2977

2978
                with tempfile.TemporaryDirectory() as tmpdirname:
2979
                    fx_model.save_pretrained(tmpdirname)
2980
                    pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
2981

2982
                # send pytorch model to the correct device
2983
                pt_model_loaded.to(torch_device)
2984
                pt_model_loaded.eval()
2985

2986
                with torch.no_grad():
2987
                    pt_outputs_loaded = pt_model_loaded(**pt_inputs)
2988

2989
                fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
2990
                pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
2991

2992
                self.assertEqual(fx_keys, pt_keys)
2993
                self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
2994

2995

2996
class WhisperStandaloneDecoderModelTester:
2997
    def __init__(
2998
        self,
2999
        parent,
3000
        batch_size=2,
3001
        is_training=True,
3002
        use_labels=False,
3003
        vocab_size=200,
3004
        hidden_size=16,
3005
        num_hidden_layers=2,
3006
        num_attention_heads=4,
3007
        input_channels=1,
3008
        hidden_act="gelu",
3009
        hidden_dropout_prob=0.1,
3010
        attention_probs_dropout_prob=0.1,
3011
        max_position_embeddings=20,
3012
        max_source_positions=30,
3013
        max_target_positions=40,
3014
        bos_token_id=98,
3015
        eos_token_id=98,
3016
        pad_token_id=0,
3017
        num_mel_bins=80,
3018
        decoder_start_token_id=85,
3019
        num_conv_layers=1,
3020
        suppress_tokens=None,
3021
        begin_suppress_tokens=None,
3022
    ):
3023
        self.parent = parent
3024
        self.batch_size = batch_size
3025
        self.is_training = is_training
3026
        self.use_labels = use_labels
3027
        self.vocab_size = vocab_size
3028
        self.hidden_size = hidden_size
3029
        self.num_hidden_layers = num_hidden_layers
3030
        self.num_attention_heads = num_attention_heads
3031
        self.input_channels = input_channels
3032
        self.hidden_act = hidden_act
3033
        self.hidden_dropout_prob = hidden_dropout_prob
3034
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
3035
        self.num_mel_bins = num_mel_bins
3036
        self.max_position_embeddings = max_position_embeddings
3037
        self.max_source_positions = max_source_positions
3038
        self.max_target_positions = max_target_positions
3039
        self.eos_token_id = eos_token_id
3040
        self.pad_token_id = pad_token_id
3041
        self.bos_token_id = bos_token_id
3042
        self.decoder_start_token_id = decoder_start_token_id
3043
        self.num_conv_layers = num_conv_layers
3044
        self.suppress_tokens = suppress_tokens
3045
        self.begin_suppress_tokens = begin_suppress_tokens
3046

3047
    def prepare_config_and_inputs(self):
3048
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
3049

3050
        decoder_input_ids = torch.tensor(
3051
            self.batch_size * [[self.decoder_start_token_id, 3, 3, 7, 2]], device=torch_device
3052
        )
3053

3054
        config = self.get_config()
3055
        config.is_encoder_decoder = False
3056
        inputs_dict = prepare_whisper_inputs_dict(
3057
            config,
3058
            attention_mask=None,
3059
            input_features=input_features,
3060
            decoder_input_ids=decoder_input_ids,
3061
        )
3062

3063
        inputs_dict.pop("input_features")
3064
        inputs_dict.pop("head_mask")
3065
        inputs_dict.pop("decoder_head_mask")
3066
        inputs_dict.pop("cross_attn_head_mask")
3067

3068
        inputs_dict["attention_mask"] = inputs_dict.pop("decoder_attention_mask")
3069
        inputs_dict["input_ids"] = inputs_dict.pop("decoder_input_ids")
3070
        return config, inputs_dict
3071

3072
    @property
3073
    def encoder_seq_length(self):
3074
        return 5
3075

3076
    @property
3077
    def seq_length(self):
3078
        return 5
3079

3080
    def get_config(self):
3081
        return WhisperConfig(
3082
            vocab_size=self.vocab_size,
3083
            d_model=self.hidden_size,
3084
            encoder_layers=self.num_hidden_layers,
3085
            decoder_layers=self.num_hidden_layers,
3086
            encoder_attention_heads=self.num_attention_heads,
3087
            decoder_attention_heads=self.num_attention_heads,
3088
            input_channels=self.input_channels,
3089
            dropout=self.hidden_dropout_prob,
3090
            attention_dropout=self.attention_probs_dropout_prob,
3091
            max_position_embeddings=self.max_position_embeddings,
3092
            max_source_positions=self.max_source_positions,
3093
            max_target_positions=self.max_target_positions,
3094
            eos_token_id=self.eos_token_id,
3095
            bos_token_id=self.bos_token_id,
3096
            pad_token_id=self.pad_token_id,
3097
            decoder_ffn_dim=self.hidden_size,
3098
            encoder_ffn_dim=self.hidden_size,
3099
            decoder_start_token_id=self.decoder_start_token_id,
3100
            suppress_tokens=self.suppress_tokens,
3101
            begin_suppress_tokens=self.begin_suppress_tokens,
3102
        )
3103

3104
    def prepare_config_and_inputs_for_common(self):
3105
        config, inputs_dict = self.prepare_config_and_inputs()
3106

3107
        inputs_dict["input_ids"][:, -1] = self.pad_token_id
3108

3109
        return config, inputs_dict
3110

3111
    def prepare_config_and_inputs_for_decoder(self):
3112
        config, input_features = self.prepare_config_and_inputs()
3113
        input_ids = input_features["input_ids"]
3114
        encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
3115

3116
        return (config, input_ids, encoder_hidden_states)
3117

3118
    def create_and_check_decoder_model_past(self, config, input_ids):
3119
        config.use_cache = True
3120
        model = WhisperDecoder(config=config).to(torch_device).eval()
3121
        # first forward pass
3122
        outputs = model(input_ids, use_cache=True)
3123
        outputs_use_cache_conf = model(input_ids)
3124
        outputs_no_past = model(input_ids, use_cache=False)
3125

3126
        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
3127
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
3128

3129
        past_key_values = outputs["past_key_values"]
3130

3131
        # create hypothetical next token and extent to next_input_ids
3132
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
3133

3134
        # append to next input_ids and
3135
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
3136

3137
        output_from_no_past = model(next_input_ids)["last_hidden_state"]
3138
        output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
3139

3140
        # select random slice
3141
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
3142
        output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
3143
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
3144

3145
        # test that outputs are equal for slice
3146
        assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
3147

3148
    def create_and_check_decoder_model_attention_mask_past(self, config, input_ids):
3149
        model = WhisperDecoder(config=config).to(torch_device).eval()
3150

3151
        # create attention mask
3152
        attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
3153

3154
        half_seq_length = input_ids.shape[-1] // 2
3155
        attn_mask[:, half_seq_length:] = 0
3156

3157
        # first forward pass
3158
        past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
3159

3160
        # create hypothetical next token and extent to next_input_ids
3161
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
3162

3163
        # change a random masked slice from input_ids
3164
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
3165
        random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
3166
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
3167

3168
        # append to next input_ids and attn_mask
3169
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
3170
        attn_mask = torch.cat(
3171
            [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
3172
            dim=1,
3173
        )
3174

3175
        # get two different outputs
3176
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
3177
        output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
3178
            "last_hidden_state"
3179
        ]
3180

3181
        # select random slice
3182
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
3183
        output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
3184
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
3185

3186
        # test that outputs are equal for slice
3187
        assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
3188

3189

3190
@require_torch
3191
class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
3192
    all_model_classes = (WhisperDecoder, WhisperForCausalLM) if is_torch_available() else ()
3193
    all_generative_model_classes = (WhisperForCausalLM,) if is_torch_available() else ()
3194
    fx_comptatible = False
3195
    test_pruning = False
3196
    is_encoder_decoder = False
3197
    test_missing_keys = False
3198

3199
    def setUp(self):
3200
        self.model_tester = WhisperStandaloneDecoderModelTester(self, is_training=False)
3201
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
3202

3203
    def test_config(self):
3204
        self.config_tester.run_common_tests()
3205

3206
    def test_decoder_model_past(self):
3207
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
3208
        config, inputs_dict = config_and_inputs
3209

3210
        self.model_tester.create_and_check_decoder_model_past(config=config, input_ids=inputs_dict["input_ids"])
3211

3212
    def test_decoder_model_attn_mask_past(self):
3213
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
3214
        config, inputs_dict = config_and_inputs
3215

3216
        self.model_tester.create_and_check_decoder_model_attention_mask_past(
3217
            config=config, input_ids=inputs_dict["input_ids"]
3218
        )
3219

3220
    @unittest.skip("Generate needs input ids")
3221
    def test_generate_without_input_ids(self):
3222
        # generate only works with input ids for whisper
3223
        pass
3224

3225
    @unittest.skip("Decoder can't keep attention grads")
3226
    def test_retain_grad_hidden_states_attentions(self):
3227
        # decoder cannot keep gradients
3228
        return
3229

3230
    @unittest.skip("The model doesn't support fast init from base")
3231
    def test_save_load_fast_init_from_base(self):
3232
        pass
3233

3234
    @unittest.skip("The model doesn't support left padding")  # and it's not used enough to be worth fixing :)
3235
    def test_left_padding_compatibility(self):
3236
        pass
3237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.