transformers
280 строк · 11.5 Кб
1# coding=utf-8
2# Copyright 2019 HuggingFace Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os17import sys18import tempfile19import unittest20import unittest.mock as mock21from pathlib import Path22
23from huggingface_hub import HfFolder, delete_repo24from huggingface_hub.file_download import http_get25from requests.exceptions import HTTPError26
27from transformers import (28AlbertTokenizer,29AutoTokenizer,30BertTokenizer,31BertTokenizerFast,32GPT2TokenizerFast,33is_tokenizers_available,34)
35from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers36from transformers.tokenization_utils import Trie37
38
39sys.path.append(str(Path(__file__).parent.parent / "utils"))40
41from test_module.custom_tokenization import CustomTokenizer # noqa E40242
43
44if is_tokenizers_available():45from test_module.custom_tokenization_fast import CustomTokenizerFast46
47
48class TokenizerUtilTester(unittest.TestCase):49def test_cached_files_are_used_when_internet_is_down(self):50# A mock response for an HTTP head request to emulate server down51response_mock = mock.Mock()52response_mock.status_code = 50053response_mock.headers = {}54response_mock.raise_for_status.side_effect = HTTPError55response_mock.json.return_value = {}56
57# Download this model to make sure it's in the cache.58_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")59
60# Under the mock environment we get a 500 error when trying to reach the tokenizer.61with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:62_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")63# This check we did call the fake head request64mock_head.assert_called()65
66@require_tokenizers67def test_cached_files_are_used_when_internet_is_down_missing_files(self):68# A mock response for an HTTP head request to emulate server down69response_mock = mock.Mock()70response_mock.status_code = 50071response_mock.headers = {}72response_mock.raise_for_status.side_effect = HTTPError73response_mock.json.return_value = {}74
75# Download this model to make sure it's in the cache.76_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")77
78# Under the mock environment we get a 500 error when trying to reach the tokenizer.79with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:80_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")81# This check we did call the fake head request82mock_head.assert_called()83
84def test_legacy_load_from_one_file(self):85# This test is for deprecated behavior and can be removed in v586try:87tmp_file = tempfile.mktemp()88with open(tmp_file, "wb") as f:89http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)90
91_ = AlbertTokenizer.from_pretrained(tmp_file)92finally:93os.remove(tmp_file)94
95# Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in96# the current folder and have the right name.97if os.path.isfile("tokenizer.json"):98# We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.99return100try:101with open("tokenizer.json", "wb") as f:102http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)103tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")104# The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000105self.assertEqual(tokenizer.vocab_size, 1000)106# Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.107
108finally:109os.remove("tokenizer.json")110
111def test_legacy_load_from_url(self):112# This test is for deprecated behavior and can be removed in v5113_ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model")114
115
116@is_staging_test
117class TokenizerPushToHubTester(unittest.TestCase):118vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]119
120@classmethod121def setUpClass(cls):122cls._token = TOKEN123HfFolder.save_token(TOKEN)124
125@classmethod126def tearDownClass(cls):127try:128delete_repo(token=cls._token, repo_id="test-tokenizer")129except HTTPError:130pass131
132try:133delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")134except HTTPError:135pass136
137try:138delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")139except HTTPError:140pass141
142def test_push_to_hub(self):143with tempfile.TemporaryDirectory() as tmp_dir:144vocab_file = os.path.join(tmp_dir, "vocab.txt")145with open(vocab_file, "w", encoding="utf-8") as vocab_writer:146vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))147tokenizer = BertTokenizer(vocab_file)148
149tokenizer.push_to_hub("test-tokenizer", token=self._token)150new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")151self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)152
153# Reset repo154delete_repo(token=self._token, repo_id="test-tokenizer")155
156# Push to hub via save_pretrained157with tempfile.TemporaryDirectory() as tmp_dir:158tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token)159
160new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")161self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)162
163def test_push_to_hub_in_organization(self):164with tempfile.TemporaryDirectory() as tmp_dir:165vocab_file = os.path.join(tmp_dir, "vocab.txt")166with open(vocab_file, "w", encoding="utf-8") as vocab_writer:167vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))168tokenizer = BertTokenizer(vocab_file)169
170tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token)171new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")172self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)173
174# Reset repo175delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")176
177# Push to hub via save_pretrained178with tempfile.TemporaryDirectory() as tmp_dir:179tokenizer.save_pretrained(180tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token181)182
183new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")184self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)185
186@require_tokenizers187def test_push_to_hub_dynamic_tokenizer(self):188CustomTokenizer.register_for_auto_class()189with tempfile.TemporaryDirectory() as tmp_dir:190vocab_file = os.path.join(tmp_dir, "vocab.txt")191with open(vocab_file, "w", encoding="utf-8") as vocab_writer:192vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))193tokenizer = CustomTokenizer(vocab_file)194
195# No fast custom tokenizer196tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)197
198tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)199# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module200self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")201
202# Fast and slow custom tokenizer203CustomTokenizerFast.register_for_auto_class()204with tempfile.TemporaryDirectory() as tmp_dir:205vocab_file = os.path.join(tmp_dir, "vocab.txt")206with open(vocab_file, "w", encoding="utf-8") as vocab_writer:207vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))208
209bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)210bert_tokenizer.save_pretrained(tmp_dir)211tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)212
213tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)214
215tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)216# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module217self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")218tokenizer = AutoTokenizer.from_pretrained(219f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True220)221# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module222self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")223
224
225class TrieTest(unittest.TestCase):226def test_trie(self):227trie = Trie()228trie.add("Hello 友達")229self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})230trie.add("Hello")231trie.data232self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})233
234def test_trie_split(self):235trie = Trie()236self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])237trie.add("[CLS]")238trie.add("extra_id_1")239trie.add("extra_id_100")240self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])241
242def test_trie_single(self):243trie = Trie()244trie.add("A")245self.assertEqual(trie.split("ABC"), ["A", "BC"])246self.assertEqual(trie.split("BCA"), ["BC", "A"])247
248def test_trie_final(self):249trie = Trie()250trie.add("TOKEN]")251trie.add("[SPECIAL_TOKEN]")252self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])253
254def test_trie_subtokens(self):255trie = Trie()256trie.add("A")257trie.add("P")258trie.add("[SPECIAL_TOKEN]")259self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])260
261def test_trie_suffix_tokens(self):262trie = Trie()263trie.add("AB")264trie.add("B")265trie.add("C")266self.assertEqual(trie.split("ABC"), ["AB", "C"])267
268def test_trie_skip(self):269trie = Trie()270trie.add("ABC")271trie.add("B")272trie.add("CD")273self.assertEqual(trie.split("ABCD"), ["ABC", "D"])274
275def test_cut_text_hardening(self):276# Even if the offsets are wrong, we necessarily output correct string277# parts.278trie = Trie()279parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])280self.assertEqual(parts, ["AB", "C"])281