transformers

Форк
0
/
test_tokenization_t5.py 
650 строк · 32.3 Кб
1
# coding=utf-8
2
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import json
16
import os
17
import re
18
import tempfile
19
import unittest
20

21
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
22
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_seqio, require_tokenizers, slow
23
from transformers.utils import cached_property, is_tf_available, is_torch_available
24

25
from ...test_tokenization_common import TokenizerTesterMixin
26

27

28
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
29

30
if is_torch_available():
31
    FRAMEWORK = "pt"
32
elif is_tf_available():
33
    FRAMEWORK = "tf"
34
else:
35
    FRAMEWORK = "jax"
36

37

38
@require_sentencepiece
39
@require_tokenizers
40
class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
41
    tokenizer_class = T5Tokenizer
42
    rust_tokenizer_class = T5TokenizerFast
43
    test_rust_tokenizer = True
44
    test_sentencepiece = True
45

46
    def setUp(self):
47
        super().setUp()
48

49
        # We have a SentencePiece fixture for testing
50
        tokenizer = T5Tokenizer(SAMPLE_VOCAB)
51
        tokenizer.save_pretrained(self.tmpdirname)
52

53
    def test_convert_token_and_id(self):
54
        """Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
55
        token = "<s>"
56
        token_id = 1
57

58
        self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
59
        self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
60

61
    def test_get_vocab(self):
62
        vocab_keys = list(self.get_tokenizer().get_vocab().keys())
63

64
        self.assertEqual(vocab_keys[0], "<unk>")
65
        self.assertEqual(vocab_keys[1], "<s>")
66
        self.assertEqual(vocab_keys[1100], "<pad>")
67
        self.assertEqual(len(vocab_keys), 1_101)
68

69
    def test_vocab_size(self):
70
        self.assertEqual(self.get_tokenizer().vocab_size, 1000)
71
        self.assertEqual(len(self.get_tokenizer()), 1101)
72

73
    def test_full_tokenizer(self):
74
        tokenizer = T5Tokenizer(SAMPLE_VOCAB)
75

76
        tokens = tokenizer.tokenize("This is a test")
77
        self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
78

79
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
80

81
        tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
82
        self.assertListEqual(
83
            tokens,
84
            [
85
                SPIECE_UNDERLINE + "I",
86
                SPIECE_UNDERLINE + "was",
87
                SPIECE_UNDERLINE + "b",
88
                "or",
89
                "n",
90
                SPIECE_UNDERLINE + "in",
91
                SPIECE_UNDERLINE + "",
92
                "9",
93
                "2",
94
                "0",
95
                "0",
96
                "0",
97
                ",",
98
                SPIECE_UNDERLINE + "and",
99
                SPIECE_UNDERLINE + "this",
100
                SPIECE_UNDERLINE + "is",
101
                SPIECE_UNDERLINE + "f",
102
                "al",
103
                "s",
104
                "é",
105
                ".",
106
            ],
107
        )
108
        ids = tokenizer.convert_tokens_to_ids(tokens)
109
        self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
110

111
        back_tokens = tokenizer.convert_ids_to_tokens(ids)
112
        self.assertListEqual(
113
            back_tokens,
114
            [
115
                SPIECE_UNDERLINE + "I",
116
                SPIECE_UNDERLINE + "was",
117
                SPIECE_UNDERLINE + "b",
118
                "or",
119
                "n",
120
                SPIECE_UNDERLINE + "in",
121
                SPIECE_UNDERLINE + "",
122
                "<unk>",
123
                "2",
124
                "0",
125
                "0",
126
                "0",
127
                ",",
128
                SPIECE_UNDERLINE + "and",
129
                SPIECE_UNDERLINE + "this",
130
                SPIECE_UNDERLINE + "is",
131
                SPIECE_UNDERLINE + "f",
132
                "al",
133
                "s",
134
                "<unk>",
135
                ".",
136
            ],
137
        )
138

139
    @cached_property
140
    def t5_base_tokenizer(self):
141
        return T5Tokenizer.from_pretrained("google-t5/t5-base")
142

143
    @cached_property
144
    def t5_base_tokenizer_fast(self):
145
        return T5TokenizerFast.from_pretrained("google-t5/t5-base")
146

147
    def get_tokenizer(self, **kwargs) -> T5Tokenizer:
148
        return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
149

150
    def get_rust_tokenizer(self, **kwargs) -> T5TokenizerFast:
151
        return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
152

153
    def test_rust_and_python_full_tokenizers(self):
154
        if not self.test_rust_tokenizer:
155
            return
156

157
        tokenizer = self.get_tokenizer()
158
        rust_tokenizer = self.get_rust_tokenizer()
159

160
        sequence = "I was born in 92000, and this is falsé."
161

162
        tokens = tokenizer.tokenize(sequence)
163
        rust_tokens = rust_tokenizer.tokenize(sequence)
164
        self.assertListEqual(tokens, rust_tokens)
165

166
        ids = tokenizer.encode(sequence, add_special_tokens=False)
167
        rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
168
        self.assertListEqual(ids, rust_ids)
169

170
        rust_tokenizer = self.get_rust_tokenizer()
171
        ids = tokenizer.encode(sequence)
172
        rust_ids = rust_tokenizer.encode(sequence)
173
        self.assertListEqual(ids, rust_ids)
174

175
    def test_eos_treatment(self):
176
        tokenizer = self.t5_base_tokenizer
177
        batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
178
        batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
179
        self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
180

181
    def test_prepare_batch(self):
182
        tokenizer = self.t5_base_tokenizer
183
        src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
184
        expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
185
        batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
186
        self.assertIsInstance(batch, BatchEncoding)
187

188
        if FRAMEWORK != "jax":
189
            result = list(batch.input_ids.numpy()[0])
190
        else:
191
            result = list(batch.input_ids.tolist()[0])
192

193
        self.assertListEqual(expected_src_tokens, result)
194

195
        self.assertEqual((2, 9), batch.input_ids.shape)
196
        self.assertEqual((2, 9), batch.attention_mask.shape)
197

198
    def test_empty_target_text(self):
199
        tokenizer = self.t5_base_tokenizer
200
        src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
201
        batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
202
        # check if input_ids are returned and no decoder_input_ids
203
        self.assertIn("input_ids", batch)
204
        self.assertIn("attention_mask", batch)
205
        self.assertNotIn("decoder_input_ids", batch)
206
        self.assertNotIn("decoder_attention_mask", batch)
207

208
    def test_max_length(self):
209
        tokenizer = self.t5_base_tokenizer
210
        tgt_text = [
211
            "Summary of the text.",
212
            "Another summary.",
213
        ]
214
        targets = tokenizer(
215
            text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
216
        )
217
        self.assertEqual(32, targets["input_ids"].shape[1])
218

219
    def test_outputs_not_longer_than_maxlen(self):
220
        tokenizer = self.t5_base_tokenizer
221

222
        batch = tokenizer(
223
            ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
224
        )
225
        self.assertIsInstance(batch, BatchEncoding)
226
        # Since T5 does NOT have a max input length,
227
        # this test should be changed to the following in Transformers v5:
228
        # self.assertEqual(batch.input_ids.shape, (2, 8001))
229
        self.assertEqual(batch.input_ids.shape, (2, 512))
230

231
    def test_eos_in_input(self):
232
        tokenizer = self.t5_base_tokenizer
233
        src_text = ["A long paragraph for summarization. </s>"]
234
        tgt_text = ["Summary of the text. </s>"]
235
        expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
236
        expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
237

238
        batch = tokenizer(src_text, text_target=tgt_text)
239

240
        self.assertEqual(expected_src_tokens, batch["input_ids"][0])
241
        self.assertEqual(expected_tgt_tokens, batch["labels"][0])
242

243
    def test_token_type_ids(self):
244
        src_text_1 = ["A first paragraph for summarization."]
245
        src_text_2 = ["A second paragraph for summarization."]
246

247
        fast_token_type_ids = self.t5_base_tokenizer_fast(
248
            src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
249
        ).token_type_ids
250
        slow_token_type_ids = self.t5_base_tokenizer(
251
            src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
252
        ).token_type_ids
253

254
        self.assertEqual(slow_token_type_ids, fast_token_type_ids)
255
        self.assertEqual(len(slow_token_type_ids[0]), 18)
256

257
    def test_fast_and_slow_same_result(self):
258
        src_text = "<pad> Today is <unk> nice day </s>"
259
        tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
260
        tgt_text = "<pad> Today is<unk> nice day</s>"
261

262
        fast_ids = self.t5_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids
263
        slow_ids = self.t5_base_tokenizer(src_text, add_special_tokens=False).input_ids
264
        self.assertEqual(tgt_ids, fast_ids)
265
        self.assertEqual(tgt_ids, slow_ids)
266

267
        fast_text = self.t5_base_tokenizer_fast.decode(fast_ids)
268
        slow_text = self.t5_base_tokenizer.decode(fast_ids)
269
        self.assertEqual(tgt_text, fast_text)
270
        self.assertEqual(tgt_text, slow_text)
271

272
    def test_special_tokens_initialization(self):
273
        for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
274
            with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
275
                added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
276

277
                tokenizer_r = self.rust_tokenizer_class.from_pretrained(
278
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs
279
                )
280
                tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
281
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
282
                )
283
                tokenizer_p = self.tokenizer_class.from_pretrained(
284
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs
285
                )
286

287
                p_output = tokenizer_p.encode("Hey this is a <special> token")
288
                r_output = tokenizer_r.encode("Hey this is a <special> token")
289
                cr_output = tokenizer_cr.encode("Hey this is a <special> token")
290

291
                special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
292

293
                self.assertEqual(p_output, r_output)
294
                self.assertEqual(cr_output, r_output)
295
                self.assertTrue(special_token_id in p_output)
296
                self.assertTrue(special_token_id in r_output)
297
                self.assertTrue(special_token_id in cr_output)
298

299
    def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self):
300
        tokenizer_list = []
301
        if self.test_slow_tokenizer:
302
            tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))
303

304
        if self.test_rust_tokenizer:
305
            tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))
306

307
        for tokenizer_class, tokenizer_utils in tokenizer_list:
308
            with tempfile.TemporaryDirectory() as tmp_dir:
309
                tokenizer_utils.save_pretrained(tmp_dir)
310

311
                with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file:
312
                    special_tokens_map = json.load(json_file)
313

314
                with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file:
315
                    tokenizer_config = json.load(json_file)
316

317
                added_tokens_extra_ids = [f"<extra_id_{i}>" for i in range(100)]
318

319
                special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [
320
                    "an_additional_special_token"
321
                ]
322
                tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [
323
                    "an_additional_special_token"
324
                ]
325

326
                with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile:
327
                    json.dump(special_tokens_map, outfile)
328
                with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile:
329
                    json.dump(tokenizer_config, outfile)
330

331
                # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes
332
                # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and
333
                # "special_tokens_map.json" files
334
                tokenizer_without_change_in_init = tokenizer_class.from_pretrained(
335
                    tmp_dir,
336
                )
337
                self.assertIn(
338
                    "an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens
339
                )
340
                # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # ByT5Tokenization no vocab
341
                self.assertEqual(
342
                    ["an_additional_special_token"],
343
                    tokenizer_without_change_in_init.convert_ids_to_tokens(
344
                        tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"])
345
                    ),
346
                )
347

348
                # Now we test that we can change the value of additional_special_tokens in the from_pretrained
349
                new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)]
350
                tokenizer = tokenizer_class.from_pretrained(
351
                    tmp_dir,
352
                    additional_special_tokens=new_added_tokens,
353
                )
354

355
                self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens)
356
                self.assertEqual(
357
                    ["a_new_additional_special_token"],
358
                    tokenizer.convert_ids_to_tokens(
359
                        tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"])
360
                    ),
361
                )
362

363
    # overwritten from `test_tokenization_common` since T5 has no max length
364
    def test_pretrained_model_lists(self):
365
        # We should have at least one default checkpoint for each tokenizer
366
        # We should specify the max input length as well (used in some part to list the pretrained checkpoints)
367
        self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
368
        self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
369

370
    @slow
371
    def test_tokenizer_integration(self):
372
        expected_encoding = {'input_ids': [[31220, 7, 41, 14034, 801, 38, 3, 102, 63, 17, 127, 524, 18, 7031, 2032, 277, 11, 3, 102, 63, 17, 127, 524, 18, 2026, 17, 10761, 18, 7041, 61, 795, 879, 18, 19681, 4648, 7, 41, 12920, 382, 6, 350, 6383, 4949, 6, 2158, 12920, 382, 9, 6, 3, 4, 11160, 6, 2043, 17153, 279, 49, 17, 6, 3, 4, 434, 9688, 11439, 21, 6869, 10509, 17725, 41, 567, 9138, 61, 11, 6869, 10509, 11946, 41, 18207, 517, 61, 28, 147, 3538, 1220, 7140, 10761, 2250, 16, 910, 1220, 8024, 11, 1659, 1413, 32, 883, 2020, 344, 2215, 226, 6, 12901, 382, 127, 524, 11, 4738, 7, 127, 15390, 5, 1], [272, 24203, 19, 876, 12, 554, 18, 9719, 1659, 2647, 26352, 6497, 7, 45, 73, 9339, 400, 26, 1499, 57, 22801, 10760, 30, 321, 646, 11, 269, 2625, 16, 66, 7500, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [37, 1704, 4216, 3, 20400, 4418, 7, 147, 8, 19743, 1782, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}  # fmt: skip
373

374
        self.tokenizer_integration_test_util(
375
            expected_encoding=expected_encoding,
376
            model_name="google-t5/t5-base",
377
            revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
378
        )
379

380
    def test_get_sentinel_tokens(self):
381
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
382
        sentinel_tokens = tokenizer.get_sentinel_tokens()
383
        self.assertEqual(len(sentinel_tokens), 10)
384
        self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
385
        self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
386

387
    def test_get_sentinel_token_ids(self):
388
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
389
        self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
390

391
    def test_get_sentinel_tokens_for_fasttokenizer(self):
392
        tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
393
        sentinel_tokens = tokenizer.get_sentinel_tokens()
394
        self.assertEqual(len(sentinel_tokens), 10)
395
        self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
396
        self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
397

398
    def test_get_sentinel_token_ids_for_fasttokenizer(self):
399
        tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
400
        self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
401

402
    def test_some_edge_cases(self):
403
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
404

405
        sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
406
        self.assertEqual(sp_tokens, ["<", "/", "s", ">", ">"])
407
        tokens = tokenizer.tokenize("</s>>")
408
        self.assertNotEqual(sp_tokens, tokens)
409
        self.assertEqual(tokens, ["</s>", ">"])
410

411
        tokens = tokenizer.tokenize("")
412
        self.assertEqual(tokens, [])
413
        self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
414

415
        tokens = tokenizer.tokenize(" ")
416
        self.assertEqual(tokens, [])
417
        self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str))
418

419
        tokens = tokenizer.tokenize("▁")
420
        self.assertEqual(tokens, [])
421
        self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))
422

423
        tokens = tokenizer.tokenize(" ▁")
424
        self.assertEqual(tokens, [])
425
        self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))
426

427
    def test_fast_slow_edge_cases(self):
428
        # We are testing spaces before and spaces after special tokens + space transformations
429
        slow_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
430
        fast_tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-base", legacy=False, from_slow=True)
431
        slow_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=False))
432
        fast_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=False))
433

434
        edge_case = "Hey!<new_token_test_>. How</s>Hey <new_token_test_>!"
435
        EXPECTED_SLOW = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "He", "y", "<new_token_test_>", "!"]  # fmt: skip
436
        with self.subTest(f"slow {edge_case} normalized = False"):
437
            self.assertEqual(slow_tokenizer.tokenize(edge_case), EXPECTED_SLOW)
438
        with self.subTest(f"Fast {edge_case} normalized = False"):
439
            self.assertEqual(fast_tokenizer.tokenize(edge_case), EXPECTED_SLOW)
440

441
        hard_case = "Hey! <new_token_test_>. How</s>   Hey   <new_token_test_>  !     .     "
442
        EXPECTED_SLOW = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "▁Hey", "<new_token_test_>", "▁", "!", "▁", "."]  # fmt: skip
443
        with self.subTest(f"slow {edge_case} normalized = False"):
444
            self.assertEqual(slow_tokenizer.tokenize(hard_case), EXPECTED_SLOW)
445
        with self.subTest(f"fast {edge_case} normalized = False"):
446
            self.assertEqual(fast_tokenizer.tokenize(hard_case), EXPECTED_SLOW)
447

448
        fast_tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-base", legacy=False, from_slow=True)
449
        fast_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=True))
450

451
        # `normalized=True` is the default normalization scheme when adding a token. Normalize -> don't strip the space.
452
        # the issue now is that our slow tokenizer should NOT strip the space if we want to simulate sentencepiece token addition.
453

454
        EXPECTED_FAST = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "He", "y", "▁", "<new_token_test_>", "!"]  # fmt: skip
455
        with self.subTest(f"fast {edge_case} normalized = True"):
456
            self.assertEqual(fast_tokenizer.tokenize(edge_case), EXPECTED_FAST)
457

458
        EXPECTED_FAST = ['▁Hey', '!', '▁', '<new_token_test_>', '.', '▁How', '</s>', '▁Hey','▁', '<new_token_test_>', '▁', '!', '▁', '.']  # fmt: skip
459
        with self.subTest(f"fast {edge_case} normalized = False"):
460
            self.assertEqual(fast_tokenizer.tokenize(hard_case), EXPECTED_FAST)
461

462
    def test_add_prefix_space(self):
463
        pretrained_name = "google-t5/t5-base"
464
        inputs = "Hey how are you doing"
465
        EXPECTED_WITH_SPACE = [9459, 149, 33, 25, 692, 1]
466
        EXPECTED_WO_SPACE = [3845, 63, 149, 33, 25, 692, 1]
467

468
        slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=False, legacy=False)
469
        fast_ = self.rust_tokenizer_class.from_pretrained(
470
            pretrained_name, add_prefix_space=False, legacy=False, from_slow=True
471
        )
472
        self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE)
473
        self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
474
        self.assertEqual(slow_.tokenize(inputs), ["He", "y", "▁how", "▁are", "▁you", "▁doing"])
475
        self.assertEqual(slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), inputs)
476
        self.assertEqual(
477
            slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
478
            fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
479
        )
480

481
        slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
482
        fast_ = self.rust_tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
483
        self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE)
484
        self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
485
        self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"])
486
        self.assertEqual(slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), inputs)
487
        self.assertEqual(
488
            slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
489
            fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
490
        )
491

492

493
@require_sentencepiece
494
@require_tokenizers
495
class CommonSpmIntegrationTests(unittest.TestCase):
496
    """
497
    A class that regroups important test to make sure that we properly handle the special tokens.
498
    """
499

500
    @classmethod
501
    def setUpClass(cls):
502
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0, legacy=False)
503
        tokenizer.add_special_tokens(
504
            {"additional_special_tokens": [AddedToken("<extra_id_0>", rstrip=False, lstrip=False)]}
505
        )
506
        # TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
507
        # So the extra ids are split....
508
        cls.tokenizer = tokenizer
509

510
    def test_add_dummy_prefix(self):
511
        # make sure `'▁'` is prepended, and outputs match sp_model's
512
        # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
513
        input_ids = self.tokenizer.encode(". Hello", add_special_tokens=False)
514
        self.assertEqual(input_ids, [7, 4, 156, 86, 20])
515
        sp_encode = self.tokenizer.sp_model.encode(". Hello")
516
        self.assertEqual(input_ids, [7] + sp_encode)
517
        tokens = self.tokenizer.tokenize(". Hello")
518
        self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])
519

520
        tokens = self.tokenizer.tokenize("")
521
        self.assertEqual(tokens, [])
522
        self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str))
523

524
        tokens = self.tokenizer.tokenize(" ")
525
        self.assertEqual(tokens, [])
526
        self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str))
527

528
        tokens = self.tokenizer.tokenize("▁")
529
        self.assertEqual(tokens, [])
530
        self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str))
531

532
    def test_remove_extra_whitespaces(self):
533
        # make sure the extra spaces are eaten
534
        # sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute
535
        input_ids = self.tokenizer.encode("       . Hello", add_special_tokens=False)
536
        self.assertEqual(input_ids, [7, 4, 156, 86, 20])
537
        sp_encode = self.tokenizer.sp_model.encode("       . Hello")
538
        self.assertEqual(input_ids, [7] + sp_encode)
539
        tokens = self.tokenizer.tokenize(" . Hello")
540
        self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])
541

542
        # `'▁'` is also a whitespace
543
        input_ids = self.tokenizer.encode("▁He is not")
544
        self.assertEqual(input_ids, [156, 46, 44, 2])
545
        tokens = self.tokenizer.tokenize("▁He is not")
546
        self.assertEqual(tokens, ["▁He", "▁is", "▁not"])  # no extra space added
547

548
        input_ids = self.tokenizer.encode("▁He is not<extra_id_0>             ▁He")
549
        # here t5x does not eat with lstrip, so there is and extra ▁He in the original one
550
        self.assertEqual(input_ids, [156, 46, 44, 1001, 156, 2])
551
        tokens = self.tokenizer.tokenize("▁He is not<extra_id_0>              ▁He")
552
        self.assertEqual(tokens, ["▁He", "▁is", "▁not", "<extra_id_0>", "▁He"])  # spaces are eaten by spm
553
        # make sure that the output after the extra id is the same as if
554
        # extra_id was not there
555
        input_ids = self.tokenizer.encode("▁He is not             ▁He")
556
        self.assertEqual(input_ids, [156, 46, 44, 156, 2])
557
        tokens = self.tokenizer.tokenize("▁He is not              ▁He")
558
        self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"])  # spaces are eaten by spm even if not start
559

560
    def test_character_after_special_token(self):
561
        # Make sure that `tokenizer.tokenize` is similar to
562
        # adding the equivalent special token to the vocab
563
        input_ids = self.tokenizer.encode("Hey <extra_id_0>I")
564
        self.assertEqual(input_ids, [156, 30, 1001, 100, 2])
565
        tokens = self.tokenizer.tokenize("Hey <extra_id_0>I")
566
        self.assertEqual(tokens, ["▁He", "y", "<extra_id_0>", "I"])
567

568
        input_ids = self.tokenizer.encode("Hello, <extra_id_0>,")
569
        self.assertEqual(input_ids, [156, 86, 20, 3, 1001, 3, 2])
570
        tokens = self.tokenizer.tokenize("Hello, <extra_id_0>,")
571
        self.assertEqual(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])
572

573
    def test_special_tokens_strip(self):
574
        input_ids = self.tokenizer.encode(" <extra_id_0> ,")
575
        self.assertEqual(input_ids, [1001, 7, 3, 2])
576
        tokens = self.tokenizer.tokenize(" <extra_id_0> ,")
577
        # spaces are not longer eaten by rstrip and lstrip
578
        self.assertEqual(tokens, ["<extra_id_0>", "▁", ","])
579

580
        # test with a begin of word like `▁He`
581
        input_ids = self.tokenizer.encode("No <extra_id_0> He")
582
        self.assertEqual(input_ids, [284, 1001, 156, 2])
583
        # spaces are eaten by rstrip / lstrip, so this is expected. Don't strip otherwise you break
584
        tokens = self.tokenizer.tokenize("No <extra_id_0> He")
585
        self.assertEqual(tokens, ["▁No", "<extra_id_0>", "▁He"])
586

587
        # Make sure this does not happen if we don't strip
588
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
589
        tokenizer.add_special_tokens({"bos_token": AddedToken("<bos>")})
590
        input_ids = tokenizer.encode("No <bos> He")
591
        self.assertEqual(input_ids, [284, 1001, 156, 2])
592
        tokens = tokenizer.tokenize("No <bos> He")
593
        # the first `' '` after `'No'` is eaten by spm:
594
        self.assertEqual(tokenizer.sp_model.encode("No         ", out_type=str), ["▁No"])
595
        self.assertEqual(tokens, ["▁No", "<bos>", "▁He"])
596

597
    @require_seqio
598
    @unittest.skipIf(
599
        os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
600
        "RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
601
    )
602
    def test_integration_seqio(self):
603
        from datasets import load_dataset
604
        from seqio import SentencePieceVocabulary
605

606
        ds = load_dataset("xnli", "all_languages", split="train+test+validation")
607

608
        # TODO @ArthurZucker fix the 3 commented tests with #23909
609
        input_texts = [
610
            "Bonjour <extra_id_0>.",
611
            # "Bonjour<extra_id_0>.",  # this will fail. In T5 the special token has to be at the end.
612
            # because in T5 they add `_<extra_id_0>` to the vocab, not `<extra_id_0>`.
613
            "                   Hey <extra_id_0>I love you",
614
            # "Hey <extra_id_0> I love you", # this will fail, we strip left, to _I vs I
615
            # "Hey <extra_id_0>▁He", # this will fail for the same reason, we replace `_` then strip
616
        ]
617

618
        import tqdm
619

620
        # Test with umt5
621
        vocab_path = "gs://t5-data/vocabs/umt5.256000/sentencepiece.model"
622
        t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
623
        hf_tokenizer = T5Tokenizer.from_pretrained("google/umt5-small", legacy=False)
624
        for text in input_texts:
625
            self.assertEqual(
626
                hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
627
            )
628
        for texts in tqdm.tqdm(ds["premise"]):
629
            for text in texts:
630
                self.assertEqual(
631
                    hf_tokenizer.encode(text, add_special_tokens=False),
632
                    t5x_tokenizer.tokenizer.tokenize(text),
633
                    f"{text}",
634
                )
635

636
        # Test with T5
637
        hf_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
638
        vocab_path = "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model"
639
        t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
640
        for text in input_texts:
641
            self.assertEqual(
642
                hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
643
            )
644
        for texts in tqdm.tqdm(ds["premise"]):
645
            for text in texts:
646
                self.assertEqual(
647
                    hf_tokenizer.encode(text, add_special_tokens=False),
648
                    t5x_tokenizer.tokenizer.tokenize(text),
649
                    f"{text}",
650
                )
651

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.