18
"""Batch samplers that work with either random or sequential data samplers."""
21
from torch.utils import data
24
class RandomSampler(data.sampler.Sampler):
25
"""Based off of pytorch RandomSampler and DistributedSampler. Essentially
26
a RandomSampler, but this class lets the user set an epoch like
27
DistributedSampler Samples elements randomly. If without replacement, then
28
sample from a shuffled dataset. If with replacement, then user can
29
specify ``num_samples`` to draw.
31
data_source (Dataset): dataset to sample from
32
num_samples (int): number of samples to draw, default=len(dataset)
33
replacement (bool): samples are drawn with replacement if ``True``,
37
def __init__(self, data_source, replacement=False, num_samples=None):
38
self.data_source = data_source
39
self.replacement = replacement
40
self._num_samples = num_samples
43
if self._num_samples is not None and replacement is False:
45
"With replacement=False, num_samples should not "
46
"be specified, since a random permute will be "
50
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
52
"num_samples should be a positive integer "
53
"value, but got num_samples={}".format(self.num_samples)
55
if not isinstance(self.replacement, bool):
57
"replacement should be a boolean value, but got "
58
"replacement={}".format(self.replacement)
62
def num_samples(self):
64
if self._num_samples is None:
65
return len(self.data_source)
66
return self._num_samples
69
n = len(self.data_source)
72
g.manual_seed(self.epoch)
76
high=n, size=(self.num_samples,), dtype=torch.int64, generator=g
79
return iter(torch.randperm(n, generator=g).tolist())
82
return self.num_samples
84
def set_epoch(self, epoch):
88
class DistributedBatchSampler(data.sampler.BatchSampler):
89
"""Similar to normal implementation of distributed sampler, except
90
implementation is at the batch sampler level, instead of just the
91
sampler level. This allows wrapping of arbitrary data samplers
92
(sequential, random, WeightedRandomSampler, etc.) with this batch
95
The `interleave` argument specifies how to distribute a batch. A value
96
of True combined with the above random sampler is equivalent to pytorch's
97
torch.utils.data.distributed.DistributedSampler.
99
For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
100
specifying True will result in the following samples for each gpu:
101
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
102
specifying False will result in the following samples:
103
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
115
super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
117
assert False, "should not be here"
118
rank = torch.distributed.get_rank()
120
self.world_size = world_size
121
self.sampler.wrap_around = 0
123
self.wrap_last = wrap_last
125
self.interleave = interleave
130
for idx in self.data_iterator(self.sampler, wrap_around=False):
132
if len(batch) == self.batch_size:
133
tbatch = self._batch(batch)
134
if i >= self.start_iter:
139
batch_len = len(batch)
140
if batch_len > 0 and not self.drop_last:
142
self.sampler.wrap_around -= self.batch_size
143
self.wrap_around += len(batch)
144
self.wrap_around %= self.batch_size
145
yield self._batch(batch)
147
self.sampler.wrap_around += self.batch_size
149
def data_iterator(self, _iter, wrap_around=False):
150
"""iterates through data and handles wrap around"""
151
for i, idx in enumerate(_iter):
152
if i < self.wrap_around % self.batch_size:
155
self.wrap_around += 1
156
self.wrap_around %= self.batch_size
159
def _batch(self, batch):
160
"""extracts samples only pertaining to this worker's batch"""
162
return batch[self.rank : self.batch_size : self.world_size]
163
start = self.rank * self.batch_size // self.world_size
164
end = (self.rank + 1) * self.batch_size // self.world_size
165
return batch[start:end]