gradio

Форк
0
/
test_helpers.py 
882 строки · 28.4 Кб
1
import asyncio
2
import json
3
import os
4
import shutil
5
import subprocess
6
import tempfile
7
import time
8
from pathlib import Path
9
from unittest.mock import patch
10

11
import gradio_client as grc
12
import pytest
13
from gradio_client import media_data
14
from gradio_client import utils as client_utils
15
from pydub import AudioSegment
16
from starlette.testclient import TestClient
17
from tqdm import tqdm
18

19
import gradio as gr
20
from gradio import utils
21

22

23
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
24
class TestExamples:
25
    def test_handle_single_input(self, patched_cache_folder):
26
        examples = gr.Examples(["hello", "hi"], gr.Textbox())
27
        assert examples.processed_examples == [["hello"], ["hi"]]
28

29
        examples = gr.Examples([["hello"]], gr.Textbox())
30
        assert examples.processed_examples == [["hello"]]
31

32
        examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
33
        assert (
34
            client_utils.encode_file_to_base64(
35
                examples.processed_examples[0][0]["path"]
36
            )
37
            == media_data.BASE64_IMAGE
38
        )
39

40
    def test_handle_multiple_inputs(self, patched_cache_folder):
41
        examples = gr.Examples(
42
            [["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
43
        )
44
        assert examples.processed_examples[0][0] == "hello"
45
        assert (
46
            client_utils.encode_file_to_base64(
47
                examples.processed_examples[0][1]["path"]
48
            )
49
            == media_data.BASE64_IMAGE
50
        )
51

52
    def test_handle_directory(self, patched_cache_folder):
53
        examples = gr.Examples("test/test_files/images", gr.Image())
54
        assert len(examples.processed_examples) == 2
55
        for row in examples.processed_examples:
56
            for output in row:
57
                assert (
58
                    client_utils.encode_file_to_base64(output["path"])
59
                    == media_data.BASE64_IMAGE
60
                )
61

62
    def test_handle_directory_with_log_file(self, patched_cache_folder):
63
        examples = gr.Examples(
64
            "test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
65
        )
66
        ex = client_utils.traverse(
67
            examples.processed_examples,
68
            lambda s: client_utils.encode_file_to_base64(s["path"]),
69
            lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
70
        )
71
        assert ex == [
72
            [media_data.BASE64_IMAGE, "hello"],
73
            [media_data.BASE64_IMAGE, "hi"],
74
        ]
75
        for sample in examples.dataset.samples:
76
            assert os.path.isabs(sample[0]["path"])
77

78
    def test_examples_per_page(self, patched_cache_folder):
79
        examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
80
        assert examples.dataset.get_config()["samples_per_page"] == 2
81

82
    def test_no_preprocessing(self, patched_cache_folder):
83
        with gr.Blocks():
84
            image = gr.Image()
85
            textbox = gr.Textbox()
86

87
            examples = gr.Examples(
88
                examples=["test/test_files/bus.png"],
89
                inputs=image,
90
                outputs=textbox,
91
                fn=lambda x: x["path"],
92
                cache_examples=True,
93
                preprocess=False,
94
            )
95

96
        prediction = examples.load_from_cache(0)
97
        assert (
98
            client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
99
        )
100

101
    def test_no_postprocessing(self, patched_cache_folder):
102
        def im(x):
103
            return [
104
                {
105
                    "image": {
106
                        "path": "test/test_files/bus.png",
107
                    },
108
                    "caption": "hi",
109
                }
110
            ]
111

112
        with gr.Blocks():
113
            text = gr.Textbox()
114
            gall = gr.Gallery()
115

116
            examples = gr.Examples(
117
                examples=["hi"],
118
                inputs=text,
119
                outputs=gall,
120
                fn=im,
121
                cache_examples=True,
122
                postprocess=False,
123
            )
124

125
        prediction = examples.load_from_cache(0)
126
        file = prediction[0].root[0].image.path
127
        assert client_utils.encode_url_or_file_to_base64(
128
            file
129
        ) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png")
130

131

132
def test_setting_cache_dir_env_variable(monkeypatch):
133
    temp_dir = tempfile.mkdtemp()
134
    monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir)
135
    with gr.Blocks():
136
        image = gr.Image()
137
        image2 = gr.Image()
138

139
        examples = gr.Examples(
140
            examples=["test/test_files/bus.png"],
141
            inputs=image,
142
            outputs=image2,
143
            fn=lambda x: x,
144
            cache_examples=True,
145
        )
146
    prediction = examples.load_from_cache(0)
147
    path_to_cached_file = Path(prediction[0].path)
148
    assert utils.is_in_or_equal(path_to_cached_file, temp_dir)
149
    monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False)
150

151

152
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
153
class TestExamplesDataset:
154
    def test_no_headers(self, patched_cache_folder):
155
        examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
156
        assert examples.dataset.headers == []
157

158
    def test_all_headers(self, patched_cache_folder):
159
        examples = gr.Examples(
160
            "test/test_files/images_log",
161
            [gr.Image(label="im"), gr.Text(label="your text")],
162
        )
163
        assert examples.dataset.headers == ["im", "your text"]
164

165
    def test_some_headers(self, patched_cache_folder):
166
        examples = gr.Examples(
167
            "test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
168
        )
169
        assert examples.dataset.headers == ["im", ""]
170

171

172
def test_example_caching_relaunch(connect):
173
    def combine(a, b):
174
        return a + " " + b
175

176
    with gr.Blocks() as demo:
177
        txt = gr.Textbox(label="Input")
178
        txt_2 = gr.Textbox(label="Input 2")
179
        txt_3 = gr.Textbox(value="", label="Output")
180
        btn = gr.Button(value="Submit")
181
        btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
182
        gr.Examples(
183
            [["hi", "Adam"], ["hello", "Eve"]],
184
            [txt, txt_2],
185
            txt_3,
186
            combine,
187
            cache_examples=True,
188
            api_name="examples",
189
        )
190

191
    with connect(demo) as client:
192
        assert client.predict(1, api_name="/examples") == (
193
            "hello",
194
            "Eve",
195
            "hello Eve",
196
        )
197

198
    # Let the server shut down
199
    time.sleep(1)
200

201
    with connect(demo) as client:
202
        assert client.predict(1, api_name="/examples") == (
203
            "hello",
204
            "Eve",
205
            "hello Eve",
206
        )
207

208

209
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
210
class TestProcessExamples:
211
    def test_caching(self, patched_cache_folder):
212
        io = gr.Interface(
213
            lambda x: f"Hello {x}",
214
            "text",
215
            "text",
216
            examples=[["World"], ["Dunya"], ["Monde"]],
217
            cache_examples=True,
218
        )
219
        prediction = io.examples_handler.load_from_cache(1)
220
        assert prediction[0] == "Hello Dunya"
221

222
    def test_example_caching_relaunch(self, patched_cache_folder, connect):
223
        def combine(a, b):
224
            return a + " " + b
225

226
        with gr.Blocks() as demo:
227
            txt = gr.Textbox(label="Input")
228
            txt_2 = gr.Textbox(label="Input 2")
229
            txt_3 = gr.Textbox(value="", label="Output")
230
            btn = gr.Button(value="Submit")
231
            btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
232
            gr.Examples(
233
                [["hi", "Adam"], ["hello", "Eve"]],
234
                [txt, txt_2],
235
                txt_3,
236
                combine,
237
                cache_examples=True,
238
                api_name="examples",
239
            )
240

241
        with connect(demo) as client:
242
            assert client.predict(1, api_name="/examples") == (
243
                "hello",
244
                "Eve",
245
                "hello Eve",
246
            )
247

248
        with connect(demo) as client:
249
            assert client.predict(1, api_name="/examples") == (
250
                "hello",
251
                "Eve",
252
                "hello Eve",
253
            )
254

255
    def test_caching_image(self, patched_cache_folder):
256
        io = gr.Interface(
257
            lambda x: x,
258
            "image",
259
            "image",
260
            examples=[["test/test_files/bus.png"]],
261
            cache_examples=True,
262
        )
263
        prediction = io.examples_handler.load_from_cache(0)
264
        assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith(
265
            "data:image/png;base64,iVBORw0KGgoAAA"
266
        )
267

268
    def test_caching_audio(self, patched_cache_folder):
269
        io = gr.Interface(
270
            lambda x: x,
271
            "audio",
272
            "audio",
273
            examples=[["test/test_files/audio_sample.wav"]],
274
            cache_examples=True,
275
        )
276
        prediction = io.examples_handler.load_from_cache(0)
277
        file = prediction[0].path
278
        assert client_utils.encode_url_or_file_to_base64(file).startswith(
279
            "data:audio/wav;base64,UklGRgA/"
280
        )
281

282
    def test_caching_with_update(self, patched_cache_folder):
283
        io = gr.Interface(
284
            lambda x: gr.update(visible=False),
285
            "text",
286
            "image",
287
            examples=[["World"], ["Dunya"], ["Monde"]],
288
            cache_examples=True,
289
        )
290
        prediction = io.examples_handler.load_from_cache(1)
291
        assert prediction[0] == {
292
            "visible": False,
293
            "__type__": "update",
294
        }
295

296
    def test_caching_with_mix_update(self, patched_cache_folder):
297
        io = gr.Interface(
298
            lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
299
            "text",
300
            ["text", "image"],
301
            examples=[["World"], ["Dunya"], ["Monde"]],
302
            cache_examples=True,
303
        )
304
        prediction = io.examples_handler.load_from_cache(1)
305
        assert prediction[0] == {
306
            "lines": 4,
307
            "value": "hello",
308
            "__type__": "update",
309
        }
310

311
    def test_caching_with_dict(self, patched_cache_folder):
312
        text = gr.Textbox()
313
        out = gr.Label()
314

315
        io = gr.Interface(
316
            lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"},
317
            "textbox",
318
            [text, out],
319
            examples=["abc"],
320
            cache_examples=True,
321
        )
322
        prediction = io.examples_handler.load_from_cache(0)
323
        assert prediction == [
324
            {"lines": 4, "__type__": "update", "interactive": False},
325
            gr.Label.data_model(**{"label": "lion", "confidences": None}),
326
        ]
327

328
    def test_caching_with_generators(self, patched_cache_folder):
329
        def test_generator(x):
330
            for y in range(len(x)):
331
                yield "Your output: " + x[: y + 1]
332

333
        io = gr.Interface(
334
            test_generator,
335
            "textbox",
336
            "textbox",
337
            examples=["abcdef"],
338
            cache_examples=True,
339
        )
340
        prediction = io.examples_handler.load_from_cache(0)
341
        assert prediction[0] == "Your output: abcdef"
342

343
    def test_caching_with_generators_and_streamed_output(self, patched_cache_folder):
344
        file_dir = Path(Path(__file__).parent, "test_files")
345
        audio = str(file_dir / "audio_sample.wav")
346

347
        def test_generator(x):
348
            for y in range(int(x)):
349
                yield audio, y * 5
350

351
        io = gr.Interface(
352
            test_generator,
353
            "number",
354
            [gr.Audio(streaming=True), "number"],
355
            examples=[3],
356
            cache_examples=True,
357
        )
358
        prediction = io.examples_handler.load_from_cache(0)
359
        len_input_audio = len(AudioSegment.from_wav(audio))
360
        len_output_audio = len(AudioSegment.from_wav(prediction[0].path))
361
        length_ratio = len_output_audio / len_input_audio
362
        assert round(length_ratio, 1) == 3.0  # might not be exactly 3x
363
        assert float(prediction[1]) == 10.0
364

365
    def test_caching_with_async_generators(self, patched_cache_folder):
366
        async def test_generator(x):
367
            for y in range(len(x)):
368
                yield "Your output: " + x[: y + 1]
369

370
        io = gr.Interface(
371
            test_generator,
372
            "textbox",
373
            "textbox",
374
            examples=["abcdef"],
375
            cache_examples=True,
376
        )
377
        prediction = io.examples_handler.load_from_cache(0)
378
        assert prediction[0] == "Your output: abcdef"
379

380
    def test_raise_helpful_error_message_if_providing_partial_examples(
381
        self, patched_cache_folder, tmp_path
382
    ):
383
        def foo(a, b):
384
            return a + b
385

386
        with pytest.warns(
387
            UserWarning,
388
            match="^Examples are being cached but not all input components have",
389
        ):
390
            with pytest.raises(Exception):
391
                gr.Interface(
392
                    foo,
393
                    inputs=["text", "text"],
394
                    outputs=["text"],
395
                    examples=[["foo"], ["bar"]],
396
                    cache_examples=True,
397
                )
398

399
        with pytest.warns(
400
            UserWarning,
401
            match="^Examples are being cached but not all input components have",
402
        ):
403
            with pytest.raises(Exception):
404
                gr.Interface(
405
                    foo,
406
                    inputs=["text", "text"],
407
                    outputs=["text"],
408
                    examples=[["foo", "bar"], ["bar", None]],
409
                    cache_examples=True,
410
                )
411

412
        def foo_no_exception(a, b=2):
413
            return a * b
414

415
        gr.Interface(
416
            foo_no_exception,
417
            inputs=["text", "number"],
418
            outputs=["text"],
419
            examples=[["foo"], ["bar"]],
420
            cache_examples=True,
421
        )
422

423
        def many_missing(a, b, c):
424
            return a * b
425

426
        with pytest.warns(
427
            UserWarning,
428
            match="^Examples are being cached but not all input components have",
429
        ):
430
            with pytest.raises(Exception):
431
                gr.Interface(
432
                    many_missing,
433
                    inputs=["text", "number", "number"],
434
                    outputs=["text"],
435
                    examples=[["foo", None, None], ["bar", 2, 3]],
436
                    cache_examples=True,
437
                )
438

439
    def test_caching_with_batch(self, patched_cache_folder):
440
        def trim_words(words, lens):
441
            trimmed_words = [word[:length] for word, length in zip(words, lens)]
442
            return [trimmed_words]
443

444
        io = gr.Interface(
445
            trim_words,
446
            ["textbox", gr.Number(precision=0)],
447
            ["textbox"],
448
            batch=True,
449
            max_batch_size=16,
450
            examples=[["hello", 3], ["hi", 4]],
451
            cache_examples=True,
452
        )
453
        prediction = io.examples_handler.load_from_cache(0)
454
        assert prediction == ["hel"]
455

456
    def test_caching_with_batch_multiple_outputs(self, patched_cache_folder):
457
        def trim_words(words, lens):
458
            trimmed_words = [word[:length] for word, length in zip(words, lens)]
459
            return trimmed_words, lens
460

461
        io = gr.Interface(
462
            trim_words,
463
            ["textbox", gr.Number(precision=0)],
464
            ["textbox", gr.Number(precision=0)],
465
            batch=True,
466
            max_batch_size=16,
467
            examples=[["hello", 3], ["hi", 4]],
468
            cache_examples=True,
469
        )
470
        prediction = io.examples_handler.load_from_cache(0)
471
        assert prediction == ["hel", "3"]
472

473
    def test_caching_with_non_io_component(self, patched_cache_folder):
474
        def predict(name):
475
            return name, gr.update(visible=True)
476

477
        with gr.Blocks():
478
            t1 = gr.Textbox()
479
            with gr.Column(visible=False) as c:
480
                t2 = gr.Textbox()
481

482
            examples = gr.Examples(
483
                [["John"], ["Mary"]],
484
                fn=predict,
485
                inputs=[t1],
486
                outputs=[t2, c],
487
                cache_examples=True,
488
            )
489

490
        prediction = examples.load_from_cache(0)
491
        assert prediction == ["John", {"visible": True, "__type__": "update"}]
492

493
    def test_end_to_end(self, patched_cache_folder):
494
        def concatenate(str1, str2):
495
            return str1 + str2
496

497
        with gr.Blocks() as demo:
498
            t1 = gr.Textbox()
499
            t2 = gr.Textbox()
500
            t1.submit(concatenate, [t1, t2], t2)
501

502
            gr.Examples(
503
                [["Hello,", None], ["Michael", None]],
504
                inputs=[t1, t2],
505
                api_name="load_example",
506
            )
507

508
        app, _, _ = demo.launch(prevent_thread_lock=True)
509
        client = TestClient(app)
510

511
        response = client.post("/api/load_example/", json={"data": [0]})
512
        assert response.json()["data"] == [
513
            {
514
                "lines": 1,
515
                "max_lines": 20,
516
                "show_label": True,
517
                "container": True,
518
                "min_width": 160,
519
                "autofocus": False,
520
                "autoscroll": True,
521
                "elem_classes": [],
522
                "rtl": False,
523
                "show_copy_button": False,
524
                "__type__": "update",
525
                "visible": True,
526
                "value": "Hello,",
527
                "type": "text",
528
            }
529
        ]
530

531
        response = client.post("/api/load_example/", json={"data": [1]})
532
        assert response.json()["data"] == [
533
            {
534
                "lines": 1,
535
                "max_lines": 20,
536
                "show_label": True,
537
                "container": True,
538
                "min_width": 160,
539
                "autofocus": False,
540
                "autoscroll": True,
541
                "elem_classes": [],
542
                "rtl": False,
543
                "show_copy_button": False,
544
                "__type__": "update",
545
                "visible": True,
546
                "value": "Michael",
547
                "type": "text",
548
            }
549
        ]
550

551
    def test_end_to_end_cache_examples(self, patched_cache_folder):
552
        def concatenate(str1, str2):
553
            return f"{str1} {str2}"
554

555
        with gr.Blocks() as demo:
556
            t1 = gr.Textbox()
557
            t2 = gr.Textbox()
558
            t1.submit(concatenate, [t1, t2], t2)
559

560
            gr.Examples(
561
                examples=[["Hello,", "World"], ["Michael", "Jordan"]],
562
                inputs=[t1, t2],
563
                outputs=[t2],
564
                fn=concatenate,
565
                cache_examples=True,
566
                api_name="load_example",
567
            )
568

569
        app, _, _ = demo.launch(prevent_thread_lock=True)
570
        client = TestClient(app)
571

572
        response = client.post("/api/load_example/", json={"data": [0]})
573
        assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
574

575
        response = client.post("/api/load_example/", json={"data": [1]})
576
        assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
577

578

579
def test_multiple_file_flagging(tmp_path):
580
    with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
581
        io = gr.Interface(
582
            fn=lambda *x: list(x),
583
            inputs=[
584
                gr.Image(type="filepath", label="frame 1"),
585
                gr.Image(type="filepath", label="frame 2"),
586
            ],
587
            outputs=[gr.Files()],
588
            examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
589
            cache_examples=True,
590
        )
591
        prediction = io.examples_handler.load_from_cache(0)
592

593
        assert len(prediction[0].root) == 2
594
        assert all(isinstance(d, gr.FileData) for d in prediction[0].root)
595

596

597
def test_examples_keep_all_suffixes(tmp_path):
598
    with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())):
599
        file_1 = tmp_path / "foo.bar.txt"
600
        file_1.write_text("file 1")
601
        file_2 = tmp_path / "file_2"
602
        file_2.mkdir(parents=True)
603
        file_2 = file_2 / "foo.bar.txt"
604
        file_2.write_text("file 2")
605
        io = gr.Interface(
606
            fn=lambda x: x.name,
607
            inputs=gr.File(),
608
            outputs=[gr.File()],
609
            examples=[[str(file_1)], [str(file_2)]],
610
            cache_examples=True,
611
        )
612
        prediction = io.examples_handler.load_from_cache(0)
613
        assert Path(prediction[0].path).read_text() == "file 1"
614
        assert prediction[0].orig_name == "foo.bar.txt"
615
        assert prediction[0].path.endswith("foo.bar.txt")
616
        prediction = io.examples_handler.load_from_cache(1)
617
        assert Path(prediction[0].path).read_text() == "file 2"
618
        assert prediction[0].orig_name == "foo.bar.txt"
619
        assert prediction[0].path.endswith("foo.bar.txt")
620

621

622
def test_make_waveform_with_spaces_in_filename():
623
    with tempfile.TemporaryDirectory() as tmpdirname:
624
        audio = os.path.join(tmpdirname, "test audio.wav")
625
        shutil.copy("test/test_files/audio_sample.wav", audio)
626
        waveform = gr.make_waveform(audio)
627
        assert waveform.endswith(".mp4")
628

629
        try:
630
            command = [
631
                "ffprobe",
632
                "-v",
633
                "error",
634
                "-select_streams",
635
                "v:0",
636
                "-show_entries",
637
                "stream=width,height",
638
                "-of",
639
                "json",
640
                waveform,
641
            ]
642

643
            result = subprocess.run(command, capture_output=True, text=True, check=True)
644
            output = result.stdout
645
            data = json.loads(output)
646

647
            width = data["streams"][0]["width"]
648
            height = data["streams"][0]["height"]
649
            assert width == 1000
650
            assert height == 400
651

652
        except subprocess.CalledProcessError as e:
653
            print("Error retrieving resolution of output waveform video:", e)
654

655

656
def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch):
657
    """
658
    Test that make_waveform raises an exception if ffmpeg fails,
659
    instead of returning a path to a non-existent or empty file.
660
    """
661
    audio = tmp_path / "test audio.wav"
662
    shutil.copy("test/test_files/audio_sample.wav", audio)
663

664
    def _failing_ffmpeg(*args, **kwargs):
665
        raise subprocess.CalledProcessError(1, "ffmpeg")
666

667
    monkeypatch.setattr(subprocess, "call", _failing_ffmpeg)
668
    with pytest.raises(Exception):
669
        gr.make_waveform(str(audio))
670

671

672
class TestProgressBar:
673
    @pytest.mark.asyncio
674
    async def test_progress_bar(self):
675
        with gr.Blocks() as demo:
676
            name = gr.Textbox()
677
            greeting = gr.Textbox()
678
            button = gr.Button(value="Greet")
679

680
            def greet(s, prog=gr.Progress()):
681
                prog(0, desc="start")
682
                time.sleep(0.15)
683
                for _ in prog.tqdm(range(4), unit="iter"):
684
                    time.sleep(0.15)
685
                time.sleep(0.15)
686
                for _ in tqdm(["a", "b", "c"], desc="alphabet"):
687
                    time.sleep(0.15)
688
                return f"Hello, {s}!"
689

690
            button.click(greet, name, greeting)
691
        demo.queue(max_size=1).launch(prevent_thread_lock=True)
692

693
        client = grc.Client(demo.local_url)
694
        job = client.submit("Gradio")
695

696
        status_updates = []
697
        while not job.done():
698
            status = job.status()
699
            update = (
700
                status.progress_data[0].index if status.progress_data else None,
701
                status.progress_data[0].desc if status.progress_data else None,
702
            )
703
            if update != (None, None) and (
704
                len(status_updates) == 0 or status_updates[-1] != update
705
            ):
706
                status_updates.append(update)
707
            time.sleep(0.05)
708

709
        assert status_updates == [
710
            (None, "start"),
711
            (0, None),
712
            (1, None),
713
            (2, None),
714
            (3, None),
715
            (4, None),
716
        ]
717

718
    @pytest.mark.asyncio
719
    async def test_progress_bar_track_tqdm(self):
720
        with gr.Blocks() as demo:
721
            name = gr.Textbox()
722
            greeting = gr.Textbox()
723
            button = gr.Button(value="Greet")
724

725
            def greet(s, prog=gr.Progress(track_tqdm=True)):
726
                prog(0, desc="start")
727
                time.sleep(0.15)
728
                for _ in prog.tqdm(range(4), unit="iter"):
729
                    time.sleep(0.15)
730
                time.sleep(0.15)
731
                for _ in tqdm(["a", "b", "c"], desc="alphabet"):
732
                    time.sleep(0.15)
733
                return f"Hello, {s}!"
734

735
            button.click(greet, name, greeting)
736
        demo.queue(max_size=1).launch(prevent_thread_lock=True)
737

738
        client = grc.Client(demo.local_url)
739
        job = client.submit("Gradio")
740

741
        status_updates = []
742
        while not job.done():
743
            status = job.status()
744
            update = (
745
                status.progress_data[0].index if status.progress_data else None,
746
                status.progress_data[0].desc if status.progress_data else None,
747
            )
748
            if update != (None, None) and (
749
                len(status_updates) == 0 or status_updates[-1] != update
750
            ):
751
                status_updates.append(update)
752
            time.sleep(0.05)
753

754
        assert status_updates == [
755
            (None, "start"),
756
            (0, None),
757
            (1, None),
758
            (2, None),
759
            (3, None),
760
            (4, None),
761
            (0, "alphabet"),
762
            (1, "alphabet"),
763
            (2, "alphabet"),
764
        ]
765

766
    @pytest.mark.asyncio
767
    async def test_progress_bar_track_tqdm_without_iterable(self):
768
        def greet(s, _=gr.Progress(track_tqdm=True)):
769
            with tqdm(total=len(s)) as progress_bar:
770
                for _c in s:
771
                    progress_bar.update()
772
                    time.sleep(0.15)
773
            return f"Hello, {s}!"
774

775
        demo = gr.Interface(greet, "text", "text")
776
        demo.queue().launch(prevent_thread_lock=True)
777

778
        client = grc.Client(demo.local_url)
779
        job = client.submit("Gradio")
780

781
        status_updates = []
782
        while not job.done():
783
            status = job.status()
784
            update = (
785
                status.progress_data[0].index if status.progress_data else None,
786
                status.progress_data[0].unit if status.progress_data else None,
787
            )
788
            if update != (None, None) and (
789
                len(status_updates) == 0 or status_updates[-1] != update
790
            ):
791
                status_updates.append(update)
792
            time.sleep(0.05)
793

794
        assert status_updates == [
795
            (1, "steps"),
796
            (2, "steps"),
797
            (3, "steps"),
798
            (4, "steps"),
799
            (5, "steps"),
800
            (6, "steps"),
801
        ]
802

803
    @pytest.mark.asyncio
804
    async def test_info_and_warning_alerts(self):
805
        def greet(s):
806
            for _c in s:
807
                gr.Info(f"Letter {_c}")
808
                time.sleep(0.15)
809
            if len(s) < 5:
810
                gr.Warning("Too short!")
811
                time.sleep(0.15)
812
            return f"Hello, {s}!"
813

814
        demo = gr.Interface(greet, "text", "text")
815
        demo.queue().launch(prevent_thread_lock=True)
816

817
        client = grc.Client(demo.local_url)
818
        job = client.submit("Jon")
819

820
        status_updates = []
821
        while not job.done():
822
            status = job.status()
823
            update = status.log
824
            if update is not None and (
825
                len(status_updates) == 0 or status_updates[-1] != update
826
            ):
827
                status_updates.append(update)
828
            time.sleep(0.05)
829

830
        assert status_updates == [
831
            ("Letter J", "info"),
832
            ("Letter o", "info"),
833
            ("Letter n", "info"),
834
            ("Too short!", "warning"),
835
        ]
836

837

838
@pytest.mark.asyncio
839
@pytest.mark.parametrize("async_handler", [True, False])
840
async def test_info_isolation(async_handler: bool):
841
    async def greet_async(name):
842
        await asyncio.sleep(2)
843
        gr.Info(f"Hello {name}")
844
        await asyncio.sleep(1)
845
        return name
846

847
    def greet_sync(name):
848
        time.sleep(2)
849
        gr.Info(f"Hello {name}")
850
        time.sleep(1)
851
        return name
852

853
    demo = gr.Interface(
854
        greet_async if async_handler else greet_sync,
855
        "text",
856
        "text",
857
        concurrency_limit=2,
858
    )
859
    demo.launch(prevent_thread_lock=True)
860

861
    async def session_interaction(name, delay=0):
862
        client = grc.Client(demo.local_url)
863
        job = client.submit(name)
864

865
        status_updates = []
866
        while not job.done():
867
            status = job.status()
868
            update = status.log
869
            if update is not None and (
870
                len(status_updates) == 0 or status_updates[-1] != update
871
            ):
872
                status_updates.append(update)
873
            time.sleep(0.05)
874
        return status_updates[-1][0] if status_updates else None
875

876
    alice_logs, bob_logs = await asyncio.gather(
877
        session_interaction("Alice"),
878
        session_interaction("Bob", delay=1),
879
    )
880

881
    assert alice_logs == "Hello Alice"
882
    assert bob_logs == "Hello Bob"
883

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.