transformers
1886 строк · 76.9 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch SpeechT5 model. """
16
17import copy18import inspect19import tempfile20import unittest21
22from transformers import SpeechT5Config, SpeechT5HifiGanConfig23from transformers.testing_utils import (24is_torch_available,25require_sentencepiece,26require_tokenizers,27require_torch,28slow,29torch_device,30)
31from transformers.trainer_utils import set_seed32from transformers.utils import cached_property33
34from ...test_configuration_common import ConfigTester35from ...test_modeling_common import (36ModelTesterMixin,37_config_zero_init,38floats_tensor,39ids_tensor,40random_attention_mask,41)
42from ...test_pipeline_mixin import PipelineTesterMixin43
44
45if is_torch_available():46import torch47
48from transformers import (49SpeechT5ForSpeechToSpeech,50SpeechT5ForSpeechToText,51SpeechT5ForTextToSpeech,52SpeechT5HifiGan,53SpeechT5Model,54SpeechT5Processor,55)56
57
58def prepare_inputs_dict(59config,60input_ids=None,61input_values=None,62decoder_input_ids=None,63decoder_input_values=None,64attention_mask=None,65decoder_attention_mask=None,66head_mask=None,67decoder_head_mask=None,68cross_attn_head_mask=None,69):70if input_ids is not None:71encoder_dict = {"input_ids": input_ids}72else:73encoder_dict = {"input_values": input_values}74
75if decoder_input_ids is not None:76decoder_dict = {"decoder_input_ids": decoder_input_ids}77else:78decoder_dict = {"decoder_input_values": decoder_input_values}79
80if head_mask is None:81head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)82if decoder_head_mask is None:83decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)84if cross_attn_head_mask is None:85cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)86
87return {88**encoder_dict,89**decoder_dict,90"attention_mask": attention_mask,91"decoder_attention_mask": decoder_attention_mask,92"head_mask": head_mask,93"decoder_head_mask": decoder_head_mask,94"cross_attn_head_mask": cross_attn_head_mask,95}96
97
98@require_torch
99class SpeechT5ModelTester:100def __init__(101self,102parent,103batch_size=13,104seq_length=7,105is_training=False,106vocab_size=81,107hidden_size=24,108num_hidden_layers=2,109num_attention_heads=2,110intermediate_size=4,111):112self.parent = parent113self.batch_size = batch_size114self.seq_length = seq_length115self.is_training = is_training116self.vocab_size = vocab_size117self.hidden_size = hidden_size118self.num_hidden_layers = num_hidden_layers119self.num_attention_heads = num_attention_heads120self.intermediate_size = intermediate_size121
122def prepare_config_and_inputs(self):123input_values = floats_tensor([self.batch_size, self.seq_length, self.hidden_size], scale=1.0)124attention_mask = random_attention_mask([self.batch_size, self.seq_length])125
126decoder_input_values = floats_tensor([self.batch_size, self.seq_length, self.hidden_size], scale=1.0)127decoder_attention_mask = random_attention_mask([self.batch_size, self.seq_length])128
129config = self.get_config()130inputs_dict = prepare_inputs_dict(131config,132input_values=input_values,133decoder_input_values=decoder_input_values,134attention_mask=attention_mask,135decoder_attention_mask=decoder_attention_mask,136)137return config, inputs_dict138
139def prepare_config_and_inputs_for_common(self):140config, inputs_dict = self.prepare_config_and_inputs()141return config, inputs_dict142
143def get_config(self):144return SpeechT5Config(145vocab_size=self.vocab_size,146hidden_size=self.hidden_size,147encoder_layers=self.num_hidden_layers,148decoder_layers=self.num_hidden_layers,149encoder_attention_heads=self.num_attention_heads,150decoder_attention_heads=self.num_attention_heads,151encoder_ffn_dim=self.intermediate_size,152decoder_ffn_dim=self.intermediate_size,153)154
155def create_and_check_model_forward(self, config, inputs_dict):156model = SpeechT5Model(config=config).to(torch_device).eval()157
158input_values = inputs_dict["input_values"]159attention_mask = inputs_dict["attention_mask"]160decoder_input_values = inputs_dict["decoder_input_values"]161
162result = model(input_values, attention_mask=attention_mask, decoder_input_values=decoder_input_values)163self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))164
165
166@require_torch
167class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):168all_model_classes = (SpeechT5Model,) if is_torch_available() else ()169pipeline_model_mapping = (170{"automatic-speech-recognition": SpeechT5ForSpeechToText, "feature-extraction": SpeechT5Model}171if is_torch_available()172else {}173)174is_encoder_decoder = True175test_pruning = False176test_headmasking = False177test_resize_embeddings = False178
179input_name = "input_values"180
181def setUp(self):182self.model_tester = SpeechT5ModelTester(self)183self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)184
185def test_config(self):186self.config_tester.run_common_tests()187
188def test_model_forward(self):189config_and_inputs = self.model_tester.prepare_config_and_inputs()190self.model_tester.create_and_check_model_forward(*config_and_inputs)191
192def test_forward_signature(self):193config, _ = self.model_tester.prepare_config_and_inputs_for_common()194
195for model_class in self.all_model_classes:196model = model_class(config)197signature = inspect.signature(model.forward)198# signature.parameters is an OrderedDict => so arg_names order is deterministic199arg_names = [*signature.parameters.keys()]200
201expected_arg_names = [202"input_values",203"attention_mask",204"decoder_input_values",205"decoder_attention_mask",206]207expected_arg_names.extend(208["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]209if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names210else ["encoder_outputs"]211)212self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)213
214# this model has no inputs_embeds215def test_inputs_embeds(self):216pass217
218# this model has no input embeddings219def test_model_common_attributes(self):220pass221
222def test_retain_grad_hidden_states_attentions(self):223# decoder cannot keep gradients224pass225
226@slow227def test_torchscript_output_attentions(self):228# disabled because this model doesn't have decoder_input_ids229pass230
231@slow232def test_torchscript_output_hidden_state(self):233# disabled because this model doesn't have decoder_input_ids234pass235
236@slow237def test_torchscript_simple(self):238# disabled because this model doesn't have decoder_input_ids239pass240
241
242@require_torch
243class SpeechT5ForSpeechToTextTester:244def __init__(245self,246parent,247batch_size=13,248encoder_seq_length=1024, # speech is longer249decoder_seq_length=7,250is_training=False,251hidden_size=24,252num_hidden_layers=2,253num_attention_heads=2,254intermediate_size=4,255conv_dim=(32, 32, 32),256conv_stride=(4, 4, 4),257conv_kernel=(8, 8, 8),258conv_bias=False,259num_conv_pos_embeddings=16,260num_conv_pos_embedding_groups=2,261vocab_size=81,262):263self.parent = parent264self.batch_size = batch_size265self.encoder_seq_length = encoder_seq_length266self.decoder_seq_length = decoder_seq_length267self.is_training = is_training268self.hidden_size = hidden_size269self.num_hidden_layers = num_hidden_layers270self.num_attention_heads = num_attention_heads271self.intermediate_size = intermediate_size272self.conv_dim = conv_dim273self.conv_stride = conv_stride274self.conv_kernel = conv_kernel275self.conv_bias = conv_bias276self.num_conv_pos_embeddings = num_conv_pos_embeddings277self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups278self.vocab_size = vocab_size279
280def prepare_config_and_inputs(self):281input_values = floats_tensor([self.batch_size, self.encoder_seq_length], scale=1.0)282attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])283
284decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size).clamp(2)285decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])286
287config = self.get_config()288inputs_dict = prepare_inputs_dict(289config,290input_values=input_values,291decoder_input_ids=decoder_input_ids,292attention_mask=attention_mask,293decoder_attention_mask=decoder_attention_mask,294)295return config, inputs_dict296
297def prepare_config_and_inputs_for_common(self):298config, inputs_dict = self.prepare_config_and_inputs()299return config, inputs_dict300
301def get_config(self):302return SpeechT5Config(303hidden_size=self.hidden_size,304encoder_layers=self.num_hidden_layers,305decoder_layers=self.num_hidden_layers,306encoder_attention_heads=self.num_attention_heads,307decoder_attention_heads=self.num_attention_heads,308encoder_ffn_dim=self.intermediate_size,309decoder_ffn_dim=self.intermediate_size,310conv_dim=self.conv_dim,311conv_stride=self.conv_stride,312conv_kernel=self.conv_kernel,313conv_bias=self.conv_bias,314num_conv_pos_embeddings=self.num_conv_pos_embeddings,315num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,316vocab_size=self.vocab_size,317)318
319def create_and_check_model_forward(self, config, inputs_dict):320model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()321
322input_values = inputs_dict["input_values"]323attention_mask = inputs_dict["attention_mask"]324decoder_input_ids = inputs_dict["decoder_input_ids"]325
326result = model(input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)327self.parent.assertEqual(result.logits.shape, (self.batch_size, self.decoder_seq_length, self.vocab_size))328
329def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):330model = SpeechT5ForSpeechToText(config=config).get_decoder().to(torch_device).eval()331input_ids = inputs_dict["decoder_input_ids"]332attention_mask = inputs_dict["decoder_attention_mask"]333
334# first forward pass335outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)336
337output, past_key_values = outputs.to_tuple()338
339# create hypothetical multiple next token and extent to next_input_ids340next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2)341next_attn_mask = ids_tensor((self.batch_size, 3), 2)342
343# append to next input_ids and344next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)345next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)346
347output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]348output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[349"last_hidden_state"350]351
352# select random slice353random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()354output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()355output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()356
357self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])358
359# test that outputs are equal for slice360self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))361
362
363@require_torch
364class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):365all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()366all_generative_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()367is_encoder_decoder = True368test_pruning = False369test_headmasking = False370
371input_name = "input_values"372
373def setUp(self):374self.model_tester = SpeechT5ForSpeechToTextTester(self)375self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)376
377def test_config(self):378self.config_tester.run_common_tests()379
380def test_save_load_strict(self):381config, inputs_dict = self.model_tester.prepare_config_and_inputs()382for model_class in self.all_model_classes:383model = model_class(config)384
385with tempfile.TemporaryDirectory() as tmpdirname:386model.save_pretrained(tmpdirname)387model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)388self.assertEqual(info["missing_keys"], [])389
390def test_model_forward(self):391config_and_inputs = self.model_tester.prepare_config_and_inputs()392self.model_tester.create_and_check_model_forward(*config_and_inputs)393
394def test_decoder_model_past_with_large_inputs(self):395config_and_inputs = self.model_tester.prepare_config_and_inputs()396self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)397
398def test_attention_outputs(self):399config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()400config.return_dict = True401
402seq_len = getattr(self.model_tester, "seq_length", None)403decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)404encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)405decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)406encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)407
408for model_class in self.all_model_classes:409inputs_dict["output_attentions"] = True410inputs_dict["output_hidden_states"] = False411config.return_dict = True412model = model_class(config)413model.to(torch_device)414model.eval()415
416subsampled_encoder_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(417encoder_seq_length
418)419subsampled_encoder_key_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(420encoder_key_length
421)422
423with torch.no_grad():424outputs = model(**self._prepare_for_class(inputs_dict, model_class))425attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions426self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)427
428# check that output_attentions also work using config429del inputs_dict["output_attentions"]430config.output_attentions = True431model = model_class(config)432model.to(torch_device)433model.eval()434with torch.no_grad():435outputs = model(**self._prepare_for_class(inputs_dict, model_class))436attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions437self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)438
439self.assertListEqual(440list(attentions[0].shape[-3:]),441[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],442)443out_len = len(outputs)444
445correct_outlen = 5446
447# loss is at first position448if "labels" in inputs_dict:449correct_outlen += 1 # loss is added to beginning450if "past_key_values" in outputs:451correct_outlen += 1 # past_key_values have been returned452
453self.assertEqual(out_len, correct_outlen)454
455# decoder attentions456decoder_attentions = outputs.decoder_attentions457self.assertIsInstance(decoder_attentions, (list, tuple))458self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)459self.assertListEqual(460list(decoder_attentions[0].shape[-3:]),461[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],462)463
464# cross attentions465cross_attentions = outputs.cross_attentions466self.assertIsInstance(cross_attentions, (list, tuple))467self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)468self.assertListEqual(469list(cross_attentions[0].shape[-3:]),470[471self.model_tester.num_attention_heads,472decoder_seq_length,473subsampled_encoder_key_length,474],475)476
477# Check attention is always last and order is fine478inputs_dict["output_attentions"] = True479inputs_dict["output_hidden_states"] = True480model = model_class(config)481model.to(torch_device)482model.eval()483with torch.no_grad():484outputs = model(**self._prepare_for_class(inputs_dict, model_class))485
486added_hidden_states = 2487self.assertEqual(out_len + added_hidden_states, len(outputs))488
489self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions490
491self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)492self.assertListEqual(493list(self_attentions[0].shape[-3:]),494[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],495)496
497def test_forward_signature(self):498config, _ = self.model_tester.prepare_config_and_inputs_for_common()499
500for model_class in self.all_model_classes:501model = model_class(config)502signature = inspect.signature(model.forward)503# signature.parameters is an OrderedDict => so arg_names order is deterministic504arg_names = [*signature.parameters.keys()]505
506expected_arg_names = [507"input_values",508"attention_mask",509"decoder_input_ids",510"decoder_attention_mask",511]512expected_arg_names.extend(513["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]514if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names515else ["encoder_outputs"]516)517self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)518
519def test_hidden_states_output(self):520def check_hidden_states_output(inputs_dict, config, model_class):521model = model_class(config)522model.to(torch_device)523model.eval()524
525with torch.no_grad():526outputs = model(**self._prepare_for_class(inputs_dict, model_class))527
528hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states529
530expected_num_layers = getattr(531self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1532)533self.assertEqual(len(hidden_states), expected_num_layers)534
535if hasattr(self.model_tester, "encoder_seq_length"):536seq_length = self.model_tester.encoder_seq_length537else:538seq_length = self.model_tester.seq_length539
540subsampled_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(seq_length)541
542self.assertListEqual(543list(hidden_states[0].shape[-2:]),544[subsampled_seq_length, self.model_tester.hidden_size],545)546
547if config.is_encoder_decoder:548hidden_states = outputs.decoder_hidden_states549
550self.assertIsInstance(hidden_states, (list, tuple))551self.assertEqual(len(hidden_states), expected_num_layers)552seq_len = getattr(self.model_tester, "seq_length", None)553decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)554
555self.assertListEqual(556list(hidden_states[0].shape[-2:]),557[decoder_seq_length, self.model_tester.hidden_size],558)559
560config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()561
562for model_class in self.all_model_classes:563inputs_dict["output_hidden_states"] = True564check_hidden_states_output(inputs_dict, config, model_class)565
566# check that output_hidden_states also work using config567del inputs_dict["output_hidden_states"]568config.output_hidden_states = True569
570check_hidden_states_output(inputs_dict, config, model_class)571
572def test_initialization(self):573config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()574
575configs_no_init = _config_zero_init(config)576for model_class in self.all_model_classes:577model = model_class(config=configs_no_init)578for name, param in model.named_parameters():579uniform_init_parms = [580"conv.weight",581"conv.parametrizations.weight",582"masked_spec_embed",583"feature_projection.projection.weight",584"feature_projection.projection.bias",585]586if param.requires_grad:587if any(x in name for x in uniform_init_parms):588self.assertTrue(589-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,590msg=f"Parameter {name} of model {model_class} seems not properly initialized",591)592else:593self.assertIn(594((param.data.mean() * 1e9).round() / 1e9).item(),595[0.0, 1.0],596msg=f"Parameter {name} of model {model_class} seems not properly initialized",597)598
599# this model has no inputs_embeds600def test_inputs_embeds(self):601pass602
603def test_resize_embeddings_untied(self):604original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()605if not self.test_resize_embeddings:606return607
608original_config.tie_word_embeddings = False609
610# if model cannot untied embeddings -> leave test611if original_config.tie_word_embeddings:612return613
614for model_class in self.all_model_classes:615config = copy.deepcopy(original_config)616model = model_class(config).to(torch_device)617
618# if no output embeddings -> leave test619if model.get_output_embeddings() is None:620continue621
622# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size623model_vocab_size = config.vocab_size624model.resize_token_embeddings(model_vocab_size + 10)625self.assertEqual(model.config.vocab_size, model_vocab_size + 10)626output_embeds = model.get_output_embeddings()627self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)628# Check bias if present629if output_embeds.bias is not None:630self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)631# Check that the model can still do a forward pass successfully (every parameter should be resized)632model(**self._prepare_for_class(inputs_dict, model_class))633
634# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size635model.resize_token_embeddings(model_vocab_size - 15)636self.assertEqual(model.config.vocab_size, model_vocab_size - 15)637# Check that it actually resizes the embeddings matrix638output_embeds = model.get_output_embeddings()639self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)640# Check bias if present641if output_embeds.bias is not None:642self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)643# Check that the model can still do a forward pass successfully (every parameter should be resized)644if "decoder_input_ids" in inputs_dict:645inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)646# Check that the model can still do a forward pass successfully (every parameter should be resized)647model(**self._prepare_for_class(inputs_dict, model_class))648
649def test_resize_tokens_embeddings(self):650original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()651if not self.test_resize_embeddings:652return653
654for model_class in self.all_model_classes:655config = copy.deepcopy(original_config)656model = model_class(config)657model.to(torch_device)658
659if self.model_tester.is_training is False:660model.eval()661
662model_vocab_size = config.vocab_size663# Retrieve the embeddings and clone theme664model_embed = model.resize_token_embeddings(model_vocab_size)665cloned_embeddings = model_embed.weight.clone()666
667# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size668model_embed = model.resize_token_embeddings(model_vocab_size + 10)669self.assertEqual(model.config.vocab_size, model_vocab_size + 10)670# Check that it actually resizes the embeddings matrix671self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)672# Check that the model can still do a forward pass successfully (every parameter should be resized)673model(**self._prepare_for_class(inputs_dict, model_class))674
675# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size676model_embed = model.resize_token_embeddings(model_vocab_size - 15)677self.assertEqual(model.config.vocab_size, model_vocab_size - 15)678# Check that it actually resizes the embeddings matrix679self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)680
681# make sure that decoder_input_ids are resized682if "decoder_input_ids" in inputs_dict:683inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)684model(**self._prepare_for_class(inputs_dict, model_class))685
686# Check that adding and removing tokens has not modified the first part of the embedding matrix.687models_equal = True688for p1, p2 in zip(cloned_embeddings, model_embed.weight):689if p1.data.ne(p2.data).sum() > 0:690models_equal = False691
692self.assertTrue(models_equal)693
694def test_retain_grad_hidden_states_attentions(self):695# decoder cannot keep gradients696pass697
698# training is not supported yet699def test_training(self):700pass701
702def test_training_gradient_checkpointing(self):703pass704
705@unittest.skip(706reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"707)708def test_training_gradient_checkpointing_use_reentrant(self):709pass710
711@unittest.skip(712reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"713)714def test_training_gradient_checkpointing_use_reentrant_false(self):715pass716
717# overwrite from test_modeling_common718def _mock_init_weights(self, module):719if hasattr(module, "weight") and module.weight is not None:720module.weight.data.fill_(3)721if hasattr(module, "weight_g") and module.weight_g is not None:722module.weight_g.data.fill_(3)723if hasattr(module, "weight_v") and module.weight_v is not None:724module.weight_v.data.fill_(3)725if hasattr(module, "bias") and module.bias is not None:726module.bias.data.fill_(3)727if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:728module.masked_spec_embed.data.fill_(3)729
730
731@require_torch
732@require_sentencepiece
733@require_tokenizers
734@slow
735class SpeechT5ForSpeechToTextIntegrationTests(unittest.TestCase):736@cached_property737def default_processor(self):738return SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")739
740def _load_datasamples(self, num_samples):741from datasets import load_dataset742
743ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")744# automatic decoding with librispeech745speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]746
747return [x["array"] for x in speech_samples]748
749def test_generation_librispeech(self):750model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")751model.to(torch_device)752processor = self.default_processor753
754input_speech = self._load_datasamples(1)755
756input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)757
758generated_ids = model.generate(input_values)759generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)760
761EXPECTED_TRANSCRIPTIONS = [762"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"763]764self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)765
766def test_generation_librispeech_batched(self):767model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")768model.to(torch_device)769processor = self.default_processor770
771input_speech = self._load_datasamples(4)772
773inputs = processor(audio=input_speech, return_tensors="pt", padding=True)774
775input_values = inputs.input_values.to(torch_device)776attention_mask = inputs.attention_mask.to(torch_device)777
778generated_ids = model.generate(input_values, attention_mask=attention_mask)779generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)780
781EXPECTED_TRANSCRIPTIONS = [782"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",783"nor is mister quilter's manner less interesting than his matter",784"he tells us that at this festive season of the year with christmas and rosebeaf looming before us"785" similars drawn from eating and its results occur most readily to the mind",786"he has grave doubts whether sir frederick latin's work is really greek after all and can discover in it"787" but little of rocky ithica",788]789self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)790
791
792@require_torch
793class SpeechT5ForTextToSpeechTester:794def __init__(795self,796parent,797batch_size=13,798encoder_seq_length=7,799decoder_seq_length=1024, # speech is longer800is_training=False,801hidden_size=24,802num_hidden_layers=2,803num_attention_heads=2,804intermediate_size=4,805vocab_size=81,806num_mel_bins=20,807reduction_factor=2,808speech_decoder_postnet_layers=2,809speech_decoder_postnet_units=32,810speech_decoder_prenet_units=32,811):812self.parent = parent813self.batch_size = batch_size814self.encoder_seq_length = encoder_seq_length815self.decoder_seq_length = decoder_seq_length816self.is_training = is_training817self.hidden_size = hidden_size818self.num_hidden_layers = num_hidden_layers819self.num_attention_heads = num_attention_heads820self.intermediate_size = intermediate_size821self.vocab_size = vocab_size822self.num_mel_bins = num_mel_bins823self.reduction_factor = reduction_factor824self.speech_decoder_postnet_layers = speech_decoder_postnet_layers825self.speech_decoder_postnet_units = speech_decoder_postnet_units826self.speech_decoder_prenet_units = speech_decoder_prenet_units827
828def prepare_config_and_inputs(self):829input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2)830attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])831
832decoder_input_values = floats_tensor([self.batch_size, self.decoder_seq_length, self.num_mel_bins], scale=1.0)833decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])834
835config = self.get_config()836inputs_dict = prepare_inputs_dict(837config,838input_ids=input_ids,839decoder_input_values=decoder_input_values,840attention_mask=attention_mask,841decoder_attention_mask=decoder_attention_mask,842)843return config, inputs_dict844
845def prepare_config_and_inputs_for_common(self):846config, inputs_dict = self.prepare_config_and_inputs()847return config, inputs_dict848
849def get_config(self):850return SpeechT5Config(851hidden_size=self.hidden_size,852encoder_layers=self.num_hidden_layers,853decoder_layers=self.num_hidden_layers,854encoder_attention_heads=self.num_attention_heads,855decoder_attention_heads=self.num_attention_heads,856encoder_ffn_dim=self.intermediate_size,857decoder_ffn_dim=self.intermediate_size,858vocab_size=self.vocab_size,859num_mel_bins=self.num_mel_bins,860reduction_factor=self.reduction_factor,861speech_decoder_postnet_layers=self.speech_decoder_postnet_layers,862speech_decoder_postnet_units=self.speech_decoder_postnet_units,863speech_decoder_prenet_units=self.speech_decoder_prenet_units,864)865
866def create_and_check_model_forward(self, config, inputs_dict):867model = SpeechT5ForTextToSpeech(config=config).to(torch_device).eval()868
869input_ids = inputs_dict["input_ids"]870attention_mask = inputs_dict["attention_mask"]871decoder_input_values = inputs_dict["decoder_input_values"]872
873result = model(input_ids, attention_mask=attention_mask, decoder_input_values=decoder_input_values)874self.parent.assertEqual(875result.spectrogram.shape,876(self.batch_size, self.decoder_seq_length * self.reduction_factor, self.num_mel_bins),877)878
879
880@require_torch
881class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):882all_model_classes = (SpeechT5ForTextToSpeech,) if is_torch_available() else ()883all_generative_model_classes = (SpeechT5ForTextToSpeech,) if is_torch_available() else ()884is_encoder_decoder = True885test_pruning = False886test_headmasking = False887
888input_name = "input_ids"889
890def setUp(self):891self.model_tester = SpeechT5ForTextToSpeechTester(self)892self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)893
894def test_config(self):895self.config_tester.run_common_tests()896
897def test_save_load_strict(self):898config, inputs_dict = self.model_tester.prepare_config_and_inputs()899for model_class in self.all_model_classes:900model = model_class(config)901
902with tempfile.TemporaryDirectory() as tmpdirname:903model.save_pretrained(tmpdirname)904model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)905self.assertEqual(info["missing_keys"], [])906
907def test_model_forward(self):908config_and_inputs = self.model_tester.prepare_config_and_inputs()909self.model_tester.create_and_check_model_forward(*config_and_inputs)910
911# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet912def test_decoder_model_past_with_large_inputs(self):913pass914
915# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet916def test_determinism(self):917pass918
919def test_forward_signature(self):920config, _ = self.model_tester.prepare_config_and_inputs_for_common()921
922for model_class in self.all_model_classes:923model = model_class(config)924signature = inspect.signature(model.forward)925# signature.parameters is an OrderedDict => so arg_names order is deterministic926arg_names = [*signature.parameters.keys()]927
928expected_arg_names = [929"input_ids",930"attention_mask",931"decoder_input_values",932"decoder_attention_mask",933]934expected_arg_names.extend(935["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]936if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names937else ["encoder_outputs"]938)939self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)940
941def test_initialization(self):942config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()943
944configs_no_init = _config_zero_init(config)945for model_class in self.all_model_classes:946model = model_class(config=configs_no_init)947for name, param in model.named_parameters():948uniform_init_parms = [949"conv.weight",950]951if param.requires_grad:952if any(x in name for x in uniform_init_parms):953self.assertTrue(954-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,955msg=f"Parameter {name} of model {model_class} seems not properly initialized",956)957else:958self.assertIn(959((param.data.mean() * 1e9).round() / 1e9).item(),960[0.0, 1.0],961msg=f"Parameter {name} of model {model_class} seems not properly initialized",962)963
964# this model has no inputs_embeds965def test_inputs_embeds(self):966pass967
968# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet969def test_model_outputs_equivalence(self):970pass971
972# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet973def test_save_load(self):974pass975
976def test_retain_grad_hidden_states_attentions(self):977# decoder cannot keep gradients978pass979
980@slow981def test_torchscript_output_attentions(self):982# disabled because this model doesn't have decoder_input_ids983pass984
985@slow986def test_torchscript_output_hidden_state(self):987# disabled because this model doesn't have decoder_input_ids988pass989
990@slow991def test_torchscript_simple(self):992# disabled because this model doesn't have decoder_input_ids993pass994
995# training is not supported yet996def test_training(self):997pass998
999def test_training_gradient_checkpointing(self):1000pass1001
1002@unittest.skip(1003reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"1004)1005def test_training_gradient_checkpointing_use_reentrant(self):1006pass1007
1008@unittest.skip(1009reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"1010)1011def test_training_gradient_checkpointing_use_reentrant_false(self):1012pass1013
1014# overwrite from test_modeling_common1015def _mock_init_weights(self, module):1016if hasattr(module, "weight") and module.weight is not None:1017module.weight.data.fill_(3)1018if hasattr(module, "weight_g") and module.weight_g is not None:1019module.weight_g.data.fill_(3)1020if hasattr(module, "weight_v") and module.weight_v is not None:1021module.weight_v.data.fill_(3)1022if hasattr(module, "bias") and module.bias is not None:1023module.bias.data.fill_(3)1024
1025
1026@require_torch
1027@require_sentencepiece
1028@require_tokenizers
1029class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):1030@cached_property1031def default_model(self):1032return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device)1033
1034@cached_property1035def default_processor(self):1036return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")1037
1038@cached_property1039def default_vocoder(self):1040return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device)1041
1042def test_generation(self):1043model = self.default_model1044processor = self.default_processor1045
1046input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."1047input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)1048speaker_embeddings = torch.zeros((1, 512), device=torch_device)1049
1050# Generate speech and validate output dimensions1051set_seed(555) # Ensure deterministic behavior1052generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)1053num_mel_bins = model.config.num_mel_bins1054self.assertEqual(1055generated_speech.shape[1], num_mel_bins, "Generated speech output has an unexpected number of mel bins."1056)1057
1058# Validate generation with additional kwargs using model.generate;1059# same method than generate_speech1060set_seed(555) # Reset seed for consistent results1061generated_speech_with_generate = model.generate(1062input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings1063)1064self.assertEqual(1065generated_speech_with_generate.shape,1066generated_speech.shape,1067"Shape mismatch between generate_speech and generate methods.",1068)1069
1070def test_one_to_many_generation(self):1071model = self.default_model1072processor = self.default_processor1073vocoder = self.default_vocoder1074
1075input_text = [1076"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",1077"nor is mister quilter's manner less interesting than his matter",1078"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",1079]1080inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)1081speaker_embeddings = torch.zeros((1, 512), device=torch_device)1082
1083# Generate spectrograms1084set_seed(555) # Ensure deterministic behavior1085spectrograms, spectrogram_lengths = model.generate_speech(1086input_ids=inputs["input_ids"],1087speaker_embeddings=speaker_embeddings,1088attention_mask=inputs["attention_mask"],1089return_output_lengths=True,1090)1091
1092# Validate generated spectrogram dimensions1093expected_batch_size = len(input_text)1094num_mel_bins = model.config.num_mel_bins1095actual_batch_size, _, actual_num_mel_bins = spectrograms.shape1096self.assertEqual(actual_batch_size, expected_batch_size, "Batch size of generated spectrograms is incorrect.")1097self.assertEqual(1098actual_num_mel_bins, num_mel_bins, "Number of mel bins in batch generated spectrograms is incorrect."1099)1100
1101# Generate waveforms using the vocoder1102waveforms = vocoder(spectrograms)1103waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]1104
1105# Validate generation with integrated vocoder1106set_seed(555) # Reset seed for consistent results1107waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(1108input_ids=inputs["input_ids"],1109speaker_embeddings=speaker_embeddings,1110attention_mask=inputs["attention_mask"],1111vocoder=vocoder,1112return_output_lengths=True,1113)1114
1115# Check consistency between waveforms generated with and without standalone vocoder1116self.assertTrue(1117torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),1118"Mismatch in waveforms generated with and without the standalone vocoder.",1119)1120self.assertEqual(1121waveform_lengths,1122waveform_lengths_with_vocoder,1123"Waveform lengths differ between standalone and integrated vocoder generation.",1124)1125
1126# Test generation consistency without returning lengths1127set_seed(555) # Reset seed for consistent results1128waveforms_with_vocoder_no_lengths = model.generate_speech(1129input_ids=inputs["input_ids"],1130speaker_embeddings=speaker_embeddings,1131attention_mask=inputs["attention_mask"],1132vocoder=vocoder,1133return_output_lengths=False,1134)1135
1136# Validate waveform consistency without length information1137self.assertTrue(1138torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),1139"Waveforms differ when generated with and without length information.",1140)1141
1142# Validate batch vs. single instance generation consistency1143for i, text in enumerate(input_text):1144inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)1145set_seed(555) # Reset seed for consistent results1146spectrogram = model.generate_speech(1147input_ids=inputs["input_ids"],1148speaker_embeddings=speaker_embeddings,1149)1150
1151# Check spectrogram shape consistency1152self.assertEqual(1153spectrogram.shape,1154spectrograms[i][: spectrogram_lengths[i]].shape,1155"Mismatch in spectrogram shape between batch and single instance generation.",1156)1157
1158# Generate and validate waveform for single instance1159waveform = vocoder(spectrogram)1160self.assertEqual(1161waveform.shape,1162waveforms[i][: waveform_lengths[i]].shape,1163"Mismatch in waveform shape between batch and single instance generation.",1164)1165
1166# Check waveform consistency with integrated vocoder1167set_seed(555) # Reset seed for consistent results1168waveform_with_integrated_vocoder = model.generate_speech(1169input_ids=inputs["input_ids"],1170speaker_embeddings=speaker_embeddings,1171vocoder=vocoder,1172)1173self.assertTrue(1174torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),1175"Mismatch in waveform between standalone and integrated vocoder for single instance generation.",1176)1177
1178def test_batch_generation(self):1179model = self.default_model1180processor = self.default_processor1181vocoder = self.default_vocoder1182
1183input_text = [1184"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",1185"nor is mister quilter's manner less interesting than his matter",1186"he tells us that at this festive season of the year with christmas and rosebeaf looming before us",1187]1188inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)1189set_seed(555) # Ensure deterministic behavior1190speaker_embeddings = torch.randn((len(input_text), 512), device=torch_device)1191
1192# Generate spectrograms1193set_seed(555) # Reset seed for consistent results1194spectrograms, spectrogram_lengths = model.generate_speech(1195input_ids=inputs["input_ids"],1196speaker_embeddings=speaker_embeddings,1197attention_mask=inputs["attention_mask"],1198return_output_lengths=True,1199)1200
1201# Validate generated spectrogram dimensions1202expected_batch_size = len(input_text)1203num_mel_bins = model.config.num_mel_bins1204actual_batch_size, _, actual_num_mel_bins = spectrograms.shape1205self.assertEqual(1206actual_batch_size,1207expected_batch_size,1208"Batch size of generated spectrograms is incorrect.",1209)1210self.assertEqual(1211actual_num_mel_bins,1212num_mel_bins,1213"Number of mel bins in batch generated spectrograms is incorrect.",1214)1215
1216# Generate waveforms using the vocoder1217waveforms = vocoder(spectrograms)1218waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]1219
1220# Validate generation with integrated vocoder1221set_seed(555) # Reset seed for consistent results1222waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(1223input_ids=inputs["input_ids"],1224speaker_embeddings=speaker_embeddings,1225attention_mask=inputs["attention_mask"],1226vocoder=vocoder,1227return_output_lengths=True,1228)1229
1230# Check consistency between waveforms generated with and without standalone vocoder1231self.assertTrue(1232torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),1233"Mismatch in waveforms generated with and without the standalone vocoder.",1234)1235self.assertEqual(1236waveform_lengths,1237waveform_lengths_with_vocoder,1238"Waveform lengths differ between standalone and integrated vocoder generation.",1239)1240
1241# Test generation consistency without returning lengths1242set_seed(555) # Reset seed for consistent results1243waveforms_with_vocoder_no_lengths = model.generate_speech(1244input_ids=inputs["input_ids"],1245speaker_embeddings=speaker_embeddings,1246attention_mask=inputs["attention_mask"],1247vocoder=vocoder,1248return_output_lengths=False,1249)1250
1251# Validate waveform consistency without length information1252self.assertTrue(1253torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),1254"Waveforms differ when generated with and without length information.",1255)1256
1257# Validate batch vs. single instance generation consistency1258for i, text in enumerate(input_text):1259inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)1260current_speaker_embedding = speaker_embeddings[i].unsqueeze(0)1261set_seed(555) # Reset seed for consistent results1262spectrogram = model.generate_speech(1263input_ids=inputs["input_ids"],1264speaker_embeddings=current_speaker_embedding,1265)1266
1267# Check spectrogram shape consistency1268self.assertEqual(1269spectrogram.shape,1270spectrograms[i][: spectrogram_lengths[i]].shape,1271"Mismatch in spectrogram shape between batch and single instance generation.",1272)1273
1274# Generate and validate waveform for single instance1275waveform = vocoder(spectrogram)1276self.assertEqual(1277waveform.shape,1278waveforms[i][: waveform_lengths[i]].shape,1279"Mismatch in waveform shape between batch and single instance generation.",1280)1281
1282# Check waveform consistency with integrated vocoder1283set_seed(555) # Reset seed for consistent results1284waveform_with_integrated_vocoder = model.generate_speech(1285input_ids=inputs["input_ids"],1286speaker_embeddings=current_speaker_embedding,1287vocoder=vocoder,1288)1289self.assertTrue(1290torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),1291"Mismatch in waveform between standalone and integrated vocoder for single instance generation.",1292)1293
1294
1295@require_torch
1296class SpeechT5ForSpeechToSpeechTester:1297def __init__(1298self,1299parent,1300batch_size=13,1301encoder_seq_length=1024, # speech is longer1302decoder_seq_length=1024,1303is_training=False,1304hidden_size=24,1305num_hidden_layers=2,1306num_attention_heads=2,1307intermediate_size=4,1308conv_dim=(32, 32, 32),1309conv_stride=(4, 4, 4),1310conv_kernel=(8, 8, 8),1311conv_bias=False,1312num_conv_pos_embeddings=16,1313num_conv_pos_embedding_groups=2,1314vocab_size=81,1315num_mel_bins=20,1316reduction_factor=2,1317speech_decoder_postnet_layers=2,1318speech_decoder_postnet_units=32,1319speech_decoder_prenet_units=32,1320):1321self.parent = parent1322self.batch_size = batch_size1323self.encoder_seq_length = encoder_seq_length1324self.decoder_seq_length = decoder_seq_length1325self.is_training = is_training1326self.hidden_size = hidden_size1327self.num_hidden_layers = num_hidden_layers1328self.num_attention_heads = num_attention_heads1329self.intermediate_size = intermediate_size1330self.conv_dim = conv_dim1331self.conv_stride = conv_stride1332self.conv_kernel = conv_kernel1333self.conv_bias = conv_bias1334self.num_conv_pos_embeddings = num_conv_pos_embeddings1335self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups1336self.vocab_size = vocab_size1337self.num_mel_bins = num_mel_bins1338self.reduction_factor = reduction_factor1339self.speech_decoder_postnet_layers = speech_decoder_postnet_layers1340self.speech_decoder_postnet_units = speech_decoder_postnet_units1341self.speech_decoder_prenet_units = speech_decoder_prenet_units1342
1343def prepare_config_and_inputs(self):1344input_values = floats_tensor([self.batch_size, self.encoder_seq_length], scale=1.0)1345attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])1346
1347decoder_input_values = floats_tensor([self.batch_size, self.decoder_seq_length, self.num_mel_bins], scale=1.0)1348decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])1349
1350config = self.get_config()1351inputs_dict = prepare_inputs_dict(1352config,1353input_values=input_values,1354decoder_input_values=decoder_input_values,1355attention_mask=attention_mask,1356decoder_attention_mask=decoder_attention_mask,1357)1358return config, inputs_dict1359
1360def prepare_config_and_inputs_for_common(self):1361config, inputs_dict = self.prepare_config_and_inputs()1362return config, inputs_dict1363
1364def get_config(self):1365return SpeechT5Config(1366hidden_size=self.hidden_size,1367encoder_layers=self.num_hidden_layers,1368decoder_layers=self.num_hidden_layers,1369encoder_attention_heads=self.num_attention_heads,1370decoder_attention_heads=self.num_attention_heads,1371encoder_ffn_dim=self.intermediate_size,1372decoder_ffn_dim=self.intermediate_size,1373conv_dim=self.conv_dim,1374conv_stride=self.conv_stride,1375conv_kernel=self.conv_kernel,1376conv_bias=self.conv_bias,1377num_conv_pos_embeddings=self.num_conv_pos_embeddings,1378num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,1379vocab_size=self.vocab_size,1380num_mel_bins=self.num_mel_bins,1381reduction_factor=self.reduction_factor,1382speech_decoder_postnet_layers=self.speech_decoder_postnet_layers,1383speech_decoder_postnet_units=self.speech_decoder_postnet_units,1384speech_decoder_prenet_units=self.speech_decoder_prenet_units,1385)1386
1387def create_and_check_model_forward(self, config, inputs_dict):1388model = SpeechT5ForSpeechToSpeech(config=config).to(torch_device).eval()1389
1390input_values = inputs_dict["input_values"]1391attention_mask = inputs_dict["attention_mask"]1392decoder_input_values = inputs_dict["decoder_input_values"]1393
1394result = model(input_values, attention_mask=attention_mask, decoder_input_values=decoder_input_values)1395self.parent.assertEqual(1396result.spectrogram.shape,1397(self.batch_size, self.decoder_seq_length * self.reduction_factor, self.num_mel_bins),1398)1399
1400
1401@require_torch
1402class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):1403all_model_classes = (SpeechT5ForSpeechToSpeech,) if is_torch_available() else ()1404all_generative_model_classes = (SpeechT5ForSpeechToSpeech,) if is_torch_available() else ()1405is_encoder_decoder = True1406test_pruning = False1407test_headmasking = False1408test_resize_embeddings = False1409
1410input_name = "input_values"1411
1412def setUp(self):1413self.model_tester = SpeechT5ForSpeechToSpeechTester(self)1414self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)1415
1416def test_config(self):1417self.config_tester.run_common_tests()1418
1419def test_save_load_strict(self):1420config, inputs_dict = self.model_tester.prepare_config_and_inputs()1421for model_class in self.all_model_classes:1422model = model_class(config)1423
1424with tempfile.TemporaryDirectory() as tmpdirname:1425model.save_pretrained(tmpdirname)1426model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)1427self.assertEqual(info["missing_keys"], [])1428
1429def test_model_forward(self):1430config_and_inputs = self.model_tester.prepare_config_and_inputs()1431self.model_tester.create_and_check_model_forward(*config_and_inputs)1432
1433# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet1434def test_decoder_model_past_with_large_inputs(self):1435pass1436
1437# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet1438def test_determinism(self):1439pass1440
1441def test_attention_outputs(self):1442config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()1443config.return_dict = True1444
1445seq_len = getattr(self.model_tester, "seq_length", None)1446decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)1447encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)1448decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)1449encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)1450
1451for model_class in self.all_model_classes:1452inputs_dict["output_attentions"] = True1453inputs_dict["output_hidden_states"] = False1454config.return_dict = True1455model = model_class(config)1456model.to(torch_device)1457model.eval()1458
1459subsampled_encoder_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(1460encoder_seq_length
1461)1462subsampled_encoder_key_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(1463encoder_key_length
1464)1465
1466with torch.no_grad():1467outputs = model(**self._prepare_for_class(inputs_dict, model_class))1468attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions1469self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)1470
1471# check that output_attentions also work using config1472del inputs_dict["output_attentions"]1473config.output_attentions = True1474model = model_class(config)1475model.to(torch_device)1476model.eval()1477with torch.no_grad():1478outputs = model(**self._prepare_for_class(inputs_dict, model_class))1479attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions1480self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)1481
1482self.assertListEqual(1483list(attentions[0].shape[-3:]),1484[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],1485)1486out_len = len(outputs)1487
1488correct_outlen = 51489
1490# loss is at first position1491if "labels" in inputs_dict:1492correct_outlen += 1 # loss is added to beginning1493if "past_key_values" in outputs:1494correct_outlen += 1 # past_key_values have been returned1495
1496self.assertEqual(out_len, correct_outlen)1497
1498# decoder attentions1499decoder_attentions = outputs.decoder_attentions1500self.assertIsInstance(decoder_attentions, (list, tuple))1501self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)1502self.assertListEqual(1503list(decoder_attentions[0].shape[-3:]),1504[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],1505)1506
1507# cross attentions1508cross_attentions = outputs.cross_attentions1509self.assertIsInstance(cross_attentions, (list, tuple))1510self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)1511self.assertListEqual(1512list(cross_attentions[0].shape[-3:]),1513[1514self.model_tester.num_attention_heads,1515decoder_seq_length,1516subsampled_encoder_key_length,1517],1518)1519
1520# Check attention is always last and order is fine1521inputs_dict["output_attentions"] = True1522inputs_dict["output_hidden_states"] = True1523model = model_class(config)1524model.to(torch_device)1525model.eval()1526with torch.no_grad():1527outputs = model(**self._prepare_for_class(inputs_dict, model_class))1528
1529added_hidden_states = 21530self.assertEqual(out_len + added_hidden_states, len(outputs))1531
1532self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions1533
1534self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)1535self.assertListEqual(1536list(self_attentions[0].shape[-3:]),1537[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],1538)1539
1540def test_forward_signature(self):1541config, _ = self.model_tester.prepare_config_and_inputs_for_common()1542
1543for model_class in self.all_model_classes:1544model = model_class(config)1545signature = inspect.signature(model.forward)1546# signature.parameters is an OrderedDict => so arg_names order is deterministic1547arg_names = [*signature.parameters.keys()]1548
1549expected_arg_names = [1550"input_values",1551"attention_mask",1552"decoder_input_values",1553"decoder_attention_mask",1554]1555expected_arg_names.extend(1556["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]1557if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names1558else ["encoder_outputs"]1559)1560self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)1561
1562def test_hidden_states_output(self):1563def check_hidden_states_output(inputs_dict, config, model_class):1564model = model_class(config)1565model.to(torch_device)1566model.eval()1567
1568with torch.no_grad():1569outputs = model(**self._prepare_for_class(inputs_dict, model_class))1570
1571hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states1572
1573expected_num_layers = getattr(1574self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 11575)1576self.assertEqual(len(hidden_states), expected_num_layers)1577
1578if hasattr(self.model_tester, "encoder_seq_length"):1579seq_length = self.model_tester.encoder_seq_length1580else:1581seq_length = self.model_tester.seq_length1582
1583subsampled_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(seq_length)1584
1585self.assertListEqual(1586list(hidden_states[0].shape[-2:]),1587[subsampled_seq_length, self.model_tester.hidden_size],1588)1589
1590if config.is_encoder_decoder:1591hidden_states = outputs.decoder_hidden_states1592
1593self.assertIsInstance(hidden_states, (list, tuple))1594self.assertEqual(len(hidden_states), expected_num_layers)1595seq_len = getattr(self.model_tester, "seq_length", None)1596decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)1597
1598self.assertListEqual(1599list(hidden_states[0].shape[-2:]),1600[decoder_seq_length, self.model_tester.hidden_size],1601)1602
1603config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()1604
1605for model_class in self.all_model_classes:1606inputs_dict["output_hidden_states"] = True1607check_hidden_states_output(inputs_dict, config, model_class)1608
1609# check that output_hidden_states also work using config1610del inputs_dict["output_hidden_states"]1611config.output_hidden_states = True1612
1613check_hidden_states_output(inputs_dict, config, model_class)1614
1615def test_initialization(self):1616config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()1617
1618configs_no_init = _config_zero_init(config)1619for model_class in self.all_model_classes:1620model = model_class(config=configs_no_init)1621for name, param in model.named_parameters():1622uniform_init_parms = [1623"conv.weight",1624"conv.parametrizations.weight",1625"masked_spec_embed",1626"feature_projection.projection.weight",1627"feature_projection.projection.bias",1628]1629if param.requires_grad:1630if any(x in name for x in uniform_init_parms):1631self.assertTrue(1632-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,1633msg=f"Parameter {name} of model {model_class} seems not properly initialized",1634)1635else:1636self.assertIn(1637((param.data.mean() * 1e9).round() / 1e9).item(),1638[0.0, 1.0],1639msg=f"Parameter {name} of model {model_class} seems not properly initialized",1640)1641
1642# this model has no inputs_embeds1643def test_inputs_embeds(self):1644pass1645
1646# this model has no input embeddings1647def test_model_common_attributes(self):1648pass1649
1650# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet1651def test_model_outputs_equivalence(self):1652pass1653
1654def test_retain_grad_hidden_states_attentions(self):1655# decoder cannot keep gradients1656pass1657
1658# skipped because there is always dropout in SpeechT5SpeechDecoderPrenet1659def test_save_load(self):1660pass1661
1662@slow1663def test_torchscript_output_attentions(self):1664# disabled because this model doesn't have decoder_input_ids1665pass1666
1667@slow1668def test_torchscript_output_hidden_state(self):1669# disabled because this model doesn't have decoder_input_ids1670pass1671
1672@slow1673def test_torchscript_simple(self):1674# disabled because this model doesn't have decoder_input_ids1675pass1676
1677# training is not supported yet1678def test_training(self):1679pass1680
1681def test_training_gradient_checkpointing(self):1682pass1683
1684@unittest.skip(1685reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"1686)1687def test_training_gradient_checkpointing_use_reentrant(self):1688pass1689
1690@unittest.skip(1691reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"1692)1693def test_training_gradient_checkpointing_use_reentrant_false(self):1694pass1695
1696# overwrite from test_modeling_common1697def _mock_init_weights(self, module):1698if hasattr(module, "weight") and module.weight is not None:1699module.weight.data.fill_(3)1700if hasattr(module, "weight_g") and module.weight_g is not None:1701module.weight_g.data.fill_(3)1702if hasattr(module, "weight_v") and module.weight_v is not None:1703module.weight_v.data.fill_(3)1704if hasattr(module, "bias") and module.bias is not None:1705module.bias.data.fill_(3)1706if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:1707module.masked_spec_embed.data.fill_(3)1708
1709
1710@require_torch
1711@require_sentencepiece
1712@require_tokenizers
1713@slow
1714class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):1715@cached_property1716def default_processor(self):1717return SpeechT5Processor.from_pretrained("microsoft/speecht5_vc")1718
1719def _load_datasamples(self, num_samples):1720from datasets import load_dataset1721
1722ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")1723# automatic decoding with librispeech1724speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]1725
1726return [x["array"] for x in speech_samples]1727
1728def test_generation_librispeech(self):1729model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc")1730model.to(torch_device)1731processor = self.default_processor1732
1733input_speech = self._load_datasamples(1)1734input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)1735
1736speaker_embeddings = torch.zeros((1, 512), device=torch_device)1737generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)1738
1739self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)1740self.assertGreaterEqual(generated_speech.shape[0], 300)1741self.assertLessEqual(generated_speech.shape[0], 310)1742
1743
1744class SpeechT5HifiGanTester:1745def __init__(1746self,1747parent,1748batch_size=13,1749seq_length=7,1750is_training=False,1751num_mel_bins=20,1752):1753self.parent = parent1754self.batch_size = batch_size1755self.seq_length = seq_length1756self.is_training = is_training1757self.num_mel_bins = num_mel_bins1758
1759def prepare_config_and_inputs(self):1760input_values = floats_tensor([self.seq_length, self.num_mel_bins], scale=1.0)1761config = self.get_config()1762return config, input_values1763
1764def get_config(self):1765return SpeechT5HifiGanConfig(1766model_in_dim=self.num_mel_bins,1767upsample_initial_channel=32,1768)1769
1770def create_and_check_model(self, config, input_values):1771model = SpeechT5HifiGan(config=config).to(torch_device).eval()1772result = model(input_values)1773self.parent.assertEqual(result.shape, (self.seq_length * 256,))1774
1775def prepare_config_and_inputs_for_common(self):1776config, input_values = self.prepare_config_and_inputs()1777inputs_dict = {"spectrogram": input_values}1778return config, inputs_dict1779
1780
1781@require_torch
1782class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):1783all_model_classes = (SpeechT5HifiGan,) if is_torch_available() else ()1784test_torchscript = False1785test_pruning = False1786test_resize_embeddings = False1787test_resize_position_embeddings = False1788test_head_masking = False1789test_mismatched_shapes = False1790test_missing_keys = False1791test_model_parallel = False1792is_encoder_decoder = False1793has_attentions = False1794
1795input_name = "spectrogram"1796
1797def setUp(self):1798self.model_tester = SpeechT5HifiGanTester(self)1799self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)1800
1801def test_config(self):1802self.config_tester.create_and_test_config_to_json_string()1803self.config_tester.create_and_test_config_to_json_file()1804self.config_tester.create_and_test_config_from_and_save_pretrained()1805self.config_tester.create_and_test_config_from_and_save_pretrained_subfolder()1806self.config_tester.create_and_test_config_with_num_labels()1807self.config_tester.check_config_can_be_init_without_params()1808self.config_tester.check_config_arguments_init()1809
1810def test_model(self):1811config_and_inputs = self.model_tester.prepare_config_and_inputs()1812self.model_tester.create_and_check_model(*config_and_inputs)1813
1814def test_forward_signature(self):1815config, _ = self.model_tester.prepare_config_and_inputs_for_common()1816
1817for model_class in self.all_model_classes:1818model = model_class(config)1819signature = inspect.signature(model.forward)1820# signature.parameters is an OrderedDict => so arg_names order is deterministic1821arg_names = [*signature.parameters.keys()]1822
1823expected_arg_names = [1824"spectrogram",1825]1826self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)1827
1828# this model does not output hidden states1829def test_hidden_states_output(self):1830pass1831
1832# skip1833def test_initialization(self):1834pass1835
1836# this model has no inputs_embeds1837def test_inputs_embeds(self):1838pass1839
1840# this model has no input embeddings1841def test_model_common_attributes(self):1842pass1843
1844# skip as this model doesn't support all arguments tested1845def test_model_outputs_equivalence(self):1846pass1847
1848# this model does not output hidden states1849def test_retain_grad_hidden_states_attentions(self):1850pass1851
1852# skip because it fails on automapping of SpeechT5HifiGanConfig1853def test_save_load_fast_init_from_base(self):1854pass1855
1856# skip because it fails on automapping of SpeechT5HifiGanConfig1857def test_save_load_fast_init_to_base(self):1858pass1859
1860def test_batched_inputs_outputs(self):1861config, inputs = self.model_tester.prepare_config_and_inputs_for_common()1862
1863for model_class in self.all_model_classes:1864model = model_class(config)1865model.to(torch_device)1866model.eval()1867
1868batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)1869with torch.no_grad():1870batched_outputs = model(batched_inputs.to(torch_device))1871
1872self.assertEqual(1873batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"1874)1875
1876def test_unbatched_inputs_outputs(self):1877config, inputs = self.model_tester.prepare_config_and_inputs_for_common()1878
1879for model_class in self.all_model_classes:1880model = model_class(config)1881model.to(torch_device)1882model.eval()1883
1884with torch.no_grad():1885outputs = model(inputs["spectrogram"].to(torch_device))1886self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")1887