transformers
326 строк · 11.9 Кб
1# coding=utf-8
2# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch ResNet model. """
16
17
18import unittest19
20from transformers import ResNetConfig21from transformers.testing_utils import require_torch, require_vision, slow, torch_device22from transformers.utils import cached_property, is_torch_available, is_vision_available23
24from ...test_backbone_common import BackboneTesterMixin25from ...test_configuration_common import ConfigTester26from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor27from ...test_pipeline_mixin import PipelineTesterMixin28
29
30if is_torch_available():31import torch32from torch import nn33
34from transformers import ResNetBackbone, ResNetForImageClassification, ResNetModel35from transformers.models.resnet.modeling_resnet import RESNET_PRETRAINED_MODEL_ARCHIVE_LIST36
37
38if is_vision_available():39from PIL import Image40
41from transformers import AutoImageProcessor42
43
44class ResNetModelTester:45def __init__(46self,47parent,48batch_size=3,49image_size=32,50num_channels=3,51embeddings_size=10,52hidden_sizes=[10, 20, 30, 40],53depths=[1, 1, 2, 1],54is_training=True,55use_labels=True,56hidden_act="relu",57num_labels=3,58scope=None,59out_features=["stage2", "stage3", "stage4"],60out_indices=[2, 3, 4],61):62self.parent = parent63self.batch_size = batch_size64self.image_size = image_size65self.num_channels = num_channels66self.embeddings_size = embeddings_size67self.hidden_sizes = hidden_sizes68self.depths = depths69self.is_training = is_training70self.use_labels = use_labels71self.hidden_act = hidden_act72self.num_labels = num_labels73self.scope = scope74self.num_stages = len(hidden_sizes)75self.out_features = out_features76self.out_indices = out_indices77
78def prepare_config_and_inputs(self):79pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])80
81labels = None82if self.use_labels:83labels = ids_tensor([self.batch_size], self.num_labels)84
85config = self.get_config()86
87return config, pixel_values, labels88
89def get_config(self):90return ResNetConfig(91num_channels=self.num_channels,92embeddings_size=self.embeddings_size,93hidden_sizes=self.hidden_sizes,94depths=self.depths,95hidden_act=self.hidden_act,96num_labels=self.num_labels,97out_features=self.out_features,98out_indices=self.out_indices,99)100
101def create_and_check_model(self, config, pixel_values, labels):102model = ResNetModel(config=config)103model.to(torch_device)104model.eval()105result = model(pixel_values)106# expected last hidden states: B, C, H // 32, W // 32107self.parent.assertEqual(108result.last_hidden_state.shape,109(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),110)111
112def create_and_check_for_image_classification(self, config, pixel_values, labels):113config.num_labels = self.num_labels114model = ResNetForImageClassification(config)115model.to(torch_device)116model.eval()117result = model(pixel_values, labels=labels)118self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))119
120def create_and_check_backbone(self, config, pixel_values, labels):121model = ResNetBackbone(config=config)122model.to(torch_device)123model.eval()124result = model(pixel_values)125
126# verify feature maps127self.parent.assertEqual(len(result.feature_maps), len(config.out_features))128self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])129
130# verify channels131self.parent.assertEqual(len(model.channels), len(config.out_features))132self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])133
134# verify backbone works with out_features=None135config.out_features = None136model = ResNetBackbone(config=config)137model.to(torch_device)138model.eval()139result = model(pixel_values)140
141# verify feature maps142self.parent.assertEqual(len(result.feature_maps), 1)143self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])144
145# verify channels146self.parent.assertEqual(len(model.channels), 1)147self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])148
149def prepare_config_and_inputs_for_common(self):150config_and_inputs = self.prepare_config_and_inputs()151config, pixel_values, labels = config_and_inputs152inputs_dict = {"pixel_values": pixel_values}153return config, inputs_dict154
155
156@require_torch
157class ResNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):158"""159Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
160attention_mask and seq_length.
161"""
162
163all_model_classes = (164(165ResNetModel,166ResNetForImageClassification,167ResNetBackbone,168)169if is_torch_available()170else ()171)172pipeline_model_mapping = (173{"image-feature-extraction": ResNetModel, "image-classification": ResNetForImageClassification}174if is_torch_available()175else {}176)177
178fx_compatible = True179test_pruning = False180test_resize_embeddings = False181test_head_masking = False182has_attentions = False183
184def setUp(self):185self.model_tester = ResNetModelTester(self)186self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)187
188def test_config(self):189self.create_and_test_config_common_properties()190self.config_tester.create_and_test_config_to_json_string()191self.config_tester.create_and_test_config_to_json_file()192self.config_tester.create_and_test_config_from_and_save_pretrained()193self.config_tester.create_and_test_config_with_num_labels()194self.config_tester.check_config_can_be_init_without_params()195self.config_tester.check_config_arguments_init()196
197def create_and_test_config_common_properties(self):198return199
200@unittest.skip(reason="ResNet does not use inputs_embeds")201def test_inputs_embeds(self):202pass203
204@unittest.skip(reason="ResNet does not support input and output embeddings")205def test_model_common_attributes(self):206pass207
208def test_model(self):209config_and_inputs = self.model_tester.prepare_config_and_inputs()210self.model_tester.create_and_check_model(*config_and_inputs)211
212def test_backbone(self):213config_and_inputs = self.model_tester.prepare_config_and_inputs()214self.model_tester.create_and_check_backbone(*config_and_inputs)215
216def test_initialization(self):217config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()218
219for model_class in self.all_model_classes:220model = model_class(config=config)221for name, module in model.named_modules():222if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):223self.assertTrue(224torch.all(module.weight == 1),225msg=f"Parameter {name} of model {model_class} seems not properly initialized",226)227self.assertTrue(228torch.all(module.bias == 0),229msg=f"Parameter {name} of model {model_class} seems not properly initialized",230)231
232def test_hidden_states_output(self):233def check_hidden_states_output(inputs_dict, config, model_class):234model = model_class(config)235model.to(torch_device)236model.eval()237
238with torch.no_grad():239outputs = model(**self._prepare_for_class(inputs_dict, model_class))240
241hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states242
243expected_num_stages = self.model_tester.num_stages244self.assertEqual(len(hidden_states), expected_num_stages + 1)245
246# ResNet's feature maps are of shape (batch_size, num_channels, height, width)247self.assertListEqual(248list(hidden_states[0].shape[-2:]),249[self.model_tester.image_size // 4, self.model_tester.image_size // 4],250)251
252config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()253layers_type = ["basic", "bottleneck"]254for model_class in self.all_model_classes:255for layer_type in layers_type:256config.layer_type = layer_type257inputs_dict["output_hidden_states"] = True258check_hidden_states_output(inputs_dict, config, model_class)259
260# check that output_hidden_states also work using config261del inputs_dict["output_hidden_states"]262config.output_hidden_states = True263
264check_hidden_states_output(inputs_dict, config, model_class)265
266@unittest.skip(reason="ResNet does not use feedforward chunking")267def test_feed_forward_chunking(self):268pass269
270def test_for_image_classification(self):271config_and_inputs = self.model_tester.prepare_config_and_inputs()272self.model_tester.create_and_check_for_image_classification(*config_and_inputs)273
274@slow275def test_model_from_pretrained(self):276for model_name in RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:277model = ResNetModel.from_pretrained(model_name)278self.assertIsNotNone(model)279
280
281# We will verify our results on an image of cute cats
282def prepare_img():283image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")284return image285
286
287@require_torch
288@require_vision
289class ResNetModelIntegrationTest(unittest.TestCase):290@cached_property291def default_image_processor(self):292return (293AutoImageProcessor.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])294if is_vision_available()295else None296)297
298@slow299def test_inference_image_classification_head(self):300model = ResNetForImageClassification.from_pretrained(RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0]).to(torch_device)301
302image_processor = self.default_image_processor303image = prepare_img()304inputs = image_processor(images=image, return_tensors="pt").to(torch_device)305
306# forward pass307with torch.no_grad():308outputs = model(**inputs)309
310# verify the logits311expected_shape = torch.Size((1, 1000))312self.assertEqual(outputs.logits.shape, expected_shape)313
314expected_slice = torch.tensor([-11.1069, -9.7877, -8.3777]).to(torch_device)315
316self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))317
318
319@require_torch
320class ResNetBackboneTest(BackboneTesterMixin, unittest.TestCase):321all_model_classes = (ResNetBackbone,) if is_torch_available() else ()322has_attentions = False323config_class = ResNetConfig324
325def setUp(self):326self.model_tester = ResNetModelTester(self)327