gradio

Форк
0
/
test_interfaces.py 
250 строк · 8.7 Кб
1
import io
2
import sys
3
from contextlib import contextmanager
4
from functools import partial
5
from string import capwords
6
from unittest.mock import MagicMock, patch
7

8
import httpx
9
import pytest
10

11
import gradio
12
from gradio.blocks import Blocks
13
from gradio.components import Image, Textbox
14
from gradio.interface import Interface, TabbedInterface, close_all, os
15
from gradio.layouts import TabItem, Tabs
16
from gradio.utils import assert_configs_are_equivalent_besides_ids
17

18

19
@contextmanager
20
def captured_output():
21
    new_out, new_err = io.StringIO(), io.StringIO()
22
    old_out, old_err = sys.stdout, sys.stderr
23
    try:
24
        sys.stdout, sys.stderr = new_out, new_err
25
        yield sys.stdout, sys.stderr
26
    finally:
27
        sys.stdout, sys.stderr = old_out, old_err
28

29

30
class TestInterface:
31
    def test_close(self):
32
        io = Interface(lambda input: None, "textbox", "label")
33
        _, local_url, _ = io.launch(prevent_thread_lock=True)
34
        response = httpx.get(local_url)
35
        assert response.status_code == 200
36
        io.close()
37
        with pytest.raises(Exception):
38
            response = httpx.get(local_url)
39

40
    def test_close_all(self):
41
        interface = Interface(lambda input: None, "textbox", "label")
42
        interface.close = MagicMock()
43
        close_all()
44
        interface.close.assert_called()
45

46
    def test_no_input_or_output(self):
47
        with pytest.raises(TypeError):
48
            Interface(lambda x: x, examples=1234)
49

50
    def test_partial_functions(self):
51
        def greet(name, formatter):
52
            return formatter(f"Hello {name}!")
53

54
        greet_upper_case = partial(greet, formatter=capwords)
55
        demo = Interface(fn=greet_upper_case, inputs="text", outputs="text")
56
        assert demo("abubakar") == "Hello Abubakar!"
57

58
    def test_input_labels_extracted_from_method(self):
59
        class A:
60
            def test(self, parameter_name):
61
                return parameter_name
62

63
        t = Textbox()
64
        Interface(A().test, t, "text")
65
        assert t.label == "parameter_name"
66

67
        def test(parameter_name1, parameter_name2):
68
            return parameter_name1
69

70
        t = Textbox()
71
        i = Image()
72
        Interface(test, [t, i], "text")
73
        assert t.label == "parameter_name1"
74
        assert i.label == "parameter_name2"
75

76
        def special_args_test(req: gradio.Request, parameter_name):
77
            return parameter_name
78

79
        t = Textbox()
80
        Interface(special_args_test, t, "text")
81
        assert t.label == "parameter_name"
82

83
    def test_examples_valid_path(self):
84
        path = os.path.join(
85
            os.path.dirname(__file__), "../gradio/test_data/flagged_with_log"
86
        )
87
        interface = Interface(lambda x: 3 * x, "number", "number", examples=path)
88
        dataset_check = any(
89
            c["type"] == "dataset" for c in interface.get_config_file()["components"]
90
        )
91
        assert dataset_check
92

93
    @patch("time.sleep")
94
    def test_block_thread(self, mock_sleep):
95
        with pytest.raises(KeyboardInterrupt):
96
            with captured_output() as (out, _):
97
                mock_sleep.side_effect = KeyboardInterrupt()
98
                interface = Interface(lambda x: x, "textbox", "label")
99
                interface.launch(prevent_thread_lock=False)
100
                output = out.getvalue().strip()
101
                assert (
102
                    "Keyboard interruption in main thread... closing server." in output
103
                )
104

105
    @patch("gradio.utils.colab_check")
106
    @patch("gradio.networking.setup_tunnel")
107
    def test_launch_colab_share_error(self, mock_setup_tunnel, mock_colab_check):
108
        mock_setup_tunnel.side_effect = RuntimeError()
109
        mock_colab_check.return_value = True
110
        interface = Interface(lambda x: x, "textbox", "label")
111
        _, _, share_url = interface.launch(prevent_thread_lock=True)
112
        assert share_url is None
113
        interface.close()
114

115
    def test_interface_representation(self):
116
        def prediction_fn(x):
117
            return x
118

119
        prediction_fn.__name__ = "prediction_fn"
120
        repr = str(Interface(prediction_fn, "textbox", "label")).split("\n")
121
        assert prediction_fn.__name__ in repr[0]
122
        assert len(repr[0]) == len(repr[1])
123

124
    @patch("webbrowser.open")
125
    def test_interface_browser(self, mock_browser):
126
        interface = Interface(lambda x: x, "textbox", "label")
127
        interface.launch(inbrowser=True, prevent_thread_lock=True)
128
        mock_browser.assert_called_once()
129
        interface.close()
130

131
    def test_examples_list(self):
132
        examples = ["test1", "test2"]
133
        interface = Interface(
134
            lambda x: x, "textbox", "label", examples=examples, examples_per_page=2
135
        )
136
        interface.launch(prevent_thread_lock=True)
137
        assert len(interface.examples_handler.examples) == 2
138
        assert len(interface.examples_handler.examples[0]) == 1
139
        assert interface.examples_handler.dataset.get_config()["samples_per_page"] == 2
140
        interface.close()
141

142
    @patch("IPython.display.display")
143
    def test_inline_display(self, mock_display):
144
        interface = Interface(lambda x: x, "textbox", "label")
145
        interface.launch(inline=True, prevent_thread_lock=True)
146
        mock_display.assert_called_once()
147
        interface.launch(inline=True, prevent_thread_lock=True)
148
        assert mock_display.call_count == 2
149
        interface.close()
150

151
    def test_setting_interactive_false(self):
152
        output_textbox = Textbox()
153
        Interface(lambda x: x, "textbox", output_textbox)
154
        assert not output_textbox.get_config()["interactive"]
155
        output_textbox = Textbox(interactive=True)
156
        Interface(lambda x: x, "textbox", output_textbox)
157
        assert output_textbox.get_config()["interactive"]
158

159
    def test_get_api_info(self):
160
        io = Interface(lambda x: x, Image(type="filepath"), "textbox")
161
        api_info = io.get_api_info()
162
        assert len(api_info["named_endpoints"]) == 1
163
        assert len(api_info["unnamed_endpoints"]) == 0
164

165
    def test_api_name(self):
166
        io = Interface(lambda x: x, "textbox", "textbox", api_name="echo")
167
        assert next(
168
            (d for d in io.config["dependencies"] if d["api_name"] == "echo"), None
169
        )
170

171
    def test_interface_in_blocks_does_not_error(self):
172
        with Blocks():
173
            Interface(fn=lambda x: x, inputs=Textbox(), outputs=Image())
174

175
    def test_interface_with_built_ins(self):
176
        t = Textbox()
177
        Interface(fn=str, inputs=t, outputs=Textbox())
178
        assert t.label == "input 0"
179

180
    def test_interface_additional_components_are_included_as_inputs(self):
181
        t = Textbox()
182
        s = gradio.Slider(0, 100)
183
        io = Interface(fn=str, inputs=t, outputs=Textbox(), additional_inputs=s)
184
        assert io.input_components == [t, s]
185

186

187
class TestTabbedInterface:
188
    def test_tabbed_interface_config_matches_manual_tab(self):
189
        interface1 = Interface(lambda x: x, "textbox", "textbox")
190
        interface2 = Interface(lambda x: x, "image", "image")
191

192
        with Blocks(mode="tabbed_interface") as demo:
193
            with Tabs():
194
                with TabItem(label="tab1"):
195
                    interface1.render()
196
                with TabItem(label="tab2"):
197
                    interface2.render()
198

199
        interface3 = Interface(lambda x: x, "textbox", "textbox")
200
        interface4 = Interface(lambda x: x, "image", "image")
201
        tabbed_interface = TabbedInterface([interface3, interface4], ["tab1", "tab2"])
202

203
        assert assert_configs_are_equivalent_besides_ids(
204
            demo.get_config_file(), tabbed_interface.get_config_file()
205
        )
206

207

208
@pytest.mark.parametrize(
209
    "interface_type", ["standard", "input_only", "output_only", "unified"]
210
)
211
@pytest.mark.parametrize("live", [True, False])
212
@pytest.mark.parametrize("use_generator", [True, False])
213
def test_interface_adds_stop_button(interface_type, live, use_generator):
214
    def gen_func(inp):
215
        yield inp
216

217
    def func(inp):
218
        return inp
219

220
    if interface_type == "standard":
221
        interface = gradio.Interface(
222
            gen_func if use_generator else func, "number", "number", live=live
223
        )
224
    elif interface_type == "input_only":
225
        interface = gradio.Interface(
226
            gen_func if use_generator else func, "number", None, live=live
227
        )
228
    elif interface_type == "output_only":
229
        interface = gradio.Interface(
230
            gen_func if use_generator else func, None, "number", live=live
231
        )
232
    else:
233
        num = gradio.Number()
234
        interface = gradio.Interface(
235
            gen_func if use_generator else func, num, num, live=live
236
        )
237
    has_stop = (
238
        len(
239
            [
240
                c
241
                for c in interface.config["components"]
242
                if c["props"].get("variant", "") == "stop"
243
            ]
244
        )
245
        == 1
246
    )
247
    if use_generator and not live:
248
        assert has_stop
249
    else:
250
        assert not has_stop
251

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.