transformers
262 строки · 9.6 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch MGP-STR model. """
16
17import unittest18
19import requests20
21from transformers import MgpstrConfig22from transformers.testing_utils import require_torch, require_vision, slow, torch_device23from transformers.utils import is_torch_available, is_vision_available24
25from ...test_configuration_common import ConfigTester26from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor27from ...test_pipeline_mixin import PipelineTesterMixin28
29
30if is_torch_available():31import torch32from torch import nn33
34from transformers import MgpstrForSceneTextRecognition, MgpstrModel35
36
37if is_vision_available():38from PIL import Image39
40from transformers import MgpstrProcessor41
42
43class MgpstrModelTester:44def __init__(45self,46parent,47is_training=False,48batch_size=13,49image_size=(32, 128),50patch_size=4,51num_channels=3,52max_token_length=27,53num_character_labels=38,54num_bpe_labels=99,55num_wordpiece_labels=99,56hidden_size=32,57num_hidden_layers=2,58num_attention_heads=4,59mlp_ratio=4.0,60patch_embeds_hidden_size=257,61output_hidden_states=None,62):63self.parent = parent64self.is_training = is_training65self.batch_size = batch_size66self.image_size = image_size67self.patch_size = patch_size68self.num_channels = num_channels69self.max_token_length = max_token_length70self.num_character_labels = num_character_labels71self.num_bpe_labels = num_bpe_labels72self.num_wordpiece_labels = num_wordpiece_labels73self.hidden_size = hidden_size74self.num_hidden_layers = num_hidden_layers75self.num_attention_heads = num_attention_heads76self.mlp_ratio = mlp_ratio77self.patch_embeds_hidden_size = patch_embeds_hidden_size78self.output_hidden_states = output_hidden_states79
80def prepare_config_and_inputs(self):81pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])82config = self.get_config()83return config, pixel_values84
85def get_config(self):86return MgpstrConfig(87image_size=self.image_size,88patch_size=self.patch_size,89num_channels=self.num_channels,90max_token_length=self.max_token_length,91num_character_labels=self.num_character_labels,92num_bpe_labels=self.num_bpe_labels,93num_wordpiece_labels=self.num_wordpiece_labels,94hidden_size=self.hidden_size,95num_hidden_layers=self.num_hidden_layers,96num_attention_heads=self.num_attention_heads,97mlp_ratio=self.mlp_ratio,98output_hidden_states=self.output_hidden_states,99)100
101def create_and_check_model(self, config, pixel_values):102model = MgpstrForSceneTextRecognition(config)103model.to(torch_device)104model.eval()105with torch.no_grad():106generated_ids = model(pixel_values)107self.parent.assertEqual(108generated_ids[0][0].shape, (self.batch_size, self.max_token_length, self.num_character_labels)109)110
111def prepare_config_and_inputs_for_common(self):112config_and_inputs = self.prepare_config_and_inputs()113config, pixel_values = config_and_inputs114inputs_dict = {"pixel_values": pixel_values}115return config, inputs_dict116
117
118@require_torch
119class MgpstrModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):120all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()121pipeline_model_mapping = (122{"feature-extraction": MgpstrForSceneTextRecognition, "image-feature-extraction": MgpstrModel}123if is_torch_available()124else {}125)126fx_compatible = False127
128test_pruning = False129test_resize_embeddings = False130test_head_masking = False131test_attention_outputs = False132
133def setUp(self):134self.model_tester = MgpstrModelTester(self)135self.config_tester = ConfigTester(self, config_class=MgpstrConfig, has_text_modality=False)136
137def test_config(self):138self.config_tester.run_common_tests()139
140def test_model(self):141config_and_inputs = self.model_tester.prepare_config_and_inputs()142self.model_tester.create_and_check_model(*config_and_inputs)143
144@unittest.skip(reason="MgpstrModel does not use inputs_embeds")145def test_inputs_embeds(self):146pass147
148def test_model_common_attributes(self):149config, _ = self.model_tester.prepare_config_and_inputs_for_common()150
151for model_class in self.all_model_classes:152model = model_class(config)153self.assertIsInstance(model.get_input_embeddings(), (nn.Module))154x = model.get_output_embeddings()155self.assertTrue(x is None or isinstance(x, nn.Linear))156
157@unittest.skip(reason="MgpstrModel does not support feedforward chunking")158def test_feed_forward_chunking(self):159pass160
161def test_gradient_checkpointing_backward_compatibility(self):162config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()163
164for model_class in self.all_model_classes:165if not model_class.supports_gradient_checkpointing:166continue167
168config.gradient_checkpointing = True169model = model_class(config)170self.assertTrue(model.is_gradient_checkpointing)171
172def test_hidden_states_output(self):173def check_hidden_states_output(inputs_dict, config, model_class):174model = model_class(config)175model.to(torch_device)176model.eval()177
178with torch.no_grad():179outputs = model(**self._prepare_for_class(inputs_dict, model_class))180
181hidden_states = outputs.hidden_states182
183expected_num_layers = getattr(184self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1185)186self.assertEqual(len(hidden_states), expected_num_layers)187
188self.assertListEqual(189list(hidden_states[0].shape[-2:]),190[self.model_tester.patch_embeds_hidden_size, self.model_tester.hidden_size],191)192
193config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()194
195for model_class in self.all_model_classes:196inputs_dict["output_hidden_states"] = True197check_hidden_states_output(inputs_dict, config, model_class)198
199# check that output_hidden_states also work using config200del inputs_dict["output_hidden_states"]201config.output_hidden_states = True202
203check_hidden_states_output(inputs_dict, config, model_class)204
205# override as the `logit_scale` parameter initilization is different for MgpstrModel206def test_initialization(self):207config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()208
209configs_no_init = _config_zero_init(config)210for model_class in self.all_model_classes:211model = model_class(config=configs_no_init)212for name, param in model.named_parameters():213if isinstance(param, (nn.Linear, nn.Conv2d, nn.LayerNorm)):214if param.requires_grad:215self.assertIn(216((param.data.mean() * 1e9).round() / 1e9).item(),217[0.0, 1.0],218msg=f"Parameter {name} of model {model_class} seems not properly initialized",219)220
221@unittest.skip(reason="Retain_grad is tested in individual model tests")222def test_retain_grad_hidden_states_attentions(self):223pass224
225
226# We will verify our results on an image from the IIIT-5k dataset
227def prepare_img():228url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"229im = Image.open(requests.get(url, stream=True).raw).convert("RGB")230return im231
232
233@require_vision
234@require_torch
235class MgpstrModelIntegrationTest(unittest.TestCase):236@slow237def test_inference(self):238model_name = "alibaba-damo/mgp-str-base"239model = MgpstrForSceneTextRecognition.from_pretrained(model_name).to(torch_device)240processor = MgpstrProcessor.from_pretrained(model_name)241
242image = prepare_img()243inputs = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)244
245# forward pass246with torch.no_grad():247outputs = model(inputs)248
249# verify the logits250self.assertEqual(outputs.logits[0].shape, torch.Size((1, 27, 38)))251
252out_strs = processor.batch_decode(outputs.logits)253expected_text = "ticket"254
255self.assertEqual(out_strs["generated_text"][0], expected_text)256
257expected_slice = torch.tensor(258[[[-39.5397, -44.4024, -36.1844], [-61.4709, -63.8639, -58.3454], [-74.0225, -68.5494, -71.2164]]],259device=torch_device,260)261
262self.assertTrue(torch.allclose(outputs.logits[0][:, 1:4, 1:4], expected_slice, atol=1e-4))263