transformers

Форк
0
/
test_modeling_speecht5.py 
1886 строк · 76.9 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch SpeechT5 model. """
16

17
import copy
18
import inspect
19
import tempfile
20
import unittest
21

22
from transformers import SpeechT5Config, SpeechT5HifiGanConfig
23
from transformers.testing_utils import (
24
    is_torch_available,
25
    require_sentencepiece,
26
    require_tokenizers,
27
    require_torch,
28
    slow,
29
    torch_device,
30
)
31
from transformers.trainer_utils import set_seed
32
from transformers.utils import cached_property
33

34
from ...test_configuration_common import ConfigTester
35
from ...test_modeling_common import (
36
    ModelTesterMixin,
37
    _config_zero_init,
38
    floats_tensor,
39
    ids_tensor,
40
    random_attention_mask,
41
)
42
from ...test_pipeline_mixin import PipelineTesterMixin
43

44

45
if is_torch_available():
46
    import torch
47

48
    from transformers import (
49
        SpeechT5ForSpeechToSpeech,
50
        SpeechT5ForSpeechToText,
51
        SpeechT5ForTextToSpeech,
52
        SpeechT5HifiGan,
53
        SpeechT5Model,
54
        SpeechT5Processor,
55
    )
56

57

58
def prepare_inputs_dict(
59
    config,
60
    input_ids=None,
61
    input_values=None,
62
    decoder_input_ids=None,
63
    decoder_input_values=None,
64
    attention_mask=None,
65
    decoder_attention_mask=None,
66
    head_mask=None,
67
    decoder_head_mask=None,
68
    cross_attn_head_mask=None,
69
):
70
    if input_ids is not None:
71
        encoder_dict = {"input_ids": input_ids}
72
    else:
73
        encoder_dict = {"input_values": input_values}
74

75
    if decoder_input_ids is not None:
76
        decoder_dict = {"decoder_input_ids": decoder_input_ids}
77
    else:
78
        decoder_dict = {"decoder_input_values": decoder_input_values}
79

80
    if head_mask is None:
81
        head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
82
    if decoder_head_mask is None:
83
        decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
84
    if cross_attn_head_mask is None:
85
        cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
86

87
    return {
88
        **encoder_dict,
89
        **decoder_dict,
90
        "attention_mask": attention_mask,
91
        "decoder_attention_mask": decoder_attention_mask,
92
        "head_mask": head_mask,
93
        "decoder_head_mask": decoder_head_mask,
94
        "cross_attn_head_mask": cross_attn_head_mask,
95
    }
96

97

98
@require_torch
99
class SpeechT5ModelTester:
100
    def __init__(
101
        self,
102
        parent,
103
        batch_size=13,
104
        seq_length=7,
105
        is_training=False,
106
        vocab_size=81,
107
        hidden_size=24,
108
        num_hidden_layers=2,
109
        num_attention_heads=2,
110
        intermediate_size=4,
111
    ):
112
        self.parent = parent
113
        self.batch_size = batch_size
114
        self.seq_length = seq_length
115
        self.is_training = is_training
116
        self.vocab_size = vocab_size
117
        self.hidden_size = hidden_size
118
        self.num_hidden_layers = num_hidden_layers
119
        self.num_attention_heads = num_attention_heads
120
        self.intermediate_size = intermediate_size
121

122
    def prepare_config_and_inputs(self):
123
        input_values = floats_tensor([self.batch_size, self.seq_length, self.hidden_size], scale=1.0)
124
        attention_mask = random_attention_mask([self.batch_size, self.seq_length])
125

126
        decoder_input_values = floats_tensor([self.batch_size, self.seq_length, self.hidden_size], scale=1.0)
127
        decoder_attention_mask = random_attention_mask([self.batch_size, self.seq_length])
128

129
        config = self.get_config()
130
        inputs_dict = prepare_inputs_dict(
131
            config,
132
            input_values=input_values,
133
            decoder_input_values=decoder_input_values,
134
            attention_mask=attention_mask,
135
            decoder_attention_mask=decoder_attention_mask,
136
        )
137
        return config, inputs_dict
138

139
    def prepare_config_and_inputs_for_common(self):
140
        config, inputs_dict = self.prepare_config_and_inputs()
141
        return config, inputs_dict
142

143
    def get_config(self):
144
        return SpeechT5Config(
145
            vocab_size=self.vocab_size,
146
            hidden_size=self.hidden_size,
147
            encoder_layers=self.num_hidden_layers,
148
            decoder_layers=self.num_hidden_layers,
149
            encoder_attention_heads=self.num_attention_heads,
150
            decoder_attention_heads=self.num_attention_heads,
151
            encoder_ffn_dim=self.intermediate_size,
152
            decoder_ffn_dim=self.intermediate_size,
153
        )
154

155
    def create_and_check_model_forward(self, config, inputs_dict):
156
        model = SpeechT5Model(config=config).to(torch_device).eval()
157

158
        input_values = inputs_dict["input_values"]
159
        attention_mask = inputs_dict["attention_mask"]
160
        decoder_input_values = inputs_dict["decoder_input_values"]
161

162
        result = model(input_values, attention_mask=attention_mask, decoder_input_values=decoder_input_values)
163
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
164

165

166
@require_torch
167
class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
168
    all_model_classes = (SpeechT5Model,) if is_torch_available() else ()
169
    pipeline_model_mapping = (
170
        {"automatic-speech-recognition": SpeechT5ForSpeechToText, "feature-extraction": SpeechT5Model}
171
        if is_torch_available()
172
        else {}
173
    )
174
    is_encoder_decoder = True
175
    test_pruning = False
176
    test_headmasking = False
177
    test_resize_embeddings = False
178

179
    input_name = "input_values"
180

181
    def setUp(self):
182
        self.model_tester = SpeechT5ModelTester(self)
183
        self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
184

185
    def test_config(self):
186
        self.config_tester.run_common_tests()
187

188
    def test_model_forward(self):
189
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
190
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
191

192
    def test_forward_signature(self):
193
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
194

195
        for model_class in self.all_model_classes:
196
            model = model_class(config)
197
            signature = inspect.signature(model.forward)
198
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
199
            arg_names = [*signature.parameters.keys()]
200

201
            expected_arg_names = [
202
                "input_values",
203
                "attention_mask",
204
                "decoder_input_values",
205
                "decoder_attention_mask",
206
            ]
207
            expected_arg_names.extend(
208
                ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
209
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
210
                else ["encoder_outputs"]
211
            )
212
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
213

214
    # this model has no inputs_embeds
215
    def test_inputs_embeds(self):
216
        pass
217

218
    # this model has no input embeddings
219
    def test_model_common_attributes(self):
220
        pass
221

222
    def test_retain_grad_hidden_states_attentions(self):
223
        # decoder cannot keep gradients
224
        pass
225

226
    @slow
227
    def test_torchscript_output_attentions(self):
228
        # disabled because this model doesn't have decoder_input_ids
229
        pass
230

231
    @slow
232
    def test_torchscript_output_hidden_state(self):
233
        # disabled because this model doesn't have decoder_input_ids
234
        pass
235

236
    @slow
237
    def test_torchscript_simple(self):
238
        # disabled because this model doesn't have decoder_input_ids
239
        pass
240

241

242
@require_torch
243
class SpeechT5ForSpeechToTextTester:
244
    def __init__(
245
        self,
246
        parent,
247
        batch_size=13,
248
        encoder_seq_length=1024,  # speech is longer
249
        decoder_seq_length=7,
250
        is_training=False,
251
        hidden_size=24,
252
        num_hidden_layers=2,
253
        num_attention_heads=2,
254
        intermediate_size=4,
255
        conv_dim=(32, 32, 32),
256
        conv_stride=(4, 4, 4),
257
        conv_kernel=(8, 8, 8),
258
        conv_bias=False,
259
        num_conv_pos_embeddings=16,
260
        num_conv_pos_embedding_groups=2,
261
        vocab_size=81,
262
    ):
263
        self.parent = parent
264
        self.batch_size = batch_size
265
        self.encoder_seq_length = encoder_seq_length
266
        self.decoder_seq_length = decoder_seq_length
267
        self.is_training = is_training
268
        self.hidden_size = hidden_size
269
        self.num_hidden_layers = num_hidden_layers
270
        self.num_attention_heads = num_attention_heads
271
        self.intermediate_size = intermediate_size
272
        self.conv_dim = conv_dim
273
        self.conv_stride = conv_stride
274
        self.conv_kernel = conv_kernel
275
        self.conv_bias = conv_bias
276
        self.num_conv_pos_embeddings = num_conv_pos_embeddings
277
        self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
278
        self.vocab_size = vocab_size
279

280
    def prepare_config_and_inputs(self):
281
        input_values = floats_tensor([self.batch_size, self.encoder_seq_length], scale=1.0)
282
        attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])
283

284
        decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size).clamp(2)
285
        decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])
286

287
        config = self.get_config()
288
        inputs_dict = prepare_inputs_dict(
289
            config,
290
            input_values=input_values,
291
            decoder_input_ids=decoder_input_ids,
292
            attention_mask=attention_mask,
293
            decoder_attention_mask=decoder_attention_mask,
294
        )
295
        return config, inputs_dict
296

297
    def prepare_config_and_inputs_for_common(self):
298
        config, inputs_dict = self.prepare_config_and_inputs()
299
        return config, inputs_dict
300

301
    def get_config(self):
302
        return SpeechT5Config(
303
            hidden_size=self.hidden_size,
304
            encoder_layers=self.num_hidden_layers,
305
            decoder_layers=self.num_hidden_layers,
306
            encoder_attention_heads=self.num_attention_heads,
307
            decoder_attention_heads=self.num_attention_heads,
308
            encoder_ffn_dim=self.intermediate_size,
309
            decoder_ffn_dim=self.intermediate_size,
310
            conv_dim=self.conv_dim,
311
            conv_stride=self.conv_stride,
312
            conv_kernel=self.conv_kernel,
313
            conv_bias=self.conv_bias,
314
            num_conv_pos_embeddings=self.num_conv_pos_embeddings,
315
            num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
316
            vocab_size=self.vocab_size,
317
        )
318

319
    def create_and_check_model_forward(self, config, inputs_dict):
320
        model = SpeechT5ForSpeechToText(config=config).to(torch_device).eval()
321

322
        input_values = inputs_dict["input_values"]
323
        attention_mask = inputs_dict["attention_mask"]
324
        decoder_input_ids = inputs_dict["decoder_input_ids"]
325

326
        result = model(input_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
327
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.decoder_seq_length, self.vocab_size))
328

329
    def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
330
        model = SpeechT5ForSpeechToText(config=config).get_decoder().to(torch_device).eval()
331
        input_ids = inputs_dict["decoder_input_ids"]
332
        attention_mask = inputs_dict["decoder_attention_mask"]
333

334
        # first forward pass
335
        outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
336

337
        output, past_key_values = outputs.to_tuple()
338

339
        # create hypothetical multiple next token and extent to next_input_ids
340
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2)
341
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
342

343
        # append to next input_ids and
344
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
345
        next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
346

347
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
348
        output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
349
            "last_hidden_state"
350
        ]
351

352
        # select random slice
353
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
354
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
355
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
356

357
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
358

359
        # test that outputs are equal for slice
360
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
361

362

363
@require_torch
364
class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
365
    all_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
366
    all_generative_model_classes = (SpeechT5ForSpeechToText,) if is_torch_available() else ()
367
    is_encoder_decoder = True
368
    test_pruning = False
369
    test_headmasking = False
370

371
    input_name = "input_values"
372

373
    def setUp(self):
374
        self.model_tester = SpeechT5ForSpeechToTextTester(self)
375
        self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
376

377
    def test_config(self):
378
        self.config_tester.run_common_tests()
379

380
    def test_save_load_strict(self):
381
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
382
        for model_class in self.all_model_classes:
383
            model = model_class(config)
384

385
            with tempfile.TemporaryDirectory() as tmpdirname:
386
                model.save_pretrained(tmpdirname)
387
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
388
            self.assertEqual(info["missing_keys"], [])
389

390
    def test_model_forward(self):
391
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
392
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
393

394
    def test_decoder_model_past_with_large_inputs(self):
395
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
396
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
397

398
    def test_attention_outputs(self):
399
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
400
        config.return_dict = True
401

402
        seq_len = getattr(self.model_tester, "seq_length", None)
403
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
404
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
405
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
406
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
407

408
        for model_class in self.all_model_classes:
409
            inputs_dict["output_attentions"] = True
410
            inputs_dict["output_hidden_states"] = False
411
            config.return_dict = True
412
            model = model_class(config)
413
            model.to(torch_device)
414
            model.eval()
415

416
            subsampled_encoder_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(
417
                encoder_seq_length
418
            )
419
            subsampled_encoder_key_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(
420
                encoder_key_length
421
            )
422

423
            with torch.no_grad():
424
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
425
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
426
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
427

428
            # check that output_attentions also work using config
429
            del inputs_dict["output_attentions"]
430
            config.output_attentions = True
431
            model = model_class(config)
432
            model.to(torch_device)
433
            model.eval()
434
            with torch.no_grad():
435
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
436
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
437
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
438

439
            self.assertListEqual(
440
                list(attentions[0].shape[-3:]),
441
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
442
            )
443
            out_len = len(outputs)
444

445
            correct_outlen = 5
446

447
            # loss is at first position
448
            if "labels" in inputs_dict:
449
                correct_outlen += 1  # loss is added to beginning
450
            if "past_key_values" in outputs:
451
                correct_outlen += 1  # past_key_values have been returned
452

453
            self.assertEqual(out_len, correct_outlen)
454

455
            # decoder attentions
456
            decoder_attentions = outputs.decoder_attentions
457
            self.assertIsInstance(decoder_attentions, (list, tuple))
458
            self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
459
            self.assertListEqual(
460
                list(decoder_attentions[0].shape[-3:]),
461
                [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
462
            )
463

464
            # cross attentions
465
            cross_attentions = outputs.cross_attentions
466
            self.assertIsInstance(cross_attentions, (list, tuple))
467
            self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
468
            self.assertListEqual(
469
                list(cross_attentions[0].shape[-3:]),
470
                [
471
                    self.model_tester.num_attention_heads,
472
                    decoder_seq_length,
473
                    subsampled_encoder_key_length,
474
                ],
475
            )
476

477
            # Check attention is always last and order is fine
478
            inputs_dict["output_attentions"] = True
479
            inputs_dict["output_hidden_states"] = True
480
            model = model_class(config)
481
            model.to(torch_device)
482
            model.eval()
483
            with torch.no_grad():
484
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
485

486
            added_hidden_states = 2
487
            self.assertEqual(out_len + added_hidden_states, len(outputs))
488

489
            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
490

491
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
492
            self.assertListEqual(
493
                list(self_attentions[0].shape[-3:]),
494
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
495
            )
496

497
    def test_forward_signature(self):
498
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
499

500
        for model_class in self.all_model_classes:
501
            model = model_class(config)
502
            signature = inspect.signature(model.forward)
503
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
504
            arg_names = [*signature.parameters.keys()]
505

506
            expected_arg_names = [
507
                "input_values",
508
                "attention_mask",
509
                "decoder_input_ids",
510
                "decoder_attention_mask",
511
            ]
512
            expected_arg_names.extend(
513
                ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
514
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
515
                else ["encoder_outputs"]
516
            )
517
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
518

519
    def test_hidden_states_output(self):
520
        def check_hidden_states_output(inputs_dict, config, model_class):
521
            model = model_class(config)
522
            model.to(torch_device)
523
            model.eval()
524

525
            with torch.no_grad():
526
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
527

528
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
529

530
            expected_num_layers = getattr(
531
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
532
            )
533
            self.assertEqual(len(hidden_states), expected_num_layers)
534

535
            if hasattr(self.model_tester, "encoder_seq_length"):
536
                seq_length = self.model_tester.encoder_seq_length
537
            else:
538
                seq_length = self.model_tester.seq_length
539

540
            subsampled_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(seq_length)
541

542
            self.assertListEqual(
543
                list(hidden_states[0].shape[-2:]),
544
                [subsampled_seq_length, self.model_tester.hidden_size],
545
            )
546

547
            if config.is_encoder_decoder:
548
                hidden_states = outputs.decoder_hidden_states
549

550
                self.assertIsInstance(hidden_states, (list, tuple))
551
                self.assertEqual(len(hidden_states), expected_num_layers)
552
                seq_len = getattr(self.model_tester, "seq_length", None)
553
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
554

555
                self.assertListEqual(
556
                    list(hidden_states[0].shape[-2:]),
557
                    [decoder_seq_length, self.model_tester.hidden_size],
558
                )
559

560
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
561

562
        for model_class in self.all_model_classes:
563
            inputs_dict["output_hidden_states"] = True
564
            check_hidden_states_output(inputs_dict, config, model_class)
565

566
            # check that output_hidden_states also work using config
567
            del inputs_dict["output_hidden_states"]
568
            config.output_hidden_states = True
569

570
            check_hidden_states_output(inputs_dict, config, model_class)
571

572
    def test_initialization(self):
573
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
574

575
        configs_no_init = _config_zero_init(config)
576
        for model_class in self.all_model_classes:
577
            model = model_class(config=configs_no_init)
578
            for name, param in model.named_parameters():
579
                uniform_init_parms = [
580
                    "conv.weight",
581
                    "conv.parametrizations.weight",
582
                    "masked_spec_embed",
583
                    "feature_projection.projection.weight",
584
                    "feature_projection.projection.bias",
585
                ]
586
                if param.requires_grad:
587
                    if any(x in name for x in uniform_init_parms):
588
                        self.assertTrue(
589
                            -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
590
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
591
                        )
592
                    else:
593
                        self.assertIn(
594
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
595
                            [0.0, 1.0],
596
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
597
                        )
598

599
    # this model has no inputs_embeds
600
    def test_inputs_embeds(self):
601
        pass
602

603
    def test_resize_embeddings_untied(self):
604
        original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
605
        if not self.test_resize_embeddings:
606
            return
607

608
        original_config.tie_word_embeddings = False
609

610
        # if model cannot untied embeddings -> leave test
611
        if original_config.tie_word_embeddings:
612
            return
613

614
        for model_class in self.all_model_classes:
615
            config = copy.deepcopy(original_config)
616
            model = model_class(config).to(torch_device)
617

618
            # if no output embeddings -> leave test
619
            if model.get_output_embeddings() is None:
620
                continue
621

622
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
623
            model_vocab_size = config.vocab_size
624
            model.resize_token_embeddings(model_vocab_size + 10)
625
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
626
            output_embeds = model.get_output_embeddings()
627
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
628
            # Check bias if present
629
            if output_embeds.bias is not None:
630
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
631
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
632
            model(**self._prepare_for_class(inputs_dict, model_class))
633

634
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
635
            model.resize_token_embeddings(model_vocab_size - 15)
636
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
637
            # Check that it actually resizes the embeddings matrix
638
            output_embeds = model.get_output_embeddings()
639
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
640
            # Check bias if present
641
            if output_embeds.bias is not None:
642
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
643
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
644
            if "decoder_input_ids" in inputs_dict:
645
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
646
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
647
            model(**self._prepare_for_class(inputs_dict, model_class))
648

649
    def test_resize_tokens_embeddings(self):
650
        original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
651
        if not self.test_resize_embeddings:
652
            return
653

654
        for model_class in self.all_model_classes:
655
            config = copy.deepcopy(original_config)
656
            model = model_class(config)
657
            model.to(torch_device)
658

659
            if self.model_tester.is_training is False:
660
                model.eval()
661

662
            model_vocab_size = config.vocab_size
663
            # Retrieve the embeddings and clone theme
664
            model_embed = model.resize_token_embeddings(model_vocab_size)
665
            cloned_embeddings = model_embed.weight.clone()
666

667
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
668
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
669
            self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
670
            # Check that it actually resizes the embeddings matrix
671
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
672
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
673
            model(**self._prepare_for_class(inputs_dict, model_class))
674

675
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
676
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
677
            self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
678
            # Check that it actually resizes the embeddings matrix
679
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
680

681
            # make sure that decoder_input_ids are resized
682
            if "decoder_input_ids" in inputs_dict:
683
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
684
            model(**self._prepare_for_class(inputs_dict, model_class))
685

686
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
687
            models_equal = True
688
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
689
                if p1.data.ne(p2.data).sum() > 0:
690
                    models_equal = False
691

692
            self.assertTrue(models_equal)
693

694
    def test_retain_grad_hidden_states_attentions(self):
695
        # decoder cannot keep gradients
696
        pass
697

698
    # training is not supported yet
699
    def test_training(self):
700
        pass
701

702
    def test_training_gradient_checkpointing(self):
703
        pass
704

705
    @unittest.skip(
706
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
707
    )
708
    def test_training_gradient_checkpointing_use_reentrant(self):
709
        pass
710

711
    @unittest.skip(
712
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
713
    )
714
    def test_training_gradient_checkpointing_use_reentrant_false(self):
715
        pass
716

717
    # overwrite from test_modeling_common
718
    def _mock_init_weights(self, module):
719
        if hasattr(module, "weight") and module.weight is not None:
720
            module.weight.data.fill_(3)
721
        if hasattr(module, "weight_g") and module.weight_g is not None:
722
            module.weight_g.data.fill_(3)
723
        if hasattr(module, "weight_v") and module.weight_v is not None:
724
            module.weight_v.data.fill_(3)
725
        if hasattr(module, "bias") and module.bias is not None:
726
            module.bias.data.fill_(3)
727
        if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
728
            module.masked_spec_embed.data.fill_(3)
729

730

731
@require_torch
732
@require_sentencepiece
733
@require_tokenizers
734
@slow
735
class SpeechT5ForSpeechToTextIntegrationTests(unittest.TestCase):
736
    @cached_property
737
    def default_processor(self):
738
        return SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
739

740
    def _load_datasamples(self, num_samples):
741
        from datasets import load_dataset
742

743
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
744
        # automatic decoding with librispeech
745
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
746

747
        return [x["array"] for x in speech_samples]
748

749
    def test_generation_librispeech(self):
750
        model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")
751
        model.to(torch_device)
752
        processor = self.default_processor
753

754
        input_speech = self._load_datasamples(1)
755

756
        input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
757

758
        generated_ids = model.generate(input_values)
759
        generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
760

761
        EXPECTED_TRANSCRIPTIONS = [
762
            "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
763
        ]
764
        self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
765

766
    def test_generation_librispeech_batched(self):
767
        model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")
768
        model.to(torch_device)
769
        processor = self.default_processor
770

771
        input_speech = self._load_datasamples(4)
772

773
        inputs = processor(audio=input_speech, return_tensors="pt", padding=True)
774

775
        input_values = inputs.input_values.to(torch_device)
776
        attention_mask = inputs.attention_mask.to(torch_device)
777

778
        generated_ids = model.generate(input_values, attention_mask=attention_mask)
779
        generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
780

781
        EXPECTED_TRANSCRIPTIONS = [
782
            "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
783
            "nor is mister quilter's manner less interesting than his matter",
784
            "he tells us that at this festive season of the year with christmas and rosebeaf looming before us"
785
            " similars drawn from eating and its results occur most readily to the mind",
786
            "he has grave doubts whether sir frederick latin's work is really greek after all and can discover in it"
787
            " but little of rocky ithica",
788
        ]
789
        self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
790

791

792
@require_torch
793
class SpeechT5ForTextToSpeechTester:
794
    def __init__(
795
        self,
796
        parent,
797
        batch_size=13,
798
        encoder_seq_length=7,
799
        decoder_seq_length=1024,  # speech is longer
800
        is_training=False,
801
        hidden_size=24,
802
        num_hidden_layers=2,
803
        num_attention_heads=2,
804
        intermediate_size=4,
805
        vocab_size=81,
806
        num_mel_bins=20,
807
        reduction_factor=2,
808
        speech_decoder_postnet_layers=2,
809
        speech_decoder_postnet_units=32,
810
        speech_decoder_prenet_units=32,
811
    ):
812
        self.parent = parent
813
        self.batch_size = batch_size
814
        self.encoder_seq_length = encoder_seq_length
815
        self.decoder_seq_length = decoder_seq_length
816
        self.is_training = is_training
817
        self.hidden_size = hidden_size
818
        self.num_hidden_layers = num_hidden_layers
819
        self.num_attention_heads = num_attention_heads
820
        self.intermediate_size = intermediate_size
821
        self.vocab_size = vocab_size
822
        self.num_mel_bins = num_mel_bins
823
        self.reduction_factor = reduction_factor
824
        self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
825
        self.speech_decoder_postnet_units = speech_decoder_postnet_units
826
        self.speech_decoder_prenet_units = speech_decoder_prenet_units
827

828
    def prepare_config_and_inputs(self):
829
        input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size).clamp(2)
830
        attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])
831

832
        decoder_input_values = floats_tensor([self.batch_size, self.decoder_seq_length, self.num_mel_bins], scale=1.0)
833
        decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])
834

835
        config = self.get_config()
836
        inputs_dict = prepare_inputs_dict(
837
            config,
838
            input_ids=input_ids,
839
            decoder_input_values=decoder_input_values,
840
            attention_mask=attention_mask,
841
            decoder_attention_mask=decoder_attention_mask,
842
        )
843
        return config, inputs_dict
844

845
    def prepare_config_and_inputs_for_common(self):
846
        config, inputs_dict = self.prepare_config_and_inputs()
847
        return config, inputs_dict
848

849
    def get_config(self):
850
        return SpeechT5Config(
851
            hidden_size=self.hidden_size,
852
            encoder_layers=self.num_hidden_layers,
853
            decoder_layers=self.num_hidden_layers,
854
            encoder_attention_heads=self.num_attention_heads,
855
            decoder_attention_heads=self.num_attention_heads,
856
            encoder_ffn_dim=self.intermediate_size,
857
            decoder_ffn_dim=self.intermediate_size,
858
            vocab_size=self.vocab_size,
859
            num_mel_bins=self.num_mel_bins,
860
            reduction_factor=self.reduction_factor,
861
            speech_decoder_postnet_layers=self.speech_decoder_postnet_layers,
862
            speech_decoder_postnet_units=self.speech_decoder_postnet_units,
863
            speech_decoder_prenet_units=self.speech_decoder_prenet_units,
864
        )
865

866
    def create_and_check_model_forward(self, config, inputs_dict):
867
        model = SpeechT5ForTextToSpeech(config=config).to(torch_device).eval()
868

869
        input_ids = inputs_dict["input_ids"]
870
        attention_mask = inputs_dict["attention_mask"]
871
        decoder_input_values = inputs_dict["decoder_input_values"]
872

873
        result = model(input_ids, attention_mask=attention_mask, decoder_input_values=decoder_input_values)
874
        self.parent.assertEqual(
875
            result.spectrogram.shape,
876
            (self.batch_size, self.decoder_seq_length * self.reduction_factor, self.num_mel_bins),
877
        )
878

879

880
@require_torch
881
class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
882
    all_model_classes = (SpeechT5ForTextToSpeech,) if is_torch_available() else ()
883
    all_generative_model_classes = (SpeechT5ForTextToSpeech,) if is_torch_available() else ()
884
    is_encoder_decoder = True
885
    test_pruning = False
886
    test_headmasking = False
887

888
    input_name = "input_ids"
889

890
    def setUp(self):
891
        self.model_tester = SpeechT5ForTextToSpeechTester(self)
892
        self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
893

894
    def test_config(self):
895
        self.config_tester.run_common_tests()
896

897
    def test_save_load_strict(self):
898
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
899
        for model_class in self.all_model_classes:
900
            model = model_class(config)
901

902
            with tempfile.TemporaryDirectory() as tmpdirname:
903
                model.save_pretrained(tmpdirname)
904
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
905
            self.assertEqual(info["missing_keys"], [])
906

907
    def test_model_forward(self):
908
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
909
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
910

911
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
912
    def test_decoder_model_past_with_large_inputs(self):
913
        pass
914

915
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
916
    def test_determinism(self):
917
        pass
918

919
    def test_forward_signature(self):
920
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
921

922
        for model_class in self.all_model_classes:
923
            model = model_class(config)
924
            signature = inspect.signature(model.forward)
925
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
926
            arg_names = [*signature.parameters.keys()]
927

928
            expected_arg_names = [
929
                "input_ids",
930
                "attention_mask",
931
                "decoder_input_values",
932
                "decoder_attention_mask",
933
            ]
934
            expected_arg_names.extend(
935
                ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
936
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
937
                else ["encoder_outputs"]
938
            )
939
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
940

941
    def test_initialization(self):
942
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
943

944
        configs_no_init = _config_zero_init(config)
945
        for model_class in self.all_model_classes:
946
            model = model_class(config=configs_no_init)
947
            for name, param in model.named_parameters():
948
                uniform_init_parms = [
949
                    "conv.weight",
950
                ]
951
                if param.requires_grad:
952
                    if any(x in name for x in uniform_init_parms):
953
                        self.assertTrue(
954
                            -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
955
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
956
                        )
957
                    else:
958
                        self.assertIn(
959
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
960
                            [0.0, 1.0],
961
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
962
                        )
963

964
    # this model has no inputs_embeds
965
    def test_inputs_embeds(self):
966
        pass
967

968
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
969
    def test_model_outputs_equivalence(self):
970
        pass
971

972
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
973
    def test_save_load(self):
974
        pass
975

976
    def test_retain_grad_hidden_states_attentions(self):
977
        # decoder cannot keep gradients
978
        pass
979

980
    @slow
981
    def test_torchscript_output_attentions(self):
982
        # disabled because this model doesn't have decoder_input_ids
983
        pass
984

985
    @slow
986
    def test_torchscript_output_hidden_state(self):
987
        # disabled because this model doesn't have decoder_input_ids
988
        pass
989

990
    @slow
991
    def test_torchscript_simple(self):
992
        # disabled because this model doesn't have decoder_input_ids
993
        pass
994

995
    # training is not supported yet
996
    def test_training(self):
997
        pass
998

999
    def test_training_gradient_checkpointing(self):
1000
        pass
1001

1002
    @unittest.skip(
1003
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
1004
    )
1005
    def test_training_gradient_checkpointing_use_reentrant(self):
1006
        pass
1007

1008
    @unittest.skip(
1009
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
1010
    )
1011
    def test_training_gradient_checkpointing_use_reentrant_false(self):
1012
        pass
1013

1014
    # overwrite from test_modeling_common
1015
    def _mock_init_weights(self, module):
1016
        if hasattr(module, "weight") and module.weight is not None:
1017
            module.weight.data.fill_(3)
1018
        if hasattr(module, "weight_g") and module.weight_g is not None:
1019
            module.weight_g.data.fill_(3)
1020
        if hasattr(module, "weight_v") and module.weight_v is not None:
1021
            module.weight_v.data.fill_(3)
1022
        if hasattr(module, "bias") and module.bias is not None:
1023
            module.bias.data.fill_(3)
1024

1025

1026
@require_torch
1027
@require_sentencepiece
1028
@require_tokenizers
1029
class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
1030
    @cached_property
1031
    def default_model(self):
1032
        return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device)
1033

1034
    @cached_property
1035
    def default_processor(self):
1036
        return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
1037

1038
    @cached_property
1039
    def default_vocoder(self):
1040
        return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device)
1041

1042
    def test_generation(self):
1043
        model = self.default_model
1044
        processor = self.default_processor
1045

1046
        input_text = "Mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
1047
        input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
1048
        speaker_embeddings = torch.zeros((1, 512), device=torch_device)
1049

1050
        # Generate speech and validate output dimensions
1051
        set_seed(555)  # Ensure deterministic behavior
1052
        generated_speech = model.generate_speech(input_ids, speaker_embeddings=speaker_embeddings)
1053
        num_mel_bins = model.config.num_mel_bins
1054
        self.assertEqual(
1055
            generated_speech.shape[1], num_mel_bins, "Generated speech output has an unexpected number of mel bins."
1056
        )
1057

1058
        # Validate generation with additional kwargs using model.generate;
1059
        # same method than generate_speech
1060
        set_seed(555)  # Reset seed for consistent results
1061
        generated_speech_with_generate = model.generate(
1062
            input_ids, attention_mask=None, speaker_embeddings=speaker_embeddings
1063
        )
1064
        self.assertEqual(
1065
            generated_speech_with_generate.shape,
1066
            generated_speech.shape,
1067
            "Shape mismatch between generate_speech and generate methods.",
1068
        )
1069

1070
    def test_one_to_many_generation(self):
1071
        model = self.default_model
1072
        processor = self.default_processor
1073
        vocoder = self.default_vocoder
1074

1075
        input_text = [
1076
            "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
1077
            "nor is mister quilter's manner less interesting than his matter",
1078
            "he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
1079
        ]
1080
        inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
1081
        speaker_embeddings = torch.zeros((1, 512), device=torch_device)
1082

1083
        # Generate spectrograms
1084
        set_seed(555)  # Ensure deterministic behavior
1085
        spectrograms, spectrogram_lengths = model.generate_speech(
1086
            input_ids=inputs["input_ids"],
1087
            speaker_embeddings=speaker_embeddings,
1088
            attention_mask=inputs["attention_mask"],
1089
            return_output_lengths=True,
1090
        )
1091

1092
        # Validate generated spectrogram dimensions
1093
        expected_batch_size = len(input_text)
1094
        num_mel_bins = model.config.num_mel_bins
1095
        actual_batch_size, _, actual_num_mel_bins = spectrograms.shape
1096
        self.assertEqual(actual_batch_size, expected_batch_size, "Batch size of generated spectrograms is incorrect.")
1097
        self.assertEqual(
1098
            actual_num_mel_bins, num_mel_bins, "Number of mel bins in batch generated spectrograms is incorrect."
1099
        )
1100

1101
        # Generate waveforms using the vocoder
1102
        waveforms = vocoder(spectrograms)
1103
        waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
1104

1105
        # Validate generation with integrated vocoder
1106
        set_seed(555)  # Reset seed for consistent results
1107
        waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
1108
            input_ids=inputs["input_ids"],
1109
            speaker_embeddings=speaker_embeddings,
1110
            attention_mask=inputs["attention_mask"],
1111
            vocoder=vocoder,
1112
            return_output_lengths=True,
1113
        )
1114

1115
        # Check consistency between waveforms generated with and without standalone vocoder
1116
        self.assertTrue(
1117
            torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),
1118
            "Mismatch in waveforms generated with and without the standalone vocoder.",
1119
        )
1120
        self.assertEqual(
1121
            waveform_lengths,
1122
            waveform_lengths_with_vocoder,
1123
            "Waveform lengths differ between standalone and integrated vocoder generation.",
1124
        )
1125

1126
        # Test generation consistency without returning lengths
1127
        set_seed(555)  # Reset seed for consistent results
1128
        waveforms_with_vocoder_no_lengths = model.generate_speech(
1129
            input_ids=inputs["input_ids"],
1130
            speaker_embeddings=speaker_embeddings,
1131
            attention_mask=inputs["attention_mask"],
1132
            vocoder=vocoder,
1133
            return_output_lengths=False,
1134
        )
1135

1136
        # Validate waveform consistency without length information
1137
        self.assertTrue(
1138
            torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),
1139
            "Waveforms differ when generated with and without length information.",
1140
        )
1141

1142
        # Validate batch vs. single instance generation consistency
1143
        for i, text in enumerate(input_text):
1144
            inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
1145
            set_seed(555)  # Reset seed for consistent results
1146
            spectrogram = model.generate_speech(
1147
                input_ids=inputs["input_ids"],
1148
                speaker_embeddings=speaker_embeddings,
1149
            )
1150

1151
            # Check spectrogram shape consistency
1152
            self.assertEqual(
1153
                spectrogram.shape,
1154
                spectrograms[i][: spectrogram_lengths[i]].shape,
1155
                "Mismatch in spectrogram shape between batch and single instance generation.",
1156
            )
1157

1158
            # Generate and validate waveform for single instance
1159
            waveform = vocoder(spectrogram)
1160
            self.assertEqual(
1161
                waveform.shape,
1162
                waveforms[i][: waveform_lengths[i]].shape,
1163
                "Mismatch in waveform shape between batch and single instance generation.",
1164
            )
1165

1166
            # Check waveform consistency with integrated vocoder
1167
            set_seed(555)  # Reset seed for consistent results
1168
            waveform_with_integrated_vocoder = model.generate_speech(
1169
                input_ids=inputs["input_ids"],
1170
                speaker_embeddings=speaker_embeddings,
1171
                vocoder=vocoder,
1172
            )
1173
            self.assertTrue(
1174
                torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),
1175
                "Mismatch in waveform between standalone and integrated vocoder for single instance generation.",
1176
            )
1177

1178
    def test_batch_generation(self):
1179
        model = self.default_model
1180
        processor = self.default_processor
1181
        vocoder = self.default_vocoder
1182

1183
        input_text = [
1184
            "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
1185
            "nor is mister quilter's manner less interesting than his matter",
1186
            "he tells us that at this festive season of the year with christmas and rosebeaf looming before us",
1187
        ]
1188
        inputs = processor(text=input_text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
1189
        set_seed(555)  # Ensure deterministic behavior
1190
        speaker_embeddings = torch.randn((len(input_text), 512), device=torch_device)
1191

1192
        # Generate spectrograms
1193
        set_seed(555)  # Reset seed for consistent results
1194
        spectrograms, spectrogram_lengths = model.generate_speech(
1195
            input_ids=inputs["input_ids"],
1196
            speaker_embeddings=speaker_embeddings,
1197
            attention_mask=inputs["attention_mask"],
1198
            return_output_lengths=True,
1199
        )
1200

1201
        # Validate generated spectrogram dimensions
1202
        expected_batch_size = len(input_text)
1203
        num_mel_bins = model.config.num_mel_bins
1204
        actual_batch_size, _, actual_num_mel_bins = spectrograms.shape
1205
        self.assertEqual(
1206
            actual_batch_size,
1207
            expected_batch_size,
1208
            "Batch size of generated spectrograms is incorrect.",
1209
        )
1210
        self.assertEqual(
1211
            actual_num_mel_bins,
1212
            num_mel_bins,
1213
            "Number of mel bins in batch generated spectrograms is incorrect.",
1214
        )
1215

1216
        # Generate waveforms using the vocoder
1217
        waveforms = vocoder(spectrograms)
1218
        waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
1219

1220
        # Validate generation with integrated vocoder
1221
        set_seed(555)  # Reset seed for consistent results
1222
        waveforms_with_vocoder, waveform_lengths_with_vocoder = model.generate_speech(
1223
            input_ids=inputs["input_ids"],
1224
            speaker_embeddings=speaker_embeddings,
1225
            attention_mask=inputs["attention_mask"],
1226
            vocoder=vocoder,
1227
            return_output_lengths=True,
1228
        )
1229

1230
        # Check consistency between waveforms generated with and without standalone vocoder
1231
        self.assertTrue(
1232
            torch.allclose(waveforms, waveforms_with_vocoder, atol=1e-8),
1233
            "Mismatch in waveforms generated with and without the standalone vocoder.",
1234
        )
1235
        self.assertEqual(
1236
            waveform_lengths,
1237
            waveform_lengths_with_vocoder,
1238
            "Waveform lengths differ between standalone and integrated vocoder generation.",
1239
        )
1240

1241
        # Test generation consistency without returning lengths
1242
        set_seed(555)  # Reset seed for consistent results
1243
        waveforms_with_vocoder_no_lengths = model.generate_speech(
1244
            input_ids=inputs["input_ids"],
1245
            speaker_embeddings=speaker_embeddings,
1246
            attention_mask=inputs["attention_mask"],
1247
            vocoder=vocoder,
1248
            return_output_lengths=False,
1249
        )
1250

1251
        # Validate waveform consistency without length information
1252
        self.assertTrue(
1253
            torch.allclose(waveforms_with_vocoder_no_lengths, waveforms_with_vocoder, atol=1e-8),
1254
            "Waveforms differ when generated with and without length information.",
1255
        )
1256

1257
        # Validate batch vs. single instance generation consistency
1258
        for i, text in enumerate(input_text):
1259
            inputs = processor(text=text, padding="max_length", max_length=128, return_tensors="pt").to(torch_device)
1260
            current_speaker_embedding = speaker_embeddings[i].unsqueeze(0)
1261
            set_seed(555)  # Reset seed for consistent results
1262
            spectrogram = model.generate_speech(
1263
                input_ids=inputs["input_ids"],
1264
                speaker_embeddings=current_speaker_embedding,
1265
            )
1266

1267
            # Check spectrogram shape consistency
1268
            self.assertEqual(
1269
                spectrogram.shape,
1270
                spectrograms[i][: spectrogram_lengths[i]].shape,
1271
                "Mismatch in spectrogram shape between batch and single instance generation.",
1272
            )
1273

1274
            # Generate and validate waveform for single instance
1275
            waveform = vocoder(spectrogram)
1276
            self.assertEqual(
1277
                waveform.shape,
1278
                waveforms[i][: waveform_lengths[i]].shape,
1279
                "Mismatch in waveform shape between batch and single instance generation.",
1280
            )
1281

1282
            # Check waveform consistency with integrated vocoder
1283
            set_seed(555)  # Reset seed for consistent results
1284
            waveform_with_integrated_vocoder = model.generate_speech(
1285
                input_ids=inputs["input_ids"],
1286
                speaker_embeddings=current_speaker_embedding,
1287
                vocoder=vocoder,
1288
            )
1289
            self.assertTrue(
1290
                torch.allclose(waveform, waveform_with_integrated_vocoder, atol=1e-8),
1291
                "Mismatch in waveform between standalone and integrated vocoder for single instance generation.",
1292
            )
1293

1294

1295
@require_torch
1296
class SpeechT5ForSpeechToSpeechTester:
1297
    def __init__(
1298
        self,
1299
        parent,
1300
        batch_size=13,
1301
        encoder_seq_length=1024,  # speech is longer
1302
        decoder_seq_length=1024,
1303
        is_training=False,
1304
        hidden_size=24,
1305
        num_hidden_layers=2,
1306
        num_attention_heads=2,
1307
        intermediate_size=4,
1308
        conv_dim=(32, 32, 32),
1309
        conv_stride=(4, 4, 4),
1310
        conv_kernel=(8, 8, 8),
1311
        conv_bias=False,
1312
        num_conv_pos_embeddings=16,
1313
        num_conv_pos_embedding_groups=2,
1314
        vocab_size=81,
1315
        num_mel_bins=20,
1316
        reduction_factor=2,
1317
        speech_decoder_postnet_layers=2,
1318
        speech_decoder_postnet_units=32,
1319
        speech_decoder_prenet_units=32,
1320
    ):
1321
        self.parent = parent
1322
        self.batch_size = batch_size
1323
        self.encoder_seq_length = encoder_seq_length
1324
        self.decoder_seq_length = decoder_seq_length
1325
        self.is_training = is_training
1326
        self.hidden_size = hidden_size
1327
        self.num_hidden_layers = num_hidden_layers
1328
        self.num_attention_heads = num_attention_heads
1329
        self.intermediate_size = intermediate_size
1330
        self.conv_dim = conv_dim
1331
        self.conv_stride = conv_stride
1332
        self.conv_kernel = conv_kernel
1333
        self.conv_bias = conv_bias
1334
        self.num_conv_pos_embeddings = num_conv_pos_embeddings
1335
        self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
1336
        self.vocab_size = vocab_size
1337
        self.num_mel_bins = num_mel_bins
1338
        self.reduction_factor = reduction_factor
1339
        self.speech_decoder_postnet_layers = speech_decoder_postnet_layers
1340
        self.speech_decoder_postnet_units = speech_decoder_postnet_units
1341
        self.speech_decoder_prenet_units = speech_decoder_prenet_units
1342

1343
    def prepare_config_and_inputs(self):
1344
        input_values = floats_tensor([self.batch_size, self.encoder_seq_length], scale=1.0)
1345
        attention_mask = random_attention_mask([self.batch_size, self.encoder_seq_length])
1346

1347
        decoder_input_values = floats_tensor([self.batch_size, self.decoder_seq_length, self.num_mel_bins], scale=1.0)
1348
        decoder_attention_mask = random_attention_mask([self.batch_size, self.decoder_seq_length])
1349

1350
        config = self.get_config()
1351
        inputs_dict = prepare_inputs_dict(
1352
            config,
1353
            input_values=input_values,
1354
            decoder_input_values=decoder_input_values,
1355
            attention_mask=attention_mask,
1356
            decoder_attention_mask=decoder_attention_mask,
1357
        )
1358
        return config, inputs_dict
1359

1360
    def prepare_config_and_inputs_for_common(self):
1361
        config, inputs_dict = self.prepare_config_and_inputs()
1362
        return config, inputs_dict
1363

1364
    def get_config(self):
1365
        return SpeechT5Config(
1366
            hidden_size=self.hidden_size,
1367
            encoder_layers=self.num_hidden_layers,
1368
            decoder_layers=self.num_hidden_layers,
1369
            encoder_attention_heads=self.num_attention_heads,
1370
            decoder_attention_heads=self.num_attention_heads,
1371
            encoder_ffn_dim=self.intermediate_size,
1372
            decoder_ffn_dim=self.intermediate_size,
1373
            conv_dim=self.conv_dim,
1374
            conv_stride=self.conv_stride,
1375
            conv_kernel=self.conv_kernel,
1376
            conv_bias=self.conv_bias,
1377
            num_conv_pos_embeddings=self.num_conv_pos_embeddings,
1378
            num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
1379
            vocab_size=self.vocab_size,
1380
            num_mel_bins=self.num_mel_bins,
1381
            reduction_factor=self.reduction_factor,
1382
            speech_decoder_postnet_layers=self.speech_decoder_postnet_layers,
1383
            speech_decoder_postnet_units=self.speech_decoder_postnet_units,
1384
            speech_decoder_prenet_units=self.speech_decoder_prenet_units,
1385
        )
1386

1387
    def create_and_check_model_forward(self, config, inputs_dict):
1388
        model = SpeechT5ForSpeechToSpeech(config=config).to(torch_device).eval()
1389

1390
        input_values = inputs_dict["input_values"]
1391
        attention_mask = inputs_dict["attention_mask"]
1392
        decoder_input_values = inputs_dict["decoder_input_values"]
1393

1394
        result = model(input_values, attention_mask=attention_mask, decoder_input_values=decoder_input_values)
1395
        self.parent.assertEqual(
1396
            result.spectrogram.shape,
1397
            (self.batch_size, self.decoder_seq_length * self.reduction_factor, self.num_mel_bins),
1398
        )
1399

1400

1401
@require_torch
1402
class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
1403
    all_model_classes = (SpeechT5ForSpeechToSpeech,) if is_torch_available() else ()
1404
    all_generative_model_classes = (SpeechT5ForSpeechToSpeech,) if is_torch_available() else ()
1405
    is_encoder_decoder = True
1406
    test_pruning = False
1407
    test_headmasking = False
1408
    test_resize_embeddings = False
1409

1410
    input_name = "input_values"
1411

1412
    def setUp(self):
1413
        self.model_tester = SpeechT5ForSpeechToSpeechTester(self)
1414
        self.config_tester = ConfigTester(self, config_class=SpeechT5Config, hidden_size=37)
1415

1416
    def test_config(self):
1417
        self.config_tester.run_common_tests()
1418

1419
    def test_save_load_strict(self):
1420
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
1421
        for model_class in self.all_model_classes:
1422
            model = model_class(config)
1423

1424
            with tempfile.TemporaryDirectory() as tmpdirname:
1425
                model.save_pretrained(tmpdirname)
1426
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
1427
            self.assertEqual(info["missing_keys"], [])
1428

1429
    def test_model_forward(self):
1430
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
1431
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
1432

1433
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
1434
    def test_decoder_model_past_with_large_inputs(self):
1435
        pass
1436

1437
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
1438
    def test_determinism(self):
1439
        pass
1440

1441
    def test_attention_outputs(self):
1442
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1443
        config.return_dict = True
1444

1445
        seq_len = getattr(self.model_tester, "seq_length", None)
1446
        decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
1447
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
1448
        decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
1449
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
1450

1451
        for model_class in self.all_model_classes:
1452
            inputs_dict["output_attentions"] = True
1453
            inputs_dict["output_hidden_states"] = False
1454
            config.return_dict = True
1455
            model = model_class(config)
1456
            model.to(torch_device)
1457
            model.eval()
1458

1459
            subsampled_encoder_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(
1460
                encoder_seq_length
1461
            )
1462
            subsampled_encoder_key_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(
1463
                encoder_key_length
1464
            )
1465

1466
            with torch.no_grad():
1467
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1468
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
1469
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
1470

1471
            # check that output_attentions also work using config
1472
            del inputs_dict["output_attentions"]
1473
            config.output_attentions = True
1474
            model = model_class(config)
1475
            model.to(torch_device)
1476
            model.eval()
1477
            with torch.no_grad():
1478
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1479
            attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
1480
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
1481

1482
            self.assertListEqual(
1483
                list(attentions[0].shape[-3:]),
1484
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
1485
            )
1486
            out_len = len(outputs)
1487

1488
            correct_outlen = 5
1489

1490
            # loss is at first position
1491
            if "labels" in inputs_dict:
1492
                correct_outlen += 1  # loss is added to beginning
1493
            if "past_key_values" in outputs:
1494
                correct_outlen += 1  # past_key_values have been returned
1495

1496
            self.assertEqual(out_len, correct_outlen)
1497

1498
            # decoder attentions
1499
            decoder_attentions = outputs.decoder_attentions
1500
            self.assertIsInstance(decoder_attentions, (list, tuple))
1501
            self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
1502
            self.assertListEqual(
1503
                list(decoder_attentions[0].shape[-3:]),
1504
                [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
1505
            )
1506

1507
            # cross attentions
1508
            cross_attentions = outputs.cross_attentions
1509
            self.assertIsInstance(cross_attentions, (list, tuple))
1510
            self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
1511
            self.assertListEqual(
1512
                list(cross_attentions[0].shape[-3:]),
1513
                [
1514
                    self.model_tester.num_attention_heads,
1515
                    decoder_seq_length,
1516
                    subsampled_encoder_key_length,
1517
                ],
1518
            )
1519

1520
            # Check attention is always last and order is fine
1521
            inputs_dict["output_attentions"] = True
1522
            inputs_dict["output_hidden_states"] = True
1523
            model = model_class(config)
1524
            model.to(torch_device)
1525
            model.eval()
1526
            with torch.no_grad():
1527
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1528

1529
            added_hidden_states = 2
1530
            self.assertEqual(out_len + added_hidden_states, len(outputs))
1531

1532
            self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
1533

1534
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
1535
            self.assertListEqual(
1536
                list(self_attentions[0].shape[-3:]),
1537
                [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
1538
            )
1539

1540
    def test_forward_signature(self):
1541
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
1542

1543
        for model_class in self.all_model_classes:
1544
            model = model_class(config)
1545
            signature = inspect.signature(model.forward)
1546
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
1547
            arg_names = [*signature.parameters.keys()]
1548

1549
            expected_arg_names = [
1550
                "input_values",
1551
                "attention_mask",
1552
                "decoder_input_values",
1553
                "decoder_attention_mask",
1554
            ]
1555
            expected_arg_names.extend(
1556
                ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
1557
                if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
1558
                else ["encoder_outputs"]
1559
            )
1560
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
1561

1562
    def test_hidden_states_output(self):
1563
        def check_hidden_states_output(inputs_dict, config, model_class):
1564
            model = model_class(config)
1565
            model.to(torch_device)
1566
            model.eval()
1567

1568
            with torch.no_grad():
1569
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
1570

1571
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
1572

1573
            expected_num_layers = getattr(
1574
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
1575
            )
1576
            self.assertEqual(len(hidden_states), expected_num_layers)
1577

1578
            if hasattr(self.model_tester, "encoder_seq_length"):
1579
                seq_length = self.model_tester.encoder_seq_length
1580
            else:
1581
                seq_length = self.model_tester.seq_length
1582

1583
            subsampled_seq_length = model.speecht5.encoder.prenet._get_feat_extract_output_lengths(seq_length)
1584

1585
            self.assertListEqual(
1586
                list(hidden_states[0].shape[-2:]),
1587
                [subsampled_seq_length, self.model_tester.hidden_size],
1588
            )
1589

1590
            if config.is_encoder_decoder:
1591
                hidden_states = outputs.decoder_hidden_states
1592

1593
                self.assertIsInstance(hidden_states, (list, tuple))
1594
                self.assertEqual(len(hidden_states), expected_num_layers)
1595
                seq_len = getattr(self.model_tester, "seq_length", None)
1596
                decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
1597

1598
                self.assertListEqual(
1599
                    list(hidden_states[0].shape[-2:]),
1600
                    [decoder_seq_length, self.model_tester.hidden_size],
1601
                )
1602

1603
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1604

1605
        for model_class in self.all_model_classes:
1606
            inputs_dict["output_hidden_states"] = True
1607
            check_hidden_states_output(inputs_dict, config, model_class)
1608

1609
            # check that output_hidden_states also work using config
1610
            del inputs_dict["output_hidden_states"]
1611
            config.output_hidden_states = True
1612

1613
            check_hidden_states_output(inputs_dict, config, model_class)
1614

1615
    def test_initialization(self):
1616
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
1617

1618
        configs_no_init = _config_zero_init(config)
1619
        for model_class in self.all_model_classes:
1620
            model = model_class(config=configs_no_init)
1621
            for name, param in model.named_parameters():
1622
                uniform_init_parms = [
1623
                    "conv.weight",
1624
                    "conv.parametrizations.weight",
1625
                    "masked_spec_embed",
1626
                    "feature_projection.projection.weight",
1627
                    "feature_projection.projection.bias",
1628
                ]
1629
                if param.requires_grad:
1630
                    if any(x in name for x in uniform_init_parms):
1631
                        self.assertTrue(
1632
                            -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
1633
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
1634
                        )
1635
                    else:
1636
                        self.assertIn(
1637
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
1638
                            [0.0, 1.0],
1639
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
1640
                        )
1641

1642
    # this model has no inputs_embeds
1643
    def test_inputs_embeds(self):
1644
        pass
1645

1646
    # this model has no input embeddings
1647
    def test_model_common_attributes(self):
1648
        pass
1649

1650
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
1651
    def test_model_outputs_equivalence(self):
1652
        pass
1653

1654
    def test_retain_grad_hidden_states_attentions(self):
1655
        # decoder cannot keep gradients
1656
        pass
1657

1658
    # skipped because there is always dropout in SpeechT5SpeechDecoderPrenet
1659
    def test_save_load(self):
1660
        pass
1661

1662
    @slow
1663
    def test_torchscript_output_attentions(self):
1664
        # disabled because this model doesn't have decoder_input_ids
1665
        pass
1666

1667
    @slow
1668
    def test_torchscript_output_hidden_state(self):
1669
        # disabled because this model doesn't have decoder_input_ids
1670
        pass
1671

1672
    @slow
1673
    def test_torchscript_simple(self):
1674
        # disabled because this model doesn't have decoder_input_ids
1675
        pass
1676

1677
    # training is not supported yet
1678
    def test_training(self):
1679
        pass
1680

1681
    def test_training_gradient_checkpointing(self):
1682
        pass
1683

1684
    @unittest.skip(
1685
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
1686
    )
1687
    def test_training_gradient_checkpointing_use_reentrant(self):
1688
        pass
1689

1690
    @unittest.skip(
1691
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
1692
    )
1693
    def test_training_gradient_checkpointing_use_reentrant_false(self):
1694
        pass
1695

1696
    # overwrite from test_modeling_common
1697
    def _mock_init_weights(self, module):
1698
        if hasattr(module, "weight") and module.weight is not None:
1699
            module.weight.data.fill_(3)
1700
        if hasattr(module, "weight_g") and module.weight_g is not None:
1701
            module.weight_g.data.fill_(3)
1702
        if hasattr(module, "weight_v") and module.weight_v is not None:
1703
            module.weight_v.data.fill_(3)
1704
        if hasattr(module, "bias") and module.bias is not None:
1705
            module.bias.data.fill_(3)
1706
        if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
1707
            module.masked_spec_embed.data.fill_(3)
1708

1709

1710
@require_torch
1711
@require_sentencepiece
1712
@require_tokenizers
1713
@slow
1714
class SpeechT5ForSpeechToSpeechIntegrationTests(unittest.TestCase):
1715
    @cached_property
1716
    def default_processor(self):
1717
        return SpeechT5Processor.from_pretrained("microsoft/speecht5_vc")
1718

1719
    def _load_datasamples(self, num_samples):
1720
        from datasets import load_dataset
1721

1722
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1723
        # automatic decoding with librispeech
1724
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
1725

1726
        return [x["array"] for x in speech_samples]
1727

1728
    def test_generation_librispeech(self):
1729
        model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc")
1730
        model.to(torch_device)
1731
        processor = self.default_processor
1732

1733
        input_speech = self._load_datasamples(1)
1734
        input_values = processor(audio=input_speech, return_tensors="pt").input_values.to(torch_device)
1735

1736
        speaker_embeddings = torch.zeros((1, 512), device=torch_device)
1737
        generated_speech = model.generate_speech(input_values, speaker_embeddings=speaker_embeddings)
1738

1739
        self.assertEqual(generated_speech.shape[1], model.config.num_mel_bins)
1740
        self.assertGreaterEqual(generated_speech.shape[0], 300)
1741
        self.assertLessEqual(generated_speech.shape[0], 310)
1742

1743

1744
class SpeechT5HifiGanTester:
1745
    def __init__(
1746
        self,
1747
        parent,
1748
        batch_size=13,
1749
        seq_length=7,
1750
        is_training=False,
1751
        num_mel_bins=20,
1752
    ):
1753
        self.parent = parent
1754
        self.batch_size = batch_size
1755
        self.seq_length = seq_length
1756
        self.is_training = is_training
1757
        self.num_mel_bins = num_mel_bins
1758

1759
    def prepare_config_and_inputs(self):
1760
        input_values = floats_tensor([self.seq_length, self.num_mel_bins], scale=1.0)
1761
        config = self.get_config()
1762
        return config, input_values
1763

1764
    def get_config(self):
1765
        return SpeechT5HifiGanConfig(
1766
            model_in_dim=self.num_mel_bins,
1767
            upsample_initial_channel=32,
1768
        )
1769

1770
    def create_and_check_model(self, config, input_values):
1771
        model = SpeechT5HifiGan(config=config).to(torch_device).eval()
1772
        result = model(input_values)
1773
        self.parent.assertEqual(result.shape, (self.seq_length * 256,))
1774

1775
    def prepare_config_and_inputs_for_common(self):
1776
        config, input_values = self.prepare_config_and_inputs()
1777
        inputs_dict = {"spectrogram": input_values}
1778
        return config, inputs_dict
1779

1780

1781
@require_torch
1782
class SpeechT5HifiGanTest(ModelTesterMixin, unittest.TestCase):
1783
    all_model_classes = (SpeechT5HifiGan,) if is_torch_available() else ()
1784
    test_torchscript = False
1785
    test_pruning = False
1786
    test_resize_embeddings = False
1787
    test_resize_position_embeddings = False
1788
    test_head_masking = False
1789
    test_mismatched_shapes = False
1790
    test_missing_keys = False
1791
    test_model_parallel = False
1792
    is_encoder_decoder = False
1793
    has_attentions = False
1794

1795
    input_name = "spectrogram"
1796

1797
    def setUp(self):
1798
        self.model_tester = SpeechT5HifiGanTester(self)
1799
        self.config_tester = ConfigTester(self, config_class=SpeechT5HifiGanConfig)
1800

1801
    def test_config(self):
1802
        self.config_tester.create_and_test_config_to_json_string()
1803
        self.config_tester.create_and_test_config_to_json_file()
1804
        self.config_tester.create_and_test_config_from_and_save_pretrained()
1805
        self.config_tester.create_and_test_config_from_and_save_pretrained_subfolder()
1806
        self.config_tester.create_and_test_config_with_num_labels()
1807
        self.config_tester.check_config_can_be_init_without_params()
1808
        self.config_tester.check_config_arguments_init()
1809

1810
    def test_model(self):
1811
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
1812
        self.model_tester.create_and_check_model(*config_and_inputs)
1813

1814
    def test_forward_signature(self):
1815
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
1816

1817
        for model_class in self.all_model_classes:
1818
            model = model_class(config)
1819
            signature = inspect.signature(model.forward)
1820
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
1821
            arg_names = [*signature.parameters.keys()]
1822

1823
            expected_arg_names = [
1824
                "spectrogram",
1825
            ]
1826
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
1827

1828
    # this model does not output hidden states
1829
    def test_hidden_states_output(self):
1830
        pass
1831

1832
    # skip
1833
    def test_initialization(self):
1834
        pass
1835

1836
    # this model has no inputs_embeds
1837
    def test_inputs_embeds(self):
1838
        pass
1839

1840
    # this model has no input embeddings
1841
    def test_model_common_attributes(self):
1842
        pass
1843

1844
    # skip as this model doesn't support all arguments tested
1845
    def test_model_outputs_equivalence(self):
1846
        pass
1847

1848
    # this model does not output hidden states
1849
    def test_retain_grad_hidden_states_attentions(self):
1850
        pass
1851

1852
    # skip because it fails on automapping of SpeechT5HifiGanConfig
1853
    def test_save_load_fast_init_from_base(self):
1854
        pass
1855

1856
    # skip because it fails on automapping of SpeechT5HifiGanConfig
1857
    def test_save_load_fast_init_to_base(self):
1858
        pass
1859

1860
    def test_batched_inputs_outputs(self):
1861
        config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
1862

1863
        for model_class in self.all_model_classes:
1864
            model = model_class(config)
1865
            model.to(torch_device)
1866
            model.eval()
1867

1868
            batched_inputs = inputs["spectrogram"].unsqueeze(0).repeat(2, 1, 1)
1869
            with torch.no_grad():
1870
                batched_outputs = model(batched_inputs.to(torch_device))
1871

1872
            self.assertEqual(
1873
                batched_inputs.shape[0], batched_outputs.shape[0], msg="Got different batch dims for input and output"
1874
            )
1875

1876
    def test_unbatched_inputs_outputs(self):
1877
        config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
1878

1879
        for model_class in self.all_model_classes:
1880
            model = model_class(config)
1881
            model.to(torch_device)
1882
            model.eval()
1883

1884
            with torch.no_grad():
1885
                outputs = model(inputs["spectrogram"].to(torch_device))
1886
            self.assertTrue(outputs.dim() == 1, msg="Got un-batched inputs but batched output")
1887

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.