google-research
151 строка · 4.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16""""Datasets prepared for SG-MCMC methods.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import tensorflow.compat.v2 as tf
24import tensorflow_datasets as tfds
25
26from cold_posterior_bnn.imdb import imdb_data
27
28
29IMDB_NUM_WORDS = 20000
30IMDB_SEQUENCE_LENGTH = 100
31
32
33def load_imdb(with_info=False, subsample_n=0):
34"""Load IMDB dataset.
35
36Args:
37with_info: bool, whether to return info dictionary.
38subsample_n: int, if >0, subsample training set to this size.
39
40Returns:
41Tuple of (dataset dict, info dict) if with_info else only
42the dataset.
43"""
44(x_train, y_train), (x_val, y_val), (x_test, y_test) = imdb_data.load_data(
45num_words=IMDB_NUM_WORDS, maxlen=IMDB_SEQUENCE_LENGTH)
46
47original_train_size = x_train.shape[0]
48if subsample_n > 0:
49x_train = x_train[0:subsample_n, :]
50y_train = y_train[0:subsample_n]
51
52dataset = {
53'x_train': x_train,
54'y_train': y_train,
55'x_val': x_val,
56'y_val': y_val,
57'x_test': x_test,
58'y_test': y_test
59}
60
61if with_info:
62info = {
63'input_shape': x_train.shape[1:],
64'train_num_examples': x_train.shape[0],
65'train_num_examples_orig': original_train_size,
66'test_num_examples': x_test.shape[0],
67'num_classes': 2,
68'num_words': IMDB_NUM_WORDS,
69'sequence_length': IMDB_SEQUENCE_LENGTH,
70}
71return dataset, info
72
73return dataset
74
75
76def load_cifar10(split, with_info=False, data_augmentation=True,
77subsample_n=0):
78"""This is a fork of edward2.utils.load_dataset.
79
80Returns a tf.data.Dataset with <image, label> pairs.
81
82Args:
83split: tfds.Split.
84with_info: bool.
85data_augmentation: bool, if True perform simple data augmentation on the
86TRAIN split with random left/right flips and random cropping. If False,
87do not perform any data augmentation.
88subsample_n: int, if >0, subsample training set to this size.
89
90Returns:
91Tuple of (tf.data.Dataset, tf.data.DatasetInfo) if with_info else only
92the dataset.
93"""
94dataset, ds_info = tfds.load('cifar10',
95split=split,
96with_info=True,
97batch_size=-1)
98image_shape = ds_info.features['image'].shape
99
100numpy_ds = tfds.as_numpy(dataset)
101numpy_images = numpy_ds['image']
102numpy_labels = numpy_ds['label']
103
104# Perform subsampling if requested
105original_train_size = numpy_images.shape[0]
106if subsample_n > 0:
107subsample_n = min(numpy_images.shape[0], subsample_n)
108numpy_images = numpy_images[0:subsample_n, :, :, :]
109numpy_labels = numpy_labels[0:subsample_n]
110
111dataset = tf.data.Dataset.from_tensor_slices((numpy_images, numpy_labels))
112
113def preprocess(image, label):
114"""Image preprocessing function."""
115if data_augmentation and split == tfds.Split.TRAIN:
116image = tf.image.random_flip_left_right(image)
117image = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
118image = tf.image.random_crop(image, image_shape)
119
120image = tf.image.convert_image_dtype(image, tf.float32)
121return image, label
122
123dataset = dataset.map(preprocess)
124
125if with_info:
126info = {
127'train_num_examples': ds_info.splits['train'].num_examples,
128'train_num_examples_orig': original_train_size,
129'test_num_examples': ds_info.splits['test'].num_examples,
130'input_shape': ds_info.features['image'].shape,
131'num_classes': ds_info.features['label'].num_classes,
132}
133return dataset, info
134return dataset
135
136
137def get_generators_from_ds(dataset):
138"""Returns generators for efficient training.
139
140Args:
141dataset: dataset dictionary.
142
143Returns:
144tfds generators for training and test data.
145"""
146data_train = tf.data.Dataset.from_tensor_slices(
147(dataset['x_train'], dataset['y_train']))
148data_test = tf.data.Dataset.from_tensor_slices(
149(dataset['x_test'], dataset['y_test']))
150
151return data_train, data_test
152
153
154
155
156