transformers

Форк
0
/
test_tokenization_utils.py 
280 строк · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2019 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import os
17
import sys
18
import tempfile
19
import unittest
20
import unittest.mock as mock
21
from pathlib import Path
22

23
from huggingface_hub import HfFolder, delete_repo
24
from huggingface_hub.file_download import http_get
25
from requests.exceptions import HTTPError
26

27
from transformers import (
28
    AlbertTokenizer,
29
    AutoTokenizer,
30
    BertTokenizer,
31
    BertTokenizerFast,
32
    GPT2TokenizerFast,
33
    is_tokenizers_available,
34
)
35
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
36
from transformers.tokenization_utils import Trie
37

38

39
sys.path.append(str(Path(__file__).parent.parent / "utils"))
40

41
from test_module.custom_tokenization import CustomTokenizer  # noqa E402
42

43

44
if is_tokenizers_available():
45
    from test_module.custom_tokenization_fast import CustomTokenizerFast
46

47

48
class TokenizerUtilTester(unittest.TestCase):
49
    def test_cached_files_are_used_when_internet_is_down(self):
50
        # A mock response for an HTTP head request to emulate server down
51
        response_mock = mock.Mock()
52
        response_mock.status_code = 500
53
        response_mock.headers = {}
54
        response_mock.raise_for_status.side_effect = HTTPError
55
        response_mock.json.return_value = {}
56

57
        # Download this model to make sure it's in the cache.
58
        _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
59

60
        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
61
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
62
            _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
63
            # This check we did call the fake head request
64
            mock_head.assert_called()
65

66
    @require_tokenizers
67
    def test_cached_files_are_used_when_internet_is_down_missing_files(self):
68
        # A mock response for an HTTP head request to emulate server down
69
        response_mock = mock.Mock()
70
        response_mock.status_code = 500
71
        response_mock.headers = {}
72
        response_mock.raise_for_status.side_effect = HTTPError
73
        response_mock.json.return_value = {}
74

75
        # Download this model to make sure it's in the cache.
76
        _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
77

78
        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
79
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
80
            _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
81
            # This check we did call the fake head request
82
            mock_head.assert_called()
83

84
    def test_legacy_load_from_one_file(self):
85
        # This test is for deprecated behavior and can be removed in v5
86
        try:
87
            tmp_file = tempfile.mktemp()
88
            with open(tmp_file, "wb") as f:
89
                http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)
90

91
            _ = AlbertTokenizer.from_pretrained(tmp_file)
92
        finally:
93
            os.remove(tmp_file)
94

95
        # Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
96
        # the current folder and have the right name.
97
        if os.path.isfile("tokenizer.json"):
98
            # We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
99
            return
100
        try:
101
            with open("tokenizer.json", "wb") as f:
102
                http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
103
            tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
104
            # The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000
105
            self.assertEqual(tokenizer.vocab_size, 1000)
106
            # Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
107

108
        finally:
109
            os.remove("tokenizer.json")
110

111
    def test_legacy_load_from_url(self):
112
        # This test is for deprecated behavior and can be removed in v5
113
        _ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model")
114

115

116
@is_staging_test
117
class TokenizerPushToHubTester(unittest.TestCase):
118
    vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
119

120
    @classmethod
121
    def setUpClass(cls):
122
        cls._token = TOKEN
123
        HfFolder.save_token(TOKEN)
124

125
    @classmethod
126
    def tearDownClass(cls):
127
        try:
128
            delete_repo(token=cls._token, repo_id="test-tokenizer")
129
        except HTTPError:
130
            pass
131

132
        try:
133
            delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
134
        except HTTPError:
135
            pass
136

137
        try:
138
            delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
139
        except HTTPError:
140
            pass
141

142
    def test_push_to_hub(self):
143
        with tempfile.TemporaryDirectory() as tmp_dir:
144
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
145
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
146
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
147
            tokenizer = BertTokenizer(vocab_file)
148

149
        tokenizer.push_to_hub("test-tokenizer", token=self._token)
150
        new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
151
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
152

153
        # Reset repo
154
        delete_repo(token=self._token, repo_id="test-tokenizer")
155

156
        # Push to hub via save_pretrained
157
        with tempfile.TemporaryDirectory() as tmp_dir:
158
            tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token)
159

160
        new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
161
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
162

163
    def test_push_to_hub_in_organization(self):
164
        with tempfile.TemporaryDirectory() as tmp_dir:
165
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
166
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
167
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
168
            tokenizer = BertTokenizer(vocab_file)
169

170
        tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token)
171
        new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
172
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
173

174
        # Reset repo
175
        delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
176

177
        # Push to hub via save_pretrained
178
        with tempfile.TemporaryDirectory() as tmp_dir:
179
            tokenizer.save_pretrained(
180
                tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token
181
            )
182

183
        new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
184
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
185

186
    @require_tokenizers
187
    def test_push_to_hub_dynamic_tokenizer(self):
188
        CustomTokenizer.register_for_auto_class()
189
        with tempfile.TemporaryDirectory() as tmp_dir:
190
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
191
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
192
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
193
            tokenizer = CustomTokenizer(vocab_file)
194

195
        # No fast custom tokenizer
196
        tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
197

198
        tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
199
        # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
200
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
201

202
        # Fast and slow custom tokenizer
203
        CustomTokenizerFast.register_for_auto_class()
204
        with tempfile.TemporaryDirectory() as tmp_dir:
205
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
206
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
207
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
208

209
            bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
210
            bert_tokenizer.save_pretrained(tmp_dir)
211
            tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
212

213
        tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
214

215
        tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
216
        # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
217
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
218
        tokenizer = AutoTokenizer.from_pretrained(
219
            f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
220
        )
221
        # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
222
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
223

224

225
class TrieTest(unittest.TestCase):
226
    def test_trie(self):
227
        trie = Trie()
228
        trie.add("Hello 友達")
229
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
230
        trie.add("Hello")
231
        trie.data
232
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
233

234
    def test_trie_split(self):
235
        trie = Trie()
236
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
237
        trie.add("[CLS]")
238
        trie.add("extra_id_1")
239
        trie.add("extra_id_100")
240
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
241

242
    def test_trie_single(self):
243
        trie = Trie()
244
        trie.add("A")
245
        self.assertEqual(trie.split("ABC"), ["A", "BC"])
246
        self.assertEqual(trie.split("BCA"), ["BC", "A"])
247

248
    def test_trie_final(self):
249
        trie = Trie()
250
        trie.add("TOKEN]")
251
        trie.add("[SPECIAL_TOKEN]")
252
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
253

254
    def test_trie_subtokens(self):
255
        trie = Trie()
256
        trie.add("A")
257
        trie.add("P")
258
        trie.add("[SPECIAL_TOKEN]")
259
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
260

261
    def test_trie_suffix_tokens(self):
262
        trie = Trie()
263
        trie.add("AB")
264
        trie.add("B")
265
        trie.add("C")
266
        self.assertEqual(trie.split("ABC"), ["AB", "C"])
267

268
    def test_trie_skip(self):
269
        trie = Trie()
270
        trie.add("ABC")
271
        trie.add("B")
272
        trie.add("CD")
273
        self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
274

275
    def test_cut_text_hardening(self):
276
        # Even if the offsets are wrong, we necessarily output correct string
277
        # parts.
278
        trie = Trie()
279
        parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
280
        self.assertEqual(parts, ["AB", "C"])
281

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.