google-research
290 строк · 8.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Sequential TinyImageNet dataset."""
17
18import math19import os20from google_drive_downloader import GoogleDriveDownloader as gdd21import numpy as np22from PIL import Image23import torch.nn.functional as F24from torch.utils.data import Dataset25from torchvision import transforms26from gradient_coresets_replay.backbone.resnet18 import resnet1827from gradient_coresets_replay.datasets.transforms import DeNormalize28from gradient_coresets_replay.datasets.utils.continual_dataset import ContinualDataset29from gradient_coresets_replay.datasets.utils.continual_dataset import get_previous_train_loader30from gradient_coresets_replay.datasets.utils.continual_dataset import store_masked_loaders31from gradient_coresets_replay.datasets.utils.validation import get_train_val32from gradient_coresets_replay.utils.conf import base_path33
34
35class TinyImagenet(Dataset):36"""Defines Tiny Imagenet as for the others pytorch datasets."""37
38def __init__(39self,40root,41train = True,42transform = None,43target_transform = None,44download = False,45):46self.not_aug_transform = transforms.Compose([transforms.ToTensor()])47self.root = root48self.train = train49self.transform = transform50self.target_transform = target_transform51self.download = download52
53if download:54if os.path.isdir(root) and len(os.listdir(root)):55print('Download not needed, files already on disk.')56else:57# https://drive.google.com/file/d/1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj/view58print('Downloading dataset')59gdd.download_file_from_google_drive(60file_id='1Sy3ScMBr0F4se8VZ6TAwDYF-nNGAAdxj',61dest_path=os.path.join(root, 'tiny-imagenet-processed.zip'),62unzip=True,63)64
65self.data = []66for num in range(20):67self.data.append(68np.load(69os.path.join(70root,71'processed/x_%s_%02d.npy'72% ('train' if self.train else 'val', num + 1),73)74)75)76self.data = np.concatenate(np.array(self.data))77
78self.targets = []79for num in range(20):80self.targets.append(81np.load(82os.path.join(83root,84'processed/y_%s_%02d.npy'85% ('train' if self.train else 'val', num + 1),86)87)88)89self.targets = np.concatenate(np.array(self.targets))90
91def __len__(self):92return len(self.data)93
94def __getitem__(self, index):95img, target = self.data[index], self.targets[index]96
97# doing this so that it is consistent with all other datasets98# to return a PIL Image99img = Image.fromarray(np.uint8(255 * img))100original_img = img.copy()101
102if self.transform is not None:103img = self.transform(img)104
105if self.target_transform is not None:106target = self.target_transform(target)107
108if hasattr(self, 'logits'):109return img, target, original_img, self.logits[index]110
111return img, target112
113
114class MyTinyImagenet(TinyImagenet):115"""Defines Tiny Imagenet as for the others pytorch datasets."""116
117def __init__(118self,119root,120train = True,121transform = None,122target_transform = None,123download = False,124):125self.root = root126super().__init__(127root, train, transform, target_transform, download128)129
130def __getitem__(self, index):131img, target = self.data[index], self.targets[index]132
133# doing this so that it is consistent with all other datasets134# to return a PIL Image135img = Image.fromarray(np.uint8(255 * img))136original_img = img.copy()137
138not_aug_img = self.not_aug_transform(original_img)139
140if self.transform is not None:141img = self.transform(img)142
143if self.target_transform is not None:144target = self.target_transform(target)145
146if hasattr(self, 'logits'):147return img, target, not_aug_img, self.logits[index]148
149return img, target, not_aug_img150
151
152class SequentialTinyImagenet(ContinualDataset):153"""Sequential TinyImageNet dataset."""154
155name = 'seq-tinyimg'156setting = 'class-il'157n_classes_per_task = 20158n_tasks = 10159transform = transforms.Compose([160transforms.RandomCrop(64, padding=4),161transforms.RandomHorizontalFlip(),162transforms.ToTensor(),163transforms.Normalize((0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)),164])165
166def __init__(self, args):167super().__init__(args)168self.train_dataset = MyTinyImagenet(169base_path() + 'TINYIMG',170train=True,171download=True,172transform=self.transform,173)174test_transform = transforms.Compose(175[transforms.ToTensor(), self.get_normalization_transform()]176)177if self.args.validation:178self.train_dataset, self.test_dataset = get_train_val(179self.train_dataset, test_transform, self.name180)181else:182self.test_dataset = MyTinyImagenet(183base_path() + 'TINYIMG',184train=False,185download=True,186transform=test_transform,187)188
189if self.args.streaming:190self.current_pos = 0191self.stream_train_indices = self.stream_indices()192self.num_streams = math.ceil(193len(self.stream_train_indices) / self.args.stream_batch_size194)195for _ in range(self.n_tasks):196_ = self.get_data_loaders() # to store self.test_dataloaders197
198def get_data_loaders(self):199transform = self.transform200
201test_transform = transforms.Compose(202[transforms.ToTensor(), self.get_normalization_transform()]203)204
205train_dataset = MyTinyImagenet(206base_path() + 'TINYIMG', train=True, download=True, transform=transform207)208if self.args.validation:209train_dataset, test_dataset = get_train_val(210train_dataset, test_transform, self.name211)212else:213test_dataset = TinyImagenet(214base_path() + 'TINYIMG',215train=False,216download=True,217transform=test_transform,218)219
220train, test = store_masked_loaders(train_dataset, test_dataset, self)221return train, test222
223def not_aug_dataloader(self, batch_size):224transform = transforms.Compose(225[transforms.ToTensor(), self.get_denormalization_transform()]226)227
228train_dataset = MyTinyImagenet(229base_path() + 'TINYIMG', train=True, download=True, transform=transform230)231train_loader = get_previous_train_loader(train_dataset, batch_size, self)232
233return train_loader234
235@staticmethod236def get_backbone():237return resnet18(238SequentialTinyImagenet.n_classes_per_task239* SequentialTinyImagenet.n_tasks240)241
242@staticmethod243def get_loss():244return F.cross_entropy245
246def get_transform(self):247transform = transforms.Compose([transforms.ToPILImage(), self.transform])248return transform249
250@staticmethod251def get_barlow_transform():252transform = transforms.Compose([253transforms.ToPILImage(),254transforms.RandomResizedCrop(size=64, scale=(0.2, 1.0)),255transforms.RandomHorizontalFlip(),256transforms.RandomApply(257[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8258),259transforms.RandomGrayscale(p=0.2),260transforms.ToTensor(),261transforms.Normalize(262(0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)263),264])265transform_prime = transforms.Compose([266transforms.ToPILImage(),267transforms.RandomResizedCrop(size=64, scale=(0.2, 1.0)),268transforms.RandomHorizontalFlip(),269transforms.RandomApply(270[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8271),272transforms.RandomGrayscale(p=0.2),273transforms.ToTensor(),274transforms.Normalize(275(0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)276),277])278return (transform, transform_prime)279
280@staticmethod281def get_normalization_transform():282transform = transforms.Normalize(283(0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821)284)285return transform286
287@staticmethod288def get_denormalization_transform():289transform = DeNormalize((0.4802, 0.4480, 0.3975), (0.2770, 0.2691, 0.2821))290return transform291