google-research

Форк
0
/
image_data_utils.py 
127 строк · 4.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common utilities for CIFAR and ImageNet datasets."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import attr
23
import numpy as np
24
import tensorflow.compat.v1 as tf
25

26
CORRUPTION_TYPES = [
27
    'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog',
28
    'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
29
    'static_compression', 'pixelate', 'saturate', 'shot_noise', 'spatter',
30
    'speckle_noise', 'zoom_blur']
31

32
_TINY_DATA_SIZE = 99
33

34

35
@attr.s
36
class DataConfig(object):
37
  """Define config for (optionally) corrupted ImageNet and CIFAR data.
38

39
  Attributes:
40
    split: String, dataset split ('train' or 'test').
41
    roll_pixels: Int, number of pixels by which to roll the image.
42
    corruption_type: String, the name of the corruption function to apply
43
      (must be one of CORRUPTION_TYPES).
44
    corruption_static: Bool. If True, use the corrupted images provided by
45
      Hendrycks and Dietterich (2019) as static imagess. If False, apply
46
      corruption functions to standard images.
47
    corruption_level: Int, level (from 1 to 5) of the corruption values
48
      defined by Hendrycks and Dietterich (2019). If 0, then use
49
      `corruption_value` instead.
50
    corruption_value: Float or tuple, corruption value to apply to the image
51
      data. If None, then use `corruption_level` instead.
52
    alt_dataset_name: Optional name of an alternate dataset (e.g. SVHN for
53
      OOD CIFAR experiments).
54
  """
55
  split = attr.ib()
56
  roll_pixels = attr.ib(0)
57
  corruption_type = attr.ib(
58
      None, validator=attr.validators.in_(CORRUPTION_TYPES + [None]))
59
  corruption_static = attr.ib(False)
60
  corruption_level = attr.ib(0, validator=attr.validators.in_(range(6)))
61
  corruption_value = attr.ib(None)
62
  alt_dataset_name = attr.ib(None)
63

64

65
DATA_CONFIG_TRAIN = DataConfig('train')
66
DATA_CONFIG_VALID = DataConfig('valid')
67
DATA_CONFIG_TEST = DataConfig('test')
68

69

70
def make_fake_data(image_shape, num_examples=_TINY_DATA_SIZE):
71
  images = np.random.rand(num_examples, *image_shape)
72
  labels = np.random.randint(0, 10, num_examples)
73
  return tf.data.Dataset.from_tensor_slices((images, labels))
74

75

76
def make_static_dataset(config, data_reader_fn):
77
  """Make a tf.Dataset of corrupted images read from disk."""
78
  if config.corruption_level not in range(1, 6):
79
    raise ValueError('Corruption level of the static images must be between 1'
80
                     ' and 5.')
81
  if config.split != 'test':
82
    raise ValueError('Split must be `test` for corrupted images.')
83

84
  if config.corruption_value is not None:
85
    raise ValueError('`corruption_value` must be `None` for static images.')
86

87
  dataset = data_reader_fn(config.corruption_type, config.corruption_level)
88

89
  def convert_image_dtype(image, label):
90
    image = tf.image.convert_image_dtype(image, tf.float32)
91
    return image, label
92
  return dataset.map(convert_image_dtype)
93

94

95
def get_data_config(name):
96
  """Parse data-config name into a DataConfig.
97

98
  Args:
99
    name: String of form "{corruption-family}-{options}" or "train", "test".
100
        For example: roll-24, corrupt-static_brightness_1
101
  Returns:
102
    DataConfig instance.
103
  """
104
  base_configs = {'train': DATA_CONFIG_TRAIN,
105
                  'valid': DATA_CONFIG_VALID, 'test': DATA_CONFIG_TEST}
106
  if name in base_configs:
107
    return base_configs[name]
108

109
  parsed = name.split('-', 1)
110
  if parsed[0] == 'roll':
111
    return DataConfig('test', roll_pixels=int(parsed[1]))
112
  elif parsed[0] == 'corrupt':
113
    corruption_src, corruption_type, corruption_x = parsed[1].split('-', 2)
114
    if corruption_src in ('static', 'array'):
115
      return DataConfig('test',
116
                        corruption_type=corruption_type,
117
                        corruption_static=corruption_src == 'static',
118
                        corruption_level=int(corruption_x))
119
    elif corruption_src == 'value':
120
      return DataConfig('test',
121
                        corruption_type=corruption_type,
122
                        corruption_value=float(corruption_x))
123
  elif parsed[0] == 'svhn':
124
    return DataConfig('test', alt_dataset_name='svhn_cropped')
125
  elif parsed[0] == 'celeb_a':
126
    return DataConfig('test', alt_dataset_name='celeb_a')
127
  raise ValueError('Data config name not recognized: %s' % name)
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.