transformers

Форк
0
/
test_modeling_falcon.py 
594 строки · 24.5 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch Falcon model. """
16

17

18
import tempfile
19
import unittest
20

21
from parameterized import parameterized
22

23
from transformers import (
24
    AutoModelForCausalLM,
25
    AutoTokenizer,
26
    FalconConfig,
27
    is_torch_available,
28
    set_seed,
29
)
30
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_sdpa, slow, torch_device
31

32
from ...generation.test_utils import GenerationTesterMixin
33
from ...test_configuration_common import ConfigTester
34
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
35
from ...test_pipeline_mixin import PipelineTesterMixin
36

37

38
if is_torch_available():
39
    import torch
40

41
    from transformers import (
42
        FalconForCausalLM,
43
        FalconForQuestionAnswering,
44
        FalconForSequenceClassification,
45
        FalconForTokenClassification,
46
        FalconModel,
47
    )
48

49

50
class FalconModelTester:
51
    def __init__(
52
        self,
53
        parent,
54
        batch_size=3,
55
        seq_length=7,
56
        is_training=True,
57
        use_input_mask=True,
58
        use_token_type_ids=False,
59
        use_labels=True,
60
        vocab_size=99,
61
        hidden_size=32,
62
        num_hidden_layers=2,
63
        num_attention_heads=4,
64
        intermediate_size=37,
65
        hidden_act="gelu",
66
        hidden_dropout_prob=0.1,
67
        attention_probs_dropout_prob=0.1,
68
        max_position_embeddings=512,
69
        type_vocab_size=16,
70
        type_sequence_label_size=2,
71
        initializer_range=0.02,
72
        num_labels=3,
73
        num_choices=4,
74
        scope=None,
75
    ):
76
        self.parent = parent
77
        self.batch_size = batch_size
78
        self.seq_length = seq_length
79
        self.is_training = is_training
80
        self.use_input_mask = use_input_mask
81
        self.use_token_type_ids = use_token_type_ids
82
        self.use_labels = use_labels
83
        self.vocab_size = vocab_size
84
        self.hidden_size = hidden_size
85
        self.num_hidden_layers = num_hidden_layers
86
        self.num_attention_heads = num_attention_heads
87
        self.intermediate_size = intermediate_size
88
        self.hidden_act = hidden_act
89
        self.hidden_dropout_prob = hidden_dropout_prob
90
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
91
        self.max_position_embeddings = max_position_embeddings
92
        self.type_vocab_size = type_vocab_size
93
        self.type_sequence_label_size = type_sequence_label_size
94
        self.initializer_range = initializer_range
95
        self.num_labels = num_labels
96
        self.num_choices = num_choices
97
        self.scope = scope
98

99
    def prepare_config_and_inputs(self):
100
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
101

102
        input_mask = None
103
        if self.use_input_mask:
104
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
105

106
        token_type_ids = None
107

108
        sequence_labels = None
109
        token_labels = None
110
        choice_labels = None
111
        if self.use_labels:
112
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
113
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
114
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
115

116
        config = self.get_config()
117

118
        return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
119

120
    def get_config(self):
121
        return FalconConfig(
122
            vocab_size=self.vocab_size,
123
            hidden_size=self.hidden_size,
124
            num_hidden_layers=self.num_hidden_layers,
125
            num_attention_heads=self.num_attention_heads,
126
            intermediate_size=self.intermediate_size,
127
            hidden_act=self.hidden_act,
128
            hidden_dropout_prob=self.hidden_dropout_prob,
129
            attention_probs_dropout_prob=self.attention_probs_dropout_prob,
130
            max_position_embeddings=self.max_position_embeddings,
131
            type_vocab_size=self.type_vocab_size,
132
            is_decoder=False,
133
            initializer_range=self.initializer_range,
134
            pad_token_id=1,
135
            new_decoder_architecture=True,
136
        )
137

138
    def create_and_check_model(
139
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
140
    ):
141
        model = FalconModel(config=config)
142
        model.to(torch_device)
143
        model.eval()
144
        result = model(input_ids, attention_mask=input_mask)
145
        result = model(input_ids)
146
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
147

148
    def create_and_check_model_as_decoder(
149
        self,
150
        config,
151
        input_ids,
152
        token_type_ids,
153
        input_mask,
154
        sequence_labels,
155
        token_labels,
156
        choice_labels,
157
        encoder_hidden_states,
158
        encoder_attention_mask,
159
    ):
160
        config.add_cross_attention = True
161
        model = FalconModel(config)
162
        model.to(torch_device)
163
        model.eval()
164
        result = model(
165
            input_ids,
166
            attention_mask=input_mask,
167
            encoder_hidden_states=encoder_hidden_states,
168
            encoder_attention_mask=encoder_attention_mask,
169
        )
170
        result = model(
171
            input_ids,
172
            attention_mask=input_mask,
173
            encoder_hidden_states=encoder_hidden_states,
174
        )
175
        result = model(input_ids, attention_mask=input_mask)
176
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
177

178
    def create_and_check_for_causal_lm(
179
        self,
180
        config,
181
        input_ids,
182
        token_type_ids,
183
        input_mask,
184
        sequence_labels,
185
        token_labels,
186
        choice_labels,
187
        encoder_hidden_states,
188
        encoder_attention_mask,
189
    ):
190
        model = FalconForCausalLM(config=config)
191
        model.to(torch_device)
192
        model.eval()
193
        result = model(input_ids, attention_mask=input_mask, labels=token_labels)
194
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
195

196
    def create_and_check_decoder_model_past_large_inputs(
197
        self,
198
        config,
199
        input_ids,
200
        token_type_ids,
201
        input_mask,
202
        sequence_labels,
203
        token_labels,
204
        choice_labels,
205
        encoder_hidden_states,
206
        encoder_attention_mask,
207
    ):
208
        config.is_decoder = True
209
        config.add_cross_attention = True
210
        model = FalconForCausalLM(config=config)
211
        model.to(torch_device)
212
        model.eval()
213

214
        # first forward pass
215
        outputs = model(
216
            input_ids,
217
            attention_mask=input_mask,
218
            encoder_hidden_states=encoder_hidden_states,
219
            encoder_attention_mask=encoder_attention_mask,
220
            use_cache=True,
221
        )
222
        past_key_values = outputs.past_key_values
223

224
        # create hypothetical multiple next token and extent to next_input_ids
225
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
226
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
227

228
        # append to next input_ids and
229
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
230
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
231

232
        output_from_no_past = model(
233
            next_input_ids,
234
            attention_mask=next_attention_mask,
235
            encoder_hidden_states=encoder_hidden_states,
236
            encoder_attention_mask=encoder_attention_mask,
237
            output_hidden_states=True,
238
        )["hidden_states"][0]
239
        output_from_past = model(
240
            next_tokens,
241
            attention_mask=next_attention_mask,
242
            encoder_hidden_states=encoder_hidden_states,
243
            encoder_attention_mask=encoder_attention_mask,
244
            past_key_values=past_key_values,
245
            output_hidden_states=True,
246
        )["hidden_states"][0]
247

248
        # select random slice
249
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
250
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
251
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
252

253
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
254

255
        # test that outputs are equal for slice
256
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
257

258
    def prepare_config_and_inputs_for_common(self):
259
        config_and_inputs = self.prepare_config_and_inputs()
260
        (
261
            config,
262
            input_ids,
263
            token_type_ids,
264
            input_mask,
265
            sequence_labels,
266
            token_labels,
267
            choice_labels,
268
        ) = config_and_inputs
269
        inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
270
        return config, inputs_dict
271

272

273
@require_torch
274
class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
275
    all_model_classes = (
276
        (
277
            FalconModel,
278
            FalconForCausalLM,
279
            FalconForSequenceClassification,
280
            FalconForTokenClassification,
281
            FalconForQuestionAnswering,
282
        )
283
        if is_torch_available()
284
        else ()
285
    )
286
    all_generative_model_classes = (FalconForCausalLM,) if is_torch_available() else ()
287
    pipeline_model_mapping = (
288
        {
289
            "feature-extraction": FalconModel,
290
            "question-answering": FalconForQuestionAnswering,
291
            "text-classification": FalconForSequenceClassification,
292
            "text-generation": FalconForCausalLM,
293
            "token-classification": FalconForTokenClassification,
294
            "zero-shot": FalconForSequenceClassification,
295
        }
296
        if is_torch_available()
297
        else {}
298
    )
299
    test_headmasking = False
300
    test_pruning = False
301

302
    # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
303
    def is_pipeline_test_to_skip(
304
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
305
    ):
306
        return True
307

308
    def setUp(self):
309
        self.model_tester = FalconModelTester(self)
310
        self.config_tester = ConfigTester(self, config_class=FalconConfig, hidden_size=37)
311

312
    def test_config(self):
313
        self.config_tester.run_common_tests()
314

315
    def test_model(self):
316
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
317
        self.model_tester.create_and_check_model(*config_and_inputs)
318

319
    def test_position_embedding_types(self):
320
        config, *inputs = self.model_tester.prepare_config_and_inputs()
321
        for alibi in [True, False]:
322
            config.alibi = alibi
323
            self.model_tester.create_and_check_model(config, *inputs)
324

325
    def test_falcon_sequence_classification_model(self):
326
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
327
        config.num_labels = 3
328
        input_ids = input_dict["input_ids"]
329
        attention_mask = input_ids.ne(1).to(torch_device)
330
        sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
331
        model = FalconForSequenceClassification(config)
332
        model.to(torch_device)
333
        model.eval()
334
        result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
335
        self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
336

337
    def test_falcon_sequence_classification_model_for_single_label(self):
338
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
339
        config.num_labels = 3
340
        config.problem_type = "single_label_classification"
341
        input_ids = input_dict["input_ids"]
342
        attention_mask = input_ids.ne(1).to(torch_device)
343
        sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
344
        model = FalconForSequenceClassification(config)
345
        model.to(torch_device)
346
        model.eval()
347
        result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
348
        self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
349

350
    def test_falcon_sequence_classification_model_for_multi_label(self):
351
        config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
352
        config.num_labels = 3
353
        config.problem_type = "multi_label_classification"
354
        input_ids = input_dict["input_ids"]
355
        attention_mask = input_ids.ne(1).to(torch_device)
356
        sequence_labels = ids_tensor(
357
            [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
358
        ).to(torch.float)
359
        model = FalconForSequenceClassification(config)
360
        model.to(torch_device)
361
        model.eval()
362
        result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
363
        self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
364

365
    def test_past_key_values_format(self):
366
        # Falcon can have different numbers of KV-heads than the number of query heads, so we need
367
        # to override this test to use the right head counts.
368
        for model_class in self.all_generative_model_classes:
369
            config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
370

371
            # If it doesn't support cache, pass the test
372
            if not hasattr(config, "use_cache"):
373
                return
374

375
            model = model_class(config).to(torch_device)
376
            if "use_cache" not in inputs:
377
                inputs["use_cache"] = True
378
            outputs = model(**inputs)
379

380
            # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
381
            if "past_key_values" not in outputs:
382
                return
383

384
            num_hidden_layers = (
385
                getattr(config, "decoder_layers", None)
386
                or getattr(config, "num_decoder_layers", None)
387
                or config.num_hidden_layers
388
            )
389
            num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
390
            embed_dim = getattr(config, "d_model", config.hidden_size)
391
            per_head_embed_dim = embed_dim // num_attention_heads
392

393
            past_kv = outputs["past_key_values"]
394
            self.assertEqual(len(past_kv), num_hidden_layers)
395

396
            batch_size, seq_length = inputs["input_ids"].shape
397
            for i in range(num_hidden_layers):
398
                if config.new_decoder_architecture:
399
                    num_attention_heads = config.num_attention_heads
400
                elif config.multi_query:
401
                    num_attention_heads = 1
402
                self.assertEqual(len(past_kv[0]), 2)  # K V for the decoder = 2
403
                self.assertEqual(
404
                    past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
405
                )
406
                self.assertEqual(
407
                    past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
408
                )
409

410
    @parameterized.expand([("linear",), ("dynamic",)])
411
    def test_model_rope_scaling(self, scaling_type):
412
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
413
        short_input = ids_tensor([1, 10], config.vocab_size)
414
        long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
415

416
        set_seed(42)  # Fixed seed at init time so the two models get the same random weights
417
        original_model = FalconModel(config)
418
        original_model.to(torch_device)
419
        original_model.eval()
420
        original_short_output = original_model(short_input).last_hidden_state
421
        original_long_output = original_model(long_input).last_hidden_state
422

423
        set_seed(42)  # Fixed seed at init time so the two models get the same random weights
424
        config.rope_scaling = {"type": scaling_type, "factor": 10.0}
425
        scaled_model = FalconModel(config)
426
        scaled_model.to(torch_device)
427
        scaled_model.eval()
428
        scaled_short_output = scaled_model(short_input).last_hidden_state
429
        scaled_long_output = scaled_model(long_input).last_hidden_state
430

431
        # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
432
        # maximum sequence length, so the outputs for the short input should match.
433
        if scaling_type == "dynamic":
434
            self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
435
        else:
436
            self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
437

438
        # The output should be different for long inputs
439
        self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
440

441
    @require_torch_sdpa
442
    @slow
443
    def test_eager_matches_sdpa_generate(self):
444
        max_new_tokens = 30
445

446
        if len(self.all_generative_model_classes) == 0:
447
            self.skipTest(f"{self.__class__.__name__} tests a model that does support generate: skipping this test")
448

449
        for model_class in self.all_generative_model_classes:
450
            if not model_class._supports_sdpa:
451
                self.skipTest(f"{model_class.__name__} does not support SDPA")
452

453
            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
454

455
            dummy_input = inputs_dict[model_class.main_input_name]
456
            if dummy_input.dtype in [torch.float32, torch.bfloat16]:
457
                dummy_input = dummy_input.to(torch.float16)
458

459
            # make sure that all models have enough positions for generation
460
            if hasattr(config, "max_position_embeddings"):
461
                config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
462

463
            model = model_class(config)
464

465
            with tempfile.TemporaryDirectory() as tmpdirname:
466
                model.save_pretrained(tmpdirname)
467

468
                dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
469

470
                model_sdpa = model_class.from_pretrained(
471
                    tmpdirname,
472
                    torch_dtype=torch.float16,
473
                    low_cpu_mem_usage=True,
474
                ).to(torch_device)
475

476
                self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
477

478
                model_eager = model_class.from_pretrained(
479
                    tmpdirname,
480
                    torch_dtype=torch.float16,
481
                    low_cpu_mem_usage=True,
482
                    attn_implementation="eager",
483
                ).to(torch_device)
484

485
                self.assertTrue(model_eager.config._attn_implementation == "eager")
486

487
                # NOTE: This check is disabled for Falcon as the non-SDPA/SDPA implementation is in the same class (legacy reason).
488
                # for name, submodule in model_eager.named_modules():
489
                #     if "SdpaAttention" in submodule.__class__.__name__:
490
                #         raise ValueError("The eager model should not have SDPA attention layers")
491

492
                # has_sdpa = False
493
                # for name, submodule in model_sdpa.named_modules():
494
                #     if "SdpaAttention" in submodule.__class__.__name__:
495
                #         has_sdpa = True
496
                #         break
497
                # if not has_sdpa:
498
                #     raise ValueError("The SDPA model should have SDPA attention layers")
499

500
                # Just test that a large cache works as expected
501
                res_eager = model_eager.generate(
502
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
503
                )
504

505
                res_sdpa = model_sdpa.generate(
506
                    dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False
507
                )
508

509
                self.assertTrue(torch.allclose(res_eager, res_sdpa))
510

511

512
@require_torch
513
class FalconLanguageGenerationTest(unittest.TestCase):
514
    @slow
515
    def test_lm_generate_falcon(self):
516
        tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/falcon-rw-1b")
517
        model = FalconForCausalLM.from_pretrained("Rocketknight1/falcon-rw-1b")
518
        model.eval()
519
        model.to(torch_device)
520
        inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
521

522
        EXPECTED_OUTPUT = (
523
            "My favorite food is pizza. I love it so much that I have a pizza party every year for my birthday."
524
        )
525

526
        output_ids = model.generate(**inputs, do_sample=False, max_new_tokens=19)
527
        output_str = tokenizer.batch_decode(output_ids)[0]
528

529
        self.assertEqual(output_str, EXPECTED_OUTPUT)
530

531
    @slow
532
    def test_lm_generation_big_models(self):
533
        # The big models are way too big for the CI, so we use tiny random models that resemble their
534
        # architectures but with much smaller and fewer layers
535
        for repo in ["Rocketknight1/tiny-random-falcon-7b", "Rocketknight1/tiny-random-falcon-40b"]:
536
            tokenizer = AutoTokenizer.from_pretrained(repo)
537
            model = FalconForCausalLM.from_pretrained(repo)
538
            model.eval()
539
            model.to(torch_device)
540
            inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
541

542
            # We just test that these run without errors - the models are randomly initialized
543
            # and so the actual text outputs will be garbage
544
            model.generate(**inputs, do_sample=False, max_new_tokens=4)
545
            model.generate(**inputs, do_sample=True, max_new_tokens=4)
546
            model.generate(**inputs, num_beams=2, max_new_tokens=4)
547

548
    @slow
549
    def test_lm_generation_use_cache(self):
550
        # The big models are way too big for the CI, so we use tiny random models that resemble their
551
        # architectures but with much smaller and fewer layers
552
        with torch.no_grad():
553
            for repo in [
554
                "Rocketknight1/falcon-rw-1b",
555
                "Rocketknight1/tiny-random-falcon-7b",
556
                "Rocketknight1/tiny-random-falcon-40b",
557
            ]:
558
                tokenizer = AutoTokenizer.from_pretrained(repo)
559
                model = FalconForCausalLM.from_pretrained(repo)
560
                model.eval()
561
                model.to(device=torch_device)
562
                inputs = tokenizer("My favorite food is", return_tensors="pt").to(torch_device)
563

564
                # Test results are the same with and without cache
565
                outputs_no_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)
566
                outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
567
                self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
568

569
    @require_bitsandbytes
570
    @slow
571
    def test_batched_generation(self):
572
        tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
573
        tokenizer.pad_token = tokenizer.eos_token
574
        model = AutoModelForCausalLM.from_pretrained(
575
            "tiiuae/falcon-7b",
576
            device_map="auto",
577
            load_in_4bit=True,
578
        )
579

580
        test_text = "A sequence: 1, 2"  # should generate the rest of the sequence
581

582
        unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
583
        unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
584
        unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)
585

586
        dummy_text = "This is a longer text " * 2  # forces left-padding on `test_text`
587
        padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
588
        padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
589
        padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)
590

591
        expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
592
        self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1])  # left-padding exists
593
        self.assertEqual(unpadded_gen_text[0], expected_output)
594
        self.assertEqual(padded_gen_text[0], expected_output)
595

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.