transformers
923 строки · 39.3 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import functools16import inspect17import tempfile18import unittest19
20import transformers21from transformers import WhisperConfig, is_flax_available22from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow23from transformers.utils import cached_property24from transformers.utils.import_utils import is_datasets_available25
26from ...test_configuration_common import ConfigTester27from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor28
29
30if is_datasets_available():31import datasets32from datasets import load_dataset33
34if is_flax_available():35import jax36import numpy as np37from flax.core.frozen_dict import unfreeze38from flax.traverse_util import flatten_dict39
40from transformers import (41FLAX_MODEL_MAPPING,42FlaxWhisperForAudioClassification,43FlaxWhisperForConditionalGeneration,44FlaxWhisperModel,45WhisperFeatureExtractor,46WhisperProcessor,47)48from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model49from transformers.models.whisper.modeling_flax_whisper import sinusoidal_embedding_init50
51
52@require_flax
53class FlaxWhisperModelTester:54config_cls = WhisperConfig55config_updates = {}56hidden_act = "gelu"57
58def __init__(59self,60parent,61batch_size=13,62seq_length=60,63is_training=True,64use_labels=False,65vocab_size=99,66d_model=16,67decoder_attention_heads=4,68decoder_ffn_dim=16,69decoder_layers=2,70encoder_attention_heads=4,71encoder_ffn_dim=16,72encoder_layers=2,73input_channels=1,74hidden_act="gelu",75hidden_dropout_prob=0.1,76attention_probs_dropout_prob=0.1,77max_position_embeddings=70,78max_source_positions=30,79max_target_positions=40,80bos_token_id=98,81eos_token_id=98,82pad_token_id=0,83num_mel_bins=80,84decoder_start_token_id=85,85num_conv_layers=1,86suppress_tokens=None,87begin_suppress_tokens=None,88):89self.parent = parent90self.batch_size = batch_size91self.seq_length = seq_length92self.is_training = is_training93self.use_labels = use_labels94self.vocab_size = vocab_size95self.d_model = d_model96self.hidden_size = d_model97self.num_hidden_layers = encoder_layers98self.num_attention_heads = encoder_attention_heads99self.decoder_attention_heads = decoder_attention_heads100self.decoder_ffn_dim = decoder_ffn_dim101self.decoder_layers = decoder_layers102self.encoder_attention_heads = encoder_attention_heads103self.encoder_ffn_dim = encoder_ffn_dim104self.encoder_layers = encoder_layers105self.encoder_seq_length = seq_length // 2106self.decoder_seq_length = 1107self.input_channels = input_channels108self.hidden_act = hidden_act109self.hidden_dropout_prob = hidden_dropout_prob110self.attention_probs_dropout_prob = attention_probs_dropout_prob111self.num_mel_bins = num_mel_bins112self.max_position_embeddings = max_position_embeddings113self.max_source_positions = max_source_positions114self.max_target_positions = max_target_positions115self.eos_token_id = eos_token_id116self.pad_token_id = pad_token_id117self.bos_token_id = bos_token_id118self.decoder_start_token_id = decoder_start_token_id119self.num_conv_layers = num_conv_layers120self.suppress_tokens = suppress_tokens121self.begin_suppress_tokens = begin_suppress_tokens122
123def prepare_config_and_inputs_for_common(self):124input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)125
126decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]])127
128config = WhisperConfig(129vocab_size=self.vocab_size,130num_mel_bins=self.num_mel_bins,131decoder_start_token_id=self.decoder_start_token_id,132is_encoder_decoder=True,133activation_function=self.hidden_act,134dropout=self.hidden_dropout_prob,135attention_dropout=self.attention_probs_dropout_prob,136max_source_positions=self.max_source_positions,137max_target_positions=self.max_target_positions,138pad_token_id=self.pad_token_id,139bos_token_id=self.bos_token_id,140eos_token_id=self.eos_token_id,141tie_word_embeddings=True,142d_model=self.d_model,143decoder_attention_heads=self.decoder_attention_heads,144decoder_ffn_dim=self.decoder_ffn_dim,145decoder_layers=self.decoder_layers,146encoder_attention_heads=self.encoder_attention_heads,147encoder_ffn_dim=self.encoder_ffn_dim,148encoder_layers=self.encoder_layers,149suppress_tokens=self.suppress_tokens,150begin_suppress_tokens=self.begin_suppress_tokens,151)152inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids)153return config, inputs_dict154
155
156def prepare_whisper_inputs_dict(157config,158input_ids,159decoder_input_ids,160attention_mask=None,161decoder_attention_mask=None,162):163if decoder_attention_mask is None:164decoder_attention_mask = np.concatenate(165[166np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8),167np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8),168],169axis=-1,170)171return {172"input_features": input_ids,173"decoder_input_ids": decoder_input_ids,174"decoder_attention_mask": decoder_attention_mask,175}176
177
178def partialclass(cls, *args, **kwargs):179class NewCls(cls):180__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)181
182return NewCls183
184
185def make_partial_class(full_class, *args, **kwargs):186partial_class = partialclass(full_class, *args, **kwargs)187partial_class.__name__ = full_class.__name__188partial_class.__module__ = full_class.__module__189
190return partial_class191
192
193@require_flax
194class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase):195all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else ()196all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else ()197is_encoder_decoder = True198test_pruning = False199test_head_masking = False200test_onnx = False201
202def setUp(self):203self.model_tester = FlaxWhisperModelTester(self)204_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()205self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]206
207self.all_model_classes = (208make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes209)210self.config_tester = ConfigTester(self, config_class=WhisperConfig)211
212def test_config(self):213self.config_tester.run_common_tests()214
215# overwrite because of `input_features`216def test_forward_signature(self):217config, _ = self.model_tester.prepare_config_and_inputs_for_common()218
219for model_class in self.all_model_classes:220model = model_class(config)221signature = inspect.signature(model.__call__)222# signature.parameters is an OrderedDict => so arg_names order is deterministic223arg_names = [*signature.parameters.keys()]224
225expected_arg_names = ["input_features", "decoder_input_ids"]226self.assertListEqual(arg_names[:2], expected_arg_names)227
228# overwrite because of `input_features`229def test_jit_compilation(self):230config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()231
232for model_class in self.all_model_classes:233with self.subTest(model_class.__name__):234prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)235model = model_class(config)236
237@jax.jit238def model_jitted(input_features, decoder_input_ids, **kwargs):239return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs)240
241with self.subTest("JIT Enabled"):242jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()243
244with self.subTest("JIT Disabled"):245with jax.disable_jit():246outputs = model_jitted(**prepared_inputs_dict).to_tuple()247
248self.assertEqual(len(outputs), len(jitted_outputs))249for jitted_output, output in zip(jitted_outputs, outputs):250self.assertEqual(jitted_output.shape, output.shape)251
252def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):253# We override with a slightly higher tol value, as test recently became flaky254super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)255
256# overwrite because of `input_features`257@is_pt_flax_cross_test258def test_save_load_bf16_to_base_pt(self):259config, _ = self.model_tester.prepare_config_and_inputs_for_common()260base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)261
262for model_class in self.all_model_classes:263if model_class.__name__ == base_class.__name__:264continue265
266model = model_class(config)267model.params = model.to_bf16(model.params)268base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))269
270# convert Flax model to PyTorch model271pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning272pt_model = pt_model_class(config).eval()273pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)274
275# check that all base model weights are loaded correctly276with tempfile.TemporaryDirectory() as tmpdirname:277pt_model.save_pretrained(tmpdirname)278base_model = base_class.from_pretrained(tmpdirname, from_pt=True)279
280base_params = flatten_dict(unfreeze(base_model.params))281
282for key in base_params_from_head.keys():283max_diff = (base_params[key] - base_params_from_head[key]).sum().item()284self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")285
286# overwrite because of `input_features`287@is_pt_flax_cross_test288def test_save_load_from_base_pt(self):289config, _ = self.model_tester.prepare_config_and_inputs_for_common()290base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)291
292for model_class in self.all_model_classes:293if model_class.__name__ == base_class.__name__:294continue295
296model = base_class(config)297base_params = flatten_dict(unfreeze(model.params))298
299# convert Flax model to PyTorch model300pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning301pt_model = pt_model_class(config).eval()302pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)303
304# check that all base model weights are loaded correctly305with tempfile.TemporaryDirectory() as tmpdirname:306# save pt model307pt_model.save_pretrained(tmpdirname)308head_model = model_class.from_pretrained(tmpdirname, from_pt=True)309
310base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))311
312for key in base_param_from_head.keys():313max_diff = (base_params[key] - base_param_from_head[key]).sum().item()314self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")315
316# overwrite because of `input_features`317@is_pt_flax_cross_test318def test_save_load_to_base_pt(self):319config, _ = self.model_tester.prepare_config_and_inputs_for_common()320base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)321
322for model_class in self.all_model_classes:323if model_class.__name__ == base_class.__name__:324continue325
326model = model_class(config)327base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))328
329# convert Flax model to PyTorch model330pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning331pt_model = pt_model_class(config).eval()332pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)333
334# check that all base model weights are loaded correctly335with tempfile.TemporaryDirectory() as tmpdirname:336pt_model.save_pretrained(tmpdirname)337base_model = base_class.from_pretrained(tmpdirname, from_pt=True)338
339base_params = flatten_dict(unfreeze(base_model.params))340
341for key in base_params_from_head.keys():342max_diff = (base_params[key] - base_params_from_head[key]).sum().item()343self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")344
345# overwrite because of `input_features`346def test_save_load_from_base(self):347config, _ = self.model_tester.prepare_config_and_inputs_for_common()348base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)349
350for model_class in self.all_model_classes:351if model_class.__name__ == base_class.__name__:352continue353
354model = base_class(config)355base_params = flatten_dict(unfreeze(model.params))356
357# check that all base model weights are loaded correctly358with tempfile.TemporaryDirectory() as tmpdirname:359model.save_pretrained(tmpdirname)360head_model = model_class.from_pretrained(tmpdirname)361
362base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))363
364for key in base_param_from_head.keys():365max_diff = (base_params[key] - base_param_from_head[key]).sum().item()366self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")367
368# overwrite because of `input_features`369def test_save_load_to_base(self):370config, _ = self.model_tester.prepare_config_and_inputs_for_common()371base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape)372
373for model_class in self.all_model_classes:374if model_class.__name__ == base_class.__name__:375continue376
377model = model_class(config)378base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))379
380# check that all base model weights are loaded correctly381with tempfile.TemporaryDirectory() as tmpdirname:382model.save_pretrained(tmpdirname)383base_model = base_class.from_pretrained(tmpdirname)384
385base_params = flatten_dict(unfreeze(base_model.params))386
387for key in base_params_from_head.keys():388max_diff = (base_params[key] - base_params_from_head[key]).sum().item()389self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")390
391def test_encoder_sinusoidal_embed_positions(self):392config, _ = self.model_tester.prepare_config_and_inputs_for_common()393
394for model_class in self.all_model_classes:395model = model_class(config)396params = model.params397if model.base_model_prefix in params:398params = model.params[model.base_model_prefix]399
400embeds = params["encoder"]["embed_positions"]["embedding"]401sinusoids = sinusoidal_embedding_init(None, embeds.shape)402self.assertTrue(jax.numpy.allclose(embeds, sinusoids))403
404
405@slow
406@require_flax
407class FlaxWhisperModelIntegrationTest(unittest.TestCase):408@cached_property409def default_processor(self):410return WhisperProcessor.from_pretrained("openai/whisper-base")411
412def _load_datasamples(self, num_samples):413ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")414# automatic decoding with librispeech415speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]416
417return [x["array"] for x in speech_samples]418
419def test_tiny_logits_librispeech(self):420model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True)421input_speech = self._load_datasamples(1)422feature_extractor = WhisperFeatureExtractor()423input_features = feature_extractor(input_speech, return_tensors="np").input_features424
425logits = model(426input_features,427decoder_input_ids=np.array([[50258, 50259, 50359]]),428output_hidden_states=False,429output_attentions=False,430return_dict=False,431)432
433# fmt: off434EXPECTED_LOGITS = np.array(435[4362.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,4370.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,4384.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,4390.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841440]441)442# fmt: on443self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))444
445def test_small_en_logits_librispeech(self):446model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)447input_speech = self._load_datasamples(1)448feature_extractor = WhisperFeatureExtractor()449input_features = feature_extractor(input_speech, return_tensors="np").input_features450
451logits = model(452input_features,453decoder_input_ids=np.array([model.config.decoder_start_token_id]),454output_hidden_states=False,455output_attentions=False,456return_dict=False,457)458
459logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T460
461# fmt: off462EXPECTED_LOGITS = np.array(463[464-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,465-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,466-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,467-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,468-11.1146, -8.1918469]470)471# fmt: on472self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))473
474def test_large_logits_librispeech(self):475model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)476input_speech = self._load_datasamples(1)477processor = WhisperProcessor.from_pretrained("openai/whisper-large")478processed_inputs = processor(479audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"480)481input_features = processed_inputs.input_features482decoder_input_ids = processed_inputs.labels483
484logits = model(485input_features,486decoder_input_ids=decoder_input_ids,487output_hidden_states=False,488output_attentions=False,489return_dict=False,490)491
492logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T493
494# fmt: off495EXPECTED_LOGITS = np.array(496[4972.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,4981.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,4991.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,5001.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184501]502)503# fmt: on504self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))505
506def test_tiny_en_generation(self):507processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")508model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)509model.config.decoder_start_token_id = 50257510
511input_speech = self._load_datasamples(1)512input_features = processor.feature_extractor(513raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"514).input_features515
516generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences517transcript = processor.tokenizer.decode(generated_ids[0])518
519EXPECTED_TRANSCRIPT = (520"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"521" classes and we are glad to"522)523self.assertEqual(transcript, EXPECTED_TRANSCRIPT)524
525def test_tiny_generation(self):526processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")527model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True)528
529input_speech = self._load_datasamples(1)530input_features = processor.feature_extractor(531raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"532).input_features533
534generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences535transcript = processor.tokenizer.decode(generated_ids[0])536
537EXPECTED_TRANSCRIPT = (538"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"539" classes and we are glad"540)541self.assertEqual(transcript, EXPECTED_TRANSCRIPT)542
543def test_large_generation(self):544processor = WhisperProcessor.from_pretrained("openai/whisper-large")545model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)546
547input_speech = self._load_datasamples(1)548input_features = processor.feature_extractor(549raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"550).input_features551
552model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")553
554generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences555transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True)556
557EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"558self.assertEqual(transcript, EXPECTED_TRANSCRIPT)559
560def test_large_generation_multilingual(self):561processor = WhisperProcessor.from_pretrained("openai/whisper-large")562model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)563
564ds = load_dataset("common_voice", "ja", split="test", streaming=True)565ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))566input_speech = next(iter(ds))["audio"]["array"]567input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np")568
569model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")570generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences571transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]572
573EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"574self.assertEqual(transcript, EXPECTED_TRANSCRIPT)575
576model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")577generated_ids = model.generate(578input_features,579do_sample=False,580max_length=20,581).sequences582transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]583
584EXPECTED_TRANSCRIPT = " Kimura-san called me."585self.assertEqual(transcript, EXPECTED_TRANSCRIPT)586
587model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")588generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences589transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]590
591EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"592self.assertEqual(transcript, EXPECTED_TRANSCRIPT)593
594def test_large_batched_generation(self):595processor = WhisperProcessor.from_pretrained("openai/whisper-large")596model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)597
598input_speech = self._load_datasamples(4)599input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features600generated_ids = model.generate(input_features, max_length=20).sequences601
602# fmt: off603EXPECTED_LOGITS = np.array(604[605[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],606[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],607[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],608[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]609]610)611# fmt: on612
613self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))614
615# fmt: off616EXPECTED_TRANSCRIPT = [617" Mr. Quilter is the apostle of the middle classes and we are glad to",618" Nor is Mr. Quilter's manner less interesting than his matter.",619" He tells us that at this festive season of the year, with Christmas and roast beef",620" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",621]622# fmt: on623
624transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)625self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)626
627def test_tiny_en_batched_generation(self):628processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")629model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)630
631input_speech = self._load_datasamples(4)632input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features633generated_ids = model.generate(input_features, max_length=20).sequences634
635# fmt: off636EXPECTED_LOGITS = np.array(637[638[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],639[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],640[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],641[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]642]643
644)645# fmt: on646
647self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))648
649# fmt: off650EXPECTED_TRANSCRIPT = [651" Mr. Quilter is the apostle of the middle classes, and we are glad to",652" Nor is Mr. Quilter's manner less interesting than his matter.",653" He tells us that at this festive season of the year, with Christmas and roast beef looming",654" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",655]656# fmt: on657
658transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)659self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)660
661@slow662def test_tiny_timestamp_generation(self):663processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")664model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")665
666input_speech = np.concatenate(self._load_datasamples(4))667input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features668
669generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))670
671generated_ids = generate_fn(input_features)672
673EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) # fmt: skip674
675self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))676
677EXPECTED_TRANSCRIPT = [678{679"text": (680" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is"681" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season"682" of the year, with Christmas and roast beef looming before us, similarly drawn from eating and"683" its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'"684" work is really Greek after all, and"685),686"offsets": [687{688"text": (689" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."690),691"timestamp": (0.0, 6.5600000000000005),692},693{694"text": " Nor is Mr. Quilter's manner less interesting than his matter.",695"timestamp": (6.5600000000000005, 11.24),696},697{698"text": (699" He tells us that at this festive season of the year, with Christmas and roast beef"700" looming"701),702"timestamp": (11.24, 16.88),703},704{705"text": (706" before us, similarly drawn from eating and its results occur most readily to the mind."707),708"timestamp": (16.88, 23.76),709},710{711"text": (712" He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and"713),714"timestamp": (23.76, 29.44),715},716],717}718]719
720transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)721self.assertEqual(transcript, EXPECTED_TRANSCRIPT)722
723
724class FlaxWhisperEncoderModelTester:725def __init__(726self,727parent,728batch_size=13,729seq_length=60,730is_training=True,731use_labels=True,732hidden_size=16,733num_hidden_layers=2,734num_attention_heads=4,735input_channels=1,736hidden_act="gelu",737hidden_dropout_prob=0.1,738attention_probs_dropout_prob=0.1,739max_position_embeddings=20,740max_source_positions=30,741num_mel_bins=80,742num_conv_layers=1,743suppress_tokens=None,744begin_suppress_tokens=None,745classifier_proj_size=4,746num_labels=2,747is_encoder_decoder=False,748is_decoder=False,749):750self.parent = parent751self.batch_size = batch_size752self.seq_length = seq_length753self.is_training = is_training754self.use_labels = use_labels755self.hidden_size = hidden_size756self.num_hidden_layers = num_hidden_layers757self.num_attention_heads = num_attention_heads758self.input_channels = input_channels759self.hidden_act = hidden_act760self.hidden_dropout_prob = hidden_dropout_prob761self.attention_probs_dropout_prob = attention_probs_dropout_prob762self.num_mel_bins = num_mel_bins763self.max_position_embeddings = max_position_embeddings764self.max_source_positions = max_source_positions765self.num_conv_layers = num_conv_layers766self.suppress_tokens = suppress_tokens767self.begin_suppress_tokens = begin_suppress_tokens768self.classifier_proj_size = classifier_proj_size769self.num_labels = num_labels770self.is_encoder_decoder = is_encoder_decoder771self.is_decoder = is_decoder772
773def get_config(self):774return WhisperConfig(775d_model=self.hidden_size,776encoder_layers=self.num_hidden_layers,777decoder_layers=self.num_hidden_layers,778encoder_attention_heads=self.num_attention_heads,779decoder_attention_heads=self.num_attention_heads,780input_channels=self.input_channels,781dropout=self.hidden_dropout_prob,782attention_dropout=self.attention_probs_dropout_prob,783max_position_embeddings=self.max_position_embeddings,784max_source_positions=self.max_source_positions,785decoder_ffn_dim=self.hidden_size,786encoder_ffn_dim=self.hidden_size,787suppress_tokens=self.suppress_tokens,788begin_suppress_tokens=self.begin_suppress_tokens,789classifier_proj_size=self.classifier_proj_size,790num_labels=self.num_labels,791is_encoder_decoder=self.is_encoder_decoder,792is_decoder=self.is_decoder,793)794
795def prepare_whisper_encoder_inputs_dict(796self,797input_features,798):799return {800"input_features": input_features,801}802
803def prepare_config_and_inputs(self):804input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length])805
806config = self.get_config()807inputs_dict = self.prepare_whisper_encoder_inputs_dict(808input_features=input_features,809)810return config, inputs_dict811
812def prepare_config_and_inputs_for_common(self):813config, inputs_dict = self.prepare_config_and_inputs()814return config, inputs_dict815
816def get_subsampled_output_lengths(self, input_lengths):817"""818Computes the output length of the convolutional layers
819"""
820
821for i in range(self.num_conv_layers):822input_lengths = (input_lengths - 1) // 2 + 1823
824return input_lengths825
826@property827def encoder_seq_length(self):828return self.get_subsampled_output_lengths(self.seq_length)829
830
831@require_flax
832class WhisperEncoderModelTest(FlaxModelTesterMixin, unittest.TestCase):833all_model_classes = (FlaxWhisperForAudioClassification,) if is_flax_available() else ()834is_encoder_decoder = False835fx_compatible = False836test_pruning = False837test_missing_keys = False838
839input_name = "input_features"840
841def setUp(self):842self.model_tester = FlaxWhisperEncoderModelTester(self)843_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()844self.init_shape = (1,) + inputs_dict["input_features"].shape[1:]845
846self.all_model_classes = (847make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes848)849self.config_tester = ConfigTester(self, config_class=WhisperConfig)850
851def test_config(self):852self.config_tester.run_common_tests()853
854# overwrite because of `input_features`855def test_jit_compilation(self):856config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()857
858for model_class in self.all_model_classes:859with self.subTest(model_class.__name__):860prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)861model = model_class(config)862
863@jax.jit864def model_jitted(input_features, **kwargs):865return model(input_features=input_features, **kwargs)866
867with self.subTest("JIT Enabled"):868jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()869
870with self.subTest("JIT Disabled"):871with jax.disable_jit():872outputs = model_jitted(**prepared_inputs_dict).to_tuple()873
874self.assertEqual(len(outputs), len(jitted_outputs))875for jitted_output, output in zip(jitted_outputs, outputs):876self.assertEqual(jitted_output.shape, output.shape)877
878# overwrite because of `input_features`879def test_forward_signature(self):880config, _ = self.model_tester.prepare_config_and_inputs_for_common()881
882for model_class in self.all_model_classes:883model = model_class(config)884signature = inspect.signature(model.__call__)885# signature.parameters is an OrderedDict => so arg_names order is deterministic886arg_names = [*signature.parameters.keys()]887
888expected_arg_names = ["input_features", "attention_mask", "output_attentions"]889self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)890
891def test_inputs_embeds(self):892pass893
894# WhisperEncoder has no inputs_embeds and thus the `get_input_embeddings` fn is not implemented895def test_model_common_attributes(self):896pass897
898# WhisperEncoder cannot resize token embeddings since it has no tokens embeddings899def test_resize_tokens_embeddings(self):900pass901
902# WhisperEncoder does not have any base model903def test_save_load_to_base(self):904pass905
906# WhisperEncoder does not have any base model907def test_save_load_from_base(self):908pass909
910# WhisperEncoder does not have any base model911@is_pt_flax_cross_test912def test_save_load_from_base_pt(self):913pass914
915# WhisperEncoder does not have any base model916@is_pt_flax_cross_test917def test_save_load_to_base_pt(self):918pass919
920# WhisperEncoder does not have any base model921@is_pt_flax_cross_test922def test_save_load_bf16_to_base_pt(self):923pass924