gradio

Форк
0
/
test_routes.py 
1233 строки · 41.4 Кб
1
"""Contains tests for networking.py and app.py"""
2
import functools
3
import os
4
import tempfile
5
import time
6
from contextlib import asynccontextmanager, closing
7
from typing import Dict
8
from unittest.mock import patch
9

10
import gradio_client as grc
11
import numpy as np
12
import pandas as pd
13
import pytest
14
import requests
15
import starlette.routing
16
from fastapi import FastAPI, Request
17
from fastapi.testclient import TestClient
18
from gradio_client import media_data
19

20
import gradio as gr
21
from gradio import (
22
    Blocks,
23
    Button,
24
    Interface,
25
    Number,
26
    Textbox,
27
    close_all,
28
    routes,
29
    wasm_utils,
30
)
31
from gradio.route_utils import (
32
    FnIndexInferError,
33
    compare_passwords_securely,
34
    get_root_url,
35
    starts_with_protocol,
36
)
37

38

39
@pytest.fixture()
40
def test_client():
41
    io = Interface(lambda x: x + x, "text", "text")
42
    app, _, _ = io.launch(prevent_thread_lock=True)
43
    test_client = TestClient(app)
44
    yield test_client
45
    io.close()
46
    close_all()
47

48

49
class TestRoutes:
50
    def test_get_main_route(self, test_client):
51
        response = test_client.get("/")
52
        assert response.status_code == 200
53

54
    def test_static_files_served_safely(self, test_client):
55
        # Make sure things outside the static folder are not accessible
56
        response = test_client.get(r"/static/..%2findex.html")
57
        assert response.status_code == 403
58
        response = test_client.get(r"/static/..%2f..%2fapi_docs.html")
59
        assert response.status_code == 403
60

61
    def test_get_config_route(self, test_client):
62
        response = test_client.get("/config/")
63
        assert response.status_code == 200
64

65
    def test_favicon_route(self, test_client):
66
        response = test_client.get("/favicon.ico")
67
        assert response.status_code == 200
68

69
    def test_upload_path(self, test_client):
70
        with open("test/test_files/alphabet.txt", "rb") as f:
71
            response = test_client.post("/upload", files={"files": f})
72
        assert response.status_code == 200
73
        file = response.json()[0]
74
        assert "alphabet" in file
75
        assert file.endswith(".txt")
76
        with open(file, "rb") as saved_file:
77
            assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"
78

79
    def test_custom_upload_path(self, gradio_temp_dir):
80
        io = Interface(lambda x: x + x, "text", "text")
81
        app, _, _ = io.launch(prevent_thread_lock=True)
82
        test_client = TestClient(app)
83
        with open("test/test_files/alphabet.txt", "rb") as f:
84
            response = test_client.post("/upload", files={"files": f})
85
        assert response.status_code == 200
86
        file = response.json()[0]
87
        assert "alphabet" in file
88
        assert file.startswith(str(gradio_temp_dir))
89
        assert file.endswith(".txt")
90
        with open(file, "rb") as saved_file:
91
            assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"
92

93
    def test_predict_route(self, test_client):
94
        response = test_client.post(
95
            "/api/predict/", json={"data": ["test"], "fn_index": 0}
96
        )
97
        assert response.status_code == 200
98
        output = dict(response.json())
99
        assert output["data"] == ["testtest"]
100

101
    def test_named_predict_route(self):
102
        with Blocks() as demo:
103
            i = Textbox()
104
            o = Textbox()
105
            i.change(lambda x: f"{x}1", i, o, api_name="p")
106
            i.change(lambda x: f"{x}2", i, o, api_name="q")
107

108
        app, _, _ = demo.launch(prevent_thread_lock=True)
109
        client = TestClient(app)
110
        response = client.post("/api/p/", json={"data": ["test"]})
111
        assert response.status_code == 200
112
        output = dict(response.json())
113
        assert output["data"] == ["test1"]
114

115
        response = client.post("/api/q/", json={"data": ["test"]})
116
        assert response.status_code == 200
117
        output = dict(response.json())
118
        assert output["data"] == ["test2"]
119

120
    def test_same_named_predict_route(self):
121
        with Blocks() as demo:
122
            i = Textbox()
123
            o = Textbox()
124
            i.change(lambda x: f"{x}0", i, o, api_name="p")
125
            i.change(lambda x: f"{x}1", i, o, api_name="p")
126

127
        app, _, _ = demo.launch(prevent_thread_lock=True)
128
        client = TestClient(app)
129
        response = client.post("/api/p/", json={"data": ["test"]})
130
        assert response.status_code == 200
131
        output = dict(response.json())
132
        assert output["data"] == ["test0"]
133

134
        response = client.post("/api/p_1/", json={"data": ["test"]})
135
        assert response.status_code == 200
136
        output = dict(response.json())
137
        assert output["data"] == ["test1"]
138

139
    def test_multiple_renamed(self):
140
        with Blocks() as demo:
141
            i = Textbox()
142
            o = Textbox()
143
            i.change(lambda x: f"{x}0", i, o, api_name="p")
144
            i.change(lambda x: f"{x}1", i, o, api_name="p")
145
            i.change(lambda x: f"{x}2", i, o, api_name="p_1")
146

147
        app, _, _ = demo.launch(prevent_thread_lock=True)
148
        client = TestClient(app)
149
        response = client.post("/api/p/", json={"data": ["test"]})
150
        assert response.status_code == 200
151
        output = dict(response.json())
152
        assert output["data"] == ["test0"]
153

154
        response = client.post("/api/p_1/", json={"data": ["test"]})
155
        assert response.status_code == 200
156
        output = dict(response.json())
157
        assert output["data"] == ["test1"]
158

159
        response = client.post("/api/p_1_1/", json={"data": ["test"]})
160
        assert response.status_code == 200
161
        output = dict(response.json())
162
        assert output["data"] == ["test2"]
163

164
    def test_predict_route_without_fn_index(self, test_client):
165
        response = test_client.post("/api/predict/", json={"data": ["test"]})
166
        assert response.status_code == 200
167
        output = dict(response.json())
168
        assert output["data"] == ["testtest"]
169

170
    def test_predict_route_batching(self):
171
        def batch_fn(x):
172
            results = []
173
            for word in x:
174
                results.append(f"Hello {word}")
175
            return (results,)
176

177
        with gr.Blocks() as demo:
178
            text = gr.Textbox()
179
            btn = gr.Button()
180
            btn.click(batch_fn, inputs=text, outputs=text, batch=True, api_name="pred")
181

182
        demo.queue(api_open=True)
183
        app, _, _ = demo.launch(prevent_thread_lock=True)
184
        client = TestClient(app)
185
        response = client.post("/api/pred/", json={"data": ["test"]})
186
        output = dict(response.json())
187
        assert output["data"] == ["Hello test"]
188

189
        app, _, _ = demo.launch(prevent_thread_lock=True)
190
        client = TestClient(app)
191
        response = client.post(
192
            "/api/pred/", json={"data": [["test", "test2"]], "batched": True}
193
        )
194
        output = dict(response.json())
195
        assert output["data"] == [["Hello test", "Hello test2"]]
196

197
    def test_state(self):
198
        def predict(input, history):
199
            if history is None:
200
                history = ""
201
            history += input
202
            return history, history
203

204
        io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
205
        app, _, _ = io.launch(prevent_thread_lock=True)
206
        client = TestClient(app)
207
        response = client.post(
208
            "/api/predict/",
209
            json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
210
        )
211
        output = dict(response.json())
212
        assert output["data"] == ["test", None]
213
        response = client.post(
214
            "/api/predict/",
215
            json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
216
        )
217
        output = dict(response.json())
218
        assert output["data"] == ["testtest", None]
219

220
    def test_get_allowed_paths(self):
221
        allowed_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
222
        allowed_file.write(media_data.BASE64_IMAGE)
223
        allowed_file.flush()
224

225
        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
226
        app, _, _ = io.launch(prevent_thread_lock=True)
227
        client = TestClient(app)
228
        file_response = client.get(f"/file={allowed_file.name}")
229
        assert file_response.status_code == 403
230
        io.close()
231

232
        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
233
        app, _, _ = io.launch(
234
            prevent_thread_lock=True,
235
            allowed_paths=[os.path.dirname(allowed_file.name)],
236
        )
237
        client = TestClient(app)
238
        file_response = client.get(f"/file={allowed_file.name}")
239
        assert file_response.status_code == 200
240
        assert len(file_response.text) == len(media_data.BASE64_IMAGE)
241
        io.close()
242

243
        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
244
        app, _, _ = io.launch(
245
            prevent_thread_lock=True,
246
            allowed_paths=[os.path.abspath(allowed_file.name)],
247
        )
248
        client = TestClient(app)
249
        file_response = client.get(f"/file={allowed_file.name}")
250
        assert file_response.status_code == 200
251
        assert len(file_response.text) == len(media_data.BASE64_IMAGE)
252
        io.close()
253

254
    def test_allowed_and_blocked_paths(self):
255
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
256
            io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
257
            app, _, _ = io.launch(
258
                prevent_thread_lock=True,
259
                allowed_paths=[os.path.dirname(tmp_file.name)],
260
            )
261
            client = TestClient(app)
262
            file_response = client.get(f"/file={tmp_file.name}")
263
            assert file_response.status_code == 200
264
        io.close()
265
        os.remove(tmp_file.name)
266

267
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
268
            io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
269
            app, _, _ = io.launch(
270
                prevent_thread_lock=True,
271
                allowed_paths=[os.path.dirname(tmp_file.name)],
272
                blocked_paths=[os.path.dirname(tmp_file.name)],
273
            )
274
            client = TestClient(app)
275
            file_response = client.get(f"/file={tmp_file.name}")
276
            assert file_response.status_code == 403
277
        io.close()
278
        os.remove(tmp_file.name)
279

280
    def test_get_file_created_by_app(self, test_client):
281
        app, _, _ = gr.Interface(lambda s: s.name, gr.File(), gr.File()).launch(
282
            prevent_thread_lock=True
283
        )
284
        client = TestClient(app)
285
        with open("test/test_files/alphabet.txt", "rb") as f:
286
            file_response = test_client.post("/upload", files={"files": f})
287
        response = client.post(
288
            "/api/predict/",
289
            json={
290
                "data": [
291
                    {
292
                        "path": file_response.json()[0],
293
                        "size": os.path.getsize("test/test_files/alphabet.txt"),
294
                    }
295
                ],
296
                "fn_index": 0,
297
                "session_hash": "_",
298
            },
299
        ).json()
300
        created_file = response["data"][0]["path"]
301
        file_response = client.get(f"/file={created_file}")
302
        assert file_response.is_success
303

304
        backwards_compatible_file_response = client.get(f"/file/{created_file}")
305
        assert backwards_compatible_file_response.is_success
306

307
        file_response_with_full_range = client.get(
308
            f"/file={created_file}", headers={"Range": "bytes=0-"}
309
        )
310
        assert file_response_with_full_range.is_success
311
        assert file_response.text == file_response_with_full_range.text
312

313
        file_response_with_partial_range = client.get(
314
            f"/file={created_file}", headers={"Range": "bytes=0-10"}
315
        )
316
        assert file_response_with_partial_range.is_success
317
        assert len(file_response_with_partial_range.text) == 11
318

319
    def test_mount_gradio_app(self):
320
        app = FastAPI()
321

322
        demo = gr.Interface(
323
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
324
        ).queue()
325
        demo1 = gr.Interface(
326
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
327
        ).queue()
328

329
        app = gr.mount_gradio_app(app, demo, path="/ps")
330
        app = gr.mount_gradio_app(app, demo1, path="/py")
331

332
        # Use context manager to trigger start up events
333
        with TestClient(app) as client:
334
            assert client.get("/ps").is_success
335
            assert client.get("/py").is_success
336

337
    def test_mount_gradio_app_with_app_kwargs(self):
338
        app = FastAPI()
339

340
        demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue()
341

342
        app = gr.mount_gradio_app(
343
            app, demo, path="/echo", app_kwargs={"docs_url": "/docs-custom"}
344
        )
345

346
        # Use context manager to trigger start up events
347
        with TestClient(app) as client:
348
            assert client.get("/echo/docs-custom").is_success
349

350
    def test_mount_gradio_app_with_lifespan(self):
351
        @asynccontextmanager
352
        async def empty_lifespan(app: FastAPI):
353
            yield
354

355
        app = FastAPI(lifespan=empty_lifespan)
356

357
        demo = gr.Interface(
358
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
359
        ).queue()
360
        demo1 = gr.Interface(
361
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
362
        ).queue()
363

364
        app = gr.mount_gradio_app(app, demo, path="/ps")
365
        app = gr.mount_gradio_app(app, demo1, path="/py")
366

367
        # Use context manager to trigger start up events
368
        with TestClient(app) as client:
369
            assert client.get("/ps").is_success
370
            assert client.get("/py").is_success
371

372
    def test_mount_gradio_app_with_startup(self):
373
        app = FastAPI()
374

375
        @app.on_event("startup")
376
        async def empty_startup():
377
            return
378

379
        demo = gr.Interface(
380
            lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
381
        ).queue()
382
        demo1 = gr.Interface(
383
            lambda s: f"Hello from py, {s}!", "textbox", "textbox"
384
        ).queue()
385

386
        app = gr.mount_gradio_app(app, demo, path="/ps")
387
        app = gr.mount_gradio_app(app, demo1, path="/py")
388

389
        # Use context manager to trigger start up events
390
        with TestClient(app) as client:
391
            assert client.get("/ps").is_success
392
            assert client.get("/py").is_success
393

394
    def test_mount_gradio_app_with_auth_dependency(self):
395
        app = FastAPI()
396

397
        def get_user(request: Request):
398
            return request.headers.get("user")
399

400
        demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox")
401

402
        app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)
403

404
        with TestClient(app) as client:
405
            assert client.get("/demo", headers={"user": "abubakar"}).is_success
406
            assert not client.get("/demo").is_success
407

408
    def test_static_file_missing(self, test_client):
409
        response = test_client.get(r"/static/not-here.js")
410
        assert response.status_code == 404
411

412
    def test_asset_file_missing(self, test_client):
413
        response = test_client.get(r"/assets/not-here.js")
414
        assert response.status_code == 404
415

416
    def test_cannot_access_files_in_working_directory(self, test_client):
417
        response = test_client.get(r"/file=not-here.js")
418
        assert response.status_code == 403
419
        response = test_client.get(r"/file=subdir/.env")
420
        assert response.status_code == 403
421

422
    def test_cannot_access_directories_in_working_directory(self, test_client):
423
        response = test_client.get(r"/file=gradio")
424
        assert response.status_code == 403
425

426
    def test_block_protocols_that_expose_windows_credentials(self, test_client):
427
        response = test_client.get(r"/file=//11.0.225.200/share")
428
        assert response.status_code == 403
429

430
    def test_do_not_expose_existence_of_files_outside_working_directory(
431
        self, test_client
432
    ):
433
        response = test_client.get(r"/file=../fake-file-that-does-not-exist.js")
434
        assert response.status_code == 403  # not a 404
435

436
    def test_proxy_route_is_restricted_to_load_urls(self):
437
        gr.context.Context.hf_token = "abcdef"
438
        app = routes.App()
439
        interface = gr.Interface(lambda x: x, "text", "text")
440
        app.configure_app(interface)
441
        with pytest.raises(PermissionError):
442
            app.build_proxy_request(
443
                "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
444
            )
445
        with pytest.raises(PermissionError):
446
            app.build_proxy_request("https://google.com")
447
        interface.proxy_urls = {
448
            "https://gradio-tests-test-loading-examples-private.hf.space"
449
        }
450
        app.build_proxy_request(
451
            "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
452
        )
453

454
    def test_proxy_does_not_leak_hf_token_externally(self):
455
        gr.context.Context.hf_token = "abcdef"
456
        app = routes.App()
457
        interface = gr.Interface(lambda x: x, "text", "text")
458
        interface.proxy_urls = {
459
            "https://gradio-tests-test-loading-examples-private.hf.space",
460
            "https://google.com",
461
        }
462
        app.configure_app(interface)
463
        r = app.build_proxy_request(
464
            "https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
465
        )
466
        assert "authorization" in dict(r.headers)
467
        r = app.build_proxy_request("https://google.com")
468
        assert "authorization" not in dict(r.headers)
469

470
    def test_can_get_config_that_includes_non_pickle_able_objects(self):
471
        my_dict = {"a": 1, "b": 2, "c": 3}
472
        with Blocks() as demo:
473
            gr.JSON(my_dict.keys())
474

475
        app, _, _ = demo.launch(prevent_thread_lock=True)
476
        client = TestClient(app)
477
        response = client.get("/")
478
        assert response.is_success
479
        response = client.get("/config/")
480
        assert response.is_success
481

482
    def test_cors_restrictions(self):
483
        io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
484
        app, _, _ = io.launch(prevent_thread_lock=True)
485
        client = TestClient(app)
486
        custom_headers = {
487
            "host": "localhost:7860",
488
            "origin": "https://example.com",
489
        }
490
        file_response = client.get("/config", headers=custom_headers)
491
        assert "access-control-allow-origin" not in file_response.headers
492
        custom_headers = {
493
            "host": "localhost:7860",
494
            "origin": "127.0.0.1",
495
        }
496
        file_response = client.get("/config", headers=custom_headers)
497
        assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
498
        io.close()
499

500
    def test_delete_cache(self, connect, gradio_temp_dir, capsys):
501
        def check_num_files_exist(blocks: Blocks):
502
            num_files = 0
503
            for temp_file_set in blocks.temp_file_sets:
504
                for temp_file in temp_file_set:
505
                    if os.path.exists(temp_file):
506
                        num_files += 1
507
            return num_files
508

509
        demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
510
        with connect(demo) as client:
511
            client.predict("test/test_files/cheetah1.jpg")
512
        assert check_num_files_exist(demo) == 1
513

514
        demo_delete = gr.Interface(
515
            lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
516
        )
517
        with connect(demo_delete) as client:
518
            client.predict("test/test_files/alphabet.txt")
519
            client.predict("test/test_files/bus.png")
520
            assert check_num_files_exist(demo_delete) == 2
521
        assert check_num_files_exist(demo_delete) == 0
522
        assert check_num_files_exist(demo) == 1
523

524
        @asynccontextmanager
525
        async def mylifespan(app: FastAPI):
526
            print("IN CUSTOM LIFESPAN")
527
            yield
528
            print("AFTER CUSTOM LIFESPAN")
529

530
        demo_custom_lifespan = gr.Interface(
531
            lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
532
        )
533

534
        with connect(
535
            demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
536
        ) as client:
537
            client.predict("test/test_files/alphabet.txt")
538
        assert check_num_files_exist(demo_custom_lifespan) == 0
539
        captured = capsys.readouterr()
540
        assert "IN CUSTOM LIFESPAN" in captured.out
541
        assert "AFTER CUSTOM LIFESPAN" in captured.out
542

543

544
class TestApp:
545
    def test_create_app(self):
546
        app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
547
        assert isinstance(app, FastAPI)
548

549

550
class TestAuthenticatedRoutes:
551
    def test_post_login(self):
552
        io = Interface(lambda x: x, "text", "text")
553
        app, _, _ = io.launch(
554
            auth=("test", "correct_password"),
555
            prevent_thread_lock=True,
556
        )
557
        client = TestClient(app)
558

559
        response = client.post(
560
            "/login",
561
            data={"username": "test", "password": "correct_password"},
562
        )
563
        assert response.status_code == 200
564

565
        response = client.post(
566
            "/login",
567
            data={"username": "test", "password": "incorrect_password"},
568
        )
569
        assert response.status_code == 400
570

571
        client.post(
572
            "/login",
573
            data={"username": "test", "password": "correct_password"},
574
        )
575
        response = client.post(
576
            "/login",
577
            data={"username": " test ", "password": "correct_password"},
578
        )
579
        assert response.status_code == 200
580

581
    def test_logout(self):
582
        io = Interface(lambda x: x, "text", "text")
583
        app, _, _ = io.launch(
584
            auth=("test", "correct_password"),
585
            prevent_thread_lock=True,
586
        )
587
        client = TestClient(app)
588

589
        client.post(
590
            "/login",
591
            data={"username": "test", "password": "correct_password"},
592
        )
593

594
        response = client.post(
595
            "/run/predict",
596
            json={"data": ["test"]},
597
        )
598
        assert response.status_code == 200
599

600
        response = client.get("/logout")
601

602
        response = client.post(
603
            "/run/predict",
604
            json={"data": ["test"]},
605
        )
606
        assert response.status_code == 401
607

608

609
class TestQueueRoutes:
610
    @pytest.mark.asyncio
611
    async def test_queue_join_routes_sets_app_if_none_set(self):
612
        io = Interface(lambda x: x, "text", "text").queue()
613
        io.launch(prevent_thread_lock=True)
614
        io._queue.server_path = None
615

616
        client = grc.Client(io.local_url)
617
        client.predict("test")
618

619
        assert io._queue.server_app == io.server_app
620

621

622
class TestDevMode:
623
    def test_mount_gradio_app_set_dev_mode_false(self):
624
        app = FastAPI()
625

626
        @app.get("/")
627
        def read_main():
628
            return {"message": "Hello!"}
629

630
        with gr.Blocks() as blocks:
631
            gr.Textbox("Hello from gradio!")
632

633
        app = routes.mount_gradio_app(app, blocks, path="/gradio")
634
        gradio_fast_api = next(
635
            route for route in app.routes if isinstance(route, starlette.routing.Mount)
636
        )
637
        assert not gradio_fast_api.app.blocks.dev_mode
638

639

640
class TestPassingRequest:
641
    def test_request_included_with_interface(self):
642
        def identity(name, request: gr.Request):
643
            assert isinstance(request.client.host, str)
644
            return name
645

646
        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
647
            prevent_thread_lock=True,
648
        )
649
        client = TestClient(app)
650

651
        response = client.post("/api/predict/", json={"data": ["test"]})
652
        assert response.status_code == 200
653
        output = dict(response.json())
654
        assert output["data"] == ["test"]
655

656
    def test_request_included_with_chat_interface(self):
657
        def identity(x, y, request: gr.Request):
658
            assert isinstance(request.client.host, str)
659
            return x
660

661
        app, _, _ = gr.ChatInterface(identity).launch(
662
            prevent_thread_lock=True,
663
        )
664
        client = TestClient(app)
665

666
        response = client.post("/api/chat/", json={"data": ["test", None]})
667
        assert response.status_code == 200
668
        output = dict(response.json())
669
        assert output["data"] == ["test", None]
670

671
    def test_request_included_with_chat_interface_when_streaming(self):
672
        def identity(x, y, request: gr.Request):
673
            assert isinstance(request.client.host, str)
674
            for i in range(len(x)):
675
                yield x[: i + 1]
676

677
        app, _, _ = (
678
            gr.ChatInterface(identity)
679
            .queue(api_open=True)
680
            .launch(
681
                prevent_thread_lock=True,
682
            )
683
        )
684
        client = TestClient(app)
685

686
        response = client.post("/api/chat/", json={"data": ["test", None]})
687
        assert response.status_code == 200
688
        output = dict(response.json())
689
        assert output["data"] == ["t", None]
690

691
    def test_request_get_headers(self):
692
        def identity(name, request: gr.Request):
693
            assert isinstance(request.headers["user-agent"], str)
694
            assert isinstance(request.headers.items(), list)
695
            assert isinstance(request.headers.keys(), list)
696
            assert isinstance(request.headers.values(), list)
697
            assert isinstance(dict(request.headers), dict)
698
            user_agent = request.headers["user-agent"]
699
            assert "testclient" in user_agent
700
            return name
701

702
        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
703
            prevent_thread_lock=True,
704
        )
705
        client = TestClient(app)
706

707
        response = client.post("/api/predict/", json={"data": ["test"]})
708
        assert response.status_code == 200
709
        output = dict(response.json())
710
        assert output["data"] == ["test"]
711

712
    def test_request_includes_username_as_none_if_no_auth(self):
713
        def identity(name, request: gr.Request):
714
            assert request.username is None
715
            return name
716

717
        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
718
            prevent_thread_lock=True,
719
        )
720
        client = TestClient(app)
721

722
        response = client.post("/api/predict/", json={"data": ["test"]})
723
        assert response.status_code == 200
724
        output = dict(response.json())
725
        assert output["data"] == ["test"]
726

727
    def test_request_includes_username_with_auth(self):
728
        def identity(name, request: gr.Request):
729
            assert request.username == "admin"
730
            return name
731

732
        app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
733
            prevent_thread_lock=True, auth=("admin", "password")
734
        )
735
        client = TestClient(app)
736

737
        client.post(
738
            "/login",
739
            data={"username": "admin", "password": "password"},
740
        )
741
        response = client.post("/api/predict/", json={"data": ["test"]})
742
        assert response.status_code == 200
743
        output = dict(response.json())
744
        assert output["data"] == ["test"]
745

746

747
def test_predict_route_is_blocked_if_api_open_false():
748
    io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
749
        api_open=False
750
    )
751
    app, _, _ = io.launch(prevent_thread_lock=True)
752
    assert io.show_api
753
    client = TestClient(app)
754
    result = client.post(
755
        "/api/predict", json={"fn_index": 0, "data": [5], "session_hash": "foo"}
756
    )
757
    assert result.status_code == 404
758

759

760
def test_predict_route_not_blocked_if_queue_disabled():
761
    with Blocks() as demo:
762
        input = Textbox()
763
        output = Textbox()
764
        number = Number()
765
        button = Button()
766
        button.click(
767
            lambda x: f"Hello, {x}!", input, output, queue=False, api_name="not_blocked"
768
        )
769
        button.click(lambda: 42, None, number, queue=True, api_name="blocked")
770
    app, _, _ = demo.queue(api_open=False).launch(
771
        prevent_thread_lock=True, show_api=True
772
    )
773
    assert demo.show_api
774
    client = TestClient(app)
775

776
    result = client.post("/api/blocked", json={"data": [], "session_hash": "foo"})
777
    assert result.status_code == 404
778
    result = client.post(
779
        "/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
780
    )
781
    assert result.status_code == 200
782
    assert result.json()["data"] == ["Hello, freddy!"]
783

784

785
def test_predict_route_not_blocked_if_routes_open():
786
    with Blocks() as demo:
787
        input = Textbox()
788
        output = Textbox()
789
        button = Button()
790
        button.click(
791
            lambda x: f"Hello, {x}!", input, output, queue=True, api_name="not_blocked"
792
        )
793
    app, _, _ = demo.queue(api_open=True).launch(
794
        prevent_thread_lock=True, show_api=False
795
    )
796
    assert not demo.show_api
797
    client = TestClient(app)
798

799
    result = client.post(
800
        "/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
801
    )
802
    assert result.status_code == 200
803
    assert result.json()["data"] == ["Hello, freddy!"]
804

805
    demo.close()
806
    demo.queue(api_open=False).launch(prevent_thread_lock=True, show_api=False)
807
    assert not demo.show_api
808

809

810
def test_show_api_queue_not_enabled():
811
    io = Interface(lambda x: x, "text", "text", examples=[["freddy"]])
812
    app, _, _ = io.launch(prevent_thread_lock=True)
813
    assert io.show_api
814
    io.close()
815
    io.launch(prevent_thread_lock=True, show_api=False)
816
    assert not io.show_api
817

818

819
def test_orjson_serialization():
820
    df = pd.DataFrame(
821
        {
822
            "date_1": pd.date_range("2021-01-01", periods=2),
823
            "date_2": pd.date_range("2022-02-15", periods=2).strftime("%B %d, %Y, %r"),
824
            "number": np.array([0.2233, 0.57281]),
825
            "number_2": np.array([84, 23]).astype(np.int64),
826
            "bool": [True, False],
827
            "markdown": ["# Hello", "# Goodbye"],
828
        }
829
    )
830

831
    with gr.Blocks() as demo:
832
        gr.DataFrame(df)
833
    app, _, _ = demo.launch(prevent_thread_lock=True)
834
    test_client = TestClient(app)
835
    response = test_client.get("/")
836
    assert response.status_code == 200
837
    demo.close()
838

839

840
def test_api_name_set_for_all_events(connect):
841
    with gr.Blocks() as demo:
842
        i = Textbox()
843
        o = Textbox()
844
        btn = Button()
845
        btn1 = Button()
846
        btn2 = Button()
847
        btn3 = Button()
848
        btn4 = Button()
849
        btn5 = Button()
850
        btn6 = Button()
851
        btn7 = Button()
852
        btn8 = Button()
853

854
        def greet(i):
855
            return "Hello " + i
856

857
        def goodbye(i):
858
            return "Goodbye " + i
859

860
        def greet_me(i):
861
            return "Hello"
862

863
        def say_goodbye(i):
864
            return "Goodbye"
865

866
        say_goodbye.__name__ = "Say_$$_goodbye"
867

868
        # Otherwise changed by ruff
869
        foo = lambda s: s  # noqa
870

871
        def foo2(s):
872
            return s + " foo"
873

874
        foo2.__name__ = "foo-2"
875

876
        class Callable:
877
            def __call__(self, a) -> str:
878
                return "From __call__"
879

880
        def from_partial(a, b):
881
            return b + a
882

883
        part = functools.partial(from_partial, b="From partial: ")
884

885
        btn.click(greet, i, o)
886
        btn1.click(goodbye, i, o)
887
        btn2.click(greet_me, i, o)
888
        btn3.click(say_goodbye, i, o)
889
        btn4.click(None, i, o)
890
        btn5.click(foo, i, o)
891
        btn6.click(foo2, i, o)
892
        btn7.click(Callable(), i, o)
893
        btn8.click(part, i, o)
894

895
    with closing(demo) as io:
896
        app, _, _ = io.launch(prevent_thread_lock=True)
897
        client = TestClient(app)
898
        assert client.post(
899
            "/api/greet", json={"data": ["freddy"], "session_hash": "foo"}
900
        ).json()["data"] == ["Hello freddy"]
901
        assert client.post(
902
            "/api/goodbye", json={"data": ["freddy"], "session_hash": "foo"}
903
        ).json()["data"] == ["Goodbye freddy"]
904
        assert client.post(
905
            "/api/greet_me", json={"data": ["freddy"], "session_hash": "foo"}
906
        ).json()["data"] == ["Hello"]
907
        assert client.post(
908
            "/api/Say__goodbye", json={"data": ["freddy"], "session_hash": "foo"}
909
        ).json()["data"] == ["Goodbye"]
910
        assert client.post(
911
            "/api/lambda", json={"data": ["freddy"], "session_hash": "foo"}
912
        ).json()["data"] == ["freddy"]
913
        assert client.post(
914
            "/api/foo-2", json={"data": ["freddy"], "session_hash": "foo"}
915
        ).json()["data"] == ["freddy foo"]
916
        assert client.post(
917
            "/api/Callable", json={"data": ["freddy"], "session_hash": "foo"}
918
        ).json()["data"] == ["From __call__"]
919
        assert client.post(
920
            "/api/partial", json={"data": ["freddy"], "session_hash": "foo"}
921
        ).json()["data"] == ["From partial: freddy"]
922
        with pytest.raises(FnIndexInferError):
923
            client.post(
924
                "/api/Say_goodbye", json={"data": ["freddy"], "session_hash": "foo"}
925
            )
926

927
    with connect(demo) as client:
928
        assert client.predict("freddy", api_name="/greet") == "Hello freddy"
929
        assert client.predict("freddy", api_name="/goodbye") == "Goodbye freddy"
930
        assert client.predict("freddy", api_name="/greet_me") == "Hello"
931
        assert client.predict("freddy", api_name="/Say__goodbye") == "Goodbye"
932

933

934
class TestShowAPI:
935
    @patch.object(wasm_utils, "IS_WASM", True)
936
    def test_show_api_false_when_is_wasm_true(self):
937
        interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
938
        assert (
939
            interface.show_api is False
940
        ), "show_api should be False when IS_WASM is True"
941

942
    @patch.object(wasm_utils, "IS_WASM", False)
943
    def test_show_api_true_when_is_wasm_false(self):
944
        interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
945
        assert (
946
            interface.show_api is True
947
        ), "show_api should be True when IS_WASM is False"
948

949

950
def test_component_server_endpoints(connect):
951
    here = os.path.dirname(os.path.abspath(__file__))
952
    with gr.Blocks() as demo:
953
        file_explorer = gr.FileExplorer(root=here)
954

955
    with closing(demo) as io:
956
        app, _, _ = io.launch(prevent_thread_lock=True)
957
        client = TestClient(app)
958
        success_req = client.post(
959
            "/component_server/",
960
            json={
961
                "session_hash": "123",
962
                "component_id": file_explorer._id,
963
                "fn_name": "ls",
964
                "data": None,
965
            },
966
        )
967
        assert success_req.status_code == 200
968
        assert len(success_req.json()) > 0
969
        fail_req = client.post(
970
            "/component_server/",
971
            json={
972
                "session_hash": "123",
973
                "component_id": file_explorer._id,
974
                "fn_name": "preprocess",
975
                "data": None,
976
            },
977
        )
978
        assert fail_req.status_code == 404
979

980

981
@pytest.mark.parametrize(
982
    "request_url, route_path, root_path, expected_root_url",
983
    [
984
        ("http://localhost:7860/", "/", None, "http://localhost:7860"),
985
        (
986
            "http://localhost:7860/demo/test",
987
            "/demo/test",
988
            None,
989
            "http://localhost:7860",
990
        ),
991
        (
992
            "http://localhost:7860/demo/test/",
993
            "/demo/test",
994
            None,
995
            "http://localhost:7860",
996
        ),
997
        (
998
            "http://localhost:7860/demo/test?query=1",
999
            "/demo/test",
1000
            None,
1001
            "http://localhost:7860",
1002
        ),
1003
        (
1004
            "http://localhost:7860/demo/test?query=1",
1005
            "/demo/test/",
1006
            "/gradio/",
1007
            "http://localhost:7860/gradio",
1008
        ),
1009
        (
1010
            "http://localhost:7860/demo/test?query=1",
1011
            "/demo/test",
1012
            "/gradio/",
1013
            "http://localhost:7860/gradio",
1014
        ),
1015
        (
1016
            "https://localhost:7860/demo/test?query=1",
1017
            "/demo/test",
1018
            "/gradio/",
1019
            "https://localhost:7860/gradio",
1020
        ),
1021
        (
1022
            "https://www.gradio.app/playground/",
1023
            "/",
1024
            "/playground",
1025
            "https://www.gradio.app/playground",
1026
        ),
1027
        (
1028
            "https://www.gradio.app/playground/",
1029
            "/",
1030
            "/playground",
1031
            "https://www.gradio.app/playground",
1032
        ),
1033
        (
1034
            "https://www.gradio.app/playground/",
1035
            "/",
1036
            "",
1037
            "https://www.gradio.app/playground",
1038
        ),
1039
        (
1040
            "https://www.gradio.app/playground/",
1041
            "/",
1042
            "http://www.gradio.app/",
1043
            "http://www.gradio.app",
1044
        ),
1045
    ],
1046
)
1047
def test_get_root_url(
1048
    request_url: str, route_path: str, root_path: str, expected_root_url: str
1049
):
1050
    scope = {
1051
        "type": "http",
1052
        "headers": [],
1053
        "path": request_url,
1054
    }
1055
    request = Request(scope)
1056
    assert get_root_url(request, route_path, root_path) == expected_root_url
1057

1058

1059
@pytest.mark.parametrize(
1060
    "headers, root_path, expected_root_url",
1061
    [
1062
        ({}, "/gradio/", "http://gradio.app/gradio"),
1063
        ({"x-forwarded-proto": "http"}, "/gradio/", "http://gradio.app/gradio"),
1064
        ({"x-forwarded-proto": "https"}, "/gradio/", "https://gradio.app/gradio"),
1065
        ({"x-forwarded-host": "gradio.dev"}, "/gradio/", "http://gradio.dev/gradio"),
1066
        (
1067
            {"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
1068
            "/",
1069
            "https://gradio.dev",
1070
        ),
1071
        (
1072
            {"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
1073
            "http://google.com",
1074
            "http://google.com",
1075
        ),
1076
    ],
1077
)
1078
def test_get_root_url_headers(
1079
    headers: Dict[str, str], root_path: str, expected_root_url: str
1080
):
1081
    scope = {
1082
        "type": "http",
1083
        "headers": [(k.encode(), v.encode()) for k, v in headers.items()],
1084
        "path": "http://gradio.app",
1085
    }
1086
    request = Request(scope)
1087
    assert get_root_url(request, "/", root_path) == expected_root_url
1088

1089

1090
class TestSimpleAPIRoutes:
1091
    def get_demo(self):
1092
        with Blocks() as demo:
1093
            input = Textbox()
1094
            output = Textbox()
1095
            output2 = Textbox()
1096

1097
            def fn_1(x):
1098
                return f"Hello, {x}!"
1099

1100
            def fn_2(x):
1101
                for i in range(len(x)):
1102
                    time.sleep(0.5)
1103
                    yield f"Hello, {x[:i+1]}!"
1104
                if len(x) < 3:
1105
                    raise ValueError("Small input")
1106

1107
            def fn_3():
1108
                return "a", "b"
1109

1110
            btn1, btn2, btn3 = Button(), Button(), Button()
1111
            btn1.click(fn_1, input, output, api_name="fn1")
1112
            btn2.click(fn_2, input, output2, api_name="fn2")
1113
            btn3.click(fn_3, None, [output, output2], api_name="fn3")
1114
        return demo
1115

1116
    def test_successful_simple_route(self):
1117
        demo = self.get_demo()
1118
        demo.launch(prevent_thread_lock=True)
1119

1120
        response = requests.post(f"{demo.local_url}call/fn1", json={"data": ["world"]})
1121

1122
        assert response.status_code == 200, "Failed to call fn1"
1123
        response = response.json()
1124
        event_id = response["event_id"]
1125

1126
        output = []
1127
        response = requests.get(f"{demo.local_url}call/fn1/{event_id}", stream=True)
1128

1129
        for line in response.iter_lines():
1130
            if line:
1131
                output.append(line.decode("utf-8"))
1132

1133
        assert output == ["event: complete", 'data: ["Hello, world!"]']
1134

1135
        response = requests.post(f"{demo.local_url}call/fn3", json={"data": []})
1136

1137
        assert response.status_code == 200, "Failed to call fn3"
1138
        response = response.json()
1139
        event_id = response["event_id"]
1140

1141
        output = []
1142
        response = requests.get(f"{demo.local_url}call/fn3/{event_id}", stream=True)
1143

1144
        for line in response.iter_lines():
1145
            if line:
1146
                output.append(line.decode("utf-8"))
1147

1148
        assert output == ["event: complete", 'data: ["a", "b"]']
1149

1150
    def test_generative_simple_route(self):
1151
        demo = self.get_demo()
1152
        demo.launch(prevent_thread_lock=True)
1153

1154
        response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["world"]})
1155

1156
        assert response.status_code == 200, "Failed to call fn2"
1157
        response = response.json()
1158
        event_id = response["event_id"]
1159

1160
        output = []
1161
        response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
1162

1163
        for line in response.iter_lines():
1164
            if line:
1165
                output.append(line.decode("utf-8"))
1166

1167
        assert output == [
1168
            "event: generating",
1169
            'data: ["Hello, w!"]',
1170
            "event: generating",
1171
            'data: ["Hello, wo!"]',
1172
            "event: generating",
1173
            'data: ["Hello, wor!"]',
1174
            "event: generating",
1175
            'data: ["Hello, worl!"]',
1176
            "event: generating",
1177
            'data: ["Hello, world!"]',
1178
            "event: complete",
1179
            'data: ["Hello, world!"]',
1180
        ]
1181

1182
        response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["w"]})
1183

1184
        assert response.status_code == 200, "Failed to call fn2"
1185
        response = response.json()
1186
        event_id = response["event_id"]
1187

1188
        output = []
1189
        response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
1190

1191
        for line in response.iter_lines():
1192
            if line:
1193
                output.append(line.decode("utf-8"))
1194

1195
        assert output == [
1196
            "event: generating",
1197
            'data: ["Hello, w!"]',
1198
            "event: error",
1199
            "data: null",
1200
        ]
1201

1202

1203
def test_compare_passwords_securely():
1204
    password1 = "password"
1205
    password2 = "pässword"
1206
    assert compare_passwords_securely(password1, password1)
1207
    assert not compare_passwords_securely(password1, password2)
1208
    assert compare_passwords_securely(password2, password2)
1209

1210

1211
@pytest.mark.parametrize(
1212
    "string, expected",
1213
    [
1214
        ("http://localhost:7860/", True),
1215
        ("https://localhost:7860/", True),
1216
        ("ftp://localhost:7860/", True),
1217
        ("smb://example.com", True),
1218
        ("ipfs://QmTzQ1Nj5R9BzF1djVQv8gvzZxVkJb1vhrLcXL1QyJzZE", True),
1219
        ("usr/local/bin", False),
1220
        ("localhost:7860", False),
1221
        ("localhost", False),
1222
        ("C:/Users/username", False),
1223
        ("//path", True),
1224
        ("\\\\path", True),
1225
        ("/usr/bin//test", False),
1226
        ("/\\10.0.225.200/share", True),
1227
        ("\\/10.0.225.200/share", True),
1228
        ("/home//user", False),
1229
        ("C:\\folder\\file", False),
1230
    ],
1231
)
1232
def test_starts_with_protocol(string, expected):
1233
    assert starts_with_protocol(string) == expected
1234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.