transformers
166 строк · 6.5 Кб
1# coding=utf-8
2# Copyright 2021 HuggingFace Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import unittest18
19from transformers.testing_utils import require_torch, require_vision20from transformers.utils import is_torch_available, is_vision_available21
22from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs23
24
25if is_vision_available():26from transformers import ChineseCLIPImageProcessor27
28
29if is_torch_available():30pass31
32
33class ChineseCLIPImageProcessingTester(unittest.TestCase):34def __init__(35self,36parent,37batch_size=7,38num_channels=3,39image_size=18,40min_resolution=30,41max_resolution=400,42do_resize=True,43size=None,44do_center_crop=True,45crop_size=None,46do_normalize=True,47image_mean=[0.48145466, 0.4578275, 0.40821073],48image_std=[0.26862954, 0.26130258, 0.27577711],49do_convert_rgb=True,50):51size = size if size is not None else {"height": 224, "width": 224}52crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}53self.parent = parent54self.batch_size = batch_size55self.num_channels = num_channels56self.image_size = image_size57self.min_resolution = min_resolution58self.max_resolution = max_resolution59self.do_resize = do_resize60self.size = size61self.do_center_crop = do_center_crop62self.crop_size = crop_size63self.do_normalize = do_normalize64self.image_mean = image_mean65self.image_std = image_std66self.do_convert_rgb = do_convert_rgb67
68def prepare_image_processor_dict(self):69return {70"do_resize": self.do_resize,71"size": self.size,72"do_center_crop": self.do_center_crop,73"crop_size": self.crop_size,74"do_normalize": self.do_normalize,75"image_mean": self.image_mean,76"image_std": self.image_std,77"do_convert_rgb": self.do_convert_rgb,78}79
80def expected_output_image_shape(self, images):81return 3, self.crop_size["height"], self.crop_size["width"]82
83def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):84return prepare_image_inputs(85batch_size=self.batch_size,86num_channels=self.num_channels,87min_resolution=self.min_resolution,88max_resolution=self.max_resolution,89equal_resolution=equal_resolution,90numpify=numpify,91torchify=torchify,92)93
94
95@require_torch
96@require_vision
97class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):98image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None99
100def setUp(self):101self.image_processor_tester = ChineseCLIPImageProcessingTester(self, do_center_crop=True)102
103@property104def image_processor_dict(self):105return self.image_processor_tester.prepare_image_processor_dict()106
107def test_image_processor_properties(self):108image_processing = self.image_processing_class(**self.image_processor_dict)109self.assertTrue(hasattr(image_processing, "do_resize"))110self.assertTrue(hasattr(image_processing, "size"))111self.assertTrue(hasattr(image_processing, "do_center_crop"))112self.assertTrue(hasattr(image_processing, "center_crop"))113self.assertTrue(hasattr(image_processing, "do_normalize"))114self.assertTrue(hasattr(image_processing, "image_mean"))115self.assertTrue(hasattr(image_processing, "image_std"))116self.assertTrue(hasattr(image_processing, "do_convert_rgb"))117
118def test_image_processor_from_dict_with_kwargs(self):119image_processor = self.image_processing_class.from_dict(self.image_processor_dict)120self.assertEqual(image_processor.size, {"height": 224, "width": 224})121self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})122
123image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)124self.assertEqual(image_processor.size, {"shortest_edge": 42})125self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})126
127@unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy128def test_call_numpy_4_channels(self):129pass130
131
132@require_torch
133@require_vision
134class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase):135image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None136
137def setUp(self):138self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)139self.expected_encoded_image_num_channels = 3140
141@property142def image_processor_dict(self):143return self.image_processor_tester.prepare_image_processor_dict()144
145def test_image_processor_properties(self):146image_processing = self.image_processing_class(**self.image_processor_dict)147self.assertTrue(hasattr(image_processing, "do_resize"))148self.assertTrue(hasattr(image_processing, "size"))149self.assertTrue(hasattr(image_processing, "do_center_crop"))150self.assertTrue(hasattr(image_processing, "center_crop"))151self.assertTrue(hasattr(image_processing, "do_normalize"))152self.assertTrue(hasattr(image_processing, "image_mean"))153self.assertTrue(hasattr(image_processing, "image_std"))154self.assertTrue(hasattr(image_processing, "do_convert_rgb"))155
156@unittest.skip("ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy157def test_call_numpy(self):158return super().test_call_numpy()159
160@unittest.skip("ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy161def test_call_pytorch(self):162return super().test_call_torch()163
164@unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet") # FIXME Amy165def test_call_numpy_4_channels(self):166pass167