pytorch-image-models
402 строки · 15.1 Кб
1""" Loader Factory, Fast Collate, CUDA Prefetcher
2
3Prefetcher and Fast Collate inspired by NVIDIA APEX example at
4https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
5
6Hacked together by / Copyright 2019, Ross Wightman
7"""
8import logging
9import random
10from contextlib import suppress
11from functools import partial
12from itertools import repeat
13from typing import Callable, Optional, Tuple, Union
14
15import torch
16import torch.utils.data
17import numpy as np
18
19from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20from .dataset import IterableImageDataset, ImageDataset
21from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
22from .random_erasing import RandomErasing
23from .mixup import FastCollateMixup
24from .transforms_factory import create_transform
25
26_logger = logging.getLogger(__name__)
27
28
29def fast_collate(batch):
30""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
31assert isinstance(batch[0], tuple)
32batch_size = len(batch)
33if isinstance(batch[0][0], tuple):
34# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
35# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36inner_tuple_size = len(batch[0][0])
37flattened_batch_size = batch_size * inner_tuple_size
38targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
39tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
40for i in range(batch_size):
41assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
42for j in range(inner_tuple_size):
43targets[i + j * batch_size] = batch[i][1]
44tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45return tensor, targets
46elif isinstance(batch[0][0], np.ndarray):
47targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
48assert len(targets) == batch_size
49tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
50for i in range(batch_size):
51tensor[i] += torch.from_numpy(batch[i][0])
52return tensor, targets
53elif isinstance(batch[0][0], torch.Tensor):
54targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
55assert len(targets) == batch_size
56tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
57for i in range(batch_size):
58tensor[i].copy_(batch[i][0])
59return tensor, targets
60else:
61assert False
62
63
64def adapt_to_chs(x, n):
65if not isinstance(x, (tuple, list)):
66x = tuple(repeat(x, n))
67elif len(x) != n:
68x_mean = np.mean(x).item()
69x = (x_mean,) * n
70_logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
71else:
72assert len(x) == n, 'normalization stats must match image channels'
73return x
74
75
76class PrefetchLoader:
77
78def __init__(
79self,
80loader,
81mean=IMAGENET_DEFAULT_MEAN,
82std=IMAGENET_DEFAULT_STD,
83channels=3,
84device=torch.device('cuda'),
85img_dtype=torch.float32,
86fp16=False,
87re_prob=0.,
88re_mode='const',
89re_count=1,
90re_num_splits=0):
91
92mean = adapt_to_chs(mean, channels)
93std = adapt_to_chs(std, channels)
94normalization_shape = (1, channels, 1, 1)
95
96self.loader = loader
97self.device = device
98if fp16:
99# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
100img_dtype = torch.float16
101self.img_dtype = img_dtype
102self.mean = torch.tensor(
103[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
104self.std = torch.tensor(
105[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
106if re_prob > 0.:
107self.random_erasing = RandomErasing(
108probability=re_prob,
109mode=re_mode,
110max_count=re_count,
111num_splits=re_num_splits,
112device=device,
113)
114else:
115self.random_erasing = None
116self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
117
118def __iter__(self):
119first = True
120if self.is_cuda:
121stream = torch.cuda.Stream()
122stream_context = partial(torch.cuda.stream, stream=stream)
123else:
124stream = None
125stream_context = suppress
126
127for next_input, next_target in self.loader:
128
129with stream_context():
130next_input = next_input.to(device=self.device, non_blocking=True)
131next_target = next_target.to(device=self.device, non_blocking=True)
132next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
133if self.random_erasing is not None:
134next_input = self.random_erasing(next_input)
135
136if not first:
137yield input, target
138else:
139first = False
140
141if stream is not None:
142torch.cuda.current_stream().wait_stream(stream)
143
144input = next_input
145target = next_target
146
147yield input, target
148
149def __len__(self):
150return len(self.loader)
151
152@property
153def sampler(self):
154return self.loader.sampler
155
156@property
157def dataset(self):
158return self.loader.dataset
159
160@property
161def mixup_enabled(self):
162if isinstance(self.loader.collate_fn, FastCollateMixup):
163return self.loader.collate_fn.mixup_enabled
164else:
165return False
166
167@mixup_enabled.setter
168def mixup_enabled(self, x):
169if isinstance(self.loader.collate_fn, FastCollateMixup):
170self.loader.collate_fn.mixup_enabled = x
171
172
173def _worker_init(worker_id, worker_seeding='all'):
174worker_info = torch.utils.data.get_worker_info()
175assert worker_info.id == worker_id
176if isinstance(worker_seeding, Callable):
177seed = worker_seeding(worker_info)
178random.seed(seed)
179torch.manual_seed(seed)
180np.random.seed(seed % (2 ** 32 - 1))
181else:
182assert worker_seeding in ('all', 'part')
183# random / torch seed already called in dataloader iter class w/ worker_info.seed
184# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
185if worker_seeding == 'all':
186np.random.seed(worker_info.seed % (2 ** 32 - 1))
187
188
189def create_loader(
190dataset: Union[ImageDataset, IterableImageDataset],
191input_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
192batch_size: int,
193is_training: bool = False,
194no_aug: bool = False,
195re_prob: float = 0.,
196re_mode: str = 'const',
197re_count: int = 1,
198re_split: bool = False,
199train_crop_mode: Optional[str] = None,
200scale: Optional[Tuple[float, float]] = None,
201ratio: Optional[Tuple[float, float]] = None,
202hflip: float = 0.5,
203vflip: float = 0.,
204color_jitter: float = 0.4,
205color_jitter_prob: Optional[float] = None,
206grayscale_prob: float = 0.,
207gaussian_blur_prob: float = 0.,
208auto_augment: Optional[str] = None,
209num_aug_repeats: int = 0,
210num_aug_splits: int = 0,
211interpolation: str = 'bilinear',
212mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
213std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
214num_workers: int = 1,
215distributed: bool = False,
216crop_pct: Optional[float] = None,
217crop_mode: Optional[str] = None,
218crop_border_pixels: Optional[int] = None,
219collate_fn: Optional[Callable] = None,
220pin_memory: bool = False,
221fp16: bool = False, # deprecated, use img_dtype
222img_dtype: torch.dtype = torch.float32,
223device: torch.device = torch.device('cuda'),
224use_prefetcher: bool = True,
225use_multi_epochs_loader: bool = False,
226persistent_workers: bool = True,
227worker_seeding: str = 'all',
228tf_preprocessing: bool = False,
229):
230"""
231
232Args:
233dataset: The image dataset to load.
234input_size: Target input size (channels, height, width) tuple or size scalar.
235batch_size: Number of samples in a batch.
236is_training: Return training (random) transforms.
237no_aug: Disable augmentation for training (useful for debug).
238re_prob: Random erasing probability.
239re_mode: Random erasing fill mode.
240re_count: Number of random erasing regions.
241re_split: Control split of random erasing across batch size.
242scale: Random resize scale range (crop area, < 1.0 => zoom in).
243ratio: Random aspect ratio range (crop ratio for RRC, ratio adjustment factor for RKR).
244hflip: Horizontal flip probability.
245vflip: Vertical flip probability.
246color_jitter: Random color jitter component factors (brightness, contrast, saturation, hue).
247Scalar is applied as (scalar,) * 3 (no hue).
248color_jitter_prob: Apply color jitter with this probability if not None (for SimlCLR-like aug
249grayscale_prob: Probability of converting image to grayscale (for SimCLR-like aug).
250gaussian_blur_prob: Probability of applying gaussian blur (for SimCLR-like aug).
251auto_augment: Auto augment configuration string (see auto_augment.py).
252num_aug_repeats: Enable special sampler to repeat same augmentation across distributed GPUs.
253num_aug_splits: Enable mode where augmentations can be split across the batch.
254interpolation: Image interpolation mode.
255mean: Image normalization mean.
256std: Image normalization standard deviation.
257num_workers: Num worker processes per DataLoader.
258distributed: Enable dataloading for distributed training.
259crop_pct: Inference crop percentage (output size / resize size).
260crop_mode: Inference crop mode. One of ['squash', 'border', 'center']. Defaults to 'center' when None.
261crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
262collate_fn: Override default collate_fn.
263pin_memory: Pin memory for device transfer.
264fp16: Deprecated argument for half-precision input dtype. Use img_dtype.
265img_dtype: Data type for input image.
266device: Device to transfer inputs and targets to.
267use_prefetcher: Use efficient pre-fetcher to load samples onto device.
268use_multi_epochs_loader:
269persistent_workers: Enable persistent worker processes.
270worker_seeding: Control worker random seeding at init.
271tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports.
272
273Returns:
274DataLoader
275"""
276re_num_splits = 0
277if re_split:
278# apply RE to second half of batch if no aug split otherwise line up with aug split
279re_num_splits = num_aug_splits or 2
280dataset.transform = create_transform(
281input_size,
282is_training=is_training,
283no_aug=no_aug,
284train_crop_mode=train_crop_mode,
285scale=scale,
286ratio=ratio,
287hflip=hflip,
288vflip=vflip,
289color_jitter=color_jitter,
290color_jitter_prob=color_jitter_prob,
291grayscale_prob=grayscale_prob,
292gaussian_blur_prob=gaussian_blur_prob,
293auto_augment=auto_augment,
294interpolation=interpolation,
295mean=mean,
296std=std,
297crop_pct=crop_pct,
298crop_mode=crop_mode,
299crop_border_pixels=crop_border_pixels,
300re_prob=re_prob,
301re_mode=re_mode,
302re_count=re_count,
303re_num_splits=re_num_splits,
304tf_preprocessing=tf_preprocessing,
305use_prefetcher=use_prefetcher,
306separate=num_aug_splits > 0,
307)
308
309if isinstance(dataset, IterableImageDataset):
310# give Iterable datasets early knowledge of num_workers so that sample estimates
311# are correct before worker processes are launched
312dataset.set_loader_cfg(num_workers=num_workers)
313
314sampler = None
315if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
316if is_training:
317if num_aug_repeats:
318sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
319else:
320sampler = torch.utils.data.distributed.DistributedSampler(dataset)
321else:
322# This will add extra duplicate entries to result in equal num
323# of samples per-process, will slightly alter validation results
324sampler = OrderedDistributedSampler(dataset)
325else:
326assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
327
328if collate_fn is None:
329collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
330
331loader_class = torch.utils.data.DataLoader
332if use_multi_epochs_loader:
333loader_class = MultiEpochsDataLoader
334
335loader_args = dict(
336batch_size=batch_size,
337shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
338num_workers=num_workers,
339sampler=sampler,
340collate_fn=collate_fn,
341pin_memory=pin_memory,
342drop_last=is_training,
343worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
344persistent_workers=persistent_workers
345)
346try:
347loader = loader_class(dataset, **loader_args)
348except TypeError as e:
349loader_args.pop('persistent_workers') # only in Pytorch 1.7+
350loader = loader_class(dataset, **loader_args)
351if use_prefetcher:
352prefetch_re_prob = re_prob if is_training and not no_aug else 0.
353loader = PrefetchLoader(
354loader,
355mean=mean,
356std=std,
357channels=input_size[0],
358device=device,
359fp16=fp16, # deprecated, use img_dtype
360img_dtype=img_dtype,
361re_prob=prefetch_re_prob,
362re_mode=re_mode,
363re_count=re_count,
364re_num_splits=re_num_splits
365)
366
367return loader
368
369
370class MultiEpochsDataLoader(torch.utils.data.DataLoader):
371
372def __init__(self, *args, **kwargs):
373super().__init__(*args, **kwargs)
374self._DataLoader__initialized = False
375if self.batch_sampler is None:
376self.sampler = _RepeatSampler(self.sampler)
377else:
378self.batch_sampler = _RepeatSampler(self.batch_sampler)
379self._DataLoader__initialized = True
380self.iterator = super().__iter__()
381
382def __len__(self):
383return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
384
385def __iter__(self):
386for i in range(len(self)):
387yield next(self.iterator)
388
389
390class _RepeatSampler(object):
391""" Sampler that repeats forever.
392
393Args:
394sampler (Sampler)
395"""
396
397def __init__(self, sampler):
398self.sampler = sampler
399
400def __iter__(self):
401while True:
402yield from iter(self.sampler)
403