1
"""Contains tests for networking.py and app.py"""
6
from contextlib import asynccontextmanager, closing
8
from unittest.mock import patch
10
import gradio_client as grc
15
import starlette.routing
16
from fastapi import FastAPI, Request
17
from fastapi.testclient import TestClient
18
from gradio_client import media_data
31
from gradio.route_utils import (
33
compare_passwords_securely,
41
io = Interface(lambda x: x + x, "text", "text")
42
app, _, _ = io.launch(prevent_thread_lock=True)
43
test_client = TestClient(app)
50
def test_get_main_route(self, test_client):
51
response = test_client.get("/")
52
assert response.status_code == 200
54
def test_static_files_served_safely(self, test_client):
56
response = test_client.get(r"/static/..%2findex.html")
57
assert response.status_code == 403
58
response = test_client.get(r"/static/..%2f..%2fapi_docs.html")
59
assert response.status_code == 403
61
def test_get_config_route(self, test_client):
62
response = test_client.get("/config/")
63
assert response.status_code == 200
65
def test_favicon_route(self, test_client):
66
response = test_client.get("/favicon.ico")
67
assert response.status_code == 200
69
def test_upload_path(self, test_client):
70
with open("test/test_files/alphabet.txt", "rb") as f:
71
response = test_client.post("/upload", files={"files": f})
72
assert response.status_code == 200
73
file = response.json()[0]
74
assert "alphabet" in file
75
assert file.endswith(".txt")
76
with open(file, "rb") as saved_file:
77
assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"
79
def test_custom_upload_path(self, gradio_temp_dir):
80
io = Interface(lambda x: x + x, "text", "text")
81
app, _, _ = io.launch(prevent_thread_lock=True)
82
test_client = TestClient(app)
83
with open("test/test_files/alphabet.txt", "rb") as f:
84
response = test_client.post("/upload", files={"files": f})
85
assert response.status_code == 200
86
file = response.json()[0]
87
assert "alphabet" in file
88
assert file.startswith(str(gradio_temp_dir))
89
assert file.endswith(".txt")
90
with open(file, "rb") as saved_file:
91
assert saved_file.read() == b"abcdefghijklmnopqrstuvwxyz"
93
def test_predict_route(self, test_client):
94
response = test_client.post(
95
"/api/predict/", json={"data": ["test"], "fn_index": 0}
97
assert response.status_code == 200
98
output = dict(response.json())
99
assert output["data"] == ["testtest"]
101
def test_named_predict_route(self):
102
with Blocks() as demo:
105
i.change(lambda x: f"{x}1", i, o, api_name="p")
106
i.change(lambda x: f"{x}2", i, o, api_name="q")
108
app, _, _ = demo.launch(prevent_thread_lock=True)
109
client = TestClient(app)
110
response = client.post("/api/p/", json={"data": ["test"]})
111
assert response.status_code == 200
112
output = dict(response.json())
113
assert output["data"] == ["test1"]
115
response = client.post("/api/q/", json={"data": ["test"]})
116
assert response.status_code == 200
117
output = dict(response.json())
118
assert output["data"] == ["test2"]
120
def test_same_named_predict_route(self):
121
with Blocks() as demo:
124
i.change(lambda x: f"{x}0", i, o, api_name="p")
125
i.change(lambda x: f"{x}1", i, o, api_name="p")
127
app, _, _ = demo.launch(prevent_thread_lock=True)
128
client = TestClient(app)
129
response = client.post("/api/p/", json={"data": ["test"]})
130
assert response.status_code == 200
131
output = dict(response.json())
132
assert output["data"] == ["test0"]
134
response = client.post("/api/p_1/", json={"data": ["test"]})
135
assert response.status_code == 200
136
output = dict(response.json())
137
assert output["data"] == ["test1"]
139
def test_multiple_renamed(self):
140
with Blocks() as demo:
143
i.change(lambda x: f"{x}0", i, o, api_name="p")
144
i.change(lambda x: f"{x}1", i, o, api_name="p")
145
i.change(lambda x: f"{x}2", i, o, api_name="p_1")
147
app, _, _ = demo.launch(prevent_thread_lock=True)
148
client = TestClient(app)
149
response = client.post("/api/p/", json={"data": ["test"]})
150
assert response.status_code == 200
151
output = dict(response.json())
152
assert output["data"] == ["test0"]
154
response = client.post("/api/p_1/", json={"data": ["test"]})
155
assert response.status_code == 200
156
output = dict(response.json())
157
assert output["data"] == ["test1"]
159
response = client.post("/api/p_1_1/", json={"data": ["test"]})
160
assert response.status_code == 200
161
output = dict(response.json())
162
assert output["data"] == ["test2"]
164
def test_predict_route_without_fn_index(self, test_client):
165
response = test_client.post("/api/predict/", json={"data": ["test"]})
166
assert response.status_code == 200
167
output = dict(response.json())
168
assert output["data"] == ["testtest"]
170
def test_predict_route_batching(self):
174
results.append(f"Hello {word}")
177
with gr.Blocks() as demo:
180
btn.click(batch_fn, inputs=text, outputs=text, batch=True, api_name="pred")
182
demo.queue(api_open=True)
183
app, _, _ = demo.launch(prevent_thread_lock=True)
184
client = TestClient(app)
185
response = client.post("/api/pred/", json={"data": ["test"]})
186
output = dict(response.json())
187
assert output["data"] == ["Hello test"]
189
app, _, _ = demo.launch(prevent_thread_lock=True)
190
client = TestClient(app)
191
response = client.post(
192
"/api/pred/", json={"data": [["test", "test2"]], "batched": True}
194
output = dict(response.json())
195
assert output["data"] == [["Hello test", "Hello test2"]]
197
def test_state(self):
198
def predict(input, history):
202
return history, history
204
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
205
app, _, _ = io.launch(prevent_thread_lock=True)
206
client = TestClient(app)
207
response = client.post(
209
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
211
output = dict(response.json())
212
assert output["data"] == ["test", None]
213
response = client.post(
215
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
217
output = dict(response.json())
218
assert output["data"] == ["testtest", None]
220
def test_get_allowed_paths(self):
221
allowed_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
222
allowed_file.write(media_data.BASE64_IMAGE)
225
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
226
app, _, _ = io.launch(prevent_thread_lock=True)
227
client = TestClient(app)
228
file_response = client.get(f"/file={allowed_file.name}")
229
assert file_response.status_code == 403
232
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
233
app, _, _ = io.launch(
234
prevent_thread_lock=True,
235
allowed_paths=[os.path.dirname(allowed_file.name)],
237
client = TestClient(app)
238
file_response = client.get(f"/file={allowed_file.name}")
239
assert file_response.status_code == 200
240
assert len(file_response.text) == len(media_data.BASE64_IMAGE)
243
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
244
app, _, _ = io.launch(
245
prevent_thread_lock=True,
246
allowed_paths=[os.path.abspath(allowed_file.name)],
248
client = TestClient(app)
249
file_response = client.get(f"/file={allowed_file.name}")
250
assert file_response.status_code == 200
251
assert len(file_response.text) == len(media_data.BASE64_IMAGE)
254
def test_allowed_and_blocked_paths(self):
255
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
256
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
257
app, _, _ = io.launch(
258
prevent_thread_lock=True,
259
allowed_paths=[os.path.dirname(tmp_file.name)],
261
client = TestClient(app)
262
file_response = client.get(f"/file={tmp_file.name}")
263
assert file_response.status_code == 200
265
os.remove(tmp_file.name)
267
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
268
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
269
app, _, _ = io.launch(
270
prevent_thread_lock=True,
271
allowed_paths=[os.path.dirname(tmp_file.name)],
272
blocked_paths=[os.path.dirname(tmp_file.name)],
274
client = TestClient(app)
275
file_response = client.get(f"/file={tmp_file.name}")
276
assert file_response.status_code == 403
278
os.remove(tmp_file.name)
280
def test_get_file_created_by_app(self, test_client):
281
app, _, _ = gr.Interface(lambda s: s.name, gr.File(), gr.File()).launch(
282
prevent_thread_lock=True
284
client = TestClient(app)
285
with open("test/test_files/alphabet.txt", "rb") as f:
286
file_response = test_client.post("/upload", files={"files": f})
287
response = client.post(
292
"path": file_response.json()[0],
293
"size": os.path.getsize("test/test_files/alphabet.txt"),
300
created_file = response["data"][0]["path"]
301
file_response = client.get(f"/file={created_file}")
302
assert file_response.is_success
304
backwards_compatible_file_response = client.get(f"/file/{created_file}")
305
assert backwards_compatible_file_response.is_success
307
file_response_with_full_range = client.get(
308
f"/file={created_file}", headers={"Range": "bytes=0-"}
310
assert file_response_with_full_range.is_success
311
assert file_response.text == file_response_with_full_range.text
313
file_response_with_partial_range = client.get(
314
f"/file={created_file}", headers={"Range": "bytes=0-10"}
316
assert file_response_with_partial_range.is_success
317
assert len(file_response_with_partial_range.text) == 11
319
def test_mount_gradio_app(self):
323
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
325
demo1 = gr.Interface(
326
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
329
app = gr.mount_gradio_app(app, demo, path="/ps")
330
app = gr.mount_gradio_app(app, demo1, path="/py")
333
with TestClient(app) as client:
334
assert client.get("/ps").is_success
335
assert client.get("/py").is_success
337
def test_mount_gradio_app_with_app_kwargs(self):
340
demo = gr.Interface(lambda s: f"You said {s}!", "textbox", "textbox").queue()
342
app = gr.mount_gradio_app(
343
app, demo, path="/echo", app_kwargs={"docs_url": "/docs-custom"}
347
with TestClient(app) as client:
348
assert client.get("/echo/docs-custom").is_success
350
def test_mount_gradio_app_with_lifespan(self):
352
async def empty_lifespan(app: FastAPI):
355
app = FastAPI(lifespan=empty_lifespan)
358
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
360
demo1 = gr.Interface(
361
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
364
app = gr.mount_gradio_app(app, demo, path="/ps")
365
app = gr.mount_gradio_app(app, demo1, path="/py")
368
with TestClient(app) as client:
369
assert client.get("/ps").is_success
370
assert client.get("/py").is_success
372
def test_mount_gradio_app_with_startup(self):
375
@app.on_event("startup")
376
async def empty_startup():
380
lambda s: f"Hello from ps, {s}!", "textbox", "textbox"
382
demo1 = gr.Interface(
383
lambda s: f"Hello from py, {s}!", "textbox", "textbox"
386
app = gr.mount_gradio_app(app, demo, path="/ps")
387
app = gr.mount_gradio_app(app, demo1, path="/py")
390
with TestClient(app) as client:
391
assert client.get("/ps").is_success
392
assert client.get("/py").is_success
394
def test_mount_gradio_app_with_auth_dependency(self):
397
def get_user(request: Request):
398
return request.headers.get("user")
400
demo = gr.Interface(lambda s: f"Hello from ps, {s}!", "textbox", "textbox")
402
app = gr.mount_gradio_app(app, demo, path="/demo", auth_dependency=get_user)
404
with TestClient(app) as client:
405
assert client.get("/demo", headers={"user": "abubakar"}).is_success
406
assert not client.get("/demo").is_success
408
def test_static_file_missing(self, test_client):
409
response = test_client.get(r"/static/not-here.js")
410
assert response.status_code == 404
412
def test_asset_file_missing(self, test_client):
413
response = test_client.get(r"/assets/not-here.js")
414
assert response.status_code == 404
416
def test_cannot_access_files_in_working_directory(self, test_client):
417
response = test_client.get(r"/file=not-here.js")
418
assert response.status_code == 403
419
response = test_client.get(r"/file=subdir/.env")
420
assert response.status_code == 403
422
def test_cannot_access_directories_in_working_directory(self, test_client):
423
response = test_client.get(r"/file=gradio")
424
assert response.status_code == 403
426
def test_block_protocols_that_expose_windows_credentials(self, test_client):
427
response = test_client.get(r"/file=//11.0.225.200/share")
428
assert response.status_code == 403
430
def test_do_not_expose_existence_of_files_outside_working_directory(
433
response = test_client.get(r"/file=../fake-file-that-does-not-exist.js")
434
assert response.status_code == 403
436
def test_proxy_route_is_restricted_to_load_urls(self):
437
gr.context.Context.hf_token = "abcdef"
439
interface = gr.Interface(lambda x: x, "text", "text")
440
app.configure_app(interface)
441
with pytest.raises(PermissionError):
442
app.build_proxy_request(
443
"https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
445
with pytest.raises(PermissionError):
446
app.build_proxy_request("https://google.com")
447
interface.proxy_urls = {
448
"https://gradio-tests-test-loading-examples-private.hf.space"
450
app.build_proxy_request(
451
"https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
454
def test_proxy_does_not_leak_hf_token_externally(self):
455
gr.context.Context.hf_token = "abcdef"
457
interface = gr.Interface(lambda x: x, "text", "text")
458
interface.proxy_urls = {
459
"https://gradio-tests-test-loading-examples-private.hf.space",
460
"https://google.com",
462
app.configure_app(interface)
463
r = app.build_proxy_request(
464
"https://gradio-tests-test-loading-examples-private.hf.space/file=Bunny.obj"
466
assert "authorization" in dict(r.headers)
467
r = app.build_proxy_request("https://google.com")
468
assert "authorization" not in dict(r.headers)
470
def test_can_get_config_that_includes_non_pickle_able_objects(self):
471
my_dict = {"a": 1, "b": 2, "c": 3}
472
with Blocks() as demo:
473
gr.JSON(my_dict.keys())
475
app, _, _ = demo.launch(prevent_thread_lock=True)
476
client = TestClient(app)
477
response = client.get("/")
478
assert response.is_success
479
response = client.get("/config/")
480
assert response.is_success
482
def test_cors_restrictions(self):
483
io = gr.Interface(lambda s: s.name, gr.File(), gr.File())
484
app, _, _ = io.launch(prevent_thread_lock=True)
485
client = TestClient(app)
487
"host": "localhost:7860",
488
"origin": "https://example.com",
490
file_response = client.get("/config", headers=custom_headers)
491
assert "access-control-allow-origin" not in file_response.headers
493
"host": "localhost:7860",
494
"origin": "127.0.0.1",
496
file_response = client.get("/config", headers=custom_headers)
497
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
500
def test_delete_cache(self, connect, gradio_temp_dir, capsys):
501
def check_num_files_exist(blocks: Blocks):
503
for temp_file_set in blocks.temp_file_sets:
504
for temp_file in temp_file_set:
505
if os.path.exists(temp_file):
509
demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
510
with connect(demo) as client:
511
client.predict("test/test_files/cheetah1.jpg")
512
assert check_num_files_exist(demo) == 1
514
demo_delete = gr.Interface(
515
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
517
with connect(demo_delete) as client:
518
client.predict("test/test_files/alphabet.txt")
519
client.predict("test/test_files/bus.png")
520
assert check_num_files_exist(demo_delete) == 2
521
assert check_num_files_exist(demo_delete) == 0
522
assert check_num_files_exist(demo) == 1
525
async def mylifespan(app: FastAPI):
526
print("IN CUSTOM LIFESPAN")
528
print("AFTER CUSTOM LIFESPAN")
530
demo_custom_lifespan = gr.Interface(
531
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
535
demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
537
client.predict("test/test_files/alphabet.txt")
538
assert check_num_files_exist(demo_custom_lifespan) == 0
539
captured = capsys.readouterr()
540
assert "IN CUSTOM LIFESPAN" in captured.out
541
assert "AFTER CUSTOM LIFESPAN" in captured.out
545
def test_create_app(self):
546
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
547
assert isinstance(app, FastAPI)
550
class TestAuthenticatedRoutes:
551
def test_post_login(self):
552
io = Interface(lambda x: x, "text", "text")
553
app, _, _ = io.launch(
554
auth=("test", "correct_password"),
555
prevent_thread_lock=True,
557
client = TestClient(app)
559
response = client.post(
561
data={"username": "test", "password": "correct_password"},
563
assert response.status_code == 200
565
response = client.post(
567
data={"username": "test", "password": "incorrect_password"},
569
assert response.status_code == 400
573
data={"username": "test", "password": "correct_password"},
575
response = client.post(
577
data={"username": " test ", "password": "correct_password"},
579
assert response.status_code == 200
581
def test_logout(self):
582
io = Interface(lambda x: x, "text", "text")
583
app, _, _ = io.launch(
584
auth=("test", "correct_password"),
585
prevent_thread_lock=True,
587
client = TestClient(app)
591
data={"username": "test", "password": "correct_password"},
594
response = client.post(
596
json={"data": ["test"]},
598
assert response.status_code == 200
600
response = client.get("/logout")
602
response = client.post(
604
json={"data": ["test"]},
606
assert response.status_code == 401
609
class TestQueueRoutes:
611
async def test_queue_join_routes_sets_app_if_none_set(self):
612
io = Interface(lambda x: x, "text", "text").queue()
613
io.launch(prevent_thread_lock=True)
614
io._queue.server_path = None
616
client = grc.Client(io.local_url)
617
client.predict("test")
619
assert io._queue.server_app == io.server_app
623
def test_mount_gradio_app_set_dev_mode_false(self):
628
return {"message": "Hello!"}
630
with gr.Blocks() as blocks:
631
gr.Textbox("Hello from gradio!")
633
app = routes.mount_gradio_app(app, blocks, path="/gradio")
634
gradio_fast_api = next(
635
route for route in app.routes if isinstance(route, starlette.routing.Mount)
637
assert not gradio_fast_api.app.blocks.dev_mode
640
class TestPassingRequest:
641
def test_request_included_with_interface(self):
642
def identity(name, request: gr.Request):
643
assert isinstance(request.client.host, str)
646
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
647
prevent_thread_lock=True,
649
client = TestClient(app)
651
response = client.post("/api/predict/", json={"data": ["test"]})
652
assert response.status_code == 200
653
output = dict(response.json())
654
assert output["data"] == ["test"]
656
def test_request_included_with_chat_interface(self):
657
def identity(x, y, request: gr.Request):
658
assert isinstance(request.client.host, str)
661
app, _, _ = gr.ChatInterface(identity).launch(
662
prevent_thread_lock=True,
664
client = TestClient(app)
666
response = client.post("/api/chat/", json={"data": ["test", None]})
667
assert response.status_code == 200
668
output = dict(response.json())
669
assert output["data"] == ["test", None]
671
def test_request_included_with_chat_interface_when_streaming(self):
672
def identity(x, y, request: gr.Request):
673
assert isinstance(request.client.host, str)
674
for i in range(len(x)):
678
gr.ChatInterface(identity)
679
.queue(api_open=True)
681
prevent_thread_lock=True,
684
client = TestClient(app)
686
response = client.post("/api/chat/", json={"data": ["test", None]})
687
assert response.status_code == 200
688
output = dict(response.json())
689
assert output["data"] == ["t", None]
691
def test_request_get_headers(self):
692
def identity(name, request: gr.Request):
693
assert isinstance(request.headers["user-agent"], str)
694
assert isinstance(request.headers.items(), list)
695
assert isinstance(request.headers.keys(), list)
696
assert isinstance(request.headers.values(), list)
697
assert isinstance(dict(request.headers), dict)
698
user_agent = request.headers["user-agent"]
699
assert "testclient" in user_agent
702
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
703
prevent_thread_lock=True,
705
client = TestClient(app)
707
response = client.post("/api/predict/", json={"data": ["test"]})
708
assert response.status_code == 200
709
output = dict(response.json())
710
assert output["data"] == ["test"]
712
def test_request_includes_username_as_none_if_no_auth(self):
713
def identity(name, request: gr.Request):
714
assert request.username is None
717
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
718
prevent_thread_lock=True,
720
client = TestClient(app)
722
response = client.post("/api/predict/", json={"data": ["test"]})
723
assert response.status_code == 200
724
output = dict(response.json())
725
assert output["data"] == ["test"]
727
def test_request_includes_username_with_auth(self):
728
def identity(name, request: gr.Request):
729
assert request.username == "admin"
732
app, _, _ = gr.Interface(identity, "textbox", "textbox").launch(
733
prevent_thread_lock=True, auth=("admin", "password")
735
client = TestClient(app)
739
data={"username": "admin", "password": "password"},
741
response = client.post("/api/predict/", json={"data": ["test"]})
742
assert response.status_code == 200
743
output = dict(response.json())
744
assert output["data"] == ["test"]
747
def test_predict_route_is_blocked_if_api_open_false():
748
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]]).queue(
751
app, _, _ = io.launch(prevent_thread_lock=True)
753
client = TestClient(app)
754
result = client.post(
755
"/api/predict", json={"fn_index": 0, "data": [5], "session_hash": "foo"}
757
assert result.status_code == 404
760
def test_predict_route_not_blocked_if_queue_disabled():
761
with Blocks() as demo:
767
lambda x: f"Hello, {x}!", input, output, queue=False, api_name="not_blocked"
769
button.click(lambda: 42, None, number, queue=True, api_name="blocked")
770
app, _, _ = demo.queue(api_open=False).launch(
771
prevent_thread_lock=True, show_api=True
774
client = TestClient(app)
776
result = client.post("/api/blocked", json={"data": [], "session_hash": "foo"})
777
assert result.status_code == 404
778
result = client.post(
779
"/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
781
assert result.status_code == 200
782
assert result.json()["data"] == ["Hello, freddy!"]
785
def test_predict_route_not_blocked_if_routes_open():
786
with Blocks() as demo:
791
lambda x: f"Hello, {x}!", input, output, queue=True, api_name="not_blocked"
793
app, _, _ = demo.queue(api_open=True).launch(
794
prevent_thread_lock=True, show_api=False
796
assert not demo.show_api
797
client = TestClient(app)
799
result = client.post(
800
"/api/not_blocked", json={"data": ["freddy"], "session_hash": "foo"}
802
assert result.status_code == 200
803
assert result.json()["data"] == ["Hello, freddy!"]
806
demo.queue(api_open=False).launch(prevent_thread_lock=True, show_api=False)
807
assert not demo.show_api
810
def test_show_api_queue_not_enabled():
811
io = Interface(lambda x: x, "text", "text", examples=[["freddy"]])
812
app, _, _ = io.launch(prevent_thread_lock=True)
815
io.launch(prevent_thread_lock=True, show_api=False)
816
assert not io.show_api
819
def test_orjson_serialization():
822
"date_1": pd.date_range("2021-01-01", periods=2),
823
"date_2": pd.date_range("2022-02-15", periods=2).strftime("%B %d, %Y, %r"),
824
"number": np.array([0.2233, 0.57281]),
825
"number_2": np.array([84, 23]).astype(np.int64),
826
"bool": [True, False],
827
"markdown": ["# Hello", "# Goodbye"],
831
with gr.Blocks() as demo:
833
app, _, _ = demo.launch(prevent_thread_lock=True)
834
test_client = TestClient(app)
835
response = test_client.get("/")
836
assert response.status_code == 200
840
def test_api_name_set_for_all_events(connect):
841
with gr.Blocks() as demo:
858
return "Goodbye " + i
866
say_goodbye.__name__ = "Say_$$_goodbye"
874
foo2.__name__ = "foo-2"
877
def __call__(self, a) -> str:
878
return "From __call__"
880
def from_partial(a, b):
883
part = functools.partial(from_partial, b="From partial: ")
885
btn.click(greet, i, o)
886
btn1.click(goodbye, i, o)
887
btn2.click(greet_me, i, o)
888
btn3.click(say_goodbye, i, o)
889
btn4.click(None, i, o)
890
btn5.click(foo, i, o)
891
btn6.click(foo2, i, o)
892
btn7.click(Callable(), i, o)
893
btn8.click(part, i, o)
895
with closing(demo) as io:
896
app, _, _ = io.launch(prevent_thread_lock=True)
897
client = TestClient(app)
899
"/api/greet", json={"data": ["freddy"], "session_hash": "foo"}
900
).json()["data"] == ["Hello freddy"]
902
"/api/goodbye", json={"data": ["freddy"], "session_hash": "foo"}
903
).json()["data"] == ["Goodbye freddy"]
905
"/api/greet_me", json={"data": ["freddy"], "session_hash": "foo"}
906
).json()["data"] == ["Hello"]
908
"/api/Say__goodbye", json={"data": ["freddy"], "session_hash": "foo"}
909
).json()["data"] == ["Goodbye"]
911
"/api/lambda", json={"data": ["freddy"], "session_hash": "foo"}
912
).json()["data"] == ["freddy"]
914
"/api/foo-2", json={"data": ["freddy"], "session_hash": "foo"}
915
).json()["data"] == ["freddy foo"]
917
"/api/Callable", json={"data": ["freddy"], "session_hash": "foo"}
918
).json()["data"] == ["From __call__"]
920
"/api/partial", json={"data": ["freddy"], "session_hash": "foo"}
921
).json()["data"] == ["From partial: freddy"]
922
with pytest.raises(FnIndexInferError):
924
"/api/Say_goodbye", json={"data": ["freddy"], "session_hash": "foo"}
927
with connect(demo) as client:
928
assert client.predict("freddy", api_name="/greet") == "Hello freddy"
929
assert client.predict("freddy", api_name="/goodbye") == "Goodbye freddy"
930
assert client.predict("freddy", api_name="/greet_me") == "Hello"
931
assert client.predict("freddy", api_name="/Say__goodbye") == "Goodbye"
935
@patch.object(wasm_utils, "IS_WASM", True)
936
def test_show_api_false_when_is_wasm_true(self):
937
interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
939
interface.show_api is False
940
), "show_api should be False when IS_WASM is True"
942
@patch.object(wasm_utils, "IS_WASM", False)
943
def test_show_api_true_when_is_wasm_false(self):
944
interface = Interface(lambda x: x, "text", "text", examples=[["hannah"]])
946
interface.show_api is True
947
), "show_api should be True when IS_WASM is False"
950
def test_component_server_endpoints(connect):
951
here = os.path.dirname(os.path.abspath(__file__))
952
with gr.Blocks() as demo:
953
file_explorer = gr.FileExplorer(root=here)
955
with closing(demo) as io:
956
app, _, _ = io.launch(prevent_thread_lock=True)
957
client = TestClient(app)
958
success_req = client.post(
959
"/component_server/",
961
"session_hash": "123",
962
"component_id": file_explorer._id,
967
assert success_req.status_code == 200
968
assert len(success_req.json()) > 0
969
fail_req = client.post(
970
"/component_server/",
972
"session_hash": "123",
973
"component_id": file_explorer._id,
974
"fn_name": "preprocess",
978
assert fail_req.status_code == 404
981
@pytest.mark.parametrize(
982
"request_url, route_path, root_path, expected_root_url",
984
("http://localhost:7860/", "/", None, "http://localhost:7860"),
986
"http://localhost:7860/demo/test",
989
"http://localhost:7860",
992
"http://localhost:7860/demo/test/",
995
"http://localhost:7860",
998
"http://localhost:7860/demo/test?query=1",
1001
"http://localhost:7860",
1004
"http://localhost:7860/demo/test?query=1",
1007
"http://localhost:7860/gradio",
1010
"http://localhost:7860/demo/test?query=1",
1013
"http://localhost:7860/gradio",
1016
"https://localhost:7860/demo/test?query=1",
1019
"https://localhost:7860/gradio",
1022
"https://www.gradio.app/playground/",
1025
"https://www.gradio.app/playground",
1028
"https://www.gradio.app/playground/",
1031
"https://www.gradio.app/playground",
1034
"https://www.gradio.app/playground/",
1037
"https://www.gradio.app/playground",
1040
"https://www.gradio.app/playground/",
1042
"http://www.gradio.app/",
1043
"http://www.gradio.app",
1047
def test_get_root_url(
1048
request_url: str, route_path: str, root_path: str, expected_root_url: str
1053
"path": request_url,
1055
request = Request(scope)
1056
assert get_root_url(request, route_path, root_path) == expected_root_url
1059
@pytest.mark.parametrize(
1060
"headers, root_path, expected_root_url",
1062
({}, "/gradio/", "http://gradio.app/gradio"),
1063
({"x-forwarded-proto": "http"}, "/gradio/", "http://gradio.app/gradio"),
1064
({"x-forwarded-proto": "https"}, "/gradio/", "https://gradio.app/gradio"),
1065
({"x-forwarded-host": "gradio.dev"}, "/gradio/", "http://gradio.dev/gradio"),
1067
{"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
1069
"https://gradio.dev",
1072
{"x-forwarded-host": "gradio.dev", "x-forwarded-proto": "https"},
1073
"http://google.com",
1074
"http://google.com",
1078
def test_get_root_url_headers(
1079
headers: Dict[str, str], root_path: str, expected_root_url: str
1083
"headers": [(k.encode(), v.encode()) for k, v in headers.items()],
1084
"path": "http://gradio.app",
1086
request = Request(scope)
1087
assert get_root_url(request, "/", root_path) == expected_root_url
1090
class TestSimpleAPIRoutes:
1092
with Blocks() as demo:
1098
return f"Hello, {x}!"
1101
for i in range(len(x)):
1103
yield f"Hello, {x[:i+1]}!"
1105
raise ValueError("Small input")
1110
btn1, btn2, btn3 = Button(), Button(), Button()
1111
btn1.click(fn_1, input, output, api_name="fn1")
1112
btn2.click(fn_2, input, output2, api_name="fn2")
1113
btn3.click(fn_3, None, [output, output2], api_name="fn3")
1116
def test_successful_simple_route(self):
1117
demo = self.get_demo()
1118
demo.launch(prevent_thread_lock=True)
1120
response = requests.post(f"{demo.local_url}call/fn1", json={"data": ["world"]})
1122
assert response.status_code == 200, "Failed to call fn1"
1123
response = response.json()
1124
event_id = response["event_id"]
1127
response = requests.get(f"{demo.local_url}call/fn1/{event_id}", stream=True)
1129
for line in response.iter_lines():
1131
output.append(line.decode("utf-8"))
1133
assert output == ["event: complete", 'data: ["Hello, world!"]']
1135
response = requests.post(f"{demo.local_url}call/fn3", json={"data": []})
1137
assert response.status_code == 200, "Failed to call fn3"
1138
response = response.json()
1139
event_id = response["event_id"]
1142
response = requests.get(f"{demo.local_url}call/fn3/{event_id}", stream=True)
1144
for line in response.iter_lines():
1146
output.append(line.decode("utf-8"))
1148
assert output == ["event: complete", 'data: ["a", "b"]']
1150
def test_generative_simple_route(self):
1151
demo = self.get_demo()
1152
demo.launch(prevent_thread_lock=True)
1154
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["world"]})
1156
assert response.status_code == 200, "Failed to call fn2"
1157
response = response.json()
1158
event_id = response["event_id"]
1161
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
1163
for line in response.iter_lines():
1165
output.append(line.decode("utf-8"))
1168
"event: generating",
1169
'data: ["Hello, w!"]',
1170
"event: generating",
1171
'data: ["Hello, wo!"]',
1172
"event: generating",
1173
'data: ["Hello, wor!"]',
1174
"event: generating",
1175
'data: ["Hello, worl!"]',
1176
"event: generating",
1177
'data: ["Hello, world!"]',
1179
'data: ["Hello, world!"]',
1182
response = requests.post(f"{demo.local_url}call/fn2", json={"data": ["w"]})
1184
assert response.status_code == 200, "Failed to call fn2"
1185
response = response.json()
1186
event_id = response["event_id"]
1189
response = requests.get(f"{demo.local_url}call/fn2/{event_id}", stream=True)
1191
for line in response.iter_lines():
1193
output.append(line.decode("utf-8"))
1196
"event: generating",
1197
'data: ["Hello, w!"]',
1203
def test_compare_passwords_securely():
1204
password1 = "password"
1205
password2 = "pässword"
1206
assert compare_passwords_securely(password1, password1)
1207
assert not compare_passwords_securely(password1, password2)
1208
assert compare_passwords_securely(password2, password2)
1211
@pytest.mark.parametrize(
1214
("http://localhost:7860/", True),
1215
("https://localhost:7860/", True),
1216
("ftp://localhost:7860/", True),
1217
("smb://example.com", True),
1218
("ipfs://QmTzQ1Nj5R9BzF1djVQv8gvzZxVkJb1vhrLcXL1QyJzZE", True),
1219
("usr/local/bin", False),
1220
("localhost:7860", False),
1221
("localhost", False),
1222
("C:/Users/username", False),
1225
("/usr/bin//test", False),
1226
("/\\10.0.225.200/share", True),
1227
("\\/10.0.225.200/share", True),
1228
("/home//user", False),
1229
("C:\\folder\\file", False),
1232
def test_starts_with_protocol(string, expected):
1233
assert starts_with_protocol(string) == expected