transformers
339 строк · 12.5 Кб
1# coding=utf-8
2# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch SegGpt model. """
16
17
18import inspect
19import unittest
20
21from datasets import load_dataset
22
23from transformers import SegGptConfig
24from transformers.testing_utils import (
25require_torch,
26require_vision,
27slow,
28torch_device,
29)
30from transformers.utils import cached_property, is_torch_available, is_vision_available
31
32from ...test_configuration_common import ConfigTester
33from ...test_modeling_common import ModelTesterMixin, floats_tensor
34from ...test_pipeline_mixin import PipelineTesterMixin
35
36
37if is_torch_available():
38import torch
39from torch import nn
40
41from transformers import SegGptForImageSegmentation, SegGptModel
42from transformers.models.seggpt.modeling_seggpt import SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST
43
44
45if is_vision_available():
46from transformers import SegGptImageProcessor
47
48
49class SegGptModelTester:
50def __init__(
51self,
52parent,
53batch_size=2,
54image_size=30,
55patch_size=2,
56num_channels=3,
57is_training=False,
58use_labels=True,
59hidden_size=32,
60num_hidden_layers=2,
61num_attention_heads=4,
62hidden_act="gelu",
63hidden_dropout_prob=0.1,
64attention_probs_dropout_prob=0.1,
65initializer_range=0.02,
66mlp_ratio=2.0,
67merge_index=0,
68intermediate_hidden_state_indices=[1],
69pretrain_image_size=10,
70decoder_hidden_size=10,
71):
72self.parent = parent
73self.batch_size = batch_size
74self.image_size = image_size
75self.patch_size = patch_size
76self.num_channels = num_channels
77self.is_training = is_training
78self.use_labels = use_labels
79self.hidden_size = hidden_size
80self.num_hidden_layers = num_hidden_layers
81self.num_attention_heads = num_attention_heads
82self.hidden_act = hidden_act
83self.hidden_dropout_prob = hidden_dropout_prob
84self.attention_probs_dropout_prob = attention_probs_dropout_prob
85self.initializer_range = initializer_range
86self.mlp_ratio = mlp_ratio
87self.merge_index = merge_index
88self.intermediate_hidden_state_indices = intermediate_hidden_state_indices
89self.pretrain_image_size = pretrain_image_size
90self.decoder_hidden_size = decoder_hidden_size
91
92# in SegGpt, the seq length equals the number of patches (we don't use the [CLS] token)
93num_patches = (image_size // patch_size) ** 2
94self.seq_length = num_patches
95
96def prepare_config_and_inputs(self):
97pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
98prompt_pixel_values = floats_tensor(
99[self.batch_size, self.num_channels, self.image_size // 2, self.image_size]
100)
101prompt_masks = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
102
103labels = None
104if self.use_labels:
105labels = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
106
107config = self.get_config()
108
109return config, pixel_values, prompt_pixel_values, prompt_masks, labels
110
111def get_config(self):
112return SegGptConfig(
113image_size=self.image_size,
114patch_size=self.patch_size,
115num_channels=self.num_channels,
116hidden_size=self.hidden_size,
117num_hidden_layers=self.num_hidden_layers,
118num_attention_heads=self.num_attention_heads,
119hidden_act=self.hidden_act,
120hidden_dropout_prob=self.hidden_dropout_prob,
121initializer_range=self.initializer_range,
122mlp_ratio=self.mlp_ratio,
123merge_index=self.merge_index,
124intermediate_hidden_state_indices=self.intermediate_hidden_state_indices,
125pretrain_image_size=self.pretrain_image_size,
126decoder_hidden_size=self.decoder_hidden_size,
127)
128
129def create_and_check_model(self, config, pixel_values, prompt_pixel_values, prompt_masks, labels):
130model = SegGptModel(config=config)
131model.to(torch_device)
132model.eval()
133result = model(pixel_values, prompt_pixel_values, prompt_masks)
134self.parent.assertEqual(
135result.last_hidden_state.shape,
136(
137self.batch_size,
138self.image_size // self.patch_size,
139self.image_size // self.patch_size,
140self.hidden_size,
141),
142)
143
144def prepare_config_and_inputs_for_common(self):
145config_and_inputs = self.prepare_config_and_inputs()
146(
147config,
148pixel_values,
149prompt_pixel_values,
150prompt_masks,
151labels,
152) = config_and_inputs
153inputs_dict = {
154"pixel_values": pixel_values,
155"prompt_pixel_values": prompt_pixel_values,
156"prompt_masks": prompt_masks,
157}
158return config, inputs_dict
159
160
161@require_torch
162class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
163"""
164Here we also overwrite some of the tests of test_modeling_common.py, as SegGpt does not use input_ids, inputs_embeds,
165attention_mask and seq_length.
166"""
167
168all_model_classes = (SegGptModel, SegGptForImageSegmentation) if is_torch_available() else ()
169fx_compatible = False
170
171test_pruning = False
172test_resize_embeddings = False
173test_head_masking = False
174test_torchscript = False
175pipeline_model_mapping = (
176{"feature-extraction": SegGptModel, "mask-generation": SegGptModel} if is_torch_available() else {}
177)
178
179def setUp(self):
180self.model_tester = SegGptModelTester(self)
181self.config_tester = ConfigTester(self, config_class=SegGptConfig, has_text_modality=False)
182
183def test_config(self):
184self.config_tester.run_common_tests()
185
186@unittest.skip(reason="SegGpt does not use inputs_embeds")
187def test_inputs_embeds(self):
188pass
189
190def test_model_common_attributes(self):
191config, _ = self.model_tester.prepare_config_and_inputs_for_common()
192
193for model_class in self.all_model_classes:
194model = model_class(config)
195self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
196
197def test_forward_signature(self):
198config, _ = self.model_tester.prepare_config_and_inputs_for_common()
199
200for model_class in self.all_model_classes:
201model = model_class(config)
202signature = inspect.signature(model.forward)
203# signature.parameters is an OrderedDict => so arg_names order is deterministic
204arg_names = [*signature.parameters.keys()]
205
206expected_arg_names = ["pixel_values", "prompt_pixel_values", "prompt_masks"]
207self.assertListEqual(arg_names[:3], expected_arg_names)
208
209def test_model(self):
210config_and_inputs = self.model_tester.prepare_config_and_inputs()
211self.model_tester.create_and_check_model(*config_and_inputs)
212
213def test_hidden_states_output(self):
214def check_hidden_states_output(inputs_dict, config, model_class):
215model = model_class(config)
216model.to(torch_device)
217model.eval()
218
219with torch.no_grad():
220outputs = model(**self._prepare_for_class(inputs_dict, model_class))
221
222hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
223
224expected_num_layers = getattr(
225self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
226)
227self.assertEqual(len(hidden_states), expected_num_layers)
228
229patch_height = patch_width = config.image_size // config.patch_size
230
231self.assertListEqual(
232list(hidden_states[0].shape[-3:]),
233[patch_height, patch_width, self.model_tester.hidden_size],
234)
235
236config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
237
238for model_class in self.all_model_classes:
239inputs_dict["output_hidden_states"] = True
240check_hidden_states_output(inputs_dict, config, model_class)
241
242# check that output_hidden_states also work using config
243del inputs_dict["output_hidden_states"]
244config.output_hidden_states = True
245
246check_hidden_states_output(inputs_dict, config, model_class)
247
248@slow
249def test_model_from_pretrained(self):
250for model_name in SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
251model = SegGptModel.from_pretrained(model_name)
252self.assertIsNotNone(model)
253
254
255def prepare_img():
256ds = load_dataset("EduardoPacheco/seggpt-example-data")["train"]
257images = [image.convert("RGB") for image in ds["image"]]
258masks = [image.convert("RGB") for image in ds["mask"]]
259return images, masks
260
261
262@require_torch
263@require_vision
264class SegGptModelIntegrationTest(unittest.TestCase):
265@cached_property
266def default_image_processor(self):
267return SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large") if is_vision_available() else None
268
269@slow
270def test_one_shot_inference(self):
271model = SegGptForImageSegmentation.from_pretrained("BAAI/seggpt-vit-large").to(torch_device)
272
273image_processor = self.default_image_processor
274
275images, masks = prepare_img()
276input_image = images[1]
277prompt_image = images[0]
278prompt_mask = masks[0]
279
280inputs = image_processor(
281images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
282)
283
284inputs = inputs.to(torch_device)
285# forward pass
286with torch.no_grad():
287outputs = model(**inputs)
288
289# verify the logits
290expected_shape = torch.Size((1, 3, 896, 448))
291self.assertEqual(outputs.pred_masks.shape, expected_shape)
292
293expected_slice = torch.tensor(
294[
295[[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
296[[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
297[[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
298]
299).to(torch_device)
300
301self.assertTrue(torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_slice, atol=1e-4))
302
303result = image_processor.post_process_semantic_segmentation(outputs, [input_image.size[::-1]])[0]
304
305result_expected_shape = torch.Size((170, 297))
306expected_area = 1082
307area = (result > 0).sum().item()
308self.assertEqual(result.shape, result_expected_shape)
309self.assertEqual(area, expected_area)
310
311@slow
312def test_few_shot_inference(self):
313model = SegGptForImageSegmentation.from_pretrained("BAAI/seggpt-vit-large").to(torch_device)
314image_processor = self.default_image_processor
315
316images, masks = prepare_img()
317input_images = [images[1]] * 2
318prompt_images = [images[0], images[2]]
319prompt_masks = [masks[0], masks[2]]
320
321inputs = image_processor(
322images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt"
323)
324
325inputs = {k: v.to(torch_device) for k, v in inputs.items()}
326with torch.no_grad():
327outputs = model(**inputs, feature_ensemble=True)
328
329expected_shape = torch.Size((2, 3, 896, 448))
330expected_slice = torch.tensor(
331[
332[[-2.1201, -2.1192, -2.1189], [-2.1217, -2.1210, -2.1204], [-2.1216, -2.1202, -2.1194]],
333[[-2.0393, -2.0390, -2.0387], [-2.0402, -2.0402, -2.0397], [-2.0400, -2.0394, -2.0388]],
334[[-1.8083, -1.8076, -1.8077], [-1.8105, -1.8102, -1.8099], [-1.8105, -1.8095, -1.8090]],
335]
336).to(torch_device)
337
338self.assertEqual(outputs.pred_masks.shape, expected_shape)
339self.assertTrue(torch.allclose(outputs.pred_masks[0, :, 448:451, :3], expected_slice, atol=4e-4))
340