gradio

Форк
0
/
test_utils.py 
455 строк · 15.2 Кб
1
from __future__ import annotations
2

3
import json
4
import os
5
import sys
6
import warnings
7
from pathlib import Path
8
from unittest.mock import MagicMock, patch
9

10
import pytest
11
from typing_extensions import Literal
12

13
from gradio import EventData, Request
14
from gradio.external_utils import format_ner_list
15
from gradio.utils import (
16
    abspath,
17
    append_unique_suffix,
18
    assert_configs_are_equivalent_besides_ids,
19
    check_function_inputs_match,
20
    colab_check,
21
    delete_none,
22
    diff,
23
    download_if_url,
24
    get_continuous_fn,
25
    get_extension_from_file_path_or_url,
26
    get_type_hints,
27
    ipython_check,
28
    is_in_or_equal,
29
    is_special_typed_parameter,
30
    kaggle_check,
31
    sagemaker_check,
32
    sanitize_list_for_csv,
33
    sanitize_value_for_csv,
34
    tex2svg,
35
    validate_url,
36
)
37

38
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
39

40

41
class TestUtils:
42
    @patch("IPython.get_ipython")
43
    def test_colab_check_no_ipython(self, mock_get_ipython):
44
        mock_get_ipython.return_value = None
45
        assert colab_check() is False
46

47
    @patch("IPython.get_ipython")
48
    def test_ipython_check_import_fail(self, mock_get_ipython):
49
        mock_get_ipython.side_effect = ImportError()
50
        assert ipython_check() is False
51

52
    @patch("IPython.get_ipython")
53
    def test_ipython_check_no_ipython(self, mock_get_ipython):
54
        mock_get_ipython.return_value = None
55
        assert ipython_check() is False
56

57
    def test_download_if_url_doesnt_crash_on_connection_error(self):
58
        in_article = "placeholder"
59
        out_article = download_if_url(in_article)
60
        assert out_article == in_article
61

62
        # non-printable characters are not allowed in URL address
63
        in_article = "text\twith\rnon-printable\nASCII\x00characters"
64
        out_article = download_if_url(in_article)
65
        assert out_article == in_article
66

67
        # only files with HTTP(S) URL can be downloaded
68
        in_article = "ftp://localhost/tmp/index.html"
69
        out_article = download_if_url(in_article)
70
        assert out_article == in_article
71

72
        in_article = "file:///C:/tmp/index.html"
73
        out_article = download_if_url(in_article)
74
        assert out_article == in_article
75

76
        # this address will raise ValueError during parsing
77
        in_article = "https://[unmatched_bracket#?:@/index.html"
78
        out_article = download_if_url(in_article)
79
        assert out_article == in_article
80

81
    def test_download_if_url_correct_parse(self):
82
        in_article = "https://github.com/gradio-app/gradio/blob/master/README.md"
83
        out_article = download_if_url(in_article)
84
        assert out_article != in_article
85

86
    def test_sagemaker_check_false(self):
87
        assert not sagemaker_check()
88

89
    def test_sagemaker_check_false_if_boto3_not_installed(self):
90
        with patch.dict(sys.modules, {"boto3": None}, clear=True):
91
            assert not sagemaker_check()
92

93
    @patch("boto3.session.Session.client")
94
    def test_sagemaker_check_true(self, mock_client):
95
        mock_client().get_caller_identity = MagicMock(
96
            return_value={
97
                "Arn": "arn:aws:sts::67364438:assumed-role/SageMaker-Datascients/SageMaker"
98
            }
99
        )
100
        assert sagemaker_check()
101

102
    def test_kaggle_check_false(self):
103
        assert not kaggle_check()
104

105
    def test_kaggle_check_true_when_run_type_set(self):
106
        with patch.dict(
107
            os.environ, {"KAGGLE_KERNEL_RUN_TYPE": "Interactive"}, clear=True
108
        ):
109
            assert kaggle_check()
110

111
    def test_kaggle_check_true_when_both_set(self):
112
        with patch.dict(
113
            os.environ,
114
            {"KAGGLE_KERNEL_RUN_TYPE": "Interactive", "GFOOTBALL_DATA_DIR": "./"},
115
            clear=True,
116
        ):
117
            assert kaggle_check()
118

119
    def test_kaggle_check_false_when_neither_set(self):
120
        with patch.dict(
121
            os.environ,
122
            {"KAGGLE_KERNEL_RUN_TYPE": "", "GFOOTBALL_DATA_DIR": ""},
123
            clear=True,
124
        ):
125
            assert not kaggle_check()
126

127

128
def test_assert_configs_are_equivalent():
129
    test_dir = Path(__file__).parent / "test_files"
130
    with open(test_dir / "xray_config.json") as fp:
131
        xray_config = json.load(fp)
132
    with open(test_dir / "xray_config_diff_ids.json") as fp:
133
        xray_config_diff_ids = json.load(fp)
134
    with open(test_dir / "xray_config_wrong.json") as fp:
135
        xray_config_wrong = json.load(fp)
136

137
    assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config)
138
    assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config_diff_ids)
139
    with pytest.raises(ValueError):
140
        assert_configs_are_equivalent_besides_ids(xray_config, xray_config_wrong)
141

142

143
class TestFormatNERList:
144
    def test_format_ner_list_standard(self):
145
        string = "Wolfgang lives in Berlin"
146
        groups = [
147
            {"entity_group": "PER", "start": 0, "end": 8},
148
            {"entity_group": "LOC", "start": 18, "end": 24},
149
        ]
150
        result = [
151
            ("", None),
152
            ("Wolfgang", "PER"),
153
            (" lives in ", None),
154
            ("Berlin", "LOC"),
155
            ("", None),
156
        ]
157
        assert format_ner_list(string, groups) == result
158

159
    def test_format_ner_list_empty(self):
160
        string = "I live in a city"
161
        groups = []
162
        result = [("I live in a city", None)]
163
        assert format_ner_list(string, groups) == result
164

165

166
class TestDeleteNone:
167
    """Credit: https://stackoverflow.com/questions/33797126/proper-way-to-remove-keys-in-dictionary-with-none-values-in-python"""
168

169
    def test_delete_none(self):
170
        input = {
171
            "a": 12,
172
            "b": 34,
173
            "c": None,
174
            "k": {
175
                "d": 34,
176
                "t": None,
177
                "m": [{"k": 23, "t": None}, [None, 1, 2, 3], {1, 2, None}],
178
                None: 123,
179
            },
180
        }
181
        truth = {
182
            "a": 12,
183
            "b": 34,
184
            "k": {
185
                "d": 34,
186
                "t": None,
187
                "m": [{"k": 23, "t": None}, [None, 1, 2, 3], {1, 2, None}],
188
                None: 123,
189
            },
190
        }
191
        assert delete_none(input) == truth
192

193

194
class TestSanitizeForCSV:
195
    def test_unsafe_value(self):
196
        assert sanitize_value_for_csv("=OPEN()") == "'=OPEN()"
197
        assert sanitize_value_for_csv("=1+2") == "'=1+2"
198
        assert sanitize_value_for_csv('=1+2";=1+2') == "'=1+2\";=1+2"
199

200
    def test_safe_value(self):
201
        assert sanitize_value_for_csv(4) == 4
202
        assert sanitize_value_for_csv(-44.44) == -44.44
203
        assert sanitize_value_for_csv("1+1=2") == "1+1=2"
204
        assert sanitize_value_for_csv("1aaa2") == "1aaa2"
205

206
    def test_list(self):
207
        assert sanitize_list_for_csv([4, "def=", "=gh+ij"]) == [4, "def=", "'=gh+ij"]
208
        assert sanitize_list_for_csv(
209
            [["=abc", "def", "gh,+ij"], ["abc", "=def", "+ghij"]]
210
        ) == [["'=abc", "def", "'gh,+ij"], ["abc", "'=def", "'+ghij"]]
211
        assert sanitize_list_for_csv([1, ["ab", "=de"]]) == [1, ["ab", "'=de"]]
212

213

214
class TestValidateURL:
215
    @pytest.mark.flaky
216
    def test_valid_urls(self):
217
        assert validate_url("https://www.gradio.app")
218
        assert validate_url("http://gradio.dev")
219
        assert validate_url(
220
            "https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg"
221
        )
222
        assert validate_url(
223
            "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
224
        )
225

226
    def test_invalid_urls(self):
227
        assert not (validate_url("C:/Users/"))
228
        assert not (validate_url("C:\\Users\\"))
229
        assert not (validate_url("/home/user"))
230

231

232
class TestAppendUniqueSuffix:
233
    def test_no_suffix(self):
234
        name = "test"
235
        list_of_names = ["test_1", "test_2"]
236
        assert append_unique_suffix(name, list_of_names) == name
237

238
    def test_first_suffix(self):
239
        name = "test"
240
        list_of_names = ["test", "test_-1"]
241
        assert append_unique_suffix(name, list_of_names) == "test_1"
242

243
    def test_later_suffix(self):
244
        name = "test"
245
        list_of_names = ["test", "test_1", "test_2", "test_3"]
246
        assert append_unique_suffix(name, list_of_names) == "test_4"
247

248

249
class TestAbspath:
250
    def test_abspath_no_symlink(self):
251
        resolved_path = str(abspath("../gradio/gradio/test_data/lion.jpg"))
252
        assert ".." not in resolved_path
253

254
    @pytest.mark.skipif(
255
        sys.platform.startswith("win"),
256
        reason="Windows doesn't allow creation of sym links without administrative privileges",
257
    )
258
    def test_abspath_symlink_path(self):
259
        os.symlink("gradio/test_data", "gradio/test_link", True)
260
        resolved_path = str(abspath("../gradio/gradio/test_link/lion.jpg"))
261
        os.unlink("gradio/test_link")
262
        assert "test_link" in resolved_path
263

264
    @pytest.mark.skipif(
265
        sys.platform.startswith("win"),
266
        reason="Windows doesn't allow creation of sym links without administrative privileges",
267
    )
268
    def test_abspath_symlink_dir(self):
269
        os.symlink("gradio/test_data", "gradio/test_link", True)
270
        full_path = os.path.join(os.getcwd(), "gradio/test_link/lion.jpg")
271
        resolved_path = str(abspath(full_path))
272
        os.unlink("gradio/test_link")
273
        assert "test_link" in resolved_path
274
        assert full_path == resolved_path
275

276

277
class TestGetTypeHints:
278
    def test_get_type_hints(self):
279
        class F:
280
            def __call__(self, s: str):
281
                return s
282

283
        class C:
284
            def f(self, s: str):
285
                return s
286

287
        def f(s: str):
288
            return s
289

290
        class GenericObject:
291
            pass
292

293
        test_objs = [F(), C().f, f]
294

295
        for x in test_objs:
296
            hints = get_type_hints(x)
297
            assert len(hints) == 1
298
            assert hints["s"] == str
299

300
        assert len(get_type_hints(GenericObject())) == 0
301

302
    def test_is_special_typed_parameter(self):
303
        def func(a: list[str], b: Literal["a", "b"], c, d: Request):
304
            pass
305

306
        hints = get_type_hints(func)
307
        assert not is_special_typed_parameter("a", hints)
308
        assert not is_special_typed_parameter("b", hints)
309
        assert not is_special_typed_parameter("c", hints)
310
        assert is_special_typed_parameter("d", hints)
311

312
    def test_is_special_typed_parameter_with_pipe(self):
313
        def func(a: Request, b: str | int, c: list[str]):
314
            pass
315

316
        hints = get_type_hints(func)
317
        assert is_special_typed_parameter("a", hints)
318
        assert not is_special_typed_parameter("b", hints)
319
        assert not is_special_typed_parameter("c", hints)
320

321

322
class TestCheckFunctionInputsMatch:
323
    def test_check_function_inputs_match(self):
324
        class F:
325
            def __call__(self, s: str, evt: EventData):
326
                return s
327

328
        class C:
329
            def f(self, s: str, evt: EventData):
330
                return s
331

332
        def f(s: str, evt: EventData):
333
            return s
334

335
        test_objs = [F(), C().f, f]
336

337
        with warnings.catch_warnings():
338
            warnings.simplefilter("error")  # Ensure there're no warnings raised here.
339

340
            for x in test_objs:
341
                check_function_inputs_match(x, [None], False)
342

343

344
class TestGetContinuousFn:
345
    @pytest.mark.asyncio
346
    async def test_get_continuous_fn(self):
347
        def int_return(x):  # for origin condition
348
            return x + 1
349

350
        def int_yield(x):  # new condition
351
            for _i in range(2):
352
                yield x
353
                x += 1
354

355
        def list_yield(x):  # new condition
356
            for _i in range(2):
357
                yield x
358
                x += [1]
359

360
        agen_int_return = get_continuous_fn(fn=int_return, every=0.01)
361
        agen_int_yield = get_continuous_fn(fn=int_yield, every=0.01)
362
        agen_list_yield = get_continuous_fn(fn=list_yield, every=0.01)
363
        agener_int_return = agen_int_return(1)
364
        agener_int = agen_int_yield(1)  # Primitive
365
        agener_list = agen_list_yield([1])  # Reference
366
        assert await agener_int_return.__anext__() == 2
367
        assert await agener_int_return.__anext__() == 2
368
        assert await agener_int.__anext__() == 1
369
        assert await agener_int.__anext__() == 2
370
        assert await agener_int.__anext__() == 1
371
        assert [1] == await agener_list.__anext__()
372
        assert [1, 1] == await agener_list.__anext__()
373
        assert [1, 1, 1] == await agener_list.__anext__()
374

375
    @pytest.mark.asyncio
376
    async def test_get_continuous_fn_with_async_function(self):
377
        async def async_int_return(x):  # for origin condition
378
            return x + 1
379

380
        agen_int_return = get_continuous_fn(fn=async_int_return, every=0.01)
381
        agener_int_return = agen_int_return(1)
382
        assert await agener_int_return.__anext__() == 2
383
        assert await agener_int_return.__anext__() == 2
384

385
    @pytest.mark.asyncio
386
    async def test_get_continuous_fn_with_async_generator(self):
387
        async def async_int_yield(x):  # new condition
388
            for _i in range(2):
389
                yield x
390
                x += 1
391

392
        async def async_list_yield(x):  # new condition
393
            for _i in range(2):
394
                yield x
395
                x += [1]
396

397
        agen_int_yield = get_continuous_fn(fn=async_int_yield, every=0.01)
398
        agen_list_yield = get_continuous_fn(fn=async_list_yield, every=0.01)
399
        agener_int = agen_int_yield(1)  # Primitive
400
        agener_list = agen_list_yield([1])  # Reference
401
        assert await agener_int.__anext__() == 1
402
        assert await agener_int.__anext__() == 2
403
        assert await agener_int.__anext__() == 1
404
        assert [1] == await agener_list.__anext__()
405
        assert [1, 1] == await agener_list.__anext__()
406
        assert [1, 1, 1] == await agener_list.__anext__()
407

408

409
def test_tex2svg_preserves_matplotlib_backend():
410
    import matplotlib
411

412
    matplotlib.use("svg")
413
    tex2svg("1+1=2")
414
    assert matplotlib.get_backend() == "svg"
415
    with pytest.raises(
416
        Exception  # specifically a pyparsing.ParseException but not important here
417
    ):
418
        tex2svg("$$$1+1=2$$$")
419
    assert matplotlib.get_backend() == "svg"
420

421

422
def test_is_in_or_equal():
423
    assert is_in_or_equal("files/lion.jpg", "files/lion.jpg")
424
    assert is_in_or_equal("files/lion.jpg", "files")
425
    assert not is_in_or_equal("files", "files/lion.jpg")
426
    assert is_in_or_equal("/home/usr/notes.txt", "/home/usr/")
427
    assert not is_in_or_equal("/home/usr/subdirectory", "/home/usr/notes.txt")
428
    assert not is_in_or_equal("/home/usr/../../etc/notes.txt", "/home/usr/")
429
    assert not is_in_or_equal("/safe_dir/subdir/../../unsafe_file.txt", "/safe_dir/")
430

431

432
@pytest.mark.parametrize(
433
    "path_or_url, extension",
434
    [
435
        ("https://example.com/avatar/xxxx.mp4?se=2023-11-16T06:51:23Z&sp=r", "mp4"),
436
        ("/home/user/documents/example.pdf", "pdf"),
437
        ("C:\\Users\\user\\documents\\example.png", "png"),
438
        ("C:/Users/user/documents/example", ""),
439
    ],
440
)
441
def test_get_extension_from_file_path_or_url(path_or_url, extension):
442
    assert get_extension_from_file_path_or_url(path_or_url) == extension
443

444

445
@pytest.mark.parametrize(
446
    "old, new, expected_diff",
447
    [
448
        ({"a": 1, "b": 2}, {"a": 1, "b": 2}, []),
449
        ({}, {"a": 1, "b": 2}, [("add", ["a"], 1), ("add", ["b"], 2)]),
450
        (["a", "b"], {"a": 1, "b": 2}, [("replace", [], {"a": 1, "b": 2})]),
451
        ("abc", "abcdef", [("append", [], "def")]),
452
    ],
453
)
454
def test_diff(old, new, expected_diff):
455
    assert diff(old, new) == expected_diff
456

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.