transformers

Форк
0
/
test_modeling_realm.py 
554 строки · 20.2 Кб
1
# coding=utf-8
2
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch REALM model. """
16

17
import copy
18
import unittest
19

20
import numpy as np
21

22
from transformers import RealmConfig, is_torch_available
23
from transformers.testing_utils import require_torch, slow, torch_device
24

25
from ...test_configuration_common import ConfigTester
26
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
27
from ...test_pipeline_mixin import PipelineTesterMixin
28

29

30
if is_torch_available():
31
    import torch
32

33
    from transformers import (
34
        RealmEmbedder,
35
        RealmForOpenQA,
36
        RealmKnowledgeAugEncoder,
37
        RealmReader,
38
        RealmRetriever,
39
        RealmScorer,
40
        RealmTokenizer,
41
    )
42

43

44
class RealmModelTester:
45
    def __init__(
46
        self,
47
        parent,
48
        batch_size=13,
49
        retriever_proj_size=128,
50
        seq_length=7,
51
        is_training=True,
52
        use_input_mask=True,
53
        use_token_type_ids=True,
54
        use_labels=True,
55
        vocab_size=99,
56
        hidden_size=32,
57
        num_hidden_layers=2,
58
        num_attention_heads=4,
59
        intermediate_size=37,
60
        hidden_act="gelu",
61
        hidden_dropout_prob=0.1,
62
        attention_probs_dropout_prob=0.1,
63
        max_position_embeddings=512,
64
        type_vocab_size=16,
65
        type_sequence_label_size=2,
66
        initializer_range=0.02,
67
        layer_norm_eps=1e-12,
68
        span_hidden_size=50,
69
        max_span_width=10,
70
        reader_layer_norm_eps=1e-3,
71
        reader_beam_size=4,
72
        reader_seq_len=288 + 32,
73
        num_block_records=13353718,
74
        searcher_beam_size=8,
75
        searcher_seq_len=64,
76
        num_labels=3,
77
        num_choices=4,
78
        num_candidates=10,
79
        scope=None,
80
    ):
81
        # General config
82
        self.parent = parent
83
        self.batch_size = batch_size
84
        self.retriever_proj_size = retriever_proj_size
85
        self.seq_length = seq_length
86
        self.is_training = is_training
87
        self.use_input_mask = use_input_mask
88
        self.use_token_type_ids = use_token_type_ids
89
        self.use_labels = use_labels
90
        self.vocab_size = vocab_size
91
        self.hidden_size = hidden_size
92
        self.num_hidden_layers = num_hidden_layers
93
        self.num_attention_heads = num_attention_heads
94
        self.intermediate_size = intermediate_size
95
        self.hidden_act = hidden_act
96
        self.hidden_dropout_prob = hidden_dropout_prob
97
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
98
        self.max_position_embeddings = max_position_embeddings
99
        self.type_vocab_size = type_vocab_size
100
        self.type_sequence_label_size = type_sequence_label_size
101
        self.initializer_range = initializer_range
102
        self.layer_norm_eps = layer_norm_eps
103

104
        # Reader config
105
        self.span_hidden_size = span_hidden_size
106
        self.max_span_width = max_span_width
107
        self.reader_layer_norm_eps = reader_layer_norm_eps
108
        self.reader_beam_size = reader_beam_size
109
        self.reader_seq_len = reader_seq_len
110

111
        # Searcher config
112
        self.num_block_records = num_block_records
113
        self.searcher_beam_size = searcher_beam_size
114
        self.searcher_seq_len = searcher_seq_len
115

116
        self.num_labels = num_labels
117
        self.num_choices = num_choices
118
        self.num_candidates = num_candidates
119
        self.scope = scope
120

121
    def prepare_config_and_inputs(self):
122
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
123
        candiate_input_ids = ids_tensor([self.batch_size, self.num_candidates, self.seq_length], self.vocab_size)
124
        reader_input_ids = ids_tensor([self.reader_beam_size, self.reader_seq_len], self.vocab_size)
125

126
        input_mask = None
127
        candiate_input_mask = None
128
        reader_input_mask = None
129
        if self.use_input_mask:
130
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
131
            candiate_input_mask = random_attention_mask([self.batch_size, self.num_candidates, self.seq_length])
132
            reader_input_mask = random_attention_mask([self.reader_beam_size, self.reader_seq_len])
133

134
        token_type_ids = None
135
        candidate_token_type_ids = None
136
        reader_token_type_ids = None
137
        if self.use_token_type_ids:
138
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
139
            candidate_token_type_ids = ids_tensor(
140
                [self.batch_size, self.num_candidates, self.seq_length], self.type_vocab_size
141
            )
142
            reader_token_type_ids = ids_tensor([self.reader_beam_size, self.reader_seq_len], self.type_vocab_size)
143

144
        sequence_labels = None
145
        token_labels = None
146
        choice_labels = None
147
        if self.use_labels:
148
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
149
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
150
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
151

152
        config = self.get_config()
153

154
        # inputs with additional num_candidates axis.
155
        scorer_encoder_inputs = (candiate_input_ids, candiate_input_mask, candidate_token_type_ids)
156
        # reader inputs
157
        reader_inputs = (reader_input_ids, reader_input_mask, reader_token_type_ids)
158

159
        return (
160
            config,
161
            input_ids,
162
            token_type_ids,
163
            input_mask,
164
            scorer_encoder_inputs,
165
            reader_inputs,
166
            sequence_labels,
167
            token_labels,
168
            choice_labels,
169
        )
170

171
    def get_config(self):
172
        return RealmConfig(
173
            vocab_size=self.vocab_size,
174
            hidden_size=self.hidden_size,
175
            retriever_proj_size=self.retriever_proj_size,
176
            num_hidden_layers=self.num_hidden_layers,
177
            num_attention_heads=self.num_attention_heads,
178
            num_candidates=self.num_candidates,
179
            intermediate_size=self.intermediate_size,
180
            hidden_act=self.hidden_act,
181
            hidden_dropout_prob=self.hidden_dropout_prob,
182
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
183
            max_position_embeddings=self.max_position_embeddings,
184
            type_vocab_size=self.type_vocab_size,
185
            initializer_range=self.initializer_range,
186
        )
187

188
    def create_and_check_embedder(
189
        self,
190
        config,
191
        input_ids,
192
        token_type_ids,
193
        input_mask,
194
        scorer_encoder_inputs,
195
        reader_inputs,
196
        sequence_labels,
197
        token_labels,
198
        choice_labels,
199
    ):
200
        model = RealmEmbedder(config=config)
201
        model.to(torch_device)
202
        model.eval()
203
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
204
        self.parent.assertEqual(result.projected_score.shape, (self.batch_size, self.retriever_proj_size))
205

206
    def create_and_check_encoder(
207
        self,
208
        config,
209
        input_ids,
210
        token_type_ids,
211
        input_mask,
212
        scorer_encoder_inputs,
213
        reader_inputs,
214
        sequence_labels,
215
        token_labels,
216
        choice_labels,
217
    ):
218
        model = RealmKnowledgeAugEncoder(config=config)
219
        model.to(torch_device)
220
        model.eval()
221
        relevance_score = floats_tensor([self.batch_size, self.num_candidates])
222
        result = model(
223
            scorer_encoder_inputs[0],
224
            attention_mask=scorer_encoder_inputs[1],
225
            token_type_ids=scorer_encoder_inputs[2],
226
            relevance_score=relevance_score,
227
            labels=token_labels,
228
        )
229
        self.parent.assertEqual(
230
            result.logits.shape, (self.batch_size * self.num_candidates, self.seq_length, self.vocab_size)
231
        )
232

233
    def create_and_check_reader(
234
        self,
235
        config,
236
        input_ids,
237
        token_type_ids,
238
        input_mask,
239
        scorer_encoder_inputs,
240
        reader_inputs,
241
        sequence_labels,
242
        token_labels,
243
        choice_labels,
244
    ):
245
        model = RealmReader(config=config)
246
        model.to(torch_device)
247
        model.eval()
248
        relevance_score = floats_tensor([self.reader_beam_size])
249
        result = model(
250
            reader_inputs[0],
251
            attention_mask=reader_inputs[1],
252
            token_type_ids=reader_inputs[2],
253
            relevance_score=relevance_score,
254
        )
255
        self.parent.assertEqual(result.block_idx.shape, ())
256
        self.parent.assertEqual(result.candidate.shape, ())
257
        self.parent.assertEqual(result.start_pos.shape, ())
258
        self.parent.assertEqual(result.end_pos.shape, ())
259

260
    def create_and_check_scorer(
261
        self,
262
        config,
263
        input_ids,
264
        token_type_ids,
265
        input_mask,
266
        scorer_encoder_inputs,
267
        reader_inputs,
268
        sequence_labels,
269
        token_labels,
270
        choice_labels,
271
    ):
272
        model = RealmScorer(config=config)
273
        model.to(torch_device)
274
        model.eval()
275
        result = model(
276
            input_ids,
277
            attention_mask=input_mask,
278
            token_type_ids=token_type_ids,
279
            candidate_input_ids=scorer_encoder_inputs[0],
280
            candidate_attention_mask=scorer_encoder_inputs[1],
281
            candidate_token_type_ids=scorer_encoder_inputs[2],
282
        )
283
        self.parent.assertEqual(result.relevance_score.shape, (self.batch_size, self.num_candidates))
284
        self.parent.assertEqual(result.query_score.shape, (self.batch_size, self.retriever_proj_size))
285
        self.parent.assertEqual(
286
            result.candidate_score.shape, (self.batch_size, self.num_candidates, self.retriever_proj_size)
287
        )
288

289
    def prepare_config_and_inputs_for_common(self):
290
        config_and_inputs = self.prepare_config_and_inputs()
291
        (
292
            config,
293
            input_ids,
294
            token_type_ids,
295
            input_mask,
296
            scorer_encoder_inputs,
297
            reader_inputs,
298
            sequence_labels,
299
            token_labels,
300
            choice_labels,
301
        ) = config_and_inputs
302
        inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
303
        return config, inputs_dict
304

305

306
@require_torch
307
class RealmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
308
    all_model_classes = (
309
        (
310
            RealmEmbedder,
311
            RealmKnowledgeAugEncoder,
312
            # RealmScorer is excluded from common tests as it is a container model
313
            # consisting of two RealmEmbedders & a simple inner product calculation.
314
            # RealmScorer
315
        )
316
        if is_torch_available()
317
        else ()
318
    )
319
    all_generative_model_classes = ()
320
    pipeline_model_mapping = {} if is_torch_available() else {}
321

322
    # disable these tests because there is no base_model in Realm
323
    test_save_load_fast_init_from_base = False
324
    test_save_load_fast_init_to_base = False
325

326
    def setUp(self):
327
        self.test_pruning = False
328
        self.model_tester = RealmModelTester(self)
329
        self.config_tester = ConfigTester(self, config_class=RealmConfig)
330

331
    def test_config(self):
332
        self.config_tester.run_common_tests()
333

334
    def test_embedder(self):
335
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
336
        self.model_tester.create_and_check_embedder(*config_and_inputs)
337

338
    def test_encoder(self):
339
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
340
        self.model_tester.create_and_check_encoder(*config_and_inputs)
341

342
    def test_model_various_embeddings(self):
343
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
344
        for type in ["absolute", "relative_key", "relative_key_query"]:
345
            config_and_inputs[0].position_embedding_type = type
346
            self.model_tester.create_and_check_embedder(*config_and_inputs)
347
            self.model_tester.create_and_check_encoder(*config_and_inputs)
348

349
    def test_scorer(self):
350
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
351
        self.model_tester.create_and_check_scorer(*config_and_inputs)
352

353
    def test_training(self):
354
        if not self.model_tester.is_training:
355
            return
356

357
        config, *inputs = self.model_tester.prepare_config_and_inputs()
358
        input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
359
        config.return_dict = True
360

361
        tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
362

363
        # RealmKnowledgeAugEncoder training
364
        model = RealmKnowledgeAugEncoder(config)
365
        model.to(torch_device)
366
        model.train()
367

368
        inputs_dict = {
369
            "input_ids": scorer_encoder_inputs[0].to(torch_device),
370
            "attention_mask": scorer_encoder_inputs[1].to(torch_device),
371
            "token_type_ids": scorer_encoder_inputs[2].to(torch_device),
372
            "relevance_score": floats_tensor([self.model_tester.batch_size, self.model_tester.num_candidates]),
373
        }
374
        inputs_dict["labels"] = torch.zeros(
375
            (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
376
        )
377
        inputs = inputs_dict
378
        loss = model(**inputs).loss
379
        loss.backward()
380

381
        # RealmForOpenQA training
382
        openqa_config = copy.deepcopy(config)
383
        openqa_config.vocab_size = 30522  # the retrieved texts will inevitably have more than 99 vocabs.
384
        openqa_config.num_block_records = 5
385
        openqa_config.searcher_beam_size = 2
386

387
        block_records = np.array(
388
            [
389
                b"This is the first record.",
390
                b"This is the second record.",
391
                b"This is the third record.",
392
                b"This is the fourth record.",
393
                b"This is the fifth record.",
394
            ],
395
            dtype=object,
396
        )
397
        retriever = RealmRetriever(block_records, tokenizer)
398
        model = RealmForOpenQA(openqa_config, retriever)
399
        model.to(torch_device)
400
        model.train()
401

402
        inputs_dict = {
403
            "input_ids": input_ids[:1].to(torch_device),
404
            "attention_mask": input_mask[:1].to(torch_device),
405
            "token_type_ids": token_type_ids[:1].to(torch_device),
406
            "answer_ids": input_ids[:1].tolist(),
407
        }
408
        inputs = self._prepare_for_class(inputs_dict, RealmForOpenQA)
409
        loss = model(**inputs).reader_output.loss
410
        loss.backward()
411

412
        # Test model.block_embedding_to
413
        device = torch.device("cpu")
414
        model.block_embedding_to(device)
415
        loss = model(**inputs).reader_output.loss
416
        loss.backward()
417
        self.assertEqual(model.block_emb.device.type, device.type)
418

419
    @slow
420
    def test_embedder_from_pretrained(self):
421
        model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
422
        self.assertIsNotNone(model)
423

424
    @slow
425
    def test_encoder_from_pretrained(self):
426
        model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
427
        self.assertIsNotNone(model)
428

429
    @slow
430
    def test_open_qa_from_pretrained(self):
431
        model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
432
        self.assertIsNotNone(model)
433

434
    @slow
435
    def test_reader_from_pretrained(self):
436
        model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
437
        self.assertIsNotNone(model)
438

439
    @slow
440
    def test_scorer_from_pretrained(self):
441
        model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
442
        self.assertIsNotNone(model)
443

444

445
@require_torch
446
class RealmModelIntegrationTest(unittest.TestCase):
447
    @slow
448
    def test_inference_embedder(self):
449
        retriever_projected_size = 128
450

451
        model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
452
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
453
        output = model(input_ids)[0]
454

455
        expected_shape = torch.Size((1, retriever_projected_size))
456
        self.assertEqual(output.shape, expected_shape)
457

458
        expected_slice = torch.tensor([[-0.0714, -0.0837, -0.1314]])
459
        self.assertTrue(torch.allclose(output[:, :3], expected_slice, atol=1e-4))
460

461
    @slow
462
    def test_inference_encoder(self):
463
        num_candidates = 2
464
        vocab_size = 30522
465

466
        model = RealmKnowledgeAugEncoder.from_pretrained(
467
            "google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
468
        )
469
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
470
        relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
471
        output = model(input_ids, relevance_score=relevance_score)[0]
472

473
        expected_shape = torch.Size((2, 6, vocab_size))
474
        self.assertEqual(output.shape, expected_shape)
475

476
        expected_slice = torch.tensor([[[-11.0888, -11.2544], [-10.2170, -10.3874]]])
477

478
        self.assertTrue(torch.allclose(output[1, :2, :2], expected_slice, atol=1e-4))
479

480
    @slow
481
    def test_inference_open_qa(self):
482
        from transformers.models.realm.retrieval_realm import RealmRetriever
483

484
        tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
485
        retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
486

487
        model = RealmForOpenQA.from_pretrained(
488
            "google/realm-orqa-nq-openqa",
489
            retriever=retriever,
490
        )
491

492
        question = "Who is the pioneer in modern computer science?"
493

494
        question = tokenizer(
495
            [question],
496
            padding=True,
497
            truncation=True,
498
            max_length=model.config.searcher_seq_len,
499
            return_tensors="pt",
500
        ).to(model.device)
501

502
        predicted_answer_ids = model(**question).predicted_answer_ids
503

504
        predicted_answer = tokenizer.decode(predicted_answer_ids)
505
        self.assertEqual(predicted_answer, "alan mathison turing")
506

507
    @slow
508
    def test_inference_reader(self):
509
        config = RealmConfig(reader_beam_size=2, max_span_width=3)
510
        model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
511

512
        concat_input_ids = torch.arange(10).view((2, 5))
513
        concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
514
        concat_block_mask = torch.tensor([[0, 0, 1, 1, 0], [0, 0, 1, 1, 0]], dtype=torch.int64)
515
        relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
516

517
        output = model(
518
            concat_input_ids,
519
            token_type_ids=concat_token_type_ids,
520
            relevance_score=relevance_score,
521
            block_mask=concat_block_mask,
522
            return_dict=True,
523
        )
524

525
        block_idx_expected_shape = torch.Size(())
526
        start_pos_expected_shape = torch.Size((1,))
527
        end_pos_expected_shape = torch.Size((1,))
528
        self.assertEqual(output.block_idx.shape, block_idx_expected_shape)
529
        self.assertEqual(output.start_pos.shape, start_pos_expected_shape)
530
        self.assertEqual(output.end_pos.shape, end_pos_expected_shape)
531

532
        expected_block_idx = torch.tensor(1)
533
        expected_start_pos = torch.tensor(3)
534
        expected_end_pos = torch.tensor(3)
535

536
        self.assertTrue(torch.allclose(output.block_idx, expected_block_idx, atol=1e-4))
537
        self.assertTrue(torch.allclose(output.start_pos, expected_start_pos, atol=1e-4))
538
        self.assertTrue(torch.allclose(output.end_pos, expected_end_pos, atol=1e-4))
539

540
    @slow
541
    def test_inference_scorer(self):
542
        num_candidates = 2
543

544
        model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
545

546
        input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
547
        candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
548
        output = model(input_ids, candidate_input_ids=candidate_input_ids)[0]
549

550
        expected_shape = torch.Size((1, 2))
551
        self.assertEqual(output.shape, expected_shape)
552

553
        expected_slice = torch.tensor([[0.7410, 0.7170]])
554
        self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))
555

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.