google-research

Форк
0
147 строк · 4.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Dataset Handler."""
17
import os
18
import torch
19
from cascaded_networks.datasets import cifar_handler
20

21

22
class DataHandler:
23
  """Handler for datasets."""
24

25
  def __init__(self,
26
               dataset_name,
27
               data_root,
28
               val_split=0.1,
29
               split_idxs_root='/tmp/split_idxs',
30
               noise_type=None,
31
               ood_mode=None,
32
               load_previous_splits=True):
33
    """Initialize dataset handler."""
34
    self.dataset_name = dataset_name
35
    self.data_root = data_root
36
    self.val_split = val_split
37
    self.noise_type = noise_type
38
    self._ood_mode = ood_mode
39
    self.load_previous_splits = load_previous_splits
40

41
    # Set idx with dataset_name
42
    split_idxs_root = os.path.join(split_idxs_root, dataset_name)
43
    if not os.path.exists(split_idxs_root):
44
      os.makedirs(split_idxs_root)
45

46
    if split_idxs_root and val_split:
47
      self.split_idxs_root = self._build_split_idx_root(split_idxs_root,
48
                                                        dataset_name)
49
    else:
50
      self.split_idxs_root = None
51

52
    self.datasets = self._build_datasets()
53
    self._set_num_classes(dataset_name)
54

55
  def _set_num_classes(self, dataset_name):
56
    """Set number of classes in dataset."""
57
    if dataset_name == 'CIFAR10':
58
      self.num_classes = 10
59
    elif dataset_name == 'CIFAR100':
60
      self.num_classes = 100
61
    elif dataset_name == 'TinyImageNet':
62
      self.num_classes = 200
63

64
  def get_transform(self, dataset_key=None):
65
    """Build dataset transform."""
66
    if dataset_key is None:
67
      dataset_key = list(self.datasets.keys())[0]
68

69
    normalize_transform = None
70
    # Grab transforms - location varies depending on base dataset.
71
    try:
72
      transforms = self.datasets[dataset_key].transform.transforms
73
      found = True
74
    except AttributeError:
75
      found = False
76

77
    if not found:
78
      try:
79
        transforms = self.datasets[dataset_key].dataset.transform.transforms
80
        found = True
81
      except AttributeError:
82
        found = False
83

84
    if not found:
85
      print('Transform list not found!')
86
    else:
87
      found = False
88
      for xform in transforms:
89
        if 'normalize' in str(xform).lower():
90
          normalize_transform = xform
91
          found = True
92
          break
93

94
    if not found:
95
      print('Normalization transform not found!')
96
    return normalize_transform
97

98
  def _build_split_idx_root(self, split_idxs_root, dataset_name):
99
    """Build directory for split idxs."""
100
    if '.json' in split_idxs_root and not os.path.exists(split_idxs_root):
101
      split_idxs_root = os.path.join(split_idxs_root, dataset_name)
102
    print(f'Setting split idxs root to {split_idxs_root}')
103
    if not os.path.exists(split_idxs_root):
104
      print(f'{split_idxs_root} does not exist!')
105
      os.makedirs(split_idxs_root)
106
      print('Complete.')
107
    return split_idxs_root
108

109
  def _build_datasets(self):
110
    """Build dataset."""
111
    if 'cifar' in self.dataset_name.lower():
112
      dataset_dict = cifar_handler.create_datasets(
113
          self.data_root,
114
          dataset_name=self.dataset_name,
115
          val_split=self.val_split,
116
          split_idxs_root=self.split_idxs_root,
117
          noise_type=self.noise_type,
118
          load_previous_splits=self.load_previous_splits)
119
    elif self.dataset_name.lower() == 'tinyimagenet':
120
      assert False, 'Not Implemented.'
121
    return dataset_dict
122

123
  def build_loader(self,
124
                   dataset_key,
125
                   flags,
126
                   dont_shuffle_train=False):
127
    """Build dataset loader."""
128

129
    # Get dataset source
130
    dataset_src = self.datasets[dataset_key]
131

132
    # Specify shuffling
133
    if dont_shuffle_train:
134
      shuffle = False
135
    else:
136
      shuffle = dataset_key == 'train'
137

138
    # Creates dataloaders, which load data in batches
139
    loader = torch.utils.data.DataLoader(
140
        dataset=dataset_src,
141
        batch_size=flags.batch_size,
142
        shuffle=shuffle,
143
        num_workers=flags.num_workers,
144
        drop_last=flags.drop_last,
145
        pin_memory=True)
146

147
    return loader
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.