transformers
244 строки · 9.5 Кб
1# Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json16import os17import random18import shutil19import tempfile20import unittest21
22import numpy as np23import pytest24
25from transformers import BertTokenizer, BertTokenizerFast26from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES27from transformers.testing_utils import require_vision28from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available29
30
31if is_vision_available():32from PIL import Image33
34from transformers import FlavaImageProcessor, FlavaProcessor35from transformers.models.flava.image_processing_flava import (36FLAVA_CODEBOOK_MEAN,37FLAVA_CODEBOOK_STD,38FLAVA_IMAGE_MEAN,39FLAVA_IMAGE_STD,40)41
42
43@require_vision
44class FlavaProcessorTest(unittest.TestCase):45def setUp(self):46self.tmpdirname = tempfile.mkdtemp()47
48vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] # fmt: skip49self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])50
51with open(self.vocab_file, "w", encoding="utf-8") as fp:52fp.write("".join([x + "\n" for x in vocab_tokens]))53
54image_processor_map = {55"image_mean": FLAVA_IMAGE_MEAN,56"image_std": FLAVA_IMAGE_STD,57"do_normalize": True,58"do_resize": True,59"size": 224,60"do_center_crop": True,61"crop_size": 224,62"input_size_patches": 14,63"total_mask_patches": 75,64"mask_group_max_patches": None,65"mask_group_min_patches": 16,66"mask_group_min_aspect_ratio": 0.3,67"mask_group_max_aspect_ratio": None,68"codebook_do_resize": True,69"codebook_size": 112,70"codebook_do_center_crop": True,71"codebook_crop_size": 112,72"codebook_do_map_pixels": True,73"codebook_do_normalize": True,74"codebook_image_mean": FLAVA_CODEBOOK_MEAN,75"codebook_image_std": FLAVA_CODEBOOK_STD,76}77
78self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME)79with open(self.image_processor_file, "w", encoding="utf-8") as fp:80json.dump(image_processor_map, fp)81
82def get_tokenizer(self, **kwargs):83return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)84
85def get_rust_tokenizer(self, **kwargs):86return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)87
88def get_image_processor(self, **kwargs):89return FlavaImageProcessor.from_pretrained(self.tmpdirname, **kwargs)90
91def tearDown(self):92shutil.rmtree(self.tmpdirname)93
94def prepare_image_inputs(self):95"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,96or a list of PyTorch tensors if one specifies torchify=True.
97"""
98
99image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]100
101image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]102
103return image_inputs104
105def test_save_load_pretrained_default(self):106tokenizer_slow = self.get_tokenizer()107tokenizer_fast = self.get_rust_tokenizer()108image_processor = self.get_image_processor()109
110processor_slow = FlavaProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)111processor_slow.save_pretrained(self.tmpdirname)112processor_slow = FlavaProcessor.from_pretrained(self.tmpdirname, use_fast=False)113
114processor_fast = FlavaProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)115processor_fast.save_pretrained(self.tmpdirname)116processor_fast = FlavaProcessor.from_pretrained(self.tmpdirname)117
118self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())119self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())120self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())121self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)122self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)123
124self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())125self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())126self.assertIsInstance(processor_slow.image_processor, FlavaImageProcessor)127self.assertIsInstance(processor_fast.image_processor, FlavaImageProcessor)128
129def test_save_load_pretrained_additional_features(self):130processor = FlavaProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())131processor.save_pretrained(self.tmpdirname)132
133tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")134image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)135
136processor = FlavaProcessor.from_pretrained(137self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0138)139
140self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())141self.assertIsInstance(processor.tokenizer, BertTokenizerFast)142
143self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())144self.assertIsInstance(processor.image_processor, FlavaImageProcessor)145
146def test_image_processor(self):147image_processor = self.get_image_processor()148tokenizer = self.get_tokenizer()149
150processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)151
152image_input = self.prepare_image_inputs()153
154input_feat_extract = image_processor(image_input, return_tensors="np")155input_processor = processor(images=image_input, return_tensors="np")156
157for key in input_feat_extract.keys():158self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)159
160# With rest of the args161random.seed(1234)162input_feat_extract = image_processor(163image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"164)165random.seed(1234)166input_processor = processor(167images=image_input, return_image_mask=True, return_codebook_pixels=True, return_tensors="np"168)169
170for key in input_feat_extract.keys():171self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)172
173def test_tokenizer(self):174image_processor = self.get_image_processor()175tokenizer = self.get_tokenizer()176
177processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)178
179input_str = "lower newer"180
181encoded_processor = processor(text=input_str)182
183encoded_tok = tokenizer(input_str)184
185for key in encoded_tok.keys():186self.assertListEqual(encoded_tok[key], encoded_processor[key])187
188def test_processor(self):189image_processor = self.get_image_processor()190tokenizer = self.get_tokenizer()191
192processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)193
194input_str = "lower newer"195image_input = self.prepare_image_inputs()196
197inputs = processor(text=input_str, images=image_input)198
199self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])200
201# add extra args202inputs = processor(text=input_str, images=image_input, return_codebook_pixels=True, return_image_mask=True)203
204self.assertListEqual(205list(inputs.keys()),206[207"input_ids",208"token_type_ids",209"attention_mask",210"pixel_values",211"codebook_pixel_values",212"bool_masked_pos",213],214)215
216# test if it raises when no input is passed217with pytest.raises(ValueError):218processor()219
220def test_tokenizer_decode(self):221image_processor = self.get_image_processor()222tokenizer = self.get_tokenizer()223
224processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)225
226predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]227
228decoded_processor = processor.batch_decode(predicted_ids)229decoded_tok = tokenizer.batch_decode(predicted_ids)230
231self.assertListEqual(decoded_tok, decoded_processor)232
233def test_model_input_names(self):234image_processor = self.get_image_processor()235tokenizer = self.get_tokenizer()236
237processor = FlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)238
239input_str = "lower newer"240image_input = self.prepare_image_inputs()241
242inputs = processor(text=input_str, images=image_input)243
244self.assertListEqual(list(inputs.keys()), processor.model_input_names)245