3
from unittest.mock import MagicMock, patch
8
from gradio import flagging
10
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
13
class TestDefaultFlagging:
14
def test_default_flagging_callback(self):
15
with tempfile.TemporaryDirectory() as tmpdirname:
16
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
17
io.launch(prevent_thread_lock=True)
18
row_count = io.flagging_callback.flag(["test", "test"])
20
row_count = io.flagging_callback.flag(["test", "test"])
24
def test_flagging_does_not_create_unnecessary_directories(self):
25
with tempfile.TemporaryDirectory() as tmpdirname:
26
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
27
io.launch(prevent_thread_lock=True)
28
io.flagging_callback.flag(["test", "test"])
29
assert os.listdir(tmpdirname) == ["log.csv"]
32
class TestSimpleFlagging:
33
def test_simple_csv_flagging_callback(self):
34
with tempfile.TemporaryDirectory() as tmpdirname:
39
flagging_dir=tmpdirname,
40
flagging_callback=flagging.SimpleCSVLogger(),
42
io.launch(prevent_thread_lock=True)
43
row_count = io.flagging_callback.flag(["test", "test"])
45
row_count = io.flagging_callback.flag(["test", "test"])
50
class TestHuggingFaceDatasetSaver:
52
"huggingface_hub.create_repo",
53
return_value=MagicMock(repo_id="gradio-tests/test"),
55
@patch("huggingface_hub.hf_hub_download")
56
@patch("huggingface_hub.metadata_update")
57
def test_saver_setup(self, metadata_update, mock_download, mock_create):
58
flagger = flagging.HuggingFaceDatasetSaver("test_token", "test")
59
with tempfile.TemporaryDirectory() as tmpdirname:
60
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
61
mock_create.assert_called_once()
62
mock_download.assert_called()
65
"huggingface_hub.create_repo",
66
return_value=MagicMock(repo_id="gradio-tests/test"),
68
@patch("huggingface_hub.hf_hub_download")
69
@patch("huggingface_hub.upload_folder")
70
@patch("huggingface_hub.upload_file")
71
@patch("huggingface_hub.metadata_update")
72
def test_saver_flag_same_dir(
73
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
75
with tempfile.TemporaryDirectory() as tmpdirname:
80
flagging_dir=tmpdirname,
81
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
83
row_count = io.flagging_callback.flag(["test", "test"], "")
85
row_count = io.flagging_callback.flag(["test", "test"])
87
for _, _, filenames in os.walk(tmpdirname):
89
fname = os.path.basename(f)
90
assert fname in ["data.csv", "dataset_info.json"] or fname.endswith(
95
"huggingface_hub.create_repo",
96
return_value=MagicMock(repo_id="gradio-tests/test"),
98
@patch("huggingface_hub.hf_hub_download")
99
@patch("huggingface_hub.upload_folder")
100
@patch("huggingface_hub.upload_file")
101
@patch("huggingface_hub.metadata_update")
102
def test_saver_flag_separate_dirs(
103
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
105
with tempfile.TemporaryDirectory() as tmpdirname:
110
flagging_dir=tmpdirname,
111
flagging_callback=flagging.HuggingFaceDatasetSaver(
112
"test", "test", separate_dirs=True
115
row_count = io.flagging_callback.flag(["test", "test"], "")
116
assert row_count == 1
117
row_count = io.flagging_callback.flag(["test", "test"])
118
assert row_count == 2
119
for _, _, filenames in os.walk(tmpdirname):
121
fname = os.path.basename(f)
125
] or fname.endswith(".lock")
128
class TestDisableFlagging:
129
def test_flagging_no_permission_error_with_flagging_disabled(self):
130
tmpdirname = tempfile.mkdtemp()
131
os.chmod(tmpdirname, 0o444)
132
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
137
allow_flagging="never",
138
flagging_dir=nonwritable_path,
140
io.launch(prevent_thread_lock=True)
144
class TestInterfaceSetsUpFlagging:
145
@pytest.mark.parametrize(
146
"allow_flagging, called",
153
def test_flag_method_init_called(self, allow_flagging, called):
154
flagging.FlagMethod.__init__ = MagicMock()
155
flagging.FlagMethod.__init__.return_value = None
156
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
157
assert flagging.FlagMethod.__init__.called == called
159
@pytest.mark.parametrize(
160
"options, processed_options",
162
(None, [("Flag", "")]),
163
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
164
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
167
def test_flagging_options_processed_correctly(self, options, processed_options):
168
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
169
assert io.flagging_options == processed_options