gradio

Форк
0
/
test_flagging.py 
169 строк · 6.5 Кб
1
import os
2
import tempfile
3
from unittest.mock import MagicMock, patch
4

5
import pytest
6

7
import gradio as gr
8
from gradio import flagging
9

10
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
11

12

13
class TestDefaultFlagging:
14
    def test_default_flagging_callback(self):
15
        with tempfile.TemporaryDirectory() as tmpdirname:
16
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
17
            io.launch(prevent_thread_lock=True)
18
            row_count = io.flagging_callback.flag(["test", "test"])
19
            assert row_count == 1  # 2 rows written including header
20
            row_count = io.flagging_callback.flag(["test", "test"])
21
            assert row_count == 2  # 3 rows written including header
22
        io.close()
23

24
    def test_flagging_does_not_create_unnecessary_directories(self):
25
        with tempfile.TemporaryDirectory() as tmpdirname:
26
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
27
            io.launch(prevent_thread_lock=True)
28
            io.flagging_callback.flag(["test", "test"])
29
            assert os.listdir(tmpdirname) == ["log.csv"]
30

31

32
class TestSimpleFlagging:
33
    def test_simple_csv_flagging_callback(self):
34
        with tempfile.TemporaryDirectory() as tmpdirname:
35
            io = gr.Interface(
36
                lambda x: x,
37
                "text",
38
                "text",
39
                flagging_dir=tmpdirname,
40
                flagging_callback=flagging.SimpleCSVLogger(),
41
            )
42
            io.launch(prevent_thread_lock=True)
43
            row_count = io.flagging_callback.flag(["test", "test"])
44
            assert row_count == 0  # no header in SimpleCSVLogger
45
            row_count = io.flagging_callback.flag(["test", "test"])
46
            assert row_count == 1  # no header in SimpleCSVLogger
47
        io.close()
48

49

50
class TestHuggingFaceDatasetSaver:
51
    @patch(
52
        "huggingface_hub.create_repo",
53
        return_value=MagicMock(repo_id="gradio-tests/test"),
54
    )
55
    @patch("huggingface_hub.hf_hub_download")
56
    @patch("huggingface_hub.metadata_update")
57
    def test_saver_setup(self, metadata_update, mock_download, mock_create):
58
        flagger = flagging.HuggingFaceDatasetSaver("test_token", "test")
59
        with tempfile.TemporaryDirectory() as tmpdirname:
60
            flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
61
        mock_create.assert_called_once()
62
        mock_download.assert_called()
63

64
    @patch(
65
        "huggingface_hub.create_repo",
66
        return_value=MagicMock(repo_id="gradio-tests/test"),
67
    )
68
    @patch("huggingface_hub.hf_hub_download")
69
    @patch("huggingface_hub.upload_folder")
70
    @patch("huggingface_hub.upload_file")
71
    @patch("huggingface_hub.metadata_update")
72
    def test_saver_flag_same_dir(
73
        self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
74
    ):
75
        with tempfile.TemporaryDirectory() as tmpdirname:
76
            io = gr.Interface(
77
                lambda x: x,
78
                "text",
79
                "text",
80
                flagging_dir=tmpdirname,
81
                flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
82
            )
83
            row_count = io.flagging_callback.flag(["test", "test"], "")
84
            assert row_count == 1  # 2 rows written including header
85
            row_count = io.flagging_callback.flag(["test", "test"])
86
            assert row_count == 2  # 3 rows written including header
87
            for _, _, filenames in os.walk(tmpdirname):
88
                for f in filenames:
89
                    fname = os.path.basename(f)
90
                    assert fname in ["data.csv", "dataset_info.json"] or fname.endswith(
91
                        ".lock"
92
                    )
93

94
    @patch(
95
        "huggingface_hub.create_repo",
96
        return_value=MagicMock(repo_id="gradio-tests/test"),
97
    )
98
    @patch("huggingface_hub.hf_hub_download")
99
    @patch("huggingface_hub.upload_folder")
100
    @patch("huggingface_hub.upload_file")
101
    @patch("huggingface_hub.metadata_update")
102
    def test_saver_flag_separate_dirs(
103
        self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
104
    ):
105
        with tempfile.TemporaryDirectory() as tmpdirname:
106
            io = gr.Interface(
107
                lambda x: x,
108
                "text",
109
                "text",
110
                flagging_dir=tmpdirname,
111
                flagging_callback=flagging.HuggingFaceDatasetSaver(
112
                    "test", "test", separate_dirs=True
113
                ),
114
            )
115
            row_count = io.flagging_callback.flag(["test", "test"], "")
116
            assert row_count == 1  # 2 rows written including header
117
            row_count = io.flagging_callback.flag(["test", "test"])
118
            assert row_count == 2  # 3 rows written including header
119
            for _, _, filenames in os.walk(tmpdirname):
120
                for f in filenames:
121
                    fname = os.path.basename(f)
122
                    assert fname in [
123
                        "metadata.jsonl",
124
                        "dataset_info.json",
125
                    ] or fname.endswith(".lock")
126

127

128
class TestDisableFlagging:
129
    def test_flagging_no_permission_error_with_flagging_disabled(self):
130
        tmpdirname = tempfile.mkdtemp()
131
        os.chmod(tmpdirname, 0o444)  # Make directory read-only
132
        nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
133
        io = gr.Interface(
134
            lambda x: x,
135
            "text",
136
            "text",
137
            allow_flagging="never",
138
            flagging_dir=nonwritable_path,
139
        )
140
        io.launch(prevent_thread_lock=True)
141
        io.close()
142

143

144
class TestInterfaceSetsUpFlagging:
145
    @pytest.mark.parametrize(
146
        "allow_flagging, called",
147
        [
148
            ("manual", True),
149
            ("auto", True),
150
            ("never", False),
151
        ],
152
    )
153
    def test_flag_method_init_called(self, allow_flagging, called):
154
        flagging.FlagMethod.__init__ = MagicMock()
155
        flagging.FlagMethod.__init__.return_value = None
156
        gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
157
        assert flagging.FlagMethod.__init__.called == called
158

159
    @pytest.mark.parametrize(
160
        "options, processed_options",
161
        [
162
            (None, [("Flag", "")]),
163
            (["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
164
            ([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
165
        ],
166
    )
167
    def test_flagging_options_processed_correctly(self, options, processed_options):
168
        io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
169
        assert io.flagging_options == processed_options
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.