transformers
559 строк · 21.8 Кб
1# coding=utf-8
2# Copyright 2020 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16from __future__ import annotations
17
18import tempfile
19import unittest
20
21import numpy as np
22
23from transformers import LxmertConfig, is_tf_available
24from transformers.testing_utils import require_tf, slow
25
26from ...test_configuration_common import ConfigTester
27from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
28from ...test_pipeline_mixin import PipelineTesterMixin
29
30
31if is_tf_available():
32import tensorflow as tf
33
34from transformers.models.lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
35
36
37class TFLxmertModelTester(object):
38def __init__(
39self,
40parent,
41vocab_size=300,
42hidden_size=28,
43num_attention_heads=2,
44num_labels=2,
45intermediate_size=64,
46hidden_act="gelu",
47hidden_dropout_prob=0.1,
48attention_probs_dropout_prob=0.1,
49max_position_embeddings=512,
50type_vocab_size=2,
51initializer_range=0.02,
52layer_norm_eps=1e-12,
53pad_token_id=0,
54num_qa_labels=30,
55num_object_labels=16,
56num_attr_labels=4,
57num_visual_features=10,
58l_layers=2,
59x_layers=1,
60r_layers=1,
61visual_feat_dim=128,
62visual_pos_dim=4,
63visual_loss_normalizer=6.67,
64seq_length=20,
65batch_size=8,
66is_training=True,
67task_matched=True,
68task_mask_lm=True,
69task_obj_predict=True,
70task_qa=True,
71visual_obj_loss=True,
72visual_attr_loss=True,
73visual_feat_loss=True,
74use_token_type_ids=True,
75use_lang_mask=True,
76output_attentions=False,
77output_hidden_states=False,
78scope=None,
79):
80self.parent = parent
81self.vocab_size = vocab_size
82self.hidden_size = hidden_size
83self.num_attention_heads = num_attention_heads
84self.num_labels = num_labels
85self.intermediate_size = intermediate_size
86self.hidden_act = hidden_act
87self.hidden_dropout_prob = hidden_dropout_prob
88self.attention_probs_dropout_prob = attention_probs_dropout_prob
89self.max_position_embeddings = max_position_embeddings
90self.type_vocab_size = type_vocab_size
91self.initializer_range = initializer_range
92self.layer_norm_eps = layer_norm_eps
93self.pad_token_id = pad_token_id
94self.num_qa_labels = num_qa_labels
95self.num_object_labels = num_object_labels
96self.num_attr_labels = num_attr_labels
97self.l_layers = l_layers
98self.x_layers = x_layers
99self.r_layers = r_layers
100self.visual_feat_dim = visual_feat_dim
101self.visual_pos_dim = visual_pos_dim
102self.visual_loss_normalizer = visual_loss_normalizer
103self.seq_length = seq_length
104self.batch_size = batch_size
105self.is_training = is_training
106self.use_lang_mask = use_lang_mask
107self.task_matched = task_matched
108self.task_mask_lm = task_mask_lm
109self.task_obj_predict = task_obj_predict
110self.task_qa = task_qa
111self.visual_obj_loss = visual_obj_loss
112self.visual_attr_loss = visual_attr_loss
113self.visual_feat_loss = visual_feat_loss
114self.num_visual_features = num_visual_features
115self.use_token_type_ids = use_token_type_ids
116self.output_attentions = output_attentions
117self.output_hidden_states = output_hidden_states
118self.scope = scope
119self.num_hidden_layers = {"vision": r_layers, "cross_encoder": x_layers, "language": l_layers}
120
121def prepare_config_and_inputs(self):
122output_attentions = self.output_attentions
123input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size)
124visual_feats = tf.random.uniform((self.batch_size, self.num_visual_features, self.visual_feat_dim))
125bounding_boxes = tf.random.uniform((self.batch_size, self.num_visual_features, 4))
126
127input_mask = None
128if self.use_lang_mask:
129input_mask = random_attention_mask([self.batch_size, self.seq_length])
130token_type_ids = None
131if self.use_token_type_ids:
132token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
133obj_labels = None
134if self.task_obj_predict:
135obj_labels = {}
136if self.visual_attr_loss and self.task_obj_predict:
137obj_labels["attr"] = (
138ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
139ids_tensor([self.batch_size, self.num_visual_features], self.num_attr_labels),
140)
141if self.visual_feat_loss and self.task_obj_predict:
142obj_labels["feat"] = (
143ids_tensor(
144[self.batch_size, self.num_visual_features, self.visual_feat_dim], self.num_visual_features
145),
146ids_tensor([self.batch_size, self.num_visual_features], self.num_visual_features),
147)
148if self.visual_obj_loss and self.task_obj_predict:
149obj_labels["obj"] = (
150ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
151ids_tensor([self.batch_size, self.num_visual_features], self.num_object_labels),
152)
153ans = None
154if self.task_qa:
155ans = ids_tensor([self.batch_size], self.num_qa_labels)
156masked_lm_labels = None
157if self.task_mask_lm:
158masked_lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
159matched_label = None
160if self.task_matched:
161matched_label = ids_tensor([self.batch_size], self.num_labels)
162
163config = LxmertConfig(
164vocab_size=self.vocab_size,
165hidden_size=self.hidden_size,
166num_attention_heads=self.num_attention_heads,
167num_labels=self.num_labels,
168intermediate_size=self.intermediate_size,
169hidden_act=self.hidden_act,
170hidden_dropout_prob=self.hidden_dropout_prob,
171attention_probs_dropout_prob=self.attention_probs_dropout_prob,
172max_position_embeddings=self.max_position_embeddings,
173type_vocab_size=self.type_vocab_size,
174initializer_range=self.initializer_range,
175layer_norm_eps=self.layer_norm_eps,
176pad_token_id=self.pad_token_id,
177num_qa_labels=self.num_qa_labels,
178num_object_labels=self.num_object_labels,
179num_attr_labels=self.num_attr_labels,
180l_layers=self.l_layers,
181x_layers=self.x_layers,
182r_layers=self.r_layers,
183visual_feat_dim=self.visual_feat_dim,
184visual_pos_dim=self.visual_pos_dim,
185visual_loss_normalizer=self.visual_loss_normalizer,
186task_matched=self.task_matched,
187task_mask_lm=self.task_mask_lm,
188task_obj_predict=self.task_obj_predict,
189task_qa=self.task_qa,
190visual_obj_loss=self.visual_obj_loss,
191visual_attr_loss=self.visual_attr_loss,
192visual_feat_loss=self.visual_feat_loss,
193output_attentions=self.output_attentions,
194output_hidden_states=self.output_hidden_states,
195)
196
197return (
198config,
199input_ids,
200visual_feats,
201bounding_boxes,
202token_type_ids,
203input_mask,
204obj_labels,
205masked_lm_labels,
206matched_label,
207ans,
208output_attentions,
209)
210
211def create_and_check_lxmert_model(
212self,
213config,
214input_ids,
215visual_feats,
216bounding_boxes,
217token_type_ids,
218input_mask,
219obj_labels,
220masked_lm_labels,
221matched_label,
222ans,
223output_attentions,
224):
225model = TFLxmertModel(config=config)
226result = model(
227input_ids,
228visual_feats,
229bounding_boxes,
230token_type_ids=token_type_ids,
231attention_mask=input_mask,
232output_attentions=output_attentions,
233)
234result = model(
235input_ids,
236visual_feats,
237bounding_boxes,
238token_type_ids=token_type_ids,
239attention_mask=input_mask,
240output_attentions=not output_attentions,
241)
242result = model(input_ids, visual_feats, bounding_boxes, return_dict=False)
243result = model(input_ids, visual_feats, bounding_boxes, return_dict=True)
244
245self.parent.assertEqual(result.language_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
246self.parent.assertEqual(
247result.vision_output.shape, (self.batch_size, self.num_visual_features, self.hidden_size)
248)
249self.parent.assertEqual(result.pooled_output.shape, (self.batch_size, self.hidden_size))
250
251def prepare_config_and_inputs_for_common(self, return_obj_labels=False):
252config_and_inputs = self.prepare_config_and_inputs()
253(
254config,
255input_ids,
256visual_feats,
257bounding_boxes,
258token_type_ids,
259input_mask,
260obj_labels,
261masked_lm_labels,
262matched_label,
263ans,
264output_attentions,
265) = config_and_inputs
266
267inputs_dict = {
268"input_ids": input_ids,
269"visual_feats": visual_feats,
270"visual_pos": bounding_boxes,
271"token_type_ids": token_type_ids,
272"attention_mask": input_mask,
273}
274
275if return_obj_labels:
276inputs_dict["obj_labels"] = obj_labels
277else:
278config.task_obj_predict = False
279
280return config, inputs_dict
281
282def create_and_check_lxmert_for_pretraining(
283self,
284config,
285input_ids,
286visual_feats,
287bounding_boxes,
288token_type_ids,
289input_mask,
290obj_labels,
291masked_lm_labels,
292matched_label,
293ans,
294output_attentions,
295):
296model = TFLxmertForPreTraining(config=config)
297result = model(
298input_ids,
299visual_feats,
300bounding_boxes,
301token_type_ids=token_type_ids,
302attention_mask=input_mask,
303masked_lm_labels=masked_lm_labels,
304obj_labels=obj_labels,
305matched_label=matched_label,
306ans=ans,
307output_attentions=output_attentions,
308)
309result = model(
310input_ids,
311visual_feats,
312bounding_boxes,
313token_type_ids=token_type_ids,
314attention_mask=input_mask,
315masked_lm_labels=masked_lm_labels,
316output_attentions=not output_attentions,
317return_dict=False,
318)
319result = model(
320input_ids,
321visual_feats,
322bounding_boxes,
323token_type_ids=token_type_ids,
324attention_mask=input_mask,
325masked_lm_labels=masked_lm_labels,
326)
327result = model(
328input_ids,
329visual_feats,
330bounding_boxes,
331token_type_ids=token_type_ids,
332attention_mask=input_mask,
333obj_labels=obj_labels,
334)
335result = model(
336input_ids,
337visual_feats,
338bounding_boxes,
339token_type_ids=token_type_ids,
340attention_mask=input_mask,
341matched_label=matched_label,
342)
343result = model(
344input_ids,
345visual_feats,
346bounding_boxes,
347token_type_ids=token_type_ids,
348attention_mask=input_mask,
349ans=ans,
350)
351result = model(
352input_ids,
353visual_feats,
354bounding_boxes,
355token_type_ids=token_type_ids,
356attention_mask=input_mask,
357masked_lm_labels=masked_lm_labels,
358obj_labels=obj_labels,
359matched_label=matched_label,
360ans=ans,
361output_attentions=not output_attentions,
362)
363
364self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
365
366
367@require_tf
368class TFLxmertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
369all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
370pipeline_model_mapping = {"feature-extraction": TFLxmertModel} if is_tf_available() else {}
371test_head_masking = False
372test_onnx = False
373
374def setUp(self):
375self.model_tester = TFLxmertModelTester(self)
376self.config_tester = ConfigTester(self, config_class=LxmertConfig, hidden_size=37)
377
378def test_config(self):
379self.config_tester.run_common_tests()
380
381def test_lxmert_model(self):
382config_and_inputs = self.model_tester.prepare_config_and_inputs()
383self.model_tester.create_and_check_lxmert_model(*config_and_inputs)
384
385def test_lxmert_for_pretraining(self):
386config_and_inputs = self.model_tester.prepare_config_and_inputs()
387self.model_tester.create_and_check_lxmert_for_pretraining(*config_and_inputs)
388
389@slow
390def test_model_from_pretrained(self):
391for model_name in ["unc-nlp/lxmert-base-uncased"]:
392model = TFLxmertModel.from_pretrained(model_name)
393self.assertIsNotNone(model)
394
395def test_attention_outputs(self):
396config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
397
398encoder_seq_length = (
399self.model_tester.encoder_seq_length
400if hasattr(self.model_tester, "encoder_seq_length")
401else self.model_tester.seq_length
402)
403encoder_key_length = (
404self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
405)
406
407for model_class in self.all_model_classes:
408inputs_dict["output_attentions"] = True
409inputs_dict["output_hidden_states"] = False
410model = model_class(config)
411outputs = model(self._prepare_for_class(inputs_dict, model_class))
412language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
413
414self.assertEqual(model.config.output_hidden_states, False)
415
416self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
417self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
418self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
419
420attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
421attention_shapes = [
422[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
423[
424self.model_tester.num_attention_heads,
425self.model_tester.num_visual_features,
426self.model_tester.num_visual_features,
427],
428[self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
429]
430
431for attention, attention_shape in zip(attentions, attention_shapes):
432self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
433out_len = len(outputs)
434
435# Check attention is always last and order is fine
436inputs_dict["output_attentions"] = True
437inputs_dict["output_hidden_states"] = True
438model = model_class(config)
439outputs = model(self._prepare_for_class(inputs_dict, model_class))
440
441# 2 hidden states were added
442self.assertEqual(out_len + 2, len(outputs))
443language_attentions, vision_attentions, cross_encoder_attentions = (outputs[-3], outputs[-2], outputs[-1])
444self.assertEqual(len(language_attentions), self.model_tester.num_hidden_layers["language"])
445self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers["vision"])
446self.assertEqual(len(cross_encoder_attentions), self.model_tester.num_hidden_layers["cross_encoder"])
447
448attentions = [language_attentions, vision_attentions, cross_encoder_attentions]
449attention_shapes = [
450[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
451[
452self.model_tester.num_attention_heads,
453self.model_tester.num_visual_features,
454self.model_tester.num_visual_features,
455],
456[self.model_tester.num_attention_heads, encoder_key_length, self.model_tester.num_visual_features],
457]
458
459for attention, attention_shape in zip(attentions, attention_shapes):
460self.assertListEqual(list(attention[0].shape[-3:]), attention_shape)
461
462def test_hidden_states_output(self):
463config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
464
465def check_hidden_states_output(config, inputs_dict, model_class):
466model = model_class(config)
467outputs = model(self._prepare_for_class(inputs_dict, model_class))
468language_hidden_states, vision_hidden_states = outputs[-2], outputs[-1]
469
470self.assertEqual(len(language_hidden_states), self.model_tester.num_hidden_layers["language"] + 1)
471self.assertEqual(len(vision_hidden_states), self.model_tester.num_hidden_layers["vision"] + 1)
472
473seq_length = self.model_tester.seq_length
474num_visual_features = self.model_tester.num_visual_features
475
476self.assertListEqual(
477list(language_hidden_states[0].shape[-2:]),
478[seq_length, self.model_tester.hidden_size],
479)
480self.assertListEqual(
481list(vision_hidden_states[0].shape[-2:]),
482[num_visual_features, self.model_tester.hidden_size],
483)
484
485for model_class in self.all_model_classes:
486inputs_dict["output_hidden_states"] = True
487check_hidden_states_output(config, inputs_dict, model_class)
488
489del inputs_dict["output_hidden_states"]
490config.output_hidden_states = True
491check_hidden_states_output(config, inputs_dict, model_class)
492
493def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
494import torch
495
496pt_inputs_dict = {}
497for key, value in tf_inputs_dict.items():
498if isinstance(value, dict):
499pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
500elif isinstance(value, (list, tuple)):
501pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
502elif isinstance(key, bool):
503pt_inputs_dict[key] = value
504elif key == "input_values":
505pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
506elif key == "pixel_values":
507pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
508elif key == "input_features":
509pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
510# other general float inputs
511elif tf_inputs_dict[key].dtype.is_floating:
512pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
513else:
514pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.long)
515
516return pt_inputs_dict
517
518def test_save_load(self):
519for model_class in self.all_model_classes:
520config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
521return_obj_labels="PreTraining" in model_class.__name__
522)
523
524model = model_class(config)
525outputs = model(self._prepare_for_class(inputs_dict, model_class))
526
527with tempfile.TemporaryDirectory() as tmpdirname:
528model.save_pretrained(tmpdirname)
529model = model_class.from_pretrained(tmpdirname)
530after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
531
532self.assert_outputs_same(after_outputs, outputs)
533
534
535@require_tf
536class TFLxmertModelIntegrationTest(unittest.TestCase):
537@slow
538def test_inference_masked_lm(self):
539model = TFLxmertModel.from_pretrained("unc-nlp/lxmert-base-uncased")
540input_ids = tf.constant([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]])
541
542num_visual_features = 10
543_, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, model.config.visual_feat_dim)
544_, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4)
545visual_feats = tf.convert_to_tensor(visual_feats, dtype=tf.float32)
546visual_pos = tf.convert_to_tensor(visual_pos, dtype=tf.float32)
547output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0]
548expected_shape = [1, 11, 768]
549self.assertEqual(expected_shape, output.shape)
550expected_slice = tf.constant(
551[
552[
553[0.24170142, -0.98075, 0.14797261],
554[1.2540525, -0.83198136, 0.5112344],
555[1.4070463, -1.1051831, 0.6990401],
556]
557]
558)
559tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
560