transformers
318 строк · 11.6 Кб
1# coding=utf-8
2# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17from __future__ import annotations18
19import unittest20import warnings21
22from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available23from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow24from transformers.utils import cached_property25
26from ...test_configuration_common import ConfigTester27from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor28from ...test_pipeline_mixin import PipelineTesterMixin29
30
31if is_tf_available():32import tensorflow as tf33
34from transformers import TFAutoModelForSeq2SeqLM, TFMarianModel, TFMarianMTModel35
36
37@require_tf
38class TFMarianModelTester:39config_cls = MarianConfig40config_updates = {}41hidden_act = "gelu"42
43def __init__(44self,45parent,46batch_size=13,47seq_length=7,48is_training=True,49use_labels=False,50vocab_size=99,51hidden_size=32,52num_hidden_layers=2,53num_attention_heads=4,54intermediate_size=37,55hidden_dropout_prob=0.1,56attention_probs_dropout_prob=0.1,57max_position_embeddings=20,58eos_token_id=2,59pad_token_id=1,60bos_token_id=0,61):62self.parent = parent63self.batch_size = batch_size64self.seq_length = seq_length65self.is_training = is_training66self.use_labels = use_labels67self.vocab_size = vocab_size68self.hidden_size = hidden_size69self.num_hidden_layers = num_hidden_layers70self.num_attention_heads = num_attention_heads71self.intermediate_size = intermediate_size72
73self.hidden_dropout_prob = hidden_dropout_prob74self.attention_probs_dropout_prob = attention_probs_dropout_prob75self.max_position_embeddings = max_position_embeddings76self.eos_token_id = eos_token_id77self.pad_token_id = pad_token_id78self.bos_token_id = bos_token_id79
80def prepare_config_and_inputs_for_common(self):81input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)82eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)83input_ids = tf.concat([input_ids, eos_tensor], axis=1)84
85decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)86
87config = self.config_cls(88vocab_size=self.vocab_size,89d_model=self.hidden_size,90encoder_layers=self.num_hidden_layers,91decoder_layers=self.num_hidden_layers,92encoder_attention_heads=self.num_attention_heads,93decoder_attention_heads=self.num_attention_heads,94encoder_ffn_dim=self.intermediate_size,95decoder_ffn_dim=self.intermediate_size,96dropout=self.hidden_dropout_prob,97attention_dropout=self.attention_probs_dropout_prob,98max_position_embeddings=self.max_position_embeddings,99eos_token_ids=[2],100bos_token_id=self.bos_token_id,101pad_token_id=self.pad_token_id,102decoder_start_token_id=self.pad_token_id,103**self.config_updates,104)105inputs_dict = prepare_marian_inputs_dict(config, input_ids, decoder_input_ids)106return config, inputs_dict107
108def check_decoder_model_past_large_inputs(self, config, inputs_dict):109model = TFMarianModel(config=config).get_decoder()110input_ids = inputs_dict["input_ids"]111
112input_ids = input_ids[:1, :]113attention_mask = inputs_dict["attention_mask"][:1, :]114head_mask = inputs_dict["head_mask"]115self.batch_size = 1116
117# first forward pass118outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)119
120output, past_key_values = outputs.to_tuple()121
122# create hypothetical next token and extent to next_input_ids123next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)124next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)125
126# append to next input_ids and127next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)128next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)129
130output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]131output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]132
133self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])134
135# select random slice136random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))137output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]138output_from_past_slice = output_from_past[:, :, random_slice_idx]139
140# test that outputs are equal for slice141tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)142
143
144def prepare_marian_inputs_dict(145config,146input_ids,147decoder_input_ids,148attention_mask=None,149decoder_attention_mask=None,150head_mask=None,151decoder_head_mask=None,152cross_attn_head_mask=None,153):154if attention_mask is None:155attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)156if decoder_attention_mask is None:157decoder_attention_mask = tf.concat(158[159tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),160tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),161],162axis=-1,163)164if head_mask is None:165head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))166if decoder_head_mask is None:167decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))168if cross_attn_head_mask is None:169cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))170return {171"input_ids": input_ids,172"decoder_input_ids": decoder_input_ids,173"attention_mask": attention_mask,174"decoder_attention_mask": decoder_attention_mask,175"head_mask": head_mask,176"decoder_head_mask": decoder_head_mask,177"cross_attn_head_mask": cross_attn_head_mask,178}179
180
181@require_tf
182class TFMarianModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):183all_model_classes = (TFMarianMTModel, TFMarianModel) if is_tf_available() else ()184all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()185pipeline_model_mapping = (186{187"conversational": TFMarianMTModel,188"feature-extraction": TFMarianModel,189"summarization": TFMarianMTModel,190"text2text-generation": TFMarianMTModel,191"translation": TFMarianMTModel,192}193if is_tf_available()194else {}195)196is_encoder_decoder = True197test_pruning = False198test_onnx = False199
200def setUp(self):201self.model_tester = TFMarianModelTester(self)202self.config_tester = ConfigTester(self, config_class=MarianConfig)203
204def test_config(self):205self.config_tester.run_common_tests()206
207def test_decoder_model_past_large_inputs(self):208config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()209self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)210
211@unittest.skip("Skipping for now, to fix @ArthurZ or @ydshieh")212def test_pipeline_conversational(self):213pass214
215
216@require_tf
217class AbstractMarianIntegrationTest(unittest.TestCase):218maxDiff = 1000 # show more chars for failing integration tests219
220@classmethod221def setUpClass(cls) -> None:222cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"223return cls224
225@cached_property226def tokenizer(self) -> MarianTokenizer:227return AutoTokenizer.from_pretrained(self.model_name)228
229@property230def eos_token_id(self) -> int:231return self.tokenizer.eos_token_id232
233@cached_property234def model(self):235warnings.simplefilter("error")236model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name)237assert isinstance(model, TFMarianMTModel)238c = model.config239self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]])240self.assertEqual(c.max_length, 512)241self.assertEqual(c.decoder_start_token_id, c.pad_token_id)242return model243
244def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):245generated_words = self.translate_src_text(**tokenizer_kwargs)246self.assertListEqual(self.expected_text, generated_words)247
248def translate_src_text(self, **tokenizer_kwargs):249model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")250generated_ids = self.model.generate(251model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128252)253generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)254return generated_words255
256
257@require_sentencepiece
258@require_tokenizers
259@require_tf
260class TestMarian_MT_EN(AbstractMarianIntegrationTest):261"""Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE."""262
263src = "mt"264tgt = "en"265src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]266expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]267
268@unittest.skip("Skipping until #12647 is resolved.")269@slow270def test_batch_generation_mt_en(self):271self._assert_generated_batch_equal_expected()272
273
274@require_sentencepiece
275@require_tokenizers
276@require_tf
277class TestMarian_en_zh(AbstractMarianIntegrationTest):278src = "en"279tgt = "zh"280src_text = ["My name is Wolfgang and I live in Berlin"]281expected_text = ["我叫沃尔夫冈 我住在柏林"]282
283@unittest.skip("Skipping until #12647 is resolved.")284@slow285def test_batch_generation_en_zh(self):286self._assert_generated_batch_equal_expected()287
288
289@require_sentencepiece
290@require_tokenizers
291@require_tf
292class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest):293"""Multilingual on target side."""294
295src = "en"296tgt = "ROMANCE"297src_text = [298">>fr<< Don't spend so much time watching TV.",299">>pt<< Your message has been sent.",300">>es<< He's two years older than me.",301]302expected_text = [303"Ne passez pas autant de temps à regarder la télé.",304"A sua mensagem foi enviada.",305"Es dos años más viejo que yo.",306]307
308@unittest.skip("Skipping until #12647 is resolved.")309@slow310def test_batch_generation_en_ROMANCE_multi(self):311self._assert_generated_batch_equal_expected()312
313@unittest.skip("Skipping until #12647 is resolved.")314@slow315def test_pipeline(self):316pipeline = TranslationPipeline(self.model, self.tokenizer, framework="tf")317output = pipeline(self.src_text)318self.assertEqual(self.expected_text, [x["translation_text"] for x in output])319