transformers

Форк
0
/
test_tokenization_utils.py 
285 строк · 12.7 Кб
1
# coding=utf-8
2
# Copyright 2018 HuggingFace Inc..
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
"""
16
isort:skip_file
17
"""
18
import os
19
import pickle
20
import tempfile
21
import unittest
22
from typing import Callable, Optional
23

24
import numpy as np
25

26
from transformers import (
27
    BatchEncoding,
28
    BertTokenizer,
29
    BertTokenizerFast,
30
    PreTrainedTokenizer,
31
    PreTrainedTokenizerFast,
32
    TensorType,
33
    TokenSpan,
34
    is_tokenizers_available,
35
)
36
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
37
from transformers.testing_utils import CaptureStderr, require_flax, require_tf, require_tokenizers, require_torch, slow
38

39

40
if is_tokenizers_available():
41
    from tokenizers import Tokenizer
42
    from tokenizers.models import WordPiece
43

44

45
class TokenizerUtilsTest(unittest.TestCase):
46
    def check_tokenizer_from_pretrained(self, tokenizer_class):
47
        s3_models = list(tokenizer_class.max_model_input_sizes.keys())
48
        for model_name in s3_models[:1]:
49
            tokenizer = tokenizer_class.from_pretrained(model_name)
50
            self.assertIsNotNone(tokenizer)
51
            self.assertIsInstance(tokenizer, tokenizer_class)
52
            self.assertIsInstance(tokenizer, PreTrainedTokenizer)
53

54
            for special_tok in tokenizer.all_special_tokens:
55
                self.assertIsInstance(special_tok, str)
56
                special_tok_id = tokenizer.convert_tokens_to_ids(special_tok)
57
                self.assertIsInstance(special_tok_id, int)
58

59
    def assert_dump_and_restore(self, be_original: BatchEncoding, equal_op: Optional[Callable] = None):
60
        batch_encoding_str = pickle.dumps(be_original)
61
        self.assertIsNotNone(batch_encoding_str)
62

63
        be_restored = pickle.loads(batch_encoding_str)
64

65
        # Ensure is_fast is correctly restored
66
        self.assertEqual(be_restored.is_fast, be_original.is_fast)
67

68
        # Ensure encodings are potentially correctly restored
69
        if be_original.is_fast:
70
            self.assertIsNotNone(be_restored.encodings)
71
        else:
72
            self.assertIsNone(be_restored.encodings)
73

74
        # Ensure the keys are the same
75
        for original_v, restored_v in zip(be_original.values(), be_restored.values()):
76
            if equal_op:
77
                self.assertTrue(equal_op(restored_v, original_v))
78
            else:
79
                self.assertEqual(restored_v, original_v)
80

81
    @slow
82
    def test_pretrained_tokenizers(self):
83
        self.check_tokenizer_from_pretrained(GPT2Tokenizer)
84

85
    def test_tensor_type_from_str(self):
86
        self.assertEqual(TensorType("tf"), TensorType.TENSORFLOW)
87
        self.assertEqual(TensorType("pt"), TensorType.PYTORCH)
88
        self.assertEqual(TensorType("np"), TensorType.NUMPY)
89

90
    @require_tokenizers
91
    def test_batch_encoding_pickle(self):
92
        import numpy as np
93

94
        tokenizer_p = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
95
        tokenizer_r = BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
96

97
        # Python no tensor
98
        with self.subTest("BatchEncoding (Python, return_tensors=None)"):
99
            self.assert_dump_and_restore(tokenizer_p("Small example to encode"))
100

101
        with self.subTest("BatchEncoding (Python, return_tensors=NUMPY)"):
102
            self.assert_dump_and_restore(
103
                tokenizer_p("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
104
            )
105

106
        with self.subTest("BatchEncoding (Rust, return_tensors=None)"):
107
            self.assert_dump_and_restore(tokenizer_r("Small example to encode"))
108

109
        with self.subTest("BatchEncoding (Rust, return_tensors=NUMPY)"):
110
            self.assert_dump_and_restore(
111
                tokenizer_r("Small example to encode", return_tensors=TensorType.NUMPY), np.array_equal
112
            )
113

114
    @require_tf
115
    @require_tokenizers
116
    def test_batch_encoding_pickle_tf(self):
117
        import tensorflow as tf
118

119
        def tf_array_equals(t1, t2):
120
            return tf.reduce_all(tf.equal(t1, t2))
121

122
        tokenizer_p = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
123
        tokenizer_r = BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
124

125
        with self.subTest("BatchEncoding (Python, return_tensors=TENSORFLOW)"):
126
            self.assert_dump_and_restore(
127
                tokenizer_p("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
128
            )
129

130
        with self.subTest("BatchEncoding (Rust, return_tensors=TENSORFLOW)"):
131
            self.assert_dump_and_restore(
132
                tokenizer_r("Small example to encode", return_tensors=TensorType.TENSORFLOW), tf_array_equals
133
            )
134

135
    @require_torch
136
    @require_tokenizers
137
    def test_batch_encoding_pickle_pt(self):
138
        import torch
139

140
        tokenizer_p = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
141
        tokenizer_r = BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
142

143
        with self.subTest("BatchEncoding (Python, return_tensors=PYTORCH)"):
144
            self.assert_dump_and_restore(
145
                tokenizer_p("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
146
            )
147

148
        with self.subTest("BatchEncoding (Rust, return_tensors=PYTORCH)"):
149
            self.assert_dump_and_restore(
150
                tokenizer_r("Small example to encode", return_tensors=TensorType.PYTORCH), torch.equal
151
            )
152

153
    @require_tokenizers
154
    def test_batch_encoding_is_fast(self):
155
        tokenizer_p = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
156
        tokenizer_r = BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
157

158
        with self.subTest("Python Tokenizer"):
159
            self.assertFalse(tokenizer_p("Small example to_encode").is_fast)
160

161
        with self.subTest("Rust Tokenizer"):
162
            self.assertTrue(tokenizer_r("Small example to_encode").is_fast)
163

164
    @require_tokenizers
165
    def test_batch_encoding_word_to_tokens(self):
166
        tokenizer_r = BertTokenizerFast.from_pretrained("google-bert/bert-base-cased")
167
        encoded = tokenizer_r(["Test", "\xad", "test"], is_split_into_words=True)
168

169
        self.assertEqual(encoded.word_to_tokens(0), TokenSpan(start=1, end=2))
170
        self.assertEqual(encoded.word_to_tokens(1), None)
171
        self.assertEqual(encoded.word_to_tokens(2), TokenSpan(start=2, end=3))
172

173
    def test_batch_encoding_with_labels(self):
174
        batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
175
        tensor_batch = batch.convert_to_tensors(tensor_type="np")
176
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
177
        self.assertEqual(tensor_batch["labels"].shape, (2,))
178
        # test converting the converted
179
        with CaptureStderr() as cs:
180
            tensor_batch = batch.convert_to_tensors(tensor_type="np")
181
        self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
182

183
        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
184
        tensor_batch = batch.convert_to_tensors(tensor_type="np", prepend_batch_axis=True)
185
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
186
        self.assertEqual(tensor_batch["labels"].shape, (1,))
187

188
    @require_torch
189
    def test_batch_encoding_with_labels_pt(self):
190
        batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
191
        tensor_batch = batch.convert_to_tensors(tensor_type="pt")
192
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
193
        self.assertEqual(tensor_batch["labels"].shape, (2,))
194
        # test converting the converted
195
        with CaptureStderr() as cs:
196
            tensor_batch = batch.convert_to_tensors(tensor_type="pt")
197
        self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
198

199
        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
200
        tensor_batch = batch.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
201
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
202
        self.assertEqual(tensor_batch["labels"].shape, (1,))
203

204
    @require_tf
205
    def test_batch_encoding_with_labels_tf(self):
206
        batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
207
        tensor_batch = batch.convert_to_tensors(tensor_type="tf")
208
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
209
        self.assertEqual(tensor_batch["labels"].shape, (2,))
210
        # test converting the converted
211
        with CaptureStderr() as cs:
212
            tensor_batch = batch.convert_to_tensors(tensor_type="tf")
213
        self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
214

215
        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
216
        tensor_batch = batch.convert_to_tensors(tensor_type="tf", prepend_batch_axis=True)
217
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
218
        self.assertEqual(tensor_batch["labels"].shape, (1,))
219

220
    @require_flax
221
    def test_batch_encoding_with_labels_jax(self):
222
        batch = BatchEncoding({"inputs": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]})
223
        tensor_batch = batch.convert_to_tensors(tensor_type="jax")
224
        self.assertEqual(tensor_batch["inputs"].shape, (2, 3))
225
        self.assertEqual(tensor_batch["labels"].shape, (2,))
226
        # test converting the converted
227
        with CaptureStderr() as cs:
228
            tensor_batch = batch.convert_to_tensors(tensor_type="jax")
229
        self.assertFalse(len(cs.err), msg=f"should have no warning, but got {cs.err}")
230

231
        batch = BatchEncoding({"inputs": [1, 2, 3], "labels": 0})
232
        tensor_batch = batch.convert_to_tensors(tensor_type="jax", prepend_batch_axis=True)
233
        self.assertEqual(tensor_batch["inputs"].shape, (1, 3))
234
        self.assertEqual(tensor_batch["labels"].shape, (1,))
235

236
    def test_padding_accepts_tensors(self):
237
        features = [{"input_ids": np.array([0, 1, 2])}, {"input_ids": np.array([0, 1, 2, 3])}]
238
        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
239

240
        batch = tokenizer.pad(features, padding=True)
241
        self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
242
        self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
243
        batch = tokenizer.pad(features, padding=True, return_tensors="np")
244
        self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
245
        self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
246

247
    @require_torch
248
    def test_padding_accepts_tensors_pt(self):
249
        import torch
250

251
        features = [{"input_ids": torch.tensor([0, 1, 2])}, {"input_ids": torch.tensor([0, 1, 2, 3])}]
252
        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
253

254
        batch = tokenizer.pad(features, padding=True)
255
        self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
256
        self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
257
        batch = tokenizer.pad(features, padding=True, return_tensors="pt")
258
        self.assertTrue(isinstance(batch["input_ids"], torch.Tensor))
259
        self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
260

261
    @require_tf
262
    def test_padding_accepts_tensors_tf(self):
263
        import tensorflow as tf
264

265
        features = [{"input_ids": tf.constant([0, 1, 2])}, {"input_ids": tf.constant([0, 1, 2, 3])}]
266
        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
267

268
        batch = tokenizer.pad(features, padding=True)
269
        self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
270
        self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
271
        batch = tokenizer.pad(features, padding=True, return_tensors="tf")
272
        self.assertTrue(isinstance(batch["input_ids"], tf.Tensor))
273
        self.assertEqual(batch["input_ids"].numpy().tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
274

275
    @require_tokenizers
276
    def test_instantiation_from_tokenizers(self):
277
        bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
278
        PreTrainedTokenizerFast(tokenizer_object=bert_tokenizer)
279

280
    @require_tokenizers
281
    def test_instantiation_from_tokenizers_json_file(self):
282
        bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
283
        with tempfile.TemporaryDirectory() as tmpdirname:
284
            bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json"))
285
            PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json"))
286

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.