transformers

Форк
0
/
test_feature_extraction_whisper.py 
255 строк · 11.2 Кб
1
# coding=utf-8
2
# Copyright 2022 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
import itertools
18
import os
19
import random
20
import tempfile
21
import unittest
22

23
import numpy as np
24
from datasets import load_dataset
25

26
from transformers import WhisperFeatureExtractor
27
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
28
from transformers.utils.import_utils import is_torch_available
29

30
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
31

32

33
if is_torch_available():
34
    import torch
35

36
global_rng = random.Random()
37

38

39
def floats_list(shape, scale=1.0, rng=None, name=None):
40
    """Creates a random float32 tensor"""
41
    if rng is None:
42
        rng = global_rng
43

44
    values = []
45
    for batch_idx in range(shape[0]):
46
        values.append([])
47
        for _ in range(shape[1]):
48
            values[-1].append(rng.random() * scale)
49

50
    return values
51

52

53
class WhisperFeatureExtractionTester(unittest.TestCase):
54
    def __init__(
55
        self,
56
        parent,
57
        batch_size=7,
58
        min_seq_length=400,
59
        max_seq_length=2000,
60
        feature_size=10,
61
        hop_length=160,
62
        chunk_length=8,
63
        padding_value=0.0,
64
        sampling_rate=4_000,
65
        return_attention_mask=False,
66
        do_normalize=True,
67
    ):
68
        self.parent = parent
69
        self.batch_size = batch_size
70
        self.min_seq_length = min_seq_length
71
        self.max_seq_length = max_seq_length
72
        self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
73
        self.padding_value = padding_value
74
        self.sampling_rate = sampling_rate
75
        self.return_attention_mask = return_attention_mask
76
        self.do_normalize = do_normalize
77
        self.feature_size = feature_size
78
        self.chunk_length = chunk_length
79
        self.hop_length = hop_length
80

81
    def prepare_feat_extract_dict(self):
82
        return {
83
            "feature_size": self.feature_size,
84
            "hop_length": self.hop_length,
85
            "chunk_length": self.chunk_length,
86
            "padding_value": self.padding_value,
87
            "sampling_rate": self.sampling_rate,
88
            "return_attention_mask": self.return_attention_mask,
89
            "do_normalize": self.do_normalize,
90
        }
91

92
    def prepare_inputs_for_common(self, equal_length=False, numpify=False):
93
        def _flatten(list_of_lists):
94
            return list(itertools.chain(*list_of_lists))
95

96
        if equal_length:
97
            speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
98
        else:
99
            # make sure that inputs increase in size
100
            speech_inputs = [
101
                floats_list((x, self.feature_size))
102
                for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
103
            ]
104
        if numpify:
105
            speech_inputs = [np.asarray(x) for x in speech_inputs]
106
        return speech_inputs
107

108

109
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
110
    feature_extraction_class = WhisperFeatureExtractor
111

112
    def setUp(self):
113
        self.feat_extract_tester = WhisperFeatureExtractionTester(self)
114

115
    def test_feat_extract_from_and_save_pretrained(self):
116
        feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
117

118
        with tempfile.TemporaryDirectory() as tmpdirname:
119
            saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
120
            check_json_file_has_correct_format(saved_file)
121
            feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
122

123
        dict_first = feat_extract_first.to_dict()
124
        dict_second = feat_extract_second.to_dict()
125
        mel_1 = feat_extract_first.mel_filters
126
        mel_2 = feat_extract_second.mel_filters
127
        self.assertTrue(np.allclose(mel_1, mel_2))
128
        self.assertEqual(dict_first, dict_second)
129

130
    def test_feat_extract_to_json_file(self):
131
        feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
132

133
        with tempfile.TemporaryDirectory() as tmpdirname:
134
            json_file_path = os.path.join(tmpdirname, "feat_extract.json")
135
            feat_extract_first.to_json_file(json_file_path)
136
            feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
137

138
        dict_first = feat_extract_first.to_dict()
139
        dict_second = feat_extract_second.to_dict()
140
        mel_1 = feat_extract_first.mel_filters
141
        mel_2 = feat_extract_second.mel_filters
142
        self.assertTrue(np.allclose(mel_1, mel_2))
143
        self.assertEqual(dict_first, dict_second)
144

145
    def test_call(self):
146
        # Tests that all call wrap to encode_plus and batch_encode_plus
147
        feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
148
        # create three inputs of length 800, 1000, and 1200
149
        speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
150
        np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
151

152
        # Test feature size
153
        input_features = feature_extractor(np_speech_inputs, padding="max_length", return_tensors="np").input_features
154
        self.assertTrue(input_features.ndim == 3)
155
        self.assertTrue(input_features.shape[-1] == feature_extractor.nb_max_frames)
156
        self.assertTrue(input_features.shape[-2] == feature_extractor.feature_size)
157

158
        # Test not batched input
159
        encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
160
        encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
161
        self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
162

163
        # Test batched
164
        encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
165
        encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
166
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
167
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
168

169
        # Test 2-D numpy arrays are batched.
170
        speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
171
        np_speech_inputs = np.asarray(speech_inputs)
172
        encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
173
        encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
174
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
175
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
176

177
        # Test truncation required
178
        speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)]
179
        np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
180

181
        speech_inputs_truncated = [x[: feature_extractor.n_samples] for x in speech_inputs]
182
        np_speech_inputs_truncated = [np.asarray(speech_input) for speech_input in speech_inputs_truncated]
183

184
        encoded_sequences_1 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
185
        encoded_sequences_2 = feature_extractor(np_speech_inputs_truncated, return_tensors="np").input_features
186
        for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
187
            self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
188

189
    @require_torch
190
    def test_double_precision_pad(self):
191
        import torch
192

193
        feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
194
        np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
195
        py_speech_inputs = np_speech_inputs.tolist()
196

197
        for inputs in [py_speech_inputs, np_speech_inputs]:
198
            np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
199
            self.assertTrue(np_processed.input_features.dtype == np.float32)
200
            pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
201
            self.assertTrue(pt_processed.input_features.dtype == torch.float32)
202

203
    def _load_datasamples(self, num_samples):
204
        ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
205
        # automatic decoding with librispeech
206
        speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
207

208
        return [x["array"] for x in speech_samples]
209

210
    @require_torch
211
    def test_torch_integration(self):
212
        # fmt: off
213
        EXPECTED_INPUT_FEATURES = torch.tensor(
214
            [
215
                0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
216
                0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
217
                0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
218
                -0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
219
            ]
220
        )
221
        # fmt: on
222

223
        input_speech = self._load_datasamples(1)
224
        feature_extractor = WhisperFeatureExtractor()
225
        input_features = feature_extractor(input_speech, return_tensors="pt").input_features
226
        self.assertEqual(input_features.shape, (1, 80, 3000))
227
        self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
228

229
    @unittest.mock.patch("transformers.models.whisper.feature_extraction_whisper.is_torch_available", lambda: False)
230
    def test_numpy_integration(self):
231
        # fmt: off
232
        EXPECTED_INPUT_FEATURES = np.array(
233
            [
234
                0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
235
                0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
236
                0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
237
                -0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
238
            ]
239
        )
240
        # fmt: on
241

242
        input_speech = self._load_datasamples(1)
243
        feature_extractor = WhisperFeatureExtractor()
244
        input_features = feature_extractor(input_speech, return_tensors="np").input_features
245
        self.assertEqual(input_features.shape, (1, 80, 3000))
246
        self.assertTrue(np.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
247

248
    def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
249
        feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
250
        audio = self._load_datasamples(1)[0]
251
        audio = ((audio - audio.min()) / (audio.max() - audio.min())) * 65535  # Rescale to [0, 65535] to show issue
252
        audio = feat_extract.zero_mean_unit_var_norm([audio], attention_mask=None)[0]
253

254
        self.assertTrue(np.all(np.mean(audio) < 1e-3))
255
        self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))
256

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.