lmops

Форк
0
/
distributed_indexed.py 
215 строк · 7.0 Кб
1
# coding=utf-8
2
# Copyright 2020 The OpenBMB team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import os
17
import struct
18
import shutil
19

20
from itertools import accumulate
21

22
import numpy as np
23
import torch
24
import torch.distributed as dist
25
from utils import print_rank, save_rank
26

27

28
dtypes = {
29
    1: np.uint8,
30
    2: np.int8,
31
    3: np.int16,
32
    4: np.int32,
33
    5: np.int64,
34
    6: np.float32,
35
    7: np.double,
36
    8: np.uint16,
37
    9: np.uint32
38
}
39

40

41
def code(dtype):
42
    for k in dtypes.keys():
43
        if dtypes[k] == dtype:
44
            return k
45
    raise ValueError(dtype)
46

47

48
def index_file_path(prefix_path):
49
    return prefix_path + '.idx'
50

51

52
def data_file_path(prefix_path):
53
    return prefix_path + '.bin'
54

55

56
class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
57
    class Index(object):
58
        _HDR_MAGIC = b'MMIDIDX\x00\x00'
59
        def __init__(self, path):
60
            with open(path, 'rb') as stream:
61
                magic_test = stream.read(9)
62
                assert self._HDR_MAGIC == magic_test, (
63
                    'Index file doesn\'t match expected format. '
64
                    'Make sure that --dataset-impl is configured properly.'
65
                )
66
                version = struct.unpack('<Q', stream.read(8))
67
                assert (1,) == version
68

69
                dtype_code, = struct.unpack('<B', stream.read(1))
70
                self._dtype = dtypes[dtype_code]
71
                self._dtype_size = self._dtype().itemsize
72

73
                self._len = struct.unpack('<Q', stream.read(8))[0]
74
                self._doc_count = struct.unpack('<Q', stream.read(8))[0]
75
                offset = stream.tell()
76

77
            self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
78
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
79
            self._sizes = np.frombuffer(
80
                self._bin_buffer,
81
                dtype=np.int32,
82
                count=self._len,
83
                offset=offset)
84
            self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
85
                                           offset=offset + self._sizes.nbytes)
86
            self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
87
                                          offset=offset + self._sizes.nbytes + self._pointers.nbytes)
88

89
        def __del__(self):
90
            self._bin_buffer_mmap._mmap.close()
91
            del self._bin_buffer_mmap
92

93
        @property
94
        def dtype(self):
95
            return self._dtype
96

97
        @property
98
        def sizes(self):
99
            return self._sizes
100

101
        @property
102
        def doc_idx(self):
103
            return self._doc_idx
104

105
        def __getitem__(self, i):
106
            return self._pointers[i], self._sizes[i]
107

108
        def __len__(self):
109
            return self._len
110

111
    def __init__(self, path, name, rank_number, rank_total, cache = None):
112
        
113
        super().__init__()
114

115
        self._path = path
116
        self._name = name
117
        self._state = 0
118
        if cache is not None:
119
            self._cache = cache
120
            os.makedirs(self._cache, exist_ok=True)
121
        else:
122
            self._cache = None
123
        self._rank_total = rank_total
124
        self._rank_number = rank_number
125
        self._index = None
126
        self._bin_buffer = None
127
        self._bin_buffer_mmap = None
128
        self.max_state, self.history = self._probe_data_path(self._path, self._name, self._rank_total)
129
        self.total_length = self.history[self.max_state-1][1]
130

131
        self._do_init(self._path, self._name, self._cache, self._state)
132

133
    def _probe_data_path(self, path, name, rank_total):
134
        print_rank("Probing Dataset")
135
            
136
        state = 0
137
        history = {-1:(0, 0)}
138
        for state in range(np.iinfo(np.int32).max):
139
            source_file = path + name + f"_{state}"
140
            if self.exists(source_file):
141
                index = self.Index(index_file_path(source_file))
142
                history[state] = (history[state-1][1], history[state-1][1] + len(index))
143
            else:
144
                break
145
            
146
        print_rank(f"Probing end. Max data state {state}, total length {history[state-1][1]}")
147
        
148
        return state, history
149

150
    def __getstate__(self):
151
        return self._path + self._name + "_%d"%(self._state)
152

153
    def __setstate__(self, state):
154
        self._state = state
155
        self._do_init(self._path, self._name, self._cache, self._state)
156

157
    def _do_init(self, path, name, cache, state):
158
        if self._bin_buffer_mmap is not None:
159
            self._bin_buffer_mmap._mmap.close()
160
            del self._bin_buffer_mmap
161
        if self._index is not None:
162
            del self._index
163

164
        self._state = state
165

166
        source_file = path + name + f"_{self._state}"
167
        self._index = self.Index(index_file_path(source_file))
168
        self._bin_buffer_mmap = np.memmap(data_file_path(source_file), mode='r', order='C')
169
        self._bin_buffer = memoryview(self._bin_buffer_mmap)
170

171
    def __del__(self):
172
        if self._bin_buffer_mmap is not None:
173
            self._bin_buffer_mmap._mmap.close()
174
            del self._bin_buffer_mmap
175
        if self._index is not None:
176
            del self._index
177

178
    def __len__(self):
179
        return self.total_length
180

181
    def _next_file(self):
182
        self._state += 1
183
        if self._state >= self.max_state:
184
            self._state = 0
185
        # print_rank(f"next_file: {self._state}")
186
        self._do_init(self._path, self._name, self._cache, self._state)
187
    
188
    def __relative_idx(self, idx):
189
        res = idx - self.history[self._state][0]
190
        return res
191

192
    def __slice_item(self, start, stop):
193
        ptr = self._index._pointers[self.__relative_idx(start)]
194
        sizes = self._index._sizes[self.__relative_idx(start):self.__relative_idx(stop)]
195
        offsets = list(accumulate(sizes))
196
        np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=sum(sizes), offset=ptr)
197
        return np.split(np_array, offsets[:-1])
198

199
    def __getitem__(self, idx):
200
        if isinstance(idx, int):
201
            while idx >= self.history[self._state][1] or idx < self.history[self._state][0]:
202
                self._next_file()
203
            ptr, size = self._index[self.__relative_idx(idx)]
204
            return np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
205
        elif isinstance(idx, slice):
206
            raise NotImplementedError()
207

208
    @property
209
    def sizes(self):
210
        return self._index.sizes
211
        
212
    def exists(self, path):
213
        return (
214
            os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
215
        )
216

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.