transformers
628 строк · 25.8 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import math16import unittest17
18from parameterized import parameterized19
20from transformers import GPTBigCodeConfig, is_torch_available21from transformers.testing_utils import require_torch, slow, torch_device22
23from ...generation.test_utils import GenerationTesterMixin24from ...test_configuration_common import ConfigTester25from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask26from ...test_pipeline_mixin import PipelineTesterMixin27
28
29if is_torch_available():30import torch31
32from transformers import (33GPT2TokenizerFast,34GPTBigCodeForCausalLM,35GPTBigCodeForSequenceClassification,36GPTBigCodeForTokenClassification,37GPTBigCodeModel,38)39from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention40from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_1241else:42is_torch_greater_or_equal_than_1_12 = False43
44
45class GPTBigCodeModelTester:46def __init__(47self,48parent,49batch_size=14,50seq_length=7,51is_training=True,52use_token_type_ids=True,53use_input_mask=True,54use_labels=True,55use_mc_token_ids=True,56vocab_size=99,57hidden_size=32,58num_hidden_layers=2,59num_attention_heads=4,60intermediate_size=37,61hidden_act="relu",62hidden_dropout_prob=0.1,63attention_probs_dropout_prob=0.1,64max_position_embeddings=512,65type_vocab_size=16,66type_sequence_label_size=2,67initializer_range=0.02,68num_labels=3,69num_choices=4,70multi_query=True,71scope=None,72):73self.parent = parent74self.batch_size = batch_size75self.seq_length = seq_length76self.is_training = is_training77self.use_token_type_ids = use_token_type_ids78self.use_input_mask = use_input_mask79self.use_labels = use_labels80self.use_mc_token_ids = use_mc_token_ids81self.vocab_size = vocab_size82self.hidden_size = hidden_size83self.num_hidden_layers = num_hidden_layers84self.num_attention_heads = num_attention_heads85self.intermediate_size = intermediate_size86self.hidden_act = hidden_act87self.hidden_dropout_prob = hidden_dropout_prob88self.attention_probs_dropout_prob = attention_probs_dropout_prob89self.max_position_embeddings = max_position_embeddings90self.type_vocab_size = type_vocab_size91self.type_sequence_label_size = type_sequence_label_size92self.initializer_range = initializer_range93self.num_labels = num_labels94self.num_choices = num_choices95self.scope = None96self.bos_token_id = vocab_size - 197self.eos_token_id = vocab_size - 298self.pad_token_id = vocab_size - 399self.multi_query = multi_query100
101def get_large_model_config(self):102return GPTBigCodeConfig.from_pretrained("bigcode/gpt_bigcode-santacoder")103
104def prepare_config_and_inputs(105self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False106):107input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)108
109input_mask = None110if self.use_input_mask:111input_mask = random_attention_mask([self.batch_size, self.seq_length])112
113token_type_ids = None114if self.use_token_type_ids:115token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)116
117mc_token_ids = None118if self.use_mc_token_ids:119mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)120
121sequence_labels = None122token_labels = None123choice_labels = None124if self.use_labels:125sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)126token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)127choice_labels = ids_tensor([self.batch_size], self.num_choices)128
129config = self.get_config(130gradient_checkpointing=gradient_checkpointing,131scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,132reorder_and_upcast_attn=reorder_and_upcast_attn,133)134
135head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)136
137return (138config,139input_ids,140input_mask,141head_mask,142token_type_ids,143mc_token_ids,144sequence_labels,145token_labels,146choice_labels,147)148
149def get_config(150self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False151):152return GPTBigCodeConfig(153vocab_size=self.vocab_size,154n_embd=self.hidden_size,155n_layer=self.num_hidden_layers,156n_head=self.num_attention_heads,157n_inner=self.intermediate_size,158activation_function=self.hidden_act,159resid_pdrop=self.hidden_dropout_prob,160attn_pdrop=self.attention_probs_dropout_prob,161n_positions=self.max_position_embeddings,162type_vocab_size=self.type_vocab_size,163initializer_range=self.initializer_range,164use_cache=True,165bos_token_id=self.bos_token_id,166eos_token_id=self.eos_token_id,167pad_token_id=self.pad_token_id,168gradient_checkpointing=gradient_checkpointing,169scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,170reorder_and_upcast_attn=reorder_and_upcast_attn,171attention_softmax_in_fp32=False,172scale_attention_softmax_in_fp32=False,173multi_query=self.multi_query,174)175
176def get_pipeline_config(self):177config = self.get_config()178config.vocab_size = 300179return config180
181def prepare_config_and_inputs_for_decoder(self):182(183config,184input_ids,185input_mask,186head_mask,187token_type_ids,188mc_token_ids,189sequence_labels,190token_labels,191choice_labels,192) = self.prepare_config_and_inputs()193
194encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])195encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)196
197return (198config,199input_ids,200input_mask,201head_mask,202token_type_ids,203sequence_labels,204token_labels,205choice_labels,206encoder_hidden_states,207encoder_attention_mask,208)209
210def create_and_check_gpt_bigcode_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):211model = GPTBigCodeModel(config=config)212model.to(torch_device)213model.eval()214
215result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)216result = model(input_ids, token_type_ids=token_type_ids)217result = model(input_ids)218
219self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))220self.parent.assertEqual(len(result.past_key_values), config.n_layer)221
222def create_and_check_gpt_bigcode_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):223model = GPTBigCodeModel(config=config)224model.to(torch_device)225model.eval()226
227# first forward pass228outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)229outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)230outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)231
232self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))233self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)234
235output, past = outputs.to_tuple()236
237# create hypothetical next token and extent to next_input_ids238next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)239next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)240
241# append to next input_ids and token_type_ids242next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)243next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)244
245output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]246output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[247"last_hidden_state"248]249
250# select random slice251random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()252output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()253output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()254
255# test that outputs are equal for slice256self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))257
258def create_and_check_gpt_bigcode_model_attention_mask_past(259self, config, input_ids, input_mask, head_mask, token_type_ids, *args260):261model = GPTBigCodeModel(config=config)262model.to(torch_device)263model.eval()264
265# create attention mask266attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)267half_seq_length = self.seq_length // 2268attn_mask[:, half_seq_length:] = 0269
270# first forward pass271output, past = model(input_ids, attention_mask=attn_mask).to_tuple()272
273# create hypothetical next token and extent to next_input_ids274next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)275
276# change a random masked slice from input_ids277random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1278random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)279input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens280
281# append to next input_ids and attn_mask282next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)283attn_mask = torch.cat(284[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],285dim=1,286)287
288# get two different outputs289output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]290output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]291
292# select random slice293random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()294output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()295output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()296
297# test that outputs are equal for slice298self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))299
300def create_and_check_gpt_bigcode_model_past_large_inputs(301self, config, input_ids, input_mask, head_mask, token_type_ids, *args302):303model = GPTBigCodeModel(config=config)304model.to(torch_device)305model.eval()306
307# first forward pass308outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True)309
310output, past = outputs.to_tuple()311
312# create hypothetical next token and extent to next_input_ids313next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)314next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)315next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)316
317# append to next input_ids and token_type_ids318next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)319next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)320next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)321
322output_from_no_past = model(323next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask324)["last_hidden_state"]325output_from_past = model(326next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past327)["last_hidden_state"]328self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])329
330# select random slice331random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()332output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()333output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()334
335# test that outputs are equal for slice336self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))337
338def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):339model = GPTBigCodeForCausalLM(config)340model.to(torch_device)341model.eval()342
343result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)344self.parent.assertEqual(result.loss.shape, ())345self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))346
347def create_and_check_forward_and_backwards(348self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False349):350model = GPTBigCodeForCausalLM(config)351model.to(torch_device)352if gradient_checkpointing:353model.gradient_checkpointing_enable()354
355result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)356self.parent.assertEqual(result.loss.shape, ())357self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))358result.loss.backward()359
360def create_and_check_gpt_bigcode_for_sequence_classification(361self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args362):363config.num_labels = self.num_labels364model = GPTBigCodeForSequenceClassification(config)365model.to(torch_device)366model.eval()367result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)368self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))369
370def create_and_check_gpt_bigcode_for_token_classification(371self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args372):373config.num_labels = self.num_labels374model = GPTBigCodeForTokenClassification(config)375model.to(torch_device)376model.eval()377result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)378self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))379
380def create_and_check_gpt_bigcode_weight_initialization(self, config, *args):381model = GPTBigCodeModel(config)382model_std = model.config.initializer_range / math.sqrt(2 * model.config.n_layer)383for key in model.state_dict().keys():384if "c_proj" in key and "weight" in key:385self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)386self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)387
388def prepare_config_and_inputs_for_common(self):389config_and_inputs = self.prepare_config_and_inputs()390
391(392config,393input_ids,394input_mask,395head_mask,396token_type_ids,397mc_token_ids,398sequence_labels,399token_labels,400choice_labels,401) = config_and_inputs402
403inputs_dict = {404"input_ids": input_ids,405"token_type_ids": token_type_ids,406"head_mask": head_mask,407}408
409return config, inputs_dict410
411
412@require_torch
413class GPTBigCodeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):414# TODO: Update the tests to use valid pretrained models.415all_model_classes = (416(417GPTBigCodeModel,418GPTBigCodeForCausalLM,419GPTBigCodeForSequenceClassification,420GPTBigCodeForTokenClassification,421)422if is_torch_available()423else ()424)425all_generative_model_classes = (GPTBigCodeForCausalLM,) if is_torch_available() else ()426pipeline_model_mapping = (427{428"feature-extraction": GPTBigCodeModel,429"text-classification": GPTBigCodeForSequenceClassification,430"text-generation": GPTBigCodeForCausalLM,431"token-classification": GPTBigCodeForTokenClassification,432"zero-shot": GPTBigCodeForSequenceClassification,433}434if is_torch_available()435else {}436)437fx_compatible = False438test_missing_keys = False439test_pruning = False440test_torchscript = False441multi_query = True442
443# special case for DoubleHeads model444def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):445inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)446
447return inputs_dict448
449def setUp(self):450self.model_tester = GPTBigCodeModelTester(self, multi_query=self.multi_query)451self.config_tester = ConfigTester(self, config_class=GPTBigCodeConfig, n_embd=37)452
453def tearDown(self):454import gc455
456gc.collect()457
458def test_config(self):459self.config_tester.run_common_tests()460
461@unittest.skip("MQA models does not support retain_grad")462def test_retain_grad_hidden_states_attentions(self):463pass464
465@unittest.skip("Contrastive search not supported due to non-standard caching mechanism")466def test_contrastive_generate(self):467pass468
469@unittest.skip("Contrastive search not supported due to non-standard caching mechanism")470def test_contrastive_generate_dict_outputs_use_cache(self):471pass472
473@unittest.skip("CPU offload seems to be broken for some reason - tiny models keep hitting corner cases")474def test_cpu_offload(self):475pass476
477@unittest.skip("Disk offload seems to be broken for some reason - tiny models keep hitting corner cases")478def test_disk_offload(self):479pass480
481@unittest.skip("BigCodeGPT has a non-standard KV cache format.")482def test_past_key_values_format(self):483pass484
485def test_gpt_bigcode_model(self):486config_and_inputs = self.model_tester.prepare_config_and_inputs()487self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)488
489def test_gpt_bigcode_model_past(self):490config_and_inputs = self.model_tester.prepare_config_and_inputs()491self.model_tester.create_and_check_gpt_bigcode_model_past(*config_and_inputs)492
493def test_gpt_bigcode_model_att_mask_past(self):494config_and_inputs = self.model_tester.prepare_config_and_inputs()495self.model_tester.create_and_check_gpt_bigcode_model_attention_mask_past(*config_and_inputs)496
497def test_gpt_bigcode_model_past_large_inputs(self):498config_and_inputs = self.model_tester.prepare_config_and_inputs()499self.model_tester.create_and_check_gpt_bigcode_model_past_large_inputs(*config_and_inputs)500
501def test_gpt_bigcode_lm_head_model(self):502config_and_inputs = self.model_tester.prepare_config_and_inputs()503self.model_tester.create_and_check_lm_head_model(*config_and_inputs)504
505def test_gpt_bigcode_sequence_classification_model(self):506config_and_inputs = self.model_tester.prepare_config_and_inputs()507self.model_tester.create_and_check_gpt_bigcode_for_sequence_classification(*config_and_inputs)508
509def test_gpt_bigcode_token_classification_model(self):510config_and_inputs = self.model_tester.prepare_config_and_inputs()511self.model_tester.create_and_check_gpt_bigcode_for_token_classification(*config_and_inputs)512
513def test_gpt_bigcode_gradient_checkpointing(self):514config_and_inputs = self.model_tester.prepare_config_and_inputs()515self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)516
517def test_gpt_bigcode_scale_attn_by_inverse_layer_idx(self):518config_and_inputs = self.model_tester.prepare_config_and_inputs(scale_attn_by_inverse_layer_idx=True)519self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)520
521def test_gpt_bigcode_reorder_and_upcast_attn(self):522config_and_inputs = self.model_tester.prepare_config_and_inputs(reorder_and_upcast_attn=True)523self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)524
525def test_gpt_bigcode_weight_initialization(self):526config_and_inputs = self.model_tester.prepare_config_and_inputs()527self.model_tester.create_and_check_gpt_bigcode_weight_initialization(*config_and_inputs)528
529
530@require_torch
531class GPTBigCodeMHAModelTest(GPTBigCodeModelTest):532# `parameterized_class` breaks with mixins, so we use inheritance instead533multi_query = False534
535
536@unittest.skipIf(537not is_torch_greater_or_equal_than_1_12,538reason="`GPTBigCode` checkpoints use `PytorchGELUTanh` which requires `torch>=1.12.0`.",539)
540@slow
541@require_torch
542class GPTBigCodeModelLanguageGenerationTest(unittest.TestCase):543def test_generate_simple(self):544model = GPTBigCodeForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder").to(torch_device)545tokenizer = GPT2TokenizerFast.from_pretrained("bigcode/gpt_bigcode-santacoder")546
547input_ids = tokenizer("def print_hello_world():", return_tensors="pt").input_ids.to(torch_device)548
549output_sequence = model.generate(input_ids)550output_sentence = tokenizer.decode(output_sequence[0], skip_special_tokens=True)551
552expected_output = """def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_"""553self.assertEqual(output_sentence, expected_output)554
555def test_generate_batched(self):556tokenizer = GPT2TokenizerFast.from_pretrained("bigcode/gpt_bigcode-santacoder")557tokenizer.pad_token = tokenizer.eos_token558tokenizer.padding_side = "left"559
560model = GPTBigCodeForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder").to(torch_device)561
562inputs = tokenizer(["def print_hello_world():", "def say_hello():"], return_tensors="pt", padding=True).to(563torch_device
564)565outputs = model.generate(**inputs)566outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)567
568expected_output = [569'def print_hello_world():\n print("Hello World!")\n\n\ndef print_hello_',570'def say_hello():\n print("Hello, World!")\n\n\nsay_hello()',571]572self.assertListEqual(outputs, expected_output)573
574
575@require_torch
576class GPTBigCodeMQATest(unittest.TestCase):577def get_attention(self, multi_query):578config = GPTBigCodeConfig.from_pretrained(579"bigcode/gpt_bigcode-santacoder",580multi_query=multi_query,581attn_pdrop=0,582resid_pdrop=0,583)584return GPTBigCodeAttention(config)585
586@parameterized.expand([(seed, is_train_mode) for seed in range(5) for is_train_mode in [True, False]])587def test_mqa_reduces_to_mha(self, seed, is_train_mode=True):588torch.manual_seed(seed)589
590# CREATE MQA AND MHA ATTENTIONS591attention_mqa = self.get_attention(True)592attention_mha = self.get_attention(False)593
594# ENFORCE MATCHING WEIGHTS595num_heads = attention_mqa.num_heads596embed_dim = attention_mqa.embed_dim597head_dim = attention_mqa.head_dim598
599with torch.no_grad():600mqa_q_weight = attention_mqa.c_attn.weight[:embed_dim, :].view(num_heads, 1, head_dim, embed_dim)601mqa_kv_weight = attention_mqa.c_attn.weight[embed_dim:, :].view(1, 2, head_dim, embed_dim)602mha_c_weight = torch.cat(603[mqa_q_weight, mqa_kv_weight.expand(num_heads, 2, head_dim, embed_dim)], dim=1604).view(3 * num_heads * head_dim, embed_dim)605
606mqa_q_bias = attention_mqa.c_attn.bias[:embed_dim].view(num_heads, 1, head_dim)607mqa_kv_bias = attention_mqa.c_attn.bias[embed_dim:].view(1, 2, head_dim)608mha_c_bias = torch.cat([mqa_q_bias, mqa_kv_bias.expand(num_heads, 2, head_dim)], dim=1).view(6093 * num_heads * head_dim610)611
612attention_mha.c_attn.weight.copy_(mha_c_weight)613attention_mha.c_attn.bias.copy_(mha_c_bias)614attention_mha.c_proj.weight.copy_(attention_mqa.c_proj.weight)615attention_mha.c_proj.bias.copy_(attention_mqa.c_proj.bias)616
617# PUT THE MODEL INTO THE CORRECT MODE618attention_mha.train(is_train_mode)619attention_mqa.train(is_train_mode)620
621# RUN AN INPUT THROUGH THE MODELS622num_tokens = 5623hidden_states = torch.randn(1, num_tokens, embed_dim)624attention_mha_result = attention_mha(hidden_states)[0]625attention_mqa_result = attention_mqa(hidden_states)[0]626
627# CHECK THAT ALL OUTPUTS ARE THE SAME628self.assertTrue(torch.allclose(attention_mha_result, attention_mqa_result, atol=1e-5))629