google-research
147 строк · 4.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Dataset Handler."""
17import os
18import torch
19from cascaded_networks.datasets import cifar_handler
20
21
22class DataHandler:
23"""Handler for datasets."""
24
25def __init__(self,
26dataset_name,
27data_root,
28val_split=0.1,
29split_idxs_root='/tmp/split_idxs',
30noise_type=None,
31ood_mode=None,
32load_previous_splits=True):
33"""Initialize dataset handler."""
34self.dataset_name = dataset_name
35self.data_root = data_root
36self.val_split = val_split
37self.noise_type = noise_type
38self._ood_mode = ood_mode
39self.load_previous_splits = load_previous_splits
40
41# Set idx with dataset_name
42split_idxs_root = os.path.join(split_idxs_root, dataset_name)
43if not os.path.exists(split_idxs_root):
44os.makedirs(split_idxs_root)
45
46if split_idxs_root and val_split:
47self.split_idxs_root = self._build_split_idx_root(split_idxs_root,
48dataset_name)
49else:
50self.split_idxs_root = None
51
52self.datasets = self._build_datasets()
53self._set_num_classes(dataset_name)
54
55def _set_num_classes(self, dataset_name):
56"""Set number of classes in dataset."""
57if dataset_name == 'CIFAR10':
58self.num_classes = 10
59elif dataset_name == 'CIFAR100':
60self.num_classes = 100
61elif dataset_name == 'TinyImageNet':
62self.num_classes = 200
63
64def get_transform(self, dataset_key=None):
65"""Build dataset transform."""
66if dataset_key is None:
67dataset_key = list(self.datasets.keys())[0]
68
69normalize_transform = None
70# Grab transforms - location varies depending on base dataset.
71try:
72transforms = self.datasets[dataset_key].transform.transforms
73found = True
74except AttributeError:
75found = False
76
77if not found:
78try:
79transforms = self.datasets[dataset_key].dataset.transform.transforms
80found = True
81except AttributeError:
82found = False
83
84if not found:
85print('Transform list not found!')
86else:
87found = False
88for xform in transforms:
89if 'normalize' in str(xform).lower():
90normalize_transform = xform
91found = True
92break
93
94if not found:
95print('Normalization transform not found!')
96return normalize_transform
97
98def _build_split_idx_root(self, split_idxs_root, dataset_name):
99"""Build directory for split idxs."""
100if '.json' in split_idxs_root and not os.path.exists(split_idxs_root):
101split_idxs_root = os.path.join(split_idxs_root, dataset_name)
102print(f'Setting split idxs root to {split_idxs_root}')
103if not os.path.exists(split_idxs_root):
104print(f'{split_idxs_root} does not exist!')
105os.makedirs(split_idxs_root)
106print('Complete.')
107return split_idxs_root
108
109def _build_datasets(self):
110"""Build dataset."""
111if 'cifar' in self.dataset_name.lower():
112dataset_dict = cifar_handler.create_datasets(
113self.data_root,
114dataset_name=self.dataset_name,
115val_split=self.val_split,
116split_idxs_root=self.split_idxs_root,
117noise_type=self.noise_type,
118load_previous_splits=self.load_previous_splits)
119elif self.dataset_name.lower() == 'tinyimagenet':
120assert False, 'Not Implemented.'
121return dataset_dict
122
123def build_loader(self,
124dataset_key,
125flags,
126dont_shuffle_train=False):
127"""Build dataset loader."""
128
129# Get dataset source
130dataset_src = self.datasets[dataset_key]
131
132# Specify shuffling
133if dont_shuffle_train:
134shuffle = False
135else:
136shuffle = dataset_key == 'train'
137
138# Creates dataloaders, which load data in batches
139loader = torch.utils.data.DataLoader(
140dataset=dataset_src,
141batch_size=flags.batch_size,
142shuffle=shuffle,
143num_workers=flags.num_workers,
144drop_last=flags.drop_last,
145pin_memory=True)
146
147return loader
148