1
from __future__ import annotations
7
from pathlib import Path
8
from unittest.mock import MagicMock, patch
11
from typing_extensions import Literal
13
from gradio import EventData, Request
14
from gradio.external_utils import format_ner_list
15
from gradio.utils import (
18
assert_configs_are_equivalent_besides_ids,
19
check_function_inputs_match,
25
get_extension_from_file_path_or_url,
29
is_special_typed_parameter,
32
sanitize_list_for_csv,
33
sanitize_value_for_csv,
38
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
42
@patch("IPython.get_ipython")
43
def test_colab_check_no_ipython(self, mock_get_ipython):
44
mock_get_ipython.return_value = None
45
assert colab_check() is False
47
@patch("IPython.get_ipython")
48
def test_ipython_check_import_fail(self, mock_get_ipython):
49
mock_get_ipython.side_effect = ImportError()
50
assert ipython_check() is False
52
@patch("IPython.get_ipython")
53
def test_ipython_check_no_ipython(self, mock_get_ipython):
54
mock_get_ipython.return_value = None
55
assert ipython_check() is False
57
def test_download_if_url_doesnt_crash_on_connection_error(self):
58
in_article = "placeholder"
59
out_article = download_if_url(in_article)
60
assert out_article == in_article
62
# non-printable characters are not allowed in URL address
63
in_article = "text\twith\rnon-printable\nASCII\x00characters"
64
out_article = download_if_url(in_article)
65
assert out_article == in_article
67
# only files with HTTP(S) URL can be downloaded
68
in_article = "ftp://localhost/tmp/index.html"
69
out_article = download_if_url(in_article)
70
assert out_article == in_article
72
in_article = "file:///C:/tmp/index.html"
73
out_article = download_if_url(in_article)
74
assert out_article == in_article
76
# this address will raise ValueError during parsing
77
in_article = "https://[unmatched_bracket#?:@/index.html"
78
out_article = download_if_url(in_article)
79
assert out_article == in_article
81
def test_download_if_url_correct_parse(self):
82
in_article = "https://github.com/gradio-app/gradio/blob/master/README.md"
83
out_article = download_if_url(in_article)
84
assert out_article != in_article
86
def test_sagemaker_check_false(self):
87
assert not sagemaker_check()
89
def test_sagemaker_check_false_if_boto3_not_installed(self):
90
with patch.dict(sys.modules, {"boto3": None}, clear=True):
91
assert not sagemaker_check()
93
@patch("boto3.session.Session.client")
94
def test_sagemaker_check_true(self, mock_client):
95
mock_client().get_caller_identity = MagicMock(
97
"Arn": "arn:aws:sts::67364438:assumed-role/SageMaker-Datascients/SageMaker"
100
assert sagemaker_check()
102
def test_kaggle_check_false(self):
103
assert not kaggle_check()
105
def test_kaggle_check_true_when_run_type_set(self):
107
os.environ, {"KAGGLE_KERNEL_RUN_TYPE": "Interactive"}, clear=True
109
assert kaggle_check()
111
def test_kaggle_check_true_when_both_set(self):
114
{"KAGGLE_KERNEL_RUN_TYPE": "Interactive", "GFOOTBALL_DATA_DIR": "./"},
117
assert kaggle_check()
119
def test_kaggle_check_false_when_neither_set(self):
122
{"KAGGLE_KERNEL_RUN_TYPE": "", "GFOOTBALL_DATA_DIR": ""},
125
assert not kaggle_check()
128
def test_assert_configs_are_equivalent():
129
test_dir = Path(__file__).parent / "test_files"
130
with open(test_dir / "xray_config.json") as fp:
131
xray_config = json.load(fp)
132
with open(test_dir / "xray_config_diff_ids.json") as fp:
133
xray_config_diff_ids = json.load(fp)
134
with open(test_dir / "xray_config_wrong.json") as fp:
135
xray_config_wrong = json.load(fp)
137
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config)
138
assert assert_configs_are_equivalent_besides_ids(xray_config, xray_config_diff_ids)
139
with pytest.raises(ValueError):
140
assert_configs_are_equivalent_besides_ids(xray_config, xray_config_wrong)
143
class TestFormatNERList:
144
def test_format_ner_list_standard(self):
145
string = "Wolfgang lives in Berlin"
147
{"entity_group": "PER", "start": 0, "end": 8},
148
{"entity_group": "LOC", "start": 18, "end": 24},
153
(" lives in ", None),
157
assert format_ner_list(string, groups) == result
159
def test_format_ner_list_empty(self):
160
string = "I live in a city"
162
result = [("I live in a city", None)]
163
assert format_ner_list(string, groups) == result
167
"""Credit: https://stackoverflow.com/questions/33797126/proper-way-to-remove-keys-in-dictionary-with-none-values-in-python"""
169
def test_delete_none(self):
177
"m": [{"k": 23, "t": None}, [None, 1, 2, 3], {1, 2, None}],
187
"m": [{"k": 23, "t": None}, [None, 1, 2, 3], {1, 2, None}],
191
assert delete_none(input) == truth
194
class TestSanitizeForCSV:
195
def test_unsafe_value(self):
196
assert sanitize_value_for_csv("=OPEN()") == "'=OPEN()"
197
assert sanitize_value_for_csv("=1+2") == "'=1+2"
198
assert sanitize_value_for_csv('=1+2";=1+2') == "'=1+2\";=1+2"
200
def test_safe_value(self):
201
assert sanitize_value_for_csv(4) == 4
202
assert sanitize_value_for_csv(-44.44) == -44.44
203
assert sanitize_value_for_csv("1+1=2") == "1+1=2"
204
assert sanitize_value_for_csv("1aaa2") == "1aaa2"
207
assert sanitize_list_for_csv([4, "def=", "=gh+ij"]) == [4, "def=", "'=gh+ij"]
208
assert sanitize_list_for_csv(
209
[["=abc", "def", "gh,+ij"], ["abc", "=def", "+ghij"]]
210
) == [["'=abc", "def", "'gh,+ij"], ["abc", "'=def", "'+ghij"]]
211
assert sanitize_list_for_csv([1, ["ab", "=de"]]) == [1, ["ab", "'=de"]]
214
class TestValidateURL:
216
def test_valid_urls(self):
217
assert validate_url("https://www.gradio.app")
218
assert validate_url("http://gradio.dev")
220
"https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg"
223
"https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/bread_small.png"
226
def test_invalid_urls(self):
227
assert not (validate_url("C:/Users/"))
228
assert not (validate_url("C:\\Users\\"))
229
assert not (validate_url("/home/user"))
232
class TestAppendUniqueSuffix:
233
def test_no_suffix(self):
235
list_of_names = ["test_1", "test_2"]
236
assert append_unique_suffix(name, list_of_names) == name
238
def test_first_suffix(self):
240
list_of_names = ["test", "test_-1"]
241
assert append_unique_suffix(name, list_of_names) == "test_1"
243
def test_later_suffix(self):
245
list_of_names = ["test", "test_1", "test_2", "test_3"]
246
assert append_unique_suffix(name, list_of_names) == "test_4"
250
def test_abspath_no_symlink(self):
251
resolved_path = str(abspath("../gradio/gradio/test_data/lion.jpg"))
252
assert ".." not in resolved_path
255
sys.platform.startswith("win"),
256
reason="Windows doesn't allow creation of sym links without administrative privileges",
258
def test_abspath_symlink_path(self):
259
os.symlink("gradio/test_data", "gradio/test_link", True)
260
resolved_path = str(abspath("../gradio/gradio/test_link/lion.jpg"))
261
os.unlink("gradio/test_link")
262
assert "test_link" in resolved_path
265
sys.platform.startswith("win"),
266
reason="Windows doesn't allow creation of sym links without administrative privileges",
268
def test_abspath_symlink_dir(self):
269
os.symlink("gradio/test_data", "gradio/test_link", True)
270
full_path = os.path.join(os.getcwd(), "gradio/test_link/lion.jpg")
271
resolved_path = str(abspath(full_path))
272
os.unlink("gradio/test_link")
273
assert "test_link" in resolved_path
274
assert full_path == resolved_path
277
class TestGetTypeHints:
278
def test_get_type_hints(self):
280
def __call__(self, s: str):
293
test_objs = [F(), C().f, f]
296
hints = get_type_hints(x)
297
assert len(hints) == 1
298
assert hints["s"] == str
300
assert len(get_type_hints(GenericObject())) == 0
302
def test_is_special_typed_parameter(self):
303
def func(a: list[str], b: Literal["a", "b"], c, d: Request):
306
hints = get_type_hints(func)
307
assert not is_special_typed_parameter("a", hints)
308
assert not is_special_typed_parameter("b", hints)
309
assert not is_special_typed_parameter("c", hints)
310
assert is_special_typed_parameter("d", hints)
312
def test_is_special_typed_parameter_with_pipe(self):
313
def func(a: Request, b: str | int, c: list[str]):
316
hints = get_type_hints(func)
317
assert is_special_typed_parameter("a", hints)
318
assert not is_special_typed_parameter("b", hints)
319
assert not is_special_typed_parameter("c", hints)
322
class TestCheckFunctionInputsMatch:
323
def test_check_function_inputs_match(self):
325
def __call__(self, s: str, evt: EventData):
329
def f(self, s: str, evt: EventData):
332
def f(s: str, evt: EventData):
335
test_objs = [F(), C().f, f]
337
with warnings.catch_warnings():
338
warnings.simplefilter("error") # Ensure there're no warnings raised here.
341
check_function_inputs_match(x, [None], False)
344
class TestGetContinuousFn:
346
async def test_get_continuous_fn(self):
347
def int_return(x): # for origin condition
350
def int_yield(x): # new condition
355
def list_yield(x): # new condition
360
agen_int_return = get_continuous_fn(fn=int_return, every=0.01)
361
agen_int_yield = get_continuous_fn(fn=int_yield, every=0.01)
362
agen_list_yield = get_continuous_fn(fn=list_yield, every=0.01)
363
agener_int_return = agen_int_return(1)
364
agener_int = agen_int_yield(1) # Primitive
365
agener_list = agen_list_yield([1]) # Reference
366
assert await agener_int_return.__anext__() == 2
367
assert await agener_int_return.__anext__() == 2
368
assert await agener_int.__anext__() == 1
369
assert await agener_int.__anext__() == 2
370
assert await agener_int.__anext__() == 1
371
assert [1] == await agener_list.__anext__()
372
assert [1, 1] == await agener_list.__anext__()
373
assert [1, 1, 1] == await agener_list.__anext__()
376
async def test_get_continuous_fn_with_async_function(self):
377
async def async_int_return(x): # for origin condition
380
agen_int_return = get_continuous_fn(fn=async_int_return, every=0.01)
381
agener_int_return = agen_int_return(1)
382
assert await agener_int_return.__anext__() == 2
383
assert await agener_int_return.__anext__() == 2
386
async def test_get_continuous_fn_with_async_generator(self):
387
async def async_int_yield(x): # new condition
392
async def async_list_yield(x): # new condition
397
agen_int_yield = get_continuous_fn(fn=async_int_yield, every=0.01)
398
agen_list_yield = get_continuous_fn(fn=async_list_yield, every=0.01)
399
agener_int = agen_int_yield(1) # Primitive
400
agener_list = agen_list_yield([1]) # Reference
401
assert await agener_int.__anext__() == 1
402
assert await agener_int.__anext__() == 2
403
assert await agener_int.__anext__() == 1
404
assert [1] == await agener_list.__anext__()
405
assert [1, 1] == await agener_list.__anext__()
406
assert [1, 1, 1] == await agener_list.__anext__()
409
def test_tex2svg_preserves_matplotlib_backend():
412
matplotlib.use("svg")
414
assert matplotlib.get_backend() == "svg"
416
Exception # specifically a pyparsing.ParseException but not important here
418
tex2svg("$$$1+1=2$$$")
419
assert matplotlib.get_backend() == "svg"
422
def test_is_in_or_equal():
423
assert is_in_or_equal("files/lion.jpg", "files/lion.jpg")
424
assert is_in_or_equal("files/lion.jpg", "files")
425
assert not is_in_or_equal("files", "files/lion.jpg")
426
assert is_in_or_equal("/home/usr/notes.txt", "/home/usr/")
427
assert not is_in_or_equal("/home/usr/subdirectory", "/home/usr/notes.txt")
428
assert not is_in_or_equal("/home/usr/../../etc/notes.txt", "/home/usr/")
429
assert not is_in_or_equal("/safe_dir/subdir/../../unsafe_file.txt", "/safe_dir/")
432
@pytest.mark.parametrize(
433
"path_or_url, extension",
435
("https://example.com/avatar/xxxx.mp4?se=2023-11-16T06:51:23Z&sp=r", "mp4"),
436
("/home/user/documents/example.pdf", "pdf"),
437
("C:\\Users\\user\\documents\\example.png", "png"),
438
("C:/Users/user/documents/example", ""),
441
def test_get_extension_from_file_path_or_url(path_or_url, extension):
442
assert get_extension_from_file_path_or_url(path_or_url) == extension
445
@pytest.mark.parametrize(
446
"old, new, expected_diff",
448
({"a": 1, "b": 2}, {"a": 1, "b": 2}, []),
449
({}, {"a": 1, "b": 2}, [("add", ["a"], 1), ("add", ["b"], 2)]),
450
(["a", "b"], {"a": 1, "b": 2}, [("replace", [], {"a": 1, "b": 2})]),
451
("abc", "abcdef", [("append", [], "def")]),
454
def test_diff(old, new, expected_diff):
455
assert diff(old, new) == expected_diff