3
from contextlib import contextmanager
4
from functools import partial
5
from string import capwords
6
from unittest.mock import MagicMock, patch
12
from gradio.blocks import Blocks
13
from gradio.components import Image, Textbox
14
from gradio.interface import Interface, TabbedInterface, close_all, os
15
from gradio.layouts import TabItem, Tabs
16
from gradio.utils import assert_configs_are_equivalent_besides_ids
21
new_out, new_err = io.StringIO(), io.StringIO()
22
old_out, old_err = sys.stdout, sys.stderr
24
sys.stdout, sys.stderr = new_out, new_err
25
yield sys.stdout, sys.stderr
27
sys.stdout, sys.stderr = old_out, old_err
32
io = Interface(lambda input: None, "textbox", "label")
33
_, local_url, _ = io.launch(prevent_thread_lock=True)
34
response = httpx.get(local_url)
35
assert response.status_code == 200
37
with pytest.raises(Exception):
38
response = httpx.get(local_url)
40
def test_close_all(self):
41
interface = Interface(lambda input: None, "textbox", "label")
42
interface.close = MagicMock()
44
interface.close.assert_called()
46
def test_no_input_or_output(self):
47
with pytest.raises(TypeError):
48
Interface(lambda x: x, examples=1234)
50
def test_partial_functions(self):
51
def greet(name, formatter):
52
return formatter(f"Hello {name}!")
54
greet_upper_case = partial(greet, formatter=capwords)
55
demo = Interface(fn=greet_upper_case, inputs="text", outputs="text")
56
assert demo("abubakar") == "Hello Abubakar!"
58
def test_input_labels_extracted_from_method(self):
60
def test(self, parameter_name):
64
Interface(A().test, t, "text")
65
assert t.label == "parameter_name"
67
def test(parameter_name1, parameter_name2):
68
return parameter_name1
72
Interface(test, [t, i], "text")
73
assert t.label == "parameter_name1"
74
assert i.label == "parameter_name2"
76
def special_args_test(req: gradio.Request, parameter_name):
80
Interface(special_args_test, t, "text")
81
assert t.label == "parameter_name"
83
def test_examples_valid_path(self):
85
os.path.dirname(__file__), "../gradio/test_data/flagged_with_log"
87
interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
89
c["type"] == "dataset" for c in interface.get_config_file()["components"]
94
def test_block_thread(self, mock_sleep):
95
with pytest.raises(KeyboardInterrupt):
96
with captured_output() as (out, _):
97
mock_sleep.side_effect = KeyboardInterrupt()
98
interface = Interface(lambda x: x, "textbox", "label")
99
interface.launch(prevent_thread_lock=False)
100
output = out.getvalue().strip()
102
"Keyboard interruption in main thread... closing server." in output
105
@patch("gradio.utils.colab_check")
106
@patch("gradio.networking.setup_tunnel")
107
def test_launch_colab_share_error(self, mock_setup_tunnel, mock_colab_check):
108
mock_setup_tunnel.side_effect = RuntimeError()
109
mock_colab_check.return_value = True
110
interface = Interface(lambda x: x, "textbox", "label")
111
_, _, share_url = interface.launch(prevent_thread_lock=True)
112
assert share_url is None
115
def test_interface_representation(self):
116
def prediction_fn(x):
119
prediction_fn.__name__ = "prediction_fn"
120
repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
121
assert prediction_fn.__name__ in repr[0]
122
assert len(repr[0]) == len(repr[1])
124
@patch("webbrowser.open")
125
def test_interface_browser(self, mock_browser):
126
interface = Interface(lambda x: x, "textbox", "label")
127
interface.launch(inbrowser=True, prevent_thread_lock=True)
128
mock_browser.assert_called_once()
131
def test_examples_list(self):
132
examples = ["test1", "test2"]
133
interface = Interface(
134
lambda x: x, "textbox", "label", examples=examples, examples_per_page=2
136
interface.launch(prevent_thread_lock=True)
137
assert len(interface.examples_handler.examples) == 2
138
assert len(interface.examples_handler.examples[0]) == 1
139
assert interface.examples_handler.dataset.get_config()["samples_per_page"] == 2
142
@patch("IPython.display.display")
143
def test_inline_display(self, mock_display):
144
interface = Interface(lambda x: x, "textbox", "label")
145
interface.launch(inline=True, prevent_thread_lock=True)
146
mock_display.assert_called_once()
147
interface.launch(inline=True, prevent_thread_lock=True)
148
assert mock_display.call_count == 2
151
def test_setting_interactive_false(self):
152
output_textbox = Textbox()
153
Interface(lambda x: x, "textbox", output_textbox)
154
assert not output_textbox.get_config()["interactive"]
155
output_textbox = Textbox(interactive=True)
156
Interface(lambda x: x, "textbox", output_textbox)
157
assert output_textbox.get_config()["interactive"]
159
def test_get_api_info(self):
160
io = Interface(lambda x: x, Image(type="filepath"), "textbox")
161
api_info = io.get_api_info()
162
assert len(api_info["named_endpoints"]) == 1
163
assert len(api_info["unnamed_endpoints"]) == 0
165
def test_api_name(self):
166
io = Interface(lambda x: x, "textbox", "textbox", api_name="echo")
168
(d for d in io.config["dependencies"] if d["api_name"] == "echo"), None
171
def test_interface_in_blocks_does_not_error(self):
173
Interface(fn=lambda x: x, inputs=Textbox(), outputs=Image())
175
def test_interface_with_built_ins(self):
177
Interface(fn=str, inputs=t, outputs=Textbox())
178
assert t.label == "input 0"
180
def test_interface_additional_components_are_included_as_inputs(self):
182
s = gradio.Slider(0, 100)
183
io = Interface(fn=str, inputs=t, outputs=Textbox(), additional_inputs=s)
184
assert io.input_components == [t, s]
187
class TestTabbedInterface:
188
def test_tabbed_interface_config_matches_manual_tab(self):
189
interface1 = Interface(lambda x: x, "textbox", "textbox")
190
interface2 = Interface(lambda x: x, "image", "image")
192
with Blocks(mode="tabbed_interface") as demo:
194
with TabItem(label="tab1"):
196
with TabItem(label="tab2"):
199
interface3 = Interface(lambda x: x, "textbox", "textbox")
200
interface4 = Interface(lambda x: x, "image", "image")
201
tabbed_interface = TabbedInterface([interface3, interface4], ["tab1", "tab2"])
203
assert assert_configs_are_equivalent_besides_ids(
204
demo.get_config_file(), tabbed_interface.get_config_file()
208
@pytest.mark.parametrize(
209
"interface_type", ["standard", "input_only", "output_only", "unified"]
211
@pytest.mark.parametrize("live", [True, False])
212
@pytest.mark.parametrize("use_generator", [True, False])
213
def test_interface_adds_stop_button(interface_type, live, use_generator):
220
if interface_type == "standard":
221
interface = gradio.Interface(
222
gen_func if use_generator else func, "number", "number", live=live
224
elif interface_type == "input_only":
225
interface = gradio.Interface(
226
gen_func if use_generator else func, "number", None, live=live
228
elif interface_type == "output_only":
229
interface = gradio.Interface(
230
gen_func if use_generator else func, None, "number", live=live
233
num = gradio.Number()
234
interface = gradio.Interface(
235
gen_func if use_generator else func, num, num, live=live
241
for c in interface.config["components"]
242
if c["props"].get("variant", "") == "stop"
247
if use_generator and not live: