gradio

Форк
0
/
test_events.py 
185 строк · 5.7 Кб
1
import ast
2
import inspect
3
from pathlib import Path
4

5
import pytest
6
from fastapi.testclient import TestClient
7

8
import gradio as gr
9

10

11
class TestEvent:
12
    def test_clear_event(self):
13
        def fn_img_cleared():
14
            print("image cleared")
15

16
        with gr.Blocks() as demo:
17
            img = gr.Image(
18
                type="pil", label="Start by uploading an image", elem_id="input_image"
19
            )
20

21
            img.clear(fn_img_cleared, [], [])
22

23
        assert demo.config["dependencies"][0]["targets"][0][1] == "clear"
24

25
    def test_event_data(self):
26
        with gr.Blocks() as demo:
27
            text = gr.Textbox()
28
            gallery = gr.Gallery()
29

30
            def fn_img_index(evt: gr.SelectData):
31
                return evt.index
32

33
            gallery.select(fn_img_index, None, text)
34

35
        app, _, _ = demo.launch(prevent_thread_lock=True)
36
        client = TestClient(app)
37

38
        resp = client.post(
39
            f"{demo.local_url}run/predict",
40
            json={"fn_index": 0, "data": [], "event_data": {"index": 1, "value": None}},
41
        )
42
        assert resp.status_code == 200
43
        assert resp.json()["data"][0] == "1"
44

45
    def test_consecutive_events(self):
46
        def double(x):
47
            return x + x
48

49
        def reverse(x):
50
            return x[::-1]
51

52
        def clear():
53
            return ""
54

55
        with gr.Blocks() as child:
56
            txt1 = gr.Textbox()
57
            txt2 = gr.Textbox()
58
            txt3 = gr.Textbox()
59

60
            txt1.submit(double, txt1, txt2).then(reverse, txt2, txt3).success(
61
                clear, None, txt1
62
            )
63

64
        with gr.Blocks() as parent:
65
            txt0 = gr.Textbox()
66
            txt0.submit(lambda x: x, txt0, txt0)
67
            child.render()
68

69
        assert parent.config["dependencies"][1]["trigger_after"] is None
70
        assert parent.config["dependencies"][2]["trigger_after"] == 1
71
        assert parent.config["dependencies"][3]["trigger_after"] == 2
72

73
        assert not parent.config["dependencies"][2]["trigger_only_on_success"]
74
        assert parent.config["dependencies"][3]["trigger_only_on_success"]
75

76
    def test_on_listener(self):
77
        with gr.Blocks() as demo:
78
            name = gr.Textbox(label="Name")
79
            output = gr.Textbox(label="Output Box")
80
            greet_btn = gr.Button("Greet")
81

82
            def greet(name):
83
                return "Hello " + name + "!"
84

85
            gr.on(
86
                triggers=[name.submit, greet_btn.click, demo.load],
87
                fn=greet,
88
                inputs=name,
89
                outputs=output,
90
            )
91

92
            with gr.Row():
93
                num1 = gr.Slider(1, 10)
94
                num2 = gr.Slider(1, 10)
95
                num3 = gr.Slider(1, 10)
96
            output = gr.Number(label="Sum")
97

98
            @gr.on(inputs=[num1, num2, num3], outputs=output)
99
            def sum(a, b, c):
100
                return a + b + c
101

102
        assert demo.config["dependencies"][0]["targets"] == [
103
            (name._id, "submit"),
104
            (greet_btn._id, "click"),
105
            (demo._id, "load"),
106
        ]
107
        assert demo.config["dependencies"][1]["targets"] == [
108
            (num1._id, "change"),
109
            (num2._id, "change"),
110
            (num3._id, "change"),
111
        ]
112

113
    def test_load_chaining(self):
114
        calls = 0
115

116
        def increment():
117
            nonlocal calls
118
            calls += 1
119
            return str(calls)
120

121
        with gr.Blocks() as demo:
122
            out = gr.Textbox(label="Call counter")
123
            demo.load(increment, inputs=None, outputs=out).then(
124
                increment, inputs=None, outputs=out
125
            )
126

127
        assert demo.config["dependencies"][0]["targets"][0][1] == "load"
128
        assert demo.config["dependencies"][0]["trigger_after"] is None
129
        assert demo.config["dependencies"][1]["targets"][0][1] == "then"
130
        assert demo.config["dependencies"][1]["trigger_after"] == 0
131

132
    def test_load_chaining_reuse(self):
133
        calls = 0
134

135
        def increment():
136
            nonlocal calls
137
            calls += 1
138
            return str(calls)
139

140
        with gr.Blocks() as demo:
141
            out = gr.Textbox(label="Call counter")
142
            demo.load(increment, inputs=None, outputs=out).then(
143
                increment, inputs=None, outputs=out
144
            )
145

146
        with gr.Blocks() as demo2:
147
            demo.render()
148

149
        assert demo2.config["dependencies"][0]["targets"][0][1] == "load"
150
        assert demo2.config["dependencies"][0]["trigger_after"] is None
151
        assert demo2.config["dependencies"][1]["targets"][0][1] == "then"
152
        assert demo2.config["dependencies"][1]["trigger_after"] == 0
153

154

155
class TestEventErrors:
156
    def test_event_defined_invalid_scope(self):
157
        with gr.Blocks() as demo:
158
            textbox = gr.Textbox()
159
            textbox.blur(lambda x: x + x, textbox, textbox)
160

161
        with pytest.raises(AttributeError):
162
            demo.load(lambda: "hello", None, textbox)
163

164
        with pytest.raises(AttributeError):
165
            textbox.change(lambda x: x + x, textbox, textbox)
166

167

168
def test_event_pyi_file_matches_source_code():
169
    """Test that the template used to create pyi files (search INTERFACE_TEMPLATE in component_meta) matches the source code of EventListener._setup."""
170
    code = (
171
        Path(__file__).parent / ".." / "gradio" / "components" / "button.pyi"
172
    ).read_text()
173
    mod = ast.parse(code)
174
    segment = None
175
    for node in ast.walk(mod):
176
        if isinstance(node, ast.FunctionDef) and node.name == "click":
177
            segment = ast.get_source_segment(code, node)
178

179
    # This would fail if Button no longer has a click method
180
    assert segment
181
    sig = inspect.signature(gr.Button.click)
182
    for param in sig.parameters.values():
183
        if param.name == "block":
184
            continue
185
        assert param.name in segment
186

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.