transformers
789 строк · 30.2 Кб
1# coding=utf-8
2# Copyright 2018 LXMERT Authors, The Hugging Face Team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import copy
18import unittest
19
20import numpy as np
21
22from transformers import LxmertConfig, is_tf_available, is_torch_available
23from transformers.models.auto import get_values
24from transformers.testing_utils import require_torch, slow, torch_device
25
26from ...test_configuration_common import ConfigTester
27from ...test_modeling_common import ModelTesterMixin, ids_tensor
28from ...test_pipeline_mixin import PipelineTesterMixin
29
30
31if is_torch_available():
32import torch
33
34from transformers import (
35MODEL_FOR_PRETRAINING_MAPPING,
36MODEL_FOR_QUESTION_ANSWERING_MAPPING,
37LxmertForPreTraining,
38LxmertForQuestionAnswering,
39LxmertModel,
40)
41from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST
42
43
44if is_tf_available():
45import tensorflow as tf
46
47
48class LxmertModelTester:
49def __init__(
50self,
51parent,
52vocab_size=300,
53hidden_size=28,
54num_attention_heads=2,
55num_labels=2,
56intermediate_size=64,
57hidden_act="gelu",
58hidden_dropout_prob=0.1,
59attention_probs_dropout_prob=0.1,
60max_position_embeddings=512,
61type_vocab_size=2,
62initializer_range=0.02,
63layer_norm_eps=1e-12,
64pad_token_id=0,
65num_qa_labels=30,
66num_object_labels=16,
67num_attr_labels=4,
68num_visual_features=10,
69l_layers=2,
70x_layers=1,
71r_layers=1,
72visual_feat_dim=128,
73visual_pos_dim=4,
74visual_loss_normalizer=6.67,
75seq_length=20,
76batch_size=4,
77is_training=True,
78task_matched=True,
79task_mask_lm=True,
80task_obj_predict=True,
81task_qa=True,
82visual_obj_loss=True,
83visual_attr_loss=True,
84visual_feat_loss=True,
85use_token_type_ids=True,
86use_lang_mask=True,
87output_attentions=False,
88output_hidden_states=False,
89scope=None,
90):
91self.parent = parent
92self.vocab_size = vocab_size
93self.hidden_size = hidden_size
94self.num_attention_heads = num_attention_heads
95self.num_labels = num_labels
96self.intermediate_size = intermediate_size
97self.hidden_act = hidden_act
98self.hidden_dropout_prob = hidden_dropout_prob
99self.attention_probs_dropout_prob = attention_probs_dropout_prob
100self.max_position_embeddings = max_position_embeddings
101self.type_vocab_size = type_vocab_size
102self.initializer_range = initializer_range
103self.layer_norm_eps = layer_norm_eps
104self.pad_token_id = pad_token_id
105self.num_qa_labels = num_qa_labels
106self.num_object_labels = num_object_labels
107self.num_attr_labels = num_attr_labels
108self.l_layers = l_layers
109self.x_layers = x_layers
110self.r_layers = r_layers
111self.visual_feat_dim = visual_feat_dim
112self.visual_pos_dim = visual_pos_dim
113self.visual_loss_normalizer = visual_loss_normalizer
114self.seq_length = seq_length
115self.batch_size = batch_size
116self.is_training = is_training
117self.use_lang_mask = use_lang_mask
118self.task_matched = task_matched
119self.task_mask_lm = task_mask_lm
120self.task_obj_predict = task_obj_predict
121self.task_qa = task_qa
122self.visual_obj_loss = visual_obj_loss
123self.visual_attr_loss = visual_attr_loss
124self.visual_feat_loss = visual_feat_loss
125self.num_visual_features = num_visual_features
126self.use_token_type_ids = use_token_type_ids
127self.output_attentions = output_attentions
128self.output_hidden_states = output_hidden_states
129self.scope = scope
130self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
131
132def prepare_config_and_inputs(self):
133output_attentions = self.output_attentions
134input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size)
135visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim, device=torch_device)
136bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4, device=torch_device)
137
138input_mask = None
139if self.use_lang_mask:
140input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
141token_type_ids = None
142if self.use_token_type_ids:
143token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
144obj_labels = None
145if self.task_obj_predict:
146obj_labels = {}
147if self.visual_attr_loss and self.task_obj_predict:
148obj_labels["attr"] = (
149ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
150ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
151)
152if self.visual_feat_loss and self.task_obj_predict:
153obj_labels["feat"] = (
154ids_tensor(
155[self.batch_size, self.num_visual_features, self.visual_feat_dim], self.num_visual_features
156),
157ids_tensor([self.batch_size, self.num_visual_features], self.num_visual_features),
158)
159if self.visual_obj_loss and self.task_obj_predict:
160obj_labels["obj"] = (
161ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
162ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
163)
164ans = None
165if self.task_qa:
166ans = ids_tensor([self.batch_size], self.num_qa_labels)
167masked_lm_labels = None
168if self.task_mask_lm:
169masked_lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
170matched_label = None
171if self.task_matched:
172matched_label = ids_tensor([self.batch_size], self.num_labels)
173
174config = self.get_config()
175
176return (
177config,
178input_ids,
179visual_feats,
180bounding_boxes,
181token_type_ids,
182input_mask,
183obj_labels,
184masked_lm_labels,
185matched_label,
186ans,
187output_attentions,
188)
189
190def get_config(self):
191return LxmertConfig(
192vocab_size=self.vocab_size,
193hidden_size=self.hidden_size,
194num_attention_heads=self.num_attention_heads,
195num_labels=self.num_labels,
196intermediate_size=self.intermediate_size,
197hidden_act=self.hidden_act,
198hidden_dropout_prob=self.hidden_dropout_prob,
199attention_probs_dropout_prob=self.attention_probs_dropout_prob,
200max_position_embeddings=self.max_position_embeddings,
201type_vocab_size=self.type_vocab_size,
202initializer_range=self.initializer_range,
203layer_norm_eps=self.layer_norm_eps,
204pad_token_id=self.pad_token_id,
205num_qa_labels=self.num_qa_labels,
206num_object_labels=self.num_object_labels,
207num_attr_labels=self.num_attr_labels,
208l_layers=self.l_layers,
209x_layers=self.x_layers,
210r_layers=self.r_layers,
211visual_feat_dim=self.visual_feat_dim,
212visual_pos_dim=self.visual_pos_dim,
213visual_loss_normalizer=self.visual_loss_normalizer,
214task_matched=self.task_matched,
215task_mask_lm=self.task_mask_lm,
216task_obj_predict=self.task_obj_predict,
217task_qa=self.task_qa,
218visual_obj_loss=self.visual_obj_loss,
219visual_attr_loss=self.visual_attr_loss,
220visual_feat_loss=self.visual_feat_loss,
221output_attentions=self.output_attentions,
222output_hidden_states=self.output_hidden_states,
223)
224
225def create_and_check_lxmert_model(
226self,
227config,
228input_ids,
229visual_feats,
230bounding_boxes,
231token_type_ids,
232input_mask,
233obj_labels,
234masked_lm_labels,
235matched_label,
236ans,
237output_attentions,
238):
239model = LxmertModel(config=config)
240model.to(torch_device)
241model.eval()
242result = model(
243input_ids,
244visual_feats,
245bounding_boxes,
246token_type_ids=token_type_ids,
247attention_mask=input_mask,
248output_attentions=output_attentions,
249)
250result = model(
251input_ids,
252visual_feats,
253bounding_boxes,
254token_type_ids=token_type_ids,
255attention_mask=input_mask,
256output_attentions=not output_attentions,
257)
258result = model(input_ids, visual_feats, bounding_boxes, return_dict=False)
259result = model(input_ids, visual_feats, bounding_boxes, return_dict=True)
260
261self.parent.assertEqual(result.language_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
262self.parent.assertEqual(
263result.vision_output.shape, (self.batch_size, self.num_visual_features, self.hidden_size)
264)
265self.parent.assertEqual(result.pooled_output.shape, (self.batch_size, self.hidden_size))
266
267def create_and_check_lxmert_for_question_answering(
268self,
269config,
270input_ids,
271visual_feats,
272bounding_boxes,
273token_type_ids,
274input_mask,
275obj_labels,
276masked_lm_labels,
277matched_label,
278ans,
279output_attentions,
280):
281model = LxmertForQuestionAnswering(config=config)
282model.to(torch_device)
283model.eval()
284result = model(
285input_ids,
286visual_feats,
287bounding_boxes,
288token_type_ids=token_type_ids,
289attention_mask=input_mask,
290labels=ans,
291output_attentions=output_attentions,
292)
293result = model(input_ids, visual_feats, bounding_boxes, labels=ans)
294result = model(
295input_ids,
296visual_feats,
297bounding_boxes,
298labels=ans,
299token_type_ids=token_type_ids,
300attention_mask=input_mask,
301output_attentions=output_attentions,
302)
303result = model(
304input_ids,
305visual_feats,
306bounding_boxes,
307token_type_ids=token_type_ids,
308attention_mask=input_mask,
309labels=ans,
310output_attentions=not output_attentions,
311)
312
313self.parent.assertEqual(result.question_answering_score.shape, (self.batch_size, self.num_qa_labels))
314
315def create_and_check_lxmert_for_pretraining(
316self,
317config,
318input_ids,
319visual_feats,
320bounding_boxes,
321token_type_ids,
322input_mask,
323obj_labels,
324masked_lm_labels,
325matched_label,
326ans,
327output_attentions,
328):
329model = LxmertForPreTraining(config=config)
330model.to(torch_device)
331model.eval()
332result = model(
333input_ids,
334visual_feats,
335bounding_boxes,
336token_type_ids=token_type_ids,
337attention_mask=input_mask,
338masked_lm_labels=masked_lm_labels,
339obj_labels=obj_labels,
340matched_label=matched_label,
341ans=ans,
342output_attentions=output_attentions,
343)
344result = model(
345input_ids,
346visual_feats,
347bounding_boxes,
348token_type_ids=token_type_ids,
349attention_mask=input_mask,
350masked_lm_labels=masked_lm_labels,
351output_attentions=not output_attentions,
352return_dict=False,
353)
354result = model(
355input_ids,
356visual_feats,
357bounding_boxes,
358token_type_ids=token_type_ids,
359attention_mask=input_mask,
360masked_lm_labels=masked_lm_labels,
361)
362result = model(
363input_ids,
364visual_feats,
365bounding_boxes,
366token_type_ids=token_type_ids,
367attention_mask=input_mask,
368obj_labels=obj_labels,
369)
370result = model(
371input_ids,
372visual_feats,
373bounding_boxes,
374token_type_ids=token_type_ids,
375attention_mask=input_mask,
376matched_label=matched_label,
377)
378result = model(
379input_ids,
380visual_feats,
381bounding_boxes,
382token_type_ids=token_type_ids,
383attention_mask=input_mask,
384ans=ans,
385)
386result = model(
387input_ids,
388visual_feats,
389bounding_boxes,
390token_type_ids=token_type_ids,
391attention_mask=input_mask,
392masked_lm_labels=masked_lm_labels,
393obj_labels=obj_labels,
394matched_label=matched_label,
395ans=ans,
396output_attentions=not output_attentions,
397)
398
399self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
400
401def resize_lxmert_num_qa_labels(
402self,
403config,
404input_ids,
405visual_feats,
406bounding_boxes,
407token_type_ids,
408input_mask,
409obj_labels,
410masked_lm_labels,
411matched_label,
412ans,
413output_attentions,
414):
415start_labels = config.num_qa_labels
416num_large_labels = config.num_qa_labels * 2
417num_small_labels = int(config.num_qa_labels * 2)
418less_labels_ans = ids_tensor([self.batch_size], num_small_labels)
419more_labels_ans = ids_tensor([self.batch_size], num_large_labels)
420model_pretrain = LxmertForPreTraining(config=config).to(torch_device)
421model_qa = LxmertForQuestionAnswering(config=config).to(torch_device)
422config.num_labels = num_small_labels
423end_labels = config.num_labels
424
425result_pretrain = model_pretrain(
426input_ids,
427visual_feats,
428bounding_boxes,
429token_type_ids=token_type_ids,
430attention_mask=input_mask,
431ans=ans,
432)
433
434result_qa = model_qa(
435input_ids,
436visual_feats,
437bounding_boxes,
438labels=ans,
439token_type_ids=token_type_ids,
440attention_mask=input_mask,
441)
442
443model_pretrain.resize_num_qa_labels(num_small_labels)
444model_qa.resize_num_qa_labels(num_small_labels)
445
446result_pretrain_less = model_pretrain(
447input_ids,
448visual_feats,
449bounding_boxes,
450token_type_ids=token_type_ids,
451attention_mask=input_mask,
452ans=less_labels_ans,
453)
454
455result_qa_less = model_qa(
456input_ids,
457visual_feats,
458bounding_boxes,
459labels=less_labels_ans,
460token_type_ids=token_type_ids,
461attention_mask=input_mask,
462)
463
464model_pretrain.resize_num_qa_labels(num_large_labels)
465model_qa.resize_num_qa_labels(num_large_labels)
466
467result_pretrain_more = model_pretrain(
468input_ids,
469visual_feats,
470bounding_boxes,
471token_type_ids=token_type_ids,
472attention_mask=input_mask,
473ans=more_labels_ans,
474)
475
476result_qa_more = model_qa(
477input_ids,
478visual_feats,
479bounding_boxes,
480labels=more_labels_ans,
481token_type_ids=token_type_ids,
482attention_mask=input_mask,
483)
484
485model_qa_labels = model_qa.num_qa_labels
486
487self.parent.assertNotEqual(start_labels, end_labels)
488self.parent.assertNotEqual(model_qa_labels, start_labels)
489self.parent.assertEqual(result_qa.question_answering_score.shape, (self.batch_size, start_labels))
490self.parent.assertEqual(result_pretrain.question_answering_score.shape, (self.batch_size, start_labels))
491self.parent.assertEqual(result_qa_less.question_answering_score.shape, (self.batch_size, num_small_labels))
492self.parent.assertEqual(
493result_pretrain_less.question_answering_score.shape, (self.batch_size, num_small_labels)
494)
495self.parent.assertEqual(result_qa_more.question_answering_score.shape, (self.batch_size, num_large_labels))
496self.parent.assertEqual(
497result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels)
498)
499
500def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
501config_and_inputs = self.prepare_config_and_inputs()
502(
503config,
504input_ids,
505visual_feats,
506bounding_boxes,
507token_type_ids,
508input_mask,
509obj_labels,
510masked_lm_labels,
511matched_label,
512ans,
513output_attentions,
514) = config_and_inputs
515
516inputs_dict = {
517"input_ids": input_ids,
518"visual_feats": visual_feats,
519"visual_pos": bounding_boxes,
520"token_type_ids": token_type_ids,
521"attention_mask": input_mask,
522}
523
524if return_obj_labels:
525inputs_dict["obj_labels"] = obj_labels
526else:
527config.task_obj_predict = False
528
529return config, inputs_dict
530
531
532@require_torch
533class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
534all_model_classes = (LxmertModel, LxmertForPreTraining, LxmertForQuestionAnswering) if is_torch_available() else ()
535pipeline_model_mapping = (
536{"feature-extraction": LxmertModel, "question-answering": LxmertForQuestionAnswering}
537if is_torch_available()
538else {}
539)
540
541fx_compatible = True
542test_head_masking = False
543test_pruning = False
544test_torchscript = False
545
546# overwrite function because qa models takes different input label shape
547def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
548inputs_dict = copy.deepcopy(inputs_dict)
549
550if return_labels:
551if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
552inputs_dict["labels"] = torch.zeros(
553self.model_tester.batch_size, dtype=torch.long, device=torch_device
554)
555elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
556# special case for models like BERT that use multi-loss training for PreTraining
557inputs_dict["labels"] = torch.zeros(
558(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
559)
560return inputs_dict
561
562def setUp(self):
563self.model_tester = LxmertModelTester(self)
564self.config_tester = ConfigTester(self, config_class=LxmertConfig, hidden_size=37)
565
566def test_config(self):
567self.config_tester.run_common_tests()
568
569def test_lxmert_model(self):
570config_and_inputs = self.model_tester.prepare_config_and_inputs()
571self.model_tester.create_and_check_lxmert_model(*config_and_inputs)
572
573def test_lxmert_question_answering(self):
574config_and_inputs = self.model_tester.prepare_config_and_inputs()
575self.model_tester.create_and_check_lxmert_for_question_answering(*config_and_inputs)
576
577def test_lxmert_pretraining(self):
578config_and_inputs = self.model_tester.prepare_config_and_inputs()
579self.model_tester.create_and_check_lxmert_for_pretraining(*config_and_inputs)
580
581def test_lxmert_question_answering_labels_resize(self):
582config_and_inputs = self.model_tester.prepare_config_and_inputs()
583self.model_tester.resize_lxmert_num_qa_labels(*config_and_inputs)
584
585@slow
586def test_model_from_pretrained(self):
587for model_name in LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
588model = LxmertModel.from_pretrained(model_name)
589model.to(torch_device)
590self.assertIsNotNone(model)
591
592def test_attention_outputs(self):
593config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
594seq_len = getattr(self.model_tester, "seq_length", None)
595encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
596encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
597chunk_length = getattr(self.model_tester, "chunk_length", None)
598if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
599encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
600
601for model_class in self.all_model_classes:
602inputs_dict["output_attentions"] = True
603inputs_dict["output_hidden_states"] = False
604model = model_class(config)
605model.to(torch_device)
606model.eval()
607with torch.no_grad():
608outputs = model(**self._prepare_for_class(inputs_dict, model_class))
609
610language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
611
612self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
613self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
614self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
615
616# check that output_attentions also work using config
617del inputs_dict["output_attentions"]
618config.output_attentions = True
619model = model_class(config)
620model.to(torch_device)
621model.eval()
622with torch.no_grad():
623outputs = model(**self._prepare_for_class(inputs_dict, model_class))
624
625language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
626self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
627self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
628self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
629
630attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
631attention_shapes = [
632[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
633[
634self.model_tester.num_attention_heads,
635self.model_tester.num_visual_features,
636self.model_tester.num_visual_features,
637],
638[self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
639]
640
641for attention, attention_shape in zip(attentions, attention_shapes):
642self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
643out_len = len(outputs)
644
645# Check attention is always last and order is fine
646inputs_dict["output_attentions"] = True
647inputs_dict["output_hidden_states"] = True
648model = model_class(config)
649model.to(torch_device)
650model.eval()
651with torch.no_grad():
652outputs = model(**self._prepare_for_class(inputs_dict, model_class))
653
654# 2 hidden states were added
655self.assertEqual(out_len + 2, len(outputs))
656
657language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
658self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
659self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
660self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
661
662attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
663attention_shapes = [
664[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
665[
666self.model_tester.num_attention_heads,
667self.model_tester.num_visual_features,
668self.model_tester.num_visual_features,
669],
670[self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
671]
672
673for attention, attention_shape in zip(attentions, attention_shapes):
674self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
675
676def test_hidden_states_output(self):
677def check_hidden_states_output(inputs_dict, config, model_class):
678model = model_class(config)
679model.to(torch_device)
680model.eval()
681
682with torch.no_grad():
683outputs = model(**self._prepare_for_class(inputs_dict, model_class))
684language_hidden_states, vision_hidden_states = outputs[-2], outputs[-1]
685
686self.assertEqual(len(language_hidden_states), self.model_tester.num_hidden_layers["language"] + 1)
687self.assertEqual(len(vision_hidden_states), self.model_tester.num_hidden_layers["vision"] + 1)
688
689seq_length = self.model_tester.seq_length
690num_visual_features = self.model_tester.num_visual_features
691
692self.assertListEqual(
693list(language_hidden_states[0].shape[-2:]),
694[seq_length, self.model_tester.hidden_size],
695)
696self.assertListEqual(
697list(vision_hidden_states[0].shape[-2:]),
698[num_visual_features, self.model_tester.hidden_size],
699)
700
701config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
702
703for model_class in self.all_model_classes:
704inputs_dict["output_hidden_states"] = True
705check_hidden_states_output(inputs_dict, config, model_class)
706
707# check that output_hidden_states also work using config
708del inputs_dict["output_hidden_states"]
709config.output_hidden_states = True
710
711check_hidden_states_output(inputs_dict, config, model_class)
712
713def test_retain_grad_hidden_states_attentions(self):
714config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
715config.output_hidden_states = True
716config.output_attentions = True
717
718# no need to test all models as different heads yield the same functionality
719model_class = self.all_model_classes[0]
720model = model_class(config)
721model.to(torch_device)
722
723inputs = self._prepare_for_class(inputs_dict, model_class)
724
725outputs = model(**inputs)
726
727hidden_states_lang = outputs.language_hidden_states[0]
728attentions_lang = outputs.language_attentions[0]
729
730hidden_states_vision = outputs.vision_hidden_states[0]
731attentions_vision = outputs.vision_attentions[0]
732
733hidden_states_lang.retain_grad()
734attentions_lang.retain_grad()
735hidden_states_vision.retain_grad()
736attentions_vision.retain_grad()
737
738outputs.language_output.flatten()[0].backward(retain_graph=True)
739outputs.vision_output.flatten()[0].backward(retain_graph=True)
740
741self.assertIsNotNone(hidden_states_lang.grad)
742self.assertIsNotNone(attentions_vision.grad)
743self.assertIsNotNone(hidden_states_vision.grad)
744self.assertIsNotNone(attentions_vision.grad)
745
746def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
747tf_inputs_dict = {}
748for key, value in pt_inputs_dict.items():
749# skip key that does not exist in tf
750if isinstance(value, dict):
751tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
752elif isinstance(value, (list, tuple)):
753tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
754elif isinstance(value, bool):
755tf_inputs_dict[key] = value
756elif key == "input_values":
757tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
758elif key == "pixel_values":
759tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
760elif key == "input_features":
761tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
762# other general float inputs
763elif value.is_floating_point():
764tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
765else:
766tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
767
768return tf_inputs_dict
769
770
771@require_torch
772class LxmertModelIntegrationTest(unittest.TestCase):
773@slow
774def test_inference_no_head_absolute_embedding(self):
775model = LxmertModel.from_pretrained(LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[0])
776input_ids = torch.tensor([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
777num_visual_features = 10
778_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
779_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
780visual_feats = torch.as_tensor(visual_feats, dtype=torch.float32)
781visual_pos = torch.as_tensor(visual_pos, dtype=torch.float32)
782output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
783expected_shape = torch.Size([1, 11, 768])
784self.assertEqual(expected_shape, output.shape)
785expected_slice = torch.tensor(
786[[[0.2417, -0.9807, 0.1480], [1.2541, -0.8320, 0.5112], [1.4070, -1.1052, 0.6990]]]
787)
788
789self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
790