transformers

Форк
0
/
test_modeling_flax_whisper.py 
923 строки · 39.3 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import functools
16
import inspect
17
import tempfile
18
import unittest
19

20
import transformers
21
from transformers import WhisperConfig, is_flax_available
22
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
23
from transformers.utils import cached_property
24
from transformers.utils.import_utils import is_datasets_available
25

26
from ...test_configuration_common import ConfigTester
27
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
28

29

30
if is_datasets_available():
31
    import datasets
32
    from datasets import load_dataset
33

34
if is_flax_available():
35
    import jax
36
    import numpy as np
37
    from flax.core.frozen_dict import unfreeze
38
    from flax.traverse_util import flatten_dict
39

40
    from transformers import (
41
        FLAX_MODEL_MAPPING,
42
        FlaxWhisperForAudioClassification,
43
        FlaxWhisperForConditionalGeneration,
44
        FlaxWhisperModel,
45
        WhisperFeatureExtractor,
46
        WhisperProcessor,
47
    )
48
    from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
49
    from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init
50

51

52
@require_flax
53
class FlaxWhisperModelTester:
54
    config_cls = WhisperConfig
55
    config_updates = {}
56
    hidden_act = "gelu"
57

58
    def __init__(
59
        self,
60
        parent,
61
        batch_size=13,
62
        seq_length=60,
63
        is_training=True,
64
        use_labels=False,
65
        vocab_size=99,
66
        d_model=16,
67
        decoder_attention_heads=4,
68
        decoder_ffn_dim=16,
69
        decoder_layers=2,
70
        encoder_attention_heads=4,
71
        encoder_ffn_dim=16,
72
        encoder_layers=2,
73
        input_channels=1,
74
        hidden_act="gelu",
75
        hidden_dropout_prob=0.1,
76
        attention_probs_dropout_prob=0.1,
77
        max_position_embeddings=70,
78
        max_source_positions=30,
79
        max_target_positions=40,
80
        bos_token_id=98,
81
        eos_token_id=98,
82
        pad_token_id=0,
83
        num_mel_bins=80,
84
        decoder_start_token_id=85,
85
        num_conv_layers=1,
86
        suppress_tokens=None,
87
        begin_suppress_tokens=None,
88
    ):
89
        self.parent = parent
90
        self.batch_size = batch_size
91
        self.seq_length = seq_length
92
        self.is_training = is_training
93
        self.use_labels = use_labels
94
        self.vocab_size = vocab_size
95
        self.d_model = d_model
96
        self.hidden_size = d_model
97
        self.num_hidden_layers = encoder_layers
98
        self.num_attention_heads = encoder_attention_heads
99
        self.decoder_attention_heads = decoder_attention_heads
100
        self.decoder_ffn_dim = decoder_ffn_dim
101
        self.decoder_layers = decoder_layers
102
        self.encoder_attention_heads = encoder_attention_heads
103
        self.encoder_ffn_dim = encoder_ffn_dim
104
        self.encoder_layers = encoder_layers
105
        self.encoder_seq_length = seq_length // 2
106
        self.decoder_seq_length = 1
107
        self.input_channels = input_channels
108
        self.hidden_act = hidden_act
109
        self.hidden_dropout_prob = hidden_dropout_prob
110
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
111
        self.num_mel_bins = num_mel_bins
112
        self.max_position_embeddings = max_position_embeddings
113
        self.max_source_positions = max_source_positions
114
        self.max_target_positions = max_target_positions
115
        self.eos_token_id = eos_token_id
116
        self.pad_token_id = pad_token_id
117
        self.bos_token_id = bos_token_id
118
        self.decoder_start_token_id = decoder_start_token_id
119
        self.num_conv_layers = num_conv_layers
120
        self.suppress_tokens = suppress_tokens
121
        self.begin_suppress_tokens = begin_suppress_tokens
122

123
    def prepare_config_and_inputs_for_common(self):
124
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
125

126
        decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]])
127

128
        config = WhisperConfig(
129
            vocab_size=self.vocab_size,
130
            num_mel_bins=self.num_mel_bins,
131
            decoder_start_token_id=self.decoder_start_token_id,
132
            is_encoder_decoder=True,
133
            activation_function=self.hidden_act,
134
            dropout=self.hidden_dropout_prob,
135
            attention_dropout=self.attention_probs_dropout_prob,
136
            max_source_positions=self.max_source_positions,
137
            max_target_positions=self.max_target_positions,
138
            pad_token_id=self.pad_token_id,
139
            bos_token_id=self.bos_token_id,
140
            eos_token_id=self.eos_token_id,
141
            tie_word_embeddings=True,
142
            d_model=self.d_model,
143
            decoder_attention_heads=self.decoder_attention_heads,
144
            decoder_ffn_dim=self.decoder_ffn_dim,
145
            decoder_layers=self.decoder_layers,
146
            encoder_attention_heads=self.encoder_attention_heads,
147
            encoder_ffn_dim=self.encoder_ffn_dim,
148
            encoder_layers=self.encoder_layers,
149
            suppress_tokens=self.suppress_tokens,
150
            begin_suppress_tokens=self.begin_suppress_tokens,
151
        )
152
        inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids)
153
        return config, inputs_dict
154

155

156
def prepare_whisper_inputs_dict(
157
    config,
158
    input_ids,
159
    decoder_input_ids,
160
    attention_mask=None,
161
    decoder_attention_mask=None,
162
):
163
    if decoder_attention_mask is None:
164
        decoder_attention_mask = np.concatenate(
165
            [
166
                np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8),
167
                np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8),
168
            ],
169
            axis=-1,
170
        )
171
    return {
172
        "input_features": input_ids,
173
        "decoder_input_ids": decoder_input_ids,
174
        "decoder_attention_mask": decoder_attention_mask,
175
    }
176

177

178
def partialclass(cls, *args, **kwargs):
179
    class NewCls(cls):
180
        __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
181

182
    return NewCls
183

184

185
def make_partial_class(full_class, *args, **kwargs):
186
    partial_class = partialclass(full_class, *args, **kwargs)
187
    partial_class.__name__ = full_class.__name__
188
    partial_class.__module__ = full_class.__module__
189

190
    return partial_class
191

192

193
@require_flax
194
class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):
195
    all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else ()
196
    all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else ()
197
    is_encoder_decoder = True
198
    test_pruning = False
199
    test_head_masking = False
200
    test_onnx = False
201

202
    def setUp(self):
203
        self.model_tester = FlaxWhisperModelTester(self)
204
        _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
205
        self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
206

207
        self.all_model_classes = (
208
            make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
209
        )
210
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
211

212
    def test_config(self):
213
        self.config_tester.run_common_tests()
214

215
    # overwrite because of `input_features`
216
    def test_forward_signature(self):
217
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
218

219
        for model_class in self.all_model_classes:
220
            model = model_class(config)
221
            signature = inspect.signature(model.__call__)
222
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
223
            arg_names = [*signature.parameters.keys()]
224

225
            expected_arg_names = ["input_features", "decoder_input_ids"]
226
            self.assertListEqual(arg_names[:2], expected_arg_names)
227

228
    # overwrite because of `input_features`
229
    def test_jit_compilation(self):
230
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
231

232
        for model_class in self.all_model_classes:
233
            with self.subTest(model_class.__name__):
234
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
235
                model = model_class(config)
236

237
                @jax.jit
238
                def model_jitted(input_features, decoder_input_ids, **kwargs):
239
                    return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs)
240

241
                with self.subTest("JIT Enabled"):
242
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
243

244
                with self.subTest("JIT Disabled"):
245
                    with jax.disable_jit():
246
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
247

248
                self.assertEqual(len(outputs), len(jitted_outputs))
249
                for jitted_output, output in zip(jitted_outputs, outputs):
250
                    self.assertEqual(jitted_output.shape, output.shape)
251

252
    def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
253
        # We override with a slightly higher tol value, as test recently became flaky
254
        super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
255

256
    # overwrite because of `input_features`
257
    @is_pt_flax_cross_test
258
    def test_save_load_bf16_to_base_pt(self):
259
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
260
        base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
261

262
        for model_class in self.all_model_classes:
263
            if model_class.__name__ == base_class.__name__:
264
                continue
265

266
            model = model_class(config)
267
            model.params = model.to_bf16(model.params)
268
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
269

270
            # convert Flax model to PyTorch model
271
            pt_model_class = getattr(transformers, model_class.__name__[4:])  # Skip the "Flax" at the beginning
272
            pt_model = pt_model_class(config).eval()
273
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
274

275
            # check that all base model weights are loaded correctly
276
            with tempfile.TemporaryDirectory() as tmpdirname:
277
                pt_model.save_pretrained(tmpdirname)
278
                base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
279

280
                base_params = flatten_dict(unfreeze(base_model.params))
281

282
                for key in base_params_from_head.keys():
283
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
284
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
285

286
    # overwrite because of `input_features`
287
    @is_pt_flax_cross_test
288
    def test_save_load_from_base_pt(self):
289
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
290
        base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
291

292
        for model_class in self.all_model_classes:
293
            if model_class.__name__ == base_class.__name__:
294
                continue
295

296
            model = base_class(config)
297
            base_params = flatten_dict(unfreeze(model.params))
298

299
            # convert Flax model to PyTorch model
300
            pt_model_class = getattr(transformers, base_class.__name__[4:])  # Skip the "Flax" at the beginning
301
            pt_model = pt_model_class(config).eval()
302
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
303

304
            # check that all base model weights are loaded correctly
305
            with tempfile.TemporaryDirectory() as tmpdirname:
306
                # save pt model
307
                pt_model.save_pretrained(tmpdirname)
308
                head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
309

310
                base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
311

312
                for key in base_param_from_head.keys():
313
                    max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
314
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
315

316
    # overwrite because of `input_features`
317
    @is_pt_flax_cross_test
318
    def test_save_load_to_base_pt(self):
319
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
320
        base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
321

322
        for model_class in self.all_model_classes:
323
            if model_class.__name__ == base_class.__name__:
324
                continue
325

326
            model = model_class(config)
327
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
328

329
            # convert Flax model to PyTorch model
330
            pt_model_class = getattr(transformers, model_class.__name__[4:])  # Skip the "Flax" at the beginning
331
            pt_model = pt_model_class(config).eval()
332
            pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
333

334
            # check that all base model weights are loaded correctly
335
            with tempfile.TemporaryDirectory() as tmpdirname:
336
                pt_model.save_pretrained(tmpdirname)
337
                base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
338

339
                base_params = flatten_dict(unfreeze(base_model.params))
340

341
                for key in base_params_from_head.keys():
342
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
343
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
344

345
    # overwrite because of `input_features`
346
    def test_save_load_from_base(self):
347
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
348
        base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
349

350
        for model_class in self.all_model_classes:
351
            if model_class.__name__ == base_class.__name__:
352
                continue
353

354
            model = base_class(config)
355
            base_params = flatten_dict(unfreeze(model.params))
356

357
            # check that all base model weights are loaded correctly
358
            with tempfile.TemporaryDirectory() as tmpdirname:
359
                model.save_pretrained(tmpdirname)
360
                head_model = model_class.from_pretrained(tmpdirname)
361

362
                base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
363

364
                for key in base_param_from_head.keys():
365
                    max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
366
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
367

368
    # overwrite because of `input_features`
369
    def test_save_load_to_base(self):
370
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
371
        base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)
372

373
        for model_class in self.all_model_classes:
374
            if model_class.__name__ == base_class.__name__:
375
                continue
376

377
            model = model_class(config)
378
            base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
379

380
            # check that all base model weights are loaded correctly
381
            with tempfile.TemporaryDirectory() as tmpdirname:
382
                model.save_pretrained(tmpdirname)
383
                base_model = base_class.from_pretrained(tmpdirname)
384

385
                base_params = flatten_dict(unfreeze(base_model.params))
386

387
                for key in base_params_from_head.keys():
388
                    max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
389
                    self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
390

391
    def test_encoder_sinusoidal_embed_positions(self):
392
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
393

394
        for model_class in self.all_model_classes:
395
            model = model_class(config)
396
            params = model.params
397
            if model.base_model_prefix in params:
398
                params = model.params[model.base_model_prefix]
399

400
            embeds = params["encoder"]["embed_positions"]["embedding"]
401
            sinusoids = sinusoidal_embedding_init(None, embeds.shape)
402
            self.assertTrue(jax.numpy.allclose(embeds, sinusoids))
403

404

405
@slow
406
@require_flax
407
class FlaxWhisperModelIntegrationTest(unittest.TestCase):
408
    @cached_property
409
    def default_processor(self):
410
        return WhisperProcessor.from_pretrained("openai/whisper-base")
411

412
    def _load_datasamples(self, num_samples):
413
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
414
        # automatic decoding with librispeech
415
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
416

417
        return [x["array"] for x in speech_samples]
418

419
    def test_tiny_logits_librispeech(self):
420
        model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True)
421
        input_speech = self._load_datasamples(1)
422
        feature_extractor = WhisperFeatureExtractor()
423
        input_features = feature_extractor(input_speech, return_tensors="np").input_features
424

425
        logits = model(
426
            input_features,
427
            decoder_input_ids=np.array([[50258, 50259, 50359]]),
428
            output_hidden_states=False,
429
            output_attentions=False,
430
            return_dict=False,
431
        )
432

433
        # fmt: off
434
        EXPECTED_LOGITS = np.array(
435
            [
436
                2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
437
                0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
438
                4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
439
                0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
440
            ]
441
        )
442
        # fmt: on
443
        self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
444

445
    def test_small_en_logits_librispeech(self):
446
        model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)
447
        input_speech = self._load_datasamples(1)
448
        feature_extractor = WhisperFeatureExtractor()
449
        input_features = feature_extractor(input_speech, return_tensors="np").input_features
450

451
        logits = model(
452
            input_features,
453
            decoder_input_ids=np.array([model.config.decoder_start_token_id]),
454
            output_hidden_states=False,
455
            output_attentions=False,
456
            return_dict=False,
457
        )
458

459
        logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
460

461
        # fmt: off
462
        EXPECTED_LOGITS = np.array(
463
            [
464
                -3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
465
                -8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
466
                -6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
467
                -10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
468
                -11.1146, -8.1918
469
            ]
470
        )
471
        # fmt: on
472
        self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
473

474
    def test_large_logits_librispeech(self):
475
        model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)
476
        input_speech = self._load_datasamples(1)
477
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
478
        processed_inputs = processor(
479
            audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"
480
        )
481
        input_features = processed_inputs.input_features
482
        decoder_input_ids = processed_inputs.labels
483

484
        logits = model(
485
            input_features,
486
            decoder_input_ids=decoder_input_ids,
487
            output_hidden_states=False,
488
            output_attentions=False,
489
            return_dict=False,
490
        )
491

492
        logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
493

494
        # fmt: off
495
        EXPECTED_LOGITS = np.array(
496
            [
497
                2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
498
                1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
499
                1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
500
                1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
501
            ]
502
        )
503
        # fmt: on
504
        self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
505

506
    def test_tiny_en_generation(self):
507
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
508
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
509
        model.config.decoder_start_token_id = 50257
510

511
        input_speech = self._load_datasamples(1)
512
        input_features = processor.feature_extractor(
513
            raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
514
        ).input_features
515

516
        generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
517
        transcript = processor.tokenizer.decode(generated_ids[0])
518

519
        EXPECTED_TRANSCRIPT = (
520
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
521
            " classes and we are glad to"
522
        )
523
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
524

525
    def test_tiny_generation(self):
526
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
527
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True)
528

529
        input_speech = self._load_datasamples(1)
530
        input_features = processor.feature_extractor(
531
            raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
532
        ).input_features
533

534
        generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
535
        transcript = processor.tokenizer.decode(generated_ids[0])
536

537
        EXPECTED_TRANSCRIPT = (
538
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
539
            " classes and we are glad"
540
        )
541
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
542

543
    def test_large_generation(self):
544
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
545
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
546

547
        input_speech = self._load_datasamples(1)
548
        input_features = processor.feature_extractor(
549
            raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
550
        ).input_features
551

552
        model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
553

554
        generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
555
        transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
556

557
        EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
558
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
559

560
    def test_large_generation_multilingual(self):
561
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
562
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
563

564
        ds = load_dataset("common_voice", "ja", split="test", streaming=True)
565
        ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
566
        input_speech = next(iter(ds))["audio"]["array"]
567
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np")
568

569
        model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
570
        generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
571
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
572

573
        EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
574
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
575

576
        model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
577
        generated_ids = model.generate(
578
            input_features,
579
            do_sample=False,
580
            max_length=20,
581
        ).sequences
582
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
583

584
        EXPECTED_TRANSCRIPT = " Kimura-san called me."
585
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
586

587
        model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
588
        generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences
589
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
590

591
        EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
592
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
593

594
    def test_large_batched_generation(self):
595
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
596
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
597

598
        input_speech = self._load_datasamples(4)
599
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
600
        generated_ids = model.generate(input_features, max_length=20).sequences
601

602
        # fmt: off
603
        EXPECTED_LOGITS = np.array(
604
            [
605
                [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
606
                [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
607
                [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
608
                [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
609
            ]
610
        )
611
        # fmt: on
612

613
        self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
614

615
        # fmt: off
616
        EXPECTED_TRANSCRIPT = [
617
            " Mr. Quilter is the apostle of the middle classes and we are glad to",
618
            " Nor is Mr. Quilter's manner less interesting than his matter.",
619
            " He tells us that at this festive season of the year, with Christmas and roast beef",
620
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
621
        ]
622
        # fmt: on
623

624
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
625
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
626

627
    def test_tiny_en_batched_generation(self):
628
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
629
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
630

631
        input_speech = self._load_datasamples(4)
632
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
633
        generated_ids = model.generate(input_features, max_length=20).sequences
634

635
        # fmt: off
636
        EXPECTED_LOGITS = np.array(
637
            [
638
                [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
639
                [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
640
                [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
641
                [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
642
            ]
643

644
        )
645
        # fmt: on
646

647
        self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
648

649
        # fmt: off
650
        EXPECTED_TRANSCRIPT = [
651
            " Mr. Quilter is the apostle of the middle classes, and we are glad to",
652
            " Nor is Mr. Quilter's manner less interesting than his matter.",
653
            " He tells us that at this festive season of the year, with Christmas and roast beef looming",
654
            " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
655
        ]
656
        # fmt: on
657

658
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
659
        self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
660

661
    @slow
662
    def test_tiny_timestamp_generation(self):
663
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
664
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
665

666
        input_speech = np.concatenate(self._load_datasamples(4))
667
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features
668

669
        generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))
670

671
        generated_ids = generate_fn(input_features)
672

673
        EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])  # fmt: skip
674

675
        self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))
676

677
        EXPECTED_TRANSCRIPT = [
678
            {
679
                "text": (
680
                    " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"
681
                    " Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"
682
                    " of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"
683
                    " its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"
684
                    " work is really Greek after all, and"
685
                ),
686
                "offsets": [
687
                    {
688
                        "text": (
689
                            " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
690
                        ),
691
                        "timestamp": (0.0, 6.5600000000000005),
692
                    },
693
                    {
694
                        "text": " Nor is Mr. Quilter's manner less interesting than his matter.",
695
                        "timestamp": (6.5600000000000005, 11.24),
696
                    },
697
                    {
698
                        "text": (
699
                            " He tells us that at this festive season of the year, with Christmas and roast beef"
700
                            " looming"
701
                        ),
702
                        "timestamp": (11.24, 16.88),
703
                    },
704
                    {
705
                        "text": (
706
                            " before us, similarly drawn from eating and its results occur most readily to the mind."
707
                        ),
708
                        "timestamp": (16.88, 23.76),
709
                    },
710
                    {
711
                        "text": (
712
                            " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"
713
                        ),
714
                        "timestamp": (23.76, 29.44),
715
                    },
716
                ],
717
            }
718
        ]
719

720
        transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
721
        self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
722

723

724
class FlaxWhisperEncoderModelTester:
725
    def __init__(
726
        self,
727
        parent,
728
        batch_size=13,
729
        seq_length=60,
730
        is_training=True,
731
        use_labels=True,
732
        hidden_size=16,
733
        num_hidden_layers=2,
734
        num_attention_heads=4,
735
        input_channels=1,
736
        hidden_act="gelu",
737
        hidden_dropout_prob=0.1,
738
        attention_probs_dropout_prob=0.1,
739
        max_position_embeddings=20,
740
        max_source_positions=30,
741
        num_mel_bins=80,
742
        num_conv_layers=1,
743
        suppress_tokens=None,
744
        begin_suppress_tokens=None,
745
        classifier_proj_size=4,
746
        num_labels=2,
747
        is_encoder_decoder=False,
748
        is_decoder=False,
749
    ):
750
        self.parent = parent
751
        self.batch_size = batch_size
752
        self.seq_length = seq_length
753
        self.is_training = is_training
754
        self.use_labels = use_labels
755
        self.hidden_size = hidden_size
756
        self.num_hidden_layers = num_hidden_layers
757
        self.num_attention_heads = num_attention_heads
758
        self.input_channels = input_channels
759
        self.hidden_act = hidden_act
760
        self.hidden_dropout_prob = hidden_dropout_prob
761
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
762
        self.num_mel_bins = num_mel_bins
763
        self.max_position_embeddings = max_position_embeddings
764
        self.max_source_positions = max_source_positions
765
        self.num_conv_layers = num_conv_layers
766
        self.suppress_tokens = suppress_tokens
767
        self.begin_suppress_tokens = begin_suppress_tokens
768
        self.classifier_proj_size = classifier_proj_size
769
        self.num_labels = num_labels
770
        self.is_encoder_decoder = is_encoder_decoder
771
        self.is_decoder = is_decoder
772

773
    def get_config(self):
774
        return WhisperConfig(
775
            d_model=self.hidden_size,
776
            encoder_layers=self.num_hidden_layers,
777
            decoder_layers=self.num_hidden_layers,
778
            encoder_attention_heads=self.num_attention_heads,
779
            decoder_attention_heads=self.num_attention_heads,
780
            input_channels=self.input_channels,
781
            dropout=self.hidden_dropout_prob,
782
            attention_dropout=self.attention_probs_dropout_prob,
783
            max_position_embeddings=self.max_position_embeddings,
784
            max_source_positions=self.max_source_positions,
785
            decoder_ffn_dim=self.hidden_size,
786
            encoder_ffn_dim=self.hidden_size,
787
            suppress_tokens=self.suppress_tokens,
788
            begin_suppress_tokens=self.begin_suppress_tokens,
789
            classifier_proj_size=self.classifier_proj_size,
790
            num_labels=self.num_labels,
791
            is_encoder_decoder=self.is_encoder_decoder,
792
            is_decoder=self.is_decoder,
793
        )
794

795
    def prepare_whisper_encoder_inputs_dict(
796
        self,
797
        input_features,
798
    ):
799
        return {
800
            "input_features": input_features,
801
        }
802

803
    def prepare_config_and_inputs(self):
804
        input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])
805

806
        config = self.get_config()
807
        inputs_dict = self.prepare_whisper_encoder_inputs_dict(
808
            input_features=input_features,
809
        )
810
        return config, inputs_dict
811

812
    def prepare_config_and_inputs_for_common(self):
813
        config, inputs_dict = self.prepare_config_and_inputs()
814
        return config, inputs_dict
815

816
    def get_subsampled_output_lengths(self, input_lengths):
817
        """
818
        Computes the output length of the convolutional layers
819
        """
820

821
        for i in range(self.num_conv_layers):
822
            input_lengths = (input_lengths - 1) // 2 + 1
823

824
        return input_lengths
825

826
    @property
827
    def encoder_seq_length(self):
828
        return self.get_subsampled_output_lengths(self.seq_length)
829

830

831
@require_flax
832
class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):
833
    all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else ()
834
    is_encoder_decoder = False
835
    fx_compatible = False
836
    test_pruning = False
837
    test_missing_keys = False
838

839
    input_name = "input_features"
840

841
    def setUp(self):
842
        self.model_tester = FlaxWhisperEncoderModelTester(self)
843
        _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
844
        self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]
845

846
        self.all_model_classes = (
847
            make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes
848
        )
849
        self.config_tester = ConfigTester(self, config_class=WhisperConfig)
850

851
    def test_config(self):
852
        self.config_tester.run_common_tests()
853

854
    # overwrite because of `input_features`
855
    def test_jit_compilation(self):
856
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
857

858
        for model_class in self.all_model_classes:
859
            with self.subTest(model_class.__name__):
860
                prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
861
                model = model_class(config)
862

863
                @jax.jit
864
                def model_jitted(input_features, **kwargs):
865
                    return model(input_features=input_features, **kwargs)
866

867
                with self.subTest("JIT Enabled"):
868
                    jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
869

870
                with self.subTest("JIT Disabled"):
871
                    with jax.disable_jit():
872
                        outputs = model_jitted(**prepared_inputs_dict).to_tuple()
873

874
                self.assertEqual(len(outputs), len(jitted_outputs))
875
                for jitted_output, output in zip(jitted_outputs, outputs):
876
                    self.assertEqual(jitted_output.shape, output.shape)
877

878
    # overwrite because of `input_features`
879
    def test_forward_signature(self):
880
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
881

882
        for model_class in self.all_model_classes:
883
            model = model_class(config)
884
            signature = inspect.signature(model.__call__)
885
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
886
            arg_names = [*signature.parameters.keys()]
887

888
            expected_arg_names = ["input_features", "attention_mask", "output_attentions"]
889
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
890

891
    def test_inputs_embeds(self):
892
        pass
893

894
    # WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented
895
    def test_model_common_attributes(self):
896
        pass
897

898
    # WhisperEncoder cannot resize token embeddings since it has no tokens embeddings
899
    def test_resize_tokens_embeddings(self):
900
        pass
901

902
    # WhisperEncoder does not have any base model
903
    def test_save_load_to_base(self):
904
        pass
905

906
    # WhisperEncoder does not have any base model
907
    def test_save_load_from_base(self):
908
        pass
909

910
    # WhisperEncoder does not have any base model
911
    @is_pt_flax_cross_test
912
    def test_save_load_from_base_pt(self):
913
        pass
914

915
    # WhisperEncoder does not have any base model
916
    @is_pt_flax_cross_test
917
    def test_save_load_to_base_pt(self):
918
        pass
919

920
    # WhisperEncoder does not have any base model
921
    @is_pt_flax_cross_test
922
    def test_save_load_bf16_to_base_pt(self):
923
        pass
924

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.