transformers
383 строки · 14.0 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch Nat model. """
16
17import collections18import unittest19
20from transformers import NatConfig21from transformers.testing_utils import require_natten, require_torch, require_vision, slow, torch_device22from transformers.utils import cached_property, is_torch_available, is_vision_available23
24from ...test_backbone_common import BackboneTesterMixin25from ...test_configuration_common import ConfigTester26from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor27from ...test_pipeline_mixin import PipelineTesterMixin28
29
30if is_torch_available():31import torch32from torch import nn33
34from transformers import NatBackbone, NatForImageClassification, NatModel35from transformers.models.nat.modeling_nat import NAT_PRETRAINED_MODEL_ARCHIVE_LIST36
37if is_vision_available():38from PIL import Image39
40from transformers import AutoImageProcessor41
42
43class NatModelTester:44def __init__(45self,46parent,47batch_size=13,48image_size=64,49patch_size=4,50num_channels=3,51embed_dim=16,52depths=[1, 2, 1],53num_heads=[2, 4, 8],54kernel_size=3,55mlp_ratio=2.0,56qkv_bias=True,57hidden_dropout_prob=0.0,58attention_probs_dropout_prob=0.0,59drop_path_rate=0.1,60hidden_act="gelu",61patch_norm=True,62initializer_range=0.02,63layer_norm_eps=1e-5,64is_training=True,65scope=None,66use_labels=True,67num_labels=10,68out_features=["stage1", "stage2"],69out_indices=[1, 2],70):71self.parent = parent72self.batch_size = batch_size73self.image_size = image_size74self.patch_size = patch_size75self.num_channels = num_channels76self.embed_dim = embed_dim77self.depths = depths78self.num_heads = num_heads79self.kernel_size = kernel_size80self.mlp_ratio = mlp_ratio81self.qkv_bias = qkv_bias82self.hidden_dropout_prob = hidden_dropout_prob83self.attention_probs_dropout_prob = attention_probs_dropout_prob84self.drop_path_rate = drop_path_rate85self.hidden_act = hidden_act86self.patch_norm = patch_norm87self.layer_norm_eps = layer_norm_eps88self.initializer_range = initializer_range89self.is_training = is_training90self.scope = scope91self.use_labels = use_labels92self.num_labels = num_labels93self.out_features = out_features94self.out_indices = out_indices95
96def prepare_config_and_inputs(self):97pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])98
99labels = None100if self.use_labels:101labels = ids_tensor([self.batch_size], self.num_labels)102
103config = self.get_config()104
105return config, pixel_values, labels106
107def get_config(self):108return NatConfig(109num_labels=self.num_labels,110image_size=self.image_size,111patch_size=self.patch_size,112num_channels=self.num_channels,113embed_dim=self.embed_dim,114depths=self.depths,115num_heads=self.num_heads,116kernel_size=self.kernel_size,117mlp_ratio=self.mlp_ratio,118qkv_bias=self.qkv_bias,119hidden_dropout_prob=self.hidden_dropout_prob,120attention_probs_dropout_prob=self.attention_probs_dropout_prob,121drop_path_rate=self.drop_path_rate,122hidden_act=self.hidden_act,123patch_norm=self.patch_norm,124layer_norm_eps=self.layer_norm_eps,125initializer_range=self.initializer_range,126out_features=self.out_features,127out_indices=self.out_indices,128)129
130def create_and_check_model(self, config, pixel_values, labels):131model = NatModel(config=config)132model.to(torch_device)133model.eval()134result = model(pixel_values)135
136expected_height = expected_width = (config.image_size // config.patch_size) // (2 ** (len(config.depths) - 1))137expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))138
139self.parent.assertEqual(140result.last_hidden_state.shape, (self.batch_size, expected_height, expected_width, expected_dim)141)142
143def create_and_check_for_image_classification(self, config, pixel_values, labels):144model = NatForImageClassification(config)145model.to(torch_device)146model.eval()147result = model(pixel_values, labels=labels)148self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))149
150# test greyscale images151config.num_channels = 1152model = NatForImageClassification(config)153model.to(torch_device)154model.eval()155
156pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])157result = model(pixel_values)158self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))159
160def create_and_check_backbone(self, config, pixel_values, labels):161model = NatBackbone(config=config)162model.to(torch_device)163model.eval()164result = model(pixel_values)165
166# verify hidden states167self.parent.assertEqual(len(result.feature_maps), len(config.out_features))168self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])169
170# verify channels171self.parent.assertEqual(len(model.channels), len(config.out_features))172
173# verify backbone works with out_features=None174config.out_features = None175model = NatBackbone(config=config)176model.to(torch_device)177model.eval()178result = model(pixel_values)179
180# verify feature maps181self.parent.assertEqual(len(result.feature_maps), 1)182self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])183
184# verify channels185self.parent.assertEqual(len(model.channels), 1)186
187def prepare_config_and_inputs_for_common(self):188config_and_inputs = self.prepare_config_and_inputs()189config, pixel_values, labels = config_and_inputs190inputs_dict = {"pixel_values": pixel_values}191return config, inputs_dict192
193
194@require_natten
195@require_torch
196class NatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):197all_model_classes = (198(199NatModel,200NatForImageClassification,201NatBackbone,202)203if is_torch_available()204else ()205)206pipeline_model_mapping = (207{"image-feature-extraction": NatModel, "image-classification": NatForImageClassification}208if is_torch_available()209else {}210)211fx_compatible = False212
213test_torchscript = False214test_pruning = False215test_resize_embeddings = False216test_head_masking = False217
218def setUp(self):219self.model_tester = NatModelTester(self)220self.config_tester = ConfigTester(self, config_class=NatConfig, embed_dim=37)221
222def test_config(self):223self.create_and_test_config_common_properties()224self.config_tester.create_and_test_config_to_json_string()225self.config_tester.create_and_test_config_to_json_file()226self.config_tester.create_and_test_config_from_and_save_pretrained()227self.config_tester.create_and_test_config_with_num_labels()228self.config_tester.check_config_can_be_init_without_params()229self.config_tester.check_config_arguments_init()230
231def create_and_test_config_common_properties(self):232return233
234def test_model(self):235config_and_inputs = self.model_tester.prepare_config_and_inputs()236self.model_tester.create_and_check_model(*config_and_inputs)237
238def test_for_image_classification(self):239config_and_inputs = self.model_tester.prepare_config_and_inputs()240self.model_tester.create_and_check_for_image_classification(*config_and_inputs)241
242def test_backbone(self):243config_and_inputs = self.model_tester.prepare_config_and_inputs()244self.model_tester.create_and_check_backbone(*config_and_inputs)245
246@unittest.skip(reason="Nat does not use inputs_embeds")247def test_inputs_embeds(self):248pass249
250@unittest.skip(reason="Nat does not use feedforward chunking")251def test_feed_forward_chunking(self):252pass253
254def test_model_common_attributes(self):255config, _ = self.model_tester.prepare_config_and_inputs_for_common()256
257for model_class in self.all_model_classes:258model = model_class(config)259self.assertIsInstance(model.get_input_embeddings(), (nn.Module))260x = model.get_output_embeddings()261self.assertTrue(x is None or isinstance(x, nn.Linear))262
263def test_attention_outputs(self):264self.skipTest("Nat's attention operation is handled entirely by NATTEN.")265
266def check_hidden_states_output(self, inputs_dict, config, model_class, image_size):267model = model_class(config)268model.to(torch_device)269model.eval()270
271with torch.no_grad():272outputs = model(**self._prepare_for_class(inputs_dict, model_class))273
274hidden_states = outputs.hidden_states275
276expected_num_layers = getattr(277self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1278)279self.assertEqual(len(hidden_states), expected_num_layers)280
281# Nat has a different seq_length282patch_size = (283config.patch_size284if isinstance(config.patch_size, collections.abc.Iterable)285else (config.patch_size, config.patch_size)286)287
288height = image_size[0] // patch_size[0]289width = image_size[1] // patch_size[1]290
291self.assertListEqual(292list(hidden_states[0].shape[-3:]),293[height, width, self.model_tester.embed_dim],294)295
296if model_class.__name__ != "NatBackbone":297reshaped_hidden_states = outputs.reshaped_hidden_states298self.assertEqual(len(reshaped_hidden_states), expected_num_layers)299
300batch_size, num_channels, height, width = reshaped_hidden_states[0].shape301reshaped_hidden_states = (302reshaped_hidden_states[0].view(batch_size, num_channels, height, width).permute(0, 2, 3, 1)303)304self.assertListEqual(305list(reshaped_hidden_states.shape[-3:]),306[height, width, self.model_tester.embed_dim],307)308
309def test_hidden_states_output(self):310config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()311
312image_size = (313self.model_tester.image_size314if isinstance(self.model_tester.image_size, collections.abc.Iterable)315else (self.model_tester.image_size, self.model_tester.image_size)316)317
318for model_class in self.all_model_classes:319inputs_dict["output_hidden_states"] = True320self.check_hidden_states_output(inputs_dict, config, model_class, image_size)321
322# check that output_hidden_states also work using config323del inputs_dict["output_hidden_states"]324config.output_hidden_states = True325
326self.check_hidden_states_output(inputs_dict, config, model_class, image_size)327
328@slow329def test_model_from_pretrained(self):330for model_name in NAT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:331model = NatModel.from_pretrained(model_name)332self.assertIsNotNone(model)333
334def test_initialization(self):335config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()336
337configs_no_init = _config_zero_init(config)338for model_class in self.all_model_classes:339model = model_class(config=configs_no_init)340for name, param in model.named_parameters():341if "embeddings" not in name and param.requires_grad:342self.assertIn(343((param.data.mean() * 1e9).round() / 1e9).item(),344[0.0, 1.0],345msg=f"Parameter {name} of model {model_class} seems not properly initialized",346)347
348
349@require_natten
350@require_vision
351@require_torch
352class NatModelIntegrationTest(unittest.TestCase):353@cached_property354def default_image_processor(self):355return AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") if is_vision_available() else None356
357@slow358def test_inference_image_classification_head(self):359model = NatForImageClassification.from_pretrained("shi-labs/nat-mini-in1k-224").to(torch_device)360image_processor = self.default_image_processor361
362image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")363inputs = image_processor(images=image, return_tensors="pt").to(torch_device)364
365# forward pass366with torch.no_grad():367outputs = model(**inputs)368
369# verify the logits370expected_shape = torch.Size((1, 1000))371self.assertEqual(outputs.logits.shape, expected_shape)372expected_slice = torch.tensor([0.3805, -0.8676, -0.3912]).to(torch_device)373self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))374
375
376@require_torch
377@require_natten
378class NatBackboneTest(unittest.TestCase, BackboneTesterMixin):379all_model_classes = (NatBackbone,) if is_torch_available() else ()380config_class = NatConfig381
382def setUp(self):383self.model_tester = NatModelTester(self)384