transformers
163 строки · 5.8 Кб
1# Copyright 2020 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import unittest
16
17import numpy as np
18
19from transformers import BertConfig, is_flax_available
20from transformers.testing_utils import require_flax, slow
21
22from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
23
24
25if is_flax_available():
26from transformers.models.bert.modeling_flax_bert import (
27FlaxBertForMaskedLM,
28FlaxBertForMultipleChoice,
29FlaxBertForNextSentencePrediction,
30FlaxBertForPreTraining,
31FlaxBertForQuestionAnswering,
32FlaxBertForSequenceClassification,
33FlaxBertForTokenClassification,
34FlaxBertModel,
35)
36
37
38class FlaxBertModelTester(unittest.TestCase):
39def __init__(
40self,
41parent,
42batch_size=13,
43seq_length=7,
44is_training=True,
45use_attention_mask=True,
46use_token_type_ids=True,
47use_labels=True,
48vocab_size=99,
49hidden_size=32,
50num_hidden_layers=2,
51num_attention_heads=4,
52intermediate_size=37,
53hidden_act="gelu",
54hidden_dropout_prob=0.1,
55attention_probs_dropout_prob=0.1,
56max_position_embeddings=512,
57type_vocab_size=16,
58type_sequence_label_size=2,
59initializer_range=0.02,
60num_choices=4,
61):
62self.parent = parent
63self.batch_size = batch_size
64self.seq_length = seq_length
65self.is_training = is_training
66self.use_attention_mask = use_attention_mask
67self.use_token_type_ids = use_token_type_ids
68self.use_labels = use_labels
69self.vocab_size = vocab_size
70self.hidden_size = hidden_size
71self.num_hidden_layers = num_hidden_layers
72self.num_attention_heads = num_attention_heads
73self.intermediate_size = intermediate_size
74self.hidden_act = hidden_act
75self.hidden_dropout_prob = hidden_dropout_prob
76self.attention_probs_dropout_prob = attention_probs_dropout_prob
77self.max_position_embeddings = max_position_embeddings
78self.type_vocab_size = type_vocab_size
79self.type_sequence_label_size = type_sequence_label_size
80self.initializer_range = initializer_range
81self.num_choices = num_choices
82
83def prepare_config_and_inputs(self):
84input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
85
86attention_mask = None
87if self.use_attention_mask:
88attention_mask = random_attention_mask([self.batch_size, self.seq_length])
89
90token_type_ids = None
91if self.use_token_type_ids:
92token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
93
94config = BertConfig(
95vocab_size=self.vocab_size,
96hidden_size=self.hidden_size,
97num_hidden_layers=self.num_hidden_layers,
98num_attention_heads=self.num_attention_heads,
99intermediate_size=self.intermediate_size,
100hidden_act=self.hidden_act,
101hidden_dropout_prob=self.hidden_dropout_prob,
102attention_probs_dropout_prob=self.attention_probs_dropout_prob,
103max_position_embeddings=self.max_position_embeddings,
104type_vocab_size=self.type_vocab_size,
105is_decoder=False,
106initializer_range=self.initializer_range,
107)
108
109return config, input_ids, token_type_ids, attention_mask
110
111def prepare_config_and_inputs_for_common(self):
112config_and_inputs = self.prepare_config_and_inputs()
113config, input_ids, token_type_ids, attention_mask = config_and_inputs
114inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
115return config, inputs_dict
116
117def prepare_config_and_inputs_for_decoder(self):
118config_and_inputs = self.prepare_config_and_inputs()
119config, input_ids, token_type_ids, attention_mask = config_and_inputs
120
121config.is_decoder = True
122encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
123encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
124
125return (
126config,
127input_ids,
128attention_mask,
129encoder_hidden_states,
130encoder_attention_mask,
131)
132
133
134@require_flax
135class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
136test_head_masking = True
137
138all_model_classes = (
139(
140FlaxBertModel,
141FlaxBertForPreTraining,
142FlaxBertForMaskedLM,
143FlaxBertForMultipleChoice,
144FlaxBertForQuestionAnswering,
145FlaxBertForNextSentencePrediction,
146FlaxBertForSequenceClassification,
147FlaxBertForTokenClassification,
148FlaxBertForQuestionAnswering,
149)
150if is_flax_available()
151else ()
152)
153
154def setUp(self):
155self.model_tester = FlaxBertModelTester(self)
156
157@slow
158def test_model_from_pretrained(self):
159# Only check this for base model, not necessary for all model classes.
160# This will also help speed-up tests.
161model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
162outputs = model(np.ones((1, 1)))
163self.assertIsNotNone(outputs)
164