1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
15
from __future__ import division, print_function
20
from paddle.io import DistributedBatchSampler
21
from ppfleetx.distributed.apis import env
23
__all__ = ["GPTBatchSampler", "DistributedBatchSampler"]
26
class GPTBatchSampler(paddle.io.BatchSampler):
27
"""Sampler that restricts data loading to a subset of the dataset.
28
In such case, each process can pass a DistributedBatchSampler instance
29
as a DataLoader sampler, and load a subset of the original dataset that
32
Dataset is assumed to be of constant size.
35
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
36
or other python object which implemented
37
`__len__` for BatchSampler to get sample
38
number of data source.
39
batch_size(int): sample indice number in a mini-batch indices.
40
num_replicas(int, optional): porcess number in distributed training.
41
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
42
retrieved from :code:`paddle.distributed.ParallenEnv`.
44
rank(int, optional): the rank of the current process among :attr:`num_replicas`
45
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
46
:code:`paddle.distributed.ParallenEnv`. Default None.
47
shuffle(bool): whther to shuffle indices order before genrating
48
batch indices. Default False.
49
drop_last(bool): whether drop the last incomplete batch dataset size
50
is not divisible by the batch size. Default False
52
.. code-block:: python
54
from paddle.io import Dataset, DistributedBatchSampler
56
class RandomDataset(Dataset):
57
def __init__(self, num_samples):
58
self.num_samples = num_samples
60
def __getitem__(self, idx):
61
image = np.random.random([784]).astype('float32')
62
label = np.random.randint(0, 9, (1, )).astype('int64')
66
return self.num_samples
68
dataset = RandomDataset(100)
69
sampler = DistributedBatchSampler(dataset, batch_size=64)
76
self, dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False, consumed_samples=0
78
self.dataset = dataset
80
assert isinstance(batch_size, int) and batch_size > 0, "batch_size should be a positive integer"
81
self.batch_size = batch_size
82
assert isinstance(shuffle, bool), "shuffle should be a boolean value"
83
self.shuffle = shuffle
84
assert isinstance(drop_last, bool), "drop_last should be a boolean number"
86
if num_replicas is not None:
87
assert isinstance(num_replicas, int) and num_replicas > 0, "num_replicas should be a positive integer"
88
self.nranks = num_replicas
90
self.nranks = env.get_data_world_size()
93
assert isinstance(rank, int) and rank >= 0, "rank should be a non-negative integer"
94
self.local_rank = rank
96
self.local_rank = env.get_data_world_rank()
98
self.drop_last = drop_last
101
self.consumed_samples = consumed_samples
102
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
103
self.total_size = self.num_samples * self.nranks
105
def get_start_end_idx(self):
106
start_idx = self.local_rank * self.batch_size
107
end_idx = start_idx + self.batch_size
108
return start_idx, end_idx
112
self.consumed_samples % self.nranks == 0
113
), "The consumed_samples should be divided by nranks. consumed_samples=%d, nranks=%s" % (
114
self.consumed_samples,
117
self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks))
118
self.remain_total_size = self.remain_num_samples * self.nranks
119
self.batch_size_times_rank_size = self.batch_size * self.nranks
121
num_samples = len(self.dataset)
123
for idx in range(self.consumed_samples, self.total_size):
124
if idx >= num_samples:
125
batch_indices.append(idx - num_samples)
127
batch_indices.append(idx)
128
if len(batch_indices) == self.batch_size_times_rank_size:
129
start_idx, end_idx = self.get_start_end_idx()
130
yield batch_indices[start_idx:end_idx]
132
if not self.drop_last and len(batch_indices) > 0:
136
num_samples = self.num_samples
137
num_samples += int(not self.drop_last) * (self.batch_size - 1)
138
return num_samples // self.batch_size
140
def set_epoch(self, epoch=0, consumed_samples=0):
142
Sets the epoch number. When :attr:`shuffle=True`, this number is used
143
as seeds of random numbers. By default, users may not set this, all
144
replicas (workers) use a different random ordering for each epoch.
145
If set same number at each epoch, this sampler will yield the same
146
ordering at all epoches.
148
epoch (int): Epoch number.
150
.. code-block:: python
152
from paddle.io import Dataset, DistributedBatchSampler
155
class RandomDataset(Dataset):
156
def __init__(self, num_samples):
157
self.num_samples = num_samples
159
def __getitem__(self, idx):
160
image = np.random.random([784]).astype('float32')
161
label = np.random.randint(0, 9, (1, )).astype('int64')
165
return self.num_samples
167
dataset = RandomDataset(100)
168
sampler = DistributedBatchSampler(dataset, batch_size=64)
170
for epoch in range(10):
171
sampler.set_epoch(epoch)
174
# if we reset the epoch, the consumed_samples should be set to 0.
175
self.consumed_samples = consumed_samples