transformers

Форк
0
/
test_modeling_pix2struct.py 
857 строк · 34.4 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch Pix2Struct model. """
16

17
import copy
18
import inspect
19
import os
20
import tempfile
21
import unittest
22

23
import numpy as np
24
import requests
25

26
from transformers import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
27
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
28
from transformers.utils import is_torch_available, is_vision_available
29

30
from ...test_configuration_common import ConfigTester
31
from ...test_modeling_common import (
32
    ModelTesterMixin,
33
    _config_zero_init,
34
    floats_tensor,
35
    ids_tensor,
36
    random_attention_mask,
37
)
38
from ...test_pipeline_mixin import PipelineTesterMixin
39

40

41
if is_torch_available():
42
    import torch
43
    from torch import nn
44

45
    from transformers import (
46
        Pix2StructForConditionalGeneration,
47
        Pix2StructProcessor,
48
        Pix2StructTextModel,
49
        Pix2StructVisionModel,
50
    )
51
    from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
52

53

54
if is_vision_available():
55
    from PIL import Image
56

57

58
class Pix2StructVisionModelTester:
59
    def __init__(
60
        self,
61
        parent,
62
        batch_size=12,
63
        image_size=30,
64
        patch_size=2,
65
        num_channels=3,
66
        is_training=True,
67
        hidden_size=12,
68
        patch_embed_hidden_size=12,
69
        projection_dim=32,
70
        max_patches=64,
71
        num_hidden_layers=2,
72
        num_attention_heads=4,
73
        intermediate_size=37,
74
        dropout=0.1,
75
        attention_dropout=0.1,
76
        initializer_range=1e-10,
77
        scope=None,
78
    ):
79
        self.parent = parent
80
        self.batch_size = batch_size
81
        self.image_size = image_size
82
        self.patch_embed_hidden_size = patch_embed_hidden_size
83
        self.patch_size = patch_size
84
        self.num_channels = num_channels
85
        self.is_training = is_training
86
        self.hidden_size = hidden_size
87
        self.max_patches = max_patches
88
        self.seq_length = self.max_patches
89
        self.patch_proj_dim = ((patch_size**2) * num_channels) + 2
90

91
        self.projection_dim = projection_dim
92
        self.num_hidden_layers = num_hidden_layers
93
        self.num_attention_heads = num_attention_heads
94
        self.intermediate_size = intermediate_size
95
        self.dropout = dropout
96
        self.attention_dropout = attention_dropout
97
        self.initializer_range = initializer_range
98
        self.scope = scope
99

100
    def prepare_config_and_inputs(self):
101
        flattened_patches = floats_tensor([self.batch_size, self.max_patches, self.patch_proj_dim])
102
        config = self.get_config()
103

104
        return config, flattened_patches
105

106
    def get_config(self):
107
        return Pix2StructVisionConfig(
108
            image_size=self.image_size,
109
            patch_size=self.patch_size,
110
            num_channels=self.num_channels,
111
            hidden_size=self.hidden_size,
112
            projection_dim=self.projection_dim,
113
            num_hidden_layers=self.num_hidden_layers,
114
            num_attention_heads=self.num_attention_heads,
115
            intermediate_size=self.intermediate_size,
116
            dropout=self.dropout,
117
            attention_dropout=self.attention_dropout,
118
            initializer_range=self.initializer_range,
119
            patch_embed_hidden_size=self.patch_embed_hidden_size,
120
        )
121

122
    def create_and_check_model(self, config, flattened_patches):
123
        model = Pix2StructVisionModel(config=config)
124
        model.to(torch_device)
125
        model.eval()
126
        with torch.no_grad():
127
            result = model(flattened_patches)
128
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
129

130
    def prepare_config_and_inputs_for_common(self):
131
        config_and_inputs = self.prepare_config_and_inputs()
132
        config, flattened_patches = config_and_inputs
133
        inputs_dict = {
134
            "flattened_patches": flattened_patches,
135
            "attention_mask": torch.randint(0, 2, (self.batch_size, self.max_patches)),
136
        }
137
        return config, inputs_dict
138

139

140
@require_torch
141
class Pix2StructVisionModelTest(ModelTesterMixin, unittest.TestCase):
142
    """
143
    Here we also overwrite some of the tests of test_modeling_common.py, as Pix2Struct does not use input_ids, inputs_embeds,
144
    attention_mask and seq_length.
145
    """
146

147
    all_model_classes = (Pix2StructVisionModel,) if is_torch_available() else ()
148
    fx_compatible = False
149
    test_pruning = False
150
    test_resize_embeddings = False
151
    test_head_masking = False
152

153
    def setUp(self):
154
        self.model_tester = Pix2StructVisionModelTester(self)
155
        self.config_tester = ConfigTester(
156
            self, config_class=Pix2StructVisionConfig, has_text_modality=False, hidden_size=37
157
        )
158

159
    def test_config(self):
160
        self.config_tester.run_common_tests()
161

162
    @unittest.skip(reason="Pix2StructVision does not use inputs_embeds")
163
    def test_inputs_embeds(self):
164
        pass
165

166
    def test_model_common_attributes(self):
167
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
168

169
        for model_class in self.all_model_classes:
170
            model = model_class(config)
171
            self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
172
            x = model.get_output_embeddings()
173
            self.assertTrue(x is None or isinstance(x, nn.Linear))
174

175
    def test_forward_signature(self):
176
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
177

178
        for model_class in self.all_model_classes:
179
            model = model_class(config)
180
            signature = inspect.signature(model.forward)
181
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
182
            arg_names = [*signature.parameters.keys()]
183

184
            expected_arg_names = ["flattened_patches"]
185
            self.assertListEqual(arg_names[:1], expected_arg_names)
186

187
    def test_model(self):
188
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
189
        self.model_tester.create_and_check_model(*config_and_inputs)
190

191
    @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")
192
    def test_training(self):
193
        pass
194

195
    @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")
196
    def test_training_gradient_checkpointing(self):
197
        pass
198

199
    @unittest.skip(
200
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
201
    )
202
    def test_training_gradient_checkpointing_use_reentrant(self):
203
        pass
204

205
    @unittest.skip(
206
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
207
    )
208
    def test_training_gradient_checkpointing_use_reentrant_false(self):
209
        pass
210

211
    @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")
212
    def test_retain_grad_hidden_states_attentions(self):
213
        pass
214

215
    @unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING")
216
    def test_save_load_fast_init_from_base(self):
217
        pass
218

219
    @unittest.skip(reason="Pix2StructVisionModel has no base class and is not available in MODEL_MAPPING")
220
    def test_save_load_fast_init_to_base(self):
221
        pass
222

223
    @slow
224
    def test_model_from_pretrained(self):
225
        for model_name in PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
226
            model = Pix2StructVisionModel.from_pretrained(model_name)
227
            self.assertIsNotNone(model)
228

229

230
class Pix2StructTextModelTester:
231
    def __init__(
232
        self,
233
        parent,
234
        batch_size=12,
235
        seq_length=7,
236
        is_training=True,
237
        use_input_mask=True,
238
        use_labels=True,
239
        vocab_size=99,
240
        hidden_size=12,
241
        projection_dim=32,
242
        num_hidden_layers=2,
243
        num_attention_heads=4,
244
        intermediate_size=37,
245
        dropout=0.1,
246
        attention_dropout=0.1,
247
        max_position_embeddings=512,
248
        initializer_range=0.02,
249
        bos_token_id=0,
250
        scope=None,
251
    ):
252
        self.parent = parent
253
        self.batch_size = batch_size
254
        self.seq_length = seq_length
255
        self.is_training = is_training
256
        self.use_input_mask = use_input_mask
257
        self.use_labels = use_labels
258
        self.d_kv = hidden_size // num_attention_heads
259
        self.vocab_size = vocab_size
260
        self.hidden_size = hidden_size
261
        self.projection_dim = projection_dim
262
        self.num_hidden_layers = num_hidden_layers
263
        self.num_attention_heads = num_attention_heads
264
        self.intermediate_size = intermediate_size
265
        self.dropout = dropout
266
        self.attention_dropout = attention_dropout
267
        self.max_position_embeddings = max_position_embeddings
268
        self.initializer_range = initializer_range
269
        self.scope = scope
270
        self.bos_token_id = bos_token_id
271

272
    def prepare_config_and_inputs(self):
273
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
274

275
        input_mask = None
276
        if self.use_input_mask:
277
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
278

279
        if input_mask is not None:
280
            batch_size, seq_length = input_mask.shape
281
            rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
282
            for batch_idx, start_index in enumerate(rnd_start_indices):
283
                input_mask[batch_idx, :start_index] = 1
284
                input_mask[batch_idx, start_index:] = 0
285

286
        config = self.get_config()
287

288
        return config, input_ids, input_mask
289

290
    def get_config(self):
291
        return Pix2StructTextConfig(
292
            vocab_size=self.vocab_size,
293
            hidden_size=self.hidden_size,
294
            projection_dim=self.projection_dim,
295
            num_hidden_layers=self.num_hidden_layers,
296
            num_attention_heads=self.num_attention_heads,
297
            intermediate_size=self.intermediate_size,
298
            dropout=self.dropout,
299
            attention_dropout=self.attention_dropout,
300
            max_position_embeddings=self.max_position_embeddings,
301
            initializer_range=self.initializer_range,
302
            bos_token_id=self.bos_token_id,
303
            d_kv=self.d_kv,
304
        )
305

306
    def create_and_check_model(self, config, input_ids, input_mask):
307
        model = Pix2StructTextModel(config=config)
308
        model.to(torch_device)
309
        model.eval()
310
        with torch.no_grad():
311
            result = model(input_ids, attention_mask=input_mask)
312
            result = model(input_ids)
313
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
314

315
    def prepare_config_and_inputs_for_common(self):
316
        config_and_inputs = self.prepare_config_and_inputs()
317
        config, input_ids, input_mask = config_and_inputs
318
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
319
        return config, inputs_dict
320

321

322
@require_torch
323
class Pix2StructTextModelTest(ModelTesterMixin, unittest.TestCase):
324
    all_model_classes = (Pix2StructTextModel,) if is_torch_available() else ()
325
    fx_compatible = False
326
    test_pruning = False
327
    test_head_masking = False
328

329
    def setUp(self):
330
        self.model_tester = Pix2StructTextModelTester(self)
331
        self.config_tester = ConfigTester(self, config_class=Pix2StructTextConfig, hidden_size=37)
332

333
    def test_config(self):
334
        self.config_tester.run_common_tests()
335

336
    def test_model(self):
337
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
338
        self.model_tester.create_and_check_model(*config_and_inputs)
339

340
    @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")
341
    def test_training(self):
342
        pass
343

344
    @unittest.skip(reason="Training is tested directly on `Pix2StructTextImageModelTest`")
345
    def test_training_gradient_checkpointing(self):
346
        pass
347

348
    @unittest.skip(
349
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
350
    )
351
    def test_training_gradient_checkpointing_use_reentrant(self):
352
        pass
353

354
    @unittest.skip(
355
        reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
356
    )
357
    def test_training_gradient_checkpointing_use_reentrant_false(self):
358
        pass
359

360
    @unittest.skip(reason="Pix2Struct does not use inputs_embeds")
361
    def test_inputs_embeds(self):
362
        pass
363

364
    @unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING")
365
    def test_save_load_fast_init_from_base(self):
366
        pass
367

368
    @unittest.skip(reason="Pix2StructTextModel has no base class and is not available in MODEL_MAPPING")
369
    def test_save_load_fast_init_to_base(self):
370
        pass
371

372
    @slow
373
    def test_model_from_pretrained(self):
374
        for model_name in PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
375
            model = Pix2StructTextModel.from_pretrained(model_name)
376
            self.assertIsNotNone(model)
377

378

379
class Pix2StructModelTester:
380
    def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
381
        if text_kwargs is None:
382
            text_kwargs = {}
383
        if vision_kwargs is None:
384
            vision_kwargs = {}
385

386
        self.parent = parent
387
        self.text_model_tester = Pix2StructTextModelTester(parent, **text_kwargs)
388
        self.vision_model_tester = Pix2StructVisionModelTester(parent, **vision_kwargs)
389
        self.is_training = is_training
390

391
    def prepare_config_and_inputs(self):
392
        text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
393
        vision_config, flattened_patches = self.vision_model_tester.prepare_config_and_inputs()
394

395
        config = self.get_config(text_config, vision_config)
396

397
        return config, input_ids, attention_mask, flattened_patches
398

399
    def get_config(self, text_config, vision_config):
400
        return Pix2StructConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)
401

402
    def prepare_config_and_inputs_for_common(self):
403
        config_and_inputs = self.prepare_config_and_inputs()
404
        config, input_ids, decoder_attention_mask, flattened_patches = config_and_inputs
405

406
        attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
407

408
        inputs_dict = {
409
            "decoder_input_ids": input_ids,
410
            "labels": input_ids,
411
            "decoder_attention_mask": decoder_attention_mask,
412
            "flattened_patches": flattened_patches,
413
            "attention_mask": attention_mask,
414
        }
415
        return config, inputs_dict
416

417

418
@require_torch
419
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
420
    all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
421
    pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
422
    fx_compatible = False
423
    test_head_masking = False
424
    test_pruning = False
425
    test_resize_embeddings = True
426
    test_attention_outputs = False
427
    test_torchscript = False
428

429
    def setUp(self):
430
        self.model_tester = Pix2StructModelTester(self)
431

432
    def test_model(self):
433
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
434
        for model_class in self.all_model_classes:
435
            model = model_class(config).to(torch_device)
436

437
            output = model(**input_dict)
438
            self.assertEqual(
439
                output[1].shape,
440
                (
441
                    self.model_tester.vision_model_tester.batch_size,
442
                    self.model_tester.text_model_tester.seq_length,
443
                    self.model_tester.text_model_tester.vocab_size,
444
                ),
445
            )
446

447
    @unittest.skip(reason="Hidden_states is tested in individual model tests")
448
    def test_hidden_states_output(self):
449
        pass
450

451
    @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
452
    def test_inputs_embeds(self):
453
        pass
454

455
    @unittest.skip(reason="Retain_grad is tested in individual model tests")
456
    def test_retain_grad_hidden_states_attentions(self):
457
        pass
458

459
    @unittest.skip(reason="Pix2StructModel does not have input/output embeddings")
460
    def test_model_common_attributes(self):
461
        pass
462

463
    def test_forward_signature(self):
464
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
465

466
        for model_class in self.all_model_classes:
467
            model = model_class(config)
468
            signature = inspect.signature(model.forward)
469
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
470
            arg_names = [*signature.parameters.keys()]
471

472
            expected_arg_names = [
473
                "flattened_patches",
474
                "attention_mask",
475
                "decoder_input_ids",
476
                "decoder_attention_mask",
477
                "head_mask",
478
                "decoder_head_mask",
479
                "cross_attn_head_mask",
480
                "encoder_outputs",
481
                "past_key_values",
482
                "labels",
483
                "decoder_inputs_embeds",
484
                "use_cache",
485
            ]
486

487
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
488

489
    def test_training(self):
490
        if not self.model_tester.is_training:
491
            return
492

493
        for model_class in self.all_model_classes[:-1]:
494
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
495
            config.return_dict = True
496

497
            model = model_class(config)
498
            model.to(torch_device)
499
            model.train()
500
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
501

502
            # hardcode labels to be the same as input_ids
503
            inputs["labels"] = inputs["input_ids"]
504

505
            loss = model(**inputs).loss
506
            loss.backward()
507

508
    def test_training_gradient_checkpointing(self):
509
        if not self.model_tester.is_training:
510
            return
511

512
        for model_class in self.all_model_classes[:-1]:
513
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
514
            config.use_cache = False
515
            config.return_dict = True
516

517
            model = model_class(config)
518
            model.to(torch_device)
519
            model.gradient_checkpointing_enable()
520
            model.train()
521
            inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
522

523
            # hardcode labels to be the same as input_ids
524
            inputs["labels"] = inputs["input_ids"]
525

526
            loss = model(**inputs).loss
527
            loss.backward()
528

529
    # override as the `logit_scale` parameter initilization is different for Pix2Struct
530
    def test_initialization(self):
531
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
532

533
        configs_no_init = _config_zero_init(config)
534
        for model_class in self.all_model_classes:
535
            model = model_class(config=configs_no_init)
536
            for name, param in model.named_parameters():
537
                if param.requires_grad:
538
                    # check if `logit_scale` is initilized as per the original implementation
539
                    if name == "logit_scale":
540
                        self.assertAlmostEqual(
541
                            param.data.item(),
542
                            np.log(1 / 0.07),
543
                            delta=1e-3,
544
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
545
                        )
546
                    else:
547
                        self.assertIn(
548
                            ((param.data.mean() * 1e9).round() / 1e9).item(),
549
                            [0.0, 1.0],
550
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
551
                        )
552

553
    # overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
554
    def test_resize_tokens_embeddings(self):
555
        original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
556
        if not self.test_resize_embeddings:
557
            return
558

559
        for model_class in self.all_model_classes:
560
            config = copy.deepcopy(original_config)
561
            model = model_class(config)
562
            model.to(torch_device)
563

564
            if self.model_tester.is_training is False:
565
                model.eval()
566

567
            model_vocab_size = config.text_config.vocab_size
568
            # Retrieve the embeddings and clone theme
569
            model_embed = model.resize_token_embeddings(model_vocab_size)
570
            cloned_embeddings = model_embed.weight.clone()
571

572
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
573
            model_embed = model.resize_token_embeddings(model_vocab_size + 10)
574
            self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
575
            # Check that it actually resizes the embeddings matrix
576
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
577
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
578
            model(**self._prepare_for_class(inputs_dict, model_class))
579

580
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
581
            model_embed = model.resize_token_embeddings(model_vocab_size - 15)
582
            self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
583
            # Check that it actually resizes the embeddings matrix
584
            self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
585

586
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
587
            # Decoder input ids should be clamped to the maximum size of the vocabulary
588
            if "decoder_input_ids" in inputs_dict:
589
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
590
            model(**self._prepare_for_class(inputs_dict, model_class))
591

592
            # Check that adding and removing tokens has not modified the first part of the embedding matrix.
593
            models_equal = True
594
            for p1, p2 in zip(cloned_embeddings, model_embed.weight):
595
                if p1.data.ne(p2.data).sum() > 0:
596
                    models_equal = False
597

598
            self.assertTrue(models_equal)
599

600
    # overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
601
    def test_resize_embeddings_untied(self):
602
        original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
603
        if not self.test_resize_embeddings:
604
            return
605

606
        original_config.tie_word_embeddings = False
607

608
        # if model cannot untied embeddings -> leave test
609
        if original_config.tie_word_embeddings:
610
            return
611

612
        for model_class in self.all_model_classes:
613
            config = copy.deepcopy(original_config)
614
            model = model_class(config).to(torch_device)
615

616
            # if no output embeddings -> leave test
617
            if model.get_output_embeddings() is None:
618
                continue
619

620
            # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
621
            model_vocab_size = config.text_config.vocab_size
622
            model.resize_token_embeddings(model_vocab_size + 10)
623
            self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
624
            output_embeds = model.get_output_embeddings()
625
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
626
            # Check bias if present
627
            if output_embeds.bias is not None:
628
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
629
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
630
            model(**self._prepare_for_class(inputs_dict, model_class))
631

632
            # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
633
            model.resize_token_embeddings(model_vocab_size - 15)
634
            self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
635
            # Check that it actually resizes the embeddings matrix
636
            output_embeds = model.get_output_embeddings()
637
            self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
638
            # Check bias if present
639
            if output_embeds.bias is not None:
640
                self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
641
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
642
            # Decoder input ids should be clamped to the maximum size of the vocabulary
643
            if "decoder_input_ids" in inputs_dict:
644
                inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
645
            # Check that the model can still do a forward pass successfully (every parameter should be resized)
646
            model(**self._prepare_for_class(inputs_dict, model_class))
647

648
    @unittest.skip(reason="Pix2Struct doesn't use tied weights")
649
    def test_tied_model_weights_key_ignore(self):
650
        pass
651

652
    def _create_and_check_torchscript(self, config, inputs_dict):
653
        if not self.test_torchscript:
654
            return
655

656
        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
657
        configs_no_init.torchscript = True
658
        configs_no_init.return_dict = False
659
        for model_class in self.all_model_classes:
660
            model = model_class(config=configs_no_init)
661
            model.to(torch_device)
662
            model.eval()
663

664
            try:
665
                input_ids = inputs_dict["input_ids"]
666
                flattened_patches = inputs_dict["flattened_patches"]  # Pix2Struct needs flattened_patches
667
                traced_model = torch.jit.trace(model, (input_ids, flattened_patches))
668
            except RuntimeError:
669
                self.fail("Couldn't trace module.")
670

671
            with tempfile.TemporaryDirectory() as tmp_dir_name:
672
                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
673

674
                try:
675
                    torch.jit.save(traced_model, pt_file_name)
676
                except Exception:
677
                    self.fail("Couldn't save module.")
678

679
                try:
680
                    loaded_model = torch.jit.load(pt_file_name)
681
                except Exception:
682
                    self.fail("Couldn't load module.")
683

684
            model.to(torch_device)
685
            model.eval()
686

687
            loaded_model.to(torch_device)
688
            loaded_model.eval()
689

690
            model_state_dict = model.state_dict()
691
            loaded_model_state_dict = loaded_model.state_dict()
692

693
            non_persistent_buffers = {}
694
            for key in loaded_model_state_dict.keys():
695
                if key not in model_state_dict.keys():
696
                    non_persistent_buffers[key] = loaded_model_state_dict[key]
697

698
            loaded_model_state_dict = {
699
                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
700
            }
701

702
            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
703

704
            model_buffers = list(model.buffers())
705
            for non_persistent_buffer in non_persistent_buffers.values():
706
                found_buffer = False
707
                for i, model_buffer in enumerate(model_buffers):
708
                    if torch.equal(non_persistent_buffer, model_buffer):
709
                        found_buffer = True
710
                        break
711

712
                self.assertTrue(found_buffer)
713
                model_buffers.pop(i)
714

715
            models_equal = True
716
            for layer_name, p1 in model_state_dict.items():
717
                p2 = loaded_model_state_dict[layer_name]
718
                if p1.data.ne(p2.data).sum() > 0:
719
                    models_equal = False
720

721
            self.assertTrue(models_equal)
722

723
    def test_load_vision_text_config(self):
724
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
725

726
        # Save Pix2StructConfig and check if we can load Pix2StructVisionConfig from it
727
        with tempfile.TemporaryDirectory() as tmp_dir_name:
728
            config.save_pretrained(tmp_dir_name)
729
            vision_config = Pix2StructVisionConfig.from_pretrained(tmp_dir_name)
730
            self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
731

732
        # Save Pix2StructConfig and check if we can load Pix2StructTextConfig from it
733
        with tempfile.TemporaryDirectory() as tmp_dir_name:
734
            config.save_pretrained(tmp_dir_name)
735
            text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
736
            self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
737

738

739
# We will verify our results on an image of a stop sign
740
def prepare_img():
741
    url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg"
742
    im = Image.open(requests.get(url, stream=True).raw)
743
    return im
744

745

746
@require_vision
747
@require_torch
748
@slow
749
class Pix2StructIntegrationTest(unittest.TestCase):
750
    def test_inference_image_captioning(self):
751
        model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)
752
        processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
753
        image = prepare_img()
754

755
        # image only
756
        inputs = processor(images=image, return_tensors="pt").to(torch_device)
757

758
        predictions = model.generate(**inputs)
759

760
        self.assertEqual(
761
            processor.decode(predictions[0], skip_special_tokens=True), "A stop sign is on a street corner."
762
        )
763

764
    def test_batched_inference_image_captioning(self):
765
        model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)
766
        processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
767
        image_1 = prepare_img()
768

769
        second_url = (
770
            "https://www.connollycove.com/wp-content/uploads/2019/06/temple-bar-dublin-world-famous-irish-pub.jpg"
771
        )
772
        image_2 = Image.open(requests.get(second_url, stream=True).raw)
773

774
        # image only
775
        inputs = processor(images=[image_1, image_2], return_tensors="pt").to(torch_device)
776

777
        predictions = model.generate(**inputs)
778

779
        self.assertEqual(
780
            processor.decode(predictions[0], skip_special_tokens=True), "A stop sign is on a street corner."
781
        )
782

783
        self.assertEqual(
784
            processor.decode(predictions[1], skip_special_tokens=True),
785
            "A row of books including The Temple Bar and Guiness.",
786
        )
787

788
    def test_batched_inference_image_captioning_conditioned(self):
789
        model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base").to(torch_device)
790
        processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
791
        image_1 = prepare_img()
792

793
        second_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/temple-bar-dublin-world-famous-irish-pub.jpg"
794
        image_2 = Image.open(requests.get(second_url, stream=True).raw)
795
        texts = ["A picture of", "An photography of"]
796

797
        # image only
798
        inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", add_special_tokens=False).to(
799
            torch_device
800
        )
801

802
        predictions = model.generate(**inputs)
803

804
        self.assertEqual(
805
            processor.decode(predictions[0], skip_special_tokens=True),
806
            "A picture of a stop sign with a red stop sign",
807
        )
808

809
        self.assertEqual(
810
            processor.decode(predictions[1], skip_special_tokens=True),
811
            "An photography of the Temple Bar and other places in the city.",
812
        )
813

814
    def test_vqa_model(self):
815
        model_id = "google/pix2struct-ai2d-base"
816

817
        image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
818
        image = Image.open(requests.get(image_url, stream=True).raw)
819

820
        model = Pix2StructForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(
821
            torch_device
822
        )
823
        processor = Pix2StructProcessor.from_pretrained(model_id)
824

825
        # image only
826
        text = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
827

828
        inputs = processor(images=image, return_tensors="pt", text=text).to(torch_device, torch.bfloat16)
829

830
        predictions = model.generate(**inputs)
831
        self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
832

833
    def test_vqa_model_batched(self):
834
        model_id = "google/pix2struct-ai2d-base"
835

836
        image_urls = [
837
            "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
838
            "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo-2.png",
839
        ]
840

841
        images = [Image.open(requests.get(image_url, stream=True).raw) for image_url in image_urls]
842

843
        texts = [
844
            "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
845
            "What is the producer in the diagram? (1) Phytoplankton (2) Zooplankton (3) Large fish (4) Small fish",
846
        ]
847

848
        model = Pix2StructForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(
849
            torch_device
850
        )
851
        processor = Pix2StructProcessor.from_pretrained(model_id)
852

853
        inputs = processor(images=images, return_tensors="pt", text=texts).to(torch_device, torch.bfloat16)
854

855
        predictions = model.generate(**inputs)
856
        self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
857
        self.assertEqual(processor.decode(predictions[1], skip_special_tokens=True), "Phytoplankton")
858

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.