transformers

Форк
0
/
test_modeling_gpt_bigcode.py 
628 строк · 25.8 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import math
16
import unittest
17

18
from parameterized import parameterized
19

20
from transformers import GPTBigCodeConfig, is_torch_available
21
from transformers.testing_utils import require_torch, slow, torch_device
22

23
from ...generation.test_utils import GenerationTesterMixin
24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_torch_available():
30
    import torch
31

32
    from transformers import (
33
        GPT2TokenizerFast,
34
        GPTBigCodeForCausalLM,
35
        GPTBigCodeForSequenceClassification,
36
        GPTBigCodeForTokenClassification,
37
        GPTBigCodeModel,
38
    )
39
    from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
40
    from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_12
41
else:
42
    is_torch_greater_or_equal_than_1_12 = False
43

44

45
class GPTBigCodeModelTester:
46
    def __init__(
47
        self,
48
        parent,
49
        batch_size=14,
50
        seq_length=7,
51
        is_training=True,
52
        use_token_type_ids=True,
53
        use_input_mask=True,
54
        use_labels=True,
55
        use_mc_token_ids=True,
56
        vocab_size=99,
57
        hidden_size=32,
58
        num_hidden_layers=2,
59
        num_attention_heads=4,
60
        intermediate_size=37,
61
        hidden_act="relu",
62
        hidden_dropout_prob=0.1,
63
        attention_probs_dropout_prob=0.1,
64
        max_position_embeddings=512,
65
        type_vocab_size=16,
66
        type_sequence_label_size=2,
67
        initializer_range=0.02,
68
        num_labels=3,
69
        num_choices=4,
70
        multi_query=True,
71
        scope=None,
72
    ):
73
        self.parent = parent
74
        self.batch_size = batch_size
75
        self.seq_length = seq_length
76
        self.is_training = is_training
77
        self.use_token_type_ids = use_token_type_ids
78
        self.use_input_mask = use_input_mask
79
        self.use_labels = use_labels
80
        self.use_mc_token_ids = use_mc_token_ids
81
        self.vocab_size = vocab_size
82
        self.hidden_size = hidden_size
83
        self.num_hidden_layers = num_hidden_layers
84
        self.num_attention_heads = num_attention_heads
85
        self.intermediate_size = intermediate_size
86
        self.hidden_act = hidden_act
87
        self.hidden_dropout_prob = hidden_dropout_prob
88
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
89
        self.max_position_embeddings = max_position_embeddings
90
        self.type_vocab_size = type_vocab_size
91
        self.type_sequence_label_size = type_sequence_label_size
92
        self.initializer_range = initializer_range
93
        self.num_labels = num_labels
94
        self.num_choices = num_choices
95
        self.scope = None
96
        self.bos_token_id = vocab_size - 1
97
        self.eos_token_id = vocab_size - 2
98
        self.pad_token_id = vocab_size - 3
99
        self.multi_query = multi_query
100

101
    def get_large_model_config(self):
102
        return GPTBigCodeConfig.from_pretrained("bigcode/gpt_bigcode-santacoder")
103

104
    def prepare_config_and_inputs(
105
        self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
106
    ):
107
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
108

109
        input_mask = None
110
        if self.use_input_mask:
111
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
112

113
        token_type_ids = None
114
        if self.use_token_type_ids:
115
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
116

117
        mc_token_ids = None
118
        if self.use_mc_token_ids:
119
            mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
120

121
        sequence_labels = None
122
        token_labels = None
123
        choice_labels = None
124
        if self.use_labels:
125
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
126
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
127
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
128

129
        config = self.get_config(
130
            gradient_checkpointing=gradient_checkpointing,
131
            scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
132
            reorder_and_upcast_attn=reorder_and_upcast_attn,
133
        )
134

135
        head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
136

137
        return (
138
            config,
139
            input_ids,
140
            input_mask,
141
            head_mask,
142
            token_type_ids,
143
            mc_token_ids,
144
            sequence_labels,
145
            token_labels,
146
            choice_labels,
147
        )
148

149
    def get_config(
150
        self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
151
    ):
152
        return GPTBigCodeConfig(
153
            vocab_size=self.vocab_size,
154
            n_embd=self.hidden_size,
155
            n_layer=self.num_hidden_layers,
156
            n_head=self.num_attention_heads,
157
            n_inner=self.intermediate_size,
158
            activation_function=self.hidden_act,
159
            resid_pdrop=self.hidden_dropout_prob,
160
            attn_pdrop=self.attention_probs_dropout_prob,
161
            n_positions=self.max_position_embeddings,
162
            type_vocab_size=self.type_vocab_size,
163
            initializer_range=self.initializer_range,
164
            use_cache=True,
165
            bos_token_id=self.bos_token_id,
166
            eos_token_id=self.eos_token_id,
167
            pad_token_id=self.pad_token_id,
168
            gradient_checkpointing=gradient_checkpointing,
169
            scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
170
            reorder_and_upcast_attn=reorder_and_upcast_attn,
171
            attention_softmax_in_fp32=False,
172
            scale_attention_softmax_in_fp32=False,
173
            multi_query=self.multi_query,
174
        )
175

176
    def get_pipeline_config(self):
177
        config = self.get_config()
178
        config.vocab_size = 300
179
        return config
180

181
    def prepare_config_and_inputs_for_decoder(self):
182
        (
183
            config,
184
            input_ids,
185
            input_mask,
186
            head_mask,
187
            token_type_ids,
188
            mc_token_ids,
189
            sequence_labels,
190
            token_labels,
191
            choice_labels,
192
        ) = self.prepare_config_and_inputs()
193

194
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
195
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
196

197
        return (
198
            config,
199
            input_ids,
200
            input_mask,
201
            head_mask,
202
            token_type_ids,
203
            sequence_labels,
204
            token_labels,
205
            choice_labels,
206
            encoder_hidden_states,
207
            encoder_attention_mask,
208
        )
209

210
    def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
211
        model = GPTBigCodeModel(config=config)
212
        model.to(torch_device)
213
        model.eval()
214

215
        result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
216
        result = model(input_ids, token_type_ids=token_type_ids)
217
        result = model(input_ids)
218

219
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
220
        self.parent.assertEqual(len(result.past_key_values), config.n_layer)
221

222
    def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
223
        model = GPTBigCodeModel(config=config)
224
        model.to(torch_device)
225
        model.eval()
226

227
        # first forward pass
228
        outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
229
        outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
230
        outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
231

232
        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
233
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
234

235
        output, past = outputs.to_tuple()
236

237
        # create hypothetical next token and extent to next_input_ids
238
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
239
        next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
240

241
        # append to next input_ids and token_type_ids
242
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
243
        next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
244

245
        output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
246
        output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[
247
            "last_hidden_state"
248
        ]
249

250
        # select random slice
251
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
252
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
253
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
254

255
        # test that outputs are equal for slice
256
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
257

258
    def create_and_check_gpt_bigcode_model_attention_mask_past(
259
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args
260
    ):
261
        model = GPTBigCodeModel(config=config)
262
        model.to(torch_device)
263
        model.eval()
264

265
        # create attention mask
266
        attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
267
        half_seq_length = self.seq_length // 2
268
        attn_mask[:, half_seq_length:] = 0
269

270
        # first forward pass
271
        output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
272

273
        # create hypothetical next token and extent to next_input_ids
274
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
275

276
        # change a random masked slice from input_ids
277
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
278
        random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
279
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
280

281
        # append to next input_ids and attn_mask
282
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
283
        attn_mask = torch.cat(
284
            [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
285
            dim=1,
286
        )
287

288
        # get two different outputs
289
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
290
        output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
291

292
        # select random slice
293
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
294
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
295
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
296

297
        # test that outputs are equal for slice
298
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
299

300
    def create_and_check_gpt_bigcode_model_past_large_inputs(
301
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args
302
    ):
303
        model = GPTBigCodeModel(config=config)
304
        model.to(torch_device)
305
        model.eval()
306

307
        # first forward pass
308
        outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True)
309

310
        output, past = outputs.to_tuple()
311

312
        # create hypothetical next token and extent to next_input_ids
313
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
314
        next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)
315
        next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
316

317
        # append to next input_ids and token_type_ids
318
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
319
        next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
320
        next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
321

322
        output_from_no_past = model(
323
            next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
324
        )["last_hidden_state"]
325
        output_from_past = model(
326
            next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past
327
        )["last_hidden_state"]
328
        self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
329

330
        # select random slice
331
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
332
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
333
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
334

335
        # test that outputs are equal for slice
336
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
337

338
    def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
339
        model = GPTBigCodeForCausalLM(config)
340
        model.to(torch_device)
341
        model.eval()
342

343
        result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
344
        self.parent.assertEqual(result.loss.shape, ())
345
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
346

347
    def create_and_check_forward_and_backwards(
348
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
349
    ):
350
        model = GPTBigCodeForCausalLM(config)
351
        model.to(torch_device)
352
        if gradient_checkpointing:
353
            model.gradient_checkpointing_enable()
354

355
        result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
356
        self.parent.assertEqual(result.loss.shape, ())
357
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
358
        result.loss.backward()
359

360
    def create_and_check_gpt_bigcode_for_sequence_classification(
361
        self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
362
    ):
363
        config.num_labels = self.num_labels
364
        model = GPTBigCodeForSequenceClassification(config)
365
        model.to(torch_device)
366
        model.eval()
367
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
368
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
369

370
    def create_and_check_gpt_bigcode_for_token_classification(
371
        self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
372
    ):
373
        config.num_labels = self.num_labels
374
        model = GPTBigCodeForTokenClassification(config)
375
        model.to(torch_device)
376
        model.eval()
377
        result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
378
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
379

380
    def create_and_check_gpt_bigcode_weight_initialization(self, config, *args):
381
        model = GPTBigCodeModel(config)
382
        model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)
383
        for key in model.state_dict().keys():
384
            if "c_proj" in key and "weight" in key:
385
                self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
386
                self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
387

388
    def prepare_config_and_inputs_for_common(self):
389
        config_and_inputs = self.prepare_config_and_inputs()
390

391
        (
392
            config,
393
            input_ids,
394
            input_mask,
395
            head_mask,
396
            token_type_ids,
397
            mc_token_ids,
398
            sequence_labels,
399
            token_labels,
400
            choice_labels,
401
        ) = config_and_inputs
402

403
        inputs_dict = {
404
            "input_ids": input_ids,
405
            "token_type_ids": token_type_ids,
406
            "head_mask": head_mask,
407
        }
408

409
        return config, inputs_dict
410

411

412
@require_torch
413
class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
414
    # TODO: Update the tests to use valid pretrained models.
415
    all_model_classes = (
416
        (
417
            GPTBigCodeModel,
418
            GPTBigCodeForCausalLM,
419
            GPTBigCodeForSequenceClassification,
420
            GPTBigCodeForTokenClassification,
421
        )
422
        if is_torch_available()
423
        else ()
424
    )
425
    all_generative_model_classes = (GPTBigCodeForCausalLM,) if is_torch_available() else ()
426
    pipeline_model_mapping = (
427
        {
428
            "feature-extraction": GPTBigCodeModel,
429
            "text-classification": GPTBigCodeForSequenceClassification,
430
            "text-generation": GPTBigCodeForCausalLM,
431
            "token-classification": GPTBigCodeForTokenClassification,
432
            "zero-shot": GPTBigCodeForSequenceClassification,
433
        }
434
        if is_torch_available()
435
        else {}
436
    )
437
    fx_compatible = False
438
    test_missing_keys = False
439
    test_pruning = False
440
    test_torchscript = False
441
    multi_query = True
442

443
    # special case for DoubleHeads model
444
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
445
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
446

447
        return inputs_dict
448

449
    def setUp(self):
450
        self.model_tester = GPTBigCodeModelTester(self, multi_query=self.multi_query)
451
        self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37)
452

453
    def tearDown(self):
454
        import gc
455

456
        gc.collect()
457

458
    def test_config(self):
459
        self.config_tester.run_common_tests()
460

461
    @unittest.skip("MQA models does not support retain_grad")
462
    def test_retain_grad_hidden_states_attentions(self):
463
        pass
464

465
    @unittest.skip("Contrastive search not supported due to non-standard caching mechanism")
466
    def test_contrastive_generate(self):
467
        pass
468

469
    @unittest.skip("Contrastive search not supported due to non-standard caching mechanism")
470
    def test_contrastive_generate_dict_outputs_use_cache(self):
471
        pass
472

473
    @unittest.skip("CPU offload seems to be broken for some reason - tiny models keep hitting corner cases")
474
    def test_cpu_offload(self):
475
        pass
476

477
    @unittest.skip("Disk offload seems to be broken for some reason - tiny models keep hitting corner cases")
478
    def test_disk_offload(self):
479
        pass
480

481
    @unittest.skip("BigCodeGPT has a non-standard KV cache format.")
482
    def test_past_key_values_format(self):
483
        pass
484

485
    def test_gpt_bigcode_model(self):
486
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
487
        self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
488

489
    def test_gpt_bigcode_model_past(self):
490
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
491
        self.model_tester.create_and_check_gpt_bigcode_model_past(*config_and_inputs)
492

493
    def test_gpt_bigcode_model_att_mask_past(self):
494
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
495
        self.model_tester.create_and_check_gpt_bigcode_model_attention_mask_past(*config_and_inputs)
496

497
    def test_gpt_bigcode_model_past_large_inputs(self):
498
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
499
        self.model_tester.create_and_check_gpt_bigcode_model_past_large_inputs(*config_and_inputs)
500

501
    def test_gpt_bigcode_lm_head_model(self):
502
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
503
        self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
504

505
    def test_gpt_bigcode_sequence_classification_model(self):
506
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
507
        self.model_tester.create_and_check_gpt_bigcode_for_sequence_classification(*config_and_inputs)
508

509
    def test_gpt_bigcode_token_classification_model(self):
510
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
511
        self.model_tester.create_and_check_gpt_bigcode_for_token_classification(*config_and_inputs)
512

513
    def test_gpt_bigcode_gradient_checkpointing(self):
514
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
515
        self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
516

517
    def test_gpt_bigcode_scale_attn_by_inverse_layer_idx(self):
518
        config_and_inputs = self.model_tester.prepare_config_and_inputs(scale_attn_by_inverse_layer_idx=True)
519
        self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
520

521
    def test_gpt_bigcode_reorder_and_upcast_attn(self):
522
        config_and_inputs = self.model_tester.prepare_config_and_inputs(reorder_and_upcast_attn=True)
523
        self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
524

525
    def test_gpt_bigcode_weight_initialization(self):
526
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
527
        self.model_tester.create_and_check_gpt_bigcode_weight_initialization(*config_and_inputs)
528

529

530
@require_torch
531
class GPTBigCodeMHAModelTest(GPTBigCodeModelTest):
532
    # `parameterized_class` breaks with mixins, so we use inheritance instead
533
    multi_query = False
534

535

536
@unittest.skipIf(
537
    not is_torch_greater_or_equal_than_1_12,
538
    reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.",
539
)
540
@slow
541
@require_torch
542
class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):
543
    def test_generate_simple(self):
544
        model = GPTBigCodeForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder").to(torch_device)
545
        tokenizer = GPT2TokenizerFast.from_pretrained("bigcode/gpt_bigcode-santacoder")
546

547
        input_ids = tokenizer("def print_hello_world():", return_tensors="pt").input_ids.to(torch_device)
548

549
        output_sequence = model.generate(input_ids)
550
        output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True)
551

552
        expected_output = """def print_hello_world():\n    print("Hello World!")\n\n\ndef print_hello_"""
553
        self.assertEqual(output_sentence, expected_output)
554

555
    def test_generate_batched(self):
556
        tokenizer = GPT2TokenizerFast.from_pretrained("bigcode/gpt_bigcode-santacoder")
557
        tokenizer.pad_token = tokenizer.eos_token
558
        tokenizer.padding_side = "left"
559

560
        model = GPTBigCodeForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder").to(torch_device)
561

562
        inputs = tokenizer(["def print_hello_world():", "def say_hello():"], return_tensors="pt", padding=True).to(
563
            torch_device
564
        )
565
        outputs = model.generate(**inputs)
566
        outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
567

568
        expected_output = [
569
            'def print_hello_world():\n    print("Hello World!")\n\n\ndef print_hello_',
570
            'def say_hello():\n    print("Hello, World!")\n\n\nsay_hello()',
571
        ]
572
        self.assertListEqual(outputs, expected_output)
573

574

575
@require_torch
576
class GPTBigCodeMQATest(unittest.TestCase):
577
    def get_attention(self, multi_query):
578
        config = GPTBigCodeConfig.from_pretrained(
579
            "bigcode/gpt_bigcode-santacoder",
580
            multi_query=multi_query,
581
            attn_pdrop=0,
582
            resid_pdrop=0,
583
        )
584
        return GPTBigCodeAttention(config)
585

586
    @parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]])
587
    def test_mqa_reduces_to_mha(self, seed, is_train_mode=True):
588
        torch.manual_seed(seed)
589

590
        # CREATE MQA AND MHA ATTENTIONS
591
        attention_mqa = self.get_attention(True)
592
        attention_mha = self.get_attention(False)
593

594
        # ENFORCE MATCHING WEIGHTS
595
        num_heads = attention_mqa.num_heads
596
        embed_dim = attention_mqa.embed_dim
597
        head_dim = attention_mqa.head_dim
598

599
        with torch.no_grad():
600
            mqa_q_weight = attention_mqa.c_attn.weight[:embed_dim, :].view(num_heads, 1, head_dim, embed_dim)
601
            mqa_kv_weight = attention_mqa.c_attn.weight[embed_dim:, :].view(1, 2, head_dim, embed_dim)
602
            mha_c_weight = torch.cat(
603
                [mqa_q_weight, mqa_kv_weight.expand(num_heads, 2, head_dim, embed_dim)], dim=1
604
            ).view(3 * num_heads * head_dim, embed_dim)
605

606
            mqa_q_bias = attention_mqa.c_attn.bias[:embed_dim].view(num_heads, 1, head_dim)
607
            mqa_kv_bias = attention_mqa.c_attn.bias[embed_dim:].view(1, 2, head_dim)
608
            mha_c_bias = torch.cat([mqa_q_bias, mqa_kv_bias.expand(num_heads, 2, head_dim)], dim=1).view(
609
                3 * num_heads * head_dim
610
            )
611

612
            attention_mha.c_attn.weight.copy_(mha_c_weight)
613
            attention_mha.c_attn.bias.copy_(mha_c_bias)
614
            attention_mha.c_proj.weight.copy_(attention_mqa.c_proj.weight)
615
            attention_mha.c_proj.bias.copy_(attention_mqa.c_proj.bias)
616

617
        # PUT THE MODEL INTO THE CORRECT MODE
618
        attention_mha.train(is_train_mode)
619
        attention_mqa.train(is_train_mode)
620

621
        # RUN AN INPUT THROUGH THE MODELS
622
        num_tokens = 5
623
        hidden_states = torch.randn(1, num_tokens, embed_dim)
624
        attention_mha_result = attention_mha(hidden_states)[0]
625
        attention_mqa_result = attention_mqa(hidden_states)[0]
626

627
        # CHECK THAT ALL OUTPUTS ARE THE SAME
628
        self.assertTrue(torch.allclose(attention_mha_result, attention_mqa_result, atol=1e-5))
629

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.