transformers

Форк
0
/
test_feature_extraction_speecht5.py 
421 строка · 18.7 Кб
1
# coding=utf-8
2
# Copyright 2021-2023 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""Tests for the SpeechT5 feature extractors."""
16

17
import itertools
18
import random
19
import unittest
20

21
import numpy as np
22

23
from transformers import BatchFeature, SpeechT5FeatureExtractor
24
from transformers.testing_utils import require_torch
25
from transformers.utils.import_utils import is_torch_available
26

27
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
28

29

30
if is_torch_available():
31
    import torch
32

33

34
global_rng = random.Random()
35

36

37
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
38
def floats_list(shape, scale=1.0, rng=None, name=None):
39
    """Creates a random float32 tensor"""
40
    if rng is None:
41
        rng = global_rng
42

43
    values = []
44
    for batch_idx in range(shape[0]):
45
        values.append([])
46
        for _ in range(shape[1]):
47
            values[-1].append(rng.random() * scale)
48

49
    return values
50

51

52
@require_torch
53
class SpeechT5FeatureExtractionTester(unittest.TestCase):
54
    def __init__(
55
        self,
56
        parent,
57
        batch_size=7,
58
        min_seq_length=400,
59
        max_seq_length=2000,
60
        feature_size=1,
61
        padding_value=0.0,
62
        sampling_rate=16000,
63
        do_normalize=True,
64
        num_mel_bins=80,
65
        hop_length=16,
66
        win_length=64,
67
        win_function="hann_window",
68
        fmin=80,
69
        fmax=7600,
70
        mel_floor=1e-10,
71
        return_attention_mask=True,
72
    ):
73
        self.parent = parent
74
        self.batch_size = batch_size
75
        self.min_seq_length = min_seq_length
76
        self.max_seq_length = max_seq_length
77
        self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
78
        self.feature_size = feature_size
79
        self.padding_value = padding_value
80
        self.sampling_rate = sampling_rate
81
        self.do_normalize = do_normalize
82
        self.num_mel_bins = num_mel_bins
83
        self.hop_length = hop_length
84
        self.win_length = win_length
85
        self.win_function = win_function
86
        self.fmin = fmin
87
        self.fmax = fmax
88
        self.mel_floor = mel_floor
89
        self.return_attention_mask = return_attention_mask
90

91
    def prepare_feat_extract_dict(self):
92
        return {
93
            "feature_size": self.feature_size,
94
            "padding_value": self.padding_value,
95
            "sampling_rate": self.sampling_rate,
96
            "do_normalize": self.do_normalize,
97
            "num_mel_bins": self.num_mel_bins,
98
            "hop_length": self.hop_length,
99
            "win_length": self.win_length,
100
            "win_function": self.win_function,
101
            "fmin": self.fmin,
102
            "fmax": self.fmax,
103
            "mel_floor": self.mel_floor,
104
            "return_attention_mask": self.return_attention_mask,
105
        }
106

107
    def prepare_inputs_for_common(self, equal_length=False, numpify=False):
108
        def _flatten(list_of_lists):
109
            return list(itertools.chain(*list_of_lists))
110

111
        if equal_length:
112
            speech_inputs = floats_list((self.batch_size, self.max_seq_length))
113
        else:
114
            # make sure that inputs increase in size
115
            speech_inputs = [
116
                _flatten(floats_list((x, self.feature_size)))
117
                for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
118
            ]
119

120
        if numpify:
121
            speech_inputs = [np.asarray(x) for x in speech_inputs]
122

123
        return speech_inputs
124

125
    def prepare_inputs_for_target(self, equal_length=False, numpify=False):
126
        if equal_length:
127
            speech_inputs = [floats_list((self.max_seq_length, self.num_mel_bins)) for _ in range(self.batch_size)]
128
        else:
129
            # make sure that inputs increase in size
130
            speech_inputs = [
131
                floats_list((x, self.num_mel_bins))
132
                for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
133
            ]
134

135
        if numpify:
136
            speech_inputs = [np.asarray(x) for x in speech_inputs]
137

138
        return speech_inputs
139

140

141
@require_torch
142
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
143
    feature_extraction_class = SpeechT5FeatureExtractor
144

145
    def setUp(self):
146
        self.feat_extract_tester = SpeechT5FeatureExtractionTester(self)
147

148
    def _check_zero_mean_unit_variance(self, input_vector):
149
        self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
150
        self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
151

152
    def test_call(self):
153
        # Tests that all call wrap to encode_plus and batch_encode_plus
154
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
155
        # create three inputs of length 800, 1000, and 1200
156
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
157
        np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
158

159
        # Test not batched input
160
        encoded_sequences_1 = feat_extract(speech_inputs[0], return_tensors="np").input_values
161
        encoded_sequences_2 = feat_extract(np_speech_inputs[0], return_tensors="np").input_values
162
        self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
163

164
        # Test batched
165
        encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
166
        encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
167
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
168
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
169

170
    def test_zero_mean_unit_variance_normalization_np(self):
171
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
172
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
173

174
        paddings = ["longest", "max_length", "do_not_pad"]
175
        max_lengths = [None, 1600, None]
176
        for max_length, padding in zip(max_lengths, paddings):
177
            processed = feat_extract(speech_inputs, padding=padding, max_length=max_length, return_tensors="np")
178
            input_values = processed.input_values
179

180
            self._check_zero_mean_unit_variance(input_values[0][:800])
181
            self.assertTrue(input_values[0][800:].sum() < 1e-6)
182
            self._check_zero_mean_unit_variance(input_values[1][:1000])
183
            self.assertTrue(input_values[0][1000:].sum() < 1e-6)
184
            self._check_zero_mean_unit_variance(input_values[2][:1200])
185

186
    def test_zero_mean_unit_variance_normalization(self):
187
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
188
        lengths = range(800, 1400, 200)
189
        speech_inputs = [floats_list((1, x))[0] for x in lengths]
190

191
        paddings = ["longest", "max_length", "do_not_pad"]
192
        max_lengths = [None, 1600, None]
193

194
        for max_length, padding in zip(max_lengths, paddings):
195
            processed = feat_extract(speech_inputs, max_length=max_length, padding=padding)
196
            input_values = processed.input_values
197

198
            self._check_zero_mean_unit_variance(input_values[0][:800])
199
            self._check_zero_mean_unit_variance(input_values[1][:1000])
200
            self._check_zero_mean_unit_variance(input_values[2][:1200])
201

202
    def test_zero_mean_unit_variance_normalization_trunc_np_max_length(self):
203
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
204
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
205
        processed = feat_extract(
206
            speech_inputs, truncation=True, max_length=1000, padding="max_length", return_tensors="np"
207
        )
208
        input_values = processed.input_values
209

210
        self._check_zero_mean_unit_variance(input_values[0, :800])
211
        self._check_zero_mean_unit_variance(input_values[1])
212
        self._check_zero_mean_unit_variance(input_values[2])
213

214
    def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
215
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
216
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
217
        processed = feat_extract(
218
            speech_inputs, truncation=True, max_length=1000, padding="longest", return_tensors="np"
219
        )
220
        input_values = processed.input_values
221

222
        self._check_zero_mean_unit_variance(input_values[0, :800])
223
        self._check_zero_mean_unit_variance(input_values[1, :1000])
224
        self._check_zero_mean_unit_variance(input_values[2])
225

226
        # make sure that if max_length < longest -> then pad to max_length
227
        self.assertTrue(input_values.shape == (3, 1000))
228

229
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
230
        processed = feat_extract(
231
            speech_inputs, truncation=True, max_length=2000, padding="longest", return_tensors="np"
232
        )
233
        input_values = processed.input_values
234

235
        self._check_zero_mean_unit_variance(input_values[0, :800])
236
        self._check_zero_mean_unit_variance(input_values[1, :1000])
237
        self._check_zero_mean_unit_variance(input_values[2])
238

239
        # make sure that if max_length > longest -> then pad to longest
240
        self.assertTrue(input_values.shape == (3, 1200))
241

242
    def test_double_precision_pad(self):
243
        feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
244
        np_speech_inputs = np.random.rand(100).astype(np.float64)
245
        py_speech_inputs = np_speech_inputs.tolist()
246

247
        for inputs in [py_speech_inputs, np_speech_inputs]:
248
            np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
249
            self.assertTrue(np_processed.input_values.dtype == np.float32)
250
            pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
251
            self.assertTrue(pt_processed.input_values.dtype == torch.float32)
252

253
    def test_call_target(self):
254
        # Tests that all call wrap to encode_plus and batch_encode_plus
255
        feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
256
        # create three inputs of length 800, 1000, and 1200
257
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
258
        np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
259

260
        # Test feature size
261
        input_values = feature_extractor(audio_target=np_speech_inputs, padding=True, return_tensors="np").input_values
262
        self.assertTrue(input_values.ndim == 3)
263
        self.assertTrue(input_values.shape[-1] == feature_extractor.num_mel_bins)
264

265
        # Test not batched input
266
        encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_values
267
        encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_values
268
        self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
269

270
        # Test batched
271
        encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_values
272
        encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_values
273
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
274
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
275

276
        # Test 2-D numpy arrays are batched.
277
        speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
278
        np_speech_inputs = np.asarray(speech_inputs)
279
        encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_values
280
        encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_values
281
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
282
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
283

284
    def test_batch_feature_target(self):
285
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
286
        feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
287
        input_name = feat_extract.model_input_names[0]
288

289
        processed_features = BatchFeature({input_name: speech_inputs})
290

291
        self.assertTrue(all(len(x) == len(y) for x, y in zip(speech_inputs, processed_features[input_name])))
292

293
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target(equal_length=True)
294
        processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="np")
295

296
        batch_features_input = processed_features[input_name]
297

298
        if len(batch_features_input.shape) < 3:
299
            batch_features_input = batch_features_input[:, :, None]
300

301
        self.assertTrue(
302
            batch_features_input.shape
303
            == (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.num_mel_bins)
304
        )
305

306
    @require_torch
307
    def test_batch_feature_target_pt(self):
308
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target(equal_length=True)
309
        feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
310
        input_name = feat_extract.model_input_names[0]
311

312
        processed_features = BatchFeature({input_name: speech_inputs}, tensor_type="pt")
313

314
        batch_features_input = processed_features[input_name]
315

316
        if len(batch_features_input.shape) < 3:
317
            batch_features_input = batch_features_input[:, :, None]
318

319
        self.assertTrue(
320
            batch_features_input.shape
321
            == (self.feat_extract_tester.batch_size, len(speech_inputs[0]), self.feat_extract_tester.num_mel_bins)
322
        )
323

324
    @require_torch
325
    def test_padding_accepts_tensors_target_pt(self):
326
        feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
327
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
328
        input_name = feat_extract.model_input_names[0]
329

330
        processed_features = BatchFeature({input_name: speech_inputs})
331

332
        feat_extract.feature_size = feat_extract.num_mel_bins  # hack!
333

334
        input_np = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
335
        input_pt = feat_extract.pad(processed_features, padding="longest", return_tensors="pt")[input_name]
336

337
        self.assertTrue(abs(input_np.astype(np.float32).sum() - input_pt.numpy().astype(np.float32).sum()) < 1e-2)
338

339
    def test_attention_mask_target(self):
340
        feat_dict = self.feat_extract_dict
341
        feat_dict["return_attention_mask"] = True
342
        feat_extract = self.feature_extraction_class(**feat_dict)
343
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
344
        input_lengths = [len(x) for x in speech_inputs]
345
        input_name = feat_extract.model_input_names[0]
346

347
        processed = BatchFeature({input_name: speech_inputs})
348

349
        feat_extract.feature_size = feat_extract.num_mel_bins  # hack!
350

351
        processed = feat_extract.pad(processed, padding="longest", return_tensors="np")
352
        self.assertIn("attention_mask", processed)
353
        self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
354
        self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lengths)
355

356
    def test_attention_mask_with_truncation_target(self):
357
        feat_dict = self.feat_extract_dict
358
        feat_dict["return_attention_mask"] = True
359
        feat_extract = self.feature_extraction_class(**feat_dict)
360
        speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
361
        input_lengths = [len(x) for x in speech_inputs]
362
        input_name = feat_extract.model_input_names[0]
363

364
        processed = BatchFeature({input_name: speech_inputs})
365
        max_length = min(input_lengths)
366

367
        feat_extract.feature_size = feat_extract.num_mel_bins  # hack!
368

369
        processed_pad = feat_extract.pad(
370
            processed, padding="max_length", max_length=max_length, truncation=True, return_tensors="np"
371
        )
372
        self.assertIn("attention_mask", processed_pad)
373
        self.assertListEqual(
374
            list(processed_pad.attention_mask.shape), [processed_pad[input_name].shape[0], max_length]
375
        )
376
        self.assertListEqual(
377
            processed_pad.attention_mask[:, :max_length].sum(-1).tolist(), [max_length for x in speech_inputs]
378
        )
379

380
    def _load_datasamples(self, num_samples):
381
        from datasets import load_dataset
382

383
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
384
        # automatic decoding with librispeech
385
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
386

387
        return [x["array"] for x in speech_samples]
388

389
    def test_integration(self):
390
        # fmt: off
391
        EXPECTED_INPUT_VALUES = torch.tensor(
392
            [2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03,
393
             3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03,
394
             2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04,
395
             4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03,
396
             7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04,
397
             4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03]
398
        )
399
        # fmt: on
400

401
        input_speech = self._load_datasamples(1)
402
        feature_extractor = SpeechT5FeatureExtractor()
403
        input_values = feature_extractor(input_speech, return_tensors="pt").input_values
404
        self.assertEquals(input_values.shape, (1, 93680))
405
        self.assertTrue(torch.allclose(input_values[0, :30], EXPECTED_INPUT_VALUES, atol=1e-6))
406

407
    def test_integration_target(self):
408
        # fmt: off
409
        EXPECTED_INPUT_VALUES = torch.tensor(
410
            [-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
411
             -3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
412
             -3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
413
             -3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
414
        )
415
        # fmt: on
416

417
        input_speech = self._load_datasamples(1)
418
        feature_extractor = SpeechT5FeatureExtractor()
419
        input_values = feature_extractor(audio_target=input_speech, return_tensors="pt").input_values
420
        self.assertEquals(input_values.shape, (1, 366, 80))
421
        self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
422

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.