google-research
209 строк · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Sequential CIFAR10 dataset."""
17
18import math
19from PIL import Image
20import torch.nn.functional as F
21from torchvision import transforms
22from torchvision.datasets import CIFAR10
23from gradient_coresets_replay.backbone.resnet18 import resnet18
24from gradient_coresets_replay.datasets.transforms import DeNormalize
25from gradient_coresets_replay.datasets.utils.continual_dataset import ContinualDataset
26from gradient_coresets_replay.datasets.utils.continual_dataset import get_previous_train_loader
27from gradient_coresets_replay.datasets.utils.continual_dataset import store_masked_loaders
28from gradient_coresets_replay.datasets.utils.validation import get_train_val
29from gradient_coresets_replay.utils.conf import base_path
30
31
32class MyCIFAR10(CIFAR10):
33"""Overrides the CIFAR10 dataset to change the getitem function."""
34
35def __init__(
36self,
37root,
38train=True,
39transform=None,
40target_transform=None,
41download=False,
42):
43self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
44super().__init__(
45root, train, transform, target_transform, download
46)
47
48def __getitem__(self, index):
49img, target = self.data[index], self.targets[index]
50
51# to return a PIL Image
52img = Image.fromarray(img, mode='RGB')
53original_img = img.copy()
54
55not_aug_img = self.not_aug_transform(original_img)
56
57if self.transform is not None:
58img = self.transform(img)
59
60if self.target_transform is not None:
61target = self.target_transform(target)
62
63if hasattr(self, 'logits'):
64return img, target, not_aug_img, self.logits[index]
65
66return img, target, not_aug_img
67
68
69class SequentialCIFAR10(ContinualDataset):
70"""Sequential CIFAR10 dataset."""
71
72name = 'seq-cifar10'
73setting = 'class-il'
74n_classes_per_task = 2
75n_tasks = 5
76transform = transforms.Compose([
77transforms.RandomCrop(32, padding=4),
78transforms.RandomHorizontalFlip(),
79transforms.ToTensor(),
80transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)),
81])
82
83def __init__(self, args):
84super().__init__(args)
85self.train_dataset = MyCIFAR10(
86base_path() + 'CIFAR10',
87train=True,
88download=True,
89transform=self.transform,
90)
91test_transform = transforms.Compose(
92[transforms.ToTensor(), self.get_normalization_transform()]
93)
94if self.args.validation:
95self.train_dataset, self.test_dataset = get_train_val(
96self.train_dataset, test_transform, self.name
97)
98else:
99self.test_dataset = CIFAR10(
100base_path() + 'CIFAR10',
101train=False,
102download=True,
103transform=test_transform,
104)
105
106if self.args.streaming:
107self.current_pos = 0
108self.stream_train_indices = self.stream_indices()
109self.num_streams = math.ceil(
110len(self.stream_train_indices) / self.args.stream_batch_size
111)
112for _ in range(self.n_tasks):
113_ = self.get_data_loaders() # to store self.test_dataloaders
114
115def get_data_loaders(self):
116transform = self.transform
117
118test_transform = transforms.Compose(
119[transforms.ToTensor(), self.get_normalization_transform()]
120)
121
122train_dataset = MyCIFAR10(
123base_path() + 'CIFAR10', train=True, download=True, transform=transform
124)
125if self.args.validation:
126train_dataset, test_dataset = get_train_val(
127train_dataset, test_transform, self.name
128)
129else:
130test_dataset = CIFAR10(
131base_path() + 'CIFAR10',
132train=False,
133download=True,
134transform=test_transform,
135)
136
137train, test = store_masked_loaders(train_dataset, test_dataset, self)
138return train, test
139
140def not_aug_dataloader(self, batch_size):
141transform = transforms.Compose(
142[transforms.ToTensor(), self.get_normalization_transform()]
143)
144
145train_dataset = MyCIFAR10(
146base_path() + 'CIFAR10', train=True, download=True, transform=transform
147)
148train_loader = get_previous_train_loader(train_dataset, batch_size, self)
149
150return train_loader
151
152@staticmethod
153def get_transform():
154transform = transforms.Compose(
155[transforms.ToPILImage(), SequentialCIFAR10.transform]
156)
157return transform
158
159@staticmethod
160def get_barlow_transform():
161transform = transforms.Compose([
162transforms.ToPILImage(),
163transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
164transforms.RandomHorizontalFlip(),
165transforms.RandomApply(
166[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
167),
168transforms.RandomGrayscale(p=0.2),
169transforms.ToTensor(),
170transforms.Normalize(
171(0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
172),
173])
174transform_prime = transforms.Compose([
175transforms.ToPILImage(),
176transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
177transforms.RandomHorizontalFlip(),
178transforms.RandomApply(
179[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
180),
181transforms.RandomGrayscale(p=0.2),
182transforms.ToTensor(),
183transforms.Normalize(
184(0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
185),
186])
187return (transform, transform_prime)
188
189@staticmethod
190def get_backbone():
191return resnet18(
192SequentialCIFAR10.n_classes_per_task * SequentialCIFAR10.n_tasks
193)
194
195@staticmethod
196def get_loss():
197return F.cross_entropy
198
199@staticmethod
200def get_normalization_transform():
201transform = transforms.Normalize(
202(0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615)
203)
204return transform
205
206@staticmethod
207def get_denormalization_transform():
208transform = DeNormalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2615))
209return transform
210