google-research
137 строк · 4.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Library for constructing MNIST and distorted MNIST datasets."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23
24from absl import flags
25from absl import logging
26
27import attr
28import numpy as np
29import scipy.ndimage
30from six.moves import range
31from six.moves import zip
32
33import tensorflow.compat.v2 as tf
34import tensorflow_datasets as tfds
35
36NUM_TRAIN_EXAMPLES = 50 * 1000
37DUMMY_DATA_SIZE = 99
38MNIST_IMAGE_SHAPE = (28, 28, 1)
39
40DATA_OPTS_ROLL = [dict(split='test', roll_pixels=k) for k in range(2, 28, 2)]
41DATA_OPTS_ROTATE = [dict(split='test', rotate_degs=k)
42for k in range(15, 181, 15)]
43DATA_OPTS_OOD = [dict(split='test', dataset_name='fashion_mnist'),
44dict(split='test', dataset_name='not_mnist')]
45
46DATA_OPTIONS_LIST = [
47dict(split='train'),
48dict(split='valid'),
49dict(split='test')] + DATA_OPTS_ROLL + DATA_OPTS_ROTATE + DATA_OPTS_OOD
50
51
52FLAGS = flags.FLAGS
53flags.DEFINE_string('mnist_path_tmpl', None,
54'Template path to MNIST data tables.')
55flags.DEFINE_string('not_mnist_path_tmpl', None,
56'Template path to NotMNIST data tables.')
57
58
59@attr.s
60class MnistDataOptions(object):
61split = attr.ib()
62dataset_name = attr.ib('mnist')
63roll_pixels = attr.ib(0)
64rotate_degs = attr.ib(0)
65
66
67def _crop_center(images, size):
68height, width = images.shape[1:3]
69i0 = height // 2 - size // 2
70j0 = width // 2 - size // 2
71return images[:, i0:i0 + size, j0:j0 + size]
72
73
74def _tfr_parse_fn(serialized, img_bytes_key='image/encoded'):
75features = {'image/class/label': tf.io.FixedLenFeature((), tf.int64),
76img_bytes_key: tf.io.FixedLenFeature([1], tf.string)}
77parsed = tf.io.parse_single_example(serialized, features)
78image = tf.io.decode_raw(parsed[img_bytes_key], tf.uint8)
79image = tf.reshape(image, [28, 28, 1])
80return image, parsed['image/class/label']
81
82
83def _mnist_dataset_from_tfr(split_name):
84# train_small contains the first 50K rows the train set; valid is the last 10K
85split_key = 'train_small' if split_name == 'train' else split_name
86path = FLAGS.mnist_path_tmpl % split_key
87logging.info('Reading dataset from %s', path)
88parse_fn = functools.partial(_tfr_parse_fn, img_bytes_key='image/encoded')
89return tf.data.TFRecordDataset(path).map(parse_fn)
90
91
92def _not_mnist_dataset_from_tfr(split_name):
93if split_name != 'test':
94raise ValueError('We should only use NotMNIST test data.')
95path = FLAGS.not_mnist_path_tmpl % split_name
96logging.info('Reading dataset from %s', path)
97parse_fn = functools.partial(_tfr_parse_fn, img_bytes_key='image/raw')
98return tf.data.TFRecordDataset(path).map(parse_fn)
99
100
101def _dataset_from_tfds(dataset_name, split):
102if split != 'test':
103raise ValueError('We should only use split=test from tfds.')
104return tfds.load(dataset_name, split=split, as_supervised=True)
105
106
107def build_dataset(opts, fake_data=False):
108"""Returns an <images, labels> dataset pair."""
109opts = MnistDataOptions(**opts)
110logging.info('Building dataset with options: %s', opts)
111
112if fake_data:
113images = np.random.rand(DUMMY_DATA_SIZE, *MNIST_IMAGE_SHAPE)
114labels = np.random.randint(0, 10, DUMMY_DATA_SIZE)
115return images, labels
116
117# We can't use in-distribution data from tfds due to inconsistent orderings.
118if opts.dataset_name == 'mnist':
119dataset = _mnist_dataset_from_tfr(opts.split)
120elif opts.dataset_name == 'not_mnist':
121dataset = _not_mnist_dataset_from_tfr(opts.split)
122else:
123dataset = _dataset_from_tfds(opts.dataset_name, opts.split)
124
125# Download dataset to memory.
126images, labels = list(zip(*tfds.as_numpy(dataset.batch(10**4))))
127images = np.concatenate(images, axis=0).astype(np.float32)
128labels = np.concatenate(labels, axis=0)
129
130images /= 255
131if opts.rotate_degs:
132images = scipy.ndimage.rotate(images, opts.rotate_degs, axes=[-2, -3])
133images = _crop_center(images, 28)
134if opts.roll_pixels:
135images = np.roll(images, opts.roll_pixels, axis=-2)
136
137return images, labels
138