google-research

Форк
0
151 строка · 4.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
""""Datasets prepared for SG-MCMC methods.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import tensorflow.compat.v2 as tf
24
import tensorflow_datasets as tfds
25

26
from cold_posterior_bnn.imdb import imdb_data
27

28

29
IMDB_NUM_WORDS = 20000
30
IMDB_SEQUENCE_LENGTH = 100
31

32

33
def load_imdb(with_info=False, subsample_n=0):
34
  """Load IMDB dataset.
35

36
  Args:
37
    with_info: bool, whether to return info dictionary.
38
    subsample_n: int, if >0, subsample training set to this size.
39

40
  Returns:
41
    Tuple of (dataset dict, info dict) if with_info else only
42
    the dataset.
43
  """
44
  (x_train, y_train), (x_val, y_val), (x_test, y_test) = imdb_data.load_data(
45
      num_words=IMDB_NUM_WORDS, maxlen=IMDB_SEQUENCE_LENGTH)
46

47
  original_train_size = x_train.shape[0]
48
  if subsample_n > 0:
49
    x_train = x_train[0:subsample_n, :]
50
    y_train = y_train[0:subsample_n]
51

52
  dataset = {
53
      'x_train': x_train,
54
      'y_train': y_train,
55
      'x_val': x_val,
56
      'y_val': y_val,
57
      'x_test': x_test,
58
      'y_test': y_test
59
  }
60

61
  if with_info:
62
    info = {
63
        'input_shape': x_train.shape[1:],
64
        'train_num_examples': x_train.shape[0],
65
        'train_num_examples_orig': original_train_size,
66
        'test_num_examples': x_test.shape[0],
67
        'num_classes': 2,
68
        'num_words': IMDB_NUM_WORDS,
69
        'sequence_length': IMDB_SEQUENCE_LENGTH,
70
    }
71
    return dataset, info
72

73
  return dataset
74

75

76
def load_cifar10(split, with_info=False, data_augmentation=True,
77
                 subsample_n=0):
78
  """This is a fork of edward2.utils.load_dataset.
79

80
  Returns a tf.data.Dataset with <image, label> pairs.
81

82
  Args:
83
    split: tfds.Split.
84
    with_info: bool.
85
    data_augmentation: bool, if True perform simple data augmentation on the
86
      TRAIN split with random left/right flips and random cropping.  If False,
87
      do not perform any data augmentation.
88
    subsample_n: int, if >0, subsample training set to this size.
89

90
  Returns:
91
    Tuple of (tf.data.Dataset, tf.data.DatasetInfo) if with_info else only
92
    the dataset.
93
  """
94
  dataset, ds_info = tfds.load('cifar10',
95
                               split=split,
96
                               with_info=True,
97
                               batch_size=-1)
98
  image_shape = ds_info.features['image'].shape
99

100
  numpy_ds = tfds.as_numpy(dataset)
101
  numpy_images = numpy_ds['image']
102
  numpy_labels = numpy_ds['label']
103

104
  # Perform subsampling if requested
105
  original_train_size = numpy_images.shape[0]
106
  if subsample_n > 0:
107
    subsample_n = min(numpy_images.shape[0], subsample_n)
108
    numpy_images = numpy_images[0:subsample_n, :, :, :]
109
    numpy_labels = numpy_labels[0:subsample_n]
110

111
  dataset = tf.data.Dataset.from_tensor_slices((numpy_images, numpy_labels))
112

113
  def preprocess(image, label):
114
    """Image preprocessing function."""
115
    if data_augmentation and split == tfds.Split.TRAIN:
116
      image = tf.image.random_flip_left_right(image)
117
      image = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
118
      image = tf.image.random_crop(image, image_shape)
119

120
    image = tf.image.convert_image_dtype(image, tf.float32)
121
    return image, label
122

123
  dataset = dataset.map(preprocess)
124

125
  if with_info:
126
    info = {
127
        'train_num_examples': ds_info.splits['train'].num_examples,
128
        'train_num_examples_orig': original_train_size,
129
        'test_num_examples': ds_info.splits['test'].num_examples,
130
        'input_shape': ds_info.features['image'].shape,
131
        'num_classes': ds_info.features['label'].num_classes,
132
    }
133
    return dataset, info
134
  return dataset
135

136

137
def get_generators_from_ds(dataset):
138
  """Returns generators for efficient training.
139

140
  Args:
141
    dataset: dataset dictionary.
142

143
  Returns:
144
    tfds generators for training and test data.
145
  """
146
  data_train = tf.data.Dataset.from_tensor_slices(
147
      (dataset['x_train'], dataset['y_train']))
148
  data_test = tf.data.Dataset.from_tensor_slices(
149
      (dataset['x_test'], dataset['y_test']))
150

151
  return data_train, data_test
152

153

154

155

156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.