google-research

Форк
0
/
cifar_data_provider.py 
190 строк · 6.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# Copyright 2016 Google Inc. All Rights Reserved.
17
#
18
# Licensed under the Apache License, Version 2.0 (the "License");
19
# you may not use this file except in compliance with the License.
20
# You may obtain a copy of the License at
21
#
22
# http://www.apache.org/licenses/LICENSE-2.0
23
#
24
# Unless required by applicable law or agreed to in writing, software
25
# distributed under the License is distributed on an "AS IS" BASIS,
26
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27
# See the License for the specific language governing permissions and
28
# limitations under the License.
29
# ==============================================================================
30

31
"""Contains code for loading and preprocessing the CIFAR data."""
32

33
import cifar100_dataset
34
import cifar10_dataset
35
import tensorflow as tf
36
import tensorflow.contrib.slim as slim
37
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
38
from tensorflow.python.ops import control_flow_ops
39

40

41
datasets_map = {
42
    'cifar10': cifar10_dataset,
43
    'cifar100': cifar100_dataset,
44
}
45

46

47
def provide_resnet_data(dataset_name,
48
                        split_name,
49
                        batch_size,
50
                        dataset_dir=None,
51
                        num_epochs=None):
52
  """Provides batches of CIFAR images for resnet.
53

54
  Args:
55
    dataset_name: Eiether 'cifar10' or 'cifar100'.
56
    split_name: Either 'train' or 'test'.
57
    batch_size: The number of images in each batch.
58
    dataset_dir: The directory where the MNIST data can be found.
59
    num_epochs: The number of times each data source is read. If left as None,
60
      the data will be cycled through indefinitely.
61

62
  Returns:
63
    images: A `Tensor` of size [batch_size, 32, 32, 1]
64
    one_hot_labels: A `Tensor` of size [batch_size, NUM_CLASSES], where
65
      each row has a single element set to one and the rest set to zeros.
66
    num_samples: The number of total samples in the dataset.
67
    num_classes: The number of total classes in the dataset.
68

69

70
  Raises:
71
    ValueError: If `split_name` is not either 'train' or 'test'.
72
  """
73
  dataset = _get_dataset(dataset_name, split_name, dataset_dir=dataset_dir)
74

75
  provider = dataset_data_provider.DatasetDataProvider(
76
      dataset,
77
      common_queue_capacity=2 * batch_size,
78
      common_queue_min=batch_size,
79
      shuffle=(split_name == 'train'),
80
      num_epochs=num_epochs)
81

82
  [image, label] = provider.get(['image', 'label'])
83

84
  image = tf.to_float(image)
85

86
  image_size = 32
87
  if split_name == 'train':
88
    image = tf.image.resize_image_with_crop_or_pad(image, image_size + 4,
89
                                                   image_size + 4)
90
    image = tf.random_crop(image, [image_size, image_size, 3])
91
    image = tf.image.random_flip_left_right(image)
92
    image /= 255
93
    # pylint: disable=unnecessary-lambda
94
    image = _apply_with_random_selector(
95
        image, lambda x, ordering: distort_color(x, ordering), num_cases=2)
96
    image = 2 * (image - 0.5)
97

98
  else:
99
    image = tf.image.resize_image_with_crop_or_pad(image, image_size,
100
                                                   image_size)
101
    image = (image - 127.5) / 127.5
102

103
  # Creates a QueueRunner for the pre-fetching operation.
104
  images, labels = tf.train.batch(
105
      [image, label],
106
      batch_size=batch_size,
107
      num_threads=1,
108
      capacity=5 * batch_size,
109
      allow_smaller_final_batch=True)
110

111
  one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
112
  one_hot_labels = tf.squeeze(one_hot_labels, 1)
113
  return images, one_hot_labels, dataset.num_samples, dataset.num_classes
114

115

116
def _get_dataset(name, split_name, **kwargs):
117
  """Given a dataset name and a split_name returns a Dataset.
118

119
  Args:
120
    name: String, name of the dataset.
121
    split_name: A train/test split name.
122
    **kwargs: Extra kwargs for get_split, for example dataset_dir.
123

124
  Returns:
125
    A `Dataset` namedtuple.
126

127
  Raises:
128
    ValueError: if dataset unknown.
129
  """
130
  if name not in datasets_map:
131
    raise ValueError('Name of dataset unknown %s' % name)
132
  dataset = datasets_map[name].get_split(split_name, **kwargs)
133
  dataset.name = name
134
  return dataset
135

136

137
def _apply_with_random_selector(x, func, num_cases):
138
  """Computes func(x, sel), with sel sampled from [0...num_cases-1].
139

140
  Args:
141
    x: input Tensor.
142
    func: Python function to apply.
143
    num_cases: Python int32, number of cases to sample sel from
144

145
  Returns:
146
    The result of func(x, sel), where func receives the value of
147
    the selector as a python integer, but sel is sampled dynamically.
148
  """
149
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
150
  # Pass the real x only to one of the func calls.
151
  return control_flow_ops.merge([
152
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
153
      for case in range(num_cases)])[0]
154

155

156
def distort_color(image, color_ordering=0, scope=None):
157
  """Distort the color of the image.
158

159
  Each color distortion is non-commutative and thus ordering of the color ops
160
  matters. Ideally we would randomly permute the ordering of the color ops.
161
  Then adding that level of complication, we select a distinct ordering
162
  of color ops for each preprocessing thread.
163

164
  Args:
165
    image: Tensor containing single image.
166
    color_ordering: Python int, a type of distortion (valid values: 0, 1).
167
    scope: Optional scope for name_scope.
168

169
  Returns:
170
    color-distorted image
171
  Raises:
172
    ValueError: if color_ordering is not in {0, 1}.
173
  """
174
  with tf.name_scope(scope, 'distort_color', [image]):
175
    if color_ordering == 0:
176
      image = tf.image.random_brightness(image, max_delta=32. / 255.)
177
      image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
178
      image = tf.image.random_hue(image, max_delta=0.2)
179
      image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
180
    elif color_ordering == 1:
181
      image = tf.image.random_brightness(image, max_delta=32. / 255.)
182
      image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
183
      image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
184
      image = tf.image.random_hue(image, max_delta=0.2)
185
    else:
186
      raise ValueError('color_ordering must be in {0, 1}')
187

188
    # The random_* ops do not necessarily clamp.
189
    image = tf.clip_by_value(image, 0.0, 1.0)
190
    return image
191

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.