transformers
272 строки · 10.2 Кб
1# coding=utf-8
2# Copyright 2021 HuggingFace Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import unittest18
19from datasets import load_dataset20
21from transformers.testing_utils import require_torch, require_vision22from transformers.utils import is_torch_available, is_vision_available23
24from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs25
26
27if is_torch_available():28import torch29
30if is_vision_available():31from PIL import Image32
33from transformers import BeitImageProcessor34
35
36class BeitImageProcessingTester(unittest.TestCase):37def __init__(38self,39parent,40batch_size=7,41num_channels=3,42image_size=18,43min_resolution=30,44max_resolution=400,45do_resize=True,46size=None,47do_center_crop=True,48crop_size=None,49do_normalize=True,50image_mean=[0.5, 0.5, 0.5],51image_std=[0.5, 0.5, 0.5],52do_reduce_labels=False,53):54size = size if size is not None else {"height": 20, "width": 20}55crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}56self.parent = parent57self.batch_size = batch_size58self.num_channels = num_channels59self.image_size = image_size60self.min_resolution = min_resolution61self.max_resolution = max_resolution62self.do_resize = do_resize63self.size = size64self.do_center_crop = do_center_crop65self.crop_size = crop_size66self.do_normalize = do_normalize67self.image_mean = image_mean68self.image_std = image_std69self.do_reduce_labels = do_reduce_labels70
71def prepare_image_processor_dict(self):72return {73"do_resize": self.do_resize,74"size": self.size,75"do_center_crop": self.do_center_crop,76"crop_size": self.crop_size,77"do_normalize": self.do_normalize,78"image_mean": self.image_mean,79"image_std": self.image_std,80"do_reduce_labels": self.do_reduce_labels,81}82
83def expected_output_image_shape(self, images):84return self.num_channels, self.crop_size["height"], self.crop_size["width"]85
86def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):87return prepare_image_inputs(88batch_size=self.batch_size,89num_channels=self.num_channels,90min_resolution=self.min_resolution,91max_resolution=self.max_resolution,92equal_resolution=equal_resolution,93numpify=numpify,94torchify=torchify,95)96
97
98def prepare_semantic_single_inputs():99dataset = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")100
101image = Image.open(dataset[0]["file"])102map = Image.open(dataset[1]["file"])103
104return image, map105
106
107def prepare_semantic_batch_inputs():108ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")109
110image1 = Image.open(ds[0]["file"])111map1 = Image.open(ds[1]["file"])112image2 = Image.open(ds[2]["file"])113map2 = Image.open(ds[3]["file"])114
115return [image1, image2], [map1, map2]116
117
118@require_torch
119@require_vision
120class BeitImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):121image_processing_class = BeitImageProcessor if is_vision_available() else None122
123def setUp(self):124self.image_processor_tester = BeitImageProcessingTester(self)125
126@property127def image_processor_dict(self):128return self.image_processor_tester.prepare_image_processor_dict()129
130def test_image_processor_properties(self):131image_processing = self.image_processing_class(**self.image_processor_dict)132self.assertTrue(hasattr(image_processing, "do_resize"))133self.assertTrue(hasattr(image_processing, "size"))134self.assertTrue(hasattr(image_processing, "do_center_crop"))135self.assertTrue(hasattr(image_processing, "center_crop"))136self.assertTrue(hasattr(image_processing, "do_normalize"))137self.assertTrue(hasattr(image_processing, "image_mean"))138self.assertTrue(hasattr(image_processing, "image_std"))139
140def test_image_processor_from_dict_with_kwargs(self):141image_processor = self.image_processing_class.from_dict(self.image_processor_dict)142self.assertEqual(image_processor.size, {"height": 20, "width": 20})143self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})144self.assertEqual(image_processor.do_reduce_labels, False)145
146image_processor = self.image_processing_class.from_dict(147self.image_processor_dict, size=42, crop_size=84, reduce_labels=True148)149self.assertEqual(image_processor.size, {"height": 42, "width": 42})150self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})151self.assertEqual(image_processor.do_reduce_labels, True)152
153def test_call_segmentation_maps(self):154# Initialize image_processing155image_processing = self.image_processing_class(**self.image_processor_dict)156# create random PyTorch tensors157image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)158maps = []159for image in image_inputs:160self.assertIsInstance(image, torch.Tensor)161maps.append(torch.zeros(image.shape[-2:]).long())162
163# Test not batched input164encoding = image_processing(image_inputs[0], maps[0], return_tensors="pt")165self.assertEqual(166encoding["pixel_values"].shape,167(1681,169self.image_processor_tester.num_channels,170self.image_processor_tester.crop_size["height"],171self.image_processor_tester.crop_size["width"],172),173)174self.assertEqual(175encoding["labels"].shape,176(1771,178self.image_processor_tester.crop_size["height"],179self.image_processor_tester.crop_size["width"],180),181)182self.assertEqual(encoding["labels"].dtype, torch.long)183self.assertTrue(encoding["labels"].min().item() >= 0)184self.assertTrue(encoding["labels"].max().item() <= 255)185
186# Test batched187encoding = image_processing(image_inputs, maps, return_tensors="pt")188self.assertEqual(189encoding["pixel_values"].shape,190(191self.image_processor_tester.batch_size,192self.image_processor_tester.num_channels,193self.image_processor_tester.crop_size["height"],194self.image_processor_tester.crop_size["width"],195),196)197self.assertEqual(198encoding["labels"].shape,199(200self.image_processor_tester.batch_size,201self.image_processor_tester.crop_size["height"],202self.image_processor_tester.crop_size["width"],203),204)205self.assertEqual(encoding["labels"].dtype, torch.long)206self.assertTrue(encoding["labels"].min().item() >= 0)207self.assertTrue(encoding["labels"].max().item() <= 255)208
209# Test not batched input (PIL images)210image, segmentation_map = prepare_semantic_single_inputs()211
212encoding = image_processing(image, segmentation_map, return_tensors="pt")213self.assertEqual(214encoding["pixel_values"].shape,215(2161,217self.image_processor_tester.num_channels,218self.image_processor_tester.crop_size["height"],219self.image_processor_tester.crop_size["width"],220),221)222self.assertEqual(223encoding["labels"].shape,224(2251,226self.image_processor_tester.crop_size["height"],227self.image_processor_tester.crop_size["width"],228),229)230self.assertEqual(encoding["labels"].dtype, torch.long)231self.assertTrue(encoding["labels"].min().item() >= 0)232self.assertTrue(encoding["labels"].max().item() <= 255)233
234# Test batched input (PIL images)235images, segmentation_maps = prepare_semantic_batch_inputs()236
237encoding = image_processing(images, segmentation_maps, return_tensors="pt")238self.assertEqual(239encoding["pixel_values"].shape,240(2412,242self.image_processor_tester.num_channels,243self.image_processor_tester.crop_size["height"],244self.image_processor_tester.crop_size["width"],245),246)247self.assertEqual(248encoding["labels"].shape,249(2502,251self.image_processor_tester.crop_size["height"],252self.image_processor_tester.crop_size["width"],253),254)255self.assertEqual(encoding["labels"].dtype, torch.long)256self.assertTrue(encoding["labels"].min().item() >= 0)257self.assertTrue(encoding["labels"].max().item() <= 255)258
259def test_reduce_labels(self):260# Initialize image_processing261image_processing = self.image_processing_class(**self.image_processor_dict)262
263# ADE20k has 150 classes, and the background is included, so labels should be between 0 and 150264image, map = prepare_semantic_single_inputs()265encoding = image_processing(image, map, return_tensors="pt")266self.assertTrue(encoding["labels"].min().item() >= 0)267self.assertTrue(encoding["labels"].max().item() <= 150)268
269image_processing.do_reduce_labels = True270encoding = image_processing(image, map, return_tensors="pt")271self.assertTrue(encoding["labels"].min().item() >= 0)272self.assertTrue(encoding["labels"].max().item() <= 255)273