lmops
215 строк · 7.0 Кб
1# coding=utf-8
2# Copyright 2020 The OpenBMB team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os
17import struct
18import shutil
19
20from itertools import accumulate
21
22import numpy as np
23import torch
24import torch.distributed as dist
25from utils import print_rank, save_rank
26
27
28dtypes = {
291: np.uint8,
302: np.int8,
313: np.int16,
324: np.int32,
335: np.int64,
346: np.float32,
357: np.double,
368: np.uint16,
379: np.uint32
38}
39
40
41def code(dtype):
42for k in dtypes.keys():
43if dtypes[k] == dtype:
44return k
45raise ValueError(dtype)
46
47
48def index_file_path(prefix_path):
49return prefix_path + '.idx'
50
51
52def data_file_path(prefix_path):
53return prefix_path + '.bin'
54
55
56class DistributedMMapIndexedDataset(torch.utils.data.Dataset):
57class Index(object):
58_HDR_MAGIC = b'MMIDIDX\x00\x00'
59def __init__(self, path):
60with open(path, 'rb') as stream:
61magic_test = stream.read(9)
62assert self._HDR_MAGIC == magic_test, (
63'Index file doesn\'t match expected format. '
64'Make sure that --dataset-impl is configured properly.'
65)
66version = struct.unpack('<Q', stream.read(8))
67assert (1,) == version
68
69dtype_code, = struct.unpack('<B', stream.read(1))
70self._dtype = dtypes[dtype_code]
71self._dtype_size = self._dtype().itemsize
72
73self._len = struct.unpack('<Q', stream.read(8))[0]
74self._doc_count = struct.unpack('<Q', stream.read(8))[0]
75offset = stream.tell()
76
77self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
78self._bin_buffer = memoryview(self._bin_buffer_mmap)
79self._sizes = np.frombuffer(
80self._bin_buffer,
81dtype=np.int32,
82count=self._len,
83offset=offset)
84self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
85offset=offset + self._sizes.nbytes)
86self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
87offset=offset + self._sizes.nbytes + self._pointers.nbytes)
88
89def __del__(self):
90self._bin_buffer_mmap._mmap.close()
91del self._bin_buffer_mmap
92
93@property
94def dtype(self):
95return self._dtype
96
97@property
98def sizes(self):
99return self._sizes
100
101@property
102def doc_idx(self):
103return self._doc_idx
104
105def __getitem__(self, i):
106return self._pointers[i], self._sizes[i]
107
108def __len__(self):
109return self._len
110
111def __init__(self, path, name, rank_number, rank_total, cache = None):
112
113super().__init__()
114
115self._path = path
116self._name = name
117self._state = 0
118if cache is not None:
119self._cache = cache
120os.makedirs(self._cache, exist_ok=True)
121else:
122self._cache = None
123self._rank_total = rank_total
124self._rank_number = rank_number
125self._index = None
126self._bin_buffer = None
127self._bin_buffer_mmap = None
128self.max_state, self.history = self._probe_data_path(self._path, self._name, self._rank_total)
129self.total_length = self.history[self.max_state-1][1]
130
131self._do_init(self._path, self._name, self._cache, self._state)
132
133def _probe_data_path(self, path, name, rank_total):
134print_rank("Probing Dataset")
135
136state = 0
137history = {-1:(0, 0)}
138for state in range(np.iinfo(np.int32).max):
139source_file = path + name + f"_{state}"
140if self.exists(source_file):
141index = self.Index(index_file_path(source_file))
142history[state] = (history[state-1][1], history[state-1][1] + len(index))
143else:
144break
145
146print_rank(f"Probing end. Max data state {state}, total length {history[state-1][1]}")
147
148return state, history
149
150def __getstate__(self):
151return self._path + self._name + "_%d"%(self._state)
152
153def __setstate__(self, state):
154self._state = state
155self._do_init(self._path, self._name, self._cache, self._state)
156
157def _do_init(self, path, name, cache, state):
158if self._bin_buffer_mmap is not None:
159self._bin_buffer_mmap._mmap.close()
160del self._bin_buffer_mmap
161if self._index is not None:
162del self._index
163
164self._state = state
165
166source_file = path + name + f"_{self._state}"
167self._index = self.Index(index_file_path(source_file))
168self._bin_buffer_mmap = np.memmap(data_file_path(source_file), mode='r', order='C')
169self._bin_buffer = memoryview(self._bin_buffer_mmap)
170
171def __del__(self):
172if self._bin_buffer_mmap is not None:
173self._bin_buffer_mmap._mmap.close()
174del self._bin_buffer_mmap
175if self._index is not None:
176del self._index
177
178def __len__(self):
179return self.total_length
180
181def _next_file(self):
182self._state += 1
183if self._state >= self.max_state:
184self._state = 0
185# print_rank(f"next_file: {self._state}")
186self._do_init(self._path, self._name, self._cache, self._state)
187
188def __relative_idx(self, idx):
189res = idx - self.history[self._state][0]
190return res
191
192def __slice_item(self, start, stop):
193ptr = self._index._pointers[self.__relative_idx(start)]
194sizes = self._index._sizes[self.__relative_idx(start):self.__relative_idx(stop)]
195offsets = list(accumulate(sizes))
196np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=sum(sizes), offset=ptr)
197return np.split(np_array, offsets[:-1])
198
199def __getitem__(self, idx):
200if isinstance(idx, int):
201while idx >= self.history[self._state][1] or idx < self.history[self._state][0]:
202self._next_file()
203ptr, size = self._index[self.__relative_idx(idx)]
204return np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
205elif isinstance(idx, slice):
206raise NotImplementedError()
207
208@property
209def sizes(self):
210return self._index.sizes
211
212def exists(self, path):
213return (
214os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
215)
216