transformers

Форк
0
/
test_modeling_vits.py 
430 строк · 17.3 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch VITS model. """
16

17
import copy
18
import os
19
import tempfile
20
import unittest
21
from typing import Dict, List, Tuple
22

23
import numpy as np
24

25
from transformers import PretrainedConfig, VitsConfig
26
from transformers.testing_utils import (
27
    is_flaky,
28
    is_torch_available,
29
    require_torch,
30
    require_torch_multi_gpu,
31
    slow,
32
    torch_device,
33
)
34
from transformers.trainer_utils import set_seed
35

36
from ...test_configuration_common import ConfigTester
37
from ...test_modeling_common import (
38
    ModelTesterMixin,
39
    global_rng,
40
    ids_tensor,
41
    random_attention_mask,
42
)
43
from ...test_pipeline_mixin import PipelineTesterMixin
44

45

46
if is_torch_available():
47
    import torch
48

49
    from transformers import VitsModel, VitsTokenizer
50

51

52
CONFIG_NAME = "config.json"
53
GENERATION_CONFIG_NAME = "generation_config.json"
54

55

56
def _config_zero_init(config):
57
    configs_no_init = copy.deepcopy(config)
58
    for key in configs_no_init.__dict__.keys():
59
        if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
60
            setattr(configs_no_init, key, 1e-10)
61
        if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
62
            no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
63
            setattr(configs_no_init, key, no_init_subconfig)
64
    return configs_no_init
65

66

67
@require_torch
68
class VitsModelTester:
69
    def __init__(
70
        self,
71
        parent,
72
        batch_size=2,
73
        seq_length=7,
74
        is_training=False,
75
        hidden_size=16,
76
        num_hidden_layers=2,
77
        num_attention_heads=2,
78
        intermediate_size=64,
79
        flow_size=16,
80
        vocab_size=38,
81
        spectrogram_bins=8,
82
        duration_predictor_num_flows=2,
83
        duration_predictor_filter_channels=16,
84
        prior_encoder_num_flows=2,
85
        upsample_initial_channel=16,
86
        upsample_rates=[8, 2],
87
        upsample_kernel_sizes=[16, 4],
88
        resblock_kernel_sizes=[3, 7],
89
        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
90
    ):
91
        self.parent = parent
92
        self.batch_size = batch_size
93
        self.seq_length = seq_length
94
        self.is_training = is_training
95
        self.hidden_size = hidden_size
96
        self.num_hidden_layers = num_hidden_layers
97
        self.num_attention_heads = num_attention_heads
98
        self.intermediate_size = intermediate_size
99
        self.flow_size = flow_size
100
        self.vocab_size = vocab_size
101
        self.spectrogram_bins = spectrogram_bins
102
        self.duration_predictor_num_flows = duration_predictor_num_flows
103
        self.duration_predictor_filter_channels = duration_predictor_filter_channels
104
        self.prior_encoder_num_flows = prior_encoder_num_flows
105
        self.upsample_initial_channel = upsample_initial_channel
106
        self.upsample_rates = upsample_rates
107
        self.upsample_kernel_sizes = upsample_kernel_sizes
108
        self.resblock_kernel_sizes = resblock_kernel_sizes
109
        self.resblock_dilation_sizes = resblock_dilation_sizes
110

111
    def prepare_config_and_inputs(self):
112
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2)
113
        attention_mask = random_attention_mask([self.batch_size, self.seq_length])
114

115
        config = self.get_config()
116
        inputs_dict = {
117
            "input_ids": input_ids,
118
            "attention_mask": attention_mask,
119
        }
120
        return config, inputs_dict
121

122
    def prepare_config_and_inputs_for_common(self):
123
        config, inputs_dict = self.prepare_config_and_inputs()
124
        return config, inputs_dict
125

126
    def get_config(self):
127
        return VitsConfig(
128
            hidden_size=self.hidden_size,
129
            num_hidden_layers=self.num_hidden_layers,
130
            num_attention_heads=self.num_attention_heads,
131
            ffn_dim=self.intermediate_size,
132
            flow_size=self.flow_size,
133
            vocab_size=self.vocab_size,
134
            spectrogram_bins=self.spectrogram_bins,
135
            duration_predictor_num_flows=self.duration_predictor_num_flows,
136
            prior_encoder_num_flows=self.prior_encoder_num_flows,
137
            duration_predictor_filter_channels=self.duration_predictor_filter_channels,
138
            posterior_encoder_num_wavenet_layers=self.num_hidden_layers,
139
            upsample_initial_channel=self.upsample_initial_channel,
140
            upsample_rates=self.upsample_rates,
141
            upsample_kernel_sizes=self.upsample_kernel_sizes,
142
            resblock_kernel_sizes=self.resblock_kernel_sizes,
143
            resblock_dilation_sizes=self.resblock_dilation_sizes,
144
        )
145

146
    def create_and_check_model_forward(self, config, inputs_dict):
147
        model = VitsModel(config=config).to(torch_device).eval()
148

149
        input_ids = inputs_dict["input_ids"]
150
        attention_mask = inputs_dict["attention_mask"]
151

152
        result = model(input_ids, attention_mask=attention_mask)
153
        self.parent.assertEqual((self.batch_size, 624), result.waveform.shape)
154

155

156
@require_torch
157
class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
158
    all_model_classes = (VitsModel,) if is_torch_available() else ()
159
    pipeline_model_mapping = (
160
        {"feature-extraction": VitsModel, "text-to-audio": VitsModel} if is_torch_available() else {}
161
    )
162
    is_encoder_decoder = False
163
    test_pruning = False
164
    test_headmasking = False
165
    test_resize_embeddings = False
166
    test_head_masking = False
167
    test_torchscript = False
168
    has_attentions = False
169

170
    input_name = "input_ids"
171

172
    def setUp(self):
173
        self.model_tester = VitsModelTester(self)
174
        self.config_tester = ConfigTester(self, config_class=VitsConfig, hidden_size=37)
175

176
    def test_config(self):
177
        self.config_tester.run_common_tests()
178

179
    # TODO: @ydshieh
180
    @is_flaky(description="torch 2.2.0 gives `Timeout >120.0s`")
181
    def test_pipeline_feature_extraction(self):
182
        super().test_pipeline_feature_extraction()
183

184
    @unittest.skip("Need to fix this after #26538")
185
    def test_model_forward(self):
186
        set_seed(12345)
187
        global_rng.seed(12345)
188
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
189
        self.model_tester.create_and_check_model_forward(*config_and_inputs)
190

191
    @require_torch_multi_gpu
192
    # override to force all elements of the batch to have the same sequence length across GPUs
193
    def test_multi_gpu_data_parallel_forward(self):
194
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
195
        config.use_stochastic_duration_prediction = False
196

197
        # move input tensors to cuda:O
198
        for key, value in inputs_dict.items():
199
            if torch.is_tensor(value):
200
                # make all elements of the batch the same -> ensures the output seq lengths are the same for DP
201
                value[1:] = value[0]
202
                inputs_dict[key] = value.to(0)
203

204
        for model_class in self.all_model_classes:
205
            model = model_class(config=config)
206
            model.to(0)
207
            model.eval()
208

209
            # Wrap model in nn.DataParallel
210
            model = torch.nn.DataParallel(model)
211
            set_seed(555)
212
            with torch.no_grad():
213
                _ = model(**self._prepare_for_class(inputs_dict, model_class)).waveform
214

215
    @unittest.skip("VITS is not deterministic")
216
    def test_determinism(self):
217
        pass
218

219
    @is_flaky(
220
        max_attempts=3,
221
        description="Weight initialisation for the VITS conv layers sometimes exceeds the kaiming normal range",
222
    )
223
    def test_initialization(self):
224
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
225

226
        uniform_init_parms = [
227
            "emb_rel_k",
228
            "emb_rel_v",
229
            "conv_1",
230
            "conv_2",
231
            "conv_pre",
232
            "conv_post",
233
            "conv_proj",
234
            "conv_dds",
235
            "project",
236
            "wavenet.in_layers",
237
            "wavenet.res_skip_layers",
238
            "upsampler",
239
            "resblocks",
240
        ]
241

242
        configs_no_init = _config_zero_init(config)
243
        for model_class in self.all_model_classes:
244
            model = model_class(config=configs_no_init)
245
            for name, param in model.named_parameters():
246
                if param.requires_grad:
247
                    if any(x in name for x in uniform_init_parms):
248
                        self.assertTrue(
249
                            -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
250
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
251
                        )
252
                    else:
253
                        self.assertIn(
254
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
255
                            [0.0, 1.0],
256
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
257
                        )
258

259
    @unittest.skip("VITS has no inputs_embeds")
260
    def test_inputs_embeds(self):
261
        pass
262

263
    @unittest.skip("VITS has no input embeddings")
264
    def test_model_common_attributes(self):
265
        pass
266

267
    # override since the model is not deterministic, so we need to set the seed for each forward pass
268
    def test_model_outputs_equivalence(self):
269
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
270

271
        def set_nan_tensor_to_zero(t):
272
            t[t != t] = 0
273
            return t
274

275
        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
276
            with torch.no_grad():
277
                set_seed(0)
278
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
279
                set_seed(0)
280
                dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
281

282
                def recursive_check(tuple_object, dict_object):
283
                    if isinstance(tuple_object, (List, Tuple)):
284
                        for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
285
                            recursive_check(tuple_iterable_value, dict_iterable_value)
286
                    elif isinstance(tuple_object, Dict):
287
                        for tuple_iterable_value, dict_iterable_value in zip(
288
                            tuple_object.values(), dict_object.values()
289
                        ):
290
                            recursive_check(tuple_iterable_value, dict_iterable_value)
291
                    elif tuple_object is None:
292
                        return
293
                    else:
294
                        self.assertTrue(
295
                            torch.allclose(
296
                                set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
297
                            ),
298
                            msg=(
299
                                "Tuple and dict output are not equal. Difference:"
300
                                f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
301
                                f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
302
                                f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
303
                            ),
304
                        )
305

306
                recursive_check(tuple_output, dict_output)
307

308
        for model_class in self.all_model_classes:
309
            model = model_class(config)
310
            model.to(torch_device)
311
            model.eval()
312

313
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
314
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
315
            check_equivalence(model, tuple_inputs, dict_inputs)
316

317
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
318
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
319
            check_equivalence(model, tuple_inputs, dict_inputs)
320

321
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
322
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
323
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
324

325
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
326
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
327
            check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
328

329
            if self.has_attentions:
330
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
331
                dict_inputs = self._prepare_for_class(inputs_dict, model_class)
332
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
333

334
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
335
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
336
                check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
337

338
                tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
339
                dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
340
                check_equivalence(
341
                    model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
342
                )
343

344
    # override since the model is not deterministic, so we need to set the seed for each forward pass
345
    def test_save_load(self):
346
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
347

348
        def check_save_load(out1, out2):
349
            # make sure we don't have nans
350
            out_2 = out2.cpu().numpy()
351
            out_2[np.isnan(out_2)] = 0
352

353
            out_1 = out1.cpu().numpy()
354
            out_1[np.isnan(out_1)] = 0
355
            max_diff = np.amax(np.abs(out_1 - out_2))
356
            self.assertLessEqual(max_diff, 1e-5)
357

358
        for model_class in self.all_model_classes:
359
            model = model_class(config)
360
            model.to(torch_device)
361
            model.eval()
362
            with torch.no_grad():
363
                set_seed(0)
364
                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
365

366
            with tempfile.TemporaryDirectory() as tmpdirname:
367
                model.save_pretrained(tmpdirname)
368

369
                # the config file (and the generation config file, if it can generate) should be saved
370
                self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
371
                self.assertEqual(
372
                    model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
373
                )
374

375
                model = model_class.from_pretrained(tmpdirname)
376
                model.to(torch_device)
377
                with torch.no_grad():
378
                    set_seed(0)
379
                    second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
380

381
            if isinstance(first, tuple) and isinstance(second, tuple):
382
                for tensor1, tensor2 in zip(first, second):
383
                    check_save_load(tensor1, tensor2)
384
            else:
385
                check_save_load(first, second)
386

387
    # overwrite from test_modeling_common
388
    def _mock_init_weights(self, module):
389
        if hasattr(module, "weight") and module.weight is not None:
390
            module.weight.data.fill_(3)
391
        if hasattr(module, "weight_g") and module.weight_g is not None:
392
            module.weight_g.data.fill_(3)
393
        if hasattr(module, "weight_v") and module.weight_v is not None:
394
            module.weight_v.data.fill_(3)
395
        if hasattr(module, "bias") and module.bias is not None:
396
            module.bias.data.fill_(3)
397

398

399
@require_torch
400
@slow
401
class VitsModelIntegrationTests(unittest.TestCase):
402
    def test_forward(self):
403
        # GPU gives different results than CPU
404
        torch_device = "cpu"
405

406
        model = VitsModel.from_pretrained("facebook/mms-tts-eng")
407
        model.to(torch_device)
408

409
        tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
410

411
        set_seed(555)  # make deterministic
412

413
        input_text = "Mister quilter is the apostle of the middle classes and we are glad to welcome his gospel!"
414
        input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device)
415

416
        with torch.no_grad():
417
            outputs = model(input_ids)
418

419
        self.assertEqual(outputs.waveform.shape, (1, 87040))
420
        # fmt: off
421
        EXPECTED_LOGITS = torch.tensor(
422
            [
423
                -0.0042,  0.0176,  0.0354,  0.0504,  0.0621,  0.0777,  0.0980,  0.1224,
424
                 0.1475,  0.1679,  0.1817,  0.1832,  0.1713,  0.1542,  0.1384,  0.1256,
425
                 0.1147,  0.1066,  0.1026,  0.0958,  0.0823,  0.0610,  0.0340,  0.0022,
426
                -0.0337, -0.0677, -0.0969, -0.1178, -0.1311, -0.1363
427
            ]
428
        )
429
        # fmt: on
430
        self.assertTrue(torch.allclose(outputs.waveform[0, 10000:10030].cpu(), EXPECTED_LOGITS, atol=1e-4))
431

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.