transformers

Форк
0
/
test_modeling_lxmert.py 
789 строк · 30.2 Кб
1
# coding=utf-8
2
# Copyright 2018 LXMERT Authors, The Hugging Face Team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
import copy
18
import unittest
19

20
import numpy as np
21

22
from transformers import LxmertConfig, is_tf_available, is_torch_available
23
from transformers.models.auto import get_values
24
from transformers.testing_utils import require_torch, slow, torch_device
25

26
from ...test_configuration_common import ConfigTester
27
from ...test_modeling_common import ModelTesterMixin, ids_tensor
28
from ...test_pipeline_mixin import PipelineTesterMixin
29

30

31
if is_torch_available():
32
    import torch
33

34
    from transformers import (
35
        MODEL_FOR_PRETRAINING_MAPPING,
36
        MODEL_FOR_QUESTION_ANSWERING_MAPPING,
37
        LxmertForPreTraining,
38
        LxmertForQuestionAnswering,
39
        LxmertModel,
40
    )
41
    from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
42

43

44
if is_tf_available():
45
    import tensorflow as tf
46

47

48
class LxmertModelTester:
49
    def __init__(
50
        self,
51
        parent,
52
        vocab_size=300,
53
        hidden_size=28,
54
        num_attention_heads=2,
55
        num_labels=2,
56
        intermediate_size=64,
57
        hidden_act="gelu",
58
        hidden_dropout_prob=0.1,
59
        attention_probs_dropout_prob=0.1,
60
        max_position_embeddings=512,
61
        type_vocab_size=2,
62
        initializer_range=0.02,
63
        layer_norm_eps=1e-12,
64
        pad_token_id=0,
65
        num_qa_labels=30,
66
        num_object_labels=16,
67
        num_attr_labels=4,
68
        num_visual_features=10,
69
        l_layers=2,
70
        x_layers=1,
71
        r_layers=1,
72
        visual_feat_dim=128,
73
        visual_pos_dim=4,
74
        visual_loss_normalizer=6.67,
75
        seq_length=20,
76
        batch_size=4,
77
        is_training=True,
78
        task_matched=True,
79
        task_mask_lm=True,
80
        task_obj_predict=True,
81
        task_qa=True,
82
        visual_obj_loss=True,
83
        visual_attr_loss=True,
84
        visual_feat_loss=True,
85
        use_token_type_ids=True,
86
        use_lang_mask=True,
87
        output_attentions=False,
88
        output_hidden_states=False,
89
        scope=None,
90
    ):
91
        self.parent = parent
92
        self.vocab_size = vocab_size
93
        self.hidden_size = hidden_size
94
        self.num_attention_heads = num_attention_heads
95
        self.num_labels = num_labels
96
        self.intermediate_size = intermediate_size
97
        self.hidden_act = hidden_act
98
        self.hidden_dropout_prob = hidden_dropout_prob
99
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
100
        self.max_position_embeddings = max_position_embeddings
101
        self.type_vocab_size = type_vocab_size
102
        self.initializer_range = initializer_range
103
        self.layer_norm_eps = layer_norm_eps
104
        self.pad_token_id = pad_token_id
105
        self.num_qa_labels = num_qa_labels
106
        self.num_object_labels = num_object_labels
107
        self.num_attr_labels = num_attr_labels
108
        self.l_layers = l_layers
109
        self.x_layers = x_layers
110
        self.r_layers = r_layers
111
        self.visual_feat_dim = visual_feat_dim
112
        self.visual_pos_dim = visual_pos_dim
113
        self.visual_loss_normalizer = visual_loss_normalizer
114
        self.seq_length = seq_length
115
        self.batch_size = batch_size
116
        self.is_training = is_training
117
        self.use_lang_mask = use_lang_mask
118
        self.task_matched = task_matched
119
        self.task_mask_lm = task_mask_lm
120
        self.task_obj_predict = task_obj_predict
121
        self.task_qa = task_qa
122
        self.visual_obj_loss = visual_obj_loss
123
        self.visual_attr_loss = visual_attr_loss
124
        self.visual_feat_loss = visual_feat_loss
125
        self.num_visual_features = num_visual_features
126
        self.use_token_type_ids = use_token_type_ids
127
        self.output_attentions = output_attentions
128
        self.output_hidden_states = output_hidden_states
129
        self.scope = scope
130
        self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
131

132
    def prepare_config_and_inputs(self):
133
        output_attentions = self.output_attentions
134
        input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size)
135
        visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim, device=torch_device)
136
        bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4, device=torch_device)
137

138
        input_mask = None
139
        if self.use_lang_mask:
140
            input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
141
        token_type_ids = None
142
        if self.use_token_type_ids:
143
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
144
        obj_labels = None
145
        if self.task_obj_predict:
146
            obj_labels = {}
147
        if self.visual_attr_loss and self.task_obj_predict:
148
            obj_labels["attr"] = (
149
                ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
150
                ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
151
            )
152
        if self.visual_feat_loss and self.task_obj_predict:
153
            obj_labels["feat"] = (
154
                ids_tensor(
155
                    [self.batch_size, self.num_visual_features, self.visual_feat_dim], self.num_visual_features
156
                ),
157
                ids_tensor([self.batch_size, self.num_visual_features], self.num_visual_features),
158
            )
159
        if self.visual_obj_loss and self.task_obj_predict:
160
            obj_labels["obj"] = (
161
                ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
162
                ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
163
            )
164
        ans = None
165
        if self.task_qa:
166
            ans = ids_tensor([self.batch_size], self.num_qa_labels)
167
        masked_lm_labels = None
168
        if self.task_mask_lm:
169
            masked_lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
170
        matched_label = None
171
        if self.task_matched:
172
            matched_label = ids_tensor([self.batch_size], self.num_labels)
173

174
        config = self.get_config()
175

176
        return (
177
            config,
178
            input_ids,
179
            visual_feats,
180
            bounding_boxes,
181
            token_type_ids,
182
            input_mask,
183
            obj_labels,
184
            masked_lm_labels,
185
            matched_label,
186
            ans,
187
            output_attentions,
188
        )
189

190
    def get_config(self):
191
        return LxmertConfig(
192
            vocab_size=self.vocab_size,
193
            hidden_size=self.hidden_size,
194
            num_attention_heads=self.num_attention_heads,
195
            num_labels=self.num_labels,
196
            intermediate_size=self.intermediate_size,
197
            hidden_act=self.hidden_act,
198
            hidden_dropout_prob=self.hidden_dropout_prob,
199
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
200
            max_position_embeddings=self.max_position_embeddings,
201
            type_vocab_size=self.type_vocab_size,
202
            initializer_range=self.initializer_range,
203
            layer_norm_eps=self.layer_norm_eps,
204
            pad_token_id=self.pad_token_id,
205
            num_qa_labels=self.num_qa_labels,
206
            num_object_labels=self.num_object_labels,
207
            num_attr_labels=self.num_attr_labels,
208
            l_layers=self.l_layers,
209
            x_layers=self.x_layers,
210
            r_layers=self.r_layers,
211
            visual_feat_dim=self.visual_feat_dim,
212
            visual_pos_dim=self.visual_pos_dim,
213
            visual_loss_normalizer=self.visual_loss_normalizer,
214
            task_matched=self.task_matched,
215
            task_mask_lm=self.task_mask_lm,
216
            task_obj_predict=self.task_obj_predict,
217
            task_qa=self.task_qa,
218
            visual_obj_loss=self.visual_obj_loss,
219
            visual_attr_loss=self.visual_attr_loss,
220
            visual_feat_loss=self.visual_feat_loss,
221
            output_attentions=self.output_attentions,
222
            output_hidden_states=self.output_hidden_states,
223
        )
224

225
    def create_and_check_lxmert_model(
226
        self,
227
        config,
228
        input_ids,
229
        visual_feats,
230
        bounding_boxes,
231
        token_type_ids,
232
        input_mask,
233
        obj_labels,
234
        masked_lm_labels,
235
        matched_label,
236
        ans,
237
        output_attentions,
238
    ):
239
        model = LxmertModel(config=config)
240
        model.to(torch_device)
241
        model.eval()
242
        result = model(
243
            input_ids,
244
            visual_feats,
245
            bounding_boxes,
246
            token_type_ids=token_type_ids,
247
            attention_mask=input_mask,
248
            output_attentions=output_attentions,
249
        )
250
        result = model(
251
            input_ids,
252
            visual_feats,
253
            bounding_boxes,
254
            token_type_ids=token_type_ids,
255
            attention_mask=input_mask,
256
            output_attentions=not output_attentions,
257
        )
258
        result = model(input_ids, visual_feats, bounding_boxes, return_dict=False)
259
        result = model(input_ids, visual_feats, bounding_boxes, return_dict=True)
260

261
        self.parent.assertEqual(result.language_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
262
        self.parent.assertEqual(
263
            result.vision_output.shape, (self.batch_size, self.num_visual_features, self.hidden_size)
264
        )
265
        self.parent.assertEqual(result.pooled_output.shape, (self.batch_size, self.hidden_size))
266

267
    def create_and_check_lxmert_for_question_answering(
268
        self,
269
        config,
270
        input_ids,
271
        visual_feats,
272
        bounding_boxes,
273
        token_type_ids,
274
        input_mask,
275
        obj_labels,
276
        masked_lm_labels,
277
        matched_label,
278
        ans,
279
        output_attentions,
280
    ):
281
        model = LxmertForQuestionAnswering(config=config)
282
        model.to(torch_device)
283
        model.eval()
284
        result = model(
285
            input_ids,
286
            visual_feats,
287
            bounding_boxes,
288
            token_type_ids=token_type_ids,
289
            attention_mask=input_mask,
290
            labels=ans,
291
            output_attentions=output_attentions,
292
        )
293
        result = model(input_ids, visual_feats, bounding_boxes, labels=ans)
294
        result = model(
295
            input_ids,
296
            visual_feats,
297
            bounding_boxes,
298
            labels=ans,
299
            token_type_ids=token_type_ids,
300
            attention_mask=input_mask,
301
            output_attentions=output_attentions,
302
        )
303
        result = model(
304
            input_ids,
305
            visual_feats,
306
            bounding_boxes,
307
            token_type_ids=token_type_ids,
308
            attention_mask=input_mask,
309
            labels=ans,
310
            output_attentions=not output_attentions,
311
        )
312

313
        self.parent.assertEqual(result.question_answering_score.shape, (self.batch_size, self.num_qa_labels))
314

315
    def create_and_check_lxmert_for_pretraining(
316
        self,
317
        config,
318
        input_ids,
319
        visual_feats,
320
        bounding_boxes,
321
        token_type_ids,
322
        input_mask,
323
        obj_labels,
324
        masked_lm_labels,
325
        matched_label,
326
        ans,
327
        output_attentions,
328
    ):
329
        model = LxmertForPreTraining(config=config)
330
        model.to(torch_device)
331
        model.eval()
332
        result = model(
333
            input_ids,
334
            visual_feats,
335
            bounding_boxes,
336
            token_type_ids=token_type_ids,
337
            attention_mask=input_mask,
338
            masked_lm_labels=masked_lm_labels,
339
            obj_labels=obj_labels,
340
            matched_label=matched_label,
341
            ans=ans,
342
            output_attentions=output_attentions,
343
        )
344
        result = model(
345
            input_ids,
346
            visual_feats,
347
            bounding_boxes,
348
            token_type_ids=token_type_ids,
349
            attention_mask=input_mask,
350
            masked_lm_labels=masked_lm_labels,
351
            output_attentions=not output_attentions,
352
            return_dict=False,
353
        )
354
        result = model(
355
            input_ids,
356
            visual_feats,
357
            bounding_boxes,
358
            token_type_ids=token_type_ids,
359
            attention_mask=input_mask,
360
            masked_lm_labels=masked_lm_labels,
361
        )
362
        result = model(
363
            input_ids,
364
            visual_feats,
365
            bounding_boxes,
366
            token_type_ids=token_type_ids,
367
            attention_mask=input_mask,
368
            obj_labels=obj_labels,
369
        )
370
        result = model(
371
            input_ids,
372
            visual_feats,
373
            bounding_boxes,
374
            token_type_ids=token_type_ids,
375
            attention_mask=input_mask,
376
            matched_label=matched_label,
377
        )
378
        result = model(
379
            input_ids,
380
            visual_feats,
381
            bounding_boxes,
382
            token_type_ids=token_type_ids,
383
            attention_mask=input_mask,
384
            ans=ans,
385
        )
386
        result = model(
387
            input_ids,
388
            visual_feats,
389
            bounding_boxes,
390
            token_type_ids=token_type_ids,
391
            attention_mask=input_mask,
392
            masked_lm_labels=masked_lm_labels,
393
            obj_labels=obj_labels,
394
            matched_label=matched_label,
395
            ans=ans,
396
            output_attentions=not output_attentions,
397
        )
398

399
        self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
400

401
    def resize_lxmert_num_qa_labels(
402
        self,
403
        config,
404
        input_ids,
405
        visual_feats,
406
        bounding_boxes,
407
        token_type_ids,
408
        input_mask,
409
        obj_labels,
410
        masked_lm_labels,
411
        matched_label,
412
        ans,
413
        output_attentions,
414
    ):
415
        start_labels = config.num_qa_labels
416
        num_large_labels = config.num_qa_labels * 2
417
        num_small_labels = int(config.num_qa_labels * 2)
418
        less_labels_ans = ids_tensor([self.batch_size], num_small_labels)
419
        more_labels_ans = ids_tensor([self.batch_size], num_large_labels)
420
        model_pretrain = LxmertForPreTraining(config=config).to(torch_device)
421
        model_qa = LxmertForQuestionAnswering(config=config).to(torch_device)
422
        config.num_labels = num_small_labels
423
        end_labels = config.num_labels
424

425
        result_pretrain = model_pretrain(
426
            input_ids,
427
            visual_feats,
428
            bounding_boxes,
429
            token_type_ids=token_type_ids,
430
            attention_mask=input_mask,
431
            ans=ans,
432
        )
433

434
        result_qa = model_qa(
435
            input_ids,
436
            visual_feats,
437
            bounding_boxes,
438
            labels=ans,
439
            token_type_ids=token_type_ids,
440
            attention_mask=input_mask,
441
        )
442

443
        model_pretrain.resize_num_qa_labels(num_small_labels)
444
        model_qa.resize_num_qa_labels(num_small_labels)
445

446
        result_pretrain_less = model_pretrain(
447
            input_ids,
448
            visual_feats,
449
            bounding_boxes,
450
            token_type_ids=token_type_ids,
451
            attention_mask=input_mask,
452
            ans=less_labels_ans,
453
        )
454

455
        result_qa_less = model_qa(
456
            input_ids,
457
            visual_feats,
458
            bounding_boxes,
459
            labels=less_labels_ans,
460
            token_type_ids=token_type_ids,
461
            attention_mask=input_mask,
462
        )
463

464
        model_pretrain.resize_num_qa_labels(num_large_labels)
465
        model_qa.resize_num_qa_labels(num_large_labels)
466

467
        result_pretrain_more = model_pretrain(
468
            input_ids,
469
            visual_feats,
470
            bounding_boxes,
471
            token_type_ids=token_type_ids,
472
            attention_mask=input_mask,
473
            ans=more_labels_ans,
474
        )
475

476
        result_qa_more = model_qa(
477
            input_ids,
478
            visual_feats,
479
            bounding_boxes,
480
            labels=more_labels_ans,
481
            token_type_ids=token_type_ids,
482
            attention_mask=input_mask,
483
        )
484

485
        model_qa_labels = model_qa.num_qa_labels
486

487
        self.parent.assertNotEqual(start_labels, end_labels)
488
        self.parent.assertNotEqual(model_qa_labels, start_labels)
489
        self.parent.assertEqual(result_qa.question_answering_score.shape, (self.batch_size, start_labels))
490
        self.parent.assertEqual(result_pretrain.question_answering_score.shape, (self.batch_size, start_labels))
491
        self.parent.assertEqual(result_qa_less.question_answering_score.shape, (self.batch_size, num_small_labels))
492
        self.parent.assertEqual(
493
            result_pretrain_less.question_answering_score.shape, (self.batch_size, num_small_labels)
494
        )
495
        self.parent.assertEqual(result_qa_more.question_answering_score.shape, (self.batch_size, num_large_labels))
496
        self.parent.assertEqual(
497
            result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels)
498
        )
499

500
    def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
501
        config_and_inputs = self.prepare_config_and_inputs()
502
        (
503
            config,
504
            input_ids,
505
            visual_feats,
506
            bounding_boxes,
507
            token_type_ids,
508
            input_mask,
509
            obj_labels,
510
            masked_lm_labels,
511
            matched_label,
512
            ans,
513
            output_attentions,
514
        ) = config_and_inputs
515

516
        inputs_dict = {
517
            "input_ids": input_ids,
518
            "visual_feats": visual_feats,
519
            "visual_pos": bounding_boxes,
520
            "token_type_ids": token_type_ids,
521
            "attention_mask": input_mask,
522
        }
523

524
        if return_obj_labels:
525
            inputs_dict["obj_labels"] = obj_labels
526
        else:
527
            config.task_obj_predict = False
528

529
        return config, inputs_dict
530

531

532
@require_torch
533
class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
534
    all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
535
    pipeline_model_mapping = (
536
        {"feature-extraction": LxmertModel, "question-answering": LxmertForQuestionAnswering}
537
        if is_torch_available()
538
        else {}
539
    )
540

541
    fx_compatible = True
542
    test_head_masking = False
543
    test_pruning = False
544
    test_torchscript = False
545

546
    # overwrite function because qa models takes different input label shape
547
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
548
        inputs_dict = copy.deepcopy(inputs_dict)
549

550
        if return_labels:
551
            if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
552
                inputs_dict["labels"] = torch.zeros(
553
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
554
                )
555
            elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
556
                # special case for models like BERT that use multi-loss training for PreTraining
557
                inputs_dict["labels"] = torch.zeros(
558
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
559
                )
560
        return inputs_dict
561

562
    def setUp(self):
563
        self.model_tester = LxmertModelTester(self)
564
        self.config_tester = ConfigTester(self, config_class=LxmertConfig, hidden_size=37)
565

566
    def test_config(self):
567
        self.config_tester.run_common_tests()
568

569
    def test_lxmert_model(self):
570
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
571
        self.model_tester.create_and_check_lxmert_model(*config_and_inputs)
572

573
    def test_lxmert_question_answering(self):
574
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
575
        self.model_tester.create_and_check_lxmert_for_question_answering(*config_and_inputs)
576

577
    def test_lxmert_pretraining(self):
578
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
579
        self.model_tester.create_and_check_lxmert_for_pretraining(*config_and_inputs)
580

581
    def test_lxmert_question_answering_labels_resize(self):
582
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
583
        self.model_tester.resize_lxmert_num_qa_labels(*config_and_inputs)
584

585
    @slow
586
    def test_model_from_pretrained(self):
587
        for model_name in LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
588
            model = LxmertModel.from_pretrained(model_name)
589
            model.to(torch_device)
590
            self.assertIsNotNone(model)
591

592
    def test_attention_outputs(self):
593
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
594
        seq_len = getattr(self.model_tester, "seq_length", None)
595
        encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
596
        encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
597
        chunk_length = getattr(self.model_tester, "chunk_length", None)
598
        if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
599
            encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
600

601
        for model_class in self.all_model_classes:
602
            inputs_dict["output_attentions"] = True
603
            inputs_dict["output_hidden_states"] = False
604
            model = model_class(config)
605
            model.to(torch_device)
606
            model.eval()
607
            with torch.no_grad():
608
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
609

610
            language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
611

612
            self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
613
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
614
            self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
615

616
            # check that output_attentions also work using config
617
            del inputs_dict["output_attentions"]
618
            config.output_attentions = True
619
            model = model_class(config)
620
            model.to(torch_device)
621
            model.eval()
622
            with torch.no_grad():
623
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
624

625
            language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
626
            self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
627
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
628
            self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
629

630
            attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
631
            attention_shapes = [
632
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
633
                [
634
                    self.model_tester.num_attention_heads,
635
                    self.model_tester.num_visual_features,
636
                    self.model_tester.num_visual_features,
637
                ],
638
                [self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
639
            ]
640

641
            for attention, attention_shape in zip(attentions, attention_shapes):
642
                self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
643
            out_len = len(outputs)
644

645
            # Check attention is always last and order is fine
646
            inputs_dict["output_attentions"] = True
647
            inputs_dict["output_hidden_states"] = True
648
            model = model_class(config)
649
            model.to(torch_device)
650
            model.eval()
651
            with torch.no_grad():
652
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
653

654
            # 2 hidden states were added
655
            self.assertEqual(out_len + 2, len(outputs))
656

657
            language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
658
            self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
659
            self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
660
            self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
661

662
            attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
663
            attention_shapes = [
664
                [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
665
                [
666
                    self.model_tester.num_attention_heads,
667
                    self.model_tester.num_visual_features,
668
                    self.model_tester.num_visual_features,
669
                ],
670
                [self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
671
            ]
672

673
            for attention, attention_shape in zip(attentions, attention_shapes):
674
                self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
675

676
    def test_hidden_states_output(self):
677
        def check_hidden_states_output(inputs_dict, config, model_class):
678
            model = model_class(config)
679
            model.to(torch_device)
680
            model.eval()
681

682
            with torch.no_grad():
683
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
684
            language_hidden_states, vision_hidden_states = outputs[-2], outputs[-1]
685

686
            self.assertEqual(len(language_hidden_states), self.model_tester.num_hidden_layers["language"] + 1)
687
            self.assertEqual(len(vision_hidden_states), self.model_tester.num_hidden_layers["vision"] + 1)
688

689
            seq_length = self.model_tester.seq_length
690
            num_visual_features = self.model_tester.num_visual_features
691

692
            self.assertListEqual(
693
                list(language_hidden_states[0].shape[-2:]),
694
                [seq_length, self.model_tester.hidden_size],
695
            )
696
            self.assertListEqual(
697
                list(vision_hidden_states[0].shape[-2:]),
698
                [num_visual_features, self.model_tester.hidden_size],
699
            )
700

701
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
702

703
        for model_class in self.all_model_classes:
704
            inputs_dict["output_hidden_states"] = True
705
            check_hidden_states_output(inputs_dict, config, model_class)
706

707
            # check that output_hidden_states also work using config
708
            del inputs_dict["output_hidden_states"]
709
            config.output_hidden_states = True
710

711
            check_hidden_states_output(inputs_dict, config, model_class)
712

713
    def test_retain_grad_hidden_states_attentions(self):
714
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
715
        config.output_hidden_states = True
716
        config.output_attentions = True
717

718
        # no need to test all models as different heads yield the same functionality
719
        model_class = self.all_model_classes[0]
720
        model = model_class(config)
721
        model.to(torch_device)
722

723
        inputs = self._prepare_for_class(inputs_dict, model_class)
724

725
        outputs = model(**inputs)
726

727
        hidden_states_lang = outputs.language_hidden_states[0]
728
        attentions_lang = outputs.language_attentions[0]
729

730
        hidden_states_vision = outputs.vision_hidden_states[0]
731
        attentions_vision = outputs.vision_attentions[0]
732

733
        hidden_states_lang.retain_grad()
734
        attentions_lang.retain_grad()
735
        hidden_states_vision.retain_grad()
736
        attentions_vision.retain_grad()
737

738
        outputs.language_output.flatten()[0].backward(retain_graph=True)
739
        outputs.vision_output.flatten()[0].backward(retain_graph=True)
740

741
        self.assertIsNotNone(hidden_states_lang.grad)
742
        self.assertIsNotNone(attentions_vision.grad)
743
        self.assertIsNotNone(hidden_states_vision.grad)
744
        self.assertIsNotNone(attentions_vision.grad)
745

746
    def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
747
        tf_inputs_dict = {}
748
        for key, value in pt_inputs_dict.items():
749
            # skip key that does not exist in tf
750
            if isinstance(value, dict):
751
                tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
752
            elif isinstance(value, (list, tuple)):
753
                tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
754
            elif isinstance(value, bool):
755
                tf_inputs_dict[key] = value
756
            elif key == "input_values":
757
                tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
758
            elif key == "pixel_values":
759
                tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
760
            elif key == "input_features":
761
                tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
762
            # other general float inputs
763
            elif value.is_floating_point():
764
                tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
765
            else:
766
                tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
767

768
        return tf_inputs_dict
769

770

771
@require_torch
772
class LxmertModelIntegrationTest(unittest.TestCase):
773
    @slow
774
    def test_inference_no_head_absolute_embedding(self):
775
        model = LxmertModel.from_pretrained(LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
776
        input_ids = torch.tensor([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
777
        num_visual_features = 10
778
        _, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
779
        _, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
780
        visual_feats = torch.as_tensor(visual_feats, dtype=torch.float32)
781
        visual_pos = torch.as_tensor(visual_pos, dtype=torch.float32)
782
        output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
783
        expected_shape = torch.Size([1, 11, 768])
784
        self.assertEqual(expected_shape, output.shape)
785
        expected_slice = torch.tensor(
786
            [[[0.2417, -0.9807, 0.1480], [1.2541, -0.8320, 0.5112], [1.4070, -1.1052, 0.6990]]]
787
        )
788

789
        self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
790

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.