transformers
191 строка · 7.1 Кб
1# Copyright 2023 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import shutil15import tempfile16import unittest17
18import numpy as np19import pytest20
21from transformers.testing_utils import require_vision22from transformers.utils import is_vision_available23
24
25if is_vision_available():26from PIL import Image27
28from transformers import (29AutoProcessor,30BertTokenizerFast,31BlipImageProcessor,32GPT2Tokenizer,33InstructBlipProcessor,34PreTrainedTokenizerFast,35)36
37
38@require_vision
39class InstructBlipProcessorTest(unittest.TestCase):40def setUp(self):41self.tmpdirname = tempfile.mkdtemp()42
43image_processor = BlipImageProcessor()44tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")45qformer_tokenizer = BertTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-bert")46
47processor = InstructBlipProcessor(image_processor, tokenizer, qformer_tokenizer)48
49processor.save_pretrained(self.tmpdirname)50
51def get_tokenizer(self, **kwargs):52return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer53
54def get_image_processor(self, **kwargs):55return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor56
57def get_qformer_tokenizer(self, **kwargs):58return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).qformer_tokenizer59
60def tearDown(self):61shutil.rmtree(self.tmpdirname)62
63def prepare_image_inputs(self):64"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,65or a list of PyTorch tensors if one specifies torchify=True.
66"""
67
68image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]69
70image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]71
72return image_inputs73
74def test_save_load_pretrained_additional_features(self):75processor = InstructBlipProcessor(76tokenizer=self.get_tokenizer(),77image_processor=self.get_image_processor(),78qformer_tokenizer=self.get_qformer_tokenizer(),79)80processor.save_pretrained(self.tmpdirname)81
82tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")83image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)84
85processor = InstructBlipProcessor.from_pretrained(86self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.087)88
89self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())90self.assertIsInstance(processor.tokenizer, PreTrainedTokenizerFast)91
92self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())93self.assertIsInstance(processor.image_processor, BlipImageProcessor)94self.assertIsInstance(processor.qformer_tokenizer, BertTokenizerFast)95
96def test_image_processor(self):97image_processor = self.get_image_processor()98tokenizer = self.get_tokenizer()99qformer_tokenizer = self.get_qformer_tokenizer()100
101processor = InstructBlipProcessor(102tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer103)104
105image_input = self.prepare_image_inputs()106
107input_feat_extract = image_processor(image_input, return_tensors="np")108input_processor = processor(images=image_input, return_tensors="np")109
110for key in input_feat_extract.keys():111self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)112
113def test_tokenizer(self):114image_processor = self.get_image_processor()115tokenizer = self.get_tokenizer()116qformer_tokenizer = self.get_qformer_tokenizer()117
118processor = InstructBlipProcessor(119tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer120)121
122input_str = "lower newer"123
124encoded_processor = processor(text=input_str)125
126encoded_tokens = tokenizer(input_str, return_token_type_ids=False)127encoded_tokens_qformer = qformer_tokenizer(input_str, return_token_type_ids=False)128
129for key in encoded_tokens.keys():130self.assertListEqual(encoded_tokens[key], encoded_processor[key])131
132for key in encoded_tokens_qformer.keys():133self.assertListEqual(encoded_tokens_qformer[key], encoded_processor["qformer_" + key])134
135def test_processor(self):136image_processor = self.get_image_processor()137tokenizer = self.get_tokenizer()138qformer_tokenizer = self.get_qformer_tokenizer()139
140processor = InstructBlipProcessor(141tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer142)143
144input_str = "lower newer"145image_input = self.prepare_image_inputs()146
147inputs = processor(text=input_str, images=image_input)148
149self.assertListEqual(150list(inputs.keys()),151["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],152)153
154# test if it raises when no input is passed155with pytest.raises(ValueError):156processor()157
158def test_tokenizer_decode(self):159image_processor = self.get_image_processor()160tokenizer = self.get_tokenizer()161qformer_tokenizer = self.get_qformer_tokenizer()162
163processor = InstructBlipProcessor(164tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer165)166
167predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]168
169decoded_processor = processor.batch_decode(predicted_ids)170decoded_tok = tokenizer.batch_decode(predicted_ids)171
172self.assertListEqual(decoded_tok, decoded_processor)173
174def test_model_input_names(self):175image_processor = self.get_image_processor()176tokenizer = self.get_tokenizer()177qformer_tokenizer = self.get_qformer_tokenizer()178
179processor = InstructBlipProcessor(180tokenizer=tokenizer, image_processor=image_processor, qformer_tokenizer=qformer_tokenizer181)182
183input_str = "lower newer"184image_input = self.prepare_image_inputs()185
186inputs = processor(text=input_str, images=image_input)187
188self.assertListEqual(189list(inputs.keys()),190["input_ids", "attention_mask", "qformer_input_ids", "qformer_attention_mask", "pixel_values"],191)192