pytorch-image-models

Форк
0
402 строки · 15.1 Кб
1
""" Loader Factory, Fast Collate, CUDA Prefetcher
2

3
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
4
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
5

6
Hacked together by / Copyright 2019, Ross Wightman
7
"""
8
import logging
9
import random
10
from contextlib import suppress
11
from functools import partial
12
from itertools import repeat
13
from typing import Callable, Optional, Tuple, Union
14

15
import torch
16
import torch.utils.data
17
import numpy as np
18

19
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20
from .dataset import IterableImageDataset, ImageDataset
21
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
22
from .random_erasing import RandomErasing
23
from .mixup import FastCollateMixup
24
from .transforms_factory import create_transform
25

26
_logger = logging.getLogger(__name__)
27

28

29
def fast_collate(batch):
30
    """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
31
    assert isinstance(batch[0], tuple)
32
    batch_size = len(batch)
33
    if isinstance(batch[0][0], tuple):
34
        # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
35
        # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36
        inner_tuple_size = len(batch[0][0])
37
        flattened_batch_size = batch_size * inner_tuple_size
38
        targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
39
        tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
40
        for i in range(batch_size):
41
            assert len(batch[i][0]) == inner_tuple_size  # all input tensor tuples must be same length
42
            for j in range(inner_tuple_size):
43
                targets[i + j * batch_size] = batch[i][1]
44
                tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45
        return tensor, targets
46
    elif isinstance(batch[0][0], np.ndarray):
47
        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
48
        assert len(targets) == batch_size
49
        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
50
        for i in range(batch_size):
51
            tensor[i] += torch.from_numpy(batch[i][0])
52
        return tensor, targets
53
    elif isinstance(batch[0][0], torch.Tensor):
54
        targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
55
        assert len(targets) == batch_size
56
        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
57
        for i in range(batch_size):
58
            tensor[i].copy_(batch[i][0])
59
        return tensor, targets
60
    else:
61
        assert False
62

63

64
def adapt_to_chs(x, n):
65
    if not isinstance(x, (tuple, list)):
66
        x = tuple(repeat(x, n))
67
    elif len(x) != n:
68
        x_mean = np.mean(x).item()
69
        x = (x_mean,) * n
70
        _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
71
    else:
72
        assert len(x) == n, 'normalization stats must match image channels'
73
    return x
74

75

76
class PrefetchLoader:
77

78
    def __init__(
79
            self,
80
            loader,
81
            mean=IMAGENET_DEFAULT_MEAN,
82
            std=IMAGENET_DEFAULT_STD,
83
            channels=3,
84
            device=torch.device('cuda'),
85
            img_dtype=torch.float32,
86
            fp16=False,
87
            re_prob=0.,
88
            re_mode='const',
89
            re_count=1,
90
            re_num_splits=0):
91

92
        mean = adapt_to_chs(mean, channels)
93
        std = adapt_to_chs(std, channels)
94
        normalization_shape = (1, channels, 1, 1)
95

96
        self.loader = loader
97
        self.device = device
98
        if fp16:
99
            # fp16 arg is deprecated, but will override dtype arg if set for bwd compat
100
            img_dtype = torch.float16
101
        self.img_dtype = img_dtype
102
        self.mean = torch.tensor(
103
            [x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
104
        self.std = torch.tensor(
105
            [x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
106
        if re_prob > 0.:
107
            self.random_erasing = RandomErasing(
108
                probability=re_prob,
109
                mode=re_mode,
110
                max_count=re_count,
111
                num_splits=re_num_splits,
112
                device=device,
113
            )
114
        else:
115
            self.random_erasing = None
116
        self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
117

118
    def __iter__(self):
119
        first = True
120
        if self.is_cuda:
121
            stream = torch.cuda.Stream()
122
            stream_context = partial(torch.cuda.stream, stream=stream)
123
        else:
124
            stream = None
125
            stream_context = suppress
126

127
        for next_input, next_target in self.loader:
128

129
            with stream_context():
130
                next_input = next_input.to(device=self.device, non_blocking=True)
131
                next_target = next_target.to(device=self.device, non_blocking=True)
132
                next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
133
                if self.random_erasing is not None:
134
                    next_input = self.random_erasing(next_input)
135

136
            if not first:
137
                yield input, target
138
            else:
139
                first = False
140

141
            if stream is not None:
142
                torch.cuda.current_stream().wait_stream(stream)
143

144
            input = next_input
145
            target = next_target
146

147
        yield input, target
148

149
    def __len__(self):
150
        return len(self.loader)
151

152
    @property
153
    def sampler(self):
154
        return self.loader.sampler
155

156
    @property
157
    def dataset(self):
158
        return self.loader.dataset
159

160
    @property
161
    def mixup_enabled(self):
162
        if isinstance(self.loader.collate_fn, FastCollateMixup):
163
            return self.loader.collate_fn.mixup_enabled
164
        else:
165
            return False
166

167
    @mixup_enabled.setter
168
    def mixup_enabled(self, x):
169
        if isinstance(self.loader.collate_fn, FastCollateMixup):
170
            self.loader.collate_fn.mixup_enabled = x
171

172

173
def _worker_init(worker_id, worker_seeding='all'):
174
    worker_info = torch.utils.data.get_worker_info()
175
    assert worker_info.id == worker_id
176
    if isinstance(worker_seeding, Callable):
177
        seed = worker_seeding(worker_info)
178
        random.seed(seed)
179
        torch.manual_seed(seed)
180
        np.random.seed(seed % (2 ** 32 - 1))
181
    else:
182
        assert worker_seeding in ('all', 'part')
183
        # random / torch seed already called in dataloader iter class w/ worker_info.seed
184
        # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
185
        if worker_seeding == 'all':
186
            np.random.seed(worker_info.seed % (2 ** 32 - 1))
187

188

189
def create_loader(
190
        dataset: Union[ImageDataset, IterableImageDataset],
191
        input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
192
        batch_size: int,
193
        is_training: bool = False,
194
        no_aug: bool = False,
195
        re_prob: float = 0.,
196
        re_mode: str = 'const',
197
        re_count: int = 1,
198
        re_split: bool = False,
199
        train_crop_mode: Optional[str] = None,
200
        scale: Optional[Tuple[float, float]] = None,
201
        ratio: Optional[Tuple[float, float]] = None,
202
        hflip: float = 0.5,
203
        vflip: float = 0.,
204
        color_jitter: float = 0.4,
205
        color_jitter_prob: Optional[float] = None,
206
        grayscale_prob: float = 0.,
207
        gaussian_blur_prob: float = 0.,
208
        auto_augment: Optional[str] = None,
209
        num_aug_repeats: int = 0,
210
        num_aug_splits: int = 0,
211
        interpolation: str = 'bilinear',
212
        mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
213
        std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
214
        num_workers: int = 1,
215
        distributed: bool = False,
216
        crop_pct: Optional[float] = None,
217
        crop_mode: Optional[str] = None,
218
        crop_border_pixels: Optional[int] = None,
219
        collate_fn: Optional[Callable] = None,
220
        pin_memory: bool = False,
221
        fp16: bool = False,  # deprecated, use img_dtype
222
        img_dtype: torch.dtype = torch.float32,
223
        device: torch.device = torch.device('cuda'),
224
        use_prefetcher: bool = True,
225
        use_multi_epochs_loader: bool = False,
226
        persistent_workers: bool = True,
227
        worker_seeding: str = 'all',
228
        tf_preprocessing: bool = False,
229
):
230
    """
231

232
    Args:
233
        dataset: The image dataset to load.
234
        input_size: Target input size (channels, height, width) tuple or size scalar.
235
        batch_size: Number of samples in a batch.
236
        is_training: Return training (random) transforms.
237
        no_aug: Disable augmentation for training (useful for debug).
238
        re_prob: Random erasing probability.
239
        re_mode: Random erasing fill mode.
240
        re_count: Number of random erasing regions.
241
        re_split: Control split of random erasing across batch size.
242
        scale: Random resize scale range (crop area, < 1.0 => zoom in).
243
        ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
244
        hflip: Horizontal flip probability.
245
        vflip: Vertical flip probability.
246
        color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
247
            Scalar is applied as (scalar,) * 3 (no hue).
248
        color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
249
        grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
250
        gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
251
        auto_augment: Auto augment configuration string (see auto_augment.py).
252
        num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
253
        num_aug_splits: Enable mode where augmentations can be split across the batch.
254
        interpolation: Image interpolation mode.
255
        mean: Image normalization mean.
256
        std: Image normalization standard deviation.
257
        num_workers: Num worker processes per DataLoader.
258
        distributed: Enable dataloading for distributed training.
259
        crop_pct: Inference crop percentage (output size / resize size).
260
        crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
261
        crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
262
        collate_fn: Override default collate_fn.
263
        pin_memory: Pin memory for device transfer.
264
        fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
265
        img_dtype: Data type for input image.
266
        device: Device to transfer inputs and targets to.
267
        use_prefetcher: Use efficient pre-fetcher to load samples onto device.
268
        use_multi_epochs_loader:
269
        persistent_workers: Enable persistent worker processes.
270
        worker_seeding: Control worker random seeding at init.
271
        tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
272

273
    Returns:
274
        DataLoader
275
    """
276
    re_num_splits = 0
277
    if re_split:
278
        # apply RE to second half of batch if no aug split otherwise line up with aug split
279
        re_num_splits = num_aug_splits or 2
280
    dataset.transform = create_transform(
281
        input_size,
282
        is_training=is_training,
283
        no_aug=no_aug,
284
        train_crop_mode=train_crop_mode,
285
        scale=scale,
286
        ratio=ratio,
287
        hflip=hflip,
288
        vflip=vflip,
289
        color_jitter=color_jitter,
290
        color_jitter_prob=color_jitter_prob,
291
        grayscale_prob=grayscale_prob,
292
        gaussian_blur_prob=gaussian_blur_prob,
293
        auto_augment=auto_augment,
294
        interpolation=interpolation,
295
        mean=mean,
296
        std=std,
297
        crop_pct=crop_pct,
298
        crop_mode=crop_mode,
299
        crop_border_pixels=crop_border_pixels,
300
        re_prob=re_prob,
301
        re_mode=re_mode,
302
        re_count=re_count,
303
        re_num_splits=re_num_splits,
304
        tf_preprocessing=tf_preprocessing,
305
        use_prefetcher=use_prefetcher,
306
        separate=num_aug_splits > 0,
307
    )
308

309
    if isinstance(dataset, IterableImageDataset):
310
        # give Iterable datasets early knowledge of num_workers so that sample estimates
311
        # are correct before worker processes are launched
312
        dataset.set_loader_cfg(num_workers=num_workers)
313

314
    sampler = None
315
    if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
316
        if is_training:
317
            if num_aug_repeats:
318
                sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
319
            else:
320
                sampler = torch.utils.data.distributed.DistributedSampler(dataset)
321
        else:
322
            # This will add extra duplicate entries to result in equal num
323
            # of samples per-process, will slightly alter validation results
324
            sampler = OrderedDistributedSampler(dataset)
325
    else:
326
        assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
327

328
    if collate_fn is None:
329
        collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
330

331
    loader_class = torch.utils.data.DataLoader
332
    if use_multi_epochs_loader:
333
        loader_class = MultiEpochsDataLoader
334

335
    loader_args = dict(
336
        batch_size=batch_size,
337
        shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
338
        num_workers=num_workers,
339
        sampler=sampler,
340
        collate_fn=collate_fn,
341
        pin_memory=pin_memory,
342
        drop_last=is_training,
343
        worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
344
        persistent_workers=persistent_workers
345
    )
346
    try:
347
        loader = loader_class(dataset, **loader_args)
348
    except TypeError as e:
349
        loader_args.pop('persistent_workers')  # only in Pytorch 1.7+
350
        loader = loader_class(dataset, **loader_args)
351
    if use_prefetcher:
352
        prefetch_re_prob = re_prob if is_training and not no_aug else 0.
353
        loader = PrefetchLoader(
354
            loader,
355
            mean=mean,
356
            std=std,
357
            channels=input_size[0],
358
            device=device,
359
            fp16=fp16,  # deprecated, use img_dtype
360
            img_dtype=img_dtype,
361
            re_prob=prefetch_re_prob,
362
            re_mode=re_mode,
363
            re_count=re_count,
364
            re_num_splits=re_num_splits
365
        )
366

367
    return loader
368

369

370
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
371

372
    def __init__(self, *args, **kwargs):
373
        super().__init__(*args, **kwargs)
374
        self._DataLoader__initialized = False
375
        if self.batch_sampler is None:
376
            self.sampler = _RepeatSampler(self.sampler)
377
        else:
378
            self.batch_sampler = _RepeatSampler(self.batch_sampler)
379
        self._DataLoader__initialized = True
380
        self.iterator = super().__iter__()
381

382
    def __len__(self):
383
        return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
384

385
    def __iter__(self):
386
        for i in range(len(self)):
387
            yield next(self.iterator)
388

389

390
class _RepeatSampler(object):
391
    """ Sampler that repeats forever.
392

393
    Args:
394
        sampler (Sampler)
395
    """
396

397
    def __init__(self, sampler):
398
        self.sampler = sampler
399

400
    def __iter__(self):
401
        while True:
402
            yield from iter(self.sampler)
403

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.