google-research

Форк
0
137 строк · 4.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Library for constructing MNIST and distorted MNIST datasets."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import functools
23

24
from absl import flags
25
from absl import logging
26

27
import attr
28
import numpy as np
29
import scipy.ndimage
30
from six.moves import range
31
from six.moves import zip
32

33
import tensorflow.compat.v2 as tf
34
import tensorflow_datasets as tfds
35

36
NUM_TRAIN_EXAMPLES = 50 * 1000
37
DUMMY_DATA_SIZE = 99
38
MNIST_IMAGE_SHAPE = (28, 28, 1)
39

40
DATA_OPTS_ROLL = [dict(split='test', roll_pixels=k) for k in range(2, 28, 2)]
41
DATA_OPTS_ROTATE = [dict(split='test', rotate_degs=k)
42
                    for k in range(15, 181, 15)]
43
DATA_OPTS_OOD = [dict(split='test', dataset_name='fashion_mnist'),
44
                 dict(split='test', dataset_name='not_mnist')]
45

46
DATA_OPTIONS_LIST = [
47
    dict(split='train'),
48
    dict(split='valid'),
49
    dict(split='test')] + DATA_OPTS_ROLL + DATA_OPTS_ROTATE + DATA_OPTS_OOD
50

51

52
FLAGS = flags.FLAGS
53
flags.DEFINE_string('mnist_path_tmpl', None,
54
                    'Template path to MNIST data tables.')
55
flags.DEFINE_string('not_mnist_path_tmpl', None,
56
                    'Template path to NotMNIST data tables.')
57

58

59
@attr.s
60
class MnistDataOptions(object):
61
  split = attr.ib()
62
  dataset_name = attr.ib('mnist')
63
  roll_pixels = attr.ib(0)
64
  rotate_degs = attr.ib(0)
65

66

67
def _crop_center(images, size):
68
  height, width = images.shape[1:3]
69
  i0 = height // 2 - size // 2
70
  j0 = width // 2 - size // 2
71
  return images[:, i0:i0 + size, j0:j0 + size]
72

73

74
def _tfr_parse_fn(serialized, img_bytes_key='image/encoded'):
75
  features = {'image/class/label': tf.io.FixedLenFeature((), tf.int64),
76
              img_bytes_key: tf.io.FixedLenFeature([1], tf.string)}
77
  parsed = tf.io.parse_single_example(serialized, features)
78
  image = tf.io.decode_raw(parsed[img_bytes_key], tf.uint8)
79
  image = tf.reshape(image, [28, 28, 1])
80
  return image, parsed['image/class/label']
81

82

83
def _mnist_dataset_from_tfr(split_name):
84
  # train_small contains the first 50K rows the train set; valid is the last 10K
85
  split_key = 'train_small' if split_name == 'train' else split_name
86
  path = FLAGS.mnist_path_tmpl % split_key
87
  logging.info('Reading dataset from %s', path)
88
  parse_fn = functools.partial(_tfr_parse_fn, img_bytes_key='image/encoded')
89
  return tf.data.TFRecordDataset(path).map(parse_fn)
90

91

92
def _not_mnist_dataset_from_tfr(split_name):
93
  if split_name != 'test':
94
    raise ValueError('We should only use NotMNIST test data.')
95
  path = FLAGS.not_mnist_path_tmpl % split_name
96
  logging.info('Reading dataset from %s', path)
97
  parse_fn = functools.partial(_tfr_parse_fn, img_bytes_key='image/raw')
98
  return tf.data.TFRecordDataset(path).map(parse_fn)
99

100

101
def _dataset_from_tfds(dataset_name, split):
102
  if split != 'test':
103
    raise ValueError('We should only use split=test from tfds.')
104
  return tfds.load(dataset_name, split=split, as_supervised=True)
105

106

107
def build_dataset(opts, fake_data=False):
108
  """Returns an <images, labels> dataset pair."""
109
  opts = MnistDataOptions(**opts)
110
  logging.info('Building dataset with options: %s', opts)
111

112
  if fake_data:
113
    images = np.random.rand(DUMMY_DATA_SIZE, *MNIST_IMAGE_SHAPE)
114
    labels = np.random.randint(0, 10, DUMMY_DATA_SIZE)
115
    return images, labels
116

117
  # We can't use in-distribution data from tfds due to inconsistent orderings.
118
  if opts.dataset_name == 'mnist':
119
    dataset = _mnist_dataset_from_tfr(opts.split)
120
  elif opts.dataset_name == 'not_mnist':
121
    dataset = _not_mnist_dataset_from_tfr(opts.split)
122
  else:
123
    dataset = _dataset_from_tfds(opts.dataset_name, opts.split)
124

125
  # Download dataset to memory.
126
  images, labels = list(zip(*tfds.as_numpy(dataset.batch(10**4))))
127
  images = np.concatenate(images, axis=0).astype(np.float32)
128
  labels = np.concatenate(labels, axis=0)
129

130
  images /= 255
131
  if opts.rotate_degs:
132
    images = scipy.ndimage.rotate(images, opts.rotate_degs, axes=[-2, -3])
133
    images = _crop_center(images, 28)
134
  if opts.roll_pixels:
135
    images = np.roll(images, opts.roll_pixels, axis=-2)
136

137
  return images, labels
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.