transformers
192 строки · 7.2 Кб
1# Copyright 2023 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import shutil
15import tempfile
16import unittest
17
18import numpy as np
19import pytest
20
21from transformers.testing_utils import require_torch, require_vision
22from transformers.utils import is_vision_available
23
24
25if is_vision_available():
26from PIL import Image
27
28from transformers import (
29AutoProcessor,
30Pix2StructImageProcessor,
31Pix2StructProcessor,
32PreTrainedTokenizerFast,
33T5Tokenizer,
34)
35
36
37@require_vision
38@require_torch
39class Pix2StructProcessorTest(unittest.TestCase):
40def setUp(self):
41self.tmpdirname = tempfile.mkdtemp()
42
43image_processor = Pix2StructImageProcessor()
44tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
45
46processor = Pix2StructProcessor(image_processor, tokenizer)
47
48processor.save_pretrained(self.tmpdirname)
49
50def get_tokenizer(self, **kwargs):
51return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
52
53def get_image_processor(self, **kwargs):
54return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
55
56def tearDown(self):
57shutil.rmtree(self.tmpdirname)
58
59def prepare_image_inputs(self):
60"""
61This function prepares a list of random PIL images of the same fixed size.
62"""
63
64image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
65
66image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
67
68return image_inputs
69
70def test_save_load_pretrained_additional_features(self):
71processor = Pix2StructProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
72processor.save_pretrained(self.tmpdirname)
73
74tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
75image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
76
77processor = Pix2StructProcessor.from_pretrained(
78self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
79)
80
81self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
82self.assertIsInstance(processor.tokenizer, PreTrainedTokenizerFast)
83
84self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
85self.assertIsInstance(processor.image_processor, Pix2StructImageProcessor)
86
87def test_image_processor(self):
88image_processor = self.get_image_processor()
89tokenizer = self.get_tokenizer()
90
91processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
92
93image_input = self.prepare_image_inputs()
94
95input_feat_extract = image_processor(image_input, return_tensors="np")
96input_processor = processor(images=image_input, return_tensors="np")
97
98for key in input_feat_extract.keys():
99self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
100
101def test_tokenizer(self):
102image_processor = self.get_image_processor()
103tokenizer = self.get_tokenizer()
104
105processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
106
107input_str = "lower newer"
108
109encoded_processor = processor(text=input_str)
110
111encoded_tok = tokenizer(input_str, return_token_type_ids=False, add_special_tokens=True)
112
113for key in encoded_tok.keys():
114self.assertListEqual(encoded_tok[key], encoded_processor[key])
115
116def test_processor(self):
117image_processor = self.get_image_processor()
118tokenizer = self.get_tokenizer()
119
120processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
121
122input_str = "lower newer"
123image_input = self.prepare_image_inputs()
124
125inputs = processor(text=input_str, images=image_input)
126
127self.assertListEqual(
128list(inputs.keys()), ["flattened_patches", "attention_mask", "decoder_attention_mask", "decoder_input_ids"]
129)
130
131# test if it raises when no input is passed
132with pytest.raises(ValueError):
133processor()
134
135def test_processor_max_patches(self):
136image_processor = self.get_image_processor()
137tokenizer = self.get_tokenizer()
138
139processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
140
141input_str = "lower newer"
142image_input = self.prepare_image_inputs()
143
144inputs = processor(text=input_str, images=image_input)
145
146max_patches = [512, 1024, 2048, 4096]
147expected_hidden_size = [770, 770, 770, 770]
148# with text
149for i, max_patch in enumerate(max_patches):
150inputs = processor(text=input_str, images=image_input, max_patches=max_patch)
151self.assertEqual(inputs["flattened_patches"][0].shape[0], max_patch)
152self.assertEqual(inputs["flattened_patches"][0].shape[1], expected_hidden_size[i])
153
154# without text input
155for i, max_patch in enumerate(max_patches):
156inputs = processor(images=image_input, max_patches=max_patch)
157self.assertEqual(inputs["flattened_patches"][0].shape[0], max_patch)
158self.assertEqual(inputs["flattened_patches"][0].shape[1], expected_hidden_size[i])
159
160def test_tokenizer_decode(self):
161image_processor = self.get_image_processor()
162tokenizer = self.get_tokenizer()
163
164processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
165
166predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
167
168decoded_processor = processor.batch_decode(predicted_ids)
169decoded_tok = tokenizer.batch_decode(predicted_ids)
170
171self.assertListEqual(decoded_tok, decoded_processor)
172
173def test_model_input_names(self):
174image_processor = self.get_image_processor()
175tokenizer = self.get_tokenizer()
176
177processor = Pix2StructProcessor(tokenizer=tokenizer, image_processor=image_processor)
178
179input_str = "lower newer"
180image_input = self.prepare_image_inputs()
181
182inputs = processor(text=input_str, images=image_input)
183
184# For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"]
185self.assertListEqual(
186list(inputs.keys()), ["flattened_patches", "attention_mask", "decoder_attention_mask", "decoder_input_ids"]
187)
188
189inputs = processor(text=input_str)
190
191# For now the processor supports only ["flattened_patches", "input_ids", "attention_mask", "decoder_attention_mask"]
192self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask"])
193