google-research

Форк
0
263 строки · 9.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Functions for reading and adding noise to CIFAR datasets."""
17

18
import functools
19
import copy
20
import numpy as np
21
import tensorflow as tf
22
import tensorflow_datasets as tfds
23

24
# Define Coarse classes for CIFAR-100
25
_COARSE_CLASSES = [
26
    [4, 72, 55, 30, 95],  #pylint: disable
27
    [32, 1, 67, 73, 91],
28
    [70, 82, 54, 92, 62],
29
    [9, 10, 16, 28, 61],
30
    [0, 83, 51, 53, 57],
31
    [39, 40, 86, 22, 87],
32
    [5, 20, 84, 25, 94],
33
    [6, 7, 14, 18, 24],
34
    [97, 3, 42, 43, 88],
35
    [68, 37, 12, 76, 17],
36
    [33, 71, 49, 23, 60],
37
    [38, 15, 19, 21, 31],
38
    [64, 66, 34, 75, 63],
39
    [99, 77, 45, 79, 26],
40
    [2, 35, 98, 11, 46],
41
    [44, 78, 93, 27, 29],
42
    [65, 36, 74, 80, 50],
43
    [96, 47, 52, 56, 59],
44
    [90, 8, 13, 48, 58],
45
    [69, 41, 81, 85, 89]
46
]
47

48

49
def preprocess_fn(*features,
50
                  mean,
51
                  std,
52
                  image_size=32,
53
                  augment=False,
54
                  noise_type='none'):
55
  """Preprocess CIFAR-10 dataset.
56

57
  Args:
58
    features: Tuple of original features and corrupted labels.
59
    mean: Channel-wise mean for normalizing image pixels.
60
    std: Channel-wise standard deviation for normalizing image pixels.
61
    image_size: Spatial height (=width) of the image.
62
    augment: A `Boolean` indicating whether to do data augmentation.
63
    noise_type: Noise type (`none` indicates clean data).
64

65
  Returns:
66
    A dict of preprocessed images and labels
67
  """
68

69
  if noise_type != 'none':
70
    image = features[0]['image']
71
    label = tf.cast(features[1], tf.int32)  # corrupted label
72
  else:
73
    features = features[0]
74
    image = features['image']
75
    label = tf.cast(features['label'], tf.int32)
76
  image = tf.cast(image, tf.float32) / 255.0
77
  if augment:
78
    image = tf.image.resize_with_crop_or_pad(image, image_size + 4,
79
                                             image_size + 4)
80
    image = tf.image.random_crop(image,
81
                                 [image.shape[0], image_size, image_size, 3])
82
    image = tf.image.random_flip_left_right(image)
83
    image = (image - mean) / std
84
  else:
85
    image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)
86
    image = (image - mean) / std
87
  return dict(image=image, label=label)
88

89

90
def get_corrupted_labels(ds,
91
                         noise_type,
92
                         noisy_frac=0.2,
93
                         num_classes=10,
94
                         seed=1335):
95
  """Simulate corrupted or noisy labels.
96

97
  Args:
98
    ds: A Tensorflow dataset object.
99
    noise_type: A string specifying noise type. One of
100
      none/random/random_flip/random_flip_asym/random_flip_next.
101
    noisy_frac: A float specifying the fraction of noisy examples.
102
    seed: Random seed.
103

104
  Returns:
105
    A `numpy` 1-D array containing noisy labels.
106
  """
107
  rng = np.random.RandomState(seed)  # fix the random seed
108
  ds = ds.batch(
109
      2048, drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)
110
  labels_noisy = np.zeros(50000)
111
  count = 0
112
  for batch in ds:
113
    label = batch['label']
114
    label_c = label
115
    if noise_type == 'random_flip':
116
      # noisy samples have randomly flipped label (always incorrect)
117
      label_c = label + rng.choice(
118
          num_classes,
119
          size=len(label),
120
          replace=True,
121
          p=np.concatenate([[1 - noisy_frac],
122
                            np.ones(num_classes - 1) * noisy_frac /
123
                            (num_classes - 1)]))
124
      label_c = tf.math.floormod(label_c, num_classes)
125
    elif noise_type == 'random':
126
      # noisy samples have random label (following
127
      # https://arxiv.org/pdf/1904.11238.pdf Sec 4.1)
128
      noisy_ids = rng.binomial(1, noisy_frac, len(label))
129
      label_c = tf.where(noisy_ids, rng.choice(num_classes, size=len(label)),
130
                         label)
131
    elif noise_type == 'random_flip_next':
132
      corrupted = rng.choice([0, 1],
133
                             size=len(label),
134
                             replace=True,
135
                             p=[1. - noisy_frac, noisy_frac])
136
      label_c = (label + corrupted) % num_classes
137
    elif noise_type == 'random_flip_asym':
138
      corrupted = rng.choice([0, 1],
139
                             size=len(label),
140
                             replace=True,
141
                             p=[1. - noisy_frac, noisy_frac])
142
      if num_classes == 10:  # cifar-10
143
        label_c = label
144
      elif num_classes == 100:  # cifar-100
145
        coarse_label = batch['coarse_label']
146
        label_c = []
147
        for ll, cc in zip(label, coarse_label):
148
          choices = copy.deepcopy(_COARSE_CLASSES[cc])
149
          choices.remove(ll)
150
          label_c.append(rng.choice(choices))
151
        label_c = np.array(label_c)
152
        label_c = np.where(corrupted == 0, label, label_c)
153
    else:
154
      raise ValueError('Unknown noisy type: {}'.format(noise_type))
155
    labels_noisy[count:count + len(label)] = label_c
156
    count += len(label)
157
  labels_noisy = labels_noisy[:count]
158
  if noise_type == 'random_flip_asym' and num_classes == 10:
159
    # cifar-10 classes: airplane : 0, automobile : 1, bird : 2, cat : 3,
160
    # deer : 4, dog : 5, frog : 6, horse : 7, ship : 8, truck : 9
161
    # noise:  truck → automobile, bird → airplane, deer → horse, cat ↔ dog
162
    # Also, only the subset of classes that can be mapped to other classes
163
    # are noisified. This corresponds to 50% of the total training examples.
164
    # Therefore, noisy_frac is divided by 2 to be consistent with CIFAR-100.
165
    # ref: https://arxiv.org/pdf/2006.13554.pdf
166
    noisy_frac = noisy_frac / 2.0
167
    noisy_examples_idx = np.arange(len(labels_noisy))[np.in1d(
168
        labels_noisy, [2, 3, 4, 5, 9])]
169
    noisy_examples_idx = noisy_examples_idx[rng.permutation(
170
        len(noisy_examples_idx))]
171
    noisy_examples_idx = noisy_examples_idx[:int(noisy_frac *
172
                                                 len(labels_noisy))]
173
    label_c = labels_noisy[noisy_examples_idx]
174
    label_c[label_c == 2] = 0  #  bird → airplane
175
    label_c[label_c == 4] = 7  #  deer → horse
176
    idx_cat = label_c == 3
177
    label_c[label_c == 5] = 3  #  dog → cat
178
    label_c[idx_cat] = 5  #  cat → dog
179
    label_c[label_c == 9] = 1  #  truck → automobile
180
    labels_noisy[noisy_examples_idx] = label_c
181
  print('total examples:', count)
182
  return labels_noisy
183

184

185
def get_dataset(batch_size,
186
                data='cifar10',
187
                num_classes=10,
188
                image_size=32,
189
                noise_type='none',
190
                noisy_frac=0.,
191
                train_on_full=False):
192
  r"""Create Tensorflow dataset object for CIFAR.
193

194
  Args:
195
    batch_size: Size of the minibatches.
196
    data: A string specifying the dataset (cifar10/cifar100)/
197
    image_size: Spatial height (=width) of the image.
198
    noise_type: A string specifying Noise type
199
      (none/random/random_flip/random_flip_asym/random_flip_next).
200
    train_on_full: A `Boolean` specifying whether to train on full dataset
201
      (True) or 90% of the dataset (False).
202

203
  Returns:
204
    Tensorflow dataset objects for train, validation, and test.
205
  """
206
  if data == 'cifar10':
207
    mean = tf.constant(
208
        np.reshape([0.4914, 0.4822, 0.4465], [1, 1, 1, 3]), dtype=tf.float32)
209
    std = tf.constant(
210
        np.reshape([0.2023, 0.1994, 0.2010], [1, 1, 1, 3]), dtype=tf.float32)
211
  elif data == 'cifar100':
212
    mean = tf.constant(
213
        np.reshape([0.5071, 0.4865, 0.4409], [1, 1, 1, 3]), dtype=tf.float32)
214
    std = tf.constant(
215
        np.reshape([0.2673, 0.2564, 0.2762], [1, 1, 1, 3]), dtype=tf.float32)
216
  preproc_fn_train = functools.partial(
217
      preprocess_fn,
218
      mean=mean,
219
      std=std,
220
      image_size=image_size,
221
      augment=True,
222
      noise_type=noise_type)
223
  if train_on_full:
224
    ds = tfds.load(data, split='train', with_info=False)
225
  else:
226
    ds = tfds.load(data, split='train[:90%]', with_info=False)
227
  if noise_type != 'none':
228
    labels_noisy = get_corrupted_labels(ds, noise_type, noisy_frac, num_classes)
229
    labels_noisy = tf.data.Dataset.from_tensor_slices(labels_noisy)
230
    ds = tf.data.Dataset.zip((ds, labels_noisy))
231
  ds = ds.repeat().shuffle(
232
      batch_size * 4, seed=1).batch(
233
          batch_size,
234
          drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
235
  ds = ds.map(preproc_fn_train)
236

237
  ds_valid = tfds.load(data, split='train[90%:]', with_info=False)
238
  if noise_type != 'none':
239
    labels_noisy = get_corrupted_labels(
240
        ds_valid, noise_type, noisy_frac, num_classes, seed=1338)
241
    labels_noisy = tf.data.Dataset.from_tensor_slices(labels_noisy)
242
    ds_valid = tf.data.Dataset.zip((ds_valid, labels_noisy))
243
  ds_valid = ds_valid.shuffle(
244
      10000, seed=1).batch(
245
          batch_size,
246
          drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)
247
  ds_valid = ds_valid.map(
248
      functools.partial(
249
          preprocess_fn,
250
          mean=mean,
251
          std=std,
252
          image_size=image_size,
253
          noise_type=noise_type))
254

255
  ds_tst = tfds.load(data, split='test', with_info=False)
256
  ds_tst = ds_tst.shuffle(
257
      10000, seed=1).batch(
258
          batch_size,
259
          drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)
260
  ds_tst = ds_tst.map(
261
      functools.partial(
262
          preprocess_fn, mean=mean, std=std, image_size=image_size))
263
  return ds, ds_valid, ds_tst
264

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.