google-research

Форк
0
/
data_loader.py 
235 строк · 7.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Data loading and processing.
17

18
Contains dataloaders and augmentation functions of MNIST,
19
CIFAR-10, CIFAR-100 and Tiny-ImageNet.
20
Also contains the implementation of dataset interface of prototypical learning.
21
"""
22
import numpy as np
23
import tensorflow.compat.v1 as tf
24
from tensorflow.compat.v1.keras.datasets import cifar10
25
from tensorflow.compat.v1.keras.datasets import cifar100
26
from tensorflow.compat.v1.keras.datasets import mnist
27
from dble import tiny_imagenet
28

29

30
def load_mnist():
31
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
32
  x_train = x_train.astype('float32')
33
  x_test = x_test.astype('float32')
34
  x_train = x_train.astype(float) / 255.
35
  x_test = x_test.astype(float) / 255.
36
  fields = x_train, np.squeeze(y_train)
37
  fields_test = x_test, np.squeeze(y_test)
38

39
  return fields, fields_test
40

41

42
def load_cifar10():
43
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
44
  x_train = x_train.astype('float32')
45
  x_test = x_test.astype('float32')
46
  fields = x_train, np.squeeze(y_train)
47
  fields_test = x_test, np.squeeze(y_test)
48

49
  return fields, fields_test
50

51

52
def load_cifar100():
53
  (x_train, y_train), (x_test, y_test) = cifar100.load_data()
54
  x_train = x_train.astype('float32')
55
  x_test = x_test.astype('float32')
56
  fields = x_train, np.squeeze(y_train)
57
  fields_test = x_test, np.squeeze(y_test)
58

59
  return fields, fields_test
60

61

62
def load_tiny_imagenet(data_dir, val_data_dir):
63
  """Loads the training and validation data of Tiny-ImageNet.
64

65
  Args:
66
    data_dir: The directory of raw training data of Tiny-ImageNet.
67
    val_data_dir: The directory of raw validation data of Tiny-ImageNet.
68
  Returns:
69
    fields: the tuple of (x_train, y_train). x_train is a numpy array with
70
    shape (num_samples, height, width, num_channels). y_train is a numpy array
71
    with shape (num_samples, ).
72
    fields_test: the tuple of (x_test, y_test).
73
  """
74
  (x_train, y_train, _), annotations = tiny_imagenet.load_training_images(
75
      data_dir)
76
  (x_test, y_test, _) = tiny_imagenet.load_validation_images(
77
      val_data_dir, annotations)
78
  x_train = x_train.astype('float32')
79
  x_test = x_test.astype('float32')
80
  indx = np.random.choice([i for i in range(98179)], size=100000, replace=True)
81
  x_train = x_train[indx]
82
  y_train = np.squeeze(y_train)[indx]
83
  indx_2 = np.argsort(y_train)
84
  fields = x_train[indx_2], y_train[indx_2]
85
  indx_test = np.random.choice([i for i in range(9832)],
86
                               size=10000,
87
                               replace=True)
88
  y_test = np.squeeze(y_test)[indx_test]
89
  x_test = x_test[indx_test]
90
  indx_test2 = np.argsort(y_test)
91
  fields_test = x_test[indx_test2], y_test[indx_test2]
92

93
  return fields, fields_test
94

95

96
def augment_cifar(batch_data, is_training=False):
97
  image = batch_data
98
  if is_training:
99
    image = tf.image.resize_image_with_crop_or_pad(batch_data, 32 + 8, 32 + 8)
100
    i = image.get_shape().as_list()[0]
101
    image = tf.random_crop(image, [i, 32, 32, 3])
102
    image = tf.image.random_flip_left_right(image)
103
  image = tf.image.per_image_standardization(image)
104

105
  return image
106

107

108
def augment_tinyimagenet(batch_data, is_training=False):
109
  image = batch_data
110
  if is_training:
111
    image = tf.image.random_flip_left_right(image)
112
  image = tf.image.per_image_standardization(image)
113

114
  return image
115

116

117
def get_image_size(dataset_name):
118
  if dataset_name == 'cifar10' or dataset_name == 'cifar100':
119
    image_size = 32
120
  else:
121
    image_size = 64
122
  return image_size
123

124

125
def uniform(n):
126

127
  def sampler(n_samples, rng=np.random):
128
    return rng.choice(n, n_samples)
129

130
  return sampler
131

132

133
class Dataset(object):
134
  """Basic dataset interface for prototypical learning."""
135

136
  def __init__(self, fields):
137
    """Store a tuple of fields and access it through next_batch interface.
138

139
    Args:
140
      fields: field[0] and field[1] are considered to be x and y.
141
    """
142
    self.n_samples = len(fields[0])
143
    self.fields = fields
144
    self.sampler = uniform(self.n_samples)
145

146
  @property
147
  def x(self):
148
    return self.fields[0]
149

150
  @property
151
  def y(self):
152
    return self.fields[1]
153

154
  def next_batch(self, n, rng=np.random):
155
    idx = self.sampler(n, rng)
156
    return tuple(field[idx] for field in self.fields)
157

158
  def get_few_shot_idxs(self, classes, num_supports):
159
    """Samples the supports and queries given classes and return their indexs.
160

161
    Args:
162
      classes: The list of indexs of classes in the episode.
163
      num_supports: A scalar describing the number of support samples needed for
164
      every class.
165
    Returns:
166
      np.array(support_idxs): indexs of the supports sampled.
167
      np.array(query_idxs): indexs of the queries sampled.
168
    """
169
    support_idxs, query_idxs = [], []
170
    idxs = np.arange(len(self.y))
171
    for cl in classes:
172
      class_idxs = idxs[self.y == cl]
173
      class_idxs_support = np.random.choice(
174
          class_idxs, size=num_supports, replace=False)
175
      class_idxs_query = np.setxor1d(class_idxs, class_idxs_support)
176

177
      support_idxs.extend(class_idxs_support)
178
      query_idxs.extend(class_idxs_query)
179

180
    return np.array(support_idxs), np.array(query_idxs)
181

182
  def next_few_shot_batch(self, query_batch_size_per_task, num_classes_per_task,
183
                          num_supports_per_class, num_tasks):
184
    """Samples the few-shot batch for prototypical training.
185

186
    Args:
187
      query_batch_size_per_task: The number of queries required
188
      for every task(episode).
189
      num_classes_per_task: The number of classes for every task.
190
      num_supports_per_class: The number of support samples required for every
191
      class.
192
      num_tasks: Task number of the batch.
193
    Returns:
194
      np.concatenate(query_images, axis=0): numpy array of query images with
195
      shape (query_batch_size_per_task*num_tasks, height, width, num_channels).
196
      np.concatenate(query_labels, axis=0): numpy array of query labels with
197
      shape (query_batch_size_per_task*num_tasks, ).
198
      np.concatenate(support_images, axis=0): numpy array of support images with
199
      shape (num_classes_per_task*num_supports_per_class*num_tasks,
200
      height, width, num_channels).
201
      np.concatenate(support_labels, axis=0): numpy array of support labels with
202
      shape (num_classes_per_task*num_supports_per_class*num_tasks, ).
203
    """
204
    labels = self.y
205
    classes = np.unique(labels)
206

207
    query_images = []
208
    query_labels = []
209
    support_images = []
210
    support_labels = []
211
    for _ in range(num_tasks):
212
      task_classes = np.random.choice(
213
          classes, size=num_classes_per_task, replace=False)
214

215
      support_idxs, query_idxs = self.get_few_shot_idxs(
216
          classes=task_classes, num_supports=num_supports_per_class)
217
      query_idxs = np.random.choice(
218
          query_idxs, size=query_batch_size_per_task, replace=False)
219

220
      labels_query = labels[query_idxs]
221
      labels_support = labels[support_idxs]
222

223
      class_map = {c: i for i, c in enumerate(task_classes)}
224
      # pylint: disable=cell-var-from-loop
225
      class_map_fn = np.vectorize(lambda t: class_map[t])
226

227
      query_images.append(self.x[query_idxs])
228
      query_labels.append(class_map_fn(labels_query))
229
      support_images.append(self.x[support_idxs])
230
      support_labels.append(class_map_fn(labels_support))
231

232
    return np.concatenate(query_images, axis=0), \
233
        np.concatenate(query_labels, axis=0), \
234
        np.concatenate(support_images, axis=0), \
235
        np.concatenate(support_labels, axis=0)
236

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.