transformers
345 строк · 14.1 Кб
1# coding=utf-8
2# Copyright 2020 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import unittest18
19from transformers import AlbertConfig, is_torch_available20from transformers.models.auto import get_values21from transformers.testing_utils import require_torch, slow, torch_device22
23from ...test_configuration_common import ConfigTester24from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask25from ...test_pipeline_mixin import PipelineTesterMixin26
27
28if is_torch_available():29import torch30
31from transformers import (32MODEL_FOR_PRETRAINING_MAPPING,33AlbertForMaskedLM,34AlbertForMultipleChoice,35AlbertForPreTraining,36AlbertForQuestionAnswering,37AlbertForSequenceClassification,38AlbertForTokenClassification,39AlbertModel,40)41from transformers.models.albert.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST42
43
44class AlbertModelTester:45def __init__(46self,47parent,48batch_size=13,49seq_length=7,50is_training=True,51use_input_mask=True,52use_token_type_ids=True,53use_labels=True,54vocab_size=99,55embedding_size=16,56hidden_size=36,57num_hidden_layers=2,58# this needs to be the same as `num_hidden_layers`!59num_hidden_groups=2,60num_attention_heads=6,61intermediate_size=37,62hidden_act="gelu",63hidden_dropout_prob=0.1,64attention_probs_dropout_prob=0.1,65max_position_embeddings=512,66type_vocab_size=16,67type_sequence_label_size=2,68initializer_range=0.02,69num_labels=3,70num_choices=4,71scope=None,72):73self.parent = parent74self.batch_size = batch_size75self.seq_length = seq_length76self.is_training = is_training77self.use_input_mask = use_input_mask78self.use_token_type_ids = use_token_type_ids79self.use_labels = use_labels80self.vocab_size = vocab_size81self.embedding_size = embedding_size82self.hidden_size = hidden_size83self.num_hidden_layers = num_hidden_layers84self.num_hidden_groups = num_hidden_groups85self.num_attention_heads = num_attention_heads86self.intermediate_size = intermediate_size87self.hidden_act = hidden_act88self.hidden_dropout_prob = hidden_dropout_prob89self.attention_probs_dropout_prob = attention_probs_dropout_prob90self.max_position_embeddings = max_position_embeddings91self.type_vocab_size = type_vocab_size92self.type_sequence_label_size = type_sequence_label_size93self.initializer_range = initializer_range94self.num_labels = num_labels95self.num_choices = num_choices96self.scope = scope97
98def prepare_config_and_inputs(self):99input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)100
101input_mask = None102if self.use_input_mask:103input_mask = random_attention_mask([self.batch_size, self.seq_length])104
105token_type_ids = None106if self.use_token_type_ids:107token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)108
109sequence_labels = None110token_labels = None111choice_labels = None112if self.use_labels:113sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)114token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)115choice_labels = ids_tensor([self.batch_size], self.num_choices)116
117config = self.get_config()118
119return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels120
121def get_config(self):122return AlbertConfig(123vocab_size=self.vocab_size,124hidden_size=self.hidden_size,125num_hidden_layers=self.num_hidden_layers,126num_attention_heads=self.num_attention_heads,127intermediate_size=self.intermediate_size,128hidden_act=self.hidden_act,129hidden_dropout_prob=self.hidden_dropout_prob,130attention_probs_dropout_prob=self.attention_probs_dropout_prob,131max_position_embeddings=self.max_position_embeddings,132type_vocab_size=self.type_vocab_size,133initializer_range=self.initializer_range,134num_hidden_groups=self.num_hidden_groups,135)136
137def create_and_check_model(138self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels139):140model = AlbertModel(config=config)141model.to(torch_device)142model.eval()143result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)144result = model(input_ids, token_type_ids=token_type_ids)145result = model(input_ids)146self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))147self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))148
149def create_and_check_for_pretraining(150self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels151):152model = AlbertForPreTraining(config=config)153model.to(torch_device)154model.eval()155result = model(156input_ids,157attention_mask=input_mask,158token_type_ids=token_type_ids,159labels=token_labels,160sentence_order_label=sequence_labels,161)162self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))163self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, config.num_labels))164
165def create_and_check_for_masked_lm(166self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels167):168model = AlbertForMaskedLM(config=config)169model.to(torch_device)170model.eval()171result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)172self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))173
174def create_and_check_for_question_answering(175self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels176):177model = AlbertForQuestionAnswering(config=config)178model.to(torch_device)179model.eval()180result = model(181input_ids,182attention_mask=input_mask,183token_type_ids=token_type_ids,184start_positions=sequence_labels,185end_positions=sequence_labels,186)187self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))188self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))189
190def create_and_check_for_sequence_classification(191self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels192):193config.num_labels = self.num_labels194model = AlbertForSequenceClassification(config)195model.to(torch_device)196model.eval()197result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)198self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))199
200def create_and_check_for_token_classification(201self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels202):203config.num_labels = self.num_labels204model = AlbertForTokenClassification(config=config)205model.to(torch_device)206model.eval()207result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)208self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))209
210def create_and_check_for_multiple_choice(211self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels212):213config.num_choices = self.num_choices214model = AlbertForMultipleChoice(config=config)215model.to(torch_device)216model.eval()217multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()218multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()219multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()220result = model(221multiple_choice_inputs_ids,222attention_mask=multiple_choice_input_mask,223token_type_ids=multiple_choice_token_type_ids,224labels=choice_labels,225)226self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))227
228def prepare_config_and_inputs_for_common(self):229config_and_inputs = self.prepare_config_and_inputs()230(231config,232input_ids,233token_type_ids,234input_mask,235sequence_labels,236token_labels,237choice_labels,238) = config_and_inputs239inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}240return config, inputs_dict241
242
243@require_torch
244class AlbertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):245all_model_classes = (246(247AlbertModel,248AlbertForPreTraining,249AlbertForMaskedLM,250AlbertForMultipleChoice,251AlbertForSequenceClassification,252AlbertForTokenClassification,253AlbertForQuestionAnswering,254)255if is_torch_available()256else ()257)258pipeline_model_mapping = (259{260"feature-extraction": AlbertModel,261"fill-mask": AlbertForMaskedLM,262"question-answering": AlbertForQuestionAnswering,263"text-classification": AlbertForSequenceClassification,264"token-classification": AlbertForTokenClassification,265"zero-shot": AlbertForSequenceClassification,266}267if is_torch_available()268else {}269)270fx_compatible = True271
272# special case for ForPreTraining model273def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):274inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)275
276if return_labels:277if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):278inputs_dict["labels"] = torch.zeros(279(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device280)281inputs_dict["sentence_order_label"] = torch.zeros(282self.model_tester.batch_size, dtype=torch.long, device=torch_device283)284return inputs_dict285
286def setUp(self):287self.model_tester = AlbertModelTester(self)288self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)289
290def test_config(self):291self.config_tester.run_common_tests()292
293def test_model(self):294config_and_inputs = self.model_tester.prepare_config_and_inputs()295self.model_tester.create_and_check_model(*config_and_inputs)296
297def test_for_pretraining(self):298config_and_inputs = self.model_tester.prepare_config_and_inputs()299self.model_tester.create_and_check_for_pretraining(*config_and_inputs)300
301def test_for_masked_lm(self):302config_and_inputs = self.model_tester.prepare_config_and_inputs()303self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)304
305def test_for_multiple_choice(self):306config_and_inputs = self.model_tester.prepare_config_and_inputs()307self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)308
309def test_for_question_answering(self):310config_and_inputs = self.model_tester.prepare_config_and_inputs()311self.model_tester.create_and_check_for_question_answering(*config_and_inputs)312
313def test_for_sequence_classification(self):314config_and_inputs = self.model_tester.prepare_config_and_inputs()315self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)316
317def test_model_various_embeddings(self):318config_and_inputs = self.model_tester.prepare_config_and_inputs()319for type in ["absolute", "relative_key", "relative_key_query"]:320config_and_inputs[0].position_embedding_type = type321self.model_tester.create_and_check_model(*config_and_inputs)322
323@slow324def test_model_from_pretrained(self):325for model_name in ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:326model = AlbertModel.from_pretrained(model_name)327self.assertIsNotNone(model)328
329
330@require_torch
331class AlbertModelIntegrationTest(unittest.TestCase):332@slow333def test_inference_no_head_absolute_embedding(self):334model = AlbertModel.from_pretrained("albert/albert-base-v2")335input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])336attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])337with torch.no_grad():338output = model(input_ids, attention_mask=attention_mask)[0]339expected_shape = torch.Size((1, 11, 768))340self.assertEqual(output.shape, expected_shape)341expected_slice = torch.tensor(342[[[-0.6513, 1.5035, -0.2766], [-0.6515, 1.5046, -0.2780], [-0.6512, 1.5049, -0.2784]]]343)344
345self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))346