transformers
1119 строк · 46.7 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the TensorFlow Whisper model. """
16
17from __future__ import annotations18
19import inspect20import tempfile21import traceback22import unittest23
24import numpy as np25
26from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor27from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow28from transformers.utils import cached_property29from transformers.utils.import_utils import is_datasets_available30
31from ...test_configuration_common import ConfigTester32from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor33from ...test_pipeline_mixin import PipelineTesterMixin34
35
36if is_datasets_available():37import datasets38from datasets import load_dataset39
40
41if is_tf_available():42import tensorflow as tf43
44from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed45from transformers.models.whisper.modeling_tf_whisper import (46TFWhisperDecoder,47TFWhisperEncoder,48sinusoidal_embedding_init,49)50
51
52def prepare_whisper_inputs_dict(53config,54input_features,55decoder_input_ids,56attention_mask=None,57decoder_attention_mask=None,58head_mask=None,59decoder_head_mask=None,60cross_attn_head_mask=None,61):62if decoder_attention_mask is None:63decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)64if head_mask is None:65head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))66if decoder_head_mask is None:67decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))68if cross_attn_head_mask is None:69cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))70return {71"input_features": input_features,72"decoder_input_ids": decoder_input_ids,73"decoder_attention_mask": decoder_attention_mask,74"head_mask": head_mask,75"decoder_head_mask": decoder_head_mask,76"cross_attn_head_mask": cross_attn_head_mask,77}78
79
80@require_tf
81class TFWhisperModelTester:82def __init__(83self,84parent,85batch_size=13,86seq_length=60,87is_training=True,88use_labels=False,89vocab_size=200,90hidden_size=16,91num_hidden_layers=2,92num_attention_heads=4,93input_channels=1,94hidden_act="gelu",95hidden_dropout_prob=0.1,96attention_probs_dropout_prob=0.1,97max_position_embeddings=20,98max_source_positions=30,99max_target_positions=60,100bos_token_id=98,101eos_token_id=98,102pad_token_id=0,103num_mel_bins=80,104decoder_start_token_id=85,105num_conv_layers=1,106suppress_tokens=None,107begin_suppress_tokens=None,108):109self.parent = parent110self.batch_size = batch_size111self.seq_length = seq_length112self.is_training = is_training113self.use_labels = use_labels114self.vocab_size = vocab_size115self.hidden_size = hidden_size116self.num_hidden_layers = num_hidden_layers117self.num_attention_heads = num_attention_heads118self.input_channels = input_channels119self.hidden_act = hidden_act120self.hidden_dropout_prob = hidden_dropout_prob121self.attention_probs_dropout_prob = attention_probs_dropout_prob122self.num_mel_bins = num_mel_bins123self.max_position_embeddings = max_position_embeddings124self.max_source_positions = max_source_positions125self.max_target_positions = max_target_positions126self.eos_token_id = eos_token_id127self.pad_token_id = pad_token_id128self.bos_token_id = bos_token_id129self.decoder_start_token_id = decoder_start_token_id130self.num_conv_layers = num_conv_layers131self.suppress_tokens = suppress_tokens132self.begin_suppress_tokens = begin_suppress_tokens133
134def prepare_config_and_inputs(self):135input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)136
137decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)138
139config = self.get_config()140inputs_dict = prepare_whisper_inputs_dict(141config,142attention_mask=None,143input_features=input_features,144decoder_input_ids=decoder_input_ids,145)146return config, inputs_dict147
148def get_config(self):149return WhisperConfig(150vocab_size=self.vocab_size,151d_model=self.hidden_size,152encoder_layers=self.num_hidden_layers,153decoder_layers=self.num_hidden_layers,154encoder_attention_heads=self.num_attention_heads,155decoder_attention_heads=self.num_attention_heads,156input_channels=self.input_channels,157dropout=self.hidden_dropout_prob,158attention_dropout=self.attention_probs_dropout_prob,159max_position_embeddings=self.max_position_embeddings,160max_source_positions=self.max_source_positions,161max_target_positions=self.max_target_positions,162eos_token_id=self.eos_token_id,163bos_token_id=self.bos_token_id,164pad_token_id=self.pad_token_id,165decoder_ffn_dim=self.hidden_size,166encoder_ffn_dim=self.hidden_size,167decoder_start_token_id=self.decoder_start_token_id,168suppress_tokens=self.suppress_tokens,169begin_suppress_tokens=self.begin_suppress_tokens,170)171
172def prepare_config_and_inputs_for_common(self):173config, inputs_dict = self.prepare_config_and_inputs()174return config, inputs_dict175
176def get_subsampled_output_lengths(self, input_lengths):177"""178Computes the output length of the convolutional layers
179"""
180
181for i in range(self.num_conv_layers):182input_lengths = (input_lengths - 1) // 2 + 1183
184return input_lengths185
186def create_and_check_model_forward(self, config, inputs_dict):187model = TFWhisperModel(config=config)188
189input_features = inputs_dict["input_features"]190decoder_input_ids = inputs_dict["decoder_input_ids"]191
192# first forward pass193last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state194
195self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))196
197def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):198model = TFWhisperModel(config=config).get_decoder()199# take a slice so we're shorter than the seqeuence length and can append later200input_ids = inputs_dict["decoder_input_ids"][:, :-10]201attention_mask = inputs_dict["decoder_attention_mask"][:, :-10]202
203# first forward pass204outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)205
206output, past_key_values = outputs.to_tuple()207
208# create hypothetical multiple next token and extent to next_input_ids209next_token = ids_tensor((self.batch_size, 3), config.vocab_size)210next_tokens = tf.where(next_token <= 2, 2, next_token)211next_attn_mask = ids_tensor((self.batch_size, 3), 2)212
213# append to next input_ids and214next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)215next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)216
217output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]218output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[219"last_hidden_state"220]221
222# select random slice223random_slice_idx = np.random.randint(0, output_from_past.shape[-1])224output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]225output_from_past_slice = output_from_past[:, :, random_slice_idx]226
227self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])228
229# test that outputs are equal for slice230self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))231
232def check_encoder_decoder_model_standalone(self, config, inputs_dict):233model = TFWhisperModel(config=config)234outputs = model(**inputs_dict)235
236encoder_last_hidden_state = outputs.encoder_last_hidden_state237last_hidden_state = outputs.last_hidden_state238
239with tempfile.TemporaryDirectory() as tmpdirname:240encoder = model.get_encoder()241encoder.save_pretrained(tmpdirname)242encoder = TFWhisperEncoder.from_pretrained(tmpdirname)243
244encoder_last_hidden_state_2 = encoder(inputs_dict["input_features"])[0]245
246self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max() < 1e-3)247
248with tempfile.TemporaryDirectory() as tmpdirname:249decoder = model.get_decoder()250decoder.save_pretrained(tmpdirname)251decoder = TFWhisperDecoder.from_pretrained(tmpdirname)252
253last_hidden_state_2 = decoder(254input_ids=inputs_dict["decoder_input_ids"],255attention_mask=inputs_dict["decoder_attention_mask"],256encoder_hidden_states=encoder_last_hidden_state,257)[0]258
259self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max() < 1e-3)260
261
262@require_tf
263class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):264all_model_classes = (TFWhisperModel, TFWhisperForConditionalGeneration) if is_tf_available() else ()265all_generative_model_classes = (TFWhisperForConditionalGeneration,) if is_tf_available() else ()266pipeline_model_mapping = {"feature-extraction": TFWhisperModel} if is_tf_available() else {}267is_encoder_decoder = True268fx_compatible = False269test_pruning = False270test_missing_keys = False271test_onnx = False272
273input_name = "input_features"274
275# TODO (ydshieh): undo skip once a fix is done on TF side.276@unittest.skip("Skip for now as TF 2.13 breaks it on GPU")277def test_xla_generate_slow(self):278super().test_xla_generate_slow()279
280def setUp(self):281self.model_tester = TFWhisperModelTester(self)282self.config_tester = ConfigTester(self, config_class=WhisperConfig)283self.maxDiff = 3000284
285def test_config(self):286self.config_tester.run_common_tests()287
288def test_save_load_strict(self):289config, inputs_dict = self.model_tester.prepare_config_and_inputs()290for model_class in self.all_model_classes:291model = model_class(config)292
293model.build_in_name_scope()294
295with tempfile.TemporaryDirectory() as tmpdirname:296model.save_pretrained(tmpdirname, saved_model=False)297model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)298self.assertEqual(info["missing_keys"], [])299
300def test_model_forward(self):301config_and_inputs = self.model_tester.prepare_config_and_inputs()302self.model_tester.create_and_check_model_forward(*config_and_inputs)303
304def test_requires_grad_encoder_embed_positions(self):305config = self.model_tester.get_config()306for model_class in self.all_model_classes:307model = model_class(config)308encoder = model.get_encoder()309self.assertFalse(encoder.embed_positions.trainable)310
311def test_encoder_sinusoidal_embed_positions(self):312config = self.model_tester.get_config()313for model_class in self.all_model_classes:314model = model_class(config)315model.build_in_name_scope()316
317embeds = model.get_encoder().embed_positions.get_weights()[0]318sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()319self.assertTrue(np.allclose(embeds, sinusoids))320
321def test_decoder_model_past_with_large_inputs(self):322config_and_inputs = self.model_tester.prepare_config_and_inputs()323self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)324
325def _get_input_ids_and_config(self):326config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()327input_ids = inputs_dict[self.input_name]328
329# cut to half length & take max batch_size 3330max_batch_size = 3331input_ids = input_ids[:max_batch_size, :, :]332
333# generate max 3 tokens334max_length = 4335if config.eos_token_id is not None and config.pad_token_id is None:336# hack to allow generate for models such as GPT2 as is done in `generate()`337config.pad_token_id = config.eos_token_id338
339return config, input_ids, None, max_length340
341# not implemented currently342def test_inputs_embeds(self):343pass344
345@unittest.skip("Training is not yet supported")346def test_training(self):347pass348
349def test_generate_with_head_masking(self):350pass351
352@unittest.skip("fp16 is not yet supported for TF models")353def test_generate_fp16(self):354config, input_dict = self.model_tester.prepare_config_and_inputs()355config.max_target_positions = 400356input_features = input_dict["input_features"]357model = TFWhisperForConditionalGeneration(config)358model.generate(input_features)359model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)360
361def test_forward_signature(self):362config, _ = self.model_tester.prepare_config_and_inputs_for_common()363
364for model_class in self.all_model_classes:365model = model_class(config)366signature = inspect.signature(model.call)367# signature.parameters is an OrderedDict => so arg_names order is deterministic368arg_names = [*signature.parameters.keys()]369
370expected_arg_names = [371"input_features",372"decoder_input_ids",373"decoder_attention_mask",374]375expected_arg_names.extend(376["decoder_position_ids", "head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]377if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names378else ["encoder_outputs"]379)380self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)381
382def test_hidden_states_output(self):383def check_hidden_states_output(inputs_dict, config, model_class):384model = model_class(config)385outputs = model(**self._prepare_for_class(inputs_dict, model_class))386
387hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states388
389expected_num_layers = getattr(390self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1391)392self.assertEqual(len(hidden_states), expected_num_layers)393
394if hasattr(self.model_tester, "encoder_seq_length"):395seq_length = self.model_tester.encoder_seq_length396else:397seq_length = self.model_tester.seq_length398
399subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)400
401self.assertListEqual(402list(hidden_states[0].shape[-2:]),403[subsampled_seq_length, self.model_tester.hidden_size],404)405
406if config.is_encoder_decoder:407hidden_states = outputs.decoder_hidden_states408
409self.assertIsInstance(hidden_states, (list, tuple))410self.assertEqual(len(hidden_states), expected_num_layers)411
412decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)413
414self.assertListEqual(415list(hidden_states[0].shape[-2:]),416[decoder_seq_length, self.model_tester.hidden_size],417)418
419config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()420
421for model_class in self.all_model_classes:422inputs_dict["output_hidden_states"] = True423check_hidden_states_output(inputs_dict, config, model_class)424
425# check that output_hidden_states also work using config426del inputs_dict["output_hidden_states"]427config.output_hidden_states = True428
429check_hidden_states_output(inputs_dict, config, model_class)430
431def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):432# We override with a slightly higher tol value, as test recently became flaky433super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)434
435def test_attention_outputs(self):436config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()437config.return_dict = True438
439seq_len = getattr(self.model_tester, "seq_length", None)440decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)441encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)442encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)443decoder_key_length = getattr(self.model_tester, "decoder_key_length", encoder_key_length)444
445for model_class in self.all_model_classes:446inputs_dict["output_attentions"] = True447inputs_dict["output_hidden_states"] = False448config.return_dict = True449model = model_class(config)450
451subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)452subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)453
454outputs = model(**self._prepare_for_class(inputs_dict, model_class))455attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions456self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)457
458# check that output_attentions also work using config459del inputs_dict["output_attentions"]460config.output_attentions = True461model = model_class(config)462
463outputs = model(**self._prepare_for_class(inputs_dict, model_class))464attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions465self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)466
467self.assertListEqual(468list(attentions[0].shape[-3:]),469[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],470)471out_len = len(outputs)472
473correct_outlen = 5474
475# loss is at first position476if "labels" in inputs_dict:477correct_outlen += 1 # loss is added to beginning478if "past_key_values" in outputs:479correct_outlen += 1 # past_key_values have been returned480
481self.assertEqual(out_len, correct_outlen)482
483# decoder attentions484decoder_attentions = outputs.decoder_attentions485self.assertIsInstance(decoder_attentions, (list, tuple))486self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)487self.assertListEqual(488list(decoder_attentions[0].shape[-3:]),489[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],490)491
492# cross attentions493cross_attentions = outputs.cross_attentions494self.assertIsInstance(cross_attentions, (list, tuple))495self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)496self.assertListEqual(497list(cross_attentions[0].shape[-3:]),498[499self.model_tester.num_attention_heads,500decoder_seq_length,501subsampled_encoder_key_length,502],503)504
505# Check attention is always last and order is fine506inputs_dict["output_attentions"] = True507inputs_dict["output_hidden_states"] = True508model = model_class(config)509outputs = model(**self._prepare_for_class(inputs_dict, model_class))510
511added_hidden_states = 2512self.assertEqual(out_len + added_hidden_states, len(outputs))513
514self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions515
516self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)517self.assertListEqual(518list(self_attentions[0].shape[-3:]),519[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],520)521
522def test_generate_without_input_ids(self):523pass524
525@staticmethod526def _get_encoder_outputs(527model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1528):529encoder = model.get_encoder()530encoder_outputs = encoder(531input_ids,532output_attentions=output_attentions,533output_hidden_states=output_hidden_states,534)535encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(536num_interleave, dim=0537)538input_ids = input_ids[:, :, 0]539input_ids = tf.zeros_like(input_ids[:, :1], dtype=tf.int64) + tf.convert_to_tensor(540[model._get_decoder_start_token_id()]541)542attention_mask = None543return encoder_outputs, input_ids, attention_mask544
545def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):546batch_size, mel, seq_length = input_ids.shape547subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)548num_sequences_in_output = batch_size * num_return_sequences549gen_len = (550output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length551)552
553# scores554self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)555
556# Attentions557# encoder558self._check_encoder_attention_for_generate(559output.encoder_attentions, batch_size, config, subsampled_seq_length560)561# decoder562self._check_attentions_for_generate(563num_sequences_in_output,564output.decoder_attentions,565min_length=1,566max_length=output.sequences.shape[-1],567config=config,568use_cache=use_cache,569)570
571# Hidden States572# encoder573self._check_encoder_hidden_states_for_generate(574output.encoder_hidden_states, batch_size, config, subsampled_seq_length575)576
577# decoder578self._check_hidden_states_for_generate(579num_sequences_in_output,580output.decoder_hidden_states,581min_length=1,582max_length=output.sequences.shape[-1],583config=config,584use_cache=use_cache,585)586
587# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is588# `input_features`589def test_lm_head_model_random_no_beam_search_generate(self):590config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()591input_features = inputs_dict.get("input_features", None)592
593# iterate over all generative models594for model_class in self.all_generative_model_classes:595model = model_class(config)596
597if config.bos_token_id is None:598# if bos token id is not defined model needs input_features599with self.assertRaises(AssertionError):600model.generate(do_sample=True, max_length=5)601# num_return_sequences = 1602self._check_generated_ids(model.generate(input_features, do_sample=True))603
604with self.assertRaises(ValueError):605# generating multiple sequences when no beam search generation606# is not allowed as it would always generate the same sequences607model.generate(input_features, do_sample=False, num_return_sequences=2)608
609# num_return_sequences > 1, sample610self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))611
612# check bad words tokens language generation613# create list of 1-seq bad token and list of 2-seq of bad tokens614bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]615output_tokens = model.generate(616input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2617)618# only count generated tokens619generated_ids = output_tokens[:, input_features.shape[-1] :]620self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))621
622# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is623# `input_features`624def test_lm_head_model_random_beam_search_generate(self):625config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()626input_features = inputs_dict.get("input_features", None)627
628for model_class in self.all_generative_model_classes:629model = model_class(config)630
631if config.bos_token_id is None:632# if bos token id is not defined model needs input_ids, num_return_sequences = 1633self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))634
635with self.assertRaises(ValueError):636# generating more sequences than having beams leads is not possible637model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)638
639# num_return_sequences > 1, sample640self._check_generated_ids(641model.generate(642input_features,643do_sample=True,644num_beams=2,645num_return_sequences=2,646)647)648# num_return_sequences > 1, greedy649self._check_generated_ids(650model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)651)652
653# check bad words tokens language generation654# create list of 1-seq bad token and list of 2-seq of bad tokens655bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]656output_tokens = model.generate(657input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2658)659# only count generated tokens660generated_ids = output_tokens[:, input_features.shape[-1] :]661self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))662
663def test_generate_with_prompt_ids_and_task_and_language(self):664config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()665model = TFWhisperForConditionalGeneration(config)666input_features = input_dict["input_features"]667prompt_ids = np.arange(5)668language = "<|de|>"669task = "translate"670lang_id = 6671task_id = 7672model.generation_config.__setattr__("lang_to_id", {language: lang_id})673model.generation_config.__setattr__("task_to_id", {task: task_id})674
675output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)676
677expected_output_start = [678*prompt_ids.tolist(),679model.generation_config.decoder_start_token_id,680lang_id,681task_id,682]683for row in output.numpy().tolist():684self.assertListEqual(row[: len(expected_output_start)], expected_output_start)685
686def test_generate_with_prompt_ids_and_forced_decoder_ids(self):687config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()688model = TFWhisperForConditionalGeneration(config)689input_features = input_dict["input_features"]690prompt_ids = np.asarray(range(5))691forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]692
693output = model.generate(694input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids695)696
697expected_output_start = [698*prompt_ids.tolist(),699model.generation_config.decoder_start_token_id,700*[token for _rank, token in forced_decoder_ids],701]702for row in output.numpy().tolist():703self.assertListEqual(row[: len(expected_output_start)], expected_output_start)704
705
706def _load_datasamples(num_samples):707ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")708# automatic decoding with librispeech709speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]710
711return [x["array"] for x in speech_samples]712
713
714def _test_large_logits_librispeech(in_queue, out_queue, timeout):715error = None716try:717_ = in_queue.get(timeout=timeout)718
719set_seed(0)720
721model = TFWhisperModel.from_pretrained("openai/whisper-large")722
723input_speech = _load_datasamples(1)724
725processor = WhisperProcessor.from_pretrained("openai/whisper-large")726processed_inputs = processor(727audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="tf"728)729input_features = processed_inputs.input_features730decoder_input_ids = processed_inputs.labels731
732logits = model(733input_features,734decoder_input_ids=decoder_input_ids,735output_hidden_states=False,736output_attentions=False,737use_cache=False,738)739
740logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])741
742# fmt: off743EXPECTED_LOGITS = tf.convert_to_tensor(744[7452.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,7461.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,7471.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,7481.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184749]750)751# fmt: on752
753unittest.TestCase().assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))754except Exception:755error = f"{traceback.format_exc()}"756
757results = {"error": error}758out_queue.put(results, timeout=timeout)759out_queue.join()760
761
762def _test_large_generation(in_queue, out_queue, timeout):763error = None764try:765_ = in_queue.get(timeout=timeout)766
767set_seed(0)768processor = WhisperProcessor.from_pretrained("openai/whisper-large")769model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")770
771input_speech = _load_datasamples(1)772input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features773
774generated_ids = model.generate(775input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"776)777transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]778
779EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"780unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)781except Exception:782error = f"{traceback.format_exc()}"783
784results = {"error": error}785out_queue.put(results, timeout=timeout)786out_queue.join()787
788
789def _test_large_generation_multilingual(in_queue, out_queue, timeout):790error = None791try:792_ = in_queue.get(timeout=timeout)793
794set_seed(0)795processor = WhisperProcessor.from_pretrained("openai/whisper-large")796model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")797
798ds = load_dataset("common_voice", "ja", split="test", streaming=True)799ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))800input_speech = next(iter(ds))["audio"]["array"]801input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features802
803generated_ids = model.generate(804input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"805)806transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]807
808EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"809unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)810
811generated_ids = model.generate(812input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"813)814transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]815
816EXPECTED_TRANSCRIPT = " Kimura-san called me."817unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)818
819generated_ids = model.generate(820input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"821)822transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]823
824EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"825unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)826except Exception:827error = f"{traceback.format_exc()}"828
829results = {"error": error}830out_queue.put(results, timeout=timeout)831out_queue.join()832
833
834def _test_large_batched_generation(in_queue, out_queue, timeout):835error = None836try:837_ = in_queue.get(timeout=timeout)838
839set_seed(0)840processor = WhisperProcessor.from_pretrained("openai/whisper-large")841model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")842
843input_speech = _load_datasamples(4)844input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features845generated_ids_1 = model.generate(input_features[0:2], max_length=20)846generated_ids_2 = model.generate(input_features[2:4], max_length=20)847generated_ids = np.concatenate([generated_ids_1, generated_ids_2])848
849# fmt: off850EXPECTED_IDS = [851[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],852[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],853[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],854[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]855]856# fmt: on857
858unittest.TestCase().assertEqual(generated_ids.tolist(), EXPECTED_IDS)859
860# fmt: off861EXPECTED_TRANSCRIPT = [862" Mr. Quilter is the apostle of the middle classes and we are glad to",863" Nor is Mr. Quilter's manner less interesting than his matter.",864" He tells us that at this festive season of the year, with Christmas and roast beef",865" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"866]867# fmt: on868
869transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)870unittest.TestCase().assertListEqual(transcript, EXPECTED_TRANSCRIPT)871except Exception:872error = f"{traceback.format_exc()}"873
874results = {"error": error}875out_queue.put(results, timeout=timeout)876out_queue.join()877
878
879@require_tf
880@require_tokenizers
881class TFWhisperModelIntegrationTests(unittest.TestCase):882@cached_property883def default_processor(self):884return WhisperProcessor.from_pretrained("openai/whisper-base")885
886def _load_datasamples(self, num_samples):887return _load_datasamples(num_samples)888
889@slow890def test_tiny_logits_librispeech(self):891set_seed(0)892model = TFWhisperModel.from_pretrained("openai/whisper-tiny")893input_speech = self._load_datasamples(1)894feature_extractor = WhisperFeatureExtractor()895input_features = feature_extractor(input_speech, return_tensors="tf").input_features896
897logits = model(898input_features,899decoder_input_ids=tf.convert_to_tensor([[50258, 50259, 50359]]),900output_hidden_states=False,901output_attentions=False,902return_dict=False,903use_cache=False,904)905
906# fmt: off907EXPECTED_LOGITS = tf.convert_to_tensor(908[9092.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,9100.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,9114.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,9120.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841913]914)915# fmt: on916self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))917
918# fmt: off919EXPECTED_GENERATION = tf.convert_to_tensor(920[921-1.4651, -2.6944, 2.7821, 2.3793, 4.0738, 0.0188, -3.3203, 1.9836,9220.0520, 0.7095, 1.1063, 0.2952, -3.6786, -0.5249, 0.3105, 4.7691,9231.1562, 1.3046, 0.5810, -0.3624, 1.7006, 1.3424, 0.9817, 2.1958,9241.8775, -5.7046, -0.7679, 4.0113, 2.6848, 2.8609925]926)927# fmt: on928
929head_logits = logits[0] @ tf.transpose(model.model.decoder.embed_tokens.weights[0])930self.assertTrue(np.allclose(head_logits[0, 0, :30], EXPECTED_GENERATION, atol=1e-4))931
932@slow933def test_small_en_logits_librispeech(self):934set_seed(0)935model = TFWhisperModel.from_pretrained("openai/whisper-small.en")936
937input_speech = self._load_datasamples(1)938
939feaure_extractor = WhisperFeatureExtractor()940input_features = feaure_extractor(input_speech, return_tensors="tf").input_features941
942logits = model(943input_features,944decoder_input_ids=tf.convert_to_tensor([[model.config.decoder_start_token_id]]),945output_hidden_states=False,946output_attentions=False,947use_cache=False,948)949
950logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])951
952# fmt: off953EXPECTED_LOGITS = tf.convert_to_tensor(954[955-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,956-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,957-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,958-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,959-11.1146, -8.1918960]961)962# fmt: on963self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))964
965@slow966def test_large_logits_librispeech(self):967run_test_in_subprocess(test_case=self, target_func=_test_large_logits_librispeech, inputs=None)968
969@slow970def test_tiny_en_generation(self):971set_seed(0)972processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")973model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")974model.config.decoder_start_token_id = 50257975
976input_speech = self._load_datasamples(1)977input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features978
979generated_ids = model.generate(input_features, num_beams=5, max_length=20)980transcript = processor.tokenizer.batch_decode(generated_ids)[0]981
982EXPECTED_TRANSCRIPT = (983"<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle"984" classes, and we are glad to"985)986self.assertEqual(transcript, EXPECTED_TRANSCRIPT)987
988@slow989def test_tiny_generation(self):990set_seed(0)991processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")992model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")993
994input_speech = self._load_datasamples(1)995input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features996
997generated_ids = model.generate(input_features, num_beams=5, max_length=20)998transcript = processor.tokenizer.decode(generated_ids[0])999
1000EXPECTED_TRANSCRIPT = (1001"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"1002" classes and we are glad"1003)1004self.assertEqual(transcript, EXPECTED_TRANSCRIPT)1005
1006@slow1007def test_tiny_xla_generation(self):1008set_seed(0)1009processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")1010model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")1011
1012input_speech = self._load_datasamples(1)1013input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features1014
1015xla_generate = tf.function(model.generate, jit_compile=True)1016
1017generated_ids = model.generate(input_features, num_beams=5, max_length=20)1018generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)1019
1020transcript = processor.tokenizer.decode(generated_ids[0])1021transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])1022
1023EXPECTED_TRANSCRIPT = (1024"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"1025" classes and we are glad"1026)1027self.assertEqual(transcript, EXPECTED_TRANSCRIPT)1028self.assertEqual(transcript_xla, EXPECTED_TRANSCRIPT)1029
1030@slow1031def test_large_generation(self):1032run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)1033
1034@slow1035def test_large_generation_multilingual(self):1036run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)1037
1038@slow1039def test_large_batched_generation(self):1040run_test_in_subprocess(test_case=self, target_func=_test_large_batched_generation, inputs=None)1041
1042@slow1043def test_tiny_en_batched_generation(self):1044set_seed(0)1045processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")1046model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")1047
1048input_speech = self._load_datasamples(4)1049input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features1050generated_ids = model.generate(input_features, max_length=20)1051
1052# fmt: off1053EXPECTED_LOGITS = tf.convert_to_tensor(1054[1055[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],1056[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],1057[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],1058[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]1059]1060
1061)1062# fmt: on1063
1064self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))1065
1066# fmt: off1067EXPECTED_TRANSCRIPT = [1068" Mr. Quilter is the apostle of the middle classes, and we are glad to",1069" Nor is Mr. Quilter's manner less interesting than his matter.",1070" He tells us that at this festive season of the year, with Christmas and roast beef looming",1071" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",1072]1073# fmt: on1074
1075transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)1076self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)1077
1078@slow1079def test_tiny_en_batched_xla_generation(self):1080set_seed(0)1081processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")1082model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")1083
1084input_speech = self._load_datasamples(4)1085input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features1086
1087xla_generate = tf.function(model.generate, jit_compile=True)1088
1089generated_ids = model.generate(input_features, max_length=20)1090generated_ids_xla = xla_generate(input_features, max_length=20)1091
1092# fmt: off1093EXPECTED_LOGITS = tf.convert_to_tensor(1094[1095[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],1096[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],1097[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],1098[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]1099]1100
1101)1102# fmt: on1103
1104self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))1105self.assertTrue(np.allclose(generated_ids_xla, EXPECTED_LOGITS))1106
1107# fmt: off1108EXPECTED_TRANSCRIPT = [1109" Mr. Quilter is the apostle of the middle classes, and we are glad to",1110" Nor is Mr. Quilter's manner less interesting than his matter.",1111" He tells us that at this festive season of the year, with Christmas and roast beef looming",1112" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",1113]1114# fmt: on1115
1116transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)1117transcript_xla = processor.batch_decode(generated_ids_xla, skip_special_tokens=True)1118self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)1119self.assertListEqual(transcript_xla, EXPECTED_TRANSCRIPT)1120