transformers
152 строки · 5.5 Кб
1# Copyright 2021 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import unittest16
17import numpy as np18
19from transformers import DistilBertConfig, is_flax_available20from transformers.testing_utils import require_flax, slow21
22from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask23
24
25if is_flax_available():26import jax.numpy as jnp27
28from transformers.models.distilbert.modeling_flax_distilbert import (29FlaxDistilBertForMaskedLM,30FlaxDistilBertForMultipleChoice,31FlaxDistilBertForQuestionAnswering,32FlaxDistilBertForSequenceClassification,33FlaxDistilBertForTokenClassification,34FlaxDistilBertModel,35)36
37
38class FlaxDistilBertModelTester(unittest.TestCase):39def __init__(40self,41parent,42batch_size=13,43seq_length=7,44is_training=True,45use_attention_mask=True,46use_token_type_ids=True,47use_labels=True,48vocab_size=99,49hidden_size=32,50num_hidden_layers=2,51num_attention_heads=4,52intermediate_size=37,53hidden_act="gelu",54hidden_dropout_prob=0.1,55attention_probs_dropout_prob=0.1,56max_position_embeddings=512,57type_vocab_size=16,58type_sequence_label_size=2,59initializer_range=0.02,60num_choices=4,61):62self.parent = parent63self.batch_size = batch_size64self.seq_length = seq_length65self.is_training = is_training66self.use_attention_mask = use_attention_mask67self.use_token_type_ids = use_token_type_ids68self.use_labels = use_labels69self.vocab_size = vocab_size70self.hidden_size = hidden_size71self.num_hidden_layers = num_hidden_layers72self.num_attention_heads = num_attention_heads73self.intermediate_size = intermediate_size74self.hidden_act = hidden_act75self.hidden_dropout_prob = hidden_dropout_prob76self.attention_probs_dropout_prob = attention_probs_dropout_prob77self.max_position_embeddings = max_position_embeddings78self.type_vocab_size = type_vocab_size79self.type_sequence_label_size = type_sequence_label_size80self.initializer_range = initializer_range81self.num_choices = num_choices82
83def prepare_config_and_inputs(self):84input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)85
86attention_mask = None87if self.use_attention_mask:88attention_mask = random_attention_mask([self.batch_size, self.seq_length])89
90config = DistilBertConfig(91vocab_size=self.vocab_size,92dim=self.hidden_size,93n_layers=self.num_hidden_layers,94n_heads=self.num_attention_heads,95hidden_dim=self.intermediate_size,96hidden_act=self.hidden_act,97dropout=self.hidden_dropout_prob,98attention_dropout=self.attention_probs_dropout_prob,99max_position_embeddings=self.max_position_embeddings,100initializer_range=self.initializer_range,101tie_weights_=True,102)103
104return config, input_ids, attention_mask105
106def prepare_config_and_inputs_for_common(self):107config_and_inputs = self.prepare_config_and_inputs()108config, input_ids, attention_mask = config_and_inputs109inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}110return config, inputs_dict111
112
113@require_flax
114class FlaxDistilBertModelTest(FlaxModelTesterMixin, unittest.TestCase):115all_model_classes = (116(117FlaxDistilBertModel,118FlaxDistilBertForMaskedLM,119FlaxDistilBertForMultipleChoice,120FlaxDistilBertForQuestionAnswering,121FlaxDistilBertForSequenceClassification,122FlaxDistilBertForTokenClassification,123FlaxDistilBertForQuestionAnswering,124)125if is_flax_available()126else ()127)128
129def setUp(self):130self.model_tester = FlaxDistilBertModelTester(self)131
132@slow133def test_model_from_pretrained(self):134for model_class_name in self.all_model_classes:135model = model_class_name.from_pretrained("distilbert-base-uncased")136outputs = model(np.ones((1, 1)))137self.assertIsNotNone(outputs)138
139
140@require_flax
141class FlaxDistilBertModelIntegrationTest(unittest.TestCase):142@slow143def test_inference_no_head_absolute_embedding(self):144model = FlaxDistilBertModel.from_pretrained("distilbert-base-uncased")145input_ids = np.array([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])146attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])147output = model(input_ids, attention_mask=attention_mask)[0]148expected_shape = (1, 11, 768)149self.assertEqual(output.shape, expected_shape)150expected_slice = np.array([[[-0.1639, 0.3299, 0.1648], [-0.1746, 0.3289, 0.1710], [-0.1884, 0.3357, 0.1810]]])151
152self.assertTrue(jnp.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))153