google-research

Форк
0
89 строк · 3.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for cifar.data_lib."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from absl import flags
23
from absl.testing import absltest
24
from absl.testing import parameterized
25

26
import tensorflow.compat.v2 as tf
27

28
from uq_benchmark_2019 import image_data_utils
29
from uq_benchmark_2019.cifar import data_lib
30

31
flags.DEFINE_bool('fake_data', True, 'Bypass tests that rely on real data and '
32
                  'use dummy random data for the remaining tests.')
33

34
tf.enable_v2_behavior()
35

36

37
class DataLibTest(parameterized.TestCase):
38

39
  @parameterized.parameters(['train', 'test', 'valid'])
40
  def test_fake_data(self, split):
41
    # config is ignored for fake data
42
    config = image_data_utils.DataConfig(split)
43
    dataset = data_lib.build_dataset(config, fake_data=True)
44
    image_shape = next(iter(dataset))[0].numpy().shape
45
    self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
46

47
  @parameterized.parameters(['train', 'test', 'valid'])
48
  def test_roll_pixels(self, split):
49
    config = image_data_utils.DataConfig(split, roll_pixels=5)
50
    if not flags.FLAGS.fake_data:
51
      dataset = data_lib.build_dataset(config)
52
      image_shape = next(iter(dataset))[0].numpy().shape
53
      self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
54

55
  @parameterized.parameters(['train', 'test', 'valid'])
56
  def test_static_cifar_c(self, split):
57
    if not flags.FLAGS.fake_data:
58
      config = image_data_utils.DataConfig(
59
          split, corruption_static=True, corruption_level=3,
60
          corruption_type='pixelate')
61
      if split in ['train', 'valid']:
62
        with self.assertRaises(ValueError):
63
          data_lib.build_dataset(config)
64
      else:
65
        dataset = data_lib.build_dataset(config)
66
        image_shape = next(iter(dataset))[0].numpy().shape
67
        self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
68

69
  @parameterized.parameters(['train', 'test', 'valid'])
70
  def test_array_cifar_c(self, split):
71
    if not flags.FLAGS.fake_data:
72
      config = image_data_utils.DataConfig(
73
          split, corruption_level=4, corruption_type='glass_blur')
74
      dataset = data_lib.build_dataset(config)
75
      image_shape = next(iter(dataset))[0].numpy().shape
76
      self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
77

78
  @parameterized.parameters(['train', 'test', 'valid'])
79
  def test_value_cifar_c(self, split):
80
    if not flags.FLAGS.fake_data:
81
      config = image_data_utils.DataConfig(
82
          split, corruption_value=.25, corruption_type='brightness')
83
      dataset = data_lib.build_dataset(config)
84
      image_shape = next(iter(dataset))[0].numpy().shape
85
      self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
86

87

88
if __name__ == '__main__':
89
  absltest.main()
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.