4
from torch.utils.data import DataLoader
7
class PrefetchGenerator(threading.Thread):
8
"""A general prefetch generator.
10
Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
13
generator: Python generator.
14
num_prefetch_queue (int): Number of prefetch queue.
17
def __init__(self, generator, num_prefetch_queue):
18
threading.Thread.__init__(self)
19
self.queue = Queue.Queue(num_prefetch_queue)
20
self.generator = generator
25
for item in self.generator:
30
next_item = self.queue.get()
39
class PrefetchDataLoader(DataLoader):
40
"""Prefetch version of dataloader.
42
Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
49
num_prefetch_queue (int): Number of prefetch queue.
50
kwargs (dict): Other arguments for dataloader.
53
def __init__(self, num_prefetch_queue, **kwargs):
54
self.num_prefetch_queue = num_prefetch_queue
55
super(PrefetchDataLoader, self).__init__(**kwargs)
58
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
68
def __init__(self, loader):
69
self.ori_loader = loader
70
self.loader = iter(loader)
74
return next(self.loader)
79
self.loader = iter(self.ori_loader)
82
class CUDAPrefetcher():
85
Reference: https://github.com/NVIDIA/apex/issues/304#
87
It may consume more GPU memory.
94
def __init__(self, loader, opt):
95
self.ori_loader = loader
96
self.loader = iter(loader)
98
self.stream = torch.cuda.Stream()
99
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
104
self.batch = next(self.loader)
105
except StopIteration:
109
with torch.cuda.stream(self.stream):
110
for k, v in self.batch.items():
111
if torch.is_tensor(v):
112
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
115
torch.cuda.current_stream().wait_stream(self.stream)
121
self.loader = iter(self.ori_loader)