google-research
89 строк · 3.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for cifar.data_lib."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from absl import flags
23from absl.testing import absltest
24from absl.testing import parameterized
25
26import tensorflow.compat.v2 as tf
27
28from uq_benchmark_2019 import image_data_utils
29from uq_benchmark_2019.cifar import data_lib
30
31flags.DEFINE_bool('fake_data', True, 'Bypass tests that rely on real data and '
32'use dummy random data for the remaining tests.')
33
34tf.enable_v2_behavior()
35
36
37class DataLibTest(parameterized.TestCase):
38
39@parameterized.parameters(['train', 'test', 'valid'])
40def test_fake_data(self, split):
41# config is ignored for fake data
42config = image_data_utils.DataConfig(split)
43dataset = data_lib.build_dataset(config, fake_data=True)
44image_shape = next(iter(dataset))[0].numpy().shape
45self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
46
47@parameterized.parameters(['train', 'test', 'valid'])
48def test_roll_pixels(self, split):
49config = image_data_utils.DataConfig(split, roll_pixels=5)
50if not flags.FLAGS.fake_data:
51dataset = data_lib.build_dataset(config)
52image_shape = next(iter(dataset))[0].numpy().shape
53self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
54
55@parameterized.parameters(['train', 'test', 'valid'])
56def test_static_cifar_c(self, split):
57if not flags.FLAGS.fake_data:
58config = image_data_utils.DataConfig(
59split, corruption_static=True, corruption_level=3,
60corruption_type='pixelate')
61if split in ['train', 'valid']:
62with self.assertRaises(ValueError):
63data_lib.build_dataset(config)
64else:
65dataset = data_lib.build_dataset(config)
66image_shape = next(iter(dataset))[0].numpy().shape
67self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
68
69@parameterized.parameters(['train', 'test', 'valid'])
70def test_array_cifar_c(self, split):
71if not flags.FLAGS.fake_data:
72config = image_data_utils.DataConfig(
73split, corruption_level=4, corruption_type='glass_blur')
74dataset = data_lib.build_dataset(config)
75image_shape = next(iter(dataset))[0].numpy().shape
76self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
77
78@parameterized.parameters(['train', 'test', 'valid'])
79def test_value_cifar_c(self, split):
80if not flags.FLAGS.fake_data:
81config = image_data_utils.DataConfig(
82split, corruption_value=.25, corruption_type='brightness')
83dataset = data_lib.build_dataset(config)
84image_shape = next(iter(dataset))[0].numpy().shape
85self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
86
87
88if __name__ == '__main__':
89absltest.main()
90