transformers

Форк
0
/
test_modeling_plbart.py 
675 строк · 26.6 Кб
1
# coding=utf-8
2
# Copyright 2022, The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch PLBART model. """
16

17

18
import copy
19
import tempfile
20
import unittest
21

22
from transformers import PLBartConfig, is_torch_available
23
from transformers.testing_utils import (
24
    require_sentencepiece,
25
    require_tokenizers,
26
    require_torch,
27
    require_torch_fp16,
28
    slow,
29
    torch_device,
30
)
31
from transformers.utils import cached_property
32

33
from ...generation.test_utils import GenerationTesterMixin
34
from ...test_configuration_common import ConfigTester
35
from ...test_modeling_common import ModelTesterMixin, ids_tensor
36
from ...test_pipeline_mixin import PipelineTesterMixin
37

38

39
if is_torch_available():
40
    import torch
41

42
    from transformers import (
43
        AutoTokenizer,
44
        PLBartForCausalLM,
45
        PLBartForConditionalGeneration,
46
        PLBartForSequenceClassification,
47
        PLBartModel,
48
    )
49
    from transformers.models.plbart.modeling_plbart import PLBartDecoder, PLBartEncoder
50

51

52
def prepare_plbart_inputs_dict(
53
    config,
54
    input_ids,
55
    decoder_input_ids,
56
    attention_mask=None,
57
    decoder_attention_mask=None,
58
    head_mask=None,
59
    decoder_head_mask=None,
60
    cross_attn_head_mask=None,
61
):
62
    if attention_mask is None:
63
        attention_mask = input_ids.ne(config.pad_token_id)
64
    if decoder_attention_mask is None:
65
        decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
66
    if head_mask is None:
67
        head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
68
    if decoder_head_mask is None:
69
        decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
70
    if cross_attn_head_mask is None:
71
        cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
72
    return {
73
        "input_ids": input_ids,
74
        "decoder_input_ids": decoder_input_ids,
75
        "attention_mask": attention_mask,
76
        "decoder_attention_mask": attention_mask,
77
        "head_mask": head_mask,
78
        "decoder_head_mask": decoder_head_mask,
79
        "cross_attn_head_mask": cross_attn_head_mask,
80
    }
81

82

83
class PLBartModelTester:
84
    def __init__(
85
        self,
86
        parent,
87
        batch_size=13,
88
        seq_length=7,
89
        is_training=True,
90
        use_labels=False,
91
        vocab_size=99,
92
        hidden_size=16,
93
        num_hidden_layers=2,
94
        num_attention_heads=4,
95
        intermediate_size=4,
96
        hidden_act="gelu",
97
        hidden_dropout_prob=0.1,
98
        attention_probs_dropout_prob=0.1,
99
        max_position_embeddings=100,
100
        eos_token_id=2,
101
        pad_token_id=1,
102
        bos_token_id=0,
103
    ):
104
        self.parent = parent
105
        self.batch_size = batch_size
106
        self.seq_length = seq_length
107
        self.is_training = is_training
108
        self.use_labels = use_labels
109
        self.vocab_size = vocab_size
110
        self.hidden_size = hidden_size
111
        self.num_hidden_layers = num_hidden_layers
112
        self.num_attention_heads = num_attention_heads
113
        self.intermediate_size = intermediate_size
114
        self.hidden_act = hidden_act
115
        self.hidden_dropout_prob = hidden_dropout_prob
116
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
117
        self.max_position_embeddings = max_position_embeddings
118
        self.eos_token_id = eos_token_id
119
        self.pad_token_id = pad_token_id
120
        self.bos_token_id = bos_token_id
121

122
    def prepare_config_and_inputs(self):
123
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
124
        input_ids = input_ids.clamp(3)
125
        input_ids[:, -1] = self.eos_token_id  # Eos Token
126

127
        decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
128

129
        config = self.get_config()
130
        inputs_dict = prepare_plbart_inputs_dict(config, input_ids, decoder_input_ids)
131
        return config, inputs_dict
132

133
    def get_config(self):
134
        return PLBartConfig(
135
            vocab_size=self.vocab_size,
136
            d_model=self.hidden_size,
137
            encoder_layers=self.num_hidden_layers,
138
            decoder_layers=self.num_hidden_layers,
139
            encoder_attention_heads=self.num_attention_heads,
140
            decoder_attention_heads=self.num_attention_heads,
141
            encoder_ffn_dim=self.intermediate_size,
142
            decoder_ffn_dim=self.intermediate_size,
143
            dropout=self.hidden_dropout_prob,
144
            attention_dropout=self.attention_probs_dropout_prob,
145
            max_position_embeddings=self.max_position_embeddings,
146
            eos_token_id=self.eos_token_id,
147
            bos_token_id=self.bos_token_id,
148
            pad_token_id=self.pad_token_id,
149
        )
150

151
    def prepare_config_and_inputs_for_common(self):
152
        config, inputs_dict = self.prepare_config_and_inputs()
153
        return config, inputs_dict
154

155
    def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
156
        model = PLBartModel(config=config).get_decoder().to(torch_device).eval()
157
        input_ids = inputs_dict["input_ids"]
158
        attention_mask = inputs_dict["attention_mask"]
159
        head_mask = inputs_dict["head_mask"]
160

161
        # first forward pass
162
        outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
163

164
        output, past_key_values = outputs.to_tuple()
165

166
        # create hypothetical multiple next token and extent to next_input_ids
167
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
168
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
169

170
        # append to next input_ids and
171
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
172
        next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
173

174
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
175
        output_with_past_key_values = model(
176
            next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values
177
        )
178
        output_from_past = output_with_past_key_values["last_hidden_state"]
179

180
        # select random slice
181
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
182
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
183
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
184

185
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
186

187
        # test that outputs are equal for slice
188
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
189

190
    def check_encoder_decoder_model_standalone(self, config, inputs_dict):
191
        model = PLBartModel(config=config).to(torch_device).eval()
192
        outputs = model(**inputs_dict)
193

194
        encoder_last_hidden_state = outputs.encoder_last_hidden_state
195
        last_hidden_state = outputs.last_hidden_state
196

197
        with tempfile.TemporaryDirectory() as tmpdirname:
198
            encoder = model.get_encoder()
199
            encoder.save_pretrained(tmpdirname)
200
            encoder = PLBartEncoder.from_pretrained(tmpdirname).to(torch_device)
201

202
        encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
203
            0
204
        ]
205

206
        self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
207

208
        with tempfile.TemporaryDirectory() as tmpdirname:
209
            decoder = model.get_decoder()
210
            decoder.save_pretrained(tmpdirname)
211
            decoder = PLBartDecoder.from_pretrained(tmpdirname).to(torch_device)
212

213
        last_hidden_state_2 = decoder(
214
            input_ids=inputs_dict["decoder_input_ids"],
215
            attention_mask=inputs_dict["decoder_attention_mask"],
216
            encoder_hidden_states=encoder_last_hidden_state,
217
            encoder_attention_mask=inputs_dict["attention_mask"],
218
        )[0]
219

220
        self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
221

222

223
@require_torch
224
class PLBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
225
    all_model_classes = (
226
        (PLBartModel, PLBartForConditionalGeneration, PLBartForSequenceClassification) if is_torch_available() else ()
227
    )
228
    all_generative_model_classes = (PLBartForConditionalGeneration,) if is_torch_available() else ()
229
    pipeline_model_mapping = (
230
        {
231
            "conversational": PLBartForConditionalGeneration,
232
            "feature-extraction": PLBartModel,
233
            "summarization": PLBartForConditionalGeneration,
234
            "text-classification": PLBartForSequenceClassification,
235
            "text-generation": PLBartForCausalLM,
236
            "text2text-generation": PLBartForConditionalGeneration,
237
            "translation": PLBartForConditionalGeneration,
238
            "zero-shot": PLBartForSequenceClassification,
239
        }
240
        if is_torch_available()
241
        else {}
242
    )
243
    is_encoder_decoder = True
244
    fx_compatible = False  # Fix me Michael
245
    test_pruning = False
246
    test_missing_keys = False
247

248
    # TODO: Fix the failed tests
249
    def is_pipeline_test_to_skip(
250
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
251
    ):
252
        if pipeline_test_casse_name == "TranslationPipelineTests":
253
            # Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
254
            # `PLBartConfig` was never used in pipeline tests: cannot create a simple tokenizer.
255
            return True
256

257
        return False
258

259
    def setUp(self):
260
        self.model_tester = PLBartModelTester(self)
261
        self.config_tester = ConfigTester(self, config_class=PLBartConfig)
262

263
    def test_config(self):
264
        self.config_tester.run_common_tests()
265

266
    def test_save_load_strict(self):
267
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
268
        for model_class in self.all_model_classes:
269
            model = model_class(config)
270

271
            with tempfile.TemporaryDirectory() as tmpdirname:
272
                model.save_pretrained(tmpdirname)
273
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
274
            self.assertEqual(info["missing_keys"], [])
275

276
    def test_decoder_model_past_with_large_inputs(self):
277
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
278
        self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
279

280
    def test_encoder_decoder_model_standalone(self):
281
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
282
        self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
283

284
    # PLBartForSequenceClassification does not support inputs_embeds
285
    def test_inputs_embeds(self):
286
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
287

288
        for model_class in (PLBartModel, PLBartForConditionalGeneration):
289
            model = model_class(config)
290
            model.to(torch_device)
291
            model.eval()
292

293
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
294

295
            if not self.is_encoder_decoder:
296
                input_ids = inputs["input_ids"]
297
                del inputs["input_ids"]
298
            else:
299
                encoder_input_ids = inputs["input_ids"]
300
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
301
                del inputs["input_ids"]
302
                inputs.pop("decoder_input_ids", None)
303

304
            wte = model.get_input_embeddings()
305
            if not self.is_encoder_decoder:
306
                inputs["inputs_embeds"] = wte(input_ids)
307
            else:
308
                inputs["inputs_embeds"] = wte(encoder_input_ids)
309
                inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
310

311
            with torch.no_grad():
312
                model(**inputs)[0]
313

314
    @require_torch_fp16
315
    def test_generate_fp16(self):
316
        config, input_dict = self.model_tester.prepare_config_and_inputs()
317
        input_ids = input_dict["input_ids"]
318
        attention_mask = input_ids.ne(1).to(torch_device)
319
        model = PLBartForConditionalGeneration(config).eval().to(torch_device)
320
        model.half()
321
        model.generate(input_ids, attention_mask=attention_mask)
322
        model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
323

324
    @unittest.skip("Failing since #26752")
325
    def test_sample_generate(self):
326
        pass
327

328

329
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
330
    """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
331
    if a is None and b is None:
332
        return True
333
    try:
334
        if torch.allclose(a, b, atol=atol):
335
            return True
336
        raise
337
    except Exception:
338
        pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()
339
        if a.numel() > 100:
340
            msg = f"tensor values are {pct_different:.1%} percent different."
341
        else:
342
            msg = f"{a} != {b}"
343
        if prefix:
344
            msg = prefix + ": " + msg
345
        raise AssertionError(msg)
346

347

348
def _long_tensor(tok_lst):
349
    return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
350

351

352
@require_torch
353
@require_sentencepiece
354
@require_tokenizers
355
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
356
    maxDiff = 1000  # longer string compare tracebacks
357
    checkpoint_name = None
358

359
    @classmethod
360
    def setUpClass(cls):
361
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name, use_fast=False)
362
        return cls
363

364
    @cached_property
365
    def model(self):
366
        """Only load the model if needed."""
367
        model = PLBartForConditionalGeneration.from_pretrained(self.checkpoint_name).to(torch_device)
368
        if "cuda" in torch_device:
369
            model = model.half()
370
        return model
371

372

373
@require_torch
374
@require_sentencepiece
375
@require_tokenizers
376
class PLBartJavaCsIntegrationTest(AbstractSeq2SeqIntegrationTest):
377
    checkpoint_name = "uclanlp/plbart-java-cs"
378
    src_text = [
379
        "public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}",
380
        "public int product(int a, int b, int c){return a*b*c;}",
381
    ]
382
    tgt_text = [
383
        "public int maximum(int a, int b, int c){return Math.Max(",
384
        "public int Product(int a, int b, int c){return a * b *",
385
    ]
386

387
    @slow
388
    def test_java_cs_generate_one(self):
389
        batch = self.tokenizer(
390
            ["public int maximum(int a, int b, int c){return Math.max(a, Math.max(b, c));}"], return_tensors="pt"
391
        )
392
        batch = batch.to(torch_device)
393
        translated_tokens = self.model.generate(**batch)
394
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
395
        self.assertEqual(self.tgt_text[0], decoded[0])
396
        # self.assertEqual(self.tgt_text[1], decoded[1])
397

398
    @slow
399
    def test_java_cs_generate_batch(self):
400
        batch = self.tokenizer(self.src_text, return_tensors="pt", padding=True, truncation=True)
401
        batch = batch.to(torch_device)
402
        translated_tokens = self.model.generate(**batch)
403
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
404
        assert self.tgt_text == decoded
405

406
    def test_plbart_java_cs_config(self):
407
        plbart_models = ["uclanlp/plbart-java-cs"]
408
        expected = {"scale_embedding": True}
409
        for name in plbart_models:
410
            config = PLBartConfig.from_pretrained(name)
411
            for k, v in expected.items():
412
                try:
413
                    self.assertEqual(v, getattr(config, k))
414
                except AssertionError as e:
415
                    e.args += (name, k)
416
                    raise
417

418
    def test_plbart_fast_forward(self):
419
        config = PLBartConfig(
420
            vocab_size=99,
421
            d_model=24,
422
            encoder_layers=2,
423
            decoder_layers=2,
424
            encoder_attention_heads=2,
425
            decoder_attention_heads=2,
426
            encoder_ffn_dim=32,
427
            decoder_ffn_dim=32,
428
            max_position_embeddings=48,
429
            add_final_layer_norm=True,
430
        )
431
        lm_model = PLBartForConditionalGeneration(config).to(torch_device)
432
        context = torch.tensor(
433
            [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long
434
        )
435
        summary = torch.tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], device=torch_device, dtype=torch.long)
436
        result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
437
        expected_shape = (*summary.shape, config.vocab_size)
438
        self.assertEqual(result.logits.shape, expected_shape)
439

440

441
@require_torch
442
@require_sentencepiece
443
@require_tokenizers
444
class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
445
    checkpoint_name = "uclanlp/plbart-base"
446
    src_text = ["Is 0 the first Fibonacci number ?", "Find the sum of all prime numbers ."]
447
    tgt_text = ["0 the first Fibonacci number?", "the sum of all prime numbers.......... the the"]
448

449
    def test_base_generate(self):
450
        inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
451
        src_lan = self.tokenizer._convert_lang_code_special_format("en_XX")
452
        translated_tokens = self.model.generate(
453
            input_ids=inputs["input_ids"].to(torch_device),
454
            decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan],
455
        )
456
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
457
        self.assertEqual(self.tgt_text[0], decoded[0])
458

459
    @slow
460
    def test_fill_mask(self):
461
        inputs = self.tokenizer(["Is 0 the <mask> Fibonacci <mask> ?"], return_tensors="pt").to(torch_device)
462
        src_lan = self.tokenizer._convert_lang_code_special_format("en_XX")
463
        outputs = self.model.generate(
464
            inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id[src_lan], num_beams=1
465
        )
466
        prediction: str = self.tokenizer.batch_decode(
467
            outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
468
        )[0]
469
        self.assertEqual(prediction, "0 0 the 0 the 0 the 0 the 0 the 0 the 0 the 0 the")
470

471

472
class PLBartStandaloneDecoderModelTester:
473
    def __init__(
474
        self,
475
        parent,
476
        vocab_size=99,
477
        batch_size=13,
478
        d_model=16,
479
        decoder_seq_length=7,
480
        is_training=True,
481
        is_decoder=True,
482
        use_attention_mask=True,
483
        use_cache=False,
484
        use_labels=True,
485
        decoder_start_token_id=2,
486
        decoder_ffn_dim=32,
487
        decoder_layers=2,
488
        encoder_attention_heads=4,
489
        decoder_attention_heads=4,
490
        max_position_embeddings=30,
491
        is_encoder_decoder=False,
492
        pad_token_id=0,
493
        bos_token_id=1,
494
        eos_token_id=2,
495
        scope=None,
496
    ):
497
        self.parent = parent
498
        self.batch_size = batch_size
499
        self.decoder_seq_length = decoder_seq_length
500
        # For common tests
501
        self.seq_length = self.decoder_seq_length
502
        self.is_training = is_training
503
        self.use_attention_mask = use_attention_mask
504
        self.use_labels = use_labels
505

506
        self.vocab_size = vocab_size
507
        self.d_model = d_model
508
        self.hidden_size = d_model
509
        self.num_hidden_layers = decoder_layers
510
        self.decoder_layers = decoder_layers
511
        self.decoder_ffn_dim = decoder_ffn_dim
512
        self.encoder_attention_heads = encoder_attention_heads
513
        self.decoder_attention_heads = decoder_attention_heads
514
        self.num_attention_heads = decoder_attention_heads
515
        self.eos_token_id = eos_token_id
516
        self.bos_token_id = bos_token_id
517
        self.pad_token_id = pad_token_id
518
        self.decoder_start_token_id = decoder_start_token_id
519
        self.use_cache = use_cache
520
        self.max_position_embeddings = max_position_embeddings
521
        self.is_encoder_decoder = is_encoder_decoder
522

523
        self.scope = None
524
        self.decoder_key_length = decoder_seq_length
525
        self.base_model_out_len = 2
526
        self.decoder_attention_idx = 1
527

528
    def prepare_config_and_inputs(self):
529
        input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
530

531
        attention_mask = None
532
        if self.use_attention_mask:
533
            attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
534

535
        lm_labels = None
536
        if self.use_labels:
537
            lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
538

539
        config = PLBartConfig(
540
            vocab_size=self.vocab_size,
541
            d_model=self.d_model,
542
            decoder_layers=self.decoder_layers,
543
            decoder_ffn_dim=self.decoder_ffn_dim,
544
            encoder_attention_heads=self.encoder_attention_heads,
545
            decoder_attention_heads=self.decoder_attention_heads,
546
            eos_token_id=self.eos_token_id,
547
            bos_token_id=self.bos_token_id,
548
            use_cache=self.use_cache,
549
            pad_token_id=self.pad_token_id,
550
            decoder_start_token_id=self.decoder_start_token_id,
551
            max_position_embeddings=self.max_position_embeddings,
552
            is_encoder_decoder=self.is_encoder_decoder,
553
        )
554

555
        return (config, input_ids, attention_mask, lm_labels)
556

557
    def create_and_check_decoder_model_past(
558
        self,
559
        config,
560
        input_ids,
561
        attention_mask,
562
        lm_labels,
563
    ):
564
        config.use_cache = True
565
        model = PLBartDecoder(config=config).to(torch_device).eval()
566
        # first forward pass
567
        outputs = model(input_ids, use_cache=True)
568
        outputs_use_cache_conf = model(input_ids)
569
        outputs_no_past = model(input_ids, use_cache=False)
570

571
        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
572
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
573

574
        past_key_values = outputs["past_key_values"]
575

576
        # create hypothetical next token and extent to next_input_ids
577
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
578

579
        # append to next input_ids and
580
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
581

582
        output_from_no_past = model(next_input_ids)["last_hidden_state"]
583
        output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
584

585
        # select random slice
586
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
587
        output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
588
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
589

590
        # test that outputs are equal for slice
591
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
592

593
    def create_and_check_decoder_model_attention_mask_past(
594
        self,
595
        config,
596
        input_ids,
597
        attention_mask,
598
        lm_labels,
599
    ):
600
        model = PLBartDecoder(config=config).to(torch_device).eval()
601

602
        # create attention mask
603
        attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
604

605
        half_seq_length = input_ids.shape[-1] // 2
606
        attn_mask[:, half_seq_length:] = 0
607

608
        # first forward pass
609
        past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
610

611
        # create hypothetical next token and extent to next_input_ids
612
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
613

614
        # change a random masked slice from input_ids
615
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
616
        random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
617
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
618

619
        # append to next input_ids and attn_mask
620
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
621
        attn_mask = torch.cat(
622
            [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
623
            dim=1,
624
        )
625

626
        # get two different outputs
627
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
628
        output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
629
            "last_hidden_state"
630
        ]
631

632
        # select random slice
633
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
634
        output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
635
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
636

637
        # test that outputs are equal for slice
638
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
639

640
    def prepare_config_and_inputs_for_common(self):
641
        config_and_inputs = self.prepare_config_and_inputs()
642
        (config, input_ids, attention_mask, lm_labels) = config_and_inputs
643
        inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
644
        return config, inputs_dict
645

646

647
@require_torch
648
class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
649
    all_model_classes = (PLBartDecoder, PLBartForCausalLM) if is_torch_available() else ()
650
    all_generative_model_classes = (PLBartForCausalLM,) if is_torch_available() else ()
651
    test_pruning = False
652
    is_encoder_decoder = False
653

654
    def setUp(self):
655
        self.model_tester = PLBartStandaloneDecoderModelTester(self, is_training=False)
656
        self.config_tester = ConfigTester(self, config_class=PLBartConfig)
657

658
    def test_config(self):
659
        self.config_tester.run_common_tests()
660

661
    def test_decoder_model_past(self):
662
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
663
        self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
664

665
    def test_decoder_model_attn_mask_past(self):
666
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
667
        self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
668

669
    def test_retain_grad_hidden_states_attentions(self):
670
        # decoder cannot keep gradients
671
        return
672

673
    @unittest.skip("The model doesn't support left padding")  # and it's not used enough to be worth fixing :)
674
    def test_left_padding_compatibility(self):
675
        pass
676

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.