transformers
213 строк · 8.3 Кб
1# Copyright 2021 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json
16import os
17import shutil
18import tempfile
19import unittest
20
21import numpy as np
22import pytest
23
24from transformers import BertTokenizer, BertTokenizerFast
25from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
26from transformers.testing_utils import require_vision
27from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
28
29
30if is_vision_available():
31from PIL import Image
32
33from transformers import ChineseCLIPImageProcessor, ChineseCLIPProcessor
34
35
36@require_vision
37class ChineseCLIPProcessorTest(unittest.TestCase):
38def setUp(self):
39self.tmpdirname = tempfile.mkdtemp()
40
41vocab_tokens = [
42"[UNK]",
43"[CLS]",
44"[SEP]",
45"[PAD]",
46"[MASK]",
47"的",
48"价",
49"格",
50"是",
51"15",
52"便",
53"alex",
54"##andra",
55",",
56"。",
57"-",
58"t",
59"shirt",
60]
61self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
62with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
63vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
64
65image_processor_map = {
66"do_resize": True,
67"size": {"height": 224, "width": 224},
68"do_center_crop": True,
69"crop_size": {"height": 18, "width": 18},
70"do_normalize": True,
71"image_mean": [0.48145466, 0.4578275, 0.40821073],
72"image_std": [0.26862954, 0.26130258, 0.27577711],
73"do_convert_rgb": True,
74}
75self.image_processor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
76with open(self.image_processor_file, "w", encoding="utf-8") as fp:
77json.dump(image_processor_map, fp)
78
79def get_tokenizer(self, **kwargs):
80return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
81
82def get_rust_tokenizer(self, **kwargs):
83return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
84
85def get_image_processor(self, **kwargs):
86return ChineseCLIPImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
87
88def tearDown(self):
89shutil.rmtree(self.tmpdirname)
90
91def prepare_image_inputs(self):
92"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
93or a list of PyTorch tensors if one specifies torchify=True.
94"""
95
96image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
97
98image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
99
100return image_inputs
101
102def test_save_load_pretrained_default(self):
103tokenizer_slow = self.get_tokenizer()
104tokenizer_fast = self.get_rust_tokenizer()
105image_processor = self.get_image_processor()
106
107processor_slow = ChineseCLIPProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
108processor_slow.save_pretrained(self.tmpdirname)
109processor_slow = ChineseCLIPProcessor.from_pretrained(self.tmpdirname, use_fast=False)
110
111processor_fast = ChineseCLIPProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
112processor_fast.save_pretrained(self.tmpdirname)
113processor_fast = ChineseCLIPProcessor.from_pretrained(self.tmpdirname)
114
115self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
116self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
117self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
118self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
119self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
120
121self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
122self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
123self.assertIsInstance(processor_slow.image_processor, ChineseCLIPImageProcessor)
124self.assertIsInstance(processor_fast.image_processor, ChineseCLIPImageProcessor)
125
126def test_save_load_pretrained_additional_features(self):
127processor = ChineseCLIPProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
128processor.save_pretrained(self.tmpdirname)
129
130tokenizer_add_kwargs = self.get_tokenizer(cls_token="(CLS)", sep_token="(SEP)")
131image_processor_add_kwargs = self.get_image_processor(do_normalize=False)
132
133processor = ChineseCLIPProcessor.from_pretrained(
134self.tmpdirname, cls_token="(CLS)", sep_token="(SEP)", do_normalize=False
135)
136
137self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
138self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
139
140self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
141self.assertIsInstance(processor.image_processor, ChineseCLIPImageProcessor)
142
143def test_image_processor(self):
144image_processor = self.get_image_processor()
145tokenizer = self.get_tokenizer()
146
147processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
148
149image_input = self.prepare_image_inputs()
150
151input_feat_extract = image_processor(image_input, return_tensors="np")
152input_processor = processor(images=image_input, return_tensors="np")
153
154for key in input_feat_extract.keys():
155self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
156
157def test_tokenizer(self):
158image_processor = self.get_image_processor()
159tokenizer = self.get_tokenizer()
160
161processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
162
163input_str = "Alexandra,T-shirt的价格是15便士。"
164
165encoded_processor = processor(text=input_str)
166
167encoded_tok = tokenizer(input_str)
168
169for key in encoded_tok.keys():
170self.assertListEqual(encoded_tok[key], encoded_processor[key])
171
172def test_processor(self):
173image_processor = self.get_image_processor()
174tokenizer = self.get_tokenizer()
175
176processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
177
178input_str = "Alexandra,T-shirt的价格是15便士。"
179image_input = self.prepare_image_inputs()
180
181inputs = processor(text=input_str, images=image_input)
182
183self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
184
185# test if it raises when no input is passed
186with pytest.raises(ValueError):
187processor()
188
189def test_tokenizer_decode(self):
190image_processor = self.get_image_processor()
191tokenizer = self.get_tokenizer()
192
193processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
194
195predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
196
197decoded_processor = processor.batch_decode(predicted_ids)
198decoded_tok = tokenizer.batch_decode(predicted_ids)
199
200self.assertListEqual(decoded_tok, decoded_processor)
201
202def test_model_input_names(self):
203image_processor = self.get_image_processor()
204tokenizer = self.get_tokenizer()
205
206processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
207
208input_str = "Alexandra,T-shirt的价格是15便士。"
209image_input = self.prepare_image_inputs()
210
211inputs = processor(text=input_str, images=image_input)
212
213self.assertListEqual(list(inputs.keys()), processor.model_input_names)
214