transformers

Форк
0
/
test_modeling_seggpt.py 
339 строк · 12.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch SegGpt model. """
16

17

18
import inspect
19
import unittest
20

21
from datasets import load_dataset
22

23
from transformers import SegGptConfig
24
from transformers.testing_utils import (
25
    require_torch,
26
    require_vision,
27
    slow,
28
    torch_device,
29
)
30
from transformers.utils import cached_property, is_torch_available, is_vision_available
31

32
from ...test_configuration_common import ConfigTester
33
from ...test_modeling_common import ModelTesterMixin, floats_tensor
34
from ...test_pipeline_mixin import PipelineTesterMixin
35

36

37
if is_torch_available():
38
    import torch
39
    from torch import nn
40

41
    from transformers import SegGptForImageSegmentation, SegGptModel
42
    from transformers.models.seggpt.modeling_seggpt import SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST
43

44

45
if is_vision_available():
46
    from transformers import SegGptImageProcessor
47

48

49
class SegGptModelTester:
50
    def __init__(
51
        self,
52
        parent,
53
        batch_size=2,
54
        image_size=30,
55
        patch_size=2,
56
        num_channels=3,
57
        is_training=False,
58
        use_labels=True,
59
        hidden_size=32,
60
        num_hidden_layers=2,
61
        num_attention_heads=4,
62
        hidden_act="gelu",
63
        hidden_dropout_prob=0.1,
64
        attention_probs_dropout_prob=0.1,
65
        initializer_range=0.02,
66
        mlp_ratio=2.0,
67
        merge_index=0,
68
        intermediate_hidden_state_indices=[1],
69
        pretrain_image_size=10,
70
        decoder_hidden_size=10,
71
    ):
72
        self.parent = parent
73
        self.batch_size = batch_size
74
        self.image_size = image_size
75
        self.patch_size = patch_size
76
        self.num_channels = num_channels
77
        self.is_training = is_training
78
        self.use_labels = use_labels
79
        self.hidden_size = hidden_size
80
        self.num_hidden_layers = num_hidden_layers
81
        self.num_attention_heads = num_attention_heads
82
        self.hidden_act = hidden_act
83
        self.hidden_dropout_prob = hidden_dropout_prob
84
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
85
        self.initializer_range = initializer_range
86
        self.mlp_ratio = mlp_ratio
87
        self.merge_index = merge_index
88
        self.intermediate_hidden_state_indices = intermediate_hidden_state_indices
89
        self.pretrain_image_size = pretrain_image_size
90
        self.decoder_hidden_size = decoder_hidden_size
91

92
        # in SegGpt, the seq length equals the number of patches (we don't use the [CLS] token)
93
        num_patches = (image_size // patch_size) ** 2
94
        self.seq_length = num_patches
95

96
    def prepare_config_and_inputs(self):
97
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
98
        prompt_pixel_values = floats_tensor(
99
            [self.batch_size, self.num_channels, self.image_size // 2, self.image_size]
100
        )
101
        prompt_masks = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
102

103
        labels = None
104
        if self.use_labels:
105
            labels = floats_tensor([self.batch_size, self.num_channels, self.image_size // 2, self.image_size])
106

107
        config = self.get_config()
108

109
        return config, pixel_values, prompt_pixel_values, prompt_masks, labels
110

111
    def get_config(self):
112
        return SegGptConfig(
113
            image_size=self.image_size,
114
            patch_size=self.patch_size,
115
            num_channels=self.num_channels,
116
            hidden_size=self.hidden_size,
117
            num_hidden_layers=self.num_hidden_layers,
118
            num_attention_heads=self.num_attention_heads,
119
            hidden_act=self.hidden_act,
120
            hidden_dropout_prob=self.hidden_dropout_prob,
121
            initializer_range=self.initializer_range,
122
            mlp_ratio=self.mlp_ratio,
123
            merge_index=self.merge_index,
124
            intermediate_hidden_state_indices=self.intermediate_hidden_state_indices,
125
            pretrain_image_size=self.pretrain_image_size,
126
            decoder_hidden_size=self.decoder_hidden_size,
127
        )
128

129
    def create_and_check_model(self, config, pixel_values, prompt_pixel_values, prompt_masks, labels):
130
        model = SegGptModel(config=config)
131
        model.to(torch_device)
132
        model.eval()
133
        result = model(pixel_values, prompt_pixel_values, prompt_masks)
134
        self.parent.assertEqual(
135
            result.last_hidden_state.shape,
136
            (
137
                self.batch_size,
138
                self.image_size // self.patch_size,
139
                self.image_size // self.patch_size,
140
                self.hidden_size,
141
            ),
142
        )
143

144
    def prepare_config_and_inputs_for_common(self):
145
        config_and_inputs = self.prepare_config_and_inputs()
146
        (
147
            config,
148
            pixel_values,
149
            prompt_pixel_values,
150
            prompt_masks,
151
            labels,
152
        ) = config_and_inputs
153
        inputs_dict = {
154
            "pixel_values": pixel_values,
155
            "prompt_pixel_values": prompt_pixel_values,
156
            "prompt_masks": prompt_masks,
157
        }
158
        return config, inputs_dict
159

160

161
@require_torch
162
class SegGptModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
163
    """
164
    Here we also overwrite some of the tests of test_modeling_common.py, as SegGpt does not use input_ids, inputs_embeds,
165
    attention_mask and seq_length.
166
    """
167

168
    all_model_classes = (SegGptModel, SegGptForImageSegmentation) if is_torch_available() else ()
169
    fx_compatible = False
170

171
    test_pruning = False
172
    test_resize_embeddings = False
173
    test_head_masking = False
174
    test_torchscript = False
175
    pipeline_model_mapping = (
176
        {"feature-extraction": SegGptModel, "mask-generation": SegGptModel} if is_torch_available() else {}
177
    )
178

179
    def setUp(self):
180
        self.model_tester = SegGptModelTester(self)
181
        self.config_tester = ConfigTester(self, config_class=SegGptConfig, has_text_modality=False)
182

183
    def test_config(self):
184
        self.config_tester.run_common_tests()
185

186
    @unittest.skip(reason="SegGpt does not use inputs_embeds")
187
    def test_inputs_embeds(self):
188
        pass
189

190
    def test_model_common_attributes(self):
191
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
192

193
        for model_class in self.all_model_classes:
194
            model = model_class(config)
195
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
196

197
    def test_forward_signature(self):
198
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
199

200
        for model_class in self.all_model_classes:
201
            model = model_class(config)
202
            signature = inspect.signature(model.forward)
203
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
204
            arg_names = [*signature.parameters.keys()]
205

206
            expected_arg_names = ["pixel_values", "prompt_pixel_values", "prompt_masks"]
207
            self.assertListEqual(arg_names[:3], expected_arg_names)
208

209
    def test_model(self):
210
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
211
        self.model_tester.create_and_check_model(*config_and_inputs)
212

213
    def test_hidden_states_output(self):
214
        def check_hidden_states_output(inputs_dict, config, model_class):
215
            model = model_class(config)
216
            model.to(torch_device)
217
            model.eval()
218

219
            with torch.no_grad():
220
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
221

222
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
223

224
            expected_num_layers = getattr(
225
                self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
226
            )
227
            self.assertEqual(len(hidden_states), expected_num_layers)
228

229
            patch_height = patch_width = config.image_size // config.patch_size
230

231
            self.assertListEqual(
232
                list(hidden_states[0].shape[-3:]),
233
                [patch_height, patch_width, self.model_tester.hidden_size],
234
            )
235

236
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
237

238
        for model_class in self.all_model_classes:
239
            inputs_dict["output_hidden_states"] = True
240
            check_hidden_states_output(inputs_dict, config, model_class)
241

242
            # check that output_hidden_states also work using config
243
            del inputs_dict["output_hidden_states"]
244
            config.output_hidden_states = True
245

246
            check_hidden_states_output(inputs_dict, config, model_class)
247

248
    @slow
249
    def test_model_from_pretrained(self):
250
        for model_name in SEGGPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
251
            model = SegGptModel.from_pretrained(model_name)
252
            self.assertIsNotNone(model)
253

254

255
def prepare_img():
256
    ds = load_dataset("EduardoPacheco/seggpt-example-data")["train"]
257
    images = [image.convert("RGB") for image in ds["image"]]
258
    masks = [image.convert("RGB") for image in ds["mask"]]
259
    return images, masks
260

261

262
@require_torch
263
@require_vision
264
class SegGptModelIntegrationTest(unittest.TestCase):
265
    @cached_property
266
    def default_image_processor(self):
267
        return SegGptImageProcessor.from_pretrained("BAAI/seggpt-vit-large") if is_vision_available() else None
268

269
    @slow
270
    def test_one_shot_inference(self):
271
        model = SegGptForImageSegmentation.from_pretrained("BAAI/seggpt-vit-large").to(torch_device)
272

273
        image_processor = self.default_image_processor
274

275
        images, masks = prepare_img()
276
        input_image = images[1]
277
        prompt_image = images[0]
278
        prompt_mask = masks[0]
279

280
        inputs = image_processor(
281
            images=input_image, prompt_images=prompt_image, prompt_masks=prompt_mask, return_tensors="pt"
282
        )
283

284
        inputs = inputs.to(torch_device)
285
        # forward pass
286
        with torch.no_grad():
287
            outputs = model(**inputs)
288

289
        # verify the logits
290
        expected_shape = torch.Size((1, 3, 896, 448))
291
        self.assertEqual(outputs.pred_masks.shape, expected_shape)
292

293
        expected_slice = torch.tensor(
294
            [
295
                [[-2.1208, -2.1190, -2.1198], [-2.1237, -2.1228, -2.1227], [-2.1232, -2.1226, -2.1228]],
296
                [[-2.0405, -2.0396, -2.0403], [-2.0434, -2.0434, -2.0433], [-2.0428, -2.0432, -2.0434]],
297
                [[-1.8102, -1.8088, -1.8099], [-1.8131, -1.8126, -1.8129], [-1.8130, -1.8128, -1.8131]],
298
            ]
299
        ).to(torch_device)
300

301
        self.assertTrue(torch.allclose(outputs.pred_masks[0, :, :3, :3], expected_slice, atol=1e-4))
302

303
        result = image_processor.post_process_semantic_segmentation(outputs, [input_image.size[::-1]])[0]
304

305
        result_expected_shape = torch.Size((170, 297))
306
        expected_area = 1082
307
        area = (result > 0).sum().item()
308
        self.assertEqual(result.shape, result_expected_shape)
309
        self.assertEqual(area, expected_area)
310

311
    @slow
312
    def test_few_shot_inference(self):
313
        model = SegGptForImageSegmentation.from_pretrained("BAAI/seggpt-vit-large").to(torch_device)
314
        image_processor = self.default_image_processor
315

316
        images, masks = prepare_img()
317
        input_images = [images[1]] * 2
318
        prompt_images = [images[0], images[2]]
319
        prompt_masks = [masks[0], masks[2]]
320

321
        inputs = image_processor(
322
            images=input_images, prompt_images=prompt_images, prompt_masks=prompt_masks, return_tensors="pt"
323
        )
324

325
        inputs = {k: v.to(torch_device) for k, v in inputs.items()}
326
        with torch.no_grad():
327
            outputs = model(**inputs, feature_ensemble=True)
328

329
        expected_shape = torch.Size((2, 3, 896, 448))
330
        expected_slice = torch.tensor(
331
            [
332
                [[-2.1201, -2.1192, -2.1189], [-2.1217, -2.1210, -2.1204], [-2.1216, -2.1202, -2.1194]],
333
                [[-2.0393, -2.0390, -2.0387], [-2.0402, -2.0402, -2.0397], [-2.0400, -2.0394, -2.0388]],
334
                [[-1.8083, -1.8076, -1.8077], [-1.8105, -1.8102, -1.8099], [-1.8105, -1.8095, -1.8090]],
335
            ]
336
        ).to(torch_device)
337

338
        self.assertEqual(outputs.pred_masks.shape, expected_shape)
339
        self.assertTrue(torch.allclose(outputs.pred_masks[0, :, 448:451, :3], expected_slice, atol=4e-4))
340

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.