transformers

Форк
0
/
test_tokenization_marian.py 
155 строк · 8.6 Кб
1
# coding=utf-8
2
# Copyright 2020 Huggingface
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import tempfile
17
import unittest
18
from pathlib import Path
19
from shutil import copyfile
20

21
from transformers import BatchEncoding, MarianTokenizer
22
from transformers.testing_utils import get_tests_dir, require_sentencepiece, slow
23
from transformers.utils import is_sentencepiece_available, is_tf_available, is_torch_available
24

25

26
if is_sentencepiece_available():
27
    from transformers.models.marian.tokenization_marian import VOCAB_FILES_NAMES, save_json
28

29
from ...test_tokenization_common import TokenizerTesterMixin
30

31

32
SAMPLE_SP = get_tests_dir("fixtures/test_sentencepiece.model")
33

34
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
35
zh_code = ">>zh<<"
36
ORG_NAME = "Helsinki-NLP/"
37

38
if is_torch_available():
39
    FRAMEWORK = "pt"
40
elif is_tf_available():
41
    FRAMEWORK = "tf"
42
else:
43
    FRAMEWORK = "jax"
44

45

46
@require_sentencepiece
47
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
48
    tokenizer_class = MarianTokenizer
49
    test_rust_tokenizer = False
50
    test_sentencepiece = True
51

52
    def setUp(self):
53
        super().setUp()
54
        vocab = ["</s>", "<unk>", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", "<pad>"]
55
        vocab_tokens = dict(zip(vocab, range(len(vocab))))
56
        save_dir = Path(self.tmpdirname)
57
        save_json(vocab_tokens, save_dir / VOCAB_FILES_NAMES["vocab"])
58
        save_json(mock_tokenizer_config, save_dir / VOCAB_FILES_NAMES["tokenizer_config_file"])
59
        if not (save_dir / VOCAB_FILES_NAMES["source_spm"]).exists():
60
            copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["source_spm"])
61
            copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["target_spm"])
62

63
        tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname)
64
        tokenizer.save_pretrained(self.tmpdirname)
65

66
    def get_tokenizer(self, **kwargs) -> MarianTokenizer:
67
        return MarianTokenizer.from_pretrained(self.tmpdirname, **kwargs)
68

69
    def get_input_output_texts(self, tokenizer):
70
        return (
71
            "This is a test",
72
            "This is a test",
73
        )
74

75
    def test_convert_token_and_id(self):
76
        """Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
77
        token = "</s>"
78
        token_id = 0
79

80
        self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
81
        self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
82

83
    def test_get_vocab(self):
84
        vocab_keys = list(self.get_tokenizer().get_vocab().keys())
85

86
        self.assertEqual(vocab_keys[0], "</s>")
87
        self.assertEqual(vocab_keys[1], "<unk>")
88
        self.assertEqual(vocab_keys[-1], "<pad>")
89
        self.assertEqual(len(vocab_keys), 9)
90

91
    def test_vocab_size(self):
92
        self.assertEqual(self.get_tokenizer().vocab_size, 9)
93

94
    def test_tokenizer_equivalence_en_de(self):
95
        en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
96
        batch = en_de_tokenizer(["I am a small frog"], return_tensors=None)
97
        self.assertIsInstance(batch, BatchEncoding)
98
        expected = [38, 121, 14, 697, 38848, 0]
99
        self.assertListEqual(expected, batch.input_ids[0])
100

101
        save_dir = tempfile.mkdtemp()
102
        en_de_tokenizer.save_pretrained(save_dir)
103
        contents = [x.name for x in Path(save_dir).glob("*")]
104
        self.assertIn("source.spm", contents)
105
        MarianTokenizer.from_pretrained(save_dir)
106

107
    def test_outputs_not_longer_than_maxlen(self):
108
        tok = self.get_tokenizer()
109

110
        batch = tok(
111
            ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
112
        )
113
        self.assertIsInstance(batch, BatchEncoding)
114
        self.assertEqual(batch.input_ids.shape, (2, 512))
115

116
    def test_outputs_can_be_shorter(self):
117
        tok = self.get_tokenizer()
118
        batch_smaller = tok(["I am a tiny frog", "I am a small frog"], padding=True, return_tensors=FRAMEWORK)
119
        self.assertIsInstance(batch_smaller, BatchEncoding)
120
        self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
121

122
    @slow
123
    def test_tokenizer_integration(self):
124
        expected_encoding = {'input_ids': [[43495, 462, 20, 42164, 1369, 52, 464, 132, 1703, 492, 13, 7491, 38999, 6, 8, 464, 132, 1703, 492, 13, 4669, 37867, 13, 7525, 27, 1593, 988, 13, 33972, 7029, 6, 20, 8251, 383, 2, 270, 5866, 3788, 2, 2353, 8251, 12338, 2, 13958, 387, 2, 3629, 6953, 188, 2900, 2, 13958, 8011, 11501, 23, 8460, 4073, 34009, 20, 435, 11439, 27, 8, 8460, 4073, 6004, 20, 9988, 375, 27, 33, 266, 1945, 1076, 1350, 37867, 3288, 5, 577, 1076, 4374, 8, 5082, 5, 26453, 257, 556, 403, 2, 242, 132, 383, 316, 492, 8, 10767, 6, 316, 304, 4239, 3, 0], [148, 15722, 19, 1839, 12, 1350, 13, 22327, 5082, 5418, 47567, 35938, 59, 318, 19552, 108, 2183, 54, 14976, 4835, 32, 547, 1114, 8, 315, 2417, 5, 92, 19088, 3, 0, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100], [36, 6395, 12570, 39147, 11597, 6, 266, 4, 45405, 7296, 3, 0, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100, 58100]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}  # fmt: skip
125
        self.tokenizer_integration_test_util(
126
            expected_encoding=expected_encoding,
127
            model_name="Helsinki-NLP/opus-mt-en-de",
128
            revision="1a8c2263da11e68e50938f97e10cd57820bd504c",
129
            decode_kwargs={"use_source_tokenizer": True},
130
        )
131

132
    def test_tokenizer_integration_seperate_vocabs(self):
133
        tokenizer = MarianTokenizer.from_pretrained("hf-internal-testing/test-marian-two-vocabs")
134

135
        source_text = "Tämä on testi"
136
        target_text = "This is a test"
137

138
        expected_src_ids = [76, 7, 2047, 2]
139
        expected_target_ids = [69, 12, 11, 940, 2]
140

141
        src_ids = tokenizer(source_text).input_ids
142
        self.assertListEqual(src_ids, expected_src_ids)
143

144
        target_ids = tokenizer(text_target=target_text).input_ids
145
        self.assertListEqual(target_ids, expected_target_ids)
146

147
        decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
148
        self.assertEqual(decoded, target_text)
149

150
    def test_tokenizer_decode(self):
151
        tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
152
        source_text = "Hello World"
153
        ids = tokenizer(source_text)["input_ids"]
154
        output_text = tokenizer.decode(ids, skip_special_tokens=True)
155
        self.assertEqual(source_text, output_text)
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.