transformers
669 строк · 26.4 Кб
1# coding=utf-8
2# Copyright 2020 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import os16import tempfile17import unittest18
19from transformers import BertConfig, is_torch_available20from transformers.models.auto import get_values21from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device22
23from ...generation.test_utils import GenerationTesterMixin24from ...test_configuration_common import ConfigTester25from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask26from ...test_pipeline_mixin import PipelineTesterMixin27
28
29if is_torch_available():30import torch31
32from transformers import (33MODEL_FOR_PRETRAINING_MAPPING,34BertForMaskedLM,35BertForMultipleChoice,36BertForNextSentencePrediction,37BertForPreTraining,38BertForQuestionAnswering,39BertForSequenceClassification,40BertForTokenClassification,41BertLMHeadModel,42BertModel,43logging,44)45from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST46
47
48class BertModelTester:49def __init__(50self,51parent,52batch_size=13,53seq_length=7,54is_training=True,55use_input_mask=True,56use_token_type_ids=True,57use_labels=True,58vocab_size=99,59hidden_size=32,60num_hidden_layers=2,61num_attention_heads=4,62intermediate_size=37,63hidden_act="gelu",64hidden_dropout_prob=0.1,65attention_probs_dropout_prob=0.1,66max_position_embeddings=512,67type_vocab_size=16,68type_sequence_label_size=2,69initializer_range=0.02,70num_labels=3,71num_choices=4,72scope=None,73):74self.parent = parent75self.batch_size = batch_size76self.seq_length = seq_length77self.is_training = is_training78self.use_input_mask = use_input_mask79self.use_token_type_ids = use_token_type_ids80self.use_labels = use_labels81self.vocab_size = vocab_size82self.hidden_size = hidden_size83self.num_hidden_layers = num_hidden_layers84self.num_attention_heads = num_attention_heads85self.intermediate_size = intermediate_size86self.hidden_act = hidden_act87self.hidden_dropout_prob = hidden_dropout_prob88self.attention_probs_dropout_prob = attention_probs_dropout_prob89self.max_position_embeddings = max_position_embeddings90self.type_vocab_size = type_vocab_size91self.type_sequence_label_size = type_sequence_label_size92self.initializer_range = initializer_range93self.num_labels = num_labels94self.num_choices = num_choices95self.scope = scope96
97def prepare_config_and_inputs(self):98input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)99
100input_mask = None101if self.use_input_mask:102input_mask = random_attention_mask([self.batch_size, self.seq_length])103
104token_type_ids = None105if self.use_token_type_ids:106token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)107
108sequence_labels = None109token_labels = None110choice_labels = None111if self.use_labels:112sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)113token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)114choice_labels = ids_tensor([self.batch_size], self.num_choices)115
116config = self.get_config()117
118return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels119
120def get_config(self):121"""122Returns a tiny configuration by default.
123"""
124return BertConfig(125vocab_size=self.vocab_size,126hidden_size=self.hidden_size,127num_hidden_layers=self.num_hidden_layers,128num_attention_heads=self.num_attention_heads,129intermediate_size=self.intermediate_size,130hidden_act=self.hidden_act,131hidden_dropout_prob=self.hidden_dropout_prob,132attention_probs_dropout_prob=self.attention_probs_dropout_prob,133max_position_embeddings=self.max_position_embeddings,134type_vocab_size=self.type_vocab_size,135is_decoder=False,136initializer_range=self.initializer_range,137)138
139def prepare_config_and_inputs_for_decoder(self):140(141config,142input_ids,143token_type_ids,144input_mask,145sequence_labels,146token_labels,147choice_labels,148) = self.prepare_config_and_inputs()149
150config.is_decoder = True151encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])152encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)153
154return (155config,156input_ids,157token_type_ids,158input_mask,159sequence_labels,160token_labels,161choice_labels,162encoder_hidden_states,163encoder_attention_mask,164)165
166def create_and_check_model(167self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels168):169model = BertModel(config=config)170model.to(torch_device)171model.eval()172result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)173result = model(input_ids, token_type_ids=token_type_ids)174result = model(input_ids)175self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))176self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))177
178def create_and_check_model_as_decoder(179self,180config,181input_ids,182token_type_ids,183input_mask,184sequence_labels,185token_labels,186choice_labels,187encoder_hidden_states,188encoder_attention_mask,189):190config.add_cross_attention = True191model = BertModel(config)192model.to(torch_device)193model.eval()194result = model(195input_ids,196attention_mask=input_mask,197token_type_ids=token_type_ids,198encoder_hidden_states=encoder_hidden_states,199encoder_attention_mask=encoder_attention_mask,200)201result = model(202input_ids,203attention_mask=input_mask,204token_type_ids=token_type_ids,205encoder_hidden_states=encoder_hidden_states,206)207result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)208self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))209self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))210
211def create_and_check_for_causal_lm(212self,213config,214input_ids,215token_type_ids,216input_mask,217sequence_labels,218token_labels,219choice_labels,220encoder_hidden_states,221encoder_attention_mask,222):223model = BertLMHeadModel(config=config)224model.to(torch_device)225model.eval()226result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)227self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))228
229def create_and_check_for_masked_lm(230self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels231):232model = BertForMaskedLM(config=config)233model.to(torch_device)234model.eval()235result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)236self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))237
238def create_and_check_model_for_causal_lm_as_decoder(239self,240config,241input_ids,242token_type_ids,243input_mask,244sequence_labels,245token_labels,246choice_labels,247encoder_hidden_states,248encoder_attention_mask,249):250config.add_cross_attention = True251model = BertLMHeadModel(config=config)252model.to(torch_device)253model.eval()254result = model(255input_ids,256attention_mask=input_mask,257token_type_ids=token_type_ids,258labels=token_labels,259encoder_hidden_states=encoder_hidden_states,260encoder_attention_mask=encoder_attention_mask,261)262result = model(263input_ids,264attention_mask=input_mask,265token_type_ids=token_type_ids,266labels=token_labels,267encoder_hidden_states=encoder_hidden_states,268)269self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))270
271def create_and_check_decoder_model_past_large_inputs(272self,273config,274input_ids,275token_type_ids,276input_mask,277sequence_labels,278token_labels,279choice_labels,280encoder_hidden_states,281encoder_attention_mask,282):283config.is_decoder = True284config.add_cross_attention = True285model = BertLMHeadModel(config=config).to(torch_device).eval()286
287# first forward pass288outputs = model(289input_ids,290attention_mask=input_mask,291encoder_hidden_states=encoder_hidden_states,292encoder_attention_mask=encoder_attention_mask,293use_cache=True,294)295past_key_values = outputs.past_key_values296
297# create hypothetical multiple next token and extent to next_input_ids298next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)299next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)300
301# append to next input_ids and302next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)303next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)304
305output_from_no_past = model(306next_input_ids,307attention_mask=next_attention_mask,308encoder_hidden_states=encoder_hidden_states,309encoder_attention_mask=encoder_attention_mask,310output_hidden_states=True,311)["hidden_states"][0]312output_from_past = model(313next_tokens,314attention_mask=next_attention_mask,315encoder_hidden_states=encoder_hidden_states,316encoder_attention_mask=encoder_attention_mask,317past_key_values=past_key_values,318output_hidden_states=True,319)["hidden_states"][0]320
321# select random slice322random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()323output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()324output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()325
326self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])327
328# test that outputs are equal for slice329self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))330
331def create_and_check_for_next_sequence_prediction(332self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels333):334model = BertForNextSentencePrediction(config=config)335model.to(torch_device)336model.eval()337result = model(338input_ids,339attention_mask=input_mask,340token_type_ids=token_type_ids,341labels=sequence_labels,342)343self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))344
345def create_and_check_for_pretraining(346self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels347):348model = BertForPreTraining(config=config)349model.to(torch_device)350model.eval()351result = model(352input_ids,353attention_mask=input_mask,354token_type_ids=token_type_ids,355labels=token_labels,356next_sentence_label=sequence_labels,357)358self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))359self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))360
361def create_and_check_for_question_answering(362self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels363):364model = BertForQuestionAnswering(config=config)365model.to(torch_device)366model.eval()367result = model(368input_ids,369attention_mask=input_mask,370token_type_ids=token_type_ids,371start_positions=sequence_labels,372end_positions=sequence_labels,373)374self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))375self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))376
377def create_and_check_for_sequence_classification(378self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels379):380config.num_labels = self.num_labels381model = BertForSequenceClassification(config)382model.to(torch_device)383model.eval()384result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)385self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))386
387def create_and_check_for_token_classification(388self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels389):390config.num_labels = self.num_labels391model = BertForTokenClassification(config=config)392model.to(torch_device)393model.eval()394result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)395self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))396
397def create_and_check_for_multiple_choice(398self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels399):400config.num_choices = self.num_choices401model = BertForMultipleChoice(config=config)402model.to(torch_device)403model.eval()404multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()405multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()406multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()407result = model(408multiple_choice_inputs_ids,409attention_mask=multiple_choice_input_mask,410token_type_ids=multiple_choice_token_type_ids,411labels=choice_labels,412)413self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))414
415def prepare_config_and_inputs_for_common(self):416config_and_inputs = self.prepare_config_and_inputs()417(418config,419input_ids,420token_type_ids,421input_mask,422sequence_labels,423token_labels,424choice_labels,425) = config_and_inputs426inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}427return config, inputs_dict428
429
430@require_torch
431class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):432all_model_classes = (433(434BertModel,435BertLMHeadModel,436BertForMaskedLM,437BertForMultipleChoice,438BertForNextSentencePrediction,439BertForPreTraining,440BertForQuestionAnswering,441BertForSequenceClassification,442BertForTokenClassification,443)444if is_torch_available()445else ()446)447all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()448pipeline_model_mapping = (449{450"feature-extraction": BertModel,451"fill-mask": BertForMaskedLM,452"question-answering": BertForQuestionAnswering,453"text-classification": BertForSequenceClassification,454"text-generation": BertLMHeadModel,455"token-classification": BertForTokenClassification,456"zero-shot": BertForSequenceClassification,457}458if is_torch_available()459else {}460)461fx_compatible = True462
463# special case for ForPreTraining model464def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):465inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)466
467if return_labels:468if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):469inputs_dict["labels"] = torch.zeros(470(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device471)472inputs_dict["next_sentence_label"] = torch.zeros(473self.model_tester.batch_size, dtype=torch.long, device=torch_device474)475return inputs_dict476
477def setUp(self):478self.model_tester = BertModelTester(self)479self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)480
481def test_config(self):482self.config_tester.run_common_tests()483
484def test_model(self):485config_and_inputs = self.model_tester.prepare_config_and_inputs()486self.model_tester.create_and_check_model(*config_and_inputs)487
488def test_model_various_embeddings(self):489config_and_inputs = self.model_tester.prepare_config_and_inputs()490for type in ["absolute", "relative_key", "relative_key_query"]:491config_and_inputs[0].position_embedding_type = type492self.model_tester.create_and_check_model(*config_and_inputs)493
494def test_model_as_decoder(self):495config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()496self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)497
498def test_model_as_decoder_with_default_input_mask(self):499# This regression test was failing with PyTorch < 1.3500(501config,502input_ids,503token_type_ids,504input_mask,505sequence_labels,506token_labels,507choice_labels,508encoder_hidden_states,509encoder_attention_mask,510) = self.model_tester.prepare_config_and_inputs_for_decoder()511
512input_mask = None513
514self.model_tester.create_and_check_model_as_decoder(515config,516input_ids,517token_type_ids,518input_mask,519sequence_labels,520token_labels,521choice_labels,522encoder_hidden_states,523encoder_attention_mask,524)525
526def test_for_causal_lm(self):527config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()528self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)529
530def test_for_masked_lm(self):531config_and_inputs = self.model_tester.prepare_config_and_inputs()532self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)533
534def test_for_causal_lm_decoder(self):535config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()536self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)537
538def test_decoder_model_past_with_large_inputs(self):539config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()540self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)541
542def test_decoder_model_past_with_large_inputs_relative_pos_emb(self):543config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()544config_and_inputs[0].position_embedding_type = "relative_key"545self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)546
547def test_for_multiple_choice(self):548config_and_inputs = self.model_tester.prepare_config_and_inputs()549self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)550
551def test_for_next_sequence_prediction(self):552config_and_inputs = self.model_tester.prepare_config_and_inputs()553self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)554
555def test_for_pretraining(self):556config_and_inputs = self.model_tester.prepare_config_and_inputs()557self.model_tester.create_and_check_for_pretraining(*config_and_inputs)558
559def test_for_question_answering(self):560config_and_inputs = self.model_tester.prepare_config_and_inputs()561self.model_tester.create_and_check_for_question_answering(*config_and_inputs)562
563def test_for_sequence_classification(self):564config_and_inputs = self.model_tester.prepare_config_and_inputs()565self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)566
567def test_for_token_classification(self):568config_and_inputs = self.model_tester.prepare_config_and_inputs()569self.model_tester.create_and_check_for_token_classification(*config_and_inputs)570
571def test_for_warning_if_padding_and_no_attention_mask(self):572(573config,574input_ids,575token_type_ids,576input_mask,577sequence_labels,578token_labels,579choice_labels,580) = self.model_tester.prepare_config_and_inputs()581
582# Set pad tokens in the input_ids583input_ids[0, 0] = config.pad_token_id584
585# Check for warnings if the attention_mask is missing.586logger = logging.get_logger("transformers.modeling_utils")587# clear cache so we can test the warning is emitted (from `warning_once`).588logger.warning_once.cache_clear()589
590with CaptureLogger(logger) as cl:591model = BertModel(config=config)592model.to(torch_device)593model.eval()594model(input_ids, attention_mask=None, token_type_ids=token_type_ids)595self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)596
597@slow598def test_model_from_pretrained(self):599for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:600model = BertModel.from_pretrained(model_name)601self.assertIsNotNone(model)602
603@slow604@require_torch_accelerator605def test_torchscript_device_change(self):606config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()607for model_class in self.all_model_classes:608# BertForMultipleChoice behaves incorrectly in JIT environments.609if model_class == BertForMultipleChoice:610return611
612config.torchscript = True613model = model_class(config=config)614
615inputs_dict = self._prepare_for_class(inputs_dict, model_class)616traced_model = torch.jit.trace(617model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))618)619
620with tempfile.TemporaryDirectory() as tmp:621torch.jit.save(traced_model, os.path.join(tmp, "bert.pt"))622loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)623loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))624
625
626@require_torch
627class BertModelIntegrationTest(unittest.TestCase):628@slow629def test_inference_no_head_absolute_embedding(self):630model = BertModel.from_pretrained("google-bert/bert-base-uncased")631input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])632attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])633with torch.no_grad():634output = model(input_ids, attention_mask=attention_mask)[0]635expected_shape = torch.Size((1, 11, 768))636self.assertEqual(output.shape, expected_shape)637expected_slice = torch.tensor([[[0.4249, 0.1008, 0.7531], [0.3771, 0.1188, 0.7467], [0.4152, 0.1098, 0.7108]]])638
639self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))640
641@slow642def test_inference_no_head_relative_embedding_key(self):643model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")644input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])645attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])646with torch.no_grad():647output = model(input_ids, attention_mask=attention_mask)[0]648expected_shape = torch.Size((1, 11, 768))649self.assertEqual(output.shape, expected_shape)650expected_slice = torch.tensor(651[[[0.0756, 0.3142, -0.5128], [0.3761, 0.3462, -0.5477], [0.2052, 0.3760, -0.1240]]]652)653
654self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))655
656@slow657def test_inference_no_head_relative_embedding_key_query(self):658model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")659input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])660attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])661with torch.no_grad():662output = model(input_ids, attention_mask=attention_mask)[0]663expected_shape = torch.Size((1, 11, 768))664self.assertEqual(output.shape, expected_shape)665expected_slice = torch.tensor(666[[[0.6496, 0.3784, 0.8203], [0.8148, 0.5656, 0.2636], [-0.0681, 0.5597, 0.7045]]]667)668
669self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))670