google-research
235 строк · 7.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Data loading and processing.
17
18Contains dataloaders and augmentation functions of MNIST,
19CIFAR-10, CIFAR-100 and Tiny-ImageNet.
20Also contains the implementation of dataset interface of prototypical learning.
21"""
22import numpy as np
23import tensorflow.compat.v1 as tf
24from tensorflow.compat.v1.keras.datasets import cifar10
25from tensorflow.compat.v1.keras.datasets import cifar100
26from tensorflow.compat.v1.keras.datasets import mnist
27from dble import tiny_imagenet
28
29
30def load_mnist():
31(x_train, y_train), (x_test, y_test) = mnist.load_data()
32x_train = x_train.astype('float32')
33x_test = x_test.astype('float32')
34x_train = x_train.astype(float) / 255.
35x_test = x_test.astype(float) / 255.
36fields = x_train, np.squeeze(y_train)
37fields_test = x_test, np.squeeze(y_test)
38
39return fields, fields_test
40
41
42def load_cifar10():
43(x_train, y_train), (x_test, y_test) = cifar10.load_data()
44x_train = x_train.astype('float32')
45x_test = x_test.astype('float32')
46fields = x_train, np.squeeze(y_train)
47fields_test = x_test, np.squeeze(y_test)
48
49return fields, fields_test
50
51
52def load_cifar100():
53(x_train, y_train), (x_test, y_test) = cifar100.load_data()
54x_train = x_train.astype('float32')
55x_test = x_test.astype('float32')
56fields = x_train, np.squeeze(y_train)
57fields_test = x_test, np.squeeze(y_test)
58
59return fields, fields_test
60
61
62def load_tiny_imagenet(data_dir, val_data_dir):
63"""Loads the training and validation data of Tiny-ImageNet.
64
65Args:
66data_dir: The directory of raw training data of Tiny-ImageNet.
67val_data_dir: The directory of raw validation data of Tiny-ImageNet.
68Returns:
69fields: the tuple of (x_train, y_train). x_train is a numpy array with
70shape (num_samples, height, width, num_channels). y_train is a numpy array
71with shape (num_samples, ).
72fields_test: the tuple of (x_test, y_test).
73"""
74(x_train, y_train, _), annotations = tiny_imagenet.load_training_images(
75data_dir)
76(x_test, y_test, _) = tiny_imagenet.load_validation_images(
77val_data_dir, annotations)
78x_train = x_train.astype('float32')
79x_test = x_test.astype('float32')
80indx = np.random.choice([i for i in range(98179)], size=100000, replace=True)
81x_train = x_train[indx]
82y_train = np.squeeze(y_train)[indx]
83indx_2 = np.argsort(y_train)
84fields = x_train[indx_2], y_train[indx_2]
85indx_test = np.random.choice([i for i in range(9832)],
86size=10000,
87replace=True)
88y_test = np.squeeze(y_test)[indx_test]
89x_test = x_test[indx_test]
90indx_test2 = np.argsort(y_test)
91fields_test = x_test[indx_test2], y_test[indx_test2]
92
93return fields, fields_test
94
95
96def augment_cifar(batch_data, is_training=False):
97image = batch_data
98if is_training:
99image = tf.image.resize_image_with_crop_or_pad(batch_data, 32 + 8, 32 + 8)
100i = image.get_shape().as_list()[0]
101image = tf.random_crop(image, [i, 32, 32, 3])
102image = tf.image.random_flip_left_right(image)
103image = tf.image.per_image_standardization(image)
104
105return image
106
107
108def augment_tinyimagenet(batch_data, is_training=False):
109image = batch_data
110if is_training:
111image = tf.image.random_flip_left_right(image)
112image = tf.image.per_image_standardization(image)
113
114return image
115
116
117def get_image_size(dataset_name):
118if dataset_name == 'cifar10' or dataset_name == 'cifar100':
119image_size = 32
120else:
121image_size = 64
122return image_size
123
124
125def uniform(n):
126
127def sampler(n_samples, rng=np.random):
128return rng.choice(n, n_samples)
129
130return sampler
131
132
133class Dataset(object):
134"""Basic dataset interface for prototypical learning."""
135
136def __init__(self, fields):
137"""Store a tuple of fields and access it through next_batch interface.
138
139Args:
140fields: field[0] and field[1] are considered to be x and y.
141"""
142self.n_samples = len(fields[0])
143self.fields = fields
144self.sampler = uniform(self.n_samples)
145
146@property
147def x(self):
148return self.fields[0]
149
150@property
151def y(self):
152return self.fields[1]
153
154def next_batch(self, n, rng=np.random):
155idx = self.sampler(n, rng)
156return tuple(field[idx] for field in self.fields)
157
158def get_few_shot_idxs(self, classes, num_supports):
159"""Samples the supports and queries given classes and return their indexs.
160
161Args:
162classes: The list of indexs of classes in the episode.
163num_supports: A scalar describing the number of support samples needed for
164every class.
165Returns:
166np.array(support_idxs): indexs of the supports sampled.
167np.array(query_idxs): indexs of the queries sampled.
168"""
169support_idxs, query_idxs = [], []
170idxs = np.arange(len(self.y))
171for cl in classes:
172class_idxs = idxs[self.y == cl]
173class_idxs_support = np.random.choice(
174class_idxs, size=num_supports, replace=False)
175class_idxs_query = np.setxor1d(class_idxs, class_idxs_support)
176
177support_idxs.extend(class_idxs_support)
178query_idxs.extend(class_idxs_query)
179
180return np.array(support_idxs), np.array(query_idxs)
181
182def next_few_shot_batch(self, query_batch_size_per_task, num_classes_per_task,
183num_supports_per_class, num_tasks):
184"""Samples the few-shot batch for prototypical training.
185
186Args:
187query_batch_size_per_task: The number of queries required
188for every task(episode).
189num_classes_per_task: The number of classes for every task.
190num_supports_per_class: The number of support samples required for every
191class.
192num_tasks: Task number of the batch.
193Returns:
194np.concatenate(query_images, axis=0): numpy array of query images with
195shape (query_batch_size_per_task*num_tasks, height, width, num_channels).
196np.concatenate(query_labels, axis=0): numpy array of query labels with
197shape (query_batch_size_per_task*num_tasks, ).
198np.concatenate(support_images, axis=0): numpy array of support images with
199shape (num_classes_per_task*num_supports_per_class*num_tasks,
200height, width, num_channels).
201np.concatenate(support_labels, axis=0): numpy array of support labels with
202shape (num_classes_per_task*num_supports_per_class*num_tasks, ).
203"""
204labels = self.y
205classes = np.unique(labels)
206
207query_images = []
208query_labels = []
209support_images = []
210support_labels = []
211for _ in range(num_tasks):
212task_classes = np.random.choice(
213classes, size=num_classes_per_task, replace=False)
214
215support_idxs, query_idxs = self.get_few_shot_idxs(
216classes=task_classes, num_supports=num_supports_per_class)
217query_idxs = np.random.choice(
218query_idxs, size=query_batch_size_per_task, replace=False)
219
220labels_query = labels[query_idxs]
221labels_support = labels[support_idxs]
222
223class_map = {c: i for i, c in enumerate(task_classes)}
224# pylint: disable=cell-var-from-loop
225class_map_fn = np.vectorize(lambda t: class_map[t])
226
227query_images.append(self.x[query_idxs])
228query_labels.append(class_map_fn(labels_query))
229support_images.append(self.x[support_idxs])
230support_labels.append(class_map_fn(labels_support))
231
232return np.concatenate(query_images, axis=0), \
233np.concatenate(query_labels, axis=0), \
234np.concatenate(support_images, axis=0), \
235np.concatenate(support_labels, axis=0)
236