transformers
763 строки · 27.8 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch SAM model. """
16
17
18import gc
19import unittest
20
21import requests
22
23from transformers import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig, pipeline
24from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device
25from transformers.utils import is_torch_available, is_vision_available
26
27from ...test_configuration_common import ConfigTester
28from ...test_modeling_common import ModelTesterMixin, floats_tensor
29from ...test_pipeline_mixin import PipelineTesterMixin
30
31
32if is_torch_available():
33import torch
34from torch import nn
35
36from transformers import SamModel, SamProcessor
37from transformers.models.sam.modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST
38
39
40if is_vision_available():
41from PIL import Image
42
43
44class SamPromptEncoderTester:
45def __init__(
46self,
47hidden_size=32,
48input_image_size=24,
49patch_size=2,
50mask_input_channels=4,
51num_point_embeddings=4,
52hidden_act="gelu",
53):
54self.hidden_size = hidden_size
55self.input_image_size = input_image_size
56self.patch_size = patch_size
57self.mask_input_channels = mask_input_channels
58self.num_point_embeddings = num_point_embeddings
59self.hidden_act = hidden_act
60
61def get_config(self):
62return SamPromptEncoderConfig(
63image_size=self.input_image_size,
64patch_size=self.patch_size,
65mask_input_channels=self.mask_input_channels,
66hidden_size=self.hidden_size,
67num_point_embeddings=self.num_point_embeddings,
68hidden_act=self.hidden_act,
69)
70
71def prepare_config_and_inputs(self):
72dummy_points = floats_tensor([self.batch_size, 3, 2])
73config = self.get_config()
74
75return config, dummy_points
76
77
78class SamMaskDecoderTester:
79def __init__(
80self,
81hidden_size=32,
82hidden_act="relu",
83mlp_dim=64,
84num_hidden_layers=2,
85num_attention_heads=4,
86attention_downsample_rate=2,
87num_multimask_outputs=3,
88iou_head_depth=3,
89iou_head_hidden_dim=32,
90layer_norm_eps=1e-6,
91):
92self.hidden_size = hidden_size
93self.hidden_act = hidden_act
94self.mlp_dim = mlp_dim
95self.num_hidden_layers = num_hidden_layers
96self.num_attention_heads = num_attention_heads
97self.attention_downsample_rate = attention_downsample_rate
98self.num_multimask_outputs = num_multimask_outputs
99self.iou_head_depth = iou_head_depth
100self.iou_head_hidden_dim = iou_head_hidden_dim
101self.layer_norm_eps = layer_norm_eps
102
103def get_config(self):
104return SamMaskDecoderConfig(
105hidden_size=self.hidden_size,
106hidden_act=self.hidden_act,
107mlp_dim=self.mlp_dim,
108num_hidden_layers=self.num_hidden_layers,
109num_attention_heads=self.num_attention_heads,
110attention_downsample_rate=self.attention_downsample_rate,
111num_multimask_outputs=self.num_multimask_outputs,
112iou_head_depth=self.iou_head_depth,
113iou_head_hidden_dim=self.iou_head_hidden_dim,
114layer_norm_eps=self.layer_norm_eps,
115)
116
117def prepare_config_and_inputs(self):
118config = self.get_config()
119
120dummy_inputs = {
121"image_embedding": floats_tensor([self.batch_size, self.hidden_size]),
122}
123
124return config, dummy_inputs
125
126
127class SamModelTester:
128def __init__(
129self,
130parent,
131hidden_size=36,
132intermediate_size=72,
133projection_dim=62,
134output_channels=32,
135num_hidden_layers=2,
136num_attention_heads=4,
137num_channels=3,
138image_size=24,
139patch_size=2,
140hidden_act="gelu",
141layer_norm_eps=1e-06,
142dropout=0.0,
143attention_dropout=0.0,
144initializer_range=0.02,
145initializer_factor=1.0,
146qkv_bias=True,
147mlp_ratio=4.0,
148use_abs_pos=True,
149use_rel_pos=True,
150rel_pos_zero_init=False,
151window_size=14,
152global_attn_indexes=[2, 5, 8, 11],
153num_pos_feats=16,
154mlp_dim=None,
155batch_size=2,
156):
157self.parent = parent
158self.image_size = image_size
159self.patch_size = patch_size
160self.output_channels = output_channels
161self.num_channels = num_channels
162self.hidden_size = hidden_size
163self.projection_dim = projection_dim
164self.num_hidden_layers = num_hidden_layers
165self.num_attention_heads = num_attention_heads
166self.intermediate_size = intermediate_size
167self.dropout = dropout
168self.attention_dropout = attention_dropout
169self.initializer_range = initializer_range
170self.initializer_factor = initializer_factor
171self.hidden_act = hidden_act
172self.layer_norm_eps = layer_norm_eps
173self.qkv_bias = qkv_bias
174self.mlp_ratio = mlp_ratio
175self.use_abs_pos = use_abs_pos
176self.use_rel_pos = use_rel_pos
177self.rel_pos_zero_init = rel_pos_zero_init
178self.window_size = window_size
179self.global_attn_indexes = global_attn_indexes
180self.num_pos_feats = num_pos_feats
181self.mlp_dim = mlp_dim
182self.batch_size = batch_size
183
184# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
185num_patches = (image_size // patch_size) ** 2
186self.seq_length = num_patches + 1
187
188self.prompt_encoder_tester = SamPromptEncoderTester()
189self.mask_decoder_tester = SamMaskDecoderTester()
190
191def prepare_config_and_inputs(self):
192pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
193config = self.get_config()
194
195return config, pixel_values
196
197def get_config(self):
198vision_config = SamVisionConfig(
199image_size=self.image_size,
200patch_size=self.patch_size,
201num_channels=self.num_channels,
202hidden_size=self.hidden_size,
203projection_dim=self.projection_dim,
204num_hidden_layers=self.num_hidden_layers,
205num_attention_heads=self.num_attention_heads,
206intermediate_size=self.intermediate_size,
207dropout=self.dropout,
208attention_dropout=self.attention_dropout,
209initializer_range=self.initializer_range,
210initializer_factor=self.initializer_factor,
211output_channels=self.output_channels,
212qkv_bias=self.qkv_bias,
213mlp_ratio=self.mlp_ratio,
214use_abs_pos=self.use_abs_pos,
215use_rel_pos=self.use_rel_pos,
216rel_pos_zero_init=self.rel_pos_zero_init,
217window_size=self.window_size,
218global_attn_indexes=self.global_attn_indexes,
219num_pos_feats=self.num_pos_feats,
220mlp_dim=self.mlp_dim,
221)
222
223prompt_encoder_config = self.prompt_encoder_tester.get_config()
224
225mask_decoder_config = self.mask_decoder_tester.get_config()
226
227return SamConfig(
228vision_config=vision_config,
229prompt_encoder_config=prompt_encoder_config,
230mask_decoder_config=mask_decoder_config,
231)
232
233def create_and_check_model(self, config, pixel_values):
234model = SamModel(config=config)
235model.to(torch_device)
236model.eval()
237with torch.no_grad():
238result = model(pixel_values)
239self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3))
240self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3))
241
242def create_and_check_get_image_features(self, config, pixel_values):
243model = SamModel(config=config)
244model.to(torch_device)
245model.eval()
246with torch.no_grad():
247result = model.get_image_embeddings(pixel_values)
248self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12))
249
250def create_and_check_get_image_hidden_states(self, config, pixel_values):
251model = SamModel(config=config)
252model.to(torch_device)
253model.eval()
254with torch.no_grad():
255result = model.vision_encoder(
256pixel_values,
257output_hidden_states=True,
258return_dict=True,
259)
260
261# after computing the convolutional features
262expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
263self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
264self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
265
266with torch.no_grad():
267result = model.vision_encoder(
268pixel_values,
269output_hidden_states=True,
270return_dict=False,
271)
272
273# after computing the convolutional features
274expected_hidden_states_shape = (self.batch_size, 12, 12, 36)
275self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1)
276self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape)
277
278def prepare_config_and_inputs_for_common(self):
279config_and_inputs = self.prepare_config_and_inputs()
280config, pixel_values = config_and_inputs
281inputs_dict = {"pixel_values": pixel_values}
282return config, inputs_dict
283
284
285@require_torch
286class SamModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
287"""
288Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds,
289attention_mask and seq_length.
290"""
291
292all_model_classes = (SamModel,) if is_torch_available() else ()
293pipeline_model_mapping = (
294{"feature-extraction": SamModel, "mask-generation": SamModel} if is_torch_available() else {}
295)
296fx_compatible = False
297test_pruning = False
298test_resize_embeddings = False
299test_head_masking = False
300test_torchscript = False
301
302# TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working
303def is_pipeline_test_to_skip(
304self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
305):
306return True
307
308def setUp(self):
309self.model_tester = SamModelTester(self)
310self.vision_config_tester = ConfigTester(self, config_class=SamVisionConfig, has_text_modality=False)
311self.prompt_encoder_config_tester = ConfigTester(
312self,
313config_class=SamPromptEncoderConfig,
314has_text_modality=False,
315num_attention_heads=12,
316num_hidden_layers=2,
317)
318self.mask_decoder_config_tester = ConfigTester(
319self, config_class=SamMaskDecoderConfig, has_text_modality=False
320)
321
322def test_config(self):
323self.vision_config_tester.run_common_tests()
324self.prompt_encoder_config_tester.run_common_tests()
325self.mask_decoder_config_tester.run_common_tests()
326
327@unittest.skip(reason="SAM's vision encoder does not use inputs_embeds")
328def test_inputs_embeds(self):
329pass
330
331def test_model_common_attributes(self):
332config, _ = self.model_tester.prepare_config_and_inputs_for_common()
333
334for model_class in self.all_model_classes:
335model = model_class(config)
336self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
337x = model.get_output_embeddings()
338self.assertTrue(x is None or isinstance(x, nn.Linear))
339
340def test_model(self):
341config_and_inputs = self.model_tester.prepare_config_and_inputs()
342self.model_tester.create_and_check_model(*config_and_inputs)
343
344def test_get_image_features(self):
345config_and_inputs = self.model_tester.prepare_config_and_inputs()
346self.model_tester.create_and_check_get_image_features(*config_and_inputs)
347
348def test_image_hidden_states(self):
349config_and_inputs = self.model_tester.prepare_config_and_inputs()
350self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs)
351
352def test_attention_outputs(self):
353config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
354config.return_dict = True
355
356expected_vision_attention_shape = (
357self.model_tester.batch_size * self.model_tester.num_attention_heads,
358196,
359196,
360)
361expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32)
362
363for model_class in self.all_model_classes:
364inputs_dict["output_attentions"] = True
365inputs_dict["output_hidden_states"] = False
366config.return_dict = True
367model = model_class(config)
368model.to(torch_device)
369model.eval()
370with torch.no_grad():
371outputs = model(**self._prepare_for_class(inputs_dict, model_class))
372
373vision_attentions = outputs.vision_attentions
374self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
375
376mask_decoder_attentions = outputs.mask_decoder_attentions
377self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
378
379# check that output_attentions also work using config
380del inputs_dict["output_attentions"]
381config.output_attentions = True
382model = model_class(config)
383model.to(torch_device)
384model.eval()
385with torch.no_grad():
386outputs = model(**self._prepare_for_class(inputs_dict, model_class))
387vision_attentions = outputs.vision_attentions
388self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers)
389
390mask_decoder_attentions = outputs.mask_decoder_attentions
391self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers)
392
393self.assertListEqual(
394list(vision_attentions[0].shape[-4:]),
395list(expected_vision_attention_shape),
396)
397
398self.assertListEqual(
399list(mask_decoder_attentions[0].shape[-4:]),
400list(expected_mask_decoder_attention_shape),
401)
402
403@unittest.skip(reason="SamModel does not support training")
404def test_training(self):
405pass
406
407@unittest.skip(reason="SamModel does not support training")
408def test_training_gradient_checkpointing(self):
409pass
410
411@unittest.skip(
412reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
413)
414def test_training_gradient_checkpointing_use_reentrant(self):
415pass
416
417@unittest.skip(
418reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
419)
420def test_training_gradient_checkpointing_use_reentrant_false(self):
421pass
422
423@unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
424def test_save_load_fast_init_from_base(self):
425pass
426
427@unittest.skip(reason="SamModel has no base class and is not available in MODEL_MAPPING")
428def test_save_load_fast_init_to_base(self):
429pass
430
431@unittest.skip(reason="SamModel does not support training")
432def test_retain_grad_hidden_states_attentions(self):
433pass
434
435@unittest.skip(reason="Hidden_states is tested in create_and_check_model tests")
436def test_hidden_states_output(self):
437pass
438
439def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
440# Use a slightly higher default tol to make the tests non-flaky
441super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
442
443@slow
444def test_model_from_pretrained(self):
445for model_name in SAM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
446model = SamModel.from_pretrained(model_name)
447self.assertIsNotNone(model)
448
449
450def prepare_image():
451img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
452raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
453return raw_image
454
455
456def prepare_dog_img():
457img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
458raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
459return raw_image
460
461
462@slow
463class SamModelIntegrationTest(unittest.TestCase):
464def tearDown(self):
465super().tearDown()
466# clean-up as much as possible GPU memory occupied by PyTorch
467gc.collect()
468backend_empty_cache(torch_device)
469
470def test_inference_mask_generation_no_point(self):
471model = SamModel.from_pretrained("facebook/sam-vit-base")
472processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
473
474model.to(torch_device)
475model.eval()
476
477raw_image = prepare_image()
478inputs = processor(images=raw_image, return_tensors="pt").to(torch_device)
479
480with torch.no_grad():
481outputs = model(**inputs)
482scores = outputs.iou_scores.squeeze()
483masks = outputs.pred_masks[0, 0, 0, 0, :3]
484self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4))
485self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4))
486
487def test_inference_mask_generation_one_point_one_bb(self):
488model = SamModel.from_pretrained("facebook/sam-vit-base")
489processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
490
491model.to(torch_device)
492model.eval()
493
494raw_image = prepare_image()
495input_boxes = [[[650, 900, 1000, 1250]]]
496input_points = [[[820, 1080]]]
497
498inputs = processor(
499images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt"
500).to(torch_device)
501
502with torch.no_grad():
503outputs = model(**inputs)
504scores = outputs.iou_scores.squeeze()
505masks = outputs.pred_masks[0, 0, 0, 0, :3]
506self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4))
507self.assertTrue(
508torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4)
509)
510
511def test_inference_mask_generation_batched_points_batched_images(self):
512model = SamModel.from_pretrained("facebook/sam-vit-base")
513processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
514
515model.to(torch_device)
516model.eval()
517
518raw_image = prepare_image()
519input_points = [
520[[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
521[[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]],
522]
523
524inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to(
525torch_device
526)
527
528with torch.no_grad():
529outputs = model(**inputs)
530scores = outputs.iou_scores.squeeze().cpu()
531masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu()
532
533EXPECTED_SCORES = torch.tensor(
534[
535[
536[0.6765, 0.9379, 0.8803],
537[0.6765, 0.9379, 0.8803],
538[0.6765, 0.9379, 0.8803],
539[0.6765, 0.9379, 0.8803],
540],
541[
542[0.3317, 0.7264, 0.7646],
543[0.6765, 0.9379, 0.8803],
544[0.6765, 0.9379, 0.8803],
545[0.6765, 0.9379, 0.8803],
546],
547]
548)
549EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625])
550self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3))
551self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3))
552
553def test_inference_mask_generation_one_point_one_bb_zero(self):
554model = SamModel.from_pretrained("facebook/sam-vit-base")
555processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
556
557model.to(torch_device)
558model.eval()
559
560raw_image = prepare_image()
561input_boxes = [[[620, 900, 1000, 1255]]]
562input_points = [[[820, 1080]]]
563labels = [[0]]
564
565inputs = processor(
566images=raw_image,
567input_boxes=input_boxes,
568input_points=input_points,
569input_labels=labels,
570return_tensors="pt",
571).to(torch_device)
572
573with torch.no_grad():
574outputs = model(**inputs)
575scores = outputs.iou_scores.squeeze()
576
577self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4))
578
579def test_inference_mask_generation_one_point(self):
580model = SamModel.from_pretrained("facebook/sam-vit-base")
581processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
582
583model.to(torch_device)
584model.eval()
585
586raw_image = prepare_image()
587
588input_points = [[[400, 650]]]
589input_labels = [[1]]
590
591inputs = processor(
592images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
593).to(torch_device)
594
595with torch.no_grad():
596outputs = model(**inputs)
597scores = outputs.iou_scores.squeeze()
598self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
599
600# With no label
601input_points = [[[400, 650]]]
602
603inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
604
605with torch.no_grad():
606outputs = model(**inputs)
607scores = outputs.iou_scores.squeeze()
608self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4))
609
610def test_inference_mask_generation_two_points(self):
611model = SamModel.from_pretrained("facebook/sam-vit-base")
612processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
613
614model.to(torch_device)
615model.eval()
616
617raw_image = prepare_image()
618
619input_points = [[[400, 650], [800, 650]]]
620input_labels = [[1, 1]]
621
622inputs = processor(
623images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt"
624).to(torch_device)
625
626with torch.no_grad():
627outputs = model(**inputs)
628scores = outputs.iou_scores.squeeze()
629self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
630
631# no labels
632inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
633
634with torch.no_grad():
635outputs = model(**inputs)
636scores = outputs.iou_scores.squeeze()
637
638self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4))
639
640def test_inference_mask_generation_two_points_batched(self):
641model = SamModel.from_pretrained("facebook/sam-vit-base")
642processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
643
644model.to(torch_device)
645model.eval()
646
647raw_image = prepare_image()
648
649input_points = [[[400, 650], [800, 650]], [[400, 650]]]
650input_labels = [[1, 1], [1]]
651
652inputs = processor(
653images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt"
654).to(torch_device)
655
656with torch.no_grad():
657outputs = model(**inputs)
658scores = outputs.iou_scores.squeeze()
659self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4))
660self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4))
661
662def test_inference_mask_generation_one_box(self):
663model = SamModel.from_pretrained("facebook/sam-vit-base")
664processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
665
666model.to(torch_device)
667model.eval()
668
669raw_image = prepare_image()
670
671input_boxes = [[[75, 275, 1725, 850]]]
672
673inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
674
675with torch.no_grad():
676outputs = model(**inputs)
677scores = outputs.iou_scores.squeeze()
678self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4))
679
680def test_inference_mask_generation_batched_image_one_point(self):
681model = SamModel.from_pretrained("facebook/sam-vit-base")
682processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
683
684model.to(torch_device)
685model.eval()
686
687raw_image = prepare_image()
688raw_dog_image = prepare_dog_img()
689
690input_points = [[[820, 1080]], [[220, 470]]]
691
692inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to(
693torch_device
694)
695
696with torch.no_grad():
697outputs = model(**inputs)
698scores_batched = outputs.iou_scores.squeeze()
699
700input_points = [[[220, 470]]]
701
702inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device)
703
704with torch.no_grad():
705outputs = model(**inputs)
706scores_single = outputs.iou_scores.squeeze()
707self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4))
708
709def test_inference_mask_generation_two_points_point_batch(self):
710model = SamModel.from_pretrained("facebook/sam-vit-base")
711processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
712
713model.to(torch_device)
714model.eval()
715
716raw_image = prepare_image()
717
718input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip
719
720input_points = input_points.unsqueeze(0)
721
722inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device)
723
724with torch.no_grad():
725outputs = model(**inputs)
726
727iou_scores = outputs.iou_scores.cpu()
728self.assertTrue(iou_scores.shape == (1, 2, 3))
729torch.testing.assert_allclose(
730iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4
731)
732
733def test_inference_mask_generation_three_boxes_point_batch(self):
734model = SamModel.from_pretrained("facebook/sam-vit-base")
735processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
736
737model.to(torch_device)
738model.eval()
739
740raw_image = prepare_image()
741
742# fmt: off
743input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu()
744EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522],
745[0.5996, 0.7661, 0.7937],
746[0.5996, 0.7661, 0.7937]]])
747# fmt: on
748input_boxes = input_boxes.unsqueeze(0)
749
750inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
751
752with torch.no_grad():
753outputs = model(**inputs)
754
755iou_scores = outputs.iou_scores.cpu()
756self.assertTrue(iou_scores.shape == (1, 3, 3))
757torch.testing.assert_allclose(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4)
758
759def test_dummy_pipeline_generation(self):
760generator = pipeline("mask-generation", model="facebook/sam-vit-base", device=torch_device)
761raw_image = prepare_image()
762
763_ = generator(raw_image, points_per_batch=64)
764