transformers

Форк
0
/
test_modeling_nllb_moe.py 
573 строки · 33.3 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch NLLB-MoE model. """
16

17

18
import copy
19
import tempfile
20
import unittest
21

22
from transformers import NllbMoeConfig, is_torch_available, set_seed
23
from transformers.testing_utils import (
24
    require_sentencepiece,
25
    require_tokenizers,
26
    require_torch,
27
    require_torch_fp16,
28
    slow,
29
    torch_device,
30
)
31
from transformers.utils import cached_property
32

33
from ...generation.test_utils import GenerationTesterMixin
34
from ...test_configuration_common import ConfigTester
35
from ...test_modeling_common import ModelTesterMixin, ids_tensor
36
from ...test_pipeline_mixin import PipelineTesterMixin
37

38

39
if is_torch_available():
40
    import torch
41

42
    from transformers import NllbMoeForConditionalGeneration, NllbMoeModel, NllbTokenizer
43
    from transformers.models.nllb_moe.modeling_nllb_moe import NllbMoeDecoder, NllbMoeEncoder, NllbMoeTop2Router
44

45

46
class NllbMoeModelTester:
47
    def __init__(
48
        self,
49
        parent,
50
        batch_size=13,
51
        seq_length=7,
52
        is_training=True,
53
        use_labels=False,
54
        vocab_size=99,
55
        hidden_size=16,
56
        num_hidden_layers=2,
57
        num_attention_heads=4,
58
        intermediate_size=4,
59
        hidden_act="relu",
60
        hidden_dropout_prob=0.1,
61
        attention_probs_dropout_prob=0.1,
62
        encoder_layerdrop=0.0,
63
        decoder_layerdrop=0.0,
64
        max_position_embeddings=20,
65
        eos_token_id=2,
66
        pad_token_id=1,
67
        bos_token_id=0,
68
        num_experts=4,
69
        encoder_sparse_step=2,
70
        decoder_sparse_step=1,
71
        expert_capacity=100,
72
        router_jitter_noise=0.0,
73
    ):
74
        self.parent = parent
75
        self.batch_size = batch_size
76
        self.seq_length = seq_length
77
        self.is_training = is_training
78
        self.use_labels = use_labels
79
        self.vocab_size = vocab_size
80
        self.hidden_size = hidden_size
81
        self.num_hidden_layers = num_hidden_layers
82
        self.num_attention_heads = num_attention_heads
83
        self.intermediate_size = intermediate_size
84
        self.hidden_act = hidden_act
85
        self.hidden_dropout_prob = hidden_dropout_prob
86
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
87
        self.encoder_layerdrop = encoder_layerdrop
88
        self.decoder_layerdrop = decoder_layerdrop
89
        self.max_position_embeddings = max_position_embeddings
90
        self.eos_token_id = eos_token_id
91
        self.pad_token_id = pad_token_id
92
        self.bos_token_id = bos_token_id
93
        self.encoder_sparse_step = encoder_sparse_step
94
        self.decoder_sparse_step = decoder_sparse_step
95
        self.expert_capacity = expert_capacity
96
        self.router_jitter_noise = router_jitter_noise
97
        self.num_experts = num_experts
98

99
    def prepare_nllb_moe_inputs_dict(
100
        self,
101
        config,
102
        input_ids,
103
        decoder_input_ids,
104
        attention_mask=None,
105
        decoder_attention_mask=None,
106
        head_mask=None,
107
        decoder_head_mask=None,
108
        cross_attn_head_mask=None,
109
    ):
110
        if attention_mask is None:
111
            attention_mask = input_ids.ne(config.pad_token_id)
112
        if decoder_attention_mask is None:
113
            decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
114
        if head_mask is None:
115
            head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
116
        if decoder_head_mask is None:
117
            decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
118
        if cross_attn_head_mask is None:
119
            cross_attn_head_mask = torch.ones(
120
                config.decoder_layers, config.decoder_attention_heads, device=torch_device
121
            )
122
        return {
123
            "input_ids": input_ids,
124
            "decoder_input_ids": decoder_input_ids,
125
            "attention_mask": attention_mask,
126
            "decoder_attention_mask": attention_mask,
127
            "head_mask": head_mask,
128
            "decoder_head_mask": decoder_head_mask,
129
            "cross_attn_head_mask": cross_attn_head_mask,
130
        }
131

132
    def prepare_config_and_inputs(self):
133
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
134
        input_ids[:, -1] = self.eos_token_id  # Eos Token
135
        decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
136

137
        # we need to clamp the input ids here to avoid having pad token in between
138
        # this is because for NllbMoe the position_ids are prepared such that
139
        # all pad tokens have pos id = 2 and rest are between 2..seq_length
140
        # and the seq_length here is seq_length - num_pad_tokens
141
        # but when using past, there is no way of knowing if the past input ids had
142
        # pad tokens in them, which results in incorrect seq_lenth and which in turn results in
143
        # position_ids being off by num_pad_tokens in past input
144
        input_ids = input_ids.clamp(self.pad_token_id + 1)
145
        decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
146

147
        config = self.get_config()
148
        inputs_dict = self.prepare_nllb_moe_inputs_dict(config, input_ids, decoder_input_ids)
149
        return config, inputs_dict
150

151
    def get_config(self):
152
        return NllbMoeConfig(
153
            vocab_size=self.vocab_size,
154
            d_model=self.hidden_size,
155
            encoder_layers=self.num_hidden_layers,
156
            decoder_layers=self.num_hidden_layers,
157
            encoder_attention_heads=self.num_attention_heads,
158
            decoder_attention_heads=self.num_attention_heads,
159
            encoder_ffn_dim=self.intermediate_size,
160
            decoder_ffn_dim=self.intermediate_size,
161
            dropout=self.hidden_dropout_prob,
162
            attention_dropout=self.attention_probs_dropout_prob,
163
            encoder_layerdrop=self.encoder_layerdrop,
164
            decoder_layerdrop=self.decoder_layerdrop,
165
            max_position_embeddings=self.max_position_embeddings,
166
            eos_token_id=self.eos_token_id,
167
            bos_token_id=self.bos_token_id,
168
            pad_token_id=self.pad_token_id,
169
            expert_capacity=self.expert_capacity,
170
            router_jitter_noise=self.router_jitter_noise,
171
            decoder_sparse_step=self.decoder_sparse_step,
172
            encoder_sparse_step=self.encoder_sparse_step,
173
            num_experts=self.num_experts,
174
        )
175

176
    def prepare_config_and_inputs_for_common(self):
177
        config, inputs_dict = self.prepare_config_and_inputs()
178
        return config, inputs_dict
179

180
    @require_torch
181
    def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
182
        model = NllbMoeModel(config=config).get_decoder().to(torch_device).eval()
183
        input_ids = inputs_dict["input_ids"]
184
        attention_mask = inputs_dict["attention_mask"]
185
        head_mask = inputs_dict["head_mask"]
186

187
        # first forward pass
188
        outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
189

190
        output, past_key_values = outputs.to_tuple()
191

192
        # create hypothetical multiple next token and extent to next_input_ids
193
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
194
        next_attn_mask = ids_tensor((self.batch_size, 3), 2)
195

196
        # append to next input_ids and
197
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
198
        next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
199

200
        output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
201
        output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
202
            "last_hidden_state"
203
        ]
204

205
        # select random slice
206
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
207
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
208
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
209

210
        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
211

212
        # test that outputs are equal for slice
213
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
214

215
    def check_encoder_decoder_model_standalone(self, config, inputs_dict):
216
        model = NllbMoeModel(config=config).to(torch_device).eval()
217
        outputs = model(**inputs_dict)
218

219
        encoder_last_hidden_state = outputs.encoder_last_hidden_state
220
        last_hidden_state = outputs.last_hidden_state
221

222
        with tempfile.TemporaryDirectory() as tmpdirname:
223
            encoder = model.get_encoder()
224
            encoder.save_pretrained(tmpdirname)
225
            encoder = NllbMoeEncoder.from_pretrained(tmpdirname).to(torch_device)
226

227
        encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
228
            0
229
        ]
230

231
        self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
232

233
        with tempfile.TemporaryDirectory() as tmpdirname:
234
            decoder = model.get_decoder()
235
            decoder.save_pretrained(tmpdirname)
236
            decoder = NllbMoeDecoder.from_pretrained(tmpdirname).to(torch_device)
237

238
        last_hidden_state_2 = decoder(
239
            input_ids=inputs_dict["decoder_input_ids"],
240
            attention_mask=inputs_dict["decoder_attention_mask"],
241
            encoder_hidden_states=encoder_last_hidden_state,
242
            encoder_attention_mask=inputs_dict["attention_mask"],
243
        )[0]
244

245
        self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
246

247

248
@require_torch
249
class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
250
    all_model_classes = (NllbMoeModel, NllbMoeForConditionalGeneration) if is_torch_available() else ()
251
    all_generative_model_classes = (NllbMoeForConditionalGeneration,) if is_torch_available() else ()
252
    pipeline_model_mapping = (
253
        {
254
            "conversational": NllbMoeForConditionalGeneration,
255
            "feature-extraction": NllbMoeModel,
256
            "summarization": NllbMoeForConditionalGeneration,
257
            "text2text-generation": NllbMoeForConditionalGeneration,
258
            "translation": NllbMoeForConditionalGeneration,
259
        }
260
        if is_torch_available()
261
        else {}
262
    )
263
    is_encoder_decoder = True
264
    fx_compatible = False
265
    test_pruning = False
266
    test_missing_keys = True
267
    test_torchscript = False
268

269
    # TODO: Fix the failed tests when this model gets more usage
270
    def is_pipeline_test_to_skip(
271
        self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
272
    ):
273
        # Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.
274
        return True
275

276
    def setUp(self):
277
        self.model_tester = NllbMoeModelTester(self)
278
        self.config_tester = ConfigTester(self, config_class=NllbMoeConfig)
279

280
    def test_config(self):
281
        self.config_tester.run_common_tests()
282

283
    def test_save_load_strict(self):
284
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
285
        for model_class in self.all_model_classes:
286
            model = model_class(config)
287

288
            with tempfile.TemporaryDirectory() as tmpdirname:
289
                model.save_pretrained(tmpdirname)
290
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
291
            self.assertEqual(info["missing_keys"], [])
292

293
    def test_decoder_model_past_with_large_inputs(self):
294
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
295
        config.decoder_sparse_step = 0
296
        self.model_tester.create_and_check_decoder_model_past_large_inputs(config, inputs_dict)
297

298
    def test_encoder_decoder_model_standalone(self):
299
        config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
300
        self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
301

302
    def test_inputs_embeds(self):
303
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
304

305
        for model_class in (NllbMoeModel, NllbMoeForConditionalGeneration):
306
            model = model_class(config)
307
            model.to(torch_device)
308
            model.eval()
309

310
            inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
311

312
            if not self.is_encoder_decoder:
313
                input_ids = inputs["input_ids"]
314
                del inputs["input_ids"]
315
            else:
316
                encoder_input_ids = inputs["input_ids"]
317
                decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
318
                del inputs["input_ids"]
319
                inputs.pop("decoder_input_ids", None)
320

321
            wte = model.get_input_embeddings()
322
            if not self.is_encoder_decoder:
323
                inputs["inputs_embeds"] = wte(input_ids)
324
            else:
325
                inputs["inputs_embeds"] = wte(encoder_input_ids)
326
                inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
327

328
            with torch.no_grad():
329
                model(**inputs)[0]
330

331
    @require_torch_fp16
332
    def test_generate_fp16(self):
333
        config, input_dict = self.model_tester.prepare_config_and_inputs()
334
        input_ids = input_dict["input_ids"]
335
        attention_mask = input_ids.ne(1).to(torch_device)
336
        model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
337
        model.half()
338
        model.generate(input_ids, attention_mask=attention_mask)
339
        model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
340

341
    def test_get_loss(self):
342
        config, input_dict = self.model_tester.prepare_config_and_inputs()
343
        input_dict["output_router_logits"] = True
344
        input_dict["labels"] = input_dict["input_ids"]
345
        model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
346
        out = model(**input_dict)
347
        self.assertIsNotNone(out.loss)
348
        self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
349
        self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
350

351

352
@require_torch
353
@require_sentencepiece
354
@require_tokenizers
355
@slow
356
class NllbMoeModelIntegrationTests(unittest.TestCase):
357
    @require_torch
358
    @cached_property
359
    def model_inputs(self):
360
        return {
361
            "input_ids": torch.LongTensor(
362
                [
363
                    [28768, 248, 6399, 9, 65972, 452, 1925, 629, 123543, 248075, 2, 256047],
364
                    [117, 7027, 7195, 202, 44778, 248075, 2, 256047, 1, 1, 1, 1],
365
                ]
366
            ),
367
            "attention_mask": torch.Tensor(
368
                [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
369
            ),
370
            "decoder_input_ids": torch.LongTensor([[2, 256057], [2, 256057]]),
371
        }
372

373
    @cached_property
374
    def tokenizer(self):
375
        return NllbTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
376

377
    @cached_property
378
    def big_model(self):
379
        return NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b")
380

381
    def inference_no_head(self):
382
        model = NllbMoeModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()
383
        with torch.no_grad():
384
            output = model(**self.model_inputs)
385
        # fmt: off
386
        EXPECTED_ENCODER_STATE = torch.Tensor([ 0.3920, -0.1974, -0.0279,  0.3463, -0.8306, -1.0629, -0.4643,  2.0563, 1.1123,  0.3566, -0.9291, -0.3840, -0.2527, -0.9858,  1.5185, -1.1346, 0.0323, -0.9103, -0.3647, -0.4462, -0.9720, -0.3541,  0.1777, -0.4647, 1.6970, -0.9062,  0.2727, -1.0737,  0.8785,  0.4324])
387
        EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01,  6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01,  7.5615e-01,  7.3555e-01,  2.3071e-01,  1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00,  4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01,  1.5902e-03,  1.0760e-01,  1.0298e-01, -9.3933e-01, -4.6567e-01,  8.0417e-01,  1.5243e+00,  5.5844e-01, -9.9239e-02,  1.4885e+00,  7.1527e-02, -5.2612e-01,  9.4435e-02])
388
        # fmt: on
389

390
        torch.testing.assert_allclose(
391
            output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
392
        )
393
        torch.testing.assert_allclose(
394
            output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3
395
        )
396

397
    def test_inference_logits(self):
398
        r"""
399
        Logits testing to check implementation consistency between `fairseq` implementation
400
        and `transformers` implementation of NLLB-MoE transformers. We only check the logits
401
        of the second sample of the batch, as it is padded.
402
        """
403
        model = NllbMoeForConditionalGeneration.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()
404
        with torch.no_grad():
405
            output = model(**self.model_inputs)
406

407
        EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,])  # fmt: skip
408
        torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
409

410
    @unittest.skip("This requires 300GB of RAM")
411
    def test_large_logits(self):
412
        model = self.big_model
413
        with torch.no_grad():
414
            output = model(**self.model_inputs)
415

416
        # fmt: off
417
        EXPECTED_ENCODER_STATE = torch.Tensor([ 0.1696, -0.0059,  0.0489,  0.0479, -0.4222, -0.2178, -0.1372, -0.0860, -0.4249, -0.0081, -0.1186,  0.6678,  0.0160,  0.4140,  0.1799,  0.0672, -0.4941,  0.0173, -0.0740,  0.0845, -0.2197,  0.4465,  0.2268, -0.1752, -0.0562,  0.1033, -0.0869, -0.5490,  0.0582,  0.2165])
418
        EXPECTED_DECODER_STATE = torch.Tensor([ 0.0374, -0.1055, -0.1060, -0.1711, -0.0540, -0.1183, -0.0779,  0.0610, -0.0279, -0.0848,  0.0222,  0.0372, -0.0298, -0.0861, -0.0354, -0.0103,  0.0538, -0.0148, -0.0105,  0.0224,  0.0629, -0.0291, -0.0671,  0.0173, -0.0066, -0.0245, -0.0499,  0.0760, -0.0067,  0.0086])
419
        EXPECTED_LOGTIS = torch.Tensor([ 0.3834,  0.2057,  4.5399,  0.8301,  0.4810,  0.9325,  0.9928,  0.9574,  0.5517,  0.9156,  0.2698,  0.6728,  0.7121,  0.3080,  0.4693,  0.5756,  1.0407,  0.2219,  0.3714,  0.5699,  0.5547,  0.8472,  0.3178,  0.1286,  0.1791,  0.9391,  0.5153, -0.2146,  0.1689,  0.6816])
420
        # fmt: on
421

422
        torch.testing.assert_allclose(
423
            output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
424
        )
425
        torch.testing.assert_allclose(
426
            output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3
427
        )
428
        torch.testing.assert_allclose(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
429

430
    @unittest.skip("This requires 300GB of RAM")
431
    def test_seq_to_seq_generation(self):
432
        model = self.big_model
433
        tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-moe-54b")
434

435
        # first 6 samples of load_dataset("facebook/flores", "eng_Latn-fra_Latn"), devtest. Truth are very similar to the fairseq translation files
436
        FIRST_6_FLORES_200 = [
437
            'We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added.',
438
            "Dr. Ehud Ur, professor of medicine at Dalhousie University in Halifax, Nova Scotia and chair of the clinical and scientific division of the Canadian Diabetes Association cautioned that the research is still in its early days.",
439
            "Like some other experts, he is skeptical about whether diabetes can be cured, noting that these findings have no relevance to people who already have Type 1 diabetes.",
440
            "On Monday, Sara Danius, permanent secretary of the Nobel Committee for Literature at the Swedish Academy, publicly announced during a radio program on Sveriges Radio in Sweden the committee, unable to reach Bob Dylan directly about winning the 2016 Nobel Prize in Literature, had abandoned its efforts to reach him.",
441
            'Danius said, "Right now we are doing nothing. I have called and sent emails to his closest collaborator and received very friendly replies. For now, that is certainly enough."',
442
            "Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage.",
443
        ]
444
        inputs = tokenizer(FIRST_6_FLORES_200, padding=True, return_tensors="pt").to(torch_device)
445
        batch_translation = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["fra_Latn"])
446

447
        EXPECTED_FAIRSEQ_TRANSLATION = [
448
            '"Nous avons maintenant des souris de 4 mois non diabétiques qui étaient diabétiques", a-t-il ajouté.',
449
            "Le docteur Ehud Ur, professeur de médecine à l'université Dalhousie, à Halifax, en Nouvelle-Écosse, et président de la division clinique et scientifique de l'Association canadienne du diabète, prévient que la recherche n'en est qu'à ses débuts.",
450
            "Comme d'autres spécialistes, il est sceptique quant à la guérison du diabète.",
451
            "Lundi, Sara Danius, secrétaire permanente du Comité Nobel de littérature à l'Académie suédoise, a annoncé publiquement lors d'une émission de radio sur Sveriges Radio en Suède que le comité, incapable de joindre Bob Dylan directement pour lui annoncer le prix Nobel de littérature 2016, avait abandonné ses efforts pour le joindre.",
452
            "Danius a déclaré: \"Pour l'instant, nous ne faisons rien. J'ai appelé et envoyé des courriels à son plus proche collaborateur et j'ai reçu des réponses très amicales. Pour l'instant, c'est certainement suffisant\".",
453
            "Auparavant, le PDG de Ring, Jamie Siminoff, a fait remarquer que la société avait commencé lorsque sa sonnette n'était pas audible depuis son magasin dans son garage.",
454
        ]
455

456
        translation = tokenizer.batch_decode(
457
            batch_translation.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
458
        )
459
        assert translation == EXPECTED_FAIRSEQ_TRANSLATION
460

461

462
@require_torch
463
class NllbMoeRouterTest(unittest.TestCase):
464
    r"""
465
    Switch Transformers has different blocks from classic transformer based models.
466
    The Swift MLP contains a Router class, that has to be tested to check if it is correctly implemented
467

468
    Original implementation of the routers here:
469

470
    """
471

472
    config = NllbMoeConfig(
473
        num_experts=4,
474
        hidden_size=32,
475
        d_ff=16,
476
        expert_capacity=4,
477
    )
478
    batch_size = 2
479
    sequence_length = 20
480

481
    def test_top_2_routing(self):
482
        # test routing with minimal reproduction
483
        mask = torch.ones((self.batch_size, self.sequence_length), dtype=torch.bool)
484
        mask[0][0] = False
485
        mask[1][0] = False
486
        mask = mask.reshape(-1)
487
        set_seed(0)
488
        hidden_states = torch.rand((self.batch_size, self.sequence_length, self.config.hidden_size))
489
        classfier = torch.nn.Linear(self.config.hidden_size, self.config.num_experts)
490
        hf_router = NllbMoeTop2Router(self.config)
491

492
        _, _, hidden_dim = hidden_states.shape
493
        logits = classfier(hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim))
494
        top_1_mask, router_probs = hf_router.route_tokens(logits, padding_mask=mask)
495
        torch.argmax(top_1_mask, dim=-1)
496
        router_mask = router_probs.bool()
497
        set_seed(0)
498
        experts = [
499
            torch.nn.Linear(hidden_dim, hidden_dim),
500
            torch.nn.Linear(hidden_dim, hidden_dim),
501
            torch.nn.Linear(hidden_dim, hidden_dim),
502
            torch.nn.Linear(hidden_dim, hidden_dim),
503
        ]
504
        hidden_states = hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim)
505
        masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask)
506
        for idx, expert in enumerate(experts):
507
            token_indices = router_mask[:, idx]
508
            combining_weights = router_probs[token_indices, idx]
509
            expert_output = expert(masked_hidden_states[idx, token_indices])
510
            expert_output *= 1 - self.config.moe_token_dropout
511
            masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output)
512
        hidden_states = masked_hidden_states.sum(dim=0).reshape(self.batch_size, self.sequence_length, hidden_dim)
513

514
        EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES = torch.Tensor([[ 7.0340e-04,  2.7997e-03, -1.3351e-02, -7.6705e-03, -3.5089e-03,3.9773e-03,  7.4593e-03,  1.2566e-02,  3.5860e-03, -2.7448e-02,-1.3731e-02, -1.0534e-02, -1.3606e-02, -1.5048e-02, -2.8914e-03,-5.0371e-03, -1.3963e-03,  6.0076e-03, -1.1380e-02, -1.4620e-02, 5.2401e-03,  8.4660e-04, -1.5319e-03, -1.6735e-02,  1.1302e-02, 3.6119e-03,  4.6084e-03, -1.3458e-02,  7.7792e-05,  1.4312e-02, 4.9107e-03, -5.0936e-03], [-4.4538e-03,  3.1026e-03,  1.4121e-04, -4.8121e-03, -5.6279e-03, 7.2493e-03,  3.9769e-03,  1.1114e-02, -1.5666e-03, -2.3477e-02, 8.7268e-03,  1.3446e-02, -2.8845e-05, -1.7287e-02,  8.7619e-03, -4.5316e-03, -1.2164e-02,  5.7461e-03, -4.5861e-03, -9.3907e-03, 2.9808e-02,  8.9206e-04, -7.6232e-04, -1.4173e-02,  3.0208e-03, 1.5310e-02,  9.7717e-03,  3.1014e-03,  7.8042e-03,  8.0197e-03, 3.4784e-03, -7.1728e-03]])  # fmt: skip
515
        self.assertTrue(torch.allclose(hidden_states.mean(1), EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES, 1e-4))
516

517
    def test_batch_prioritized_routing(self):
518
        set_seed(0)
519
        config = NllbMoeConfig(
520
            num_experts=4, hidden_size=32, d_ff=16, expert_capacity=4, second_expert_policy="random"
521
        )
522
        mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)
523
        logits = torch.rand((self.batch_size * self.sequence_length, 4))
524
        config.batch_prioritized_routing = True
525
        router = NllbMoeTop2Router(config)
526
        top_1_mask, _ = router.route_tokens(logits, padding_mask=mask)
527
        # check that the routing is batch first. One of the last token is routed while expert capacity is very small
528
        # this means that it had a greater probability of being routed
529
        assert top_1_mask[-1, 0] == 1
530

531
    def test_second_expert_policy(self):
532
        config = NllbMoeConfig(
533
            num_experts=4,
534
            hidden_size=32,
535
            d_ff=16,
536
            expert_capacity=40,
537
        )
538
        set_seed(0)
539
        mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)
540
        logits = torch.rand((self.batch_size * self.sequence_length, 4))
541

542
        set_seed(0)
543
        config.second_expert_policy = "random"
544
        router = NllbMoeTop2Router(config)
545
        top_1_mask, router_probs = router.route_tokens(logits, padding_mask=mask)
546

547
        set_seed(0)
548
        config.second_expert_policy = "sampling"
549
        router = NllbMoeTop2Router(config)
550
        top_1_mask_sp, router_probs_sp = router.route_tokens(logits, padding_mask=mask)
551

552
        set_seed(0)
553
        config.second_expert_policy = "all"
554
        router = NllbMoeTop2Router(config)
555
        top_1_mask_all, router_probs_all = router.route_tokens(logits, padding_mask=mask)
556

557
        # fmt: off
558
        EXPECTED_ROUTER_ALL = torch.tensor([[0.3902, 0.0000, 0.0000, 0.6098], [0.0000, 0.0000, 0.7770, 0.2230], [0.0000, 0.0000, 0.2726, 0.7274], [0.4221, 0.0000, 0.5779, 0.0000], [0.0000, 0.0000, 0.7810, 0.2190], [0.5518, 0.4482, 0.0000, 0.0000], [0.0000, 0.4060, 0.5940, 0.0000], [0.7340, 0.0000, 0.0000, 0.2660], [0.4778, 0.5222, 0.0000, 0.0000], [0.0000, 0.3984, 0.0000, 0.6016], [0.0000, 0.0548, 0.9452, 0.0000], [0.6796, 0.0000, 0.0000, 0.3204], [0.0700, 0.0000, 0.9300, 0.0000], [0.1854, 0.0000, 0.8146, 0.0000], [0.6775, 0.3225, 0.0000, 0.0000], [0.0000, 0.0000, 0.5027, 0.4973], [0.0000, 0.6577, 0.0000, 0.3423], [0.0000, 0.7767, 0.0000, 0.2233], [0.1944, 0.8056, 0.0000, 0.0000], [0.0000, 0.3073, 0.0000, 0.6927], [0.0000, 0.5655, 0.4345, 0.0000], [0.5791, 0.0000, 0.0000, 0.4209], [0.0440, 0.0000, 0.9560, 0.0000], [0.0083, 0.9917, 0.0000, 0.0000], [0.0000, 0.8395, 0.0000, 0.1605], [0.0000, 0.1458, 0.0000, 0.8542], [0.0000, 0.8534, 0.1466, 0.0000], [0.4938, 0.0000, 0.0000, 0.5062], [0.1329, 0.8671, 0.0000, 0.0000], [0.3058, 0.0000, 0.6942, 0.0000], [0.4458, 0.0000, 0.0000, 0.5542], [0.9053, 0.0947, 0.0000, 0.0000], [0.0000, 0.7563, 0.2437, 0.0000], [0.0000, 0.0000, 0.4096, 0.5904], [0.4551, 0.0000, 0.0000, 0.5449], [0.8502, 0.1498, 0.0000, 0.0000], [0.0000, 0.6312, 0.3688, 0.0000], [0.8920, 0.0000, 0.0000, 0.1080], [0.1913, 0.0000, 0.0000, 0.8087], [0.2491, 0.7509, 0.0000, 0.0000]])
559
        EXPECTED_ROUTER_SP = torch.tensor([[0.0000, 0.6539, 0.0000, 0.3461], [0.0000, 0.0000, 0.3998, 0.6002], [0.0000, 0.5574, 0.0000, 0.4426], [0.0000, 0.0000, 0.4441, 0.5559], [0.0000, 0.6545, 0.3455, 0.0000], [0.4419, 0.5581, 0.0000, 0.0000], [0.0000, 0.4014, 0.5986, 0.0000], [0.3215, 0.0000, 0.0000, 0.6785], [0.4765, 0.5235, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.4156, 0.5844, 0.0000], [0.3370, 0.0000, 0.6630, 0.0000], [0.0000, 0.0000, 0.4558, 0.5442], [0.4659, 0.0000, 0.5341, 0.0000], [0.6179, 0.3821, 0.0000, 0.0000], [0.6277, 0.0000, 0.3723, 0.0000], [0.5836, 0.4164, 0.0000, 0.0000], [0.0000, 0.6600, 0.0000, 0.3400], [0.0000, 0.4933, 0.0000, 0.5067], [0.6016, 0.0000, 0.0000, 0.3984], [0.0000, 0.5160, 0.4840, 0.0000], [0.5799, 0.0000, 0.0000, 0.4201], [0.0000, 0.0000, 0.4826, 0.5174], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [0.6448, 0.0000, 0.0000, 0.3552], [0.0000, 0.5909, 0.4091, 0.0000], [0.4196, 0.0000, 0.0000, 0.5804], [0.3191, 0.6809, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.4123, 0.0000, 0.5877, 0.0000], [0.0000, 0.3736, 0.0000, 0.6264], [0.0000, 0.0000, 0.6009, 0.3991], [0.4246, 0.0000, 0.0000, 0.5754], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.3595, 0.6405, 0.0000], [0.5433, 0.0000, 0.0000, 0.4567], [0.0000, 0.6806, 0.0000, 0.3194], [0.6689, 0.3311, 0.0000, 0.0000]])
560
        EXPECTED_ROUTER = torch.tensor([[0.4324, 0.5676, 0.0000, 0.0000], [0.0000, 0.4348, 0.0000, 0.5652], [0.4559, 0.5441, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.4744, 0.5256, 0.0000, 0.0000], [0.0000, 0.5103, 0.0000, 0.4897], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 1.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.5063, 0.4937, 0.0000, 0.0000], [0.5396, 0.0000, 0.0000, 0.4604], [0.4576, 0.5424, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.5134, 0.0000, 0.4866, 0.0000], [0.0000, 0.5160, 0.4840, 0.0000], [0.5439, 0.0000, 0.4561, 0.0000], [0.4849, 0.0000, 0.0000, 0.5151], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.4448, 0.0000, 0.5552], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.0000, 0.0000, 0.5296, 0.4704], [0.0000, 0.0000, 0.4469, 0.5531], [0.0000, 0.4053, 0.5947, 0.0000], [0.0000, 0.0000, 0.4460, 0.5540], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.0000, 0.5851, 0.4149], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.5010, 0.4990, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000]])
561

562
        EXPECTED_TOP_1_ALL = torch.LongTensor([[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0]])
563
        EXPECTED_TOP_1_SP = torch.LongTensor([[0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0]])
564
        # `sampling` and `random` do not affect the mask of the top_1 router
565
        # fmt: on
566

567
        torch.testing.assert_allclose(router_probs_all, EXPECTED_ROUTER_ALL, 1e-4, 1e-4)
568
        torch.testing.assert_allclose(router_probs_sp, EXPECTED_ROUTER_SP, 1e-4, 1e-4)
569
        torch.testing.assert_allclose(router_probs, EXPECTED_ROUTER, 1e-4, 1e-4)
570

571
        torch.testing.assert_allclose(top_1_mask_all, EXPECTED_TOP_1_ALL, 1e-4, 1e-4)
572
        torch.testing.assert_allclose(top_1_mask_sp, EXPECTED_TOP_1_SP, 1e-4, 1e-4)
573
        torch.testing.assert_allclose(top_1_mask, EXPECTED_TOP_1_SP, 1e-4, 1e-4)
574

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.