transformers
557 строк · 21.1 Кб
1# coding=utf-8
2# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch BEiT model. """
16
17
18import unittest
19
20from datasets import load_dataset
21from packaging import version
22
23from transformers import BeitConfig
24from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
25from transformers.utils import cached_property, is_torch_available, is_vision_available
26
27from ...test_backbone_common import BackboneTesterMixin
28from ...test_configuration_common import ConfigTester
29from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
30from ...test_pipeline_mixin import PipelineTesterMixin
31
32
33if is_torch_available():
34import torch
35from torch import nn
36
37from transformers import (
38BeitBackbone,
39BeitForImageClassification,
40BeitForMaskedImageModeling,
41BeitForSemanticSegmentation,
42BeitModel,
43)
44from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
45from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST
46
47
48if is_vision_available():
49import PIL
50from PIL import Image
51
52from transformers import BeitImageProcessor
53
54
55class BeitModelTester:
56def __init__(
57self,
58parent,
59vocab_size=100,
60batch_size=13,
61image_size=30,
62patch_size=2,
63num_channels=3,
64is_training=True,
65use_labels=True,
66hidden_size=32,
67num_hidden_layers=4,
68num_attention_heads=4,
69intermediate_size=37,
70hidden_act="gelu",
71hidden_dropout_prob=0.1,
72attention_probs_dropout_prob=0.1,
73type_sequence_label_size=10,
74initializer_range=0.02,
75num_labels=3,
76scope=None,
77out_indices=[1, 2, 3, 4],
78out_features=["stage1", "stage2", "stage3", "stage4"],
79):
80self.parent = parent
81self.vocab_size = vocab_size
82self.batch_size = batch_size
83self.image_size = image_size
84self.patch_size = patch_size
85self.num_channels = num_channels
86self.is_training = is_training
87self.use_labels = use_labels
88self.hidden_size = hidden_size
89self.num_hidden_layers = num_hidden_layers
90self.num_attention_heads = num_attention_heads
91self.intermediate_size = intermediate_size
92self.hidden_act = hidden_act
93self.hidden_dropout_prob = hidden_dropout_prob
94self.attention_probs_dropout_prob = attention_probs_dropout_prob
95self.type_sequence_label_size = type_sequence_label_size
96self.initializer_range = initializer_range
97self.scope = scope
98self.out_indices = out_indices
99self.out_features = out_features
100self.num_labels = num_labels
101
102# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
103num_patches = (image_size // patch_size) ** 2
104self.seq_length = num_patches + 1
105
106def prepare_config_and_inputs(self):
107pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
108
109labels = None
110pixel_labels = None
111if self.use_labels:
112labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
113pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
114
115config = self.get_config()
116
117return config, pixel_values, labels, pixel_labels
118
119def get_config(self):
120return BeitConfig(
121vocab_size=self.vocab_size,
122image_size=self.image_size,
123patch_size=self.patch_size,
124num_channels=self.num_channels,
125hidden_size=self.hidden_size,
126num_hidden_layers=self.num_hidden_layers,
127num_attention_heads=self.num_attention_heads,
128intermediate_size=self.intermediate_size,
129hidden_act=self.hidden_act,
130hidden_dropout_prob=self.hidden_dropout_prob,
131attention_probs_dropout_prob=self.attention_probs_dropout_prob,
132is_decoder=False,
133initializer_range=self.initializer_range,
134out_indices=self.out_indices,
135out_features=self.out_features,
136)
137
138def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
139model = BeitModel(config=config)
140model.to(torch_device)
141model.eval()
142result = model(pixel_values)
143self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
144
145def create_and_check_backbone(self, config, pixel_values, labels, pixel_labels):
146model = BeitBackbone(config=config)
147model.to(torch_device)
148model.eval()
149result = model(pixel_values)
150
151# verify hidden states
152self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
153expected_height = expected_width = self.image_size // config.patch_size
154self.parent.assertListEqual(
155list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
156)
157
158# verify channels
159self.parent.assertEqual(len(model.channels), len(config.out_features))
160
161# verify backbone works with out_features=None
162config.out_features = None
163model = BeitBackbone(config=config)
164model.to(torch_device)
165model.eval()
166result = model(pixel_values)
167
168# verify feature maps
169self.parent.assertEqual(len(result.feature_maps), 1)
170self.parent.assertListEqual(
171list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
172)
173
174# verify channels
175self.parent.assertEqual(len(model.channels), 1)
176
177def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
178model = BeitForMaskedImageModeling(config=config)
179model.to(torch_device)
180model.eval()
181result = model(pixel_values)
182self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length - 1, self.vocab_size))
183
184def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
185config.num_labels = self.type_sequence_label_size
186model = BeitForImageClassification(config)
187model.to(torch_device)
188model.eval()
189result = model(pixel_values, labels=labels)
190self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
191
192# test greyscale images
193config.num_channels = 1
194model = BeitForImageClassification(config)
195model.to(torch_device)
196model.eval()
197
198pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
199result = model(pixel_values, labels=labels)
200self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
201
202def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
203config.num_labels = self.num_labels
204model = BeitForSemanticSegmentation(config)
205model.to(torch_device)
206model.eval()
207result = model(pixel_values)
208self.parent.assertEqual(
209result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
210)
211result = model(pixel_values, labels=pixel_labels)
212self.parent.assertEqual(
213result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
214)
215
216def prepare_config_and_inputs_for_common(self):
217config_and_inputs = self.prepare_config_and_inputs()
218config, pixel_values, labels, pixel_labels = config_and_inputs
219inputs_dict = {"pixel_values": pixel_values}
220return config, inputs_dict
221
222
223@require_torch
224class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
225"""
226Here we also overwrite some of the tests of test_modeling_common.py, as BEiT does not use input_ids, inputs_embeds,
227attention_mask and seq_length.
228"""
229
230all_model_classes = (
231(
232BeitModel,
233BeitForImageClassification,
234BeitForMaskedImageModeling,
235BeitForSemanticSegmentation,
236BeitBackbone,
237)
238if is_torch_available()
239else ()
240)
241pipeline_model_mapping = (
242{
243"image-feature-extraction": BeitModel,
244"image-classification": BeitForImageClassification,
245"image-segmentation": BeitForSemanticSegmentation,
246}
247if is_torch_available()
248else {}
249)
250
251test_pruning = False
252test_resize_embeddings = False
253test_head_masking = False
254
255def setUp(self):
256self.model_tester = BeitModelTester(self)
257self.config_tester = ConfigTester(self, config_class=BeitConfig, has_text_modality=False, hidden_size=37)
258
259def test_config(self):
260self.config_tester.run_common_tests()
261
262@unittest.skip(reason="BEiT does not use inputs_embeds")
263def test_inputs_embeds(self):
264pass
265
266@require_torch_multi_gpu
267@unittest.skip(reason="BEiT has some layers using `add_module` which doesn't work well with `nn.DataParallel`")
268def test_multi_gpu_data_parallel_forward(self):
269pass
270
271@unittest.skip(reason="BEiT does not support feedforward chunking yet")
272def test_feed_forward_chunking(self):
273pass
274
275def test_model_common_attributes(self):
276config, _ = self.model_tester.prepare_config_and_inputs_for_common()
277
278for model_class in self.all_model_classes:
279model = model_class(config)
280self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
281x = model.get_output_embeddings()
282self.assertTrue(x is None or isinstance(x, nn.Linear))
283
284def test_model(self):
285config_and_inputs = self.model_tester.prepare_config_and_inputs()
286self.model_tester.create_and_check_model(*config_and_inputs)
287
288def test_backbone(self):
289config_and_inputs = self.model_tester.prepare_config_and_inputs()
290self.model_tester.create_and_check_backbone(*config_and_inputs)
291
292def test_for_masked_lm(self):
293config_and_inputs = self.model_tester.prepare_config_and_inputs()
294self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
295
296def test_for_image_classification(self):
297config_and_inputs = self.model_tester.prepare_config_and_inputs()
298self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
299
300def test_for_semantic_segmentation(self):
301config_and_inputs = self.model_tester.prepare_config_and_inputs()
302self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
303
304def test_training(self):
305if not self.model_tester.is_training:
306return
307
308config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
309config.return_dict = True
310
311for model_class in self.all_model_classes:
312# we don't test BeitForMaskedImageModeling
313if model_class.__name__ in [
314*MODEL_MAPPING_NAMES.values(),
315*MODEL_FOR_BACKBONE_MAPPING_NAMES.values(),
316"BeitForMaskedImageModeling",
317]:
318continue
319
320model = model_class(config)
321model.to(torch_device)
322model.train()
323inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
324loss = model(**inputs).loss
325loss.backward()
326
327def test_training_gradient_checkpointing(self):
328config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
329if not self.model_tester.is_training:
330return
331
332config.use_cache = False
333config.return_dict = True
334
335for model_class in self.all_model_classes:
336# we don't test BeitForMaskedImageModeling
337if (
338model_class.__name__
339in [
340*MODEL_MAPPING_NAMES.values(),
341*MODEL_FOR_BACKBONE_MAPPING_NAMES.values(),
342"BeitForMaskedImageModeling",
343]
344or not model_class.supports_gradient_checkpointing
345):
346continue
347
348model = model_class(config)
349model.gradient_checkpointing_enable()
350model.to(torch_device)
351model.train()
352inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
353loss = model(**inputs).loss
354loss.backward()
355
356@unittest.skip(
357reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
358)
359def test_training_gradient_checkpointing_use_reentrant(self):
360pass
361
362@unittest.skip(
363reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
364)
365def test_training_gradient_checkpointing_use_reentrant_false(self):
366pass
367
368def test_initialization(self):
369config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
370
371configs_no_init = _config_zero_init(config)
372for model_class in self.all_model_classes:
373model = model_class(config=configs_no_init)
374for name, param in model.named_parameters():
375# we skip lambda parameters as these require special initial values
376# determined by config.layer_scale_init_value
377if "lambda" in name:
378continue
379if param.requires_grad:
380self.assertIn(
381((param.data.mean() * 1e9).round() / 1e9).item(),
382[0.0, 1.0],
383msg=f"Parameter {name} of model {model_class} seems not properly initialized",
384)
385
386@slow
387def test_model_from_pretrained(self):
388for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
389model = BeitModel.from_pretrained(model_name)
390self.assertIsNotNone(model)
391
392
393# We will verify our results on an image of cute cats
394def prepare_img():
395image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
396return image
397
398
399@require_torch
400@require_vision
401class BeitModelIntegrationTest(unittest.TestCase):
402@cached_property
403def default_image_processor(self):
404return BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None
405
406@slow
407def test_inference_masked_image_modeling_head(self):
408model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k").to(torch_device)
409
410image_processor = self.default_image_processor
411image = prepare_img()
412pixel_values = image_processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
413
414# prepare bool_masked_pos
415bool_masked_pos = torch.ones((1, 196), dtype=torch.bool).to(torch_device)
416
417# forward pass
418with torch.no_grad():
419outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)
420logits = outputs.logits
421
422# verify the logits
423expected_shape = torch.Size((1, 196, 8192))
424self.assertEqual(logits.shape, expected_shape)
425
426expected_slice = torch.tensor(
427[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]
428).to(torch_device)
429
430self.assertTrue(torch.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))
431
432@slow
433def test_inference_image_classification_head_imagenet_1k(self):
434model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device)
435
436image_processor = self.default_image_processor
437image = prepare_img()
438inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
439
440# forward pass
441with torch.no_grad():
442outputs = model(**inputs)
443logits = outputs.logits
444
445# verify the logits
446expected_shape = torch.Size((1, 1000))
447self.assertEqual(logits.shape, expected_shape)
448
449expected_slice = torch.tensor([-1.2385, -1.0987, -1.0108]).to(torch_device)
450
451self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
452
453expected_class_idx = 281
454self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
455
456@slow
457def test_inference_image_classification_head_imagenet_22k(self):
458model = BeitForImageClassification.from_pretrained("microsoft/beit-large-patch16-224-pt22k-ft22k").to(
459torch_device
460)
461
462image_processor = self.default_image_processor
463image = prepare_img()
464inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
465
466# forward pass
467with torch.no_grad():
468outputs = model(**inputs)
469logits = outputs.logits
470
471# verify the logits
472expected_shape = torch.Size((1, 21841))
473self.assertEqual(logits.shape, expected_shape)
474
475expected_slice = torch.tensor([1.6881, -0.2787, 0.5901]).to(torch_device)
476
477self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4))
478
479expected_class_idx = 2396
480self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
481
482@slow
483def test_inference_semantic_segmentation(self):
484model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
485model = model.to(torch_device)
486
487image_processor = BeitImageProcessor(do_resize=True, size=640, do_center_crop=False)
488
489ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
490image = Image.open(ds[0]["file"])
491inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
492
493# forward pass
494with torch.no_grad():
495outputs = model(**inputs)
496logits = outputs.logits
497
498# verify the logits
499expected_shape = torch.Size((1, 150, 160, 160))
500self.assertEqual(logits.shape, expected_shape)
501
502is_pillow_less_than_9 = version.parse(PIL.__version__) < version.parse("9.0.0")
503
504if is_pillow_less_than_9:
505expected_slice = torch.tensor(
506[
507[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
508[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
509[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
510],
511device=torch_device,
512)
513else:
514expected_slice = torch.tensor(
515[
516[[-4.8960, -2.3688, -3.0355], [-2.8478, -0.9836, -1.7418], [-2.9449, -1.3332, -2.1456]],
517[[-5.8081, -3.4124, -4.1006], [-3.8561, -2.2081, -3.0323], [-3.8365, -2.4601, -3.3669]],
518[[-0.0309, 3.9868, 4.0540], [2.9640, 4.6877, 4.9976], [3.2081, 4.7690, 4.9942]],
519],
520device=torch_device,
521)
522
523self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
524
525@slow
526def test_post_processing_semantic_segmentation(self):
527model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
528model = model.to(torch_device)
529
530image_processor = BeitImageProcessor(do_resize=True, size=640, do_center_crop=False)
531
532ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
533image = Image.open(ds[0]["file"])
534inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
535
536# forward pass
537with torch.no_grad():
538outputs = model(**inputs)
539
540outputs.logits = outputs.logits.detach().cpu()
541
542segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)])
543expected_shape = torch.Size((500, 300))
544self.assertEqual(segmentation[0].shape, expected_shape)
545
546segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
547expected_shape = torch.Size((160, 160))
548self.assertEqual(segmentation[0].shape, expected_shape)
549
550
551@require_torch
552class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
553all_model_classes = (BeitBackbone,) if is_torch_available() else ()
554config_class = BeitConfig
555
556def setUp(self):
557self.model_tester = BeitModelTester(self)
558