transformers

Форк
0
/
test_processor_chinese_clip.py 
213 строк · 8.3 Кб
1
# Copyright 2021 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
import os
17
import shutil
18
import tempfile
19
import unittest
20

21
import numpy as np
22
import pytest
23

24
from transformers import BertTokenizer, BertTokenizerFast
25
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
26
from transformers.testing_utils import require_vision
27
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
28

29

30
if is_vision_available():
31
    from PIL import Image
32

33
    from transformers import ChineseCLIPImageProcessor, ChineseCLIPProcessor
34

35

36
@require_vision
37
class ChineseCLIPProcessorTest(unittest.TestCase):
38
    def setUp(self):
39
        self.tmpdirname = tempfile.mkdtemp()
40

41
        vocab_tokens = [
42
            "[UNK]",
43
            "[CLS]",
44
            "[SEP]",
45
            "[PAD]",
46
            "[MASK]",
47
            "的",
48
            "价",
49
            "格",
50
            "是",
51
            "15",
52
            "便",
53
            "alex",
54
            "##andra",
55
            ",",
56
            "。",
57
            "-",
58
            "t",
59
            "shirt",
60
        ]
61
        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
62
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
63
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
64

65
        image_processor_map = {
66
            "do_resize": True,
67
            "size": {"height": 224, "width": 224},
68
            "do_center_crop": True,
69
            "crop_size": {"height": 18, "width": 18},
70
            "do_normalize": True,
71
            "image_mean": [0.48145466, 0.4578275, 0.40821073],
72
            "image_std": [0.26862954, 0.26130258, 0.27577711],
73
            "do_convert_rgb": True,
74
        }
75
        self.image_processor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
76
        with open(self.image_processor_file, "w", encoding="utf-8") as fp:
77
            json.dump(image_processor_map, fp)
78

79
    def get_tokenizer(self, **kwargs):
80
        return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
81

82
    def get_rust_tokenizer(self, **kwargs):
83
        return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
84

85
    def get_image_processor(self, **kwargs):
86
        return ChineseCLIPImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
87

88
    def tearDown(self):
89
        shutil.rmtree(self.tmpdirname)
90

91
    def prepare_image_inputs(self):
92
        """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
93
        or a list of PyTorch tensors if one specifies torchify=True.
94
        """
95

96
        image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
97

98
        image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
99

100
        return image_inputs
101

102
    def test_save_load_pretrained_default(self):
103
        tokenizer_slow = self.get_tokenizer()
104
        tokenizer_fast = self.get_rust_tokenizer()
105
        image_processor = self.get_image_processor()
106

107
        processor_slow = ChineseCLIPProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
108
        processor_slow.save_pretrained(self.tmpdirname)
109
        processor_slow = ChineseCLIPProcessor.from_pretrained(self.tmpdirname, use_fast=False)
110

111
        processor_fast = ChineseCLIPProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
112
        processor_fast.save_pretrained(self.tmpdirname)
113
        processor_fast = ChineseCLIPProcessor.from_pretrained(self.tmpdirname)
114

115
        self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
116
        self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
117
        self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
118
        self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
119
        self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
120

121
        self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
122
        self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
123
        self.assertIsInstance(processor_slow.image_processor, ChineseCLIPImageProcessor)
124
        self.assertIsInstance(processor_fast.image_processor, ChineseCLIPImageProcessor)
125

126
    def test_save_load_pretrained_additional_features(self):
127
        processor = ChineseCLIPProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
128
        processor.save_pretrained(self.tmpdirname)
129

130
        tokenizer_add_kwargs = self.get_tokenizer(cls_token="(CLS)", sep_token="(SEP)")
131
        image_processor_add_kwargs = self.get_image_processor(do_normalize=False)
132

133
        processor = ChineseCLIPProcessor.from_pretrained(
134
            self.tmpdirname, cls_token="(CLS)", sep_token="(SEP)", do_normalize=False
135
        )
136

137
        self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
138
        self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
139

140
        self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
141
        self.assertIsInstance(processor.image_processor, ChineseCLIPImageProcessor)
142

143
    def test_image_processor(self):
144
        image_processor = self.get_image_processor()
145
        tokenizer = self.get_tokenizer()
146

147
        processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
148

149
        image_input = self.prepare_image_inputs()
150

151
        input_feat_extract = image_processor(image_input, return_tensors="np")
152
        input_processor = processor(images=image_input, return_tensors="np")
153

154
        for key in input_feat_extract.keys():
155
            self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
156

157
    def test_tokenizer(self):
158
        image_processor = self.get_image_processor()
159
        tokenizer = self.get_tokenizer()
160

161
        processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
162

163
        input_str = "Alexandra,T-shirt的价格是15便士。"
164

165
        encoded_processor = processor(text=input_str)
166

167
        encoded_tok = tokenizer(input_str)
168

169
        for key in encoded_tok.keys():
170
            self.assertListEqual(encoded_tok[key], encoded_processor[key])
171

172
    def test_processor(self):
173
        image_processor = self.get_image_processor()
174
        tokenizer = self.get_tokenizer()
175

176
        processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
177

178
        input_str = "Alexandra,T-shirt的价格是15便士。"
179
        image_input = self.prepare_image_inputs()
180

181
        inputs = processor(text=input_str, images=image_input)
182

183
        self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
184

185
        # test if it raises when no input is passed
186
        with pytest.raises(ValueError):
187
            processor()
188

189
    def test_tokenizer_decode(self):
190
        image_processor = self.get_image_processor()
191
        tokenizer = self.get_tokenizer()
192

193
        processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
194

195
        predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
196

197
        decoded_processor = processor.batch_decode(predicted_ids)
198
        decoded_tok = tokenizer.batch_decode(predicted_ids)
199

200
        self.assertListEqual(decoded_tok, decoded_processor)
201

202
    def test_model_input_names(self):
203
        image_processor = self.get_image_processor()
204
        tokenizer = self.get_tokenizer()
205

206
        processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
207

208
        input_str = "Alexandra,T-shirt的价格是15便士。"
209
        image_input = self.prepare_image_inputs()
210

211
        inputs = processor(text=input_str, images=image_input)
212

213
        self.assertListEqual(list(inputs.keys()), processor.model_input_names)
214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.