google-research
263 строки · 9.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Functions for reading and adding noise to CIFAR datasets."""
17
18import functools19import copy20import numpy as np21import tensorflow as tf22import tensorflow_datasets as tfds23
24# Define Coarse classes for CIFAR-100
25_COARSE_CLASSES = [26[4, 72, 55, 30, 95], #pylint: disable27[32, 1, 67, 73, 91],28[70, 82, 54, 92, 62],29[9, 10, 16, 28, 61],30[0, 83, 51, 53, 57],31[39, 40, 86, 22, 87],32[5, 20, 84, 25, 94],33[6, 7, 14, 18, 24],34[97, 3, 42, 43, 88],35[68, 37, 12, 76, 17],36[33, 71, 49, 23, 60],37[38, 15, 19, 21, 31],38[64, 66, 34, 75, 63],39[99, 77, 45, 79, 26],40[2, 35, 98, 11, 46],41[44, 78, 93, 27, 29],42[65, 36, 74, 80, 50],43[96, 47, 52, 56, 59],44[90, 8, 13, 48, 58],45[69, 41, 81, 85, 89]46]
47
48
49def preprocess_fn(*features,50mean,51std,52image_size=32,53augment=False,54noise_type='none'):55"""Preprocess CIFAR-10 dataset.56
57Args:
58features: Tuple of original features and corrupted labels.
59mean: Channel-wise mean for normalizing image pixels.
60std: Channel-wise standard deviation for normalizing image pixels.
61image_size: Spatial height (=width) of the image.
62augment: A `Boolean` indicating whether to do data augmentation.
63noise_type: Noise type (`none` indicates clean data).
64
65Returns:
66A dict of preprocessed images and labels
67"""
68
69if noise_type != 'none':70image = features[0]['image']71label = tf.cast(features[1], tf.int32) # corrupted label72else:73features = features[0]74image = features['image']75label = tf.cast(features['label'], tf.int32)76image = tf.cast(image, tf.float32) / 255.077if augment:78image = tf.image.resize_with_crop_or_pad(image, image_size + 4,79image_size + 4)80image = tf.image.random_crop(image,81[image.shape[0], image_size, image_size, 3])82image = tf.image.random_flip_left_right(image)83image = (image - mean) / std84else:85image = tf.image.resize_with_crop_or_pad(image, image_size, image_size)86image = (image - mean) / std87return dict(image=image, label=label)88
89
90def get_corrupted_labels(ds,91noise_type,92noisy_frac=0.2,93num_classes=10,94seed=1335):95"""Simulate corrupted or noisy labels.96
97Args:
98ds: A Tensorflow dataset object.
99noise_type: A string specifying noise type. One of
100none/random/random_flip/random_flip_asym/random_flip_next.
101noisy_frac: A float specifying the fraction of noisy examples.
102seed: Random seed.
103
104Returns:
105A `numpy` 1-D array containing noisy labels.
106"""
107rng = np.random.RandomState(seed) # fix the random seed108ds = ds.batch(1092048, drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)110labels_noisy = np.zeros(50000)111count = 0112for batch in ds:113label = batch['label']114label_c = label115if noise_type == 'random_flip':116# noisy samples have randomly flipped label (always incorrect)117label_c = label + rng.choice(118num_classes,119size=len(label),120replace=True,121p=np.concatenate([[1 - noisy_frac],122np.ones(num_classes - 1) * noisy_frac /123(num_classes - 1)]))124label_c = tf.math.floormod(label_c, num_classes)125elif noise_type == 'random':126# noisy samples have random label (following127# https://arxiv.org/pdf/1904.11238.pdf Sec 4.1)128noisy_ids = rng.binomial(1, noisy_frac, len(label))129label_c = tf.where(noisy_ids, rng.choice(num_classes, size=len(label)),130label)131elif noise_type == 'random_flip_next':132corrupted = rng.choice([0, 1],133size=len(label),134replace=True,135p=[1. - noisy_frac, noisy_frac])136label_c = (label + corrupted) % num_classes137elif noise_type == 'random_flip_asym':138corrupted = rng.choice([0, 1],139size=len(label),140replace=True,141p=[1. - noisy_frac, noisy_frac])142if num_classes == 10: # cifar-10143label_c = label144elif num_classes == 100: # cifar-100145coarse_label = batch['coarse_label']146label_c = []147for ll, cc in zip(label, coarse_label):148choices = copy.deepcopy(_COARSE_CLASSES[cc])149choices.remove(ll)150label_c.append(rng.choice(choices))151label_c = np.array(label_c)152label_c = np.where(corrupted == 0, label, label_c)153else:154raise ValueError('Unknown noisy type: {}'.format(noise_type))155labels_noisy[count:count + len(label)] = label_c156count += len(label)157labels_noisy = labels_noisy[:count]158if noise_type == 'random_flip_asym' and num_classes == 10:159# cifar-10 classes: airplane : 0, automobile : 1, bird : 2, cat : 3,160# deer : 4, dog : 5, frog : 6, horse : 7, ship : 8, truck : 9161# noise: truck → automobile, bird → airplane, deer → horse, cat ↔ dog162# Also, only the subset of classes that can be mapped to other classes163# are noisified. This corresponds to 50% of the total training examples.164# Therefore, noisy_frac is divided by 2 to be consistent with CIFAR-100.165# ref: https://arxiv.org/pdf/2006.13554.pdf166noisy_frac = noisy_frac / 2.0167noisy_examples_idx = np.arange(len(labels_noisy))[np.in1d(168labels_noisy, [2, 3, 4, 5, 9])]169noisy_examples_idx = noisy_examples_idx[rng.permutation(170len(noisy_examples_idx))]171noisy_examples_idx = noisy_examples_idx[:int(noisy_frac *172len(labels_noisy))]173label_c = labels_noisy[noisy_examples_idx]174label_c[label_c == 2] = 0 # bird → airplane175label_c[label_c == 4] = 7 # deer → horse176idx_cat = label_c == 3177label_c[label_c == 5] = 3 # dog → cat178label_c[idx_cat] = 5 # cat → dog179label_c[label_c == 9] = 1 # truck → automobile180labels_noisy[noisy_examples_idx] = label_c181print('total examples:', count)182return labels_noisy183
184
185def get_dataset(batch_size,186data='cifar10',187num_classes=10,188image_size=32,189noise_type='none',190noisy_frac=0.,191train_on_full=False):192r"""Create Tensorflow dataset object for CIFAR.193
194Args:
195batch_size: Size of the minibatches.
196data: A string specifying the dataset (cifar10/cifar100)/
197image_size: Spatial height (=width) of the image.
198noise_type: A string specifying Noise type
199(none/random/random_flip/random_flip_asym/random_flip_next).
200train_on_full: A `Boolean` specifying whether to train on full dataset
201(True) or 90% of the dataset (False).
202
203Returns:
204Tensorflow dataset objects for train, validation, and test.
205"""
206if data == 'cifar10':207mean = tf.constant(208np.reshape([0.4914, 0.4822, 0.4465], [1, 1, 1, 3]), dtype=tf.float32)209std = tf.constant(210np.reshape([0.2023, 0.1994, 0.2010], [1, 1, 1, 3]), dtype=tf.float32)211elif data == 'cifar100':212mean = tf.constant(213np.reshape([0.5071, 0.4865, 0.4409], [1, 1, 1, 3]), dtype=tf.float32)214std = tf.constant(215np.reshape([0.2673, 0.2564, 0.2762], [1, 1, 1, 3]), dtype=tf.float32)216preproc_fn_train = functools.partial(217preprocess_fn,218mean=mean,219std=std,220image_size=image_size,221augment=True,222noise_type=noise_type)223if train_on_full:224ds = tfds.load(data, split='train', with_info=False)225else:226ds = tfds.load(data, split='train[:90%]', with_info=False)227if noise_type != 'none':228labels_noisy = get_corrupted_labels(ds, noise_type, noisy_frac, num_classes)229labels_noisy = tf.data.Dataset.from_tensor_slices(labels_noisy)230ds = tf.data.Dataset.zip((ds, labels_noisy))231ds = ds.repeat().shuffle(232batch_size * 4, seed=1).batch(233batch_size,234drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)235ds = ds.map(preproc_fn_train)236
237ds_valid = tfds.load(data, split='train[90%:]', with_info=False)238if noise_type != 'none':239labels_noisy = get_corrupted_labels(240ds_valid, noise_type, noisy_frac, num_classes, seed=1338)241labels_noisy = tf.data.Dataset.from_tensor_slices(labels_noisy)242ds_valid = tf.data.Dataset.zip((ds_valid, labels_noisy))243ds_valid = ds_valid.shuffle(24410000, seed=1).batch(245batch_size,246drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)247ds_valid = ds_valid.map(248functools.partial(249preprocess_fn,250mean=mean,251std=std,252image_size=image_size,253noise_type=noise_type))254
255ds_tst = tfds.load(data, split='test', with_info=False)256ds_tst = ds_tst.shuffle(25710000, seed=1).batch(258batch_size,259drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE)260ds_tst = ds_tst.map(261functools.partial(262preprocess_fn, mean=mean, std=std, image_size=image_size))263return ds, ds_valid, ds_tst264