google-research

Форк
0
209 строк · 6.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Sequential CIFAR10 dataset."""
17

18
import math
19
from PIL import Image
20
import torch.nn.functional as F
21
from torchvision import transforms
22
from torchvision.datasets import CIFAR10
23
from gradient_coresets_replay.backbone.resnet18 import resnet18
24
from gradient_coresets_replay.datasets.transforms import DeNormalize
25
from gradient_coresets_replay.datasets.utils.continual_dataset import ContinualDataset
26
from gradient_coresets_replay.datasets.utils.continual_dataset import get_previous_train_loader
27
from gradient_coresets_replay.datasets.utils.continual_dataset import store_masked_loaders
28
from gradient_coresets_replay.datasets.utils.validation import get_train_val
29
from gradient_coresets_replay.utils.conf import base_path
30

31

32
class MyCIFAR10(CIFAR10):
33
  """Overrides the CIFAR10 dataset to change the getitem function."""
34

35
  def __init__(
36
      self,
37
      root,
38
      train=True,
39
      transform=None,
40
      target_transform=None,
41
      download=False,
42
  ):
43
    self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
44
    super().__init__(
45
        root, train, transform, target_transform, download
46
    )
47

48
  def __getitem__(self, index):
49
    img, target = self.data[index], self.targets[index]
50

51
    # to return a PIL Image
52
    img = Image.fromarray(img, mode='RGB')
53
    original_img = img.copy()
54

55
    not_aug_img = self.not_aug_transform(original_img)
56

57
    if self.transform is not None:
58
      img = self.transform(img)
59

60
    if self.target_transform is not None:
61
      target = self.target_transform(target)
62

63
    if hasattr(self, 'logits'):
64
      return img, target, not_aug_img, self.logits[index]
65

66
    return img, target, not_aug_img
67

68

69
class SequentialCIFAR10(ContinualDataset):
70
  """Sequential CIFAR10 dataset."""
71

72
  name = 'seq-cifar10'
73
  setting = 'class-il'
74
  n_classes_per_task = 2
75
  n_tasks = 5
76
  transform = transforms.Compose([
77
      transforms.RandomCrop(32, padding=4),
78
      transforms.RandomHorizontalFlip(),
79
      transforms.ToTensor(),
80
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)),
81
  ])
82

83
  def __init__(self, args):
84
    super().__init__(args)
85
    self.train_dataset = MyCIFAR10(
86
        base_path() + 'CIFAR10',
87
        train=True,
88
        download=True,
89
        transform=self.transform,
90
    )
91
    test_transform = transforms.Compose(
92
        [transforms.ToTensor(), self.get_normalization_transform()]
93
    )
94
    if self.args.validation:
95
      self.train_dataset, self.test_dataset = get_train_val(
96
          self.train_dataset, test_transform, self.name
97
      )
98
    else:
99
      self.test_dataset = CIFAR10(
100
          base_path() + 'CIFAR10',
101
          train=False,
102
          download=True,
103
          transform=test_transform,
104
      )
105

106
    if self.args.streaming:
107
      self.current_pos = 0
108
      self.stream_train_indices = self.stream_indices()
109
      self.num_streams = math.ceil(
110
          len(self.stream_train_indices) / self.args.stream_batch_size
111
      )
112
      for _ in range(self.n_tasks):
113
        _ = self.get_data_loaders()  # to store self.test_dataloaders
114

115
  def get_data_loaders(self):
116
    transform = self.transform
117

118
    test_transform = transforms.Compose(
119
        [transforms.ToTensor(), self.get_normalization_transform()]
120
    )
121

122
    train_dataset = MyCIFAR10(
123
        base_path() + 'CIFAR10', train=True, download=True, transform=transform
124
    )
125
    if self.args.validation:
126
      train_dataset, test_dataset = get_train_val(
127
          train_dataset, test_transform, self.name
128
      )
129
    else:
130
      test_dataset = CIFAR10(
131
          base_path() + 'CIFAR10',
132
          train=False,
133
          download=True,
134
          transform=test_transform,
135
      )
136

137
    train, test = store_masked_loaders(train_dataset, test_dataset, self)
138
    return train, test
139

140
  def not_aug_dataloader(self, batch_size):
141
    transform = transforms.Compose(
142
        [transforms.ToTensor(), self.get_normalization_transform()]
143
    )
144

145
    train_dataset = MyCIFAR10(
146
        base_path() + 'CIFAR10', train=True, download=True, transform=transform
147
    )
148
    train_loader = get_previous_train_loader(train_dataset, batch_size, self)
149

150
    return train_loader
151

152
  @staticmethod
153
  def get_transform():
154
    transform = transforms.Compose(
155
        [transforms.ToPILImage(), SequentialCIFAR10.transform]
156
    )
157
    return transform
158

159
  @staticmethod
160
  def get_barlow_transform():
161
    transform = transforms.Compose([
162
        transforms.ToPILImage(),
163
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
164
        transforms.RandomHorizontalFlip(),
165
        transforms.RandomApply(
166
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
167
        ),
168
        transforms.RandomGrayscale(p=0.2),
169
        transforms.ToTensor(),
170
        transforms.Normalize(
171
            (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
172
        ),
173
    ])
174
    transform_prime = transforms.Compose([
175
        transforms.ToPILImage(),
176
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
177
        transforms.RandomHorizontalFlip(),
178
        transforms.RandomApply(
179
            [transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
180
        ),
181
        transforms.RandomGrayscale(p=0.2),
182
        transforms.ToTensor(),
183
        transforms.Normalize(
184
            (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
185
        ),
186
    ])
187
    return (transform, transform_prime)
188

189
  @staticmethod
190
  def get_backbone():
191
    return resnet18(
192
        SequentialCIFAR10.n_classes_per_task * SequentialCIFAR10.n_tasks
193
    )
194

195
  @staticmethod
196
  def get_loss():
197
    return F.cross_entropy
198

199
  @staticmethod
200
  def get_normalization_transform():
201
    transform = transforms.Normalize(
202
        (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
203
    )
204
    return transform
205

206
  @staticmethod
207
  def get_denormalization_transform():
208
    transform = DeNormalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615))
209
    return transform
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.