2
from unittest.mock import MagicMock
7
StableDiffusionDepth2ImgPipeline, # type: ignore
8
StableDiffusionImageVariationPipeline, # type: ignore
9
StableDiffusionImg2ImgPipeline, # type: ignore
10
StableDiffusionInpaintPipeline, # type: ignore
11
StableDiffusionInstructPix2PixPipeline, # type: ignore
12
StableDiffusionPipeline, # type: ignore
13
StableDiffusionUpscalePipeline, # type: ignore
15
from transformers import (
16
AudioClassificationPipeline,
17
AutomaticSpeechRecognitionPipeline,
18
DocumentQuestionAnsweringPipeline,
19
FeatureExtractionPipeline,
21
ImageClassificationPipeline,
23
ObjectDetectionPipeline,
24
QuestionAnsweringPipeline,
25
SummarizationPipeline,
26
Text2TextGenerationPipeline,
27
TextClassificationPipeline,
28
TextGenerationPipeline,
30
VisualQuestionAnsweringPipeline,
31
ZeroShotClassificationPipeline,
35
from gradio.pipelines_utils import (
36
handle_diffusers_pipeline,
37
handle_transformers_pipeline,
42
def test_text_to_text_model_from_pipeline():
43
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
44
io = gr.Interface.from_pipeline(pipe)
45
output = io("My name is Sylvain and I work at Hugging Face in Brooklyn")
46
assert isinstance(output, str)
50
def test_stable_diffusion_pipeline():
51
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe")
52
io = gr.Interface.from_pipeline(pipe)
53
output = io("An astronaut", "low quality", 3, 7.5)
54
assert isinstance(output, str)
58
def test_interface_in_blocks():
59
pipe1 = transformers.pipeline(model="sshleifer/bart-tiny-random")
60
pipe2 = transformers.pipeline(model="sshleifer/bart-tiny-random")
61
with gr.Blocks() as demo:
62
with gr.Tab("Image Inference"):
63
gr.Interface.from_pipeline(pipe1)
64
with gr.Tab("Image Inference"):
65
gr.Interface.from_pipeline(pipe2)
66
demo.launch(prevent_thread_lock=True)
70
def test_transformers_load_from_pipeline():
71
from transformers import pipeline
73
pipe = pipeline(model="deepset/roberta-base-squad2")
74
io = gr.Interface.from_pipeline(pipe)
75
assert io.input_components[0].label == "Context"
76
assert io.input_components[1].label == "Question"
77
assert io.output_components[0].label == "Answer"
78
assert io.output_components[1].label == "Score"
81
class TestHandleTransformersPipelines(unittest.TestCase):
82
def test_audio_classification_pipeline(self):
83
pipe = MagicMock(spec=AudioClassificationPipeline)
84
pipeline_info = handle_transformers_pipeline(pipe)
85
assert pipeline_info is not None
86
assert pipeline_info["inputs"].label == "Input"
87
assert pipeline_info["outputs"].label == "Class"
89
def test_automatic_speech_recognition_pipeline(self):
90
pipe = MagicMock(spec=AutomaticSpeechRecognitionPipeline)
91
pipeline_info = handle_transformers_pipeline(pipe)
92
assert pipeline_info is not None
93
assert pipeline_info["inputs"].label == "Input"
94
assert pipeline_info["outputs"].label == "Output"
96
def test_object_detection_pipeline(self):
97
pipe = MagicMock(spec=ObjectDetectionPipeline)
98
pipeline_info = handle_transformers_pipeline(pipe)
99
assert pipeline_info is not None
100
assert pipeline_info["inputs"].label == "Input Image"
101
assert pipeline_info["outputs"].label == "Objects Detected"
103
def test_feature_extraction_pipeline(self):
104
pipe = MagicMock(spec=FeatureExtractionPipeline)
105
pipeline_info = handle_transformers_pipeline(pipe)
106
assert pipeline_info is not None
107
assert pipeline_info["inputs"].label == "Input"
108
assert pipeline_info["outputs"].label == "Output"
110
def test_fill_mask_pipeline(self):
111
pipe = MagicMock(spec=FillMaskPipeline)
112
pipeline_info = handle_transformers_pipeline(pipe)
113
assert pipeline_info is not None
114
assert pipeline_info["inputs"].label == "Input"
115
assert pipeline_info["outputs"].label == "Classification"
117
def test_image_classification_pipeline(self):
118
pipe = MagicMock(spec=ImageClassificationPipeline)
119
pipeline_info = handle_transformers_pipeline(pipe)
120
assert pipeline_info is not None
121
assert pipeline_info["inputs"].label == "Input Image"
122
assert pipeline_info["outputs"].label == "Classification"
124
def test_question_answering_pipeline(self):
125
pipe = MagicMock(spec=QuestionAnsweringPipeline)
126
pipeline_info = handle_transformers_pipeline(pipe)
127
assert pipeline_info is not None
128
assert pipeline_info["inputs"][0].label == "Context"
129
assert pipeline_info["inputs"][1].label == "Question"
130
assert pipeline_info["outputs"][0].label == "Answer"
131
assert pipeline_info["outputs"][1].label == "Score"
133
def test_summarization_pipeline(self):
134
pipe = MagicMock(spec=SummarizationPipeline)
135
pipeline_info = handle_transformers_pipeline(pipe)
136
assert pipeline_info is not None
137
assert pipeline_info["inputs"].label == "Input"
138
assert pipeline_info["outputs"].label == "Summary"
140
def test_text_classification_pipeline(self):
141
pipe = MagicMock(spec=TextClassificationPipeline)
142
pipeline_info = handle_transformers_pipeline(pipe)
143
assert pipeline_info is not None
144
assert pipeline_info["inputs"].label == "Input"
145
assert pipeline_info["outputs"].label == "Classification"
147
def test_text_generation_pipeline(self):
148
pipe = MagicMock(spec=TextGenerationPipeline)
149
pipeline_info = handle_transformers_pipeline(pipe)
150
assert pipeline_info is not None
151
assert pipeline_info["inputs"].label == "Input"
152
assert pipeline_info["outputs"].label == "Output"
154
def test_translation_pipeline(self):
155
pipe = MagicMock(spec=TranslationPipeline)
156
pipeline_info = handle_transformers_pipeline(pipe)
157
assert pipeline_info is not None
158
assert pipeline_info["inputs"].label == "Input"
159
assert pipeline_info["outputs"].label == "Translation"
161
def test_text2text_generation_pipeline(self):
162
pipe = MagicMock(spec=Text2TextGenerationPipeline)
163
pipeline_info = handle_transformers_pipeline(pipe)
164
assert pipeline_info is not None
165
assert pipeline_info["inputs"].label == "Input"
166
assert pipeline_info["outputs"].label == "Generated Text"
168
def test_zero_shot_classification_pipeline(self):
169
pipe = MagicMock(spec=ZeroShotClassificationPipeline)
170
pipeline_info = handle_transformers_pipeline(pipe)
171
assert pipeline_info is not None
172
assert pipeline_info["inputs"][0].label == "Input"
174
pipeline_info["inputs"][1].label == "Possible class names (comma-separated)"
176
assert pipeline_info["inputs"][2].label == "Allow multiple true classes"
177
assert pipeline_info["outputs"].label == "Classification"
179
def test_document_question_answering_pipeline(self):
180
pipe = MagicMock(spec=DocumentQuestionAnsweringPipeline)
181
pipeline_info = handle_transformers_pipeline(pipe)
182
assert pipeline_info is not None
183
assert pipeline_info["inputs"][0].label == "Input Document"
184
assert pipeline_info["inputs"][1].label == "Question"
185
assert pipeline_info["outputs"].label == "Label"
187
def test_visual_question_answering_pipeline(self):
188
pipe = MagicMock(spec=VisualQuestionAnsweringPipeline)
189
pipeline_info = handle_transformers_pipeline(pipe)
190
assert pipeline_info is not None
191
assert pipeline_info["inputs"][0].label == "Input Image"
192
assert pipeline_info["inputs"][1].label == "Question"
193
assert pipeline_info["outputs"].label == "Score"
195
def test_image_to_text_pipeline(self):
196
pipe = MagicMock(spec=ImageToTextPipeline)
197
pipeline_info = handle_transformers_pipeline(pipe)
198
assert pipeline_info is not None
199
assert pipeline_info["inputs"].label == "Input Image"
200
assert pipeline_info["outputs"].label == "Text"
202
def test_unsupported_pipeline(self):
204
with self.assertRaises(ValueError):
205
handle_transformers_pipeline(pipe)
208
class TestHandleDiffusersPipelines(unittest.TestCase):
209
def test_stable_diffusion_pipeline(self):
210
pipe = MagicMock(spec=StableDiffusionPipeline)
211
pipeline_info = handle_diffusers_pipeline(pipe)
212
assert pipeline_info is not None
213
assert pipeline_info["inputs"][0].label == "Prompt"
214
assert pipeline_info["inputs"][1].label == "Negative prompt"
215
assert pipeline_info["outputs"].label == "Generated Image"
217
def test_stable_diffusion_img2img_pipeline(self):
218
pipe = MagicMock(spec=StableDiffusionImg2ImgPipeline)
219
pipeline_info = handle_diffusers_pipeline(pipe)
220
assert pipeline_info is not None
221
assert pipeline_info["inputs"][0].label == "Prompt"
222
assert pipeline_info["inputs"][1].label == "Negative prompt"
223
assert pipeline_info["outputs"].label == "Generated Image"
225
def test_stable_diffusion_inpaint_pipeline(self):
226
pipe = MagicMock(spec=StableDiffusionInpaintPipeline)
227
pipeline_info = handle_diffusers_pipeline(pipe)
228
assert pipeline_info is not None
229
assert pipeline_info["inputs"][0].label == "Prompt"
230
assert pipeline_info["inputs"][1].label == "Negative prompt"
231
assert pipeline_info["outputs"].label == "Generated Image"
233
def test_stable_diffusion_depth2img_pipeline(self):
234
pipe = MagicMock(spec=StableDiffusionDepth2ImgPipeline)
235
pipeline_info = handle_diffusers_pipeline(pipe)
236
assert pipeline_info is not None
237
assert pipeline_info["inputs"][0].label == "Prompt"
238
assert pipeline_info["inputs"][1].label == "Negative prompt"
239
assert pipeline_info["outputs"].label == "Generated Image"
241
def test_stable_diffusion_image_variation_pipeline(self):
242
pipe = MagicMock(spec=StableDiffusionImageVariationPipeline)
243
pipeline_info = handle_diffusers_pipeline(pipe)
244
assert pipeline_info is not None
245
assert pipeline_info["inputs"][0].label == "Image"
246
assert pipeline_info["outputs"].label == "Generated Image"
248
def test_stable_diffusion_instruct_pix2pix_pipeline(self):
249
pipe = MagicMock(spec=StableDiffusionInstructPix2PixPipeline)
250
pipeline_info = handle_diffusers_pipeline(pipe)
251
assert pipeline_info is not None
252
assert pipeline_info["inputs"][0].label == "Prompt"
253
assert pipeline_info["inputs"][1].label == "Negative prompt"
254
assert pipeline_info["outputs"].label == "Generated Image"
256
def test_stable_diffusion_upscale_pipeline(self):
257
pipe = MagicMock(spec=StableDiffusionUpscalePipeline)
258
pipeline_info = handle_diffusers_pipeline(pipe)
259
assert pipeline_info is not None
260
assert pipeline_info["inputs"][0].label == "Prompt"
261
assert pipeline_info["inputs"][1].label == "Negative prompt"
262
assert pipeline_info["outputs"].label == "Generated Image"
264
def test_unsupported_pipeline(self):
266
with self.assertRaises(ValueError):
267
handle_transformers_pipeline(pipe)