google-research
190 строк · 6.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# Copyright 2016 Google Inc. All Rights Reserved.
17#
18# Licensed under the Apache License, Version 2.0 (the "License");
19# you may not use this file except in compliance with the License.
20# You may obtain a copy of the License at
21#
22# http://www.apache.org/licenses/LICENSE-2.0
23#
24# Unless required by applicable law or agreed to in writing, software
25# distributed under the License is distributed on an "AS IS" BASIS,
26# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27# See the License for the specific language governing permissions and
28# limitations under the License.
29# ==============================================================================
30
31"""Contains code for loading and preprocessing the CIFAR data."""
32
33import cifar100_dataset
34import cifar10_dataset
35import tensorflow as tf
36import tensorflow.contrib.slim as slim
37from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
38from tensorflow.python.ops import control_flow_ops
39
40
41datasets_map = {
42'cifar10': cifar10_dataset,
43'cifar100': cifar100_dataset,
44}
45
46
47def provide_resnet_data(dataset_name,
48split_name,
49batch_size,
50dataset_dir=None,
51num_epochs=None):
52"""Provides batches of CIFAR images for resnet.
53
54Args:
55dataset_name: Eiether 'cifar10' or 'cifar100'.
56split_name: Either 'train' or 'test'.
57batch_size: The number of images in each batch.
58dataset_dir: The directory where the MNIST data can be found.
59num_epochs: The number of times each data source is read. If left as None,
60the data will be cycled through indefinitely.
61
62Returns:
63images: A `Tensor` of size [batch_size, 32, 32, 1]
64one_hot_labels: A `Tensor` of size [batch_size, NUM_CLASSES], where
65each row has a single element set to one and the rest set to zeros.
66num_samples: The number of total samples in the dataset.
67num_classes: The number of total classes in the dataset.
68
69
70Raises:
71ValueError: If `split_name` is not either 'train' or 'test'.
72"""
73dataset = _get_dataset(dataset_name, split_name, dataset_dir=dataset_dir)
74
75provider = dataset_data_provider.DatasetDataProvider(
76dataset,
77common_queue_capacity=2 * batch_size,
78common_queue_min=batch_size,
79shuffle=(split_name == 'train'),
80num_epochs=num_epochs)
81
82[image, label] = provider.get(['image', 'label'])
83
84image = tf.to_float(image)
85
86image_size = 32
87if split_name == 'train':
88image = tf.image.resize_image_with_crop_or_pad(image, image_size + 4,
89image_size + 4)
90image = tf.random_crop(image, [image_size, image_size, 3])
91image = tf.image.random_flip_left_right(image)
92image /= 255
93# pylint: disable=unnecessary-lambda
94image = _apply_with_random_selector(
95image, lambda x, ordering: distort_color(x, ordering), num_cases=2)
96image = 2 * (image - 0.5)
97
98else:
99image = tf.image.resize_image_with_crop_or_pad(image, image_size,
100image_size)
101image = (image - 127.5) / 127.5
102
103# Creates a QueueRunner for the pre-fetching operation.
104images, labels = tf.train.batch(
105[image, label],
106batch_size=batch_size,
107num_threads=1,
108capacity=5 * batch_size,
109allow_smaller_final_batch=True)
110
111one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
112one_hot_labels = tf.squeeze(one_hot_labels, 1)
113return images, one_hot_labels, dataset.num_samples, dataset.num_classes
114
115
116def _get_dataset(name, split_name, **kwargs):
117"""Given a dataset name and a split_name returns a Dataset.
118
119Args:
120name: String, name of the dataset.
121split_name: A train/test split name.
122**kwargs: Extra kwargs for get_split, for example dataset_dir.
123
124Returns:
125A `Dataset` namedtuple.
126
127Raises:
128ValueError: if dataset unknown.
129"""
130if name not in datasets_map:
131raise ValueError('Name of dataset unknown %s' % name)
132dataset = datasets_map[name].get_split(split_name, **kwargs)
133dataset.name = name
134return dataset
135
136
137def _apply_with_random_selector(x, func, num_cases):
138"""Computes func(x, sel), with sel sampled from [0...num_cases-1].
139
140Args:
141x: input Tensor.
142func: Python function to apply.
143num_cases: Python int32, number of cases to sample sel from
144
145Returns:
146The result of func(x, sel), where func receives the value of
147the selector as a python integer, but sel is sampled dynamically.
148"""
149sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
150# Pass the real x only to one of the func calls.
151return control_flow_ops.merge([
152func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
153for case in range(num_cases)])[0]
154
155
156def distort_color(image, color_ordering=0, scope=None):
157"""Distort the color of the image.
158
159Each color distortion is non-commutative and thus ordering of the color ops
160matters. Ideally we would randomly permute the ordering of the color ops.
161Then adding that level of complication, we select a distinct ordering
162of color ops for each preprocessing thread.
163
164Args:
165image: Tensor containing single image.
166color_ordering: Python int, a type of distortion (valid values: 0, 1).
167scope: Optional scope for name_scope.
168
169Returns:
170color-distorted image
171Raises:
172ValueError: if color_ordering is not in {0, 1}.
173"""
174with tf.name_scope(scope, 'distort_color', [image]):
175if color_ordering == 0:
176image = tf.image.random_brightness(image, max_delta=32. / 255.)
177image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
178image = tf.image.random_hue(image, max_delta=0.2)
179image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
180elif color_ordering == 1:
181image = tf.image.random_brightness(image, max_delta=32. / 255.)
182image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
183image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
184image = tf.image.random_hue(image, max_delta=0.2)
185else:
186raise ValueError('color_ordering must be in {0, 1}')
187
188# The random_* ops do not necessarily clamp.
189image = tf.clip_by_value(image, 0.0, 1.0)
190return image
191