transformers

Форк
0
/
test_modeling_sam.py 
763 строки · 27.8 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch SAM model. """
16

17

18
import gc
19
import unittest
20

21
import requests
22

23
from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
24
from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
25
from transformers.utils import is_torch_available, is_vision_available
26

27
from ...test_configuration_common import ConfigTester
28
from ...test_modeling_common import ModelTesterMixin, floats_tensor
29
from ...test_pipeline_mixin import PipelineTesterMixin
30

31

32
if is_torch_available():
33
    import torch
34
    from torch import nn
35

36
    from transformers import SamModel, SamProcessor
37
    from transformers.models.sam.modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST
38

39

40
if is_vision_available():
41
    from PIL import Image
42

43

44
class SamPromptEncoderTester:
45
    def __init__(
46
        self,
47
        hidden_size=32,
48
        input_image_size=24,
49
        patch_size=2,
50
        mask_input_channels=4,
51
        num_point_embeddings=4,
52
        hidden_act="gelu",
53
    ):
54
        self.hidden_size = hidden_size
55
        self.input_image_size = input_image_size
56
        self.patch_size = patch_size
57
        self.mask_input_channels = mask_input_channels
58
        self.num_point_embeddings = num_point_embeddings
59
        self.hidden_act = hidden_act
60

61
    def get_config(self):
62
        return SamPromptEncoderConfig(
63
            image_size=self.input_image_size,
64
            patch_size=self.patch_size,
65
            mask_input_channels=self.mask_input_channels,
66
            hidden_size=self.hidden_size,
67
            num_point_embeddings=self.num_point_embeddings,
68
            hidden_act=self.hidden_act,
69
        )
70

71
    def prepare_config_and_inputs(self):
72
        dummy_points = floats_tensor([self.batch_size, 3, 2])
73
        config = self.get_config()
74

75
        return config, dummy_points
76

77

78
class SamMaskDecoderTester:
79
    def __init__(
80
        self,
81
        hidden_size=32,
82
        hidden_act="relu",
83
        mlp_dim=64,
84
        num_hidden_layers=2,
85
        num_attention_heads=4,
86
        attention_downsample_rate=2,
87
        num_multimask_outputs=3,
88
        iou_head_depth=3,
89
        iou_head_hidden_dim=32,
90
        layer_norm_eps=1e-6,
91
    ):
92
        self.hidden_size = hidden_size
93
        self.hidden_act = hidden_act
94
        self.mlp_dim = mlp_dim
95
        self.num_hidden_layers = num_hidden_layers
96
        self.num_attention_heads = num_attention_heads
97
        self.attention_downsample_rate = attention_downsample_rate
98
        self.num_multimask_outputs = num_multimask_outputs
99
        self.iou_head_depth = iou_head_depth
100
        self.iou_head_hidden_dim = iou_head_hidden_dim
101
        self.layer_norm_eps = layer_norm_eps
102

103
    def get_config(self):
104
        return SamMaskDecoderConfig(
105
            hidden_size=self.hidden_size,
106
            hidden_act=self.hidden_act,
107
            mlp_dim=self.mlp_dim,
108
            num_hidden_layers=self.num_hidden_layers,
109
            num_attention_heads=self.num_attention_heads,
110
            attention_downsample_rate=self.attention_downsample_rate,
111
            num_multimask_outputs=self.num_multimask_outputs,
112
            iou_head_depth=self.iou_head_depth,
113
            iou_head_hidden_dim=self.iou_head_hidden_dim,
114
            layer_norm_eps=self.layer_norm_eps,
115
        )
116

117
    def prepare_config_and_inputs(self):
118
        config = self.get_config()
119

120
        dummy_inputs = {
121
            "image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
122
        }
123

124
        return config, dummy_inputs
125

126

127
class SamModelTester:
128
    def __init__(
129
        self,
130
        parent,
131
        hidden_size=36,
132
        intermediate_size=72,
133
        projection_dim=62,
134
        output_channels=32,
135
        num_hidden_layers=2,
136
        num_attention_heads=4,
137
        num_channels=3,
138
        image_size=24,
139
        patch_size=2,
140
        hidden_act="gelu",
141
        layer_norm_eps=1e-06,
142
        dropout=0.0,
143
        attention_dropout=0.0,
144
        initializer_range=0.02,
145
        initializer_factor=1.0,
146
        qkv_bias=True,
147
        mlp_ratio=4.0,
148
        use_abs_pos=True,
149
        use_rel_pos=True,
150
        rel_pos_zero_init=False,
151
        window_size=14,
152
        global_attn_indexes=[2, 5, 8, 11],
153
        num_pos_feats=16,
154
        mlp_dim=None,
155
        batch_size=2,
156
    ):
157
        self.parent = parent
158
        self.image_size = image_size
159
        self.patch_size = patch_size
160
        self.output_channels = output_channels
161
        self.num_channels = num_channels
162
        self.hidden_size = hidden_size
163
        self.projection_dim = projection_dim
164
        self.num_hidden_layers = num_hidden_layers
165
        self.num_attention_heads = num_attention_heads
166
        self.intermediate_size = intermediate_size
167
        self.dropout = dropout
168
        self.attention_dropout = attention_dropout
169
        self.initializer_range = initializer_range
170
        self.initializer_factor = initializer_factor
171
        self.hidden_act = hidden_act
172
        self.layer_norm_eps = layer_norm_eps
173
        self.qkv_bias = qkv_bias
174
        self.mlp_ratio = mlp_ratio
175
        self.use_abs_pos = use_abs_pos
176
        self.use_rel_pos = use_rel_pos
177
        self.rel_pos_zero_init = rel_pos_zero_init
178
        self.window_size = window_size
179
        self.global_attn_indexes = global_attn_indexes
180
        self.num_pos_feats = num_pos_feats
181
        self.mlp_dim = mlp_dim
182
        self.batch_size = batch_size
183

184
        # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
185
        num_patches = (image_size // patch_size) ** 2
186
        self.seq_length = num_patches + 1
187

188
        self.prompt_encoder_tester = SamPromptEncoderTester()
189
        self.mask_decoder_tester = SamMaskDecoderTester()
190

191
    def prepare_config_and_inputs(self):
192
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
193
        config = self.get_config()
194

195
        return config, pixel_values
196

197
    def get_config(self):
198
        vision_config = SamVisionConfig(
199
            image_size=self.image_size,
200
            patch_size=self.patch_size,
201
            num_channels=self.num_channels,
202
            hidden_size=self.hidden_size,
203
            projection_dim=self.projection_dim,
204
            num_hidden_layers=self.num_hidden_layers,
205
            num_attention_heads=self.num_attention_heads,
206
            intermediate_size=self.intermediate_size,
207
            dropout=self.dropout,
208
            attention_dropout=self.attention_dropout,
209
            initializer_range=self.initializer_range,
210
            initializer_factor=self.initializer_factor,
211
            output_channels=self.output_channels,
212
            qkv_bias=self.qkv_bias,
213
            mlp_ratio=self.mlp_ratio,
214
            use_abs_pos=self.use_abs_pos,
215
            use_rel_pos=self.use_rel_pos,
216
            rel_pos_zero_init=self.rel_pos_zero_init,
217
            window_size=self.window_size,
218
            global_attn_indexes=self.global_attn_indexes,
219
            num_pos_feats=self.num_pos_feats,
220
            mlp_dim=self.mlp_dim,
221
        )
222

223
        prompt_encoder_config = self.prompt_encoder_tester.get_config()
224

225
        mask_decoder_config = self.mask_decoder_tester.get_config()
226

227
        return SamConfig(
228
            vision_config=vision_config,
229
            prompt_encoder_config=prompt_encoder_config,
230
            mask_decoder_config=mask_decoder_config,
231
        )
232

233
    def create_and_check_model(self, config, pixel_values):
234
        model = SamModel(config=config)
235
        model.to(torch_device)
236
        model.eval()
237
        with torch.no_grad():
238
            result = model(pixel_values)
239
        self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
240
        self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
241

242
    def create_and_check_get_image_features(self, config, pixel_values):
243
        model = SamModel(config=config)
244
        model.to(torch_device)
245
        model.eval()
246
        with torch.no_grad():
247
            result = model.get_image_embeddings(pixel_values)
248
        self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
249

250
    def create_and_check_get_image_hidden_states(self, config, pixel_values):
251
        model = SamModel(config=config)
252
        model.to(torch_device)
253
        model.eval()
254
        with torch.no_grad():
255
            result = model.vision_encoder(
256
                pixel_values,
257
                output_hidden_states=True,
258
                return_dict=True,
259
            )
260

261
        # after computing the convolutional features
262
        expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
263
        self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
264
        self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
265

266
        with torch.no_grad():
267
            result = model.vision_encoder(
268
                pixel_values,
269
                output_hidden_states=True,
270
                return_dict=False,
271
            )
272

273
        # after computing the convolutional features
274
        expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
275
        self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
276
        self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
277

278
    def prepare_config_and_inputs_for_common(self):
279
        config_and_inputs = self.prepare_config_and_inputs()
280
        config, pixel_values = config_and_inputs
281
        inputs_dict = {"pixel_values": pixel_values}
282
        return config, inputs_dict
283

284

285
@require_torch
286
class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
287
    """
288
    Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
289
    attention_mask and seq_length.
290
    """
291

292
    all_model_classes = (SamModel,) if is_torch_available() else ()
293
    pipeline_model_mapping = (
294
        {"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {}
295
    )
296
    fx_compatible = False
297
    test_pruning = False
298
    test_resize_embeddings = False
299
    test_head_masking = False
300
    test_torchscript = False
301

302
    # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
303
    def is_pipeline_test_to_skip(
304
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
305
    ):
306
        return True
307

308
    def setUp(self):
309
        self.model_tester = SamModelTester(self)
310
        self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
311
        self.prompt_encoder_config_tester = ConfigTester(
312
            self,
313
            config_class=SamPromptEncoderConfig,
314
            has_text_modality=False,
315
            num_attention_heads=12,
316
            num_hidden_layers=2,
317
        )
318
        self.mask_decoder_config_tester = ConfigTester(
319
            self, config_class=SamMaskDecoderConfig, has_text_modality=False
320
        )
321

322
    def test_config(self):
323
        self.vision_config_tester.run_common_tests()
324
        self.prompt_encoder_config_tester.run_common_tests()
325
        self.mask_decoder_config_tester.run_common_tests()
326

327
    @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
328
    def test_inputs_embeds(self):
329
        pass
330

331
    def test_model_common_attributes(self):
332
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
333

334
        for model_class in self.all_model_classes:
335
            model = model_class(config)
336
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
337
            x = model.get_output_embeddings()
338
            self.assertTrue(x is None or isinstance(x, nn.Linear))
339

340
    def test_model(self):
341
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
342
        self.model_tester.create_and_check_model(*config_and_inputs)
343

344
    def test_get_image_features(self):
345
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
346
        self.model_tester.create_and_check_get_image_features(*config_and_inputs)
347

348
    def test_image_hidden_states(self):
349
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
350
        self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
351

352
    def test_attention_outputs(self):
353
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
354
        config.return_dict = True
355

356
        expected_vision_attention_shape = (
357
            self.model_tester.batch_size * self.model_tester.num_attention_heads,
358
            196,
359
            196,
360
        )
361
        expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
362

363
        for model_class in self.all_model_classes:
364
            inputs_dict["output_attentions"] = True
365
            inputs_dict["output_hidden_states"] = False
366
            config.return_dict = True
367
            model = model_class(config)
368
            model.to(torch_device)
369
            model.eval()
370
            with torch.no_grad():
371
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
372

373
            vision_attentions = outputs.vision_attentions
374
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
375

376
            mask_decoder_attentions = outputs.mask_decoder_attentions
377
            self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
378

379
            # check that output_attentions also work using config
380
            del inputs_dict["output_attentions"]
381
            config.output_attentions = True
382
            model = model_class(config)
383
            model.to(torch_device)
384
            model.eval()
385
            with torch.no_grad():
386
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
387
            vision_attentions = outputs.vision_attentions
388
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
389

390
            mask_decoder_attentions = outputs.mask_decoder_attentions
391
            self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
392

393
            self.assertListEqual(
394
                list(vision_attentions[0].shape[-4:]),
395
                list(expected_vision_attention_shape),
396
            )
397

398
            self.assertListEqual(
399
                list(mask_decoder_attentions[0].shape[-4:]),
400
                list(expected_mask_decoder_attention_shape),
401
            )
402

403
    @unittest.skip(reason="SamModel does not support training")
404
    def test_training(self):
405
        pass
406

407
    @unittest.skip(reason="SamModel does not support training")
408
    def test_training_gradient_checkpointing(self):
409
        pass
410

411
    @unittest.skip(
412
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
413
    )
414
    def test_training_gradient_checkpointing_use_reentrant(self):
415
        pass
416

417
    @unittest.skip(
418
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
419
    )
420
    def test_training_gradient_checkpointing_use_reentrant_false(self):
421
        pass
422

423
    @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
424
    def test_save_load_fast_init_from_base(self):
425
        pass
426

427
    @unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
428
    def test_save_load_fast_init_to_base(self):
429
        pass
430

431
    @unittest.skip(reason="SamModel does not support training")
432
    def test_retain_grad_hidden_states_attentions(self):
433
        pass
434

435
    @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
436
    def test_hidden_states_output(self):
437
        pass
438

439
    def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
440
        # Use a slightly higher default tol to make the tests non-flaky
441
        super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
442

443
    @slow
444
    def test_model_from_pretrained(self):
445
        for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
446
            model = SamModel.from_pretrained(model_name)
447
            self.assertIsNotNone(model)
448

449

450
def prepare_image():
451
    img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
452
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
453
    return raw_image
454

455

456
def prepare_dog_img():
457
    img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
458
    raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
459
    return raw_image
460

461

462
@slow
463
class SamModelIntegrationTest(unittest.TestCase):
464
    def tearDown(self):
465
        super().tearDown()
466
        # clean-up as much as possible GPU memory occupied by PyTorch
467
        gc.collect()
468
        backend_empty_cache(torch_device)
469

470
    def test_inference_mask_generation_no_point(self):
471
        model = SamModel.from_pretrained("facebook/sam-vit-base")
472
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
473

474
        model.to(torch_device)
475
        model.eval()
476

477
        raw_image = prepare_image()
478
        inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
479

480
        with torch.no_grad():
481
            outputs = model(**inputs)
482
        scores = outputs.iou_scores.squeeze()
483
        masks = outputs.pred_masks[0, 0, 0, 0, :3]
484
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
485
        self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
486

487
    def test_inference_mask_generation_one_point_one_bb(self):
488
        model = SamModel.from_pretrained("facebook/sam-vit-base")
489
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
490

491
        model.to(torch_device)
492
        model.eval()
493

494
        raw_image = prepare_image()
495
        input_boxes = [[[650, 900, 1000, 1250]]]
496
        input_points = [[[820, 1080]]]
497

498
        inputs = processor(
499
            images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
500
        ).to(torch_device)
501

502
        with torch.no_grad():
503
            outputs = model(**inputs)
504
        scores = outputs.iou_scores.squeeze()
505
        masks = outputs.pred_masks[0, 0, 0, 0, :3]
506
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
507
        self.assertTrue(
508
            torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
509
        )
510

511
    def test_inference_mask_generation_batched_points_batched_images(self):
512
        model = SamModel.from_pretrained("facebook/sam-vit-base")
513
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
514

515
        model.to(torch_device)
516
        model.eval()
517

518
        raw_image = prepare_image()
519
        input_points = [
520
            [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
521
            [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
522
        ]
523

524
        inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to(
525
            torch_device
526
        )
527

528
        with torch.no_grad():
529
            outputs = model(**inputs)
530
        scores = outputs.iou_scores.squeeze().cpu()
531
        masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
532

533
        EXPECTED_SCORES = torch.tensor(
534
            [
535
                [
536
                    [0.6765, 0.9379, 0.8803],
537
                    [0.6765, 0.9379, 0.8803],
538
                    [0.6765, 0.9379, 0.8803],
539
                    [0.6765, 0.9379, 0.8803],
540
                ],
541
                [
542
                    [0.3317, 0.7264, 0.7646],
543
                    [0.6765, 0.9379, 0.8803],
544
                    [0.6765, 0.9379, 0.8803],
545
                    [0.6765, 0.9379, 0.8803],
546
                ],
547
            ]
548
        )
549
        EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
550
        self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
551
        self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
552

553
    def test_inference_mask_generation_one_point_one_bb_zero(self):
554
        model = SamModel.from_pretrained("facebook/sam-vit-base")
555
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
556

557
        model.to(torch_device)
558
        model.eval()
559

560
        raw_image = prepare_image()
561
        input_boxes = [[[620, 900, 1000, 1255]]]
562
        input_points = [[[820, 1080]]]
563
        labels = [[0]]
564

565
        inputs = processor(
566
            images=raw_image,
567
            input_boxes=input_boxes,
568
            input_points=input_points,
569
            input_labels=labels,
570
            return_tensors="pt",
571
        ).to(torch_device)
572

573
        with torch.no_grad():
574
            outputs = model(**inputs)
575
        scores = outputs.iou_scores.squeeze()
576

577
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
578

579
    def test_inference_mask_generation_one_point(self):
580
        model = SamModel.from_pretrained("facebook/sam-vit-base")
581
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
582

583
        model.to(torch_device)
584
        model.eval()
585

586
        raw_image = prepare_image()
587

588
        input_points = [[[400, 650]]]
589
        input_labels = [[1]]
590

591
        inputs = processor(
592
            images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
593
        ).to(torch_device)
594

595
        with torch.no_grad():
596
            outputs = model(**inputs)
597
        scores = outputs.iou_scores.squeeze()
598
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
599

600
        # With no label
601
        input_points = [[[400, 650]]]
602

603
        inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
604

605
        with torch.no_grad():
606
            outputs = model(**inputs)
607
        scores = outputs.iou_scores.squeeze()
608
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
609

610
    def test_inference_mask_generation_two_points(self):
611
        model = SamModel.from_pretrained("facebook/sam-vit-base")
612
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
613

614
        model.to(torch_device)
615
        model.eval()
616

617
        raw_image = prepare_image()
618

619
        input_points = [[[400, 650], [800, 650]]]
620
        input_labels = [[1, 1]]
621

622
        inputs = processor(
623
            images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
624
        ).to(torch_device)
625

626
        with torch.no_grad():
627
            outputs = model(**inputs)
628
        scores = outputs.iou_scores.squeeze()
629
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
630

631
        # no labels
632
        inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
633

634
        with torch.no_grad():
635
            outputs = model(**inputs)
636
        scores = outputs.iou_scores.squeeze()
637

638
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
639

640
    def test_inference_mask_generation_two_points_batched(self):
641
        model = SamModel.from_pretrained("facebook/sam-vit-base")
642
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
643

644
        model.to(torch_device)
645
        model.eval()
646

647
        raw_image = prepare_image()
648

649
        input_points = [[[400, 650], [800, 650]], [[400, 650]]]
650
        input_labels = [[1, 1], [1]]
651

652
        inputs = processor(
653
            images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt"
654
        ).to(torch_device)
655

656
        with torch.no_grad():
657
            outputs = model(**inputs)
658
        scores = outputs.iou_scores.squeeze()
659
        self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
660
        self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
661

662
    def test_inference_mask_generation_one_box(self):
663
        model = SamModel.from_pretrained("facebook/sam-vit-base")
664
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
665

666
        model.to(torch_device)
667
        model.eval()
668

669
        raw_image = prepare_image()
670

671
        input_boxes = [[[75, 275, 1725, 850]]]
672

673
        inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
674

675
        with torch.no_grad():
676
            outputs = model(**inputs)
677
        scores = outputs.iou_scores.squeeze()
678
        self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
679

680
    def test_inference_mask_generation_batched_image_one_point(self):
681
        model = SamModel.from_pretrained("facebook/sam-vit-base")
682
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
683

684
        model.to(torch_device)
685
        model.eval()
686

687
        raw_image = prepare_image()
688
        raw_dog_image = prepare_dog_img()
689

690
        input_points = [[[820, 1080]], [[220, 470]]]
691

692
        inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to(
693
            torch_device
694
        )
695

696
        with torch.no_grad():
697
            outputs = model(**inputs)
698
        scores_batched = outputs.iou_scores.squeeze()
699

700
        input_points = [[[220, 470]]]
701

702
        inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device)
703

704
        with torch.no_grad():
705
            outputs = model(**inputs)
706
        scores_single = outputs.iou_scores.squeeze()
707
        self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
708

709
    def test_inference_mask_generation_two_points_point_batch(self):
710
        model = SamModel.from_pretrained("facebook/sam-vit-base")
711
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
712

713
        model.to(torch_device)
714
        model.eval()
715

716
        raw_image = prepare_image()
717

718
        input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu()  # fmt: skip
719

720
        input_points = input_points.unsqueeze(0)
721

722
        inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
723

724
        with torch.no_grad():
725
            outputs = model(**inputs)
726

727
        iou_scores = outputs.iou_scores.cpu()
728
        self.assertTrue(iou_scores.shape == (1, 2, 3))
729
        torch.testing.assert_allclose(
730
            iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
731
        )
732

733
    def test_inference_mask_generation_three_boxes_point_batch(self):
734
        model = SamModel.from_pretrained("facebook/sam-vit-base")
735
        processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
736

737
        model.to(torch_device)
738
        model.eval()
739

740
        raw_image = prepare_image()
741

742
        # fmt: off
743
        input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]],  [[75, 275, 1725, 850]]]).cpu()
744
        EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
745
         [0.5996, 0.7661, 0.7937],
746
         [0.5996, 0.7661, 0.7937]]])
747
        # fmt: on
748
        input_boxes = input_boxes.unsqueeze(0)
749

750
        inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
751

752
        with torch.no_grad():
753
            outputs = model(**inputs)
754

755
        iou_scores = outputs.iou_scores.cpu()
756
        self.assertTrue(iou_scores.shape == (1, 3, 3))
757
        torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
758

759
    def test_dummy_pipeline_generation(self):
760
        generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
761
        raw_image = prepare_image()
762

763
        _ = generator(raw_image, points_per_batch=64)
764

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.