transformers
294 строки · 11.1 Кб
1# Copyright 2021 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import inspect16import unittest17
18import numpy as np19
20from transformers import BeitConfig21from transformers.testing_utils import require_flax, require_vision, slow22from transformers.utils import cached_property, is_flax_available, is_vision_available23
24from ...test_configuration_common import ConfigTester25from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor26
27
28if is_flax_available():29import jax30
31from transformers import FlaxBeitForImageClassification, FlaxBeitForMaskedImageModeling, FlaxBeitModel32
33if is_vision_available():34from PIL import Image35
36from transformers import BeitImageProcessor37
38
39class FlaxBeitModelTester(unittest.TestCase):40def __init__(41self,42parent,43vocab_size=100,44batch_size=13,45image_size=30,46patch_size=2,47num_channels=3,48is_training=True,49use_labels=True,50hidden_size=32,51num_hidden_layers=2,52num_attention_heads=4,53intermediate_size=37,54hidden_act="gelu",55hidden_dropout_prob=0.1,56attention_probs_dropout_prob=0.1,57type_sequence_label_size=10,58initializer_range=0.02,59num_labels=3,60):61self.parent = parent62self.vocab_size = vocab_size63self.batch_size = batch_size64self.image_size = image_size65self.patch_size = patch_size66self.num_channels = num_channels67self.is_training = is_training68self.use_labels = use_labels69self.hidden_size = hidden_size70self.num_hidden_layers = num_hidden_layers71self.num_attention_heads = num_attention_heads72self.intermediate_size = intermediate_size73self.hidden_act = hidden_act74self.hidden_dropout_prob = hidden_dropout_prob75self.attention_probs_dropout_prob = attention_probs_dropout_prob76self.type_sequence_label_size = type_sequence_label_size77self.initializer_range = initializer_range78
79# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)80num_patches = (image_size // patch_size) ** 281self.seq_length = num_patches + 182
83def prepare_config_and_inputs(self):84pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])85
86labels = None87if self.use_labels:88labels = ids_tensor([self.batch_size], self.type_sequence_label_size)89
90config = BeitConfig(91vocab_size=self.vocab_size,92image_size=self.image_size,93patch_size=self.patch_size,94num_channels=self.num_channels,95hidden_size=self.hidden_size,96num_hidden_layers=self.num_hidden_layers,97num_attention_heads=self.num_attention_heads,98intermediate_size=self.intermediate_size,99hidden_act=self.hidden_act,100hidden_dropout_prob=self.hidden_dropout_prob,101attention_probs_dropout_prob=self.attention_probs_dropout_prob,102is_decoder=False,103initializer_range=self.initializer_range,104)105
106return config, pixel_values, labels107
108def create_and_check_model(self, config, pixel_values, labels):109model = FlaxBeitModel(config=config)110result = model(pixel_values)111self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))112
113def create_and_check_for_masked_lm(self, config, pixel_values, labels):114model = FlaxBeitForMaskedImageModeling(config=config)115result = model(pixel_values)116self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length - 1, self.vocab_size))117
118def create_and_check_for_image_classification(self, config, pixel_values, labels):119config.num_labels = self.type_sequence_label_size120model = FlaxBeitForImageClassification(config=config)121result = model(pixel_values)122self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))123
124# test greyscale images125config.num_channels = 1126model = FlaxBeitForImageClassification(config)127
128pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])129result = model(pixel_values)130
131def prepare_config_and_inputs_for_common(self):132config_and_inputs = self.prepare_config_and_inputs()133(134config,135pixel_values,136labels,137) = config_and_inputs138inputs_dict = {"pixel_values": pixel_values}139return config, inputs_dict140
141
142@require_flax
143class FlaxBeitModelTest(FlaxModelTesterMixin, unittest.TestCase):144all_model_classes = (145(FlaxBeitModel, FlaxBeitForImageClassification, FlaxBeitForMaskedImageModeling) if is_flax_available() else ()146)147
148def setUp(self) -> None:149self.model_tester = FlaxBeitModelTester(self)150self.config_tester = ConfigTester(self, config_class=BeitConfig, has_text_modality=False, hidden_size=37)151
152def test_config(self):153self.config_tester.run_common_tests()154
155# We need to override this test because Beit's forward signature is different than text models.156def test_forward_signature(self):157config, _ = self.model_tester.prepare_config_and_inputs_for_common()158
159for model_class in self.all_model_classes:160model = model_class(config)161signature = inspect.signature(model.__call__)162# signature.parameters is an OrderedDict => so arg_names order is deterministic163arg_names = [*signature.parameters.keys()]164
165expected_arg_names = ["pixel_values"]166self.assertListEqual(arg_names[:1], expected_arg_names)167
168# We need to override this test because Beit expects pixel_values instead of input_ids169def test_jit_compilation(self):170config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()171
172for model_class in self.all_model_classes:173with self.subTest(model_class.__name__):174prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)175model = model_class(config)176
177@jax.jit178def model_jitted(pixel_values, **kwargs):179return model(pixel_values=pixel_values, **kwargs)180
181with self.subTest("JIT Enabled"):182jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()183
184with self.subTest("JIT Disabled"):185with jax.disable_jit():186outputs = model_jitted(**prepared_inputs_dict).to_tuple()187
188self.assertEqual(len(outputs), len(jitted_outputs))189for jitted_output, output in zip(jitted_outputs, outputs):190self.assertEqual(jitted_output.shape, output.shape)191
192def test_model(self):193config_and_inputs = self.model_tester.prepare_config_and_inputs()194self.model_tester.create_and_check_model(*config_and_inputs)195
196def test_for_masked_lm(self):197config_and_inputs = self.model_tester.prepare_config_and_inputs()198self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)199
200def test_for_image_classification(self):201config_and_inputs = self.model_tester.prepare_config_and_inputs()202self.model_tester.create_and_check_for_image_classification(*config_and_inputs)203
204@slow205def test_model_from_pretrained(self):206for model_class_name in self.all_model_classes:207model = model_class_name.from_pretrained("microsoft/beit-base-patch16-224")208outputs = model(np.ones((1, 3, 224, 224)))209self.assertIsNotNone(outputs)210
211
212# We will verify our results on an image of cute cats
213def prepare_img():214image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")215return image216
217
218@require_vision
219@require_flax
220class FlaxBeitModelIntegrationTest(unittest.TestCase):221@cached_property222def default_image_processor(self):223return BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None224
225@slow226def test_inference_masked_image_modeling_head(self):227model = FlaxBeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")228
229image_processor = self.default_image_processor230image = prepare_img()231pixel_values = image_processor(images=image, return_tensors="np").pixel_values232
233# prepare bool_masked_pos234bool_masked_pos = np.ones((1, 196), dtype=bool)235
236# forward pass237outputs = model(pixel_values=pixel_values, bool_masked_pos=bool_masked_pos)238logits = outputs.logits239
240# verify the logits241expected_shape = (1, 196, 8192)242self.assertEqual(logits.shape, expected_shape)243
244expected_slice = np.array(245[[-3.2437, 0.5072, -13.9174], [-3.2456, 0.4948, -13.9401], [-3.2033, 0.5121, -13.8550]]246)247
248self.assertTrue(np.allclose(logits[bool_masked_pos][:3, :3], expected_slice, atol=1e-2))249
250@slow251def test_inference_image_classification_head_imagenet_1k(self):252model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")253
254image_processor = self.default_image_processor255image = prepare_img()256inputs = image_processor(images=image, return_tensors="np")257
258# forward pass259outputs = model(**inputs)260logits = outputs.logits261
262# verify the logits263expected_shape = (1, 1000)264self.assertEqual(logits.shape, expected_shape)265
266expected_slice = np.array([-1.2385, -1.0987, -1.0108])267
268self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))269
270expected_class_idx = 281271self.assertEqual(logits.argmax(-1).item(), expected_class_idx)272
273@slow274def test_inference_image_classification_head_imagenet_22k(self):275model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-large-patch16-224-pt22k-ft22k")276
277image_processor = self.default_image_processor278image = prepare_img()279inputs = image_processor(images=image, return_tensors="np")280
281# forward pass282outputs = model(**inputs)283logits = outputs.logits284
285# verify the logits286expected_shape = (1, 21841)287self.assertEqual(logits.shape, expected_shape)288
289expected_slice = np.array([1.6881, -0.2787, 0.5901])290
291self.assertTrue(np.allclose(logits[0, :3], expected_slice, atol=1e-4))292
293expected_class_idx = 2396294self.assertEqual(logits.argmax(-1).item(), expected_class_idx)295