gpt-neox

Форк
0
/
samplers.py 
165 строк · 6.1 Кб
1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
3
#
4
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17

18
"""Batch samplers that work with either random or sequential data samplers."""
19

20
import torch
21
from torch.utils import data
22

23

24
class RandomSampler(data.sampler.Sampler):
25
    """Based off of pytorch RandomSampler and DistributedSampler. Essentially
26
    a RandomSampler, but this class lets the user set an epoch like
27
    DistributedSampler Samples elements randomly. If without replacement, then
28
    sample from a shuffled dataset. If with replacement, then user can
29
    specify ``num_samples`` to draw.
30
    Arguments:
31
        data_source (Dataset): dataset to sample from
32
        num_samples (int): number of samples to draw, default=len(dataset)
33
        replacement (bool): samples are drawn with replacement if ``True``,
34
        default=False
35
    """
36

37
    def __init__(self, data_source, replacement=False, num_samples=None):
38
        self.data_source = data_source
39
        self.replacement = replacement
40
        self._num_samples = num_samples
41
        self.epoch = -1
42

43
        if self._num_samples is not None and replacement is False:
44
            raise ValueError(
45
                "With replacement=False, num_samples should not "
46
                "be specified, since a random permute will be "
47
                "performed."
48
            )
49

50
        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
51
            raise ValueError(
52
                "num_samples should be a positive integer "
53
                "value, but got num_samples={}".format(self.num_samples)
54
            )
55
        if not isinstance(self.replacement, bool):
56
            raise ValueError(
57
                "replacement should be a boolean value, but got "
58
                "replacement={}".format(self.replacement)
59
            )
60

61
    @property
62
    def num_samples(self):
63
        # dataset size might change at runtime
64
        if self._num_samples is None:
65
            return len(self.data_source)
66
        return self._num_samples
67

68
    def __iter__(self):
69
        n = len(self.data_source)
70
        g = torch.Generator()
71
        if self.epoch >= 0:
72
            g.manual_seed(self.epoch)
73
        if self.replacement:
74
            return iter(
75
                torch.randint(
76
                    high=n, size=(self.num_samples,), dtype=torch.int64, generator=g
77
                ).tolist()
78
            )
79
        return iter(torch.randperm(n, generator=g).tolist())
80

81
    def __len__(self):
82
        return self.num_samples
83

84
    def set_epoch(self, epoch):
85
        self.epoch = epoch
86

87

88
class DistributedBatchSampler(data.sampler.BatchSampler):
89
    """Similar to normal implementation of distributed sampler, except
90
    implementation is at the batch sampler level, instead of just the
91
    sampler level. This allows wrapping of arbitrary data samplers
92
    (sequential, random, WeightedRandomSampler, etc.) with this batch
93
    sampler.
94

95
    The `interleave` argument specifies how to distribute a batch. A value
96
    of True combined with the above random sampler is equivalent to pytorch's
97
    torch.utils.data.distributed.DistributedSampler.
98

99
    For the following batch [0,1,2,3,4,5,6,7] and data parallelism of 2
100
    specifying True will result in the following samples for each gpu:
101
        GPU0: [0,2,4,6] GPU1: [1,3,5,7]
102
    specifying False will result in the following samples:
103
        GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
104

105
    def __init__(
106
        self,
107
        sampler,
108
        batch_size,
109
        drop_last,
110
        rank=-1,
111
        world_size=2,
112
        wrap_last=False,
113
        interleave=False,
114
    ):
115
        super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
116
        if rank == -1:
117
            assert False, "should not be here"
118
            rank = torch.distributed.get_rank()
119
        self.rank = rank
120
        self.world_size = world_size
121
        self.sampler.wrap_around = 0
122
        self.wrap_around = 0
123
        self.wrap_last = wrap_last
124
        self.start_iter = 0
125
        self.interleave = interleave
126

127
    def __iter__(self):
128
        batch = []
129
        i = 0
130
        for idx in self.data_iterator(self.sampler, wrap_around=False):
131
            batch.append(idx)
132
            if len(batch) == self.batch_size:
133
                tbatch = self._batch(batch)
134
                if i >= self.start_iter:
135
                    yield tbatch
136
                    self.start_iter = 0
137
                i += 1
138
                batch = []
139
        batch_len = len(batch)
140
        if batch_len > 0 and not self.drop_last:
141
            if self.wrap_last:
142
                self.sampler.wrap_around -= self.batch_size
143
                self.wrap_around += len(batch)
144
                self.wrap_around %= self.batch_size
145
            yield self._batch(batch)
146
        if self.wrap_last:
147
            self.sampler.wrap_around += self.batch_size
148

149
    def data_iterator(self, _iter, wrap_around=False):
150
        """iterates through data and handles wrap around"""
151
        for i, idx in enumerate(_iter):
152
            if i < self.wrap_around % self.batch_size:
153
                continue
154
            if wrap_around:
155
                self.wrap_around += 1
156
                self.wrap_around %= self.batch_size
157
            yield idx
158

159
    def _batch(self, batch):
160
        """extracts samples only pertaining to this worker's batch"""
161
        if self.interleave:
162
            return batch[self.rank : self.batch_size : self.world_size]
163
        start = self.rank * self.batch_size // self.world_size
164
        end = (self.rank + 1) * self.batch_size // self.world_size
165
        return batch[start:end]
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.