transformers
573 строки · 33.3 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch NLLB-MoE model. """
16
17
18import copy19import tempfile20import unittest21
22from transformers import NllbMoeConfig, is_torch_available, set_seed23from transformers.testing_utils import (24require_sentencepiece,25require_tokenizers,26require_torch,27require_torch_fp16,28slow,29torch_device,30)
31from transformers.utils import cached_property32
33from ...generation.test_utils import GenerationTesterMixin34from ...test_configuration_common import ConfigTester35from ...test_modeling_common import ModelTesterMixin, ids_tensor36from ...test_pipeline_mixin import PipelineTesterMixin37
38
39if is_torch_available():40import torch41
42from transformers import NllbMoeForConditionalGeneration, NllbMoeModel, NllbTokenizer43from transformers.models.nllb_moe.modeling_nllb_moe import NllbMoeDecoder, NllbMoeEncoder, NllbMoeTop2Router44
45
46class NllbMoeModelTester:47def __init__(48self,49parent,50batch_size=13,51seq_length=7,52is_training=True,53use_labels=False,54vocab_size=99,55hidden_size=16,56num_hidden_layers=2,57num_attention_heads=4,58intermediate_size=4,59hidden_act="relu",60hidden_dropout_prob=0.1,61attention_probs_dropout_prob=0.1,62encoder_layerdrop=0.0,63decoder_layerdrop=0.0,64max_position_embeddings=20,65eos_token_id=2,66pad_token_id=1,67bos_token_id=0,68num_experts=4,69encoder_sparse_step=2,70decoder_sparse_step=1,71expert_capacity=100,72router_jitter_noise=0.0,73):74self.parent = parent75self.batch_size = batch_size76self.seq_length = seq_length77self.is_training = is_training78self.use_labels = use_labels79self.vocab_size = vocab_size80self.hidden_size = hidden_size81self.num_hidden_layers = num_hidden_layers82self.num_attention_heads = num_attention_heads83self.intermediate_size = intermediate_size84self.hidden_act = hidden_act85self.hidden_dropout_prob = hidden_dropout_prob86self.attention_probs_dropout_prob = attention_probs_dropout_prob87self.encoder_layerdrop = encoder_layerdrop88self.decoder_layerdrop = decoder_layerdrop89self.max_position_embeddings = max_position_embeddings90self.eos_token_id = eos_token_id91self.pad_token_id = pad_token_id92self.bos_token_id = bos_token_id93self.encoder_sparse_step = encoder_sparse_step94self.decoder_sparse_step = decoder_sparse_step95self.expert_capacity = expert_capacity96self.router_jitter_noise = router_jitter_noise97self.num_experts = num_experts98
99def prepare_nllb_moe_inputs_dict(100self,101config,102input_ids,103decoder_input_ids,104attention_mask=None,105decoder_attention_mask=None,106head_mask=None,107decoder_head_mask=None,108cross_attn_head_mask=None,109):110if attention_mask is None:111attention_mask = input_ids.ne(config.pad_token_id)112if decoder_attention_mask is None:113decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)114if head_mask is None:115head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)116if decoder_head_mask is None:117decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)118if cross_attn_head_mask is None:119cross_attn_head_mask = torch.ones(120config.decoder_layers, config.decoder_attention_heads, device=torch_device121)122return {123"input_ids": input_ids,124"decoder_input_ids": decoder_input_ids,125"attention_mask": attention_mask,126"decoder_attention_mask": attention_mask,127"head_mask": head_mask,128"decoder_head_mask": decoder_head_mask,129"cross_attn_head_mask": cross_attn_head_mask,130}131
132def prepare_config_and_inputs(self):133input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)134input_ids[:, -1] = self.eos_token_id # Eos Token135decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)136
137# we need to clamp the input ids here to avoid having pad token in between138# this is because for NllbMoe the position_ids are prepared such that139# all pad tokens have pos id = 2 and rest are between 2..seq_length140# and the seq_length here is seq_length - num_pad_tokens141# but when using past, there is no way of knowing if the past input ids had142# pad tokens in them, which results in incorrect seq_lenth and which in turn results in143# position_ids being off by num_pad_tokens in past input144input_ids = input_ids.clamp(self.pad_token_id + 1)145decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)146
147config = self.get_config()148inputs_dict = self.prepare_nllb_moe_inputs_dict(config, input_ids, decoder_input_ids)149return config, inputs_dict150
151def get_config(self):152return NllbMoeConfig(153vocab_size=self.vocab_size,154d_model=self.hidden_size,155encoder_layers=self.num_hidden_layers,156decoder_layers=self.num_hidden_layers,157encoder_attention_heads=self.num_attention_heads,158decoder_attention_heads=self.num_attention_heads,159encoder_ffn_dim=self.intermediate_size,160decoder_ffn_dim=self.intermediate_size,161dropout=self.hidden_dropout_prob,162attention_dropout=self.attention_probs_dropout_prob,163encoder_layerdrop=self.encoder_layerdrop,164decoder_layerdrop=self.decoder_layerdrop,165max_position_embeddings=self.max_position_embeddings,166eos_token_id=self.eos_token_id,167bos_token_id=self.bos_token_id,168pad_token_id=self.pad_token_id,169expert_capacity=self.expert_capacity,170router_jitter_noise=self.router_jitter_noise,171decoder_sparse_step=self.decoder_sparse_step,172encoder_sparse_step=self.encoder_sparse_step,173num_experts=self.num_experts,174)175
176def prepare_config_and_inputs_for_common(self):177config, inputs_dict = self.prepare_config_and_inputs()178return config, inputs_dict179
180@require_torch181def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):182model = NllbMoeModel(config=config).get_decoder().to(torch_device).eval()183input_ids = inputs_dict["input_ids"]184attention_mask = inputs_dict["attention_mask"]185head_mask = inputs_dict["head_mask"]186
187# first forward pass188outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)189
190output, past_key_values = outputs.to_tuple()191
192# create hypothetical multiple next token and extent to next_input_ids193next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)194next_attn_mask = ids_tensor((self.batch_size, 3), 2)195
196# append to next input_ids and197next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)198next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)199
200output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]201output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[202"last_hidden_state"203]204
205# select random slice206random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()207output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()208output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()209
210self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])211
212# test that outputs are equal for slice213self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))214
215def check_encoder_decoder_model_standalone(self, config, inputs_dict):216model = NllbMoeModel(config=config).to(torch_device).eval()217outputs = model(**inputs_dict)218
219encoder_last_hidden_state = outputs.encoder_last_hidden_state220last_hidden_state = outputs.last_hidden_state221
222with tempfile.TemporaryDirectory() as tmpdirname:223encoder = model.get_encoder()224encoder.save_pretrained(tmpdirname)225encoder = NllbMoeEncoder.from_pretrained(tmpdirname).to(torch_device)226
227encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[2280229]230
231self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)232
233with tempfile.TemporaryDirectory() as tmpdirname:234decoder = model.get_decoder()235decoder.save_pretrained(tmpdirname)236decoder = NllbMoeDecoder.from_pretrained(tmpdirname).to(torch_device)237
238last_hidden_state_2 = decoder(239input_ids=inputs_dict["decoder_input_ids"],240attention_mask=inputs_dict["decoder_attention_mask"],241encoder_hidden_states=encoder_last_hidden_state,242encoder_attention_mask=inputs_dict["attention_mask"],243)[0]244
245self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)246
247
248@require_torch
249class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):250all_model_classes = (NllbMoeModel, NllbMoeForConditionalGeneration) if is_torch_available() else ()251all_generative_model_classes = (NllbMoeForConditionalGeneration,) if is_torch_available() else ()252pipeline_model_mapping = (253{254"conversational": NllbMoeForConditionalGeneration,255"feature-extraction": NllbMoeModel,256"summarization": NllbMoeForConditionalGeneration,257"text2text-generation": NllbMoeForConditionalGeneration,258"translation": NllbMoeForConditionalGeneration,259}260if is_torch_available()261else {}262)263is_encoder_decoder = True264fx_compatible = False265test_pruning = False266test_missing_keys = True267test_torchscript = False268
269# TODO: Fix the failed tests when this model gets more usage270def is_pipeline_test_to_skip(271self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name272):273# Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.274return True275
276def setUp(self):277self.model_tester = NllbMoeModelTester(self)278self.config_tester = ConfigTester(self, config_class=NllbMoeConfig)279
280def test_config(self):281self.config_tester.run_common_tests()282
283def test_save_load_strict(self):284config, inputs_dict = self.model_tester.prepare_config_and_inputs()285for model_class in self.all_model_classes:286model = model_class(config)287
288with tempfile.TemporaryDirectory() as tmpdirname:289model.save_pretrained(tmpdirname)290model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)291self.assertEqual(info["missing_keys"], [])292
293def test_decoder_model_past_with_large_inputs(self):294config, inputs_dict = self.model_tester.prepare_config_and_inputs()295config.decoder_sparse_step = 0296self.model_tester.create_and_check_decoder_model_past_large_inputs(config, inputs_dict)297
298def test_encoder_decoder_model_standalone(self):299config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()300self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)301
302def test_inputs_embeds(self):303config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()304
305for model_class in (NllbMoeModel, NllbMoeForConditionalGeneration):306model = model_class(config)307model.to(torch_device)308model.eval()309
310inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))311
312if not self.is_encoder_decoder:313input_ids = inputs["input_ids"]314del inputs["input_ids"]315else:316encoder_input_ids = inputs["input_ids"]317decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)318del inputs["input_ids"]319inputs.pop("decoder_input_ids", None)320
321wte = model.get_input_embeddings()322if not self.is_encoder_decoder:323inputs["inputs_embeds"] = wte(input_ids)324else:325inputs["inputs_embeds"] = wte(encoder_input_ids)326inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)327
328with torch.no_grad():329model(**inputs)[0]330
331@require_torch_fp16332def test_generate_fp16(self):333config, input_dict = self.model_tester.prepare_config_and_inputs()334input_ids = input_dict["input_ids"]335attention_mask = input_ids.ne(1).to(torch_device)336model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)337model.half()338model.generate(input_ids, attention_mask=attention_mask)339model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)340
341def test_get_loss(self):342config, input_dict = self.model_tester.prepare_config_and_inputs()343input_dict["output_router_logits"] = True344input_dict["labels"] = input_dict["input_ids"]345model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)346out = model(**input_dict)347self.assertIsNotNone(out.loss)348self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])349self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])350
351
352@require_torch
353@require_sentencepiece
354@require_tokenizers
355@slow
356class NllbMoeModelIntegrationTests(unittest.TestCase):357@require_torch358@cached_property359def model_inputs(self):360return {361"input_ids": torch.LongTensor(362[363[28768, 248, 6399, 9, 65972, 452, 1925, 629, 123543, 248075, 2, 256047],364[117, 7027, 7195, 202, 44778, 248075, 2, 256047, 1, 1, 1, 1],365]366),367"attention_mask": torch.Tensor(368[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]369),370"decoder_input_ids": torch.LongTensor([[2, 256057], [2, 256057]]),371}372
373@cached_property374def tokenizer(self):375return NllbTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")376
377@cached_property378def big_model(self):379return NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b")380
381def inference_no_head(self):382model = NllbMoeModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()383with torch.no_grad():384output = model(**self.model_inputs)385# fmt: off386EXPECTED_ENCODER_STATE = torch.Tensor([ 0.3920, -0.1974, -0.0279, 0.3463, -0.8306, -1.0629, -0.4643, 2.0563, 1.1123, 0.3566, -0.9291, -0.3840, -0.2527, -0.9858, 1.5185, -1.1346, 0.0323, -0.9103, -0.3647, -0.4462, -0.9720, -0.3541, 0.1777, -0.4647, 1.6970, -0.9062, 0.2727, -1.0737, 0.8785, 0.4324])387EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01, 6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01, 7.5615e-01, 7.3555e-01, 2.3071e-01, 1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00, 4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01, 1.5902e-03, 1.0760e-01, 1.0298e-01, -9.3933e-01, -4.6567e-01, 8.0417e-01, 1.5243e+00, 5.5844e-01, -9.9239e-02, 1.4885e+00, 7.1527e-02, -5.2612e-01, 9.4435e-02])388# fmt: on389
390torch.testing.assert_allclose(391output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3392)393torch.testing.assert_allclose(394output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3395)396
397def test_inference_logits(self):398r"""399Logits testing to check implementation consistency between `fairseq` implementation
400and `transformers` implementation of NLLB-MoE transformers. We only check the logits
401of the second sample of the batch, as it is padded.
402"""
403model = NllbMoeForConditionalGeneration.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()404with torch.no_grad():405output = model(**self.model_inputs)406
407EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,]) # fmt: skip408torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)409
410@unittest.skip("This requires 300GB of RAM")411def test_large_logits(self):412model = self.big_model413with torch.no_grad():414output = model(**self.model_inputs)415
416# fmt: off417EXPECTED_ENCODER_STATE = torch.Tensor([ 0.1696, -0.0059, 0.0489, 0.0479, -0.4222, -0.2178, -0.1372, -0.0860, -0.4249, -0.0081, -0.1186, 0.6678, 0.0160, 0.4140, 0.1799, 0.0672, -0.4941, 0.0173, -0.0740, 0.0845, -0.2197, 0.4465, 0.2268, -0.1752, -0.0562, 0.1033, -0.0869, -0.5490, 0.0582, 0.2165])418EXPECTED_DECODER_STATE = torch.Tensor([ 0.0374, -0.1055, -0.1060, -0.1711, -0.0540, -0.1183, -0.0779, 0.0610, -0.0279, -0.0848, 0.0222, 0.0372, -0.0298, -0.0861, -0.0354, -0.0103, 0.0538, -0.0148, -0.0105, 0.0224, 0.0629, -0.0291, -0.0671, 0.0173, -0.0066, -0.0245, -0.0499, 0.0760, -0.0067, 0.0086])419EXPECTED_LOGTIS = torch.Tensor([ 0.3834, 0.2057, 4.5399, 0.8301, 0.4810, 0.9325, 0.9928, 0.9574, 0.5517, 0.9156, 0.2698, 0.6728, 0.7121, 0.3080, 0.4693, 0.5756, 1.0407, 0.2219, 0.3714, 0.5699, 0.5547, 0.8472, 0.3178, 0.1286, 0.1791, 0.9391, 0.5153, -0.2146, 0.1689, 0.6816])420# fmt: on421
422torch.testing.assert_allclose(423output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3424)425torch.testing.assert_allclose(426output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3427)428torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)429
430@unittest.skip("This requires 300GB of RAM")431def test_seq_to_seq_generation(self):432model = self.big_model433tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-moe-54b")434
435# first 6 samples of load_dataset("facebook/flores", "eng_Latn-fra_Latn"), devtest. Truth are very similar to the fairseq translation files436FIRST_6_FLORES_200 = [437'We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added.',438"Dr. Ehud Ur, professor of medicine at Dalhousie University in Halifax, Nova Scotia and chair of the clinical and scientific division of the Canadian Diabetes Association cautioned that the research is still in its early days.",439"Like some other experts, he is skeptical about whether diabetes can be cured, noting that these findings have no relevance to people who already have Type 1 diabetes.",440"On Monday, Sara Danius, permanent secretary of the Nobel Committee for Literature at the Swedish Academy, publicly announced during a radio program on Sveriges Radio in Sweden the committee, unable to reach Bob Dylan directly about winning the 2016 Nobel Prize in Literature, had abandoned its efforts to reach him.",441'Danius said, "Right now we are doing nothing. I have called and sent emails to his closest collaborator and received very friendly replies. For now, that is certainly enough."',442"Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage.",443]444inputs = tokenizer(FIRST_6_FLORES_200, padding=True, return_tensors="pt").to(torch_device)445batch_translation = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["fra_Latn"])446
447EXPECTED_FAIRSEQ_TRANSLATION = [448'"Nous avons maintenant des souris de 4 mois non diabétiques qui étaient diabétiques", a-t-il ajouté.',449"Le docteur Ehud Ur, professeur de médecine à l'université Dalhousie, à Halifax, en Nouvelle-Écosse, et président de la division clinique et scientifique de l'Association canadienne du diabète, prévient que la recherche n'en est qu'à ses débuts.",450"Comme d'autres spécialistes, il est sceptique quant à la guérison du diabète.",451"Lundi, Sara Danius, secrétaire permanente du Comité Nobel de littérature à l'Académie suédoise, a annoncé publiquement lors d'une émission de radio sur Sveriges Radio en Suède que le comité, incapable de joindre Bob Dylan directement pour lui annoncer le prix Nobel de littérature 2016, avait abandonné ses efforts pour le joindre.",452"Danius a déclaré: \"Pour l'instant, nous ne faisons rien. J'ai appelé et envoyé des courriels à son plus proche collaborateur et j'ai reçu des réponses très amicales. Pour l'instant, c'est certainement suffisant\".",453"Auparavant, le PDG de Ring, Jamie Siminoff, a fait remarquer que la société avait commencé lorsque sa sonnette n'était pas audible depuis son magasin dans son garage.",454]455
456translation = tokenizer.batch_decode(457batch_translation.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True458)459assert translation == EXPECTED_FAIRSEQ_TRANSLATION460
461
462@require_torch
463class NllbMoeRouterTest(unittest.TestCase):464r"""465Switch Transformers has different blocks from classic transformer based models.
466The Swift MLP contains a Router class, that has to be tested to check if it is correctly implemented
467
468Original implementation of the routers here:
469
470"""
471
472config = NllbMoeConfig(473num_experts=4,474hidden_size=32,475d_ff=16,476expert_capacity=4,477)478batch_size = 2479sequence_length = 20480
481def test_top_2_routing(self):482# test routing with minimal reproduction483mask = torch.ones((self.batch_size, self.sequence_length), dtype=torch.bool)484mask[0][0] = False485mask[1][0] = False486mask = mask.reshape(-1)487set_seed(0)488hidden_states = torch.rand((self.batch_size, self.sequence_length, self.config.hidden_size))489classfier = torch.nn.Linear(self.config.hidden_size, self.config.num_experts)490hf_router = NllbMoeTop2Router(self.config)491
492_, _, hidden_dim = hidden_states.shape493logits = classfier(hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim))494top_1_mask, router_probs = hf_router.route_tokens(logits, padding_mask=mask)495torch.argmax(top_1_mask, dim=-1)496router_mask = router_probs.bool()497set_seed(0)498experts = [499torch.nn.Linear(hidden_dim, hidden_dim),500torch.nn.Linear(hidden_dim, hidden_dim),501torch.nn.Linear(hidden_dim, hidden_dim),502torch.nn.Linear(hidden_dim, hidden_dim),503]504hidden_states = hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim)505masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask)506for idx, expert in enumerate(experts):507token_indices = router_mask[:, idx]508combining_weights = router_probs[token_indices, idx]509expert_output = expert(masked_hidden_states[idx, token_indices])510expert_output *= 1 - self.config.moe_token_dropout511masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output)512hidden_states = masked_hidden_states.sum(dim=0).reshape(self.batch_size, self.sequence_length, hidden_dim)513
514EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES = torch.Tensor([[ 7.0340e-04, 2.7997e-03, -1.3351e-02, -7.6705e-03, -3.5089e-03,3.9773e-03, 7.4593e-03, 1.2566e-02, 3.5860e-03, -2.7448e-02,-1.3731e-02, -1.0534e-02, -1.3606e-02, -1.5048e-02, -2.8914e-03,-5.0371e-03, -1.3963e-03, 6.0076e-03, -1.1380e-02, -1.4620e-02, 5.2401e-03, 8.4660e-04, -1.5319e-03, -1.6735e-02, 1.1302e-02, 3.6119e-03, 4.6084e-03, -1.3458e-02, 7.7792e-05, 1.4312e-02, 4.9107e-03, -5.0936e-03], [-4.4538e-03, 3.1026e-03, 1.4121e-04, -4.8121e-03, -5.6279e-03, 7.2493e-03, 3.9769e-03, 1.1114e-02, -1.5666e-03, -2.3477e-02, 8.7268e-03, 1.3446e-02, -2.8845e-05, -1.7287e-02, 8.7619e-03, -4.5316e-03, -1.2164e-02, 5.7461e-03, -4.5861e-03, -9.3907e-03, 2.9808e-02, 8.9206e-04, -7.6232e-04, -1.4173e-02, 3.0208e-03, 1.5310e-02, 9.7717e-03, 3.1014e-03, 7.8042e-03, 8.0197e-03, 3.4784e-03, -7.1728e-03]]) # fmt: skip515self.assertTrue(torch.allclose(hidden_states.mean(1), EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES, 1e-4))516
517def test_batch_prioritized_routing(self):518set_seed(0)519config = NllbMoeConfig(520num_experts=4, hidden_size=32, d_ff=16, expert_capacity=4, second_expert_policy="random"521)522mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)523logits = torch.rand((self.batch_size * self.sequence_length, 4))524config.batch_prioritized_routing = True525router = NllbMoeTop2Router(config)526top_1_mask, _ = router.route_tokens(logits, padding_mask=mask)527# check that the routing is batch first. One of the last token is routed while expert capacity is very small528# this means that it had a greater probability of being routed529assert top_1_mask[-1, 0] == 1530
531def test_second_expert_policy(self):532config = NllbMoeConfig(533num_experts=4,534hidden_size=32,535d_ff=16,536expert_capacity=40,537)538set_seed(0)539mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)540logits = torch.rand((self.batch_size * self.sequence_length, 4))541
542set_seed(0)543config.second_expert_policy = "random"544router = NllbMoeTop2Router(config)545top_1_mask, router_probs = router.route_tokens(logits, padding_mask=mask)546
547set_seed(0)548config.second_expert_policy = "sampling"549router = NllbMoeTop2Router(config)550top_1_mask_sp, router_probs_sp = router.route_tokens(logits, padding_mask=mask)551
552set_seed(0)553config.second_expert_policy = "all"554router = NllbMoeTop2Router(config)555top_1_mask_all, router_probs_all = router.route_tokens(logits, padding_mask=mask)556
557# fmt: off558EXPECTED_ROUTER_ALL = torch.tensor([[0.3902, 0.0000, 0.0000, 0.6098], [0.0000, 0.0000, 0.7770, 0.2230], [0.0000, 0.0000, 0.2726, 0.7274], [0.4221, 0.0000, 0.5779, 0.0000], [0.0000, 0.0000, 0.7810, 0.2190], [0.5518, 0.4482, 0.0000, 0.0000], [0.0000, 0.4060, 0.5940, 0.0000], [0.7340, 0.0000, 0.0000, 0.2660], [0.4778, 0.5222, 0.0000, 0.0000], [0.0000, 0.3984, 0.0000, 0.6016], [0.0000, 0.0548, 0.9452, 0.0000], [0.6796, 0.0000, 0.0000, 0.3204], [0.0700, 0.0000, 0.9300, 0.0000], [0.1854, 0.0000, 0.8146, 0.0000], [0.6775, 0.3225, 0.0000, 0.0000], [0.0000, 0.0000, 0.5027, 0.4973], [0.0000, 0.6577, 0.0000, 0.3423], [0.0000, 0.7767, 0.0000, 0.2233], [0.1944, 0.8056, 0.0000, 0.0000], [0.0000, 0.3073, 0.0000, 0.6927], [0.0000, 0.5655, 0.4345, 0.0000], [0.5791, 0.0000, 0.0000, 0.4209], [0.0440, 0.0000, 0.9560, 0.0000], [0.0083, 0.9917, 0.0000, 0.0000], [0.0000, 0.8395, 0.0000, 0.1605], [0.0000, 0.1458, 0.0000, 0.8542], [0.0000, 0.8534, 0.1466, 0.0000], [0.4938, 0.0000, 0.0000, 0.5062], [0.1329, 0.8671, 0.0000, 0.0000], [0.3058, 0.0000, 0.6942, 0.0000], [0.4458, 0.0000, 0.0000, 0.5542], [0.9053, 0.0947, 0.0000, 0.0000], [0.0000, 0.7563, 0.2437, 0.0000], [0.0000, 0.0000, 0.4096, 0.5904], [0.4551, 0.0000, 0.0000, 0.5449], [0.8502, 0.1498, 0.0000, 0.0000], [0.0000, 0.6312, 0.3688, 0.0000], [0.8920, 0.0000, 0.0000, 0.1080], [0.1913, 0.0000, 0.0000, 0.8087], [0.2491, 0.7509, 0.0000, 0.0000]])559EXPECTED_ROUTER_SP = torch.tensor([[0.0000, 0.6539, 0.0000, 0.3461], [0.0000, 0.0000, 0.3998, 0.6002], [0.0000, 0.5574, 0.0000, 0.4426], [0.0000, 0.0000, 0.4441, 0.5559], [0.0000, 0.6545, 0.3455, 0.0000], [0.4419, 0.5581, 0.0000, 0.0000], [0.0000, 0.4014, 0.5986, 0.0000], [0.3215, 0.0000, 0.0000, 0.6785], [0.4765, 0.5235, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.4156, 0.5844, 0.0000], [0.3370, 0.0000, 0.6630, 0.0000], [0.0000, 0.0000, 0.4558, 0.5442], [0.4659, 0.0000, 0.5341, 0.0000], [0.6179, 0.3821, 0.0000, 0.0000], [0.6277, 0.0000, 0.3723, 0.0000], [0.5836, 0.4164, 0.0000, 0.0000], [0.0000, 0.6600, 0.0000, 0.3400], [0.0000, 0.4933, 0.0000, 0.5067], [0.6016, 0.0000, 0.0000, 0.3984], [0.0000, 0.5160, 0.4840, 0.0000], [0.5799, 0.0000, 0.0000, 0.4201], [0.0000, 0.0000, 0.4826, 0.5174], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [0.6448, 0.0000, 0.0000, 0.3552], [0.0000, 0.5909, 0.4091, 0.0000], [0.4196, 0.0000, 0.0000, 0.5804], [0.3191, 0.6809, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.4123, 0.0000, 0.5877, 0.0000], [0.0000, 0.3736, 0.0000, 0.6264], [0.0000, 0.0000, 0.6009, 0.3991], [0.4246, 0.0000, 0.0000, 0.5754], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.3595, 0.6405, 0.0000], [0.5433, 0.0000, 0.0000, 0.4567], [0.0000, 0.6806, 0.0000, 0.3194], [0.6689, 0.3311, 0.0000, 0.0000]])560EXPECTED_ROUTER = torch.tensor([[0.4324, 0.5676, 0.0000, 0.0000], [0.0000, 0.4348, 0.0000, 0.5652], [0.4559, 0.5441, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.4744, 0.5256, 0.0000, 0.0000], [0.0000, 0.5103, 0.0000, 0.4897], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 1.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.5063, 0.4937, 0.0000, 0.0000], [0.5396, 0.0000, 0.0000, 0.4604], [0.4576, 0.5424, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.5134, 0.0000, 0.4866, 0.0000], [0.0000, 0.5160, 0.4840, 0.0000], [0.5439, 0.0000, 0.4561, 0.0000], [0.4849, 0.0000, 0.0000, 0.5151], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.4448, 0.0000, 0.5552], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.0000, 0.0000, 0.5296, 0.4704], [0.0000, 0.0000, 0.4469, 0.5531], [0.0000, 0.4053, 0.5947, 0.0000], [0.0000, 0.0000, 0.4460, 0.5540], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.0000, 0.5851, 0.4149], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.5010, 0.4990, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000]])561
562EXPECTED_TOP_1_ALL = torch.LongTensor([[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0]])563EXPECTED_TOP_1_SP = torch.LongTensor([[0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0]])564# `sampling` and `random` do not affect the mask of the top_1 router565# fmt: on566
567torch.testing.assert_allclose(router_probs_all, EXPECTED_ROUTER_ALL, 1e-4, 1e-4)568torch.testing.assert_allclose(router_probs_sp, EXPECTED_ROUTER_SP, 1e-4, 1e-4)569torch.testing.assert_allclose(router_probs, EXPECTED_ROUTER, 1e-4, 1e-4)570
571torch.testing.assert_allclose(top_1_mask_all, EXPECTED_TOP_1_ALL, 1e-4, 1e-4)572torch.testing.assert_allclose(top_1_mask_sp, EXPECTED_TOP_1_SP, 1e-4, 1e-4)573torch.testing.assert_allclose(top_1_mask, EXPECTED_TOP_1_SP, 1e-4, 1e-4)574