transformers

Форк
0
/
test_modeling_tf_rembert.py 
728 строк · 27.4 Кб
1
# coding=utf-8
2
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
from __future__ import annotations
18

19
import unittest
20

21
from transformers import RemBertConfig, is_tf_available
22
from transformers.testing_utils import require_tf, slow
23

24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_tf_available():
30
    import tensorflow as tf
31

32
    from transformers import (
33
        TFRemBertForCausalLM,
34
        TFRemBertForMaskedLM,
35
        TFRemBertForMultipleChoice,
36
        TFRemBertForQuestionAnswering,
37
        TFRemBertForSequenceClassification,
38
        TFRemBertForTokenClassification,
39
        TFRemBertModel,
40
    )
41

42

43
class TFRemBertModelTester:
44
    def __init__(
45
        self,
46
        parent,
47
        batch_size=13,
48
        seq_length=7,
49
        is_training=True,
50
        use_input_mask=True,
51
        use_token_type_ids=True,
52
        use_labels=True,
53
        vocab_size=99,
54
        hidden_size=32,
55
        input_embedding_size=18,
56
        output_embedding_size=43,
57
        num_hidden_layers=2,
58
        num_attention_heads=4,
59
        intermediate_size=37,
60
        hidden_act="gelu",
61
        hidden_dropout_prob=0.1,
62
        attention_probs_dropout_prob=0.1,
63
        max_position_embeddings=512,
64
        type_vocab_size=16,
65
        type_sequence_label_size=2,
66
        initializer_range=0.02,
67
        num_labels=3,
68
        num_choices=4,
69
        scope=None,
70
    ):
71
        self.parent = parent
72
        self.batch_size = 13
73
        self.seq_length = 7
74
        self.is_training = True
75
        self.use_input_mask = True
76
        self.use_token_type_ids = True
77
        self.use_labels = True
78
        self.vocab_size = 99
79
        self.hidden_size = 32
80
        self.input_embedding_size = input_embedding_size
81
        self.output_embedding_size = output_embedding_size
82
        self.num_hidden_layers = 2
83
        self.num_attention_heads = 4
84
        self.intermediate_size = 37
85
        self.hidden_act = "gelu"
86
        self.hidden_dropout_prob = 0.1
87
        self.attention_probs_dropout_prob = 0.1
88
        self.max_position_embeddings = 512
89
        self.type_vocab_size = 16
90
        self.type_sequence_label_size = 2
91
        self.initializer_range = 0.02
92
        self.num_labels = 3
93
        self.num_choices = 4
94
        self.scope = None
95

96
    def prepare_config_and_inputs(self):
97
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
98

99
        input_mask = None
100
        if self.use_input_mask:
101
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
102

103
        token_type_ids = None
104
        if self.use_token_type_ids:
105
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
106

107
        sequence_labels = None
108
        token_labels = None
109
        choice_labels = None
110
        if self.use_labels:
111
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
112
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
113
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
114

115
        config = RemBertConfig(
116
            vocab_size=self.vocab_size,
117
            hidden_size=self.hidden_size,
118
            input_embedding_size=self.input_embedding_size,
119
            output_embedding_size=self.output_embedding_size,
120
            num_hidden_layers=self.num_hidden_layers,
121
            num_attention_heads=self.num_attention_heads,
122
            intermediate_size=self.intermediate_size,
123
            hidden_act=self.hidden_act,
124
            hidden_dropout_prob=self.hidden_dropout_prob,
125
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
126
            max_position_embeddings=self.max_position_embeddings,
127
            type_vocab_size=self.type_vocab_size,
128
            initializer_range=self.initializer_range,
129
            return_dict=True,
130
        )
131

132
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
133

134
    def prepare_config_and_inputs_for_decoder(self):
135
        (
136
            config,
137
            input_ids,
138
            token_type_ids,
139
            input_mask,
140
            sequence_labels,
141
            token_labels,
142
            choice_labels,
143
        ) = self.prepare_config_and_inputs()
144

145
        config.is_decoder = True
146
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
147
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
148

149
        return (
150
            config,
151
            input_ids,
152
            token_type_ids,
153
            input_mask,
154
            sequence_labels,
155
            token_labels,
156
            choice_labels,
157
            encoder_hidden_states,
158
            encoder_attention_mask,
159
        )
160

161
    def create_and_check_model(
162
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
163
    ):
164
        model = TFRemBertModel(config=config)
165
        inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
166

167
        inputs = [input_ids, input_mask]
168
        result = model(inputs)
169

170
        result = model(input_ids)
171

172
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
173

174
    def create_and_check_causal_lm_base_model(
175
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
176
    ):
177
        config.is_decoder = True
178

179
        model = TFRemBertModel(config=config)
180
        inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
181
        result = model(inputs)
182

183
        inputs = [input_ids, input_mask]
184
        result = model(inputs)
185

186
        result = model(input_ids)
187

188
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
189

190
    def create_and_check_model_as_decoder(
191
        self,
192
        config,
193
        input_ids,
194
        token_type_ids,
195
        input_mask,
196
        sequence_labels,
197
        token_labels,
198
        choice_labels,
199
        encoder_hidden_states,
200
        encoder_attention_mask,
201
    ):
202
        config.add_cross_attention = True
203

204
        model = TFRemBertModel(config=config)
205
        inputs = {
206
            "input_ids": input_ids,
207
            "attention_mask": input_mask,
208
            "token_type_ids": token_type_ids,
209
            "encoder_hidden_states": encoder_hidden_states,
210
            "encoder_attention_mask": encoder_attention_mask,
211
        }
212
        result = model(inputs)
213

214
        inputs = [input_ids, input_mask]
215
        result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
216

217
        # Also check the case where encoder outputs are not passed
218
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
219

220
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
221

222
    def create_and_check_causal_lm_model(
223
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
224
    ):
225
        config.is_decoder = True
226
        model = TFRemBertForCausalLM(config=config)
227
        inputs = {
228
            "input_ids": input_ids,
229
            "attention_mask": input_mask,
230
            "token_type_ids": token_type_ids,
231
        }
232
        prediction_scores = model(inputs)["logits"]
233
        self.parent.assertListEqual(
234
            list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
235
        )
236

237
    def create_and_check_causal_lm_model_as_decoder(
238
        self,
239
        config,
240
        input_ids,
241
        token_type_ids,
242
        input_mask,
243
        sequence_labels,
244
        token_labels,
245
        choice_labels,
246
        encoder_hidden_states,
247
        encoder_attention_mask,
248
    ):
249
        config.add_cross_attention = True
250

251
        model = TFRemBertForCausalLM(config=config)
252
        inputs = {
253
            "input_ids": input_ids,
254
            "attention_mask": input_mask,
255
            "token_type_ids": token_type_ids,
256
            "encoder_hidden_states": encoder_hidden_states,
257
            "encoder_attention_mask": encoder_attention_mask,
258
        }
259
        result = model(inputs)
260

261
        inputs = [input_ids, input_mask]
262
        result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
263

264
        prediction_scores = result["logits"]
265
        self.parent.assertListEqual(
266
            list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
267
        )
268

269
    def create_and_check_causal_lm_model_past(
270
        self,
271
        config,
272
        input_ids,
273
        token_type_ids,
274
        input_mask,
275
        sequence_labels,
276
        token_labels,
277
        choice_labels,
278
    ):
279
        config.is_decoder = True
280

281
        model = TFRemBertForCausalLM(config=config)
282

283
        # first forward pass
284
        outputs = model(input_ids, use_cache=True)
285
        outputs_use_cache_conf = model(input_ids)
286
        outputs_no_past = model(input_ids, use_cache=False)
287

288
        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
289
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
290

291
        past_key_values = outputs.past_key_values
292

293
        # create hypothetical next token and extent to next_input_ids
294
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
295

296
        # append to next input_ids and attn_mask
297
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
298

299
        output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
300
        output_from_past = model(
301
            next_tokens, past_key_values=past_key_values, output_hidden_states=True
302
        ).hidden_states[0]
303

304
        # select random slice
305
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
306
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
307
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]
308

309
        # test that outputs are equal for slice
310
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
311

312
    def create_and_check_causal_lm_model_past_with_attn_mask(
313
        self,
314
        config,
315
        input_ids,
316
        token_type_ids,
317
        input_mask,
318
        sequence_labels,
319
        token_labels,
320
        choice_labels,
321
    ):
322
        config.is_decoder = True
323

324
        model = TFRemBertForCausalLM(config=config)
325

326
        # create attention mask
327
        half_seq_length = self.seq_length // 2
328
        attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
329
        attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
330
        attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
331

332
        # first forward pass
333
        outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
334

335
        # create hypothetical next token and extent to next_input_ids
336
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
337

338
        past_key_values = outputs.past_key_values
339

340
        # change a random masked slice from input_ids
341
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
342
        random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
343
        vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
344
        condition = tf.transpose(
345
            tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
346
        )
347
        input_ids = tf.where(condition, random_other_next_tokens, input_ids)
348

349
        # append to next input_ids and
350
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
351
        attn_mask = tf.concat(
352
            [attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
353
            axis=1,
354
        )
355

356
        output_from_no_past = model(
357
            next_input_ids,
358
            attention_mask=attn_mask,
359
            output_hidden_states=True,
360
        ).hidden_states[0]
361
        output_from_past = model(
362
            next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
363
        ).hidden_states[0]
364

365
        # select random slice
366
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
367
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
368
        output_from_past_slice = output_from_past[:, 0, random_slice_idx]
369

370
        # test that outputs are equal for slice
371
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
372

373
    def create_and_check_causal_lm_model_past_large_inputs(
374
        self,
375
        config,
376
        input_ids,
377
        token_type_ids,
378
        input_mask,
379
        sequence_labels,
380
        token_labels,
381
        choice_labels,
382
    ):
383
        config.is_decoder = True
384

385
        model = TFRemBertForCausalLM(config=config)
386

387
        input_ids = input_ids[:1, :]
388
        input_mask = input_mask[:1, :]
389
        self.batch_size = 1
390

391
        # first forward pass
392
        outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
393
        past_key_values = outputs.past_key_values
394

395
        # create hypothetical next token and extent to next_input_ids
396
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
397
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
398

399
        # append to next input_ids and
400
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
401
        next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
402

403
        output_from_no_past = model(
404
            next_input_ids,
405
            attention_mask=next_attention_mask,
406
            output_hidden_states=True,
407
        ).hidden_states[0]
408
        output_from_past = model(
409
            next_tokens,
410
            attention_mask=next_attention_mask,
411
            past_key_values=past_key_values,
412
            output_hidden_states=True,
413
        ).hidden_states[0]
414

415
        self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
416

417
        # select random slice
418
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
419
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
420
        output_from_past_slice = output_from_past[:, :, random_slice_idx]
421

422
        # test that outputs are equal for slice
423
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
424

425
    def create_and_check_decoder_model_past_large_inputs(
426
        self,
427
        config,
428
        input_ids,
429
        token_type_ids,
430
        input_mask,
431
        sequence_labels,
432
        token_labels,
433
        choice_labels,
434
        encoder_hidden_states,
435
        encoder_attention_mask,
436
    ):
437
        config.add_cross_attention = True
438

439
        model = TFRemBertForCausalLM(config=config)
440

441
        input_ids = input_ids[:1, :]
442
        input_mask = input_mask[:1, :]
443
        encoder_hidden_states = encoder_hidden_states[:1, :, :]
444
        encoder_attention_mask = encoder_attention_mask[:1, :]
445
        self.batch_size = 1
446

447
        # first forward pass
448
        outputs = model(
449
            input_ids,
450
            attention_mask=input_mask,
451
            encoder_hidden_states=encoder_hidden_states,
452
            encoder_attention_mask=encoder_attention_mask,
453
            use_cache=True,
454
        )
455
        past_key_values = outputs.past_key_values
456

457
        # create hypothetical next token and extent to next_input_ids
458
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
459
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
460

461
        # append to next input_ids and
462
        next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
463
        next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
464

465
        output_from_no_past = model(
466
            next_input_ids,
467
            attention_mask=next_attention_mask,
468
            encoder_hidden_states=encoder_hidden_states,
469
            encoder_attention_mask=encoder_attention_mask,
470
            output_hidden_states=True,
471
        ).hidden_states[0]
472
        output_from_past = model(
473
            next_tokens,
474
            attention_mask=next_attention_mask,
475
            encoder_hidden_states=encoder_hidden_states,
476
            encoder_attention_mask=encoder_attention_mask,
477
            past_key_values=past_key_values,
478
            output_hidden_states=True,
479
        ).hidden_states[0]
480

481
        self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
482

483
        # select random slice
484
        random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
485
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
486
        output_from_past_slice = output_from_past[:, :, random_slice_idx]
487

488
        # test that outputs are equal for slice
489
        tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
490

491
    def create_and_check_for_masked_lm(
492
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
493
    ):
494
        model = TFRemBertForMaskedLM(config=config)
495
        inputs = {
496
            "input_ids": input_ids,
497
            "attention_mask": input_mask,
498
            "token_type_ids": token_type_ids,
499
        }
500
        result = model(inputs)
501
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
502

503
    def create_and_check_for_sequence_classification(
504
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
505
    ):
506
        config.num_labels = self.num_labels
507
        model = TFRemBertForSequenceClassification(config=config)
508
        inputs = {
509
            "input_ids": input_ids,
510
            "attention_mask": input_mask,
511
            "token_type_ids": token_type_ids,
512
        }
513

514
        result = model(inputs)
515
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
516

517
    def create_and_check_for_multiple_choice(
518
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
519
    ):
520
        config.num_choices = self.num_choices
521
        model = TFRemBertForMultipleChoice(config=config)
522
        multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
523
        multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
524
        multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
525
        inputs = {
526
            "input_ids": multiple_choice_inputs_ids,
527
            "attention_mask": multiple_choice_input_mask,
528
            "token_type_ids": multiple_choice_token_type_ids,
529
        }
530
        result = model(inputs)
531
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
532

533
    def create_and_check_for_token_classification(
534
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
535
    ):
536
        config.num_labels = self.num_labels
537
        model = TFRemBertForTokenClassification(config=config)
538
        inputs = {
539
            "input_ids": input_ids,
540
            "attention_mask": input_mask,
541
            "token_type_ids": token_type_ids,
542
        }
543
        result = model(inputs)
544
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
545

546
    def create_and_check_for_question_answering(
547
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
548
    ):
549
        model = TFRemBertForQuestionAnswering(config=config)
550
        inputs = {
551
            "input_ids": input_ids,
552
            "attention_mask": input_mask,
553
            "token_type_ids": token_type_ids,
554
        }
555

556
        result = model(inputs)
557
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
558
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
559

560
    def prepare_config_and_inputs_for_common(self):
561
        config_and_inputs = self.prepare_config_and_inputs()
562
        (
563
            config,
564
            input_ids,
565
            token_type_ids,
566
            input_mask,
567
            sequence_labels,
568
            token_labels,
569
            choice_labels,
570
        ) = config_and_inputs
571
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
572
        return config, inputs_dict
573

574

575
@require_tf
576
class TFRemBertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
577
    all_model_classes = (
578
        (
579
            TFRemBertModel,
580
            TFRemBertForCausalLM,
581
            TFRemBertForMaskedLM,
582
            TFRemBertForQuestionAnswering,
583
            TFRemBertForSequenceClassification,
584
            TFRemBertForTokenClassification,
585
            TFRemBertForMultipleChoice,
586
        )
587
        if is_tf_available()
588
        else ()
589
    )
590
    pipeline_model_mapping = (
591
        {
592
            "feature-extraction": TFRemBertModel,
593
            "fill-mask": TFRemBertForMaskedLM,
594
            "question-answering": TFRemBertForQuestionAnswering,
595
            "text-classification": TFRemBertForSequenceClassification,
596
            "text-generation": TFRemBertForCausalLM,
597
            "token-classification": TFRemBertForTokenClassification,
598
            "zero-shot": TFRemBertForSequenceClassification,
599
        }
600
        if is_tf_available()
601
        else {}
602
    )
603

604
    test_head_masking = False
605
    test_onnx = False
606

607
    def setUp(self):
608
        self.model_tester = TFRemBertModelTester(self)
609
        self.config_tester = ConfigTester(self, config_class=RemBertConfig, hidden_size=37)
610

611
    def test_config(self):
612
        self.config_tester.run_common_tests()
613

614
    def test_model(self):
615
        """Test the base model"""
616
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
617
        self.model_tester.create_and_check_model(*config_and_inputs)
618

619
    def test_causal_lm_base_model(self):
620
        """Test the base model of the causal LM model
621

622
        is_deocder=True, no cross_attention, no encoder outputs
623
        """
624
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
625
        self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
626

627
    def test_model_as_decoder(self):
628
        """Test the base model as a decoder (of an encoder-decoder architecture)
629

630
        is_deocder=True + cross_attention + pass encoder outputs
631
        """
632
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
633
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
634

635
    def test_for_masked_lm(self):
636
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
637
        self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
638

639
    def test_for_causal_lm(self):
640
        """Test the causal LM model"""
641
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
642
        self.model_tester.create_and_check_causal_lm_model(*config_and_inputs)
643

644
    def test_causal_lm_model_as_decoder(self):
645
        """Test the causal LM model as a decoder"""
646
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
647
        self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs)
648

649
    def test_causal_lm_model_past(self):
650
        """Test causal LM model with `past_key_values`"""
651
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
652
        self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs)
653

654
    def test_causal_lm_model_past_with_attn_mask(self):
655
        """Test the causal LM model with `past_key_values` and `attention_mask`"""
656
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
657
        self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs)
658

659
    def test_causal_lm_model_past_with_large_inputs(self):
660
        """Test the causal LM model with `past_key_values` and a longer decoder sequence length"""
661
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
662
        self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs)
663

664
    def test_decoder_model_past_with_large_inputs(self):
665
        """Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention"""
666
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
667
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
668

669
    def test_for_multiple_choice(self):
670
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
671
        self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
672

673
    def test_for_question_answering(self):
674
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
675
        self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
676

677
    def test_for_sequence_classification(self):
678
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
679
        self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
680

681
    def test_for_token_classification(self):
682
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
683
        self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
684

685
    @slow
686
    def test_model_from_pretrained(self):
687
        model = TFRemBertModel.from_pretrained("google/rembert")
688
        self.assertIsNotNone(model)
689

690

691
@require_tf
692
class TFRemBertModelIntegrationTest(unittest.TestCase):
693
    @slow
694
    def test_inference_model(self):
695
        model = TFRemBertModel.from_pretrained("google/rembert")
696

697
        input_ids = tf.constant([[312, 56498, 313, 2125, 313]])
698
        segment_ids = tf.constant([[0, 0, 0, 1, 1]])
699
        output = model(input_ids, token_type_ids=segment_ids, output_hidden_states=True)
700

701
        hidden_size = 1152
702

703
        expected_shape = [1, 5, hidden_size]
704
        self.assertEqual(output["last_hidden_state"].shape, expected_shape)
705

706
        expected_implementation = tf.constant(
707
            [
708
                [
709
                    [0.0754, -0.2022, 0.1904],
710
                    [-0.3354, -0.3692, -0.4791],
711
                    [-0.2314, -0.6729, -0.0749],
712
                    [-0.0396, -0.3105, -0.4234],
713
                    [-0.1571, -0.0525, 0.5353],
714
                ]
715
            ]
716
        )
717
        tf.debugging.assert_near(output["last_hidden_state"][:, :, :3], expected_implementation, atol=1e-4)
718

719
        # Running on the original tf implementation gives slightly different results here.
720
        # Not clear why this variations is present
721
        # TODO: Find reason for discrepancy
722
        # expected_original_implementation = [[
723
        #     [0.07630594074726105, -0.20146065950393677, 0.19107051193714142],
724
        #     [-0.3405614495277405, -0.36971670389175415, -0.4808273911476135],
725
        #     [-0.22587086260318756, -0.6656315922737122, -0.07844287157058716],
726
        #     [-0.04145475849509239, -0.3077218234539032, -0.42316967248916626],
727
        #     [-0.15887849032878876, -0.054529931396245956, 0.5356100797653198]
728
        # ]]
729

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.