google-research
127 строк · 4.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common utilities for CIFAR and ImageNet datasets."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import attr
23import numpy as np
24import tensorflow.compat.v1 as tf
25
26CORRUPTION_TYPES = [
27'brightness', 'contrast', 'defocus_blur', 'elastic_transform', 'fog',
28'frost', 'gaussian_blur', 'gaussian_noise', 'glass_blur', 'impulse_noise',
29'static_compression', 'pixelate', 'saturate', 'shot_noise', 'spatter',
30'speckle_noise', 'zoom_blur']
31
32_TINY_DATA_SIZE = 99
33
34
35@attr.s
36class DataConfig(object):
37"""Define config for (optionally) corrupted ImageNet and CIFAR data.
38
39Attributes:
40split: String, dataset split ('train' or 'test').
41roll_pixels: Int, number of pixels by which to roll the image.
42corruption_type: String, the name of the corruption function to apply
43(must be one of CORRUPTION_TYPES).
44corruption_static: Bool. If True, use the corrupted images provided by
45Hendrycks and Dietterich (2019) as static imagess. If False, apply
46corruption functions to standard images.
47corruption_level: Int, level (from 1 to 5) of the corruption values
48defined by Hendrycks and Dietterich (2019). If 0, then use
49`corruption_value` instead.
50corruption_value: Float or tuple, corruption value to apply to the image
51data. If None, then use `corruption_level` instead.
52alt_dataset_name: Optional name of an alternate dataset (e.g. SVHN for
53OOD CIFAR experiments).
54"""
55split = attr.ib()
56roll_pixels = attr.ib(0)
57corruption_type = attr.ib(
58None, validator=attr.validators.in_(CORRUPTION_TYPES + [None]))
59corruption_static = attr.ib(False)
60corruption_level = attr.ib(0, validator=attr.validators.in_(range(6)))
61corruption_value = attr.ib(None)
62alt_dataset_name = attr.ib(None)
63
64
65DATA_CONFIG_TRAIN = DataConfig('train')
66DATA_CONFIG_VALID = DataConfig('valid')
67DATA_CONFIG_TEST = DataConfig('test')
68
69
70def make_fake_data(image_shape, num_examples=_TINY_DATA_SIZE):
71images = np.random.rand(num_examples, *image_shape)
72labels = np.random.randint(0, 10, num_examples)
73return tf.data.Dataset.from_tensor_slices((images, labels))
74
75
76def make_static_dataset(config, data_reader_fn):
77"""Make a tf.Dataset of corrupted images read from disk."""
78if config.corruption_level not in range(1, 6):
79raise ValueError('Corruption level of the static images must be between 1'
80' and 5.')
81if config.split != 'test':
82raise ValueError('Split must be `test` for corrupted images.')
83
84if config.corruption_value is not None:
85raise ValueError('`corruption_value` must be `None` for static images.')
86
87dataset = data_reader_fn(config.corruption_type, config.corruption_level)
88
89def convert_image_dtype(image, label):
90image = tf.image.convert_image_dtype(image, tf.float32)
91return image, label
92return dataset.map(convert_image_dtype)
93
94
95def get_data_config(name):
96"""Parse data-config name into a DataConfig.
97
98Args:
99name: String of form "{corruption-family}-{options}" or "train", "test".
100For example: roll-24, corrupt-static_brightness_1
101Returns:
102DataConfig instance.
103"""
104base_configs = {'train': DATA_CONFIG_TRAIN,
105'valid': DATA_CONFIG_VALID, 'test': DATA_CONFIG_TEST}
106if name in base_configs:
107return base_configs[name]
108
109parsed = name.split('-', 1)
110if parsed[0] == 'roll':
111return DataConfig('test', roll_pixels=int(parsed[1]))
112elif parsed[0] == 'corrupt':
113corruption_src, corruption_type, corruption_x = parsed[1].split('-', 2)
114if corruption_src in ('static', 'array'):
115return DataConfig('test',
116corruption_type=corruption_type,
117corruption_static=corruption_src == 'static',
118corruption_level=int(corruption_x))
119elif corruption_src == 'value':
120return DataConfig('test',
121corruption_type=corruption_type,
122corruption_value=float(corruption_x))
123elif parsed[0] == 'svhn':
124return DataConfig('test', alt_dataset_name='svhn_cropped')
125elif parsed[0] == 'celeb_a':
126return DataConfig('test', alt_dataset_name='celeb_a')
127raise ValueError('Data config name not recognized: %s' % name)
128