paddlenlp

Форк
0
175 строк · 7.1 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from __future__ import division, print_function
16

17
import math
18

19
import paddle
20
from paddle.io import DistributedBatchSampler
21
from ppfleetx.distributed.apis import env
22

23
__all__ = ["GPTBatchSampler", "DistributedBatchSampler"]
24

25

26
class GPTBatchSampler(paddle.io.BatchSampler):
27
    """Sampler that restricts data loading to a subset of the dataset.
28
    In such case, each process can pass a DistributedBatchSampler instance
29
    as a DataLoader sampler, and load a subset of the original dataset that
30
    is exclusive to it.
31
    .. note::
32
        Dataset is assumed to be of constant size.
33

34
    Args:
35
        dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
36
                     or other python object which implemented
37
                     `__len__` for BatchSampler to get sample
38
                     number of data source.
39
        batch_size(int): sample indice number in a mini-batch indices.
40
        num_replicas(int, optional): porcess number in distributed training.
41
            If :attr:`num_replicas` is None, :attr:`num_replicas` will be
42
            retrieved from :code:`paddle.distributed.ParallenEnv`.
43
            Default None.
44
        rank(int, optional): the rank of the current process among :attr:`num_replicas`
45
            processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
46
            :code:`paddle.distributed.ParallenEnv`. Default None.
47
        shuffle(bool): whther to shuffle indices order before genrating
48
            batch indices. Default False.
49
        drop_last(bool): whether drop the last incomplete batch dataset size
50
            is not divisible by the batch size. Default False
51
    Examples:
52
        .. code-block:: python
53
            import numpy as np
54
            from paddle.io import Dataset, DistributedBatchSampler
55
            # init with dataset
56
            class RandomDataset(Dataset):
57
                def __init__(self, num_samples):
58
                    self.num_samples = num_samples
59

60
                def __getitem__(self, idx):
61
                    image = np.random.random([784]).astype('float32')
62
                    label = np.random.randint(0, 9, (1, )).astype('int64')
63
                    return image, label
64

65
                def __len__(self):
66
                    return self.num_samples
67

68
            dataset = RandomDataset(100)
69
            sampler = DistributedBatchSampler(dataset, batch_size=64)
70
            for data in sampler:
71
                # do something
72
                break
73
    """
74

75
    def __init__(
76
        self, dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False, consumed_samples=0
77
    ):
78
        self.dataset = dataset
79

80
        assert isinstance(batch_size, int) and batch_size > 0, "batch_size should be a positive integer"
81
        self.batch_size = batch_size
82
        assert isinstance(shuffle, bool), "shuffle should be a boolean value"
83
        self.shuffle = shuffle
84
        assert isinstance(drop_last, bool), "drop_last should be a boolean number"
85

86
        if num_replicas is not None:
87
            assert isinstance(num_replicas, int) and num_replicas > 0, "num_replicas should be a positive integer"
88
            self.nranks = num_replicas
89
        else:
90
            self.nranks = env.get_data_world_size()
91

92
        if rank is not None:
93
            assert isinstance(rank, int) and rank >= 0, "rank should be a non-negative integer"
94
            self.local_rank = rank
95
        else:
96
            self.local_rank = env.get_data_world_rank()
97

98
        self.drop_last = drop_last
99
        self.epoch = 0
100

101
        self.consumed_samples = consumed_samples
102
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
103
        self.total_size = self.num_samples * self.nranks
104

105
    def get_start_end_idx(self):
106
        start_idx = self.local_rank * self.batch_size
107
        end_idx = start_idx + self.batch_size
108
        return start_idx, end_idx
109

110
    def __iter__(self):
111
        assert (
112
            self.consumed_samples % self.nranks == 0
113
        ), "The consumed_samples should be divided by nranks. consumed_samples=%d, nranks=%s" % (
114
            self.consumed_samples,
115
            self.nranks,
116
        )
117
        self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks))
118
        self.remain_total_size = self.remain_num_samples * self.nranks
119
        self.batch_size_times_rank_size = self.batch_size * self.nranks
120

121
        num_samples = len(self.dataset)
122
        batch_indices = []
123
        for idx in range(self.consumed_samples, self.total_size):
124
            if idx >= num_samples:
125
                batch_indices.append(idx - num_samples)
126
            else:
127
                batch_indices.append(idx)
128
            if len(batch_indices) == self.batch_size_times_rank_size:
129
                start_idx, end_idx = self.get_start_end_idx()
130
                yield batch_indices[start_idx:end_idx]
131
                batch_indices = []
132
        if not self.drop_last and len(batch_indices) > 0:
133
            yield batch_indices
134

135
    def __len__(self):
136
        num_samples = self.num_samples
137
        num_samples += int(not self.drop_last) * (self.batch_size - 1)
138
        return num_samples // self.batch_size
139

140
    def set_epoch(self, epoch=0, consumed_samples=0):
141
        """
142
        Sets the epoch number. When :attr:`shuffle=True`, this number is used
143
        as seeds of random numbers. By default, users may not set this, all
144
        replicas (workers) use a different random ordering for each epoch.
145
        If set same number at each epoch, this sampler will yield the same
146
        ordering at all epoches.
147
        Arguments:
148
            epoch (int): Epoch number.
149
        Examples:
150
            .. code-block:: python
151

152
                from paddle.io import Dataset, DistributedBatchSampler
153

154
                # init with dataset
155
                class RandomDataset(Dataset):
156
                    def __init__(self, num_samples):
157
                        self.num_samples = num_samples
158

159
                    def __getitem__(self, idx):
160
                        image = np.random.random([784]).astype('float32')
161
                        label = np.random.randint(0, 9, (1, )).astype('int64')
162
                        return image, label
163

164
                    def __len__(self):
165
                        return self.num_samples
166

167
                dataset = RandomDataset(100)
168
                sampler = DistributedBatchSampler(dataset, batch_size=64)
169

170
                for epoch in range(10):
171
                    sampler.set_epoch(epoch)
172
        """
173
        self.epoch = epoch
174
        # if we reset the epoch, the consumed_samples should be set to 0.
175
        self.consumed_samples = consumed_samples
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.