transformers
901 строка · 33.7 Кб
1# coding=utf-8
2# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch Marian model. """
16
17import tempfile18import unittest19
20from huggingface_hub.hf_api import list_models21
22from transformers import MarianConfig, is_torch_available23from transformers.testing_utils import (24require_sentencepiece,25require_tokenizers,26require_torch,27require_torch_fp16,28slow,29torch_device,30)
31from transformers.utils import cached_property32
33from ...generation.test_utils import GenerationTesterMixin34from ...test_configuration_common import ConfigTester35from ...test_modeling_common import ModelTesterMixin, ids_tensor36from ...test_pipeline_mixin import PipelineTesterMixin37
38
39if is_torch_available():40import torch41
42from transformers import (43AutoConfig,44AutoModelWithLMHead,45AutoTokenizer,46MarianModel,47MarianMTModel,48TranslationPipeline,49)50from transformers.models.marian.convert_marian_to_pytorch import (51ORG_NAME,52convert_hf_name_to_opus_name,53convert_opus_name_to_hf_name,54)55from transformers.models.marian.modeling_marian import (56MarianDecoder,57MarianEncoder,58MarianForCausalLM,59shift_tokens_right,60)61
62
63def prepare_marian_inputs_dict(64config,65input_ids,66decoder_input_ids,67attention_mask=None,68decoder_attention_mask=None,69head_mask=None,70decoder_head_mask=None,71cross_attn_head_mask=None,72):73if attention_mask is None:74attention_mask = input_ids.ne(config.pad_token_id)75if decoder_attention_mask is None:76decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)77if head_mask is None:78head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)79if decoder_head_mask is None:80decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)81if cross_attn_head_mask is None:82cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)83return {84"input_ids": input_ids,85"decoder_input_ids": decoder_input_ids,86"attention_mask": attention_mask,87"decoder_attention_mask": attention_mask,88"head_mask": head_mask,89"decoder_head_mask": decoder_head_mask,90"cross_attn_head_mask": cross_attn_head_mask,91}92
93
94class MarianModelTester:95def __init__(96self,97parent,98batch_size=13,99seq_length=7,100is_training=True,101use_labels=False,102vocab_size=99,103hidden_size=16,104num_hidden_layers=2,105num_attention_heads=4,106intermediate_size=4,107hidden_act="gelu",108hidden_dropout_prob=0.1,109attention_probs_dropout_prob=0.1,110max_position_embeddings=20,111eos_token_id=2,112pad_token_id=1,113bos_token_id=0,114decoder_start_token_id=3,115):116self.parent = parent117self.batch_size = batch_size118self.seq_length = seq_length119self.is_training = is_training120self.use_labels = use_labels121self.vocab_size = vocab_size122self.hidden_size = hidden_size123self.num_hidden_layers = num_hidden_layers124self.num_attention_heads = num_attention_heads125self.intermediate_size = intermediate_size126self.hidden_act = hidden_act127self.hidden_dropout_prob = hidden_dropout_prob128self.attention_probs_dropout_prob = attention_probs_dropout_prob129self.max_position_embeddings = max_position_embeddings130self.eos_token_id = eos_token_id131self.pad_token_id = pad_token_id132self.bos_token_id = bos_token_id133self.decoder_start_token_id = decoder_start_token_id134
135# forcing a certain token to be generated, sets all other tokens to -inf136# if however the token to be generated is already at -inf then it can lead token137# `nan` values and thus break generation138self.forced_bos_token_id = None139self.forced_eos_token_id = None140
141def prepare_config_and_inputs(self):142input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(1433,144)145input_ids[:, -1] = self.eos_token_id # Eos Token146
147decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)148
149config = self.get_config()150inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)151return config, inputs_dict152
153def get_config(self):154return MarianConfig(155vocab_size=self.vocab_size,156d_model=self.hidden_size,157encoder_layers=self.num_hidden_layers,158decoder_layers=self.num_hidden_layers,159encoder_attention_heads=self.num_attention_heads,160decoder_attention_heads=self.num_attention_heads,161encoder_ffn_dim=self.intermediate_size,162decoder_ffn_dim=self.intermediate_size,163dropout=self.hidden_dropout_prob,164attention_dropout=self.attention_probs_dropout_prob,165max_position_embeddings=self.max_position_embeddings,166eos_token_id=self.eos_token_id,167bos_token_id=self.bos_token_id,168pad_token_id=self.pad_token_id,169decoder_start_token_id=self.decoder_start_token_id,170forced_bos_token_id=self.forced_bos_token_id,171forced_eos_token_id=self.forced_eos_token_id,172)173
174def prepare_config_and_inputs_for_common(self):175config, inputs_dict = self.prepare_config_and_inputs()176return config, inputs_dict177
178def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):179model = MarianModel(config=config).get_decoder().to(torch_device).eval()180input_ids = inputs_dict["input_ids"]181attention_mask = inputs_dict["attention_mask"]182head_mask = inputs_dict["head_mask"]183
184# first forward pass185outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)186
187output, past_key_values = outputs.to_tuple()188
189# create hypothetical multiple next token and extent to next_input_ids190next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)191next_attn_mask = ids_tensor((self.batch_size, 3), 2)192
193# append to next input_ids and194next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)195next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)196
197output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]198output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[199"last_hidden_state"200]201
202# select random slice203random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()204output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()205output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()206
207self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])208
209# test that outputs are equal for slice210self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))211
212def check_encoder_decoder_model_standalone(self, config, inputs_dict):213model = MarianModel(config=config).to(torch_device).eval()214outputs = model(**inputs_dict)215
216encoder_last_hidden_state = outputs.encoder_last_hidden_state217last_hidden_state = outputs.last_hidden_state218
219with tempfile.TemporaryDirectory() as tmpdirname:220encoder = model.get_encoder()221encoder.save_pretrained(tmpdirname)222encoder = MarianEncoder.from_pretrained(tmpdirname).to(torch_device)223
224encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[2250226]227
228self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)229
230with tempfile.TemporaryDirectory() as tmpdirname:231decoder = model.get_decoder()232decoder.save_pretrained(tmpdirname)233decoder = MarianDecoder.from_pretrained(tmpdirname).to(torch_device)234
235last_hidden_state_2 = decoder(236input_ids=inputs_dict["decoder_input_ids"],237attention_mask=inputs_dict["decoder_attention_mask"],238encoder_hidden_states=encoder_last_hidden_state,239encoder_attention_mask=inputs_dict["attention_mask"],240)[0]241
242self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)243
244
245@require_torch
246class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):247all_model_classes = (MarianModel, MarianMTModel) if is_torch_available() else ()248all_generative_model_classes = (MarianMTModel,) if is_torch_available() else ()249pipeline_model_mapping = (250{251"conversational": MarianMTModel,252"feature-extraction": MarianModel,253"summarization": MarianMTModel,254"text-generation": MarianForCausalLM,255"text2text-generation": MarianMTModel,256"translation": MarianMTModel,257}258if is_torch_available()259else {}260)261is_encoder_decoder = True262fx_compatible = True263test_pruning = False264test_missing_keys = False265
266def setUp(self):267self.model_tester = MarianModelTester(self)268self.config_tester = ConfigTester(self, config_class=MarianConfig)269
270def test_config(self):271self.config_tester.run_common_tests()272
273def test_save_load_strict(self):274config, inputs_dict = self.model_tester.prepare_config_and_inputs()275for model_class in self.all_model_classes:276model = model_class(config)277
278with tempfile.TemporaryDirectory() as tmpdirname:279model.save_pretrained(tmpdirname)280model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)281self.assertEqual(info["missing_keys"], [])282
283def test_decoder_model_past_with_large_inputs(self):284config_and_inputs = self.model_tester.prepare_config_and_inputs()285self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)286
287def test_encoder_decoder_model_standalone(self):288config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()289self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)290
291@require_torch_fp16292def test_generate_fp16(self):293config, input_dict = self.model_tester.prepare_config_and_inputs()294input_ids = input_dict["input_ids"]295attention_mask = input_ids.ne(1).to(torch_device)296model = MarianMTModel(config).eval().to(torch_device)297model.half()298model.generate(input_ids, attention_mask=attention_mask)299model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)300
301def test_share_encoder_decoder_embeddings(self):302config, input_dict = self.model_tester.prepare_config_and_inputs()303
304# check if embeddings are shared by default305for model_class in self.all_model_classes:306model = model_class(config)307self.assertIs(model.get_encoder().embed_tokens, model.get_decoder().embed_tokens)308self.assertIs(model.get_encoder().embed_tokens.weight, model.get_decoder().embed_tokens.weight)309
310# check if embeddings are not shared when config.share_encoder_decoder_embeddings = False311config.share_encoder_decoder_embeddings = False312for model_class in self.all_model_classes:313model = model_class(config)314self.assertIsNot(model.get_encoder().embed_tokens, model.get_decoder().embed_tokens)315self.assertIsNot(model.get_encoder().embed_tokens.weight, model.get_decoder().embed_tokens.weight)316
317# check if a model with shared embeddings can be saved and loaded with share_encoder_decoder_embeddings = False318config, _ = self.model_tester.prepare_config_and_inputs()319for model_class in self.all_model_classes:320model = model_class(config)321with tempfile.TemporaryDirectory() as tmpdirname:322model.save_pretrained(tmpdirname)323model = model_class.from_pretrained(tmpdirname, share_encoder_decoder_embeddings=False)324self.assertIsNot(model.get_encoder().embed_tokens, model.get_decoder().embed_tokens)325self.assertIsNot(model.get_encoder().embed_tokens.weight, model.get_decoder().embed_tokens.weight)326
327def test_resize_decoder_token_embeddings(self):328config, _ = self.model_tester.prepare_config_and_inputs()329
330# check if resize_decoder_token_embeddings raises an error when embeddings are shared331for model_class in self.all_model_classes:332model = model_class(config)333with self.assertRaises(ValueError):334model.resize_decoder_token_embeddings(config.vocab_size + 1)335
336# check if decoder embeddings are resized when config.share_encoder_decoder_embeddings = False337config.share_encoder_decoder_embeddings = False338for model_class in self.all_model_classes:339model = model_class(config)340model.resize_decoder_token_embeddings(config.vocab_size + 1)341self.assertEqual(model.get_decoder().embed_tokens.weight.shape, (config.vocab_size + 1, config.d_model))342
343# check if lm_head is also resized344config, _ = self.model_tester.prepare_config_and_inputs()345config.share_encoder_decoder_embeddings = False346model = MarianMTModel(config)347model.resize_decoder_token_embeddings(config.vocab_size + 1)348self.assertEqual(model.lm_head.weight.shape, (config.vocab_size + 1, config.d_model))349
350def test_tie_word_embeddings_decoder(self):351pass352
353@unittest.skip("Skipping for now, to fix @ArthurZ or @ydshieh")354def test_pipeline_conversational(self):355pass356
357@unittest.skip(358reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"359)360def test_training_gradient_checkpointing(self):361pass362
363@unittest.skip(364reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"365)366def test_training_gradient_checkpointing_use_reentrant(self):367pass368
369@unittest.skip(370reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"371)372def test_training_gradient_checkpointing_use_reentrant_false(self):373pass374
375
376def assert_tensors_close(a, b, atol=1e-12, prefix=""):377"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""378if a is None and b is None:379return True380try:381if torch.allclose(a, b, atol=atol):382return True383raise384except Exception:385pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item()386if a.numel() > 100:387msg = f"tensor values are {pct_different:.1%} percent different."388else:389msg = f"{a} != {b}"390if prefix:391msg = prefix + ": " + msg392raise AssertionError(msg)393
394
395def _long_tensor(tok_lst):396return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)397
398
399class ModelManagementTests(unittest.TestCase):400@slow401@require_torch402def test_model_names(self):403model_list = list_models()404model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]405bad_model_ids = [mid for mid in model_ids if "+" in model_ids]406self.assertListEqual([], bad_model_ids)407self.assertGreater(len(model_ids), 500)408
409
410@require_torch
411@require_sentencepiece
412@require_tokenizers
413class MarianIntegrationTest(unittest.TestCase):414src = "en"415tgt = "de"416src_text = [417"I am a small frog.",418"Now I can forget the 100 words of german that I know.",419"Tom asked his teacher for advice.",420"That's how I would do it.",421"Tom really admired Mary's courage.",422"Turn around and close your eyes.",423]424expected_text = [425"Ich bin ein kleiner Frosch.",426"Jetzt kann ich die 100 Wörter des Deutschen vergessen, die ich kenne.",427"Tom bat seinen Lehrer um Rat.",428"So würde ich das machen.",429"Tom bewunderte Marias Mut wirklich.",430"Drehen Sie sich um und schließen Sie die Augen.",431]432# ^^ actual C++ output differs slightly: (1) des Deutschen removed, (2) ""-> "O", (3) tun -> machen433
434@classmethod435def setUpClass(cls) -> None:436cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"437return cls438
439@cached_property440def tokenizer(self):441return AutoTokenizer.from_pretrained(self.model_name)442
443@property444def eos_token_id(self) -> int:445return self.tokenizer.eos_token_id446
447@cached_property448def model(self):449model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)450c = model.config451self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]])452self.assertEqual(c.max_length, 512)453self.assertEqual(c.decoder_start_token_id, c.pad_token_id)454
455if torch_device == "cuda":456return model.half()457else:458return model459
460def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):461generated_words = self.translate_src_text(**tokenizer_kwargs)462self.assertListEqual(self.expected_text, generated_words)463
464def translate_src_text(self, **tokenizer_kwargs):465model_inputs = self.tokenizer(self.src_text, padding=True, return_tensors="pt", **tokenizer_kwargs).to(466torch_device
467)468self.assertEqual(self.model.device, model_inputs.input_ids.device)469generated_ids = self.model.generate(470model_inputs.input_ids,471attention_mask=model_inputs.attention_mask,472num_beams=2,473max_length=128,474renormalize_logits=True, # Marian should always renormalize its logits. See #25459475)476generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)477return generated_words478
479
480@require_sentencepiece
481@require_tokenizers
482class TestMarian_EN_DE_More(MarianIntegrationTest):483@slow484def test_forward(self):485src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]486expected_ids = [38, 121, 14, 697, 38848, 0]487
488model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)489
490self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())491
492desired_keys = {493"input_ids",494"attention_mask",495"labels",496}497self.assertSetEqual(desired_keys, set(model_inputs.keys()))498model_inputs["decoder_input_ids"] = shift_tokens_right(499model_inputs.labels, self.tokenizer.pad_token_id, self.model.config.decoder_start_token_id500)501model_inputs["return_dict"] = True502model_inputs["use_cache"] = False503with torch.no_grad():504outputs = self.model(**model_inputs)505max_indices = outputs.logits.argmax(-1)506self.tokenizer.batch_decode(max_indices)507
508def test_unk_support(self):509t = self.tokenizer510ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()511expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]512self.assertEqual(expected, ids)513
514def test_pad_not_split(self):515input_ids_w_pad = self.tokenizer(["I am a small frog <pad>"], return_tensors="pt").input_ids[0].tolist()516expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad517self.assertListEqual(expected_w_pad, input_ids_w_pad)518
519@slow520def test_batch_generation_en_de(self):521self._assert_generated_batch_equal_expected()522
523def test_auto_config(self):524config = AutoConfig.from_pretrained(self.model_name)525self.assertIsInstance(config, MarianConfig)526
527
528@require_sentencepiece
529@require_tokenizers
530class TestMarian_EN_FR(MarianIntegrationTest):531src = "en"532tgt = "fr"533src_text = [534"I am a small frog.",535"Now I can forget the 100 words of german that I know.",536]537expected_text = [538"Je suis une petite grenouille.",539"Maintenant, je peux oublier les 100 mots d'allemand que je connais.",540]541
542@slow543def test_batch_generation_en_fr(self):544self._assert_generated_batch_equal_expected()545
546
547@require_sentencepiece
548@require_tokenizers
549class TestMarian_FR_EN(MarianIntegrationTest):550src = "fr"551tgt = "en"552src_text = [553"Donnez moi le micro.",554"Tom et Mary étaient assis à une table.", # Accents555]556expected_text = [557"Give me the microphone.",558"Tom and Mary were sitting at a table.",559]560
561@slow562def test_batch_generation_fr_en(self):563self._assert_generated_batch_equal_expected()564
565
566@require_sentencepiece
567@require_tokenizers
568class TestMarian_RU_FR(MarianIntegrationTest):569src = "ru"570tgt = "fr"571src_text = ["Он показал мне рукопись своей новой пьесы."]572expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]573
574@slow575def test_batch_generation_ru_fr(self):576self._assert_generated_batch_equal_expected()577
578
579@require_sentencepiece
580@require_tokenizers
581class TestMarian_MT_EN(MarianIntegrationTest):582"""Cover low resource/high perplexity setting. This breaks without adjust_logits_generation overwritten"""583
584src = "mt"585tgt = "en"586src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]587expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]588
589@slow590def test_batch_generation_mt_en(self):591self._assert_generated_batch_equal_expected()592
593
594@require_sentencepiece
595@require_tokenizers
596class TestMarian_en_zh(MarianIntegrationTest):597src = "en"598tgt = "zh"599src_text = ["My name is Wolfgang and I live in Berlin"]600expected_text = ["我叫沃尔夫冈 我住在柏林"]601
602@slow603def test_batch_generation_eng_zho(self):604self._assert_generated_batch_equal_expected()605
606
607@require_sentencepiece
608@require_tokenizers
609class TestMarian_en_ROMANCE(MarianIntegrationTest):610"""Multilingual on target side."""611
612src = "en"613tgt = "ROMANCE"614src_text = [615">>fr<< Don't spend so much time watching TV.",616">>pt<< Your message has been sent.",617">>es<< He's two years older than me.",618]619expected_text = [620"Ne passez pas autant de temps à regarder la télé.",621"A sua mensagem foi enviada.",622"Es dos años más viejo que yo.",623]624
625@slow626def test_batch_generation_en_ROMANCE_multi(self):627self._assert_generated_batch_equal_expected()628
629@slow630@require_torch631def test_pipeline(self):632pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=torch_device)633output = pipeline(self.src_text)634self.assertEqual(self.expected_text, [x["translation_text"] for x in output])635
636
637@require_sentencepiece
638@require_tokenizers
639class TestMarian_FI_EN_V2(MarianIntegrationTest):640src = "fi"641tgt = "en"642src_text = [643"minä tykkään kirjojen lukemisesta",644"Pidän jalkapallon katsomisesta",645]646expected_text = ["I like to read books", "I like watching football"]647
648@classmethod649def setUpClass(cls) -> None:650cls.model_name = "hf-internal-testing/test-opus-tatoeba-fi-en-v2"651return cls652
653@slow654def test_batch_generation_fi_en(self):655self._assert_generated_batch_equal_expected()656
657
658@require_torch
659class TestConversionUtils(unittest.TestCase):660def test_renaming_multilingual(self):661old_names = [662"opus-mt-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",663"opus-mt-cmn+cn-fi", # no group664"opus-mt-en-de", # standard name665"opus-mt-en-de", # standard name666]667expected = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]668self.assertListEqual(expected, [convert_opus_name_to_hf_name(x) for x in old_names])669
670def test_undoing_renaming(self):671hf_names = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]672converted_opus_names = [convert_hf_name_to_opus_name(x) for x in hf_names]673expected_opus_names = [674"cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",675"cmn+cn-fi",676"en-de", # standard name677"en-de",678]679self.assertListEqual(expected_opus_names, converted_opus_names)680
681
682class MarianStandaloneDecoderModelTester:683def __init__(684self,685parent,686vocab_size=99,687batch_size=13,688d_model=16,689decoder_seq_length=7,690is_training=True,691is_decoder=True,692use_attention_mask=True,693use_cache=False,694use_labels=True,695decoder_start_token_id=2,696decoder_ffn_dim=32,697decoder_layers=2,698encoder_attention_heads=4,699decoder_attention_heads=4,700max_position_embeddings=30,701is_encoder_decoder=False,702pad_token_id=0,703bos_token_id=1,704eos_token_id=2,705scope=None,706):707self.parent = parent708self.batch_size = batch_size709self.decoder_seq_length = decoder_seq_length710# For common tests711self.seq_length = self.decoder_seq_length712self.is_training = is_training713self.use_attention_mask = use_attention_mask714self.use_labels = use_labels715
716self.vocab_size = vocab_size717self.d_model = d_model718self.hidden_size = d_model719self.num_hidden_layers = decoder_layers720self.decoder_layers = decoder_layers721self.decoder_ffn_dim = decoder_ffn_dim722self.encoder_attention_heads = encoder_attention_heads723self.decoder_attention_heads = decoder_attention_heads724self.num_attention_heads = decoder_attention_heads725self.eos_token_id = eos_token_id726self.bos_token_id = bos_token_id727self.pad_token_id = pad_token_id728self.decoder_start_token_id = decoder_start_token_id729self.use_cache = use_cache730self.max_position_embeddings = max_position_embeddings731self.is_encoder_decoder = is_encoder_decoder732
733self.scope = None734self.decoder_key_length = decoder_seq_length735self.base_model_out_len = 2736self.decoder_attention_idx = 1737
738def prepare_config_and_inputs(self):739input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)740
741attention_mask = None742if self.use_attention_mask:743attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)744
745lm_labels = None746if self.use_labels:747lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)748
749config = MarianConfig(750vocab_size=self.vocab_size,751d_model=self.d_model,752decoder_layers=self.decoder_layers,753decoder_ffn_dim=self.decoder_ffn_dim,754encoder_attention_heads=self.encoder_attention_heads,755decoder_attention_heads=self.decoder_attention_heads,756eos_token_id=self.eos_token_id,757bos_token_id=self.bos_token_id,758use_cache=self.use_cache,759pad_token_id=self.pad_token_id,760decoder_start_token_id=self.decoder_start_token_id,761max_position_embeddings=self.max_position_embeddings,762is_encoder_decoder=self.is_encoder_decoder,763)764
765return (766config,767input_ids,768attention_mask,769lm_labels,770)771
772def create_and_check_decoder_model_past(773self,774config,775input_ids,776attention_mask,777lm_labels,778):779config.use_cache = True780model = MarianDecoder(config=config).to(torch_device).eval()781# first forward pass782outputs = model(input_ids, use_cache=True)783outputs_use_cache_conf = model(input_ids)784outputs_no_past = model(input_ids, use_cache=False)785
786self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))787self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)788
789past_key_values = outputs["past_key_values"]790
791# create hypothetical next token and extent to next_input_ids792next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)793
794# append to next input_ids and795next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)796
797output_from_no_past = model(next_input_ids)["last_hidden_state"]798output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]799
800# select random slice801random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()802output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()803output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()804
805# test that outputs are equal for slice806assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)807
808def create_and_check_decoder_model_attention_mask_past(809self,810config,811input_ids,812attention_mask,813lm_labels,814):815model = MarianDecoder(config=config).to(torch_device).eval()816
817# create attention mask818attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)819
820half_seq_length = input_ids.shape[-1] // 2821attn_mask[:, half_seq_length:] = 0822
823# first forward pass824past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]825
826# create hypothetical next token and extent to next_input_ids827next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)828
829# change a random masked slice from input_ids830random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1831random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)832input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens833
834# append to next input_ids and attn_mask835next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)836attn_mask = torch.cat(837[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],838dim=1,839)840
841# get two different outputs842output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]843output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[844"last_hidden_state"845]846
847# select random slice848random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()849output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()850output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()851
852# test that outputs are equal for slice853assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)854
855def prepare_config_and_inputs_for_common(self):856config_and_inputs = self.prepare_config_and_inputs()857(858config,859input_ids,860attention_mask,861lm_labels,862) = config_and_inputs863
864inputs_dict = {865"input_ids": input_ids,866"attention_mask": attention_mask,867}868return config, inputs_dict869
870
871@require_torch
872class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):873all_model_classes = (MarianDecoder, MarianForCausalLM) if is_torch_available() else ()874all_generative_model_classes = (MarianForCausalLM,) if is_torch_available() else ()875test_pruning = False876is_encoder_decoder = False877
878def setUp(879self,880):881self.model_tester = MarianStandaloneDecoderModelTester(self, is_training=False)882self.config_tester = ConfigTester(self, config_class=MarianConfig)883
884def test_config(self):885self.config_tester.run_common_tests()886
887def test_decoder_model_past(self):888config_and_inputs = self.model_tester.prepare_config_and_inputs()889self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)890
891def test_decoder_model_attn_mask_past(self):892config_and_inputs = self.model_tester.prepare_config_and_inputs()893self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)894
895def test_retain_grad_hidden_states_attentions(self):896# decoder cannot keep gradients897return898
899@unittest.skip("The model doesn't support left padding") # and it's not used enough to be worth fixing :)900def test_left_padding_compatibility(self):901pass902