transformers

Форк
0
/
test_image_processing_chinese_clip.py 
166 строк · 6.5 Кб
1
# coding=utf-8
2
# Copyright 2021 HuggingFace Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
import unittest
18

19
from transformers.testing_utils import require_torch, require_vision
20
from transformers.utils import is_torch_available, is_vision_available
21

22
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
23

24

25
if is_vision_available():
26
    from transformers import ChineseCLIPImageProcessor
27

28

29
if is_torch_available():
30
    pass
31

32

33
class ChineseCLIPImageProcessingTester(unittest.TestCase):
34
    def __init__(
35
        self,
36
        parent,
37
        batch_size=7,
38
        num_channels=3,
39
        image_size=18,
40
        min_resolution=30,
41
        max_resolution=400,
42
        do_resize=True,
43
        size=None,
44
        do_center_crop=True,
45
        crop_size=None,
46
        do_normalize=True,
47
        image_mean=[0.48145466, 0.4578275, 0.40821073],
48
        image_std=[0.26862954, 0.26130258, 0.27577711],
49
        do_convert_rgb=True,
50
    ):
51
        size = size if size is not None else {"height": 224, "width": 224}
52
        crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
53
        self.parent = parent
54
        self.batch_size = batch_size
55
        self.num_channels = num_channels
56
        self.image_size = image_size
57
        self.min_resolution = min_resolution
58
        self.max_resolution = max_resolution
59
        self.do_resize = do_resize
60
        self.size = size
61
        self.do_center_crop = do_center_crop
62
        self.crop_size = crop_size
63
        self.do_normalize = do_normalize
64
        self.image_mean = image_mean
65
        self.image_std = image_std
66
        self.do_convert_rgb = do_convert_rgb
67

68
    def prepare_image_processor_dict(self):
69
        return {
70
            "do_resize": self.do_resize,
71
            "size": self.size,
72
            "do_center_crop": self.do_center_crop,
73
            "crop_size": self.crop_size,
74
            "do_normalize": self.do_normalize,
75
            "image_mean": self.image_mean,
76
            "image_std": self.image_std,
77
            "do_convert_rgb": self.do_convert_rgb,
78
        }
79

80
    def expected_output_image_shape(self, images):
81
        return 3, self.crop_size["height"], self.crop_size["width"]
82

83
    def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
84
        return prepare_image_inputs(
85
            batch_size=self.batch_size,
86
            num_channels=self.num_channels,
87
            min_resolution=self.min_resolution,
88
            max_resolution=self.max_resolution,
89
            equal_resolution=equal_resolution,
90
            numpify=numpify,
91
            torchify=torchify,
92
        )
93

94

95
@require_torch
96
@require_vision
97
class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
98
    image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
99

100
    def setUp(self):
101
        self.image_processor_tester = ChineseCLIPImageProcessingTester(self, do_center_crop=True)
102

103
    @property
104
    def image_processor_dict(self):
105
        return self.image_processor_tester.prepare_image_processor_dict()
106

107
    def test_image_processor_properties(self):
108
        image_processing = self.image_processing_class(**self.image_processor_dict)
109
        self.assertTrue(hasattr(image_processing, "do_resize"))
110
        self.assertTrue(hasattr(image_processing, "size"))
111
        self.assertTrue(hasattr(image_processing, "do_center_crop"))
112
        self.assertTrue(hasattr(image_processing, "center_crop"))
113
        self.assertTrue(hasattr(image_processing, "do_normalize"))
114
        self.assertTrue(hasattr(image_processing, "image_mean"))
115
        self.assertTrue(hasattr(image_processing, "image_std"))
116
        self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
117

118
    def test_image_processor_from_dict_with_kwargs(self):
119
        image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
120
        self.assertEqual(image_processor.size, {"height": 224, "width": 224})
121
        self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
122

123
        image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
124
        self.assertEqual(image_processor.size, {"shortest_edge": 42})
125
        self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
126

127
    @unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")  # FIXME Amy
128
    def test_call_numpy_4_channels(self):
129
        pass
130

131

132
@require_torch
133
@require_vision
134
class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase):
135
    image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
136

137
    def setUp(self):
138
        self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)
139
        self.expected_encoded_image_num_channels = 3
140

141
    @property
142
    def image_processor_dict(self):
143
        return self.image_processor_tester.prepare_image_processor_dict()
144

145
    def test_image_processor_properties(self):
146
        image_processing = self.image_processing_class(**self.image_processor_dict)
147
        self.assertTrue(hasattr(image_processing, "do_resize"))
148
        self.assertTrue(hasattr(image_processing, "size"))
149
        self.assertTrue(hasattr(image_processing, "do_center_crop"))
150
        self.assertTrue(hasattr(image_processing, "center_crop"))
151
        self.assertTrue(hasattr(image_processing, "do_normalize"))
152
        self.assertTrue(hasattr(image_processing, "image_mean"))
153
        self.assertTrue(hasattr(image_processing, "image_std"))
154
        self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
155

156
    @unittest.skip("ChineseCLIPImageProcessor does not support 4 channels yet")  # FIXME Amy
157
    def test_call_numpy(self):
158
        return super().test_call_numpy()
159

160
    @unittest.skip("ChineseCLIPImageProcessor does not support 4 channels yet")  # FIXME Amy
161
    def test_call_pytorch(self):
162
        return super().test_call_torch()
163

164
    @unittest.skip("ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")  # FIXME Amy
165
    def test_call_numpy_4_channels(self):
166
        pass
167

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.