transformers
430 строк · 17.3 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch VITS model. """
16
17import copy
18import os
19import tempfile
20import unittest
21from typing import Dict, List, Tuple
22
23import numpy as np
24
25from transformers import PretrainedConfig, VitsConfig
26from transformers.testing_utils import (
27is_flaky,
28is_torch_available,
29require_torch,
30require_torch_multi_gpu,
31slow,
32torch_device,
33)
34from transformers.trainer_utils import set_seed
35
36from ...test_configuration_common import ConfigTester
37from ...test_modeling_common import (
38ModelTesterMixin,
39global_rng,
40ids_tensor,
41random_attention_mask,
42)
43from ...test_pipeline_mixin import PipelineTesterMixin
44
45
46if is_torch_available():
47import torch
48
49from transformers import VitsModel, VitsTokenizer
50
51
52CONFIG_NAME = "config.json"
53GENERATION_CONFIG_NAME = "generation_config.json"
54
55
56def _config_zero_init(config):
57configs_no_init = copy.deepcopy(config)
58for key in configs_no_init.__dict__.keys():
59if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
60setattr(configs_no_init, key, 1e-10)
61if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
62no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
63setattr(configs_no_init, key, no_init_subconfig)
64return configs_no_init
65
66
67@require_torch
68class VitsModelTester:
69def __init__(
70self,
71parent,
72batch_size=2,
73seq_length=7,
74is_training=False,
75hidden_size=16,
76num_hidden_layers=2,
77num_attention_heads=2,
78intermediate_size=64,
79flow_size=16,
80vocab_size=38,
81spectrogram_bins=8,
82duration_predictor_num_flows=2,
83duration_predictor_filter_channels=16,
84prior_encoder_num_flows=2,
85upsample_initial_channel=16,
86upsample_rates=[8, 2],
87upsample_kernel_sizes=[16, 4],
88resblock_kernel_sizes=[3, 7],
89resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
90):
91self.parent = parent
92self.batch_size = batch_size
93self.seq_length = seq_length
94self.is_training = is_training
95self.hidden_size = hidden_size
96self.num_hidden_layers = num_hidden_layers
97self.num_attention_heads = num_attention_heads
98self.intermediate_size = intermediate_size
99self.flow_size = flow_size
100self.vocab_size = vocab_size
101self.spectrogram_bins = spectrogram_bins
102self.duration_predictor_num_flows = duration_predictor_num_flows
103self.duration_predictor_filter_channels = duration_predictor_filter_channels
104self.prior_encoder_num_flows = prior_encoder_num_flows
105self.upsample_initial_channel = upsample_initial_channel
106self.upsample_rates = upsample_rates
107self.upsample_kernel_sizes = upsample_kernel_sizes
108self.resblock_kernel_sizes = resblock_kernel_sizes
109self.resblock_dilation_sizes = resblock_dilation_sizes
110
111def prepare_config_and_inputs(self):
112input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
113attention_mask = random_attention_mask([self.batch_size, self.seq_length])
114
115config = self.get_config()
116inputs_dict = {
117"input_ids": input_ids,
118"attention_mask": attention_mask,
119}
120return config, inputs_dict
121
122def prepare_config_and_inputs_for_common(self):
123config, inputs_dict = self.prepare_config_and_inputs()
124return config, inputs_dict
125
126def get_config(self):
127return VitsConfig(
128hidden_size=self.hidden_size,
129num_hidden_layers=self.num_hidden_layers,
130num_attention_heads=self.num_attention_heads,
131ffn_dim=self.intermediate_size,
132flow_size=self.flow_size,
133vocab_size=self.vocab_size,
134spectrogram_bins=self.spectrogram_bins,
135duration_predictor_num_flows=self.duration_predictor_num_flows,
136prior_encoder_num_flows=self.prior_encoder_num_flows,
137duration_predictor_filter_channels=self.duration_predictor_filter_channels,
138posterior_encoder_num_wavenet_layers=self.num_hidden_layers,
139upsample_initial_channel=self.upsample_initial_channel,
140upsample_rates=self.upsample_rates,
141upsample_kernel_sizes=self.upsample_kernel_sizes,
142resblock_kernel_sizes=self.resblock_kernel_sizes,
143resblock_dilation_sizes=self.resblock_dilation_sizes,
144)
145
146def create_and_check_model_forward(self, config, inputs_dict):
147model = VitsModel(config=config).to(torch_device).eval()
148
149input_ids = inputs_dict["input_ids"]
150attention_mask = inputs_dict["attention_mask"]
151
152result = model(input_ids, attention_mask=attention_mask)
153self.parent.assertEqual((self.batch_size, 624), result.waveform.shape)
154
155
156@require_torch
157class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
158all_model_classes = (VitsModel,) if is_torch_available() else ()
159pipeline_model_mapping = (
160{"feature-extraction": VitsModel, "text-to-audio": VitsModel} if is_torch_available() else {}
161)
162is_encoder_decoder = False
163test_pruning = False
164test_headmasking = False
165test_resize_embeddings = False
166test_head_masking = False
167test_torchscript = False
168has_attentions = False
169
170input_name = "input_ids"
171
172def setUp(self):
173self.model_tester = VitsModelTester(self)
174self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)
175
176def test_config(self):
177self.config_tester.run_common_tests()
178
179# TODO: @ydshieh
180@is_flaky(description="torch 2.2.0 gives `Timeout >120.0s`")
181def test_pipeline_feature_extraction(self):
182super().test_pipeline_feature_extraction()
183
184@unittest.skip("Need to fix this after #26538")
185def test_model_forward(self):
186set_seed(12345)
187global_rng.seed(12345)
188config_and_inputs = self.model_tester.prepare_config_and_inputs()
189self.model_tester.create_and_check_model_forward(*config_and_inputs)
190
191@require_torch_multi_gpu
192# override to force all elements of the batch to have the same sequence length across GPUs
193def test_multi_gpu_data_parallel_forward(self):
194config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
195config.use_stochastic_duration_prediction = False
196
197# move input tensors to cuda:O
198for key, value in inputs_dict.items():
199if torch.is_tensor(value):
200# make all elements of the batch the same -> ensures the output seq lengths are the same for DP
201value[1:] = value[0]
202inputs_dict[key] = value.to(0)
203
204for model_class in self.all_model_classes:
205model = model_class(config=config)
206model.to(0)
207model.eval()
208
209# Wrap model in nn.DataParallel
210model = torch.nn.DataParallel(model)
211set_seed(555)
212with torch.no_grad():
213_ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform
214
215@unittest.skip("VITS is not deterministic")
216def test_determinism(self):
217pass
218
219@is_flaky(
220max_attempts=3,
221description="Weight initialisation for the VITS conv layers sometimes exceeds the kaiming normal range",
222)
223def test_initialization(self):
224config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
225
226uniform_init_parms = [
227"emb_rel_k",
228"emb_rel_v",
229"conv_1",
230"conv_2",
231"conv_pre",
232"conv_post",
233"conv_proj",
234"conv_dds",
235"project",
236"wavenet.in_layers",
237"wavenet.res_skip_layers",
238"upsampler",
239"resblocks",
240]
241
242configs_no_init = _config_zero_init(config)
243for model_class in self.all_model_classes:
244model = model_class(config=configs_no_init)
245for name, param in model.named_parameters():
246if param.requires_grad:
247if any(x in name for x in uniform_init_parms):
248self.assertTrue(
249-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
250msg=f"Parameter {name} of model {model_class} seems not properly initialized",
251)
252else:
253self.assertIn(
254((param.data.mean() * 1e9).round() / 1e9).item(),
255[0.0, 1.0],
256msg=f"Parameter {name} of model {model_class} seems not properly initialized",
257)
258
259@unittest.skip("VITS has no inputs_embeds")
260def test_inputs_embeds(self):
261pass
262
263@unittest.skip("VITS has no input embeddings")
264def test_model_common_attributes(self):
265pass
266
267# override since the model is not deterministic, so we need to set the seed for each forward pass
268def test_model_outputs_equivalence(self):
269config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
270
271def set_nan_tensor_to_zero(t):
272t[t != t] = 0
273return t
274
275def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
276with torch.no_grad():
277set_seed(0)
278tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
279set_seed(0)
280dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
281
282def recursive_check(tuple_object, dict_object):
283if isinstance(tuple_object, (List, Tuple)):
284for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
285recursive_check(tuple_iterable_value, dict_iterable_value)
286elif isinstance(tuple_object, Dict):
287for tuple_iterable_value, dict_iterable_value in zip(
288tuple_object.values(), dict_object.values()
289):
290recursive_check(tuple_iterable_value, dict_iterable_value)
291elif tuple_object is None:
292return
293else:
294self.assertTrue(
295torch.allclose(
296set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
297),
298msg=(
299"Tuple and dict output are not equal. Difference:"
300f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
301f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
302f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
303),
304)
305
306recursive_check(tuple_output, dict_output)
307
308for model_class in self.all_model_classes:
309model = model_class(config)
310model.to(torch_device)
311model.eval()
312
313tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
314dict_inputs = self._prepare_for_class(inputs_dict, model_class)
315check_equivalence(model, tuple_inputs, dict_inputs)
316
317tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
318dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
319check_equivalence(model, tuple_inputs, dict_inputs)
320
321tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
322dict_inputs = self._prepare_for_class(inputs_dict, model_class)
323check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
324
325tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
326dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
327check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
328
329if self.has_attentions:
330tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
331dict_inputs = self._prepare_for_class(inputs_dict, model_class)
332check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
333
334tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
335dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
336check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
337
338tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
339dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
340check_equivalence(
341model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
342)
343
344# override since the model is not deterministic, so we need to set the seed for each forward pass
345def test_save_load(self):
346config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
347
348def check_save_load(out1, out2):
349# make sure we don't have nans
350out_2 = out2.cpu().numpy()
351out_2[np.isnan(out_2)] = 0
352
353out_1 = out1.cpu().numpy()
354out_1[np.isnan(out_1)] = 0
355max_diff = np.amax(np.abs(out_1 - out_2))
356self.assertLessEqual(max_diff, 1e-5)
357
358for model_class in self.all_model_classes:
359model = model_class(config)
360model.to(torch_device)
361model.eval()
362with torch.no_grad():
363set_seed(0)
364first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
365
366with tempfile.TemporaryDirectory() as tmpdirname:
367model.save_pretrained(tmpdirname)
368
369# the config file (and the generation config file, if it can generate) should be saved
370self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
371self.assertEqual(
372model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
373)
374
375model = model_class.from_pretrained(tmpdirname)
376model.to(torch_device)
377with torch.no_grad():
378set_seed(0)
379second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
380
381if isinstance(first, tuple) and isinstance(second, tuple):
382for tensor1, tensor2 in zip(first, second):
383check_save_load(tensor1, tensor2)
384else:
385check_save_load(first, second)
386
387# overwrite from test_modeling_common
388def _mock_init_weights(self, module):
389if hasattr(module, "weight") and module.weight is not None:
390module.weight.data.fill_(3)
391if hasattr(module, "weight_g") and module.weight_g is not None:
392module.weight_g.data.fill_(3)
393if hasattr(module, "weight_v") and module.weight_v is not None:
394module.weight_v.data.fill_(3)
395if hasattr(module, "bias") and module.bias is not None:
396module.bias.data.fill_(3)
397
398
399@require_torch
400@slow
401class VitsModelIntegrationTests(unittest.TestCase):
402def test_forward(self):
403# GPU gives different results than CPU
404torch_device = "cpu"
405
406model = VitsModel.from_pretrained("facebook/mms-tts-eng")
407model.to(torch_device)
408
409tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
410
411set_seed(555) # make deterministic
412
413input_text = "Mister quilter is the apostle of the middle classes and we are glad to welcome his gospel!"
414input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device)
415
416with torch.no_grad():
417outputs = model(input_ids)
418
419self.assertEqual(outputs.waveform.shape, (1, 87040))
420# fmt: off
421EXPECTED_LOGITS = torch.tensor(
422[
423-0.0042, 0.0176, 0.0354, 0.0504, 0.0621, 0.0777, 0.0980, 0.1224,
4240.1475, 0.1679, 0.1817, 0.1832, 0.1713, 0.1542, 0.1384, 0.1256,
4250.1147, 0.1066, 0.1026, 0.0958, 0.0823, 0.0610, 0.0340, 0.0022,
426-0.0337, -0.0677, -0.0969, -0.1178, -0.1311, -0.1363
427]
428)
429# fmt: on
430self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4))
431