gradio

Форк
0
/
test_pipelines.py 
267 строк · 11.2 Кб
1
import unittest
2
from unittest.mock import MagicMock
3

4
import pytest
5
import transformers
6
from diffusers import (
7
    StableDiffusionDepth2ImgPipeline,  # type: ignore
8
    StableDiffusionImageVariationPipeline,  # type: ignore
9
    StableDiffusionImg2ImgPipeline,  # type: ignore
10
    StableDiffusionInpaintPipeline,  # type: ignore
11
    StableDiffusionInstructPix2PixPipeline,  # type: ignore
12
    StableDiffusionPipeline,  # type: ignore
13
    StableDiffusionUpscalePipeline,  # type: ignore
14
)
15
from transformers import (
16
    AudioClassificationPipeline,
17
    AutomaticSpeechRecognitionPipeline,
18
    DocumentQuestionAnsweringPipeline,
19
    FeatureExtractionPipeline,
20
    FillMaskPipeline,
21
    ImageClassificationPipeline,
22
    ImageToTextPipeline,
23
    ObjectDetectionPipeline,
24
    QuestionAnsweringPipeline,
25
    SummarizationPipeline,
26
    Text2TextGenerationPipeline,
27
    TextClassificationPipeline,
28
    TextGenerationPipeline,
29
    TranslationPipeline,
30
    VisualQuestionAnsweringPipeline,
31
    ZeroShotClassificationPipeline,
32
)
33

34
import gradio as gr
35
from gradio.pipelines_utils import (
36
    handle_diffusers_pipeline,
37
    handle_transformers_pipeline,
38
)
39

40

41
@pytest.mark.flaky
42
def test_text_to_text_model_from_pipeline():
43
    pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
44
    io = gr.Interface.from_pipeline(pipe)
45
    output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
46
    assert isinstance(output, str)
47

48

49
@pytest.mark.flaky
50
def test_stable_diffusion_pipeline():
51
    pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe")
52
    io = gr.Interface.from_pipeline(pipe)
53
    output = io("An astronaut", "low quality", 3, 7.5)
54
    assert isinstance(output, str)
55

56

57
@pytest.mark.flaky
58
def test_interface_in_blocks():
59
    pipe1 = transformers.pipeline(model="sshleifer/bart-tiny-random")
60
    pipe2 = transformers.pipeline(model="sshleifer/bart-tiny-random")
61
    with gr.Blocks() as demo:
62
        with gr.Tab("Image Inference"):
63
            gr.Interface.from_pipeline(pipe1)
64
        with gr.Tab("Image Inference"):
65
            gr.Interface.from_pipeline(pipe2)
66
    demo.launch(prevent_thread_lock=True)
67
    demo.close()
68

69

70
def test_transformers_load_from_pipeline():
71
    from transformers import pipeline
72

73
    pipe = pipeline(model="deepset/roberta-base-squad2")
74
    io = gr.Interface.from_pipeline(pipe)
75
    assert io.input_components[0].label == "Context"
76
    assert io.input_components[1].label == "Question"
77
    assert io.output_components[0].label == "Answer"
78
    assert io.output_components[1].label == "Score"
79

80

81
class TestHandleTransformersPipelines(unittest.TestCase):
82
    def test_audio_classification_pipeline(self):
83
        pipe = MagicMock(spec=AudioClassificationPipeline)
84
        pipeline_info = handle_transformers_pipeline(pipe)
85
        assert pipeline_info is not None
86
        assert pipeline_info["inputs"].label == "Input"
87
        assert pipeline_info["outputs"].label == "Class"
88

89
    def test_automatic_speech_recognition_pipeline(self):
90
        pipe = MagicMock(spec=AutomaticSpeechRecognitionPipeline)
91
        pipeline_info = handle_transformers_pipeline(pipe)
92
        assert pipeline_info is not None
93
        assert pipeline_info["inputs"].label == "Input"
94
        assert pipeline_info["outputs"].label == "Output"
95

96
    def test_object_detection_pipeline(self):
97
        pipe = MagicMock(spec=ObjectDetectionPipeline)
98
        pipeline_info = handle_transformers_pipeline(pipe)
99
        assert pipeline_info is not None
100
        assert pipeline_info["inputs"].label == "Input Image"
101
        assert pipeline_info["outputs"].label == "Objects Detected"
102

103
    def test_feature_extraction_pipeline(self):
104
        pipe = MagicMock(spec=FeatureExtractionPipeline)
105
        pipeline_info = handle_transformers_pipeline(pipe)
106
        assert pipeline_info is not None
107
        assert pipeline_info["inputs"].label == "Input"
108
        assert pipeline_info["outputs"].label == "Output"
109

110
    def test_fill_mask_pipeline(self):
111
        pipe = MagicMock(spec=FillMaskPipeline)
112
        pipeline_info = handle_transformers_pipeline(pipe)
113
        assert pipeline_info is not None
114
        assert pipeline_info["inputs"].label == "Input"
115
        assert pipeline_info["outputs"].label == "Classification"
116

117
    def test_image_classification_pipeline(self):
118
        pipe = MagicMock(spec=ImageClassificationPipeline)
119
        pipeline_info = handle_transformers_pipeline(pipe)
120
        assert pipeline_info is not None
121
        assert pipeline_info["inputs"].label == "Input Image"
122
        assert pipeline_info["outputs"].label == "Classification"
123

124
    def test_question_answering_pipeline(self):
125
        pipe = MagicMock(spec=QuestionAnsweringPipeline)
126
        pipeline_info = handle_transformers_pipeline(pipe)
127
        assert pipeline_info is not None
128
        assert pipeline_info["inputs"][0].label == "Context"
129
        assert pipeline_info["inputs"][1].label == "Question"
130
        assert pipeline_info["outputs"][0].label == "Answer"
131
        assert pipeline_info["outputs"][1].label == "Score"
132

133
    def test_summarization_pipeline(self):
134
        pipe = MagicMock(spec=SummarizationPipeline)
135
        pipeline_info = handle_transformers_pipeline(pipe)
136
        assert pipeline_info is not None
137
        assert pipeline_info["inputs"].label == "Input"
138
        assert pipeline_info["outputs"].label == "Summary"
139

140
    def test_text_classification_pipeline(self):
141
        pipe = MagicMock(spec=TextClassificationPipeline)
142
        pipeline_info = handle_transformers_pipeline(pipe)
143
        assert pipeline_info is not None
144
        assert pipeline_info["inputs"].label == "Input"
145
        assert pipeline_info["outputs"].label == "Classification"
146

147
    def test_text_generation_pipeline(self):
148
        pipe = MagicMock(spec=TextGenerationPipeline)
149
        pipeline_info = handle_transformers_pipeline(pipe)
150
        assert pipeline_info is not None
151
        assert pipeline_info["inputs"].label == "Input"
152
        assert pipeline_info["outputs"].label == "Output"
153

154
    def test_translation_pipeline(self):
155
        pipe = MagicMock(spec=TranslationPipeline)
156
        pipeline_info = handle_transformers_pipeline(pipe)
157
        assert pipeline_info is not None
158
        assert pipeline_info["inputs"].label == "Input"
159
        assert pipeline_info["outputs"].label == "Translation"
160

161
    def test_text2text_generation_pipeline(self):
162
        pipe = MagicMock(spec=Text2TextGenerationPipeline)
163
        pipeline_info = handle_transformers_pipeline(pipe)
164
        assert pipeline_info is not None
165
        assert pipeline_info["inputs"].label == "Input"
166
        assert pipeline_info["outputs"].label == "Generated Text"
167

168
    def test_zero_shot_classification_pipeline(self):
169
        pipe = MagicMock(spec=ZeroShotClassificationPipeline)
170
        pipeline_info = handle_transformers_pipeline(pipe)
171
        assert pipeline_info is not None
172
        assert pipeline_info["inputs"][0].label == "Input"
173
        assert (
174
            pipeline_info["inputs"][1].label == "Possible class names (comma-separated)"
175
        )
176
        assert pipeline_info["inputs"][2].label == "Allow multiple true classes"
177
        assert pipeline_info["outputs"].label == "Classification"
178

179
    def test_document_question_answering_pipeline(self):
180
        pipe = MagicMock(spec=DocumentQuestionAnsweringPipeline)
181
        pipeline_info = handle_transformers_pipeline(pipe)
182
        assert pipeline_info is not None
183
        assert pipeline_info["inputs"][0].label == "Input Document"
184
        assert pipeline_info["inputs"][1].label == "Question"
185
        assert pipeline_info["outputs"].label == "Label"
186

187
    def test_visual_question_answering_pipeline(self):
188
        pipe = MagicMock(spec=VisualQuestionAnsweringPipeline)
189
        pipeline_info = handle_transformers_pipeline(pipe)
190
        assert pipeline_info is not None
191
        assert pipeline_info["inputs"][0].label == "Input Image"
192
        assert pipeline_info["inputs"][1].label == "Question"
193
        assert pipeline_info["outputs"].label == "Score"
194

195
    def test_image_to_text_pipeline(self):
196
        pipe = MagicMock(spec=ImageToTextPipeline)
197
        pipeline_info = handle_transformers_pipeline(pipe)
198
        assert pipeline_info is not None
199
        assert pipeline_info["inputs"].label == "Input Image"
200
        assert pipeline_info["outputs"].label == "Text"
201

202
    def test_unsupported_pipeline(self):
203
        pipe = MagicMock()
204
        with self.assertRaises(ValueError):
205
            handle_transformers_pipeline(pipe)
206

207

208
class TestHandleDiffusersPipelines(unittest.TestCase):
209
    def test_stable_diffusion_pipeline(self):
210
        pipe = MagicMock(spec=StableDiffusionPipeline)
211
        pipeline_info = handle_diffusers_pipeline(pipe)
212
        assert pipeline_info is not None
213
        assert pipeline_info["inputs"][0].label == "Prompt"
214
        assert pipeline_info["inputs"][1].label == "Negative prompt"
215
        assert pipeline_info["outputs"].label == "Generated Image"
216

217
    def test_stable_diffusion_img2img_pipeline(self):
218
        pipe = MagicMock(spec=StableDiffusionImg2ImgPipeline)
219
        pipeline_info = handle_diffusers_pipeline(pipe)
220
        assert pipeline_info is not None
221
        assert pipeline_info["inputs"][0].label == "Prompt"
222
        assert pipeline_info["inputs"][1].label == "Negative prompt"
223
        assert pipeline_info["outputs"].label == "Generated Image"
224

225
    def test_stable_diffusion_inpaint_pipeline(self):
226
        pipe = MagicMock(spec=StableDiffusionInpaintPipeline)
227
        pipeline_info = handle_diffusers_pipeline(pipe)
228
        assert pipeline_info is not None
229
        assert pipeline_info["inputs"][0].label == "Prompt"
230
        assert pipeline_info["inputs"][1].label == "Negative prompt"
231
        assert pipeline_info["outputs"].label == "Generated Image"
232

233
    def test_stable_diffusion_depth2img_pipeline(self):
234
        pipe = MagicMock(spec=StableDiffusionDepth2ImgPipeline)
235
        pipeline_info = handle_diffusers_pipeline(pipe)
236
        assert pipeline_info is not None
237
        assert pipeline_info["inputs"][0].label == "Prompt"
238
        assert pipeline_info["inputs"][1].label == "Negative prompt"
239
        assert pipeline_info["outputs"].label == "Generated Image"
240

241
    def test_stable_diffusion_image_variation_pipeline(self):
242
        pipe = MagicMock(spec=StableDiffusionImageVariationPipeline)
243
        pipeline_info = handle_diffusers_pipeline(pipe)
244
        assert pipeline_info is not None
245
        assert pipeline_info["inputs"][0].label == "Image"
246
        assert pipeline_info["outputs"].label == "Generated Image"
247

248
    def test_stable_diffusion_instruct_pix2pix_pipeline(self):
249
        pipe = MagicMock(spec=StableDiffusionInstructPix2PixPipeline)
250
        pipeline_info = handle_diffusers_pipeline(pipe)
251
        assert pipeline_info is not None
252
        assert pipeline_info["inputs"][0].label == "Prompt"
253
        assert pipeline_info["inputs"][1].label == "Negative prompt"
254
        assert pipeline_info["outputs"].label == "Generated Image"
255

256
    def test_stable_diffusion_upscale_pipeline(self):
257
        pipe = MagicMock(spec=StableDiffusionUpscalePipeline)
258
        pipeline_info = handle_diffusers_pipeline(pipe)
259
        assert pipeline_info is not None
260
        assert pipeline_info["inputs"][0].label == "Prompt"
261
        assert pipeline_info["inputs"][1].label == "Negative prompt"
262
        assert pipeline_info["outputs"].label == "Generated Image"
263

264
    def test_unsupported_pipeline(self):
265
        pipe = MagicMock()
266
        with self.assertRaises(ValueError):
267
            handle_transformers_pipeline(pipe)
268

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.