transformers
857 строк · 34.4 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch Pix2Struct model. """
16
17import copy18import inspect19import os20import tempfile21import unittest22
23import numpy as np24import requests25
26from transformers import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig27from transformers.testing_utils import require_torch, require_vision, slow, torch_device28from transformers.utils import is_torch_available, is_vision_available29
30from ...test_configuration_common import ConfigTester31from ...test_modeling_common import (32ModelTesterMixin,33_config_zero_init,34floats_tensor,35ids_tensor,36random_attention_mask,37)
38from ...test_pipeline_mixin import PipelineTesterMixin39
40
41if is_torch_available():42import torch43from torch import nn44
45from transformers import (46Pix2StructForConditionalGeneration,47Pix2StructProcessor,48Pix2StructTextModel,49Pix2StructVisionModel,50)51from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST52
53
54if is_vision_available():55from PIL import Image56
57
58class Pix2StructVisionModelTester:59def __init__(60self,61parent,62batch_size=12,63image_size=30,64patch_size=2,65num_channels=3,66is_training=True,67hidden_size=12,68patch_embed_hidden_size=12,69projection_dim=32,70max_patches=64,71num_hidden_layers=2,72num_attention_heads=4,73intermediate_size=37,74dropout=0.1,75attention_dropout=0.1,76initializer_range=1e-10,77scope=None,78):79self.parent = parent80self.batch_size = batch_size81self.image_size = image_size82self.patch_embed_hidden_size = patch_embed_hidden_size83self.patch_size = patch_size84self.num_channels = num_channels85self.is_training = is_training86self.hidden_size = hidden_size87self.max_patches = max_patches88self.seq_length = self.max_patches89self.patch_proj_dim = ((patch_size**2) * num_channels) + 290
91self.projection_dim = projection_dim92self.num_hidden_layers = num_hidden_layers93self.num_attention_heads = num_attention_heads94self.intermediate_size = intermediate_size95self.dropout = dropout96self.attention_dropout = attention_dropout97self.initializer_range = initializer_range98self.scope = scope99
100def prepare_config_and_inputs(self):101flattened_patches = floats_tensor([self.batch_size, self.max_patches, self.patch_proj_dim])102config = self.get_config()103
104return config, flattened_patches105
106def get_config(self):107return Pix2StructVisionConfig(108image_size=self.image_size,109patch_size=self.patch_size,110num_channels=self.num_channels,111hidden_size=self.hidden_size,112projection_dim=self.projection_dim,113num_hidden_layers=self.num_hidden_layers,114num_attention_heads=self.num_attention_heads,115intermediate_size=self.intermediate_size,116dropout=self.dropout,117attention_dropout=self.attention_dropout,118initializer_range=self.initializer_range,119patch_embed_hidden_size=self.patch_embed_hidden_size,120)121
122def create_and_check_model(self, config, flattened_patches):123model = Pix2StructVisionModel(config=config)124model.to(torch_device)125model.eval()126with torch.no_grad():127result = model(flattened_patches)128self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))129
130def prepare_config_and_inputs_for_common(self):131config_and_inputs = self.prepare_config_and_inputs()132config, flattened_patches = config_and_inputs133inputs_dict = {134"flattened_patches": flattened_patches,135"attention_mask": torch.randint(0, 2, (self.batch_size, self.max_patches)),136}137return config, inputs_dict138
139
140@require_torch
141class Pix2StructVisionModelTest(ModelTesterMixin, unittest.TestCase):142"""143Here we also overwrite some of the tests of test_modeling_common.py, as Pix2Struct does not use input_ids, inputs_embeds,
144attention_mask and seq_length.
145"""
146
147all_model_classes = (Pix2StructVisionModel,) if is_torch_available() else ()148fx_compatible = False149test_pruning = False150test_resize_embeddings = False151test_head_masking = False152
153def setUp(self):154self.model_tester = Pix2StructVisionModelTester(self)155self.config_tester = ConfigTester(156self, config_class=Pix2StructVisionConfig, has_text_modality=False, hidden_size=37157)158
159def test_config(self):160self.config_tester.run_common_tests()161
162@unittest.skip(reason="Pix2StructVision does not use inputs_embeds")163def test_inputs_embeds(self):164pass165
166def test_model_common_attributes(self):167config, _ = self.model_tester.prepare_config_and_inputs_for_common()168
169for model_class in self.all_model_classes:170model = model_class(config)171self.assertIsInstance(model.get_input_embeddings(), (nn.Module))172x = model.get_output_embeddings()173self.assertTrue(x is None or isinstance(x, nn.Linear))174
175def test_forward_signature(self):176config, _ = self.model_tester.prepare_config_and_inputs_for_common()177
178for model_class in self.all_model_classes:179model = model_class(config)180signature = inspect.signature(model.forward)181# signature.parameters is an OrderedDict => so arg_names order is deterministic182arg_names = [*signature.parameters.keys()]183
184expected_arg_names = ["flattened_patches"]185self.assertListEqual(arg_names[:1], expected_arg_names)186
187def test_model(self):188config_and_inputs = self.model_tester.prepare_config_and_inputs()189self.model_tester.create_and_check_model(*config_and_inputs)190
191@unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")192def test_training(self):193pass194
195@unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")196def test_training_gradient_checkpointing(self):197pass198
199@unittest.skip(200reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"201)202def test_training_gradient_checkpointing_use_reentrant(self):203pass204
205@unittest.skip(206reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"207)208def test_training_gradient_checkpointing_use_reentrant_false(self):209pass210
211@unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")212def test_retain_grad_hidden_states_attentions(self):213pass214
215@unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING")216def test_save_load_fast_init_from_base(self):217pass218
219@unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING")220def test_save_load_fast_init_to_base(self):221pass222
223@slow224def test_model_from_pretrained(self):225for model_name in PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:226model = Pix2StructVisionModel.from_pretrained(model_name)227self.assertIsNotNone(model)228
229
230class Pix2StructTextModelTester:231def __init__(232self,233parent,234batch_size=12,235seq_length=7,236is_training=True,237use_input_mask=True,238use_labels=True,239vocab_size=99,240hidden_size=12,241projection_dim=32,242num_hidden_layers=2,243num_attention_heads=4,244intermediate_size=37,245dropout=0.1,246attention_dropout=0.1,247max_position_embeddings=512,248initializer_range=0.02,249bos_token_id=0,250scope=None,251):252self.parent = parent253self.batch_size = batch_size254self.seq_length = seq_length255self.is_training = is_training256self.use_input_mask = use_input_mask257self.use_labels = use_labels258self.d_kv = hidden_size // num_attention_heads259self.vocab_size = vocab_size260self.hidden_size = hidden_size261self.projection_dim = projection_dim262self.num_hidden_layers = num_hidden_layers263self.num_attention_heads = num_attention_heads264self.intermediate_size = intermediate_size265self.dropout = dropout266self.attention_dropout = attention_dropout267self.max_position_embeddings = max_position_embeddings268self.initializer_range = initializer_range269self.scope = scope270self.bos_token_id = bos_token_id271
272def prepare_config_and_inputs(self):273input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)274
275input_mask = None276if self.use_input_mask:277input_mask = random_attention_mask([self.batch_size, self.seq_length])278
279if input_mask is not None:280batch_size, seq_length = input_mask.shape281rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))282for batch_idx, start_index in enumerate(rnd_start_indices):283input_mask[batch_idx, :start_index] = 1284input_mask[batch_idx, start_index:] = 0285
286config = self.get_config()287
288return config, input_ids, input_mask289
290def get_config(self):291return Pix2StructTextConfig(292vocab_size=self.vocab_size,293hidden_size=self.hidden_size,294projection_dim=self.projection_dim,295num_hidden_layers=self.num_hidden_layers,296num_attention_heads=self.num_attention_heads,297intermediate_size=self.intermediate_size,298dropout=self.dropout,299attention_dropout=self.attention_dropout,300max_position_embeddings=self.max_position_embeddings,301initializer_range=self.initializer_range,302bos_token_id=self.bos_token_id,303d_kv=self.d_kv,304)305
306def create_and_check_model(self, config, input_ids, input_mask):307model = Pix2StructTextModel(config=config)308model.to(torch_device)309model.eval()310with torch.no_grad():311result = model(input_ids, attention_mask=input_mask)312result = model(input_ids)313self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))314
315def prepare_config_and_inputs_for_common(self):316config_and_inputs = self.prepare_config_and_inputs()317config, input_ids, input_mask = config_and_inputs318inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}319return config, inputs_dict320
321
322@require_torch
323class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase):324all_model_classes = (Pix2StructTextModel,) if is_torch_available() else ()325fx_compatible = False326test_pruning = False327test_head_masking = False328
329def setUp(self):330self.model_tester = Pix2StructTextModelTester(self)331self.config_tester = ConfigTester(self, config_class=Pix2StructTextConfig, hidden_size=37)332
333def test_config(self):334self.config_tester.run_common_tests()335
336def test_model(self):337config_and_inputs = self.model_tester.prepare_config_and_inputs()338self.model_tester.create_and_check_model(*config_and_inputs)339
340@unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")341def test_training(self):342pass343
344@unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")345def test_training_gradient_checkpointing(self):346pass347
348@unittest.skip(349reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"350)351def test_training_gradient_checkpointing_use_reentrant(self):352pass353
354@unittest.skip(355reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"356)357def test_training_gradient_checkpointing_use_reentrant_false(self):358pass359
360@unittest.skip(reason="Pix2Struct does not use inputs_embeds")361def test_inputs_embeds(self):362pass363
364@unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING")365def test_save_load_fast_init_from_base(self):366pass367
368@unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING")369def test_save_load_fast_init_to_base(self):370pass371
372@slow373def test_model_from_pretrained(self):374for model_name in PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:375model = Pix2StructTextModel.from_pretrained(model_name)376self.assertIsNotNone(model)377
378
379class Pix2StructModelTester:380def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):381if text_kwargs is None:382text_kwargs = {}383if vision_kwargs is None:384vision_kwargs = {}385
386self.parent = parent387self.text_model_tester = Pix2StructTextModelTester(parent, **text_kwargs)388self.vision_model_tester = Pix2StructVisionModelTester(parent, **vision_kwargs)389self.is_training = is_training390
391def prepare_config_and_inputs(self):392text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()393vision_config, flattened_patches = self.vision_model_tester.prepare_config_and_inputs()394
395config = self.get_config(text_config, vision_config)396
397return config, input_ids, attention_mask, flattened_patches398
399def get_config(self, text_config, vision_config):400return Pix2StructConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)401
402def prepare_config_and_inputs_for_common(self):403config_and_inputs = self.prepare_config_and_inputs()404config, input_ids, decoder_attention_mask, flattened_patches = config_and_inputs405
406attention_mask = (flattened_patches.sum(dim=-1) != 0).float()407
408inputs_dict = {409"decoder_input_ids": input_ids,410"labels": input_ids,411"decoder_attention_mask": decoder_attention_mask,412"flattened_patches": flattened_patches,413"attention_mask": attention_mask,414}415return config, inputs_dict416
417
418@require_torch
419class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):420all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()421pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}422fx_compatible = False423test_head_masking = False424test_pruning = False425test_resize_embeddings = True426test_attention_outputs = False427test_torchscript = False428
429def setUp(self):430self.model_tester = Pix2StructModelTester(self)431
432def test_model(self):433config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()434for model_class in self.all_model_classes:435model = model_class(config).to(torch_device)436
437output = model(**input_dict)438self.assertEqual(439output[1].shape,440(441self.model_tester.vision_model_tester.batch_size,442self.model_tester.text_model_tester.seq_length,443self.model_tester.text_model_tester.vocab_size,444),445)446
447@unittest.skip(reason="Hidden_states is tested in individual model tests")448def test_hidden_states_output(self):449pass450
451@unittest.skip(reason="Inputs_embeds is tested in individual model tests")452def test_inputs_embeds(self):453pass454
455@unittest.skip(reason="Retain_grad is tested in individual model tests")456def test_retain_grad_hidden_states_attentions(self):457pass458
459@unittest.skip(reason="Pix2StructModel does not have input/output embeddings")460def test_model_common_attributes(self):461pass462
463def test_forward_signature(self):464config, _ = self.model_tester.prepare_config_and_inputs_for_common()465
466for model_class in self.all_model_classes:467model = model_class(config)468signature = inspect.signature(model.forward)469# signature.parameters is an OrderedDict => so arg_names order is deterministic470arg_names = [*signature.parameters.keys()]471
472expected_arg_names = [473"flattened_patches",474"attention_mask",475"decoder_input_ids",476"decoder_attention_mask",477"head_mask",478"decoder_head_mask",479"cross_attn_head_mask",480"encoder_outputs",481"past_key_values",482"labels",483"decoder_inputs_embeds",484"use_cache",485]486
487self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)488
489def test_training(self):490if not self.model_tester.is_training:491return492
493for model_class in self.all_model_classes[:-1]:494config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()495config.return_dict = True496
497model = model_class(config)498model.to(torch_device)499model.train()500inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)501
502# hardcode labels to be the same as input_ids503inputs["labels"] = inputs["input_ids"]504
505loss = model(**inputs).loss506loss.backward()507
508def test_training_gradient_checkpointing(self):509if not self.model_tester.is_training:510return511
512for model_class in self.all_model_classes[:-1]:513config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()514config.use_cache = False515config.return_dict = True516
517model = model_class(config)518model.to(torch_device)519model.gradient_checkpointing_enable()520model.train()521inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)522
523# hardcode labels to be the same as input_ids524inputs["labels"] = inputs["input_ids"]525
526loss = model(**inputs).loss527loss.backward()528
529# override as the `logit_scale` parameter initilization is different for Pix2Struct530def test_initialization(self):531config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()532
533configs_no_init = _config_zero_init(config)534for model_class in self.all_model_classes:535model = model_class(config=configs_no_init)536for name, param in model.named_parameters():537if param.requires_grad:538# check if `logit_scale` is initilized as per the original implementation539if name == "logit_scale":540self.assertAlmostEqual(541param.data.item(),542np.log(1 / 0.07),543delta=1e-3,544msg=f"Parameter {name} of model {model_class} seems not properly initialized",545)546else:547self.assertIn(548((param.data.mean() * 1e9).round() / 1e9).item(),549[0.0, 1.0],550msg=f"Parameter {name} of model {model_class} seems not properly initialized",551)552
553# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`554def test_resize_tokens_embeddings(self):555original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()556if not self.test_resize_embeddings:557return558
559for model_class in self.all_model_classes:560config = copy.deepcopy(original_config)561model = model_class(config)562model.to(torch_device)563
564if self.model_tester.is_training is False:565model.eval()566
567model_vocab_size = config.text_config.vocab_size568# Retrieve the embeddings and clone theme569model_embed = model.resize_token_embeddings(model_vocab_size)570cloned_embeddings = model_embed.weight.clone()571
572# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size573model_embed = model.resize_token_embeddings(model_vocab_size + 10)574self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)575# Check that it actually resizes the embeddings matrix576self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)577# Check that the model can still do a forward pass successfully (every parameter should be resized)578model(**self._prepare_for_class(inputs_dict, model_class))579
580# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size581model_embed = model.resize_token_embeddings(model_vocab_size - 15)582self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)583# Check that it actually resizes the embeddings matrix584self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)585
586# Check that the model can still do a forward pass successfully (every parameter should be resized)587# Decoder input ids should be clamped to the maximum size of the vocabulary588if "decoder_input_ids" in inputs_dict:589inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)590model(**self._prepare_for_class(inputs_dict, model_class))591
592# Check that adding and removing tokens has not modified the first part of the embedding matrix.593models_equal = True594for p1, p2 in zip(cloned_embeddings, model_embed.weight):595if p1.data.ne(p2.data).sum() > 0:596models_equal = False597
598self.assertTrue(models_equal)599
600# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`601def test_resize_embeddings_untied(self):602original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()603if not self.test_resize_embeddings:604return605
606original_config.tie_word_embeddings = False607
608# if model cannot untied embeddings -> leave test609if original_config.tie_word_embeddings:610return611
612for model_class in self.all_model_classes:613config = copy.deepcopy(original_config)614model = model_class(config).to(torch_device)615
616# if no output embeddings -> leave test617if model.get_output_embeddings() is None:618continue619
620# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size621model_vocab_size = config.text_config.vocab_size622model.resize_token_embeddings(model_vocab_size + 10)623self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)624output_embeds = model.get_output_embeddings()625self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)626# Check bias if present627if output_embeds.bias is not None:628self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)629# Check that the model can still do a forward pass successfully (every parameter should be resized)630model(**self._prepare_for_class(inputs_dict, model_class))631
632# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size633model.resize_token_embeddings(model_vocab_size - 15)634self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)635# Check that it actually resizes the embeddings matrix636output_embeds = model.get_output_embeddings()637self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)638# Check bias if present639if output_embeds.bias is not None:640self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)641# Check that the model can still do a forward pass successfully (every parameter should be resized)642# Decoder input ids should be clamped to the maximum size of the vocabulary643if "decoder_input_ids" in inputs_dict:644inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)645# Check that the model can still do a forward pass successfully (every parameter should be resized)646model(**self._prepare_for_class(inputs_dict, model_class))647
648@unittest.skip(reason="Pix2Struct doesn't use tied weights")649def test_tied_model_weights_key_ignore(self):650pass651
652def _create_and_check_torchscript(self, config, inputs_dict):653if not self.test_torchscript:654return655
656configs_no_init = _config_zero_init(config) # To be sure we have no Nan657configs_no_init.torchscript = True658configs_no_init.return_dict = False659for model_class in self.all_model_classes:660model = model_class(config=configs_no_init)661model.to(torch_device)662model.eval()663
664try:665input_ids = inputs_dict["input_ids"]666flattened_patches = inputs_dict["flattened_patches"] # Pix2Struct needs flattened_patches667traced_model = torch.jit.trace(model, (input_ids, flattened_patches))668except RuntimeError:669self.fail("Couldn't trace module.")670
671with tempfile.TemporaryDirectory() as tmp_dir_name:672pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")673
674try:675torch.jit.save(traced_model, pt_file_name)676except Exception:677self.fail("Couldn't save module.")678
679try:680loaded_model = torch.jit.load(pt_file_name)681except Exception:682self.fail("Couldn't load module.")683
684model.to(torch_device)685model.eval()686
687loaded_model.to(torch_device)688loaded_model.eval()689
690model_state_dict = model.state_dict()691loaded_model_state_dict = loaded_model.state_dict()692
693non_persistent_buffers = {}694for key in loaded_model_state_dict.keys():695if key not in model_state_dict.keys():696non_persistent_buffers[key] = loaded_model_state_dict[key]697
698loaded_model_state_dict = {699key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers700}701
702self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))703
704model_buffers = list(model.buffers())705for non_persistent_buffer in non_persistent_buffers.values():706found_buffer = False707for i, model_buffer in enumerate(model_buffers):708if torch.equal(non_persistent_buffer, model_buffer):709found_buffer = True710break711
712self.assertTrue(found_buffer)713model_buffers.pop(i)714
715models_equal = True716for layer_name, p1 in model_state_dict.items():717p2 = loaded_model_state_dict[layer_name]718if p1.data.ne(p2.data).sum() > 0:719models_equal = False720
721self.assertTrue(models_equal)722
723def test_load_vision_text_config(self):724config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()725
726# Save Pix2StructConfig and check if we can load Pix2StructVisionConfig from it727with tempfile.TemporaryDirectory() as tmp_dir_name:728config.save_pretrained(tmp_dir_name)729vision_config = Pix2StructVisionConfig.from_pretrained(tmp_dir_name)730self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())731
732# Save Pix2StructConfig and check if we can load Pix2StructTextConfig from it733with tempfile.TemporaryDirectory() as tmp_dir_name:734config.save_pretrained(tmp_dir_name)735text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)736self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())737
738
739# We will verify our results on an image of a stop sign
740def prepare_img():741url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg"742im = Image.open(requests.get(url, stream=True).raw)743return im744
745
746@require_vision
747@require_torch
748@slow
749class Pix2StructIntegrationTest(unittest.TestCase):750def test_inference_image_captioning(self):751model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)752processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")753image = prepare_img()754
755# image only756inputs = processor(images=image, return_tensors="pt").to(torch_device)757
758predictions = model.generate(**inputs)759
760self.assertEqual(761processor.decode(predictions[0], skip_special_tokens=True), "A stop sign is on a street corner."762)763
764def test_batched_inference_image_captioning(self):765model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)766processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")767image_1 = prepare_img()768
769second_url = (770"https://www.connollycove.com/wp-content/uploads/2019/06/temple-bar-dublin-world-famous-irish-pub.jpg"771)772image_2 = Image.open(requests.get(second_url, stream=True).raw)773
774# image only775inputs = processor(images=[image_1, image_2], return_tensors="pt").to(torch_device)776
777predictions = model.generate(**inputs)778
779self.assertEqual(780processor.decode(predictions[0], skip_special_tokens=True), "A stop sign is on a street corner."781)782
783self.assertEqual(784processor.decode(predictions[1], skip_special_tokens=True),785"A row of books including The Temple Bar and Guiness.",786)787
788def test_batched_inference_image_captioning_conditioned(self):789model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)790processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")791image_1 = prepare_img()792
793second_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/temple-bar-dublin-world-famous-irish-pub.jpg"794image_2 = Image.open(requests.get(second_url, stream=True).raw)795texts = ["A picture of", "An photography of"]796
797# image only798inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", add_special_tokens=False).to(799torch_device
800)801
802predictions = model.generate(**inputs)803
804self.assertEqual(805processor.decode(predictions[0], skip_special_tokens=True),806"A picture of a stop sign with a red stop sign",807)808
809self.assertEqual(810processor.decode(predictions[1], skip_special_tokens=True),811"An photography of the Temple Bar and other places in the city.",812)813
814def test_vqa_model(self):815model_id = "google/pix2struct-ai2d-base"816
817image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"818image = Image.open(requests.get(image_url, stream=True).raw)819
820model = Pix2StructForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(821torch_device
822)823processor = Pix2StructProcessor.from_pretrained(model_id)824
825# image only826text = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"827
828inputs = processor(images=image, return_tensors="pt", text=text).to(torch_device, torch.bfloat16)829
830predictions = model.generate(**inputs)831self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")832
833def test_vqa_model_batched(self):834model_id = "google/pix2struct-ai2d-base"835
836image_urls = [837"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",838"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo-2.png",839]840
841images = [Image.open(requests.get(image_url, stream=True).raw) for image_url in image_urls]842
843texts = [844"What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",845"What is the producer in the diagram? (1) Phytoplankton (2) Zooplankton (3) Large fish (4) Small fish",846]847
848model = Pix2StructForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(849torch_device
850)851processor = Pix2StructProcessor.from_pretrained(model_id)852
853inputs = processor(images=images, return_tensors="pt", text=texts).to(torch_device, torch.bfloat16)854
855predictions = model.generate(**inputs)856self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")857self.assertEqual(processor.decode(predictions[1], skip_special_tokens=True), "Phytoplankton")858