transformers
302 строки · 11.1 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch ViTDet model. """
16
17
18import unittest19
20from transformers import VitDetConfig21from transformers.testing_utils import is_flaky, require_torch, torch_device22from transformers.utils import is_torch_available23
24from ...test_backbone_common import BackboneTesterMixin25from ...test_configuration_common import ConfigTester26from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor27from ...test_pipeline_mixin import PipelineTesterMixin28
29
30if is_torch_available():31import torch32from torch import nn33
34from transformers import VitDetBackbone, VitDetModel35
36
37class VitDetModelTester:38def __init__(39self,40parent,41batch_size=13,42image_size=30,43patch_size=2,44num_channels=3,45is_training=True,46use_labels=True,47hidden_size=32,48num_hidden_layers=2,49num_attention_heads=4,50intermediate_size=37,51hidden_act="gelu",52hidden_dropout_prob=0.1,53attention_probs_dropout_prob=0.1,54type_sequence_label_size=10,55initializer_range=0.02,56scope=None,57):58self.parent = parent59self.batch_size = batch_size60self.image_size = image_size61self.patch_size = patch_size62self.num_channels = num_channels63self.is_training = is_training64self.use_labels = use_labels65self.hidden_size = hidden_size66self.num_hidden_layers = num_hidden_layers67self.num_attention_heads = num_attention_heads68self.intermediate_size = intermediate_size69self.hidden_act = hidden_act70self.hidden_dropout_prob = hidden_dropout_prob71self.attention_probs_dropout_prob = attention_probs_dropout_prob72self.type_sequence_label_size = type_sequence_label_size73self.initializer_range = initializer_range74self.scope = scope75
76self.num_patches_one_direction = self.image_size // self.patch_size77self.seq_length = (self.image_size // self.patch_size) ** 278
79def prepare_config_and_inputs(self):80pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])81
82labels = None83if self.use_labels:84labels = ids_tensor([self.batch_size], self.type_sequence_label_size)85
86config = self.get_config()87
88return config, pixel_values, labels89
90def get_config(self):91return VitDetConfig(92image_size=self.image_size,93pretrain_image_size=self.image_size,94patch_size=self.patch_size,95num_channels=self.num_channels,96hidden_size=self.hidden_size,97num_hidden_layers=self.num_hidden_layers,98num_attention_heads=self.num_attention_heads,99intermediate_size=self.intermediate_size,100hidden_act=self.hidden_act,101hidden_dropout_prob=self.hidden_dropout_prob,102attention_probs_dropout_prob=self.attention_probs_dropout_prob,103is_decoder=False,104initializer_range=self.initializer_range,105)106
107def create_and_check_model(self, config, pixel_values, labels):108model = VitDetModel(config=config)109model.to(torch_device)110model.eval()111result = model(pixel_values)112self.parent.assertEqual(113result.last_hidden_state.shape,114(self.batch_size, self.hidden_size, self.num_patches_one_direction, self.num_patches_one_direction),115)116
117def create_and_check_backbone(self, config, pixel_values, labels):118model = VitDetBackbone(config=config)119model.to(torch_device)120model.eval()121result = model(pixel_values)122
123# verify hidden states124self.parent.assertEqual(len(result.feature_maps), len(config.out_features))125self.parent.assertListEqual(126list(result.feature_maps[0].shape),127[self.batch_size, self.hidden_size, self.num_patches_one_direction, self.num_patches_one_direction],128)129
130# verify channels131self.parent.assertEqual(len(model.channels), len(config.out_features))132self.parent.assertListEqual(model.channels, [config.hidden_size])133
134# verify backbone works with out_features=None135config.out_features = None136model = VitDetBackbone(config=config)137model.to(torch_device)138model.eval()139result = model(pixel_values)140
141# verify feature maps142self.parent.assertEqual(len(result.feature_maps), 1)143self.parent.assertListEqual(144list(result.feature_maps[0].shape),145[self.batch_size, self.hidden_size, self.num_patches_one_direction, self.num_patches_one_direction],146)147
148# verify channels149self.parent.assertEqual(len(model.channels), 1)150self.parent.assertListEqual(model.channels, [config.hidden_size])151
152def prepare_config_and_inputs_for_common(self):153config_and_inputs = self.prepare_config_and_inputs()154config, pixel_values, labels = config_and_inputs155inputs_dict = {"pixel_values": pixel_values}156return config, inputs_dict157
158
159@require_torch
160class VitDetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):161"""162Here we also overwrite some of the tests of test_modeling_common.py, as VitDet does not use input_ids, inputs_embeds,
163attention_mask and seq_length.
164"""
165
166all_model_classes = (VitDetModel, VitDetBackbone) if is_torch_available() else ()167pipeline_model_mapping = {"feature-extraction": VitDetModel} if is_torch_available() else {}168
169fx_compatible = False170test_pruning = False171test_resize_embeddings = False172test_head_masking = False173
174def setUp(self):175self.model_tester = VitDetModelTester(self)176self.config_tester = ConfigTester(self, config_class=VitDetConfig, has_text_modality=False, hidden_size=37)177
178@is_flaky(max_attempts=3, description="`torch.nn.init.trunc_normal_` is flaky.")179def test_initialization(self):180super().test_initialization()181
182# TODO: Fix me (once this model gets more usage)183@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")184def test_cpu_offload(self):185super().test_cpu_offload()186
187# TODO: Fix me (once this model gets more usage)188@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")189def test_disk_offload_bin(self):190super().test_disk_offload()191
192@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")193def test_disk_offload_safetensors(self):194super().test_disk_offload()195
196# TODO: Fix me (once this model gets more usage)197@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")198def test_model_parallelism(self):199super().test_model_parallelism()200
201def test_config(self):202self.config_tester.run_common_tests()203
204@unittest.skip(reason="VitDet does not use inputs_embeds")205def test_inputs_embeds(self):206pass207
208def test_model_common_attributes(self):209config, _ = self.model_tester.prepare_config_and_inputs_for_common()210
211for model_class in self.all_model_classes:212model = model_class(config)213self.assertIsInstance(model.get_input_embeddings(), (nn.Module))214x = model.get_output_embeddings()215self.assertTrue(x is None or isinstance(x, nn.Linear))216
217def test_model(self):218config_and_inputs = self.model_tester.prepare_config_and_inputs()219self.model_tester.create_and_check_model(*config_and_inputs)220
221def test_backbone(self):222config_and_inputs = self.model_tester.prepare_config_and_inputs()223self.model_tester.create_and_check_backbone(*config_and_inputs)224
225def test_hidden_states_output(self):226def check_hidden_states_output(inputs_dict, config, model_class):227model = model_class(config)228model.to(torch_device)229model.eval()230
231with torch.no_grad():232outputs = model(**self._prepare_for_class(inputs_dict, model_class))233
234hidden_states = outputs.hidden_states235
236expected_num_stages = self.model_tester.num_hidden_layers237self.assertEqual(len(hidden_states), expected_num_stages + 1)238
239# VitDet's feature maps are of shape (batch_size, num_channels, height, width)240self.assertListEqual(241list(hidden_states[0].shape[-2:]),242[243self.model_tester.num_patches_one_direction,244self.model_tester.num_patches_one_direction,245],246)247
248config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()249
250for model_class in self.all_model_classes:251inputs_dict["output_hidden_states"] = True252check_hidden_states_output(inputs_dict, config, model_class)253
254# check that output_hidden_states also work using config255del inputs_dict["output_hidden_states"]256config.output_hidden_states = True257
258check_hidden_states_output(inputs_dict, config, model_class)259
260# overwrite since VitDet only supports retraining gradients of hidden states261def test_retain_grad_hidden_states_attentions(self):262config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()263config.output_hidden_states = True264config.output_attentions = self.has_attentions265
266# no need to test all models as different heads yield the same functionality267model_class = self.all_model_classes[0]268model = model_class(config)269model.to(torch_device)270
271inputs = self._prepare_for_class(inputs_dict, model_class)272
273outputs = model(**inputs)274
275output = outputs[0]276
277# Encoder-/Decoder-only models278hidden_states = outputs.hidden_states[0]279hidden_states.retain_grad()280
281output.flatten()[0].backward(retain_graph=True)282
283self.assertIsNotNone(hidden_states.grad)284
285@unittest.skip(reason="VitDet does not support feedforward chunking")286def test_feed_forward_chunking(self):287pass288
289@unittest.skip(reason="VitDet does not have standalone checkpoints since it used as backbone in other models")290def test_model_from_pretrained(self):291pass292
293
294@require_torch
295class VitDetBackboneTest(unittest.TestCase, BackboneTesterMixin):296all_model_classes = (VitDetBackbone,) if is_torch_available() else ()297config_class = VitDetConfig298
299has_attentions = False300
301def setUp(self):302self.model_tester = VitDetModelTester(self)303