gradio

Форк
0
/
test_processing_utils.py 
379 строк · 14.2 Кб
1
import os
2
import shutil
3
import tempfile
4
from copy import deepcopy
5
from pathlib import Path
6
from unittest.mock import patch
7

8
import ffmpy
9
import numpy as np
10
import pytest
11
from gradio_client import media_data
12
from PIL import Image, ImageCms
13

14
from gradio import processing_utils, utils
15

16

17
class TestTempFileManagement:
18
    def test_hash_file(self):
19
        h1 = processing_utils.hash_file("gradio/test_data/cheetah1.jpg")
20
        h2 = processing_utils.hash_file("gradio/test_data/cheetah1-copy.jpg")
21
        h3 = processing_utils.hash_file("gradio/test_data/cheetah2.jpg")
22
        assert h1 == h2
23
        assert h1 != h3
24

25
    def test_make_temp_copy_if_needed(self, gradio_temp_dir):
26
        f = processing_utils.save_file_to_cache(
27
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
28
        )
29
        try:  # Delete if already exists from before this test
30
            os.remove(f)
31
        except OSError:
32
            pass
33

34
        f = processing_utils.save_file_to_cache(
35
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
36
        )
37
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
38

39
        assert Path(f).name == "cheetah1.jpg"
40

41
        f = processing_utils.save_file_to_cache(
42
            "gradio/test_data/cheetah1.jpg", cache_dir=gradio_temp_dir
43
        )
44
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
45

46
        f = processing_utils.save_file_to_cache(
47
            "gradio/test_data/cheetah1-copy.jpg", cache_dir=gradio_temp_dir
48
        )
49
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2
50
        assert Path(f).name == "cheetah1-copy.jpg"
51

52
    def test_save_b64_to_cache(self, gradio_temp_dir):
53
        base64_file_1 = media_data.BASE64_IMAGE
54
        base64_file_2 = media_data.BASE64_AUDIO["data"]
55

56
        f = processing_utils.save_base64_to_cache(
57
            base64_file_1, cache_dir=gradio_temp_dir
58
        )
59
        try:  # Delete if already exists from before this test
60
            os.remove(f)
61
        except OSError:
62
            pass
63

64
        f = processing_utils.save_base64_to_cache(
65
            base64_file_1, cache_dir=gradio_temp_dir
66
        )
67
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
68

69
        f = processing_utils.save_base64_to_cache(
70
            base64_file_1, cache_dir=gradio_temp_dir
71
        )
72
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
73

74
        f = processing_utils.save_base64_to_cache(
75
            base64_file_2, cache_dir=gradio_temp_dir
76
        )
77
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2
78

79
    @pytest.mark.flaky
80
    def test_save_url_to_cache(self, gradio_temp_dir):
81
        url1 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/test_image.png"
82
        url2 = "https://raw.githubusercontent.com/gradio-app/gradio/main/gradio/test_data/cheetah1.jpg"
83

84
        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
85
        try:  # Delete if already exists from before this test
86
            os.remove(f)
87
        except OSError:
88
            pass
89

90
        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
91
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
92

93
        f = processing_utils.save_url_to_cache(url1, cache_dir=gradio_temp_dir)
94
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
95

96
        f = processing_utils.save_url_to_cache(url2, cache_dir=gradio_temp_dir)
97
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2
98

99
    def test_save_url_to_cache_with_spaces(self, gradio_temp_dir):
100
        url = "https://huggingface.co/datasets/freddyaboulton/gradio-reviews/resolve/main00015-20230906102032-7778-Wonderwoman VintageMagStyle   _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
101
        processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
102
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
103

104
    def test_save_url_to_cache_with_redirect(self, gradio_temp_dir):
105
        url = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
106
        processing_utils.save_url_to_cache(url, cache_dir=gradio_temp_dir)
107
        assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1
108

109

110
class TestImagePreprocessing:
111
    def test_encode_plot_to_base64(self):
112
        with utils.MatplotlibBackendMananger():
113
            import matplotlib.pyplot as plt
114

115
            plt.plot([1, 2, 3, 4])
116
            output_base64 = processing_utils.encode_plot_to_base64(plt)
117
        assert output_base64.startswith(
118
            ""
119
        )
120

121
    def test_encode_array_to_base64(self):
122
        img = Image.open("gradio/test_data/test_image.png")
123
        img = img.convert("RGB")
124
        numpy_data = np.asarray(img, dtype=np.uint8)
125
        output_base64 = processing_utils.encode_array_to_base64(numpy_data)
126
        assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
127

128
    def test_encode_pil_to_base64(self):
129
        img = Image.open("gradio/test_data/test_image.png")
130
        img = img.convert("RGB")
131
        img.info = {}  # Strip metadata
132
        output_base64 = processing_utils.encode_pil_to_base64(img)
133
        assert output_base64 == deepcopy(media_data.ARRAY_TO_BASE64_IMAGE)
134

135
    def test_save_pil_to_file_keeps_pnginfo(self, gradio_temp_dir):
136
        input_img = Image.open("gradio/test_data/test_image.png")
137
        input_img = input_img.convert("RGB")
138
        input_img.info = {"key1": "value1", "key2": "value2"}
139
        input_img.save(gradio_temp_dir / "test_test_image.png")
140

141
        file_obj = processing_utils.save_pil_to_cache(
142
            input_img, cache_dir=gradio_temp_dir
143
        )
144
        output_img = Image.open(file_obj)
145

146
        assert output_img.info == input_img.info
147

148
    def test_np_pil_encode_to_the_same(self, gradio_temp_dir):
149
        arr = np.random.randint(0, 255, size=(100, 100, 3), dtype=np.uint8)
150
        pil = Image.fromarray(arr)
151
        assert processing_utils.save_pil_to_cache(
152
            pil, cache_dir=gradio_temp_dir
153
        ) == processing_utils.save_img_array_to_cache(arr, cache_dir=gradio_temp_dir)
154

155
    def test_encode_pil_to_temp_file_metadata_color_profile(self, gradio_temp_dir):
156
        # Read image
157
        img = Image.open("gradio/test_data/test_image.png")
158
        img_metadata = Image.open("gradio/test_data/test_image.png")
159
        img_metadata.info = {"key1": "value1", "key2": "value2"}
160

161
        # Creating sRGB profile
162
        profile = ImageCms.createProfile("sRGB")
163
        profile2 = ImageCms.ImageCmsProfile(profile)
164
        img.save(
165
            gradio_temp_dir / "img_color_profile.png", icc_profile=profile2.tobytes()
166
        )
167
        img_cp1 = Image.open(str(gradio_temp_dir / "img_color_profile.png"))
168

169
        # Creating XYZ profile
170
        profile = ImageCms.createProfile("XYZ")
171
        profile2 = ImageCms.ImageCmsProfile(profile)
172
        img.save(
173
            gradio_temp_dir / "img_color_profile_2.png", icc_profile=profile2.tobytes()
174
        )
175
        img_cp2 = Image.open(str(gradio_temp_dir / "img_color_profile_2.png"))
176

177
        img_path = processing_utils.save_pil_to_cache(img, cache_dir=gradio_temp_dir)
178
        img_metadata_path = processing_utils.save_pil_to_cache(
179
            img_metadata, cache_dir=gradio_temp_dir
180
        )
181
        img_cp1_path = processing_utils.save_pil_to_cache(
182
            img_cp1, cache_dir=gradio_temp_dir
183
        )
184
        img_cp2_path = processing_utils.save_pil_to_cache(
185
            img_cp2, cache_dir=gradio_temp_dir
186
        )
187
        assert len({img_path, img_metadata_path, img_cp1_path, img_cp2_path}) == 4
188

189
    def test_resize_and_crop(self):
190
        img = Image.open("gradio/test_data/test_image.png")
191
        new_img = processing_utils.resize_and_crop(img, (20, 20))
192
        assert new_img.size == (20, 20)
193
        with pytest.raises(ValueError):
194
            processing_utils.resize_and_crop(
195
                **{"img": img, "size": (20, 20), "crop_type": "test"}
196
            )
197

198

199
class TestAudioPreprocessing:
200
    def test_audio_from_file(self):
201
        audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
202
        assert audio[0] == 22050
203
        assert isinstance(audio[1], np.ndarray)
204

205
    def test_audio_to_file(self):
206
        audio = processing_utils.audio_from_file("gradio/test_data/test_audio.wav")
207
        processing_utils.audio_to_file(audio[0], audio[1], "test_audio_to_file")
208
        assert os.path.exists("test_audio_to_file")
209
        os.remove("test_audio_to_file")
210

211
    def test_convert_to_16_bit_wav(self):
212
        # Generate a random audio sample and set the amplitude
213
        audio = np.random.randint(-100, 100, size=(100), dtype="int16")
214
        audio[0] = -32767
215
        audio[1] = 32766
216

217
        audio_ = audio.astype("float64")
218
        audio_ = processing_utils.convert_to_16_bit_wav(audio_)
219
        assert np.allclose(audio, audio_)
220
        assert audio_.dtype == "int16"
221

222
        audio_ = audio.astype("float32")
223
        audio_ = processing_utils.convert_to_16_bit_wav(audio_)
224
        assert np.allclose(audio, audio_)
225
        assert audio_.dtype == "int16"
226

227
        audio_ = processing_utils.convert_to_16_bit_wav(audio)
228
        assert np.allclose(audio, audio_)
229
        assert audio_.dtype == "int16"
230

231

232
class TestOutputPreprocessing:
233
    float_dtype_list = [
234
        float,
235
        float,
236
        np.double,
237
        np.single,
238
        np.float32,
239
        np.float64,
240
        "float32",
241
        "float64",
242
    ]
243

244
    def test_float_conversion_dtype(self):
245
        """Test any conversion from a float dtype to an other."""
246

247
        x = np.array([-1, 1])
248
        # Test all combinations of dtypes conversions
249
        dtype_combin = np.array(
250
            np.meshgrid(
251
                TestOutputPreprocessing.float_dtype_list,
252
                TestOutputPreprocessing.float_dtype_list,
253
            )
254
        ).T.reshape(-1, 2)
255

256
        for dtype_in, dtype_out in dtype_combin:
257
            x = x.astype(dtype_in)
258
            y = processing_utils._convert(x, dtype_out)
259
            assert y.dtype == np.dtype(dtype_out)
260

261
    def test_subclass_conversion(self):
262
        """Check subclass conversion behavior"""
263
        x = np.array([-1, 1])
264
        for dtype in TestOutputPreprocessing.float_dtype_list:
265
            x = x.astype(dtype)
266
            y = processing_utils._convert(x, np.floating)
267
            assert y.dtype == x.dtype
268

269

270
class TestVideoProcessing:
271
    def test_video_has_playable_codecs(self, test_file_dir):
272
        assert processing_utils.video_is_playable(
273
            str(test_file_dir / "video_sample.mp4")
274
        )
275
        assert processing_utils.video_is_playable(
276
            str(test_file_dir / "video_sample.ogg")
277
        )
278
        assert processing_utils.video_is_playable(
279
            str(test_file_dir / "video_sample.webm")
280
        )
281
        assert not processing_utils.video_is_playable(
282
            str(test_file_dir / "bad_video_sample.mp4")
283
        )
284

285
    def raise_ffmpy_runtime_exception(*args, **kwargs):
286
        raise ffmpy.FFRuntimeError("", "", "", "")
287

288
    @pytest.mark.parametrize(
289
        "exception_to_raise", [raise_ffmpy_runtime_exception, KeyError(), IndexError()]
290
    )
291
    def test_video_has_playable_codecs_catches_exceptions(
292
        self, exception_to_raise, test_file_dir
293
    ):
294
        with patch(
295
            "ffmpy.FFprobe.run", side_effect=exception_to_raise
296
        ), tempfile.NamedTemporaryFile(
297
            suffix="out.avi", delete=False
298
        ) as tmp_not_playable_vid:
299
            shutil.copy(
300
                str(test_file_dir / "bad_video_sample.mp4"),
301
                tmp_not_playable_vid.name,
302
            )
303
            assert processing_utils.video_is_playable(tmp_not_playable_vid.name)
304

305
    def test_convert_video_to_playable_mp4(self, test_file_dir):
306
        with tempfile.NamedTemporaryFile(
307
            suffix="out.avi", delete=False
308
        ) as tmp_not_playable_vid:
309
            shutil.copy(
310
                str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
311
            )
312
            with patch("os.remove", wraps=os.remove) as mock_remove:
313
                playable_vid = processing_utils.convert_video_to_playable_mp4(
314
                    tmp_not_playable_vid.name
315
                )
316
            # check tempfile got deleted
317
            assert not Path(mock_remove.call_args[0][0]).exists()
318
            assert processing_utils.video_is_playable(playable_vid)
319

320
    @patch("ffmpy.FFmpeg.run", side_effect=raise_ffmpy_runtime_exception)
321
    def test_video_conversion_returns_original_video_if_fails(
322
        self, mock_run, test_file_dir
323
    ):
324
        with tempfile.NamedTemporaryFile(
325
            suffix="out.avi", delete=False
326
        ) as tmp_not_playable_vid:
327
            shutil.copy(
328
                str(test_file_dir / "bad_video_sample.mp4"), tmp_not_playable_vid.name
329
            )
330
            playable_vid = processing_utils.convert_video_to_playable_mp4(
331
                tmp_not_playable_vid.name
332
            )
333
            # If the conversion succeeded it'd be .mp4
334
            assert Path(playable_vid).suffix == ".avi"
335

336

337
def test_add_root_url():
338
    data = {
339
        "file": {
340
            "path": "path",
341
            "url": "/file=path",
342
            "meta": {"_type": "gradio.FileData"},
343
        },
344
        "file2": {
345
            "path": "path2",
346
            "url": "https://www.gradio.app",
347
            "meta": {"_type": "gradio.FileData"},
348
        },
349
    }
350
    root_url = "http://localhost:7860"
351
    expected = {
352
        "file": {
353
            "path": "path",
354
            "url": f"{root_url}/file=path",
355
            "meta": {"_type": "gradio.FileData"},
356
        },
357
        "file2": {
358
            "path": "path2",
359
            "url": "https://www.gradio.app",
360
            "meta": {"_type": "gradio.FileData"},
361
        },
362
    }
363
    assert processing_utils.add_root_url(data, root_url, None) == expected
364
    new_root_url = "https://1234.gradio.live"
365
    new_expected = {
366
        "file": {
367
            "path": "path",
368
            "url": f"{new_root_url}/file=path",
369
            "meta": {"_type": "gradio.FileData"},
370
        },
371
        "file2": {
372
            "path": "path2",
373
            "url": "https://www.gradio.app",
374
            "meta": {"_type": "gradio.FileData"},
375
        },
376
    }
377
    assert (
378
        processing_utils.add_root_url(expected, new_root_url, root_url) == new_expected
379
    )
380

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.