transformers

Форк
0
/
test_modeling_vit_msn.py 
232 строки · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch ViTMSN model. """
16

17

18
import unittest
19

20
from transformers import ViTMSNConfig
21
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
22
from transformers.utils import cached_property, is_torch_available, is_vision_available
23

24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_torch_available():
30
    import torch
31
    from torch import nn
32

33
    from transformers import ViTMSNForImageClassification, ViTMSNModel
34
    from transformers.models.vit_msn.modeling_vit_msn import VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST
35

36

37
if is_vision_available():
38
    from PIL import Image
39

40
    from transformers import ViTImageProcessor
41

42

43
class ViTMSNModelTester:
44
    def __init__(
45
        self,
46
        parent,
47
        batch_size=13,
48
        image_size=30,
49
        patch_size=2,
50
        num_channels=3,
51
        is_training=True,
52
        use_labels=True,
53
        hidden_size=32,
54
        num_hidden_layers=2,
55
        num_attention_heads=4,
56
        intermediate_size=37,
57
        hidden_act="gelu",
58
        hidden_dropout_prob=0.1,
59
        attention_probs_dropout_prob=0.1,
60
        type_sequence_label_size=10,
61
        initializer_range=0.02,
62
        scope=None,
63
    ):
64
        self.parent = parent
65
        self.batch_size = batch_size
66
        self.image_size = image_size
67
        self.patch_size = patch_size
68
        self.num_channels = num_channels
69
        self.is_training = is_training
70
        self.use_labels = use_labels
71
        self.hidden_size = hidden_size
72
        self.num_hidden_layers = num_hidden_layers
73
        self.num_attention_heads = num_attention_heads
74
        self.intermediate_size = intermediate_size
75
        self.hidden_act = hidden_act
76
        self.hidden_dropout_prob = hidden_dropout_prob
77
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
78
        self.type_sequence_label_size = type_sequence_label_size
79
        self.initializer_range = initializer_range
80
        self.scope = scope
81

82
        # in ViT MSN, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
83
        num_patches = (image_size // patch_size) ** 2
84
        self.seq_length = num_patches + 1
85

86
    def prepare_config_and_inputs(self):
87
        pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
88

89
        labels = None
90
        if self.use_labels:
91
            labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
92

93
        config = self.get_config()
94

95
        return config, pixel_values, labels
96

97
    def get_config(self):
98
        return ViTMSNConfig(
99
            image_size=self.image_size,
100
            patch_size=self.patch_size,
101
            num_channels=self.num_channels,
102
            hidden_size=self.hidden_size,
103
            num_hidden_layers=self.num_hidden_layers,
104
            num_attention_heads=self.num_attention_heads,
105
            intermediate_size=self.intermediate_size,
106
            hidden_act=self.hidden_act,
107
            hidden_dropout_prob=self.hidden_dropout_prob,
108
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
109
            initializer_range=self.initializer_range,
110
        )
111

112
    def create_and_check_model(self, config, pixel_values, labels):
113
        model = ViTMSNModel(config=config)
114
        model.to(torch_device)
115
        model.eval()
116
        result = model(pixel_values)
117
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
118

119
    def create_and_check_for_image_classification(self, config, pixel_values, labels):
120
        config.num_labels = self.type_sequence_label_size
121
        model = ViTMSNForImageClassification(config)
122
        model.to(torch_device)
123
        model.eval()
124
        result = model(pixel_values, labels=labels)
125
        print("Pixel and labels shape: {pixel_values.shape}, {labels.shape}")
126
        print("Labels: {labels}")
127
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
128

129
        # test greyscale images
130
        config.num_channels = 1
131
        model = ViTMSNForImageClassification(config)
132
        model.to(torch_device)
133
        model.eval()
134

135
        pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
136
        result = model(pixel_values)
137
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
138

139
    def prepare_config_and_inputs_for_common(self):
140
        config_and_inputs = self.prepare_config_and_inputs()
141
        config, pixel_values, labels = config_and_inputs
142
        inputs_dict = {"pixel_values": pixel_values}
143
        return config, inputs_dict
144

145

146
@require_torch
147
class ViTMSNModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
148
    """
149
    Here we also overwrite some of the tests of test_modeling_common.py, as ViTMSN does not use input_ids, inputs_embeds,
150
    attention_mask and seq_length.
151
    """
152

153
    all_model_classes = (ViTMSNModel, ViTMSNForImageClassification) if is_torch_available() else ()
154
    pipeline_model_mapping = (
155
        {"image-feature-extraction": ViTMSNModel, "image-classification": ViTMSNForImageClassification}
156
        if is_torch_available()
157
        else {}
158
    )
159

160
    test_pruning = False
161
    test_torchscript = False
162
    test_resize_embeddings = False
163
    test_head_masking = False
164

165
    def setUp(self):
166
        self.model_tester = ViTMSNModelTester(self)
167
        self.config_tester = ConfigTester(self, config_class=ViTMSNConfig, has_text_modality=False, hidden_size=37)
168

169
    def test_config(self):
170
        self.config_tester.run_common_tests()
171

172
    @unittest.skip(reason="ViTMSN does not use inputs_embeds")
173
    def test_inputs_embeds(self):
174
        pass
175

176
    def test_model_common_attributes(self):
177
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
178

179
        for model_class in self.all_model_classes:
180
            model = model_class(config)
181
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
182
            x = model.get_output_embeddings()
183
            self.assertTrue(x is None or isinstance(x, nn.Linear))
184

185
    def test_model(self):
186
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
187
        self.model_tester.create_and_check_model(*config_and_inputs)
188

189
    def test_for_image_classification(self):
190
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
191
        self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
192

193
    @slow
194
    def test_model_from_pretrained(self):
195
        for model_name in VIT_MSN_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
196
            model = ViTMSNModel.from_pretrained(model_name)
197
            self.assertIsNotNone(model)
198

199

200
# We will verify our results on an image of cute cats
201
def prepare_img():
202
    image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
203
    return image
204

205

206
@require_torch
207
@require_vision
208
class ViTMSNModelIntegrationTest(unittest.TestCase):
209
    @cached_property
210
    def default_image_processor(self):
211
        return ViTImageProcessor.from_pretrained("facebook/vit-msn-small") if is_vision_available() else None
212

213
    @slow
214
    def test_inference_image_classification_head(self):
215
        torch.manual_seed(2)
216
        model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small").to(torch_device)
217

218
        image_processor = self.default_image_processor
219
        image = prepare_img()
220
        inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
221

222
        # forward pass
223
        with torch.no_grad():
224
            outputs = model(**inputs)
225

226
        # verify the logits
227
        expected_shape = torch.Size((1, 1000))
228
        self.assertEqual(outputs.logits.shape, expected_shape)
229

230
        expected_slice = torch.tensor([0.5588, 0.6853, -0.5929]).to(torch_device)
231

232
        self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
233

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.