google-research

Форк
0
290 строк · 8.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Sequential TinyImageNet dataset."""
17

18
import math
19
import os
20
from google_drive_downloader import GoogleDriveDownloader as gdd
21
import numpy as np
22
from PIL import Image
23
import torch.nn.functional as F
24
from torch.utils.data import Dataset
25
from torchvision import transforms
26
from gradient_coresets_replay.backbone.resnet18 import resnet18
27
from gradient_coresets_replay.datasets.transforms import DeNormalize
28
from gradient_coresets_replay.datasets.utils.continual_dataset import ContinualDataset
29
from gradient_coresets_replay.datasets.utils.continual_dataset import get_previous_train_loader
30
from gradient_coresets_replay.datasets.utils.continual_dataset import store_masked_loaders
31
from gradient_coresets_replay.datasets.utils.validation import get_train_val
32
from gradient_coresets_replay.utils.conf import base_path
33

34

35
class TinyImagenet(Dataset):
36
  """Defines Tiny Imagenet as for the others pytorch datasets."""
37

38
  def __init__(
39
      self,
40
      root,
41
      train = True,
42
      transform = None,
43
      target_transform = None,
44
      download = False,
45
  ):
46
    self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
47
    self.root = root
48
    self.train = train
49
    self.transform = transform
50
    self.target_transform = target_transform
51
    self.download = download
52

53
    if download:
54
      if os.path.isdir(root) and len(os.listdir(root)):
55
        print('Download not needed, files already on disk.')
56
      else:
57
        # https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view
58
        print('Downloading dataset')
59
        gdd.download_file_from_google_drive(
60
            file_id='1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj',
61
            dest_path=os.path.join(root, 'tiny-imagenet-processed.zip'),
62
            unzip=True,
63
        )
64

65
    self.data = []
66
    for num in range(20):
67
      self.data.append(
68
          np.load(
69
              os.path.join(
70
                  root,
71
                  'processed/x_%s_%02d.npy'
72
                  % ('train' if self.train else 'val', num + 1),
73
              )
74
          )
75
      )
76
    self.data = np.concatenate(np.array(self.data))
77

78
    self.targets = []
79
    for num in range(20):
80
      self.targets.append(
81
          np.load(
82
              os.path.join(
83
                  root,
84
                  'processed/y_%s_%02d.npy'
85
                  % ('train' if self.train else 'val', num + 1),
86
              )
87
          )
88
      )
89
    self.targets = np.concatenate(np.array(self.targets))
90

91
  def __len__(self):
92
    return len(self.data)
93

94
  def __getitem__(self, index):
95
    img, target = self.data[index], self.targets[index]
96

97
    # doing this so that it is consistent with all other datasets
98
    # to return a PIL Image
99
    img = Image.fromarray(np.uint8(255 * img))
100
    original_img = img.copy()
101

102
    if self.transform is not None:
103
      img = self.transform(img)
104

105
    if self.target_transform is not None:
106
      target = self.target_transform(target)
107

108
    if hasattr(self, 'logits'):
109
      return img, target, original_img, self.logits[index]
110

111
    return img, target
112

113

114
class MyTinyImagenet(TinyImagenet):
115
  """Defines Tiny Imagenet as for the others pytorch datasets."""
116

117
  def __init__(
118
      self,
119
      root,
120
      train = True,
121
      transform = None,
122
      target_transform = None,
123
      download = False,
124
  ):
125
    self.root = root
126
    super().__init__(
127
        root, train, transform, target_transform, download
128
    )
129

130
  def __getitem__(self, index):
131
    img, target = self.data[index], self.targets[index]
132

133
    # doing this so that it is consistent with all other datasets
134
    # to return a PIL Image
135
    img = Image.fromarray(np.uint8(255 * img))
136
    original_img = img.copy()
137

138
    not_aug_img = self.not_aug_transform(original_img)
139

140
    if self.transform is not None:
141
      img = self.transform(img)
142

143
    if self.target_transform is not None:
144
      target = self.target_transform(target)
145

146
    if hasattr(self, 'logits'):
147
      return img, target, not_aug_img, self.logits[index]
148

149
    return img, target, not_aug_img
150

151

152
class SequentialTinyImagenet(ContinualDataset):
153
  """Sequential TinyImageNet dataset."""
154

155
  name = 'seq-tinyimg'
156
  setting = 'class-il'
157
  n_classes_per_task = 20
158
  n_tasks = 10
159
  transform = transforms.Compose([
160
      transforms.RandomCrop(64, padding=4),
161
      transforms.RandomHorizontalFlip(),
162
      transforms.ToTensor(),
163
      transforms.Normalize((0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)),
164
  ])
165

166
  def __init__(self, args):
167
    super().__init__(args)
168
    self.train_dataset = MyTinyImagenet(
169
        base_path() + 'TINYIMG',
170
        train=True,
171
        download=True,
172
        transform=self.transform,
173
    )
174
    test_transform = transforms.Compose(
175
        [transforms.ToTensor(), self.get_normalization_transform()]
176
    )
177
    if self.args.validation:
178
      self.train_dataset, self.test_dataset = get_train_val(
179
          self.train_dataset, test_transform, self.name
180
      )
181
    else:
182
      self.test_dataset = MyTinyImagenet(
183
          base_path() + 'TINYIMG',
184
          train=False,
185
          download=True,
186
          transform=test_transform,
187
      )
188

189
    if self.args.streaming:
190
      self.current_pos = 0
191
      self.stream_train_indices = self.stream_indices()
192
      self.num_streams = math.ceil(
193
          len(self.stream_train_indices) / self.args.stream_batch_size
194
      )
195
      for _ in range(self.n_tasks):
196
        _ = self.get_data_loaders()  # to store self.test_dataloaders
197

198
  def get_data_loaders(self):
199
    transform = self.transform
200

201
    test_transform = transforms.Compose(
202
        [transforms.ToTensor(), self.get_normalization_transform()]
203
    )
204

205
    train_dataset = MyTinyImagenet(
206
        base_path() + 'TINYIMG', train=True, download=True, transform=transform
207
    )
208
    if self.args.validation:
209
      train_dataset, test_dataset = get_train_val(
210
          train_dataset, test_transform, self.name
211
      )
212
    else:
213
      test_dataset = TinyImagenet(
214
          base_path() + 'TINYIMG',
215
          train=False,
216
          download=True,
217
          transform=test_transform,
218
      )
219

220
    train, test = store_masked_loaders(train_dataset, test_dataset, self)
221
    return train, test
222

223
  def not_aug_dataloader(self, batch_size):
224
    transform = transforms.Compose(
225
        [transforms.ToTensor(), self.get_denormalization_transform()]
226
    )
227

228
    train_dataset = MyTinyImagenet(
229
        base_path() + 'TINYIMG', train=True, download=True, transform=transform
230
    )
231
    train_loader = get_previous_train_loader(train_dataset, batch_size, self)
232

233
    return train_loader
234

235
  @staticmethod
236
  def get_backbone():
237
    return resnet18(
238
        SequentialTinyImagenet.n_classes_per_task
239
        * SequentialTinyImagenet.n_tasks
240
    )
241

242
  @staticmethod
243
  def get_loss():
244
    return F.cross_entropy
245

246
  def get_transform(self):
247
    transform = transforms.Compose([transforms.ToPILImage(), self.transform])
248
    return transform
249

250
  @staticmethod
251
  def get_barlow_transform():
252
    transform = transforms.Compose([
253
        transforms.ToPILImage(),
254
        transforms.RandomResizedCrop(size=64, scale=(0.2, 1.0)),
255
        transforms.RandomHorizontalFlip(),
256
        transforms.RandomApply(
257
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
258
        ),
259
        transforms.RandomGrayscale(p=0.2),
260
        transforms.ToTensor(),
261
        transforms.Normalize(
262
            (0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)
263
        ),
264
    ])
265
    transform_prime = transforms.Compose([
266
        transforms.ToPILImage(),
267
        transforms.RandomResizedCrop(size=64, scale=(0.2, 1.0)),
268
        transforms.RandomHorizontalFlip(),
269
        transforms.RandomApply(
270
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
271
        ),
272
        transforms.RandomGrayscale(p=0.2),
273
        transforms.ToTensor(),
274
        transforms.Normalize(
275
            (0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)
276
        ),
277
    ])
278
    return (transform, transform_prime)
279

280
  @staticmethod
281
  def get_normalization_transform():
282
    transform = transforms.Normalize(
283
        (0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)
284
    )
285
    return transform
286

287
  @staticmethod
288
  def get_denormalization_transform():
289
    transform = DeNormalize((0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821))
290
    return transform
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.