transformers

Форк
0
/
test_processor_flava.py 
244 строки · 9.5 Кб
1
# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
import os
17
import random
18
import shutil
19
import tempfile
20
import unittest
21

22
import numpy as np
23
import pytest
24

25
from transformers import BertTokenizer, BertTokenizerFast
26
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
27
from transformers.testing_utils import require_vision
28
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
29

30

31
if is_vision_available():
32
    from PIL import Image
33

34
    from transformers import FlavaImageProcessor, FlavaProcessor
35
    from transformers.models.flava.image_processing_flava import (
36
        FLAVA_CODEBOOK_MEAN,
37
        FLAVA_CODEBOOK_STD,
38
        FLAVA_IMAGE_MEAN,
39
        FLAVA_IMAGE_STD,
40
    )
41

42

43
@require_vision
44
class FlavaProcessorTest(unittest.TestCase):
45
    def setUp(self):
46
        self.tmpdirname = tempfile.mkdtemp()
47

48
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"]  # fmt: skip
49
        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
50

51
        with open(self.vocab_file, "w", encoding="utf-8") as fp:
52
            fp.write("".join([x + "\n" for x in vocab_tokens]))
53

54
        image_processor_map = {
55
            "image_mean": FLAVA_IMAGE_MEAN,
56
            "image_std": FLAVA_IMAGE_STD,
57
            "do_normalize": True,
58
            "do_resize": True,
59
            "size": 224,
60
            "do_center_crop": True,
61
            "crop_size": 224,
62
            "input_size_patches": 14,
63
            "total_mask_patches": 75,
64
            "mask_group_max_patches": None,
65
            "mask_group_min_patches": 16,
66
            "mask_group_min_aspect_ratio": 0.3,
67
            "mask_group_max_aspect_ratio": None,
68
            "codebook_do_resize": True,
69
            "codebook_size": 112,
70
            "codebook_do_center_crop": True,
71
            "codebook_crop_size": 112,
72
            "codebook_do_map_pixels": True,
73
            "codebook_do_normalize": True,
74
            "codebook_image_mean": FLAVA_CODEBOOK_MEAN,
75
            "codebook_image_std": FLAVA_CODEBOOK_STD,
76
        }
77

78
        self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME)
79
        with open(self.image_processor_file, "w", encoding="utf-8") as fp:
80
            json.dump(image_processor_map, fp)
81

82
    def get_tokenizer(self, **kwargs):
83
        return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
84

85
    def get_rust_tokenizer(self, **kwargs):
86
        return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
87

88
    def get_image_processor(self, **kwargs):
89
        return FlavaImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
90

91
    def tearDown(self):
92
        shutil.rmtree(self.tmpdirname)
93

94
    def prepare_image_inputs(self):
95
        """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
96
        or a list of PyTorch tensors if one specifies torchify=True.
97
        """
98

99
        image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
100

101
        image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
102

103
        return image_inputs
104

105
    def test_save_load_pretrained_default(self):
106
        tokenizer_slow = self.get_tokenizer()
107
        tokenizer_fast = self.get_rust_tokenizer()
108
        image_processor = self.get_image_processor()
109

110
        processor_slow = FlavaProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
111
        processor_slow.save_pretrained(self.tmpdirname)
112
        processor_slow = FlavaProcessor.from_pretrained(self.tmpdirname, use_fast=False)
113

114
        processor_fast = FlavaProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
115
        processor_fast.save_pretrained(self.tmpdirname)
116
        processor_fast = FlavaProcessor.from_pretrained(self.tmpdirname)
117

118
        self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
119
        self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
120
        self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
121
        self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
122
        self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
123

124
        self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
125
        self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
126
        self.assertIsInstance(processor_slow.image_processor, FlavaImageProcessor)
127
        self.assertIsInstance(processor_fast.image_processor, FlavaImageProcessor)
128

129
    def test_save_load_pretrained_additional_features(self):
130
        processor = FlavaProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
131
        processor.save_pretrained(self.tmpdirname)
132

133
        tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
134
        image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
135

136
        processor = FlavaProcessor.from_pretrained(
137
            self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
138
        )
139

140
        self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
141
        self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
142

143
        self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
144
        self.assertIsInstance(processor.image_processor, FlavaImageProcessor)
145

146
    def test_image_processor(self):
147
        image_processor = self.get_image_processor()
148
        tokenizer = self.get_tokenizer()
149

150
        processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
151

152
        image_input = self.prepare_image_inputs()
153

154
        input_feat_extract = image_processor(image_input, return_tensors="np")
155
        input_processor = processor(images=image_input, return_tensors="np")
156

157
        for key in input_feat_extract.keys():
158
            self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
159

160
        # With rest of the args
161
        random.seed(1234)
162
        input_feat_extract = image_processor(
163
            image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
164
        )
165
        random.seed(1234)
166
        input_processor = processor(
167
            images=image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"
168
        )
169

170
        for key in input_feat_extract.keys():
171
            self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
172

173
    def test_tokenizer(self):
174
        image_processor = self.get_image_processor()
175
        tokenizer = self.get_tokenizer()
176

177
        processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
178

179
        input_str = "lower newer"
180

181
        encoded_processor = processor(text=input_str)
182

183
        encoded_tok = tokenizer(input_str)
184

185
        for key in encoded_tok.keys():
186
            self.assertListEqual(encoded_tok[key], encoded_processor[key])
187

188
    def test_processor(self):
189
        image_processor = self.get_image_processor()
190
        tokenizer = self.get_tokenizer()
191

192
        processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
193

194
        input_str = "lower newer"
195
        image_input = self.prepare_image_inputs()
196

197
        inputs = processor(text=input_str, images=image_input)
198

199
        self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
200

201
        # add extra args
202
        inputs = processor(text=input_str, images=image_input, return_codebook_pixels=True, return_image_mask=True)
203

204
        self.assertListEqual(
205
            list(inputs.keys()),
206
            [
207
                "input_ids",
208
                "token_type_ids",
209
                "attention_mask",
210
                "pixel_values",
211
                "codebook_pixel_values",
212
                "bool_masked_pos",
213
            ],
214
        )
215

216
        # test if it raises when no input is passed
217
        with pytest.raises(ValueError):
218
            processor()
219

220
    def test_tokenizer_decode(self):
221
        image_processor = self.get_image_processor()
222
        tokenizer = self.get_tokenizer()
223

224
        processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
225

226
        predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
227

228
        decoded_processor = processor.batch_decode(predicted_ids)
229
        decoded_tok = tokenizer.batch_decode(predicted_ids)
230

231
        self.assertListEqual(decoded_tok, decoded_processor)
232

233
    def test_model_input_names(self):
234
        image_processor = self.get_image_processor()
235
        tokenizer = self.get_tokenizer()
236

237
        processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
238

239
        input_str = "lower newer"
240
        image_input = self.prepare_image_inputs()
241

242
        inputs = processor(text=input_str, images=image_input)
243

244
        self.assertListEqual(list(inputs.keys()), processor.model_input_names)
245

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.