transformers

Форк
0
/
test_modeling_bert.py 
669 строк · 26.4 Кб
1
# coding=utf-8
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import os
16
import tempfile
17
import unittest
18

19
from transformers import BertConfig, is_torch_available
20
from transformers.models.auto import get_values
21
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_accelerator, slow, torch_device
22

23
from ...generation.test_utils import GenerationTesterMixin
24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_torch_available():
30
    import torch
31

32
    from transformers import (
33
        MODEL_FOR_PRETRAINING_MAPPING,
34
        BertForMaskedLM,
35
        BertForMultipleChoice,
36
        BertForNextSentencePrediction,
37
        BertForPreTraining,
38
        BertForQuestionAnswering,
39
        BertForSequenceClassification,
40
        BertForTokenClassification,
41
        BertLMHeadModel,
42
        BertModel,
43
        logging,
44
    )
45
    from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
46

47

48
class BertModelTester:
49
    def __init__(
50
        self,
51
        parent,
52
        batch_size=13,
53
        seq_length=7,
54
        is_training=True,
55
        use_input_mask=True,
56
        use_token_type_ids=True,
57
        use_labels=True,
58
        vocab_size=99,
59
        hidden_size=32,
60
        num_hidden_layers=2,
61
        num_attention_heads=4,
62
        intermediate_size=37,
63
        hidden_act="gelu",
64
        hidden_dropout_prob=0.1,
65
        attention_probs_dropout_prob=0.1,
66
        max_position_embeddings=512,
67
        type_vocab_size=16,
68
        type_sequence_label_size=2,
69
        initializer_range=0.02,
70
        num_labels=3,
71
        num_choices=4,
72
        scope=None,
73
    ):
74
        self.parent = parent
75
        self.batch_size = batch_size
76
        self.seq_length = seq_length
77
        self.is_training = is_training
78
        self.use_input_mask = use_input_mask
79
        self.use_token_type_ids = use_token_type_ids
80
        self.use_labels = use_labels
81
        self.vocab_size = vocab_size
82
        self.hidden_size = hidden_size
83
        self.num_hidden_layers = num_hidden_layers
84
        self.num_attention_heads = num_attention_heads
85
        self.intermediate_size = intermediate_size
86
        self.hidden_act = hidden_act
87
        self.hidden_dropout_prob = hidden_dropout_prob
88
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
89
        self.max_position_embeddings = max_position_embeddings
90
        self.type_vocab_size = type_vocab_size
91
        self.type_sequence_label_size = type_sequence_label_size
92
        self.initializer_range = initializer_range
93
        self.num_labels = num_labels
94
        self.num_choices = num_choices
95
        self.scope = scope
96

97
    def prepare_config_and_inputs(self):
98
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
99

100
        input_mask = None
101
        if self.use_input_mask:
102
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
103

104
        token_type_ids = None
105
        if self.use_token_type_ids:
106
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
107

108
        sequence_labels = None
109
        token_labels = None
110
        choice_labels = None
111
        if self.use_labels:
112
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
113
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
114
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
115

116
        config = self.get_config()
117

118
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
119

120
    def get_config(self):
121
        """
122
        Returns a tiny configuration by default.
123
        """
124
        return BertConfig(
125
            vocab_size=self.vocab_size,
126
            hidden_size=self.hidden_size,
127
            num_hidden_layers=self.num_hidden_layers,
128
            num_attention_heads=self.num_attention_heads,
129
            intermediate_size=self.intermediate_size,
130
            hidden_act=self.hidden_act,
131
            hidden_dropout_prob=self.hidden_dropout_prob,
132
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
133
            max_position_embeddings=self.max_position_embeddings,
134
            type_vocab_size=self.type_vocab_size,
135
            is_decoder=False,
136
            initializer_range=self.initializer_range,
137
        )
138

139
    def prepare_config_and_inputs_for_decoder(self):
140
        (
141
            config,
142
            input_ids,
143
            token_type_ids,
144
            input_mask,
145
            sequence_labels,
146
            token_labels,
147
            choice_labels,
148
        ) = self.prepare_config_and_inputs()
149

150
        config.is_decoder = True
151
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
152
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
153

154
        return (
155
            config,
156
            input_ids,
157
            token_type_ids,
158
            input_mask,
159
            sequence_labels,
160
            token_labels,
161
            choice_labels,
162
            encoder_hidden_states,
163
            encoder_attention_mask,
164
        )
165

166
    def create_and_check_model(
167
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
168
    ):
169
        model = BertModel(config=config)
170
        model.to(torch_device)
171
        model.eval()
172
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
173
        result = model(input_ids, token_type_ids=token_type_ids)
174
        result = model(input_ids)
175
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
176
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
177

178
    def create_and_check_model_as_decoder(
179
        self,
180
        config,
181
        input_ids,
182
        token_type_ids,
183
        input_mask,
184
        sequence_labels,
185
        token_labels,
186
        choice_labels,
187
        encoder_hidden_states,
188
        encoder_attention_mask,
189
    ):
190
        config.add_cross_attention = True
191
        model = BertModel(config)
192
        model.to(torch_device)
193
        model.eval()
194
        result = model(
195
            input_ids,
196
            attention_mask=input_mask,
197
            token_type_ids=token_type_ids,
198
            encoder_hidden_states=encoder_hidden_states,
199
            encoder_attention_mask=encoder_attention_mask,
200
        )
201
        result = model(
202
            input_ids,
203
            attention_mask=input_mask,
204
            token_type_ids=token_type_ids,
205
            encoder_hidden_states=encoder_hidden_states,
206
        )
207
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
208
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
209
        self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
210

211
    def create_and_check_for_causal_lm(
212
        self,
213
        config,
214
        input_ids,
215
        token_type_ids,
216
        input_mask,
217
        sequence_labels,
218
        token_labels,
219
        choice_labels,
220
        encoder_hidden_states,
221
        encoder_attention_mask,
222
    ):
223
        model = BertLMHeadModel(config=config)
224
        model.to(torch_device)
225
        model.eval()
226
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
227
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
228

229
    def create_and_check_for_masked_lm(
230
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
231
    ):
232
        model = BertForMaskedLM(config=config)
233
        model.to(torch_device)
234
        model.eval()
235
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
236
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
237

238
    def create_and_check_model_for_causal_lm_as_decoder(
239
        self,
240
        config,
241
        input_ids,
242
        token_type_ids,
243
        input_mask,
244
        sequence_labels,
245
        token_labels,
246
        choice_labels,
247
        encoder_hidden_states,
248
        encoder_attention_mask,
249
    ):
250
        config.add_cross_attention = True
251
        model = BertLMHeadModel(config=config)
252
        model.to(torch_device)
253
        model.eval()
254
        result = model(
255
            input_ids,
256
            attention_mask=input_mask,
257
            token_type_ids=token_type_ids,
258
            labels=token_labels,
259
            encoder_hidden_states=encoder_hidden_states,
260
            encoder_attention_mask=encoder_attention_mask,
261
        )
262
        result = model(
263
            input_ids,
264
            attention_mask=input_mask,
265
            token_type_ids=token_type_ids,
266
            labels=token_labels,
267
            encoder_hidden_states=encoder_hidden_states,
268
        )
269
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
270

271
    def create_and_check_decoder_model_past_large_inputs(
272
        self,
273
        config,
274
        input_ids,
275
        token_type_ids,
276
        input_mask,
277
        sequence_labels,
278
        token_labels,
279
        choice_labels,
280
        encoder_hidden_states,
281
        encoder_attention_mask,
282
    ):
283
        config.is_decoder = True
284
        config.add_cross_attention = True
285
        model = BertLMHeadModel(config=config).to(torch_device).eval()
286

287
        # first forward pass
288
        outputs = model(
289
            input_ids,
290
            attention_mask=input_mask,
291
            encoder_hidden_states=encoder_hidden_states,
292
            encoder_attention_mask=encoder_attention_mask,
293
            use_cache=True,
294
        )
295
        past_key_values = outputs.past_key_values
296

297
        # create hypothetical multiple next token and extent to next_input_ids
298
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
299
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
300

301
        # append to next input_ids and
302
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
303
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
304

305
        output_from_no_past = model(
306
            next_input_ids,
307
            attention_mask=next_attention_mask,
308
            encoder_hidden_states=encoder_hidden_states,
309
            encoder_attention_mask=encoder_attention_mask,
310
            output_hidden_states=True,
311
        )["hidden_states"][0]
312
        output_from_past = model(
313
            next_tokens,
314
            attention_mask=next_attention_mask,
315
            encoder_hidden_states=encoder_hidden_states,
316
            encoder_attention_mask=encoder_attention_mask,
317
            past_key_values=past_key_values,
318
            output_hidden_states=True,
319
        )["hidden_states"][0]
320

321
        # select random slice
322
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
323
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
324
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
325

326
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
327

328
        # test that outputs are equal for slice
329
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
330

331
    def create_and_check_for_next_sequence_prediction(
332
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
333
    ):
334
        model = BertForNextSentencePrediction(config=config)
335
        model.to(torch_device)
336
        model.eval()
337
        result = model(
338
            input_ids,
339
            attention_mask=input_mask,
340
            token_type_ids=token_type_ids,
341
            labels=sequence_labels,
342
        )
343
        self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
344

345
    def create_and_check_for_pretraining(
346
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
347
    ):
348
        model = BertForPreTraining(config=config)
349
        model.to(torch_device)
350
        model.eval()
351
        result = model(
352
            input_ids,
353
            attention_mask=input_mask,
354
            token_type_ids=token_type_ids,
355
            labels=token_labels,
356
            next_sentence_label=sequence_labels,
357
        )
358
        self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
359
        self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
360

361
    def create_and_check_for_question_answering(
362
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
363
    ):
364
        model = BertForQuestionAnswering(config=config)
365
        model.to(torch_device)
366
        model.eval()
367
        result = model(
368
            input_ids,
369
            attention_mask=input_mask,
370
            token_type_ids=token_type_ids,
371
            start_positions=sequence_labels,
372
            end_positions=sequence_labels,
373
        )
374
        self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
375
        self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
376

377
    def create_and_check_for_sequence_classification(
378
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
379
    ):
380
        config.num_labels = self.num_labels
381
        model = BertForSequenceClassification(config)
382
        model.to(torch_device)
383
        model.eval()
384
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
385
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
386

387
    def create_and_check_for_token_classification(
388
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
389
    ):
390
        config.num_labels = self.num_labels
391
        model = BertForTokenClassification(config=config)
392
        model.to(torch_device)
393
        model.eval()
394
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
395
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
396

397
    def create_and_check_for_multiple_choice(
398
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
399
    ):
400
        config.num_choices = self.num_choices
401
        model = BertForMultipleChoice(config=config)
402
        model.to(torch_device)
403
        model.eval()
404
        multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
405
        multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
406
        multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
407
        result = model(
408
            multiple_choice_inputs_ids,
409
            attention_mask=multiple_choice_input_mask,
410
            token_type_ids=multiple_choice_token_type_ids,
411
            labels=choice_labels,
412
        )
413
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
414

415
    def prepare_config_and_inputs_for_common(self):
416
        config_and_inputs = self.prepare_config_and_inputs()
417
        (
418
            config,
419
            input_ids,
420
            token_type_ids,
421
            input_mask,
422
            sequence_labels,
423
            token_labels,
424
            choice_labels,
425
        ) = config_and_inputs
426
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
427
        return config, inputs_dict
428

429

430
@require_torch
431
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
432
    all_model_classes = (
433
        (
434
            BertModel,
435
            BertLMHeadModel,
436
            BertForMaskedLM,
437
            BertForMultipleChoice,
438
            BertForNextSentencePrediction,
439
            BertForPreTraining,
440
            BertForQuestionAnswering,
441
            BertForSequenceClassification,
442
            BertForTokenClassification,
443
        )
444
        if is_torch_available()
445
        else ()
446
    )
447
    all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
448
    pipeline_model_mapping = (
449
        {
450
            "feature-extraction": BertModel,
451
            "fill-mask": BertForMaskedLM,
452
            "question-answering": BertForQuestionAnswering,
453
            "text-classification": BertForSequenceClassification,
454
            "text-generation": BertLMHeadModel,
455
            "token-classification": BertForTokenClassification,
456
            "zero-shot": BertForSequenceClassification,
457
        }
458
        if is_torch_available()
459
        else {}
460
    )
461
    fx_compatible = True
462

463
    # special case for ForPreTraining model
464
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
465
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
466

467
        if return_labels:
468
            if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
469
                inputs_dict["labels"] = torch.zeros(
470
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
471
                )
472
                inputs_dict["next_sentence_label"] = torch.zeros(
473
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
474
                )
475
        return inputs_dict
476

477
    def setUp(self):
478
        self.model_tester = BertModelTester(self)
479
        self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
480

481
    def test_config(self):
482
        self.config_tester.run_common_tests()
483

484
    def test_model(self):
485
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
486
        self.model_tester.create_and_check_model(*config_and_inputs)
487

488
    def test_model_various_embeddings(self):
489
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
490
        for type in ["absolute", "relative_key", "relative_key_query"]:
491
            config_and_inputs[0].position_embedding_type = type
492
            self.model_tester.create_and_check_model(*config_and_inputs)
493

494
    def test_model_as_decoder(self):
495
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
496
        self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
497

498
    def test_model_as_decoder_with_default_input_mask(self):
499
        # This regression test was failing with PyTorch < 1.3
500
        (
501
            config,
502
            input_ids,
503
            token_type_ids,
504
            input_mask,
505
            sequence_labels,
506
            token_labels,
507
            choice_labels,
508
            encoder_hidden_states,
509
            encoder_attention_mask,
510
        ) = self.model_tester.prepare_config_and_inputs_for_decoder()
511

512
        input_mask = None
513

514
        self.model_tester.create_and_check_model_as_decoder(
515
            config,
516
            input_ids,
517
            token_type_ids,
518
            input_mask,
519
            sequence_labels,
520
            token_labels,
521
            choice_labels,
522
            encoder_hidden_states,
523
            encoder_attention_mask,
524
        )
525

526
    def test_for_causal_lm(self):
527
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
528
        self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
529

530
    def test_for_masked_lm(self):
531
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
532
        self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
533

534
    def test_for_causal_lm_decoder(self):
535
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
536
        self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
537

538
    def test_decoder_model_past_with_large_inputs(self):
539
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
540
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
541

542
    def test_decoder_model_past_with_large_inputs_relative_pos_emb(self):
543
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
544
        config_and_inputs[0].position_embedding_type = "relative_key"
545
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
546

547
    def test_for_multiple_choice(self):
548
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
549
        self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
550

551
    def test_for_next_sequence_prediction(self):
552
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
553
        self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
554

555
    def test_for_pretraining(self):
556
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
557
        self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
558

559
    def test_for_question_answering(self):
560
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
561
        self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
562

563
    def test_for_sequence_classification(self):
564
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
565
        self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
566

567
    def test_for_token_classification(self):
568
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
569
        self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
570

571
    def test_for_warning_if_padding_and_no_attention_mask(self):
572
        (
573
            config,
574
            input_ids,
575
            token_type_ids,
576
            input_mask,
577
            sequence_labels,
578
            token_labels,
579
            choice_labels,
580
        ) = self.model_tester.prepare_config_and_inputs()
581

582
        # Set pad tokens in the input_ids
583
        input_ids[0, 0] = config.pad_token_id
584

585
        # Check for warnings if the attention_mask is missing.
586
        logger = logging.get_logger("transformers.modeling_utils")
587
        # clear cache so we can test the warning is emitted (from `warning_once`).
588
        logger.warning_once.cache_clear()
589

590
        with CaptureLogger(logger) as cl:
591
            model = BertModel(config=config)
592
            model.to(torch_device)
593
            model.eval()
594
            model(input_ids, attention_mask=None, token_type_ids=token_type_ids)
595
        self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out)
596

597
    @slow
598
    def test_model_from_pretrained(self):
599
        for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
600
            model = BertModel.from_pretrained(model_name)
601
            self.assertIsNotNone(model)
602

603
    @slow
604
    @require_torch_accelerator
605
    def test_torchscript_device_change(self):
606
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
607
        for model_class in self.all_model_classes:
608
            # BertForMultipleChoice behaves incorrectly in JIT environments.
609
            if model_class == BertForMultipleChoice:
610
                return
611

612
            config.torchscript = True
613
            model = model_class(config=config)
614

615
            inputs_dict = self._prepare_for_class(inputs_dict, model_class)
616
            traced_model = torch.jit.trace(
617
                model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
618
            )
619

620
            with tempfile.TemporaryDirectory() as tmp:
621
                torch.jit.save(traced_model, os.path.join(tmp, "bert.pt"))
622
                loaded = torch.jit.load(os.path.join(tmp, "bert.pt"), map_location=torch_device)
623
                loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
624

625

626
@require_torch
627
class BertModelIntegrationTest(unittest.TestCase):
628
    @slow
629
    def test_inference_no_head_absolute_embedding(self):
630
        model = BertModel.from_pretrained("google-bert/bert-base-uncased")
631
        input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
632
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
633
        with torch.no_grad():
634
            output = model(input_ids, attention_mask=attention_mask)[0]
635
        expected_shape = torch.Size((1, 11, 768))
636
        self.assertEqual(output.shape, expected_shape)
637
        expected_slice = torch.tensor([[[0.4249, 0.1008, 0.7531], [0.3771, 0.1188, 0.7467], [0.4152, 0.1098, 0.7108]]])
638

639
        self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
640

641
    @slow
642
    def test_inference_no_head_relative_embedding_key(self):
643
        model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
644
        input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
645
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
646
        with torch.no_grad():
647
            output = model(input_ids, attention_mask=attention_mask)[0]
648
        expected_shape = torch.Size((1, 11, 768))
649
        self.assertEqual(output.shape, expected_shape)
650
        expected_slice = torch.tensor(
651
            [[[0.0756, 0.3142, -0.5128], [0.3761, 0.3462, -0.5477], [0.2052, 0.3760, -0.1240]]]
652
        )
653

654
        self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
655

656
    @slow
657
    def test_inference_no_head_relative_embedding_key_query(self):
658
        model = BertModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key-query")
659
        input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
660
        attention_mask = torch.tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
661
        with torch.no_grad():
662
            output = model(input_ids, attention_mask=attention_mask)[0]
663
        expected_shape = torch.Size((1, 11, 768))
664
        self.assertEqual(output.shape, expected_shape)
665
        expected_slice = torch.tensor(
666
            [[[0.6496, 0.3784, 0.8203], [0.8148, 0.5656, 0.2636], [-0.0681, 0.5597, 0.7045]]]
667
        )
668

669
        self.assertTrue(torch.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))
670

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.