8
from pathlib import Path
9
from unittest.mock import patch
11
import gradio_client as grc
13
from gradio_client import media_data
14
from gradio_client import utils as client_utils
15
from pydub import AudioSegment
16
from starlette.testclient import TestClient
20
from gradio import utils
23
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
25
def test_handle_single_input(self, patched_cache_folder):
26
examples = gr.Examples(["hello", "hi"], gr.Textbox())
27
assert examples.processed_examples == [["hello"], ["hi"]]
29
examples = gr.Examples([["hello"]], gr.Textbox())
30
assert examples.processed_examples == [["hello"]]
32
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
34
client_utils.encode_file_to_base64(
35
examples.processed_examples[0][0]["path"]
37
== media_data.BASE64_IMAGE
40
def test_handle_multiple_inputs(self, patched_cache_folder):
41
examples = gr.Examples(
42
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
44
assert examples.processed_examples[0][0] == "hello"
46
client_utils.encode_file_to_base64(
47
examples.processed_examples[0][1]["path"]
49
== media_data.BASE64_IMAGE
52
def test_handle_directory(self, patched_cache_folder):
53
examples = gr.Examples("test/test_files/images", gr.Image())
54
assert len(examples.processed_examples) == 2
55
for row in examples.processed_examples:
58
client_utils.encode_file_to_base64(output["path"])
59
== media_data.BASE64_IMAGE
62
def test_handle_directory_with_log_file(self, patched_cache_folder):
63
examples = gr.Examples(
64
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
66
ex = client_utils.traverse(
67
examples.processed_examples,
68
lambda s: client_utils.encode_file_to_base64(s["path"]),
69
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
72
[media_data.BASE64_IMAGE, "hello"],
73
[media_data.BASE64_IMAGE, "hi"],
75
for sample in examples.dataset.samples:
76
assert os.path.isabs(sample[0]["path"])
78
def test_examples_per_page(self, patched_cache_folder):
79
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
80
assert examples.dataset.get_config()["samples_per_page"] == 2
82
def test_no_preprocessing(self, patched_cache_folder):
85
textbox = gr.Textbox()
87
examples = gr.Examples(
88
examples=["test/test_files/bus.png"],
91
fn=lambda x: x["path"],
96
prediction = examples.load_from_cache(0)
98
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
101
def test_no_postprocessing(self, patched_cache_folder):
106
"path": "test/test_files/bus.png",
116
examples = gr.Examples(
125
prediction = examples.load_from_cache(0)
126
file = prediction[0].root[0].image.path
127
assert client_utils.encode_url_or_file_to_base64(
129
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png")
132
def test_setting_cache_dir_env_variable(monkeypatch):
133
temp_dir = tempfile.mkdtemp()
134
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir)
139
examples = gr.Examples(
140
examples=["test/test_files/bus.png"],
146
prediction = examples.load_from_cache(0)
147
path_to_cached_file = Path(prediction[0].path)
148
assert utils.is_in_or_equal(path_to_cached_file, temp_dir)
149
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False)
152
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
153
class TestExamplesDataset:
154
def test_no_headers(self, patched_cache_folder):
155
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
156
assert examples.dataset.headers == []
158
def test_all_headers(self, patched_cache_folder):
159
examples = gr.Examples(
160
"test/test_files/images_log",
161
[gr.Image(label="im"), gr.Text(label="your text")],
163
assert examples.dataset.headers == ["im", "your text"]
165
def test_some_headers(self, patched_cache_folder):
166
examples = gr.Examples(
167
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
169
assert examples.dataset.headers == ["im", ""]
172
def test_example_caching_relaunch(connect):
176
with gr.Blocks() as demo:
177
txt = gr.Textbox(label="Input")
178
txt_2 = gr.Textbox(label="Input 2")
179
txt_3 = gr.Textbox(value="", label="Output")
180
btn = gr.Button(value="Submit")
181
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
183
[["hi", "Adam"], ["hello", "Eve"]],
191
with connect(demo) as client:
192
assert client.predict(1, api_name="/examples") == (
198
# Let the server shut down
201
with connect(demo) as client:
202
assert client.predict(1, api_name="/examples") == (
209
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
210
class TestProcessExamples:
211
def test_caching(self, patched_cache_folder):
213
lambda x: f"Hello {x}",
216
examples=[["World"], ["Dunya"], ["Monde"]],
219
prediction = io.examples_handler.load_from_cache(1)
220
assert prediction[0] == "Hello Dunya"
222
def test_example_caching_relaunch(self, patched_cache_folder, connect):
226
with gr.Blocks() as demo:
227
txt = gr.Textbox(label="Input")
228
txt_2 = gr.Textbox(label="Input 2")
229
txt_3 = gr.Textbox(value="", label="Output")
230
btn = gr.Button(value="Submit")
231
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
233
[["hi", "Adam"], ["hello", "Eve"]],
241
with connect(demo) as client:
242
assert client.predict(1, api_name="/examples") == (
248
with connect(demo) as client:
249
assert client.predict(1, api_name="/examples") == (
255
def test_caching_image(self, patched_cache_folder):
260
examples=[["test/test_files/bus.png"]],
263
prediction = io.examples_handler.load_from_cache(0)
264
assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith(
265
"data:image/png;base64,iVBORw0KGgoAAA"
268
def test_caching_audio(self, patched_cache_folder):
273
examples=[["test/test_files/audio_sample.wav"]],
276
prediction = io.examples_handler.load_from_cache(0)
277
file = prediction[0].path
278
assert client_utils.encode_url_or_file_to_base64(file).startswith(
279
"data:audio/wav;base64,UklGRgA/"
282
def test_caching_with_update(self, patched_cache_folder):
284
lambda x: gr.update(visible=False),
287
examples=[["World"], ["Dunya"], ["Monde"]],
290
prediction = io.examples_handler.load_from_cache(1)
291
assert prediction[0] == {
293
"__type__": "update",
296
def test_caching_with_mix_update(self, patched_cache_folder):
298
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
301
examples=[["World"], ["Dunya"], ["Monde"]],
304
prediction = io.examples_handler.load_from_cache(1)
305
assert prediction[0] == {
308
"__type__": "update",
311
def test_caching_with_dict(self, patched_cache_folder):
316
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"},
322
prediction = io.examples_handler.load_from_cache(0)
323
assert prediction == [
324
{"lines": 4, "__type__": "update", "interactive": False},
325
gr.Label.data_model(**{"label": "lion", "confidences": None}),
328
def test_caching_with_generators(self, patched_cache_folder):
329
def test_generator(x):
330
for y in range(len(x)):
331
yield "Your output: " + x[: y + 1]
340
prediction = io.examples_handler.load_from_cache(0)
341
assert prediction[0] == "Your output: abcdef"
343
def test_caching_with_generators_and_streamed_output(self, patched_cache_folder):
344
file_dir = Path(Path(__file__).parent, "test_files")
345
audio = str(file_dir / "audio_sample.wav")
347
def test_generator(x):
348
for y in range(int(x)):
354
[gr.Audio(streaming=True), "number"],
358
prediction = io.examples_handler.load_from_cache(0)
359
len_input_audio = len(AudioSegment.from_wav(audio))
360
len_output_audio = len(AudioSegment.from_wav(prediction[0].path))
361
length_ratio = len_output_audio / len_input_audio
362
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
363
assert float(prediction[1]) == 10.0
365
def test_caching_with_async_generators(self, patched_cache_folder):
366
async def test_generator(x):
367
for y in range(len(x)):
368
yield "Your output: " + x[: y + 1]
377
prediction = io.examples_handler.load_from_cache(0)
378
assert prediction[0] == "Your output: abcdef"
380
def test_raise_helpful_error_message_if_providing_partial_examples(
381
self, patched_cache_folder, tmp_path
388
match="^Examples are being cached but not all input components have",
390
with pytest.raises(Exception):
393
inputs=["text", "text"],
395
examples=[["foo"], ["bar"]],
401
match="^Examples are being cached but not all input components have",
403
with pytest.raises(Exception):
406
inputs=["text", "text"],
408
examples=[["foo", "bar"], ["bar", None]],
412
def foo_no_exception(a, b=2):
417
inputs=["text", "number"],
419
examples=[["foo"], ["bar"]],
423
def many_missing(a, b, c):
428
match="^Examples are being cached but not all input components have",
430
with pytest.raises(Exception):
433
inputs=["text", "number", "number"],
435
examples=[["foo", None, None], ["bar", 2, 3]],
439
def test_caching_with_batch(self, patched_cache_folder):
440
def trim_words(words, lens):
441
trimmed_words = [word[:length] for word, length in zip(words, lens)]
442
return [trimmed_words]
446
["textbox", gr.Number(precision=0)],
450
examples=[["hello", 3], ["hi", 4]],
453
prediction = io.examples_handler.load_from_cache(0)
454
assert prediction == ["hel"]
456
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder):
457
def trim_words(words, lens):
458
trimmed_words = [word[:length] for word, length in zip(words, lens)]
459
return trimmed_words, lens
463
["textbox", gr.Number(precision=0)],
464
["textbox", gr.Number(precision=0)],
467
examples=[["hello", 3], ["hi", 4]],
470
prediction = io.examples_handler.load_from_cache(0)
471
assert prediction == ["hel", "3"]
473
def test_caching_with_non_io_component(self, patched_cache_folder):
475
return name, gr.update(visible=True)
479
with gr.Column(visible=False) as c:
482
examples = gr.Examples(
483
[["John"], ["Mary"]],
490
prediction = examples.load_from_cache(0)
491
assert prediction == ["John", {"visible": True, "__type__": "update"}]
493
def test_end_to_end(self, patched_cache_folder):
494
def concatenate(str1, str2):
497
with gr.Blocks() as demo:
500
t1.submit(concatenate, [t1, t2], t2)
503
[["Hello,", None], ["Michael", None]],
505
api_name="load_example",
508
app, _, _ = demo.launch(prevent_thread_lock=True)
509
client = TestClient(app)
511
response = client.post("/api/load_example/", json={"data": [0]})
512
assert response.json()["data"] == [
523
"show_copy_button": False,
524
"__type__": "update",
531
response = client.post("/api/load_example/", json={"data": [1]})
532
assert response.json()["data"] == [
543
"show_copy_button": False,
544
"__type__": "update",
551
def test_end_to_end_cache_examples(self, patched_cache_folder):
552
def concatenate(str1, str2):
553
return f"{str1} {str2}"
555
with gr.Blocks() as demo:
558
t1.submit(concatenate, [t1, t2], t2)
561
examples=[["Hello,", "World"], ["Michael", "Jordan"]],
566
api_name="load_example",
569
app, _, _ = demo.launch(prevent_thread_lock=True)
570
client = TestClient(app)
572
response = client.post("/api/load_example/", json={"data": [0]})
573
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
575
response = client.post("/api/load_example/", json={"data": [1]})
576
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
579
def test_multiple_file_flagging(tmp_path):
580
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
582
fn=lambda *x: list(x),
584
gr.Image(type="filepath", label="frame 1"),
585
gr.Image(type="filepath", label="frame 2"),
587
outputs=[gr.Files()],
588
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
591
prediction = io.examples_handler.load_from_cache(0)
593
assert len(prediction[0].root) == 2
594
assert all(isinstance(d, gr.FileData) for d in prediction[0].root)
597
def test_examples_keep_all_suffixes(tmp_path):
598
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())):
599
file_1 = tmp_path / "foo.bar.txt"
600
file_1.write_text("file 1")
601
file_2 = tmp_path / "file_2"
602
file_2.mkdir(parents=True)
603
file_2 = file_2 / "foo.bar.txt"
604
file_2.write_text("file 2")
609
examples=[[str(file_1)], [str(file_2)]],
612
prediction = io.examples_handler.load_from_cache(0)
613
assert Path(prediction[0].path).read_text() == "file 1"
614
assert prediction[0].orig_name == "foo.bar.txt"
615
assert prediction[0].path.endswith("foo.bar.txt")
616
prediction = io.examples_handler.load_from_cache(1)
617
assert Path(prediction[0].path).read_text() == "file 2"
618
assert prediction[0].orig_name == "foo.bar.txt"
619
assert prediction[0].path.endswith("foo.bar.txt")
622
def test_make_waveform_with_spaces_in_filename():
623
with tempfile.TemporaryDirectory() as tmpdirname:
624
audio = os.path.join(tmpdirname, "test audio.wav")
625
shutil.copy("test/test_files/audio_sample.wav", audio)
626
waveform = gr.make_waveform(audio)
627
assert waveform.endswith(".mp4")
637
"stream=width,height",
643
result = subprocess.run(command, capture_output=True, text=True, check=True)
644
output = result.stdout
645
data = json.loads(output)
647
width = data["streams"][0]["width"]
648
height = data["streams"][0]["height"]
652
except subprocess.CalledProcessError as e:
653
print("Error retrieving resolution of output waveform video:", e)
656
def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch):
658
Test that make_waveform raises an exception if ffmpeg fails,
659
instead of returning a path to a non-existent or empty file.
661
audio = tmp_path / "test audio.wav"
662
shutil.copy("test/test_files/audio_sample.wav", audio)
664
def _failing_ffmpeg(*args, **kwargs):
665
raise subprocess.CalledProcessError(1, "ffmpeg")
667
monkeypatch.setattr(subprocess, "call", _failing_ffmpeg)
668
with pytest.raises(Exception):
669
gr.make_waveform(str(audio))
672
class TestProgressBar:
674
async def test_progress_bar(self):
675
with gr.Blocks() as demo:
677
greeting = gr.Textbox()
678
button = gr.Button(value="Greet")
680
def greet(s, prog=gr.Progress()):
681
prog(0, desc="start")
683
for _ in prog.tqdm(range(4), unit="iter"):
686
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
688
return f"Hello, {s}!"
690
button.click(greet, name, greeting)
691
demo.queue(max_size=1).launch(prevent_thread_lock=True)
693
client = grc.Client(demo.local_url)
694
job = client.submit("Gradio")
697
while not job.done():
698
status = job.status()
700
status.progress_data[0].index if status.progress_data else None,
701
status.progress_data[0].desc if status.progress_data else None,
703
if update != (None, None) and (
704
len(status_updates) == 0 or status_updates[-1] != update
706
status_updates.append(update)
709
assert status_updates == [
719
async def test_progress_bar_track_tqdm(self):
720
with gr.Blocks() as demo:
722
greeting = gr.Textbox()
723
button = gr.Button(value="Greet")
725
def greet(s, prog=gr.Progress(track_tqdm=True)):
726
prog(0, desc="start")
728
for _ in prog.tqdm(range(4), unit="iter"):
731
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
733
return f"Hello, {s}!"
735
button.click(greet, name, greeting)
736
demo.queue(max_size=1).launch(prevent_thread_lock=True)
738
client = grc.Client(demo.local_url)
739
job = client.submit("Gradio")
742
while not job.done():
743
status = job.status()
745
status.progress_data[0].index if status.progress_data else None,
746
status.progress_data[0].desc if status.progress_data else None,
748
if update != (None, None) and (
749
len(status_updates) == 0 or status_updates[-1] != update
751
status_updates.append(update)
754
assert status_updates == [
767
async def test_progress_bar_track_tqdm_without_iterable(self):
768
def greet(s, _=gr.Progress(track_tqdm=True)):
769
with tqdm(total=len(s)) as progress_bar:
771
progress_bar.update()
773
return f"Hello, {s}!"
775
demo = gr.Interface(greet, "text", "text")
776
demo.queue().launch(prevent_thread_lock=True)
778
client = grc.Client(demo.local_url)
779
job = client.submit("Gradio")
782
while not job.done():
783
status = job.status()
785
status.progress_data[0].index if status.progress_data else None,
786
status.progress_data[0].unit if status.progress_data else None,
788
if update != (None, None) and (
789
len(status_updates) == 0 or status_updates[-1] != update
791
status_updates.append(update)
794
assert status_updates == [
804
async def test_info_and_warning_alerts(self):
807
gr.Info(f"Letter {_c}")
810
gr.Warning("Too short!")
812
return f"Hello, {s}!"
814
demo = gr.Interface(greet, "text", "text")
815
demo.queue().launch(prevent_thread_lock=True)
817
client = grc.Client(demo.local_url)
818
job = client.submit("Jon")
821
while not job.done():
822
status = job.status()
824
if update is not None and (
825
len(status_updates) == 0 or status_updates[-1] != update
827
status_updates.append(update)
830
assert status_updates == [
831
("Letter J", "info"),
832
("Letter o", "info"),
833
("Letter n", "info"),
834
("Too short!", "warning"),
839
@pytest.mark.parametrize("async_handler", [True, False])
840
async def test_info_isolation(async_handler: bool):
841
async def greet_async(name):
842
await asyncio.sleep(2)
843
gr.Info(f"Hello {name}")
844
await asyncio.sleep(1)
847
def greet_sync(name):
849
gr.Info(f"Hello {name}")
854
greet_async if async_handler else greet_sync,
859
demo.launch(prevent_thread_lock=True)
861
async def session_interaction(name, delay=0):
862
client = grc.Client(demo.local_url)
863
job = client.submit(name)
866
while not job.done():
867
status = job.status()
869
if update is not None and (
870
len(status_updates) == 0 or status_updates[-1] != update
872
status_updates.append(update)
874
return status_updates[-1][0] if status_updates else None
876
alice_logs, bob_logs = await asyncio.gather(
877
session_interaction("Alice"),
878
session_interaction("Bob", delay=1),
881
assert alice_logs == "Hello Alice"
882
assert bob_logs == "Hello Bob"