gpt-neox

Форк
0
/
indexed_dataset.py 
595 строк · 18.8 Кб
1
# Copyright (c) 2024, EleutherAI
2
# This file is based on code by the authors denoted below and has been modified from its original version.
3
#
4
# Copyright (c) Facebook, Inc. and its affiliates.
5
#
6
# This source code is licensed under the MIT license found in the
7
# LICENSE file in the root directory of this source tree.
8

9

10
# copied from fairseq/fairseq/data/indexed_dataset.py
11
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
12
# other slight modifications to remove fairseq dependencies
13
# Added document index to index file and made it accessible.
14
#    An empty sentence no longer separates documents.
15

16
import os
17
import shutil
18
import struct
19
from functools import lru_cache
20
from itertools import accumulate
21

22
import numpy as np
23
import torch
24

25
from megatron import print_rank_0
26

27

28
def __best_fitting_dtype(vocab_size=None):
29
    if vocab_size is not None and vocab_size < 65500:
30
        return np.uint16
31
    else:
32
        return np.int32
33

34

35
def infer_dataset_impl(path):
36
    if IndexedDataset.exists(path):
37
        with open(index_file_path(path), "rb") as f:
38
            magic = f.read(8)
39
            if magic == IndexedDataset._HDR_MAGIC:
40
                return "cached"
41
            elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
42
                return "mmap"
43
            else:
44
                return None
45
    else:
46
        print(f"Dataset does not exist: {path}")
47
        print(
48
            "Path should be a basename that both .idx and .bin can be appended to get full filenames."
49
        )
50
        return None
51

52

53
def make_builder(out_file, impl, vocab_size=None):
54
    if impl == "mmap":
55
        return MMapIndexedDatasetBuilder(
56
            out_file, dtype=__best_fitting_dtype(vocab_size)
57
        )
58
    else:
59
        return IndexedDatasetBuilder(out_file)
60

61

62
def make_dataset(path, impl, skip_warmup=False):
63
    if not IndexedDataset.exists(path):
64
        print(f"Dataset does not exist: {path}")
65
        print(
66
            "Path should be a basename that both .idx and .bin can be appended to get full filenames."
67
        )
68
        return None
69
    if impl == "infer":
70
        impl = infer_dataset_impl(path)
71
    elif impl == "cached" and IndexedDataset.exists(path):
72
        return IndexedCachedDataset(path)
73
    elif impl == "mmap" and MMapIndexedDataset.exists(path):
74
        return MMapIndexedDataset(path, skip_warmup)
75
    print(f"Unknown dataset implementation: {impl}")
76
    return None
77

78

79
def dataset_exists(path, impl):
80
    if impl == "mmap":
81
        return MMapIndexedDataset.exists(path)
82
    else:
83
        return IndexedDataset.exists(path)
84

85

86
def read_longs(f, n):
87
    a = np.empty(n, dtype=np.int64)
88
    f.readinto(a)
89
    return a
90

91

92
def write_longs(f, a):
93
    f.write(np.array(a, dtype=np.int64))
94

95

96
dtypes = {
97
    1: np.uint8,
98
    2: np.int8,
99
    3: np.int16,
100
    4: np.int32,
101
    5: np.int64,
102
    6: np.float32,
103
    7: np.float64,
104
    8: np.uint16,
105
}
106

107

108
def code(dtype):
109
    for k in dtypes.keys():
110
        if dtypes[k] == dtype:
111
            return k
112
    raise ValueError(dtype)
113

114

115
def index_file_path(prefix_path):
116
    return prefix_path + ".idx"
117

118

119
def data_file_path(prefix_path):
120
    return prefix_path + ".bin"
121

122

123
def create_doc_idx(sizes):
124
    doc_idx = [0]
125
    for i, s in enumerate(sizes):
126
        if s == 0:
127
            doc_idx.append(i + 1)
128
    return doc_idx
129

130

131
class IndexedDataset(torch.utils.data.Dataset):
132
    """Loader for IndexedDataset"""
133

134
    _HDR_MAGIC = b"TNTIDX\x00\x00"
135

136
    def __init__(self, path):
137
        super().__init__()
138
        self.path = path
139
        self.data_file = None
140
        self.read_index(path)
141

142
    def read_index(self, path):
143
        with open(index_file_path(path), "rb") as f:
144
            magic = f.read(8)
145
            assert magic == self._HDR_MAGIC, (
146
                "Index file doesn't match expected format. "
147
                "Make sure that --dataset-impl is configured properly."
148
            )
149
            version = f.read(8)
150
            assert struct.unpack("<Q", version) == (1,)
151
            code, self.element_size = struct.unpack("<QQ", f.read(16))
152
            self.dtype = dtypes[code]
153
            self._len, self.s = struct.unpack("<QQ", f.read(16))
154
            self.doc_count = struct.unpack("<Q", f.read(8))
155
            self.dim_offsets = read_longs(f, self._len + 1)
156
            self.data_offsets = read_longs(f, self._len + 1)
157
            self.sizes = read_longs(f, self.s)
158
            self.doc_idx = read_longs(f, self.doc_count)
159

160
    def read_data(self, path):
161
        self.data_file = open(data_file_path(path), "rb", buffering=0)
162

163
    def check_index(self, i):
164
        if i < 0 or i >= self._len:
165
            raise IndexError("index out of range")
166

167
    def __del__(self):
168
        if self.data_file:
169
            self.data_file.close()
170

171
    # @lru_cache(maxsize=8)
172
    def __getitem__(self, idx):
173
        if not self.data_file:
174
            self.read_data(self.path)
175
        if isinstance(idx, int):
176
            i = idx
177
            self.check_index(i)
178
            tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
179
            a = np.empty(tensor_size, dtype=self.dtype)
180
            self.data_file.seek(self.data_offsets[i] * self.element_size)
181
            self.data_file.readinto(a)
182
            return a
183
        elif isinstance(idx, slice):
184
            start, stop, step = idx.indices(len(self))
185
            if step != 1:
186
                raise ValueError("Slices into indexed_dataset must be contiguous")
187
            sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
188
            size = sum(sizes)
189
            a = np.empty(size, dtype=self.dtype)
190
            self.data_file.seek(self.data_offsets[start] * self.element_size)
191
            self.data_file.readinto(a)
192
            offsets = list(accumulate(sizes))
193
            sents = np.split(a, offsets[:-1])
194
            return sents
195

196
    def __len__(self):
197
        return self._len
198

199
    def num_tokens(self, index):
200
        return self.sizes[index]
201

202
    def size(self, index):
203
        return self.sizes[index]
204

205
    @staticmethod
206
    def exists(path):
207
        return os.path.exists(index_file_path(path)) and os.path.exists(
208
            data_file_path(path)
209
        )
210

211
    @property
212
    def supports_prefetch(self):
213
        return False  # avoid prefetching to save memory
214

215

216
class IndexedCachedDataset(IndexedDataset):
217
    def __init__(self, path):
218
        super().__init__(path)
219
        self.cache = None
220
        self.cache_index = {}
221

222
    @property
223
    def supports_prefetch(self):
224
        return True
225

226
    def prefetch(self, indices):
227
        if all(i in self.cache_index for i in indices):
228
            return
229
        if not self.data_file:
230
            self.read_data(self.path)
231
        indices = sorted(set(indices))
232
        total_size = 0
233
        for i in indices:
234
            total_size += self.data_offsets[i + 1] - self.data_offsets[i]
235
        self.cache = np.empty(total_size, dtype=self.dtype)
236
        ptx = 0
237
        self.cache_index.clear()
238
        for i in indices:
239
            self.cache_index[i] = ptx
240
            size = self.data_offsets[i + 1] - self.data_offsets[i]
241
            a = self.cache[ptx : ptx + size]
242
            self.data_file.seek(self.data_offsets[i] * self.element_size)
243
            self.data_file.readinto(a)
244
            ptx += size
245
        if self.data_file:
246
            # close and delete data file after prefetch so we can pickle
247
            self.data_file.close()
248
            self.data_file = None
249

250
    # @lru_cache(maxsize=8)
251
    def __getitem__(self, idx):
252
        if isinstance(idx, int):
253
            i = idx
254
            self.check_index(i)
255
            tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
256
            a = np.empty(tensor_size, dtype=self.dtype)
257
            ptx = self.cache_index[i]
258
            np.copyto(a, self.cache[ptx : ptx + a.size])
259
            return a
260
        elif isinstance(idx, slice):
261
            # Hack just to make this work, can optimizer later if necessary
262
            sents = []
263
            for i in range(*idx.indices(len(self))):
264
                sents.append(self[i])
265
            return sents
266

267

268
class IndexedDatasetBuilder(object):
269
    element_sizes = {
270
        np.uint8: 1,
271
        np.int8: 1,
272
        np.int16: 2,
273
        np.int32: 4,
274
        np.int64: 8,
275
        np.float32: 4,
276
        np.float64: 8,
277
    }
278

279
    def __init__(self, out_file, dtype=np.int32):
280
        self.out_file = open(out_file, "wb")
281
        self.dtype = dtype
282
        self.data_offsets = [0]
283
        self.dim_offsets = [0]
284
        self.sizes = []
285
        self.element_size = self.element_sizes[self.dtype]
286
        self.doc_idx = [0]
287

288
    def add_item(self, np_array):
289
        assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
290
        bytes = self.out_file.write(np_array)
291
        self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
292
        for s in np_array.shape:
293
            self.sizes.append(s)
294
        self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape))
295

296
    def end_document(self):
297
        self.doc_idx.append(len(self.sizes))
298

299
    def merge_file_(self, another_file):
300
        index = IndexedDataset(another_file)
301
        assert index.dtype == self.dtype
302

303
        begin = self.data_offsets[-1]
304
        for offset in index.data_offsets[1:]:
305
            self.data_offsets.append(begin + offset)
306
        self.sizes.extend(index.sizes)
307
        begin = self.dim_offsets[-1]
308
        for dim_offset in index.dim_offsets[1:]:
309
            self.dim_offsets.append(begin + dim_offset)
310

311
        with open(data_file_path(another_file), "rb") as f:
312
            while True:
313
                data = f.read(1024)
314
                if data:
315
                    self.out_file.write(data)
316
                else:
317
                    break
318

319
    def finalize(self, index_file):
320
        self.out_file.close()
321
        index = open(index_file, "wb")
322
        index.write(b"TNTIDX\x00\x00")
323
        index.write(struct.pack("<Q", 1))
324
        index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
325
        index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
326
        index.write(struct.pack("<Q", len(self.doc_idx)))
327
        write_longs(index, self.dim_offsets)
328
        write_longs(index, self.data_offsets)
329
        write_longs(index, self.sizes)
330
        write_longs(index, self.doc_idx)
331
        index.close()
332

333

334
def _warmup_mmap_file(path):
335
    with open(path, "rb") as stream:
336
        while stream.read(100 * 1024 * 1024):
337
            pass
338

339

340
class MMapIndexedDataset(torch.utils.data.Dataset):
341
    class Index(object):
342
        _HDR_MAGIC = b"MMIDIDX\x00\x00"
343

344
        @classmethod
345
        def writer(cls, path, dtype):
346
            class _Writer(object):
347
                def __enter__(self):
348
                    self._file = open(path, "wb")
349

350
                    # Write Magic string so we can check the file format then opening it again.
351
                    self._file.write(cls._HDR_MAGIC)
352
                    # Write version number
353
                    # Little endian unsigned 64 Bit integer
354
                    self._file.write(struct.pack("<Q", 1))
355
                    # Little endian unsigned 8 Bit integer
356
                    self._file.write(struct.pack("<B", code(dtype)))
357

358
                    return self
359

360
                @staticmethod
361
                def _get_pointers(sizes):
362
                    pointers = np.zeros(len(sizes), dtype=np.int64)
363
                    sizes = np.array(sizes, dtype=np.int64)
364

365
                    np.cumsum(sizes[:-1], out=pointers[1:])
366
                    pointers = pointers * dtype().itemsize
367
                    return pointers
368

369
                def write(self, sizes, doc_idx):
370
                    pointers = self._get_pointers(sizes)
371

372
                    # Little endian unsigned 64 Bit integer
373
                    self._file.write(struct.pack("<Q", len(sizes)))
374
                    # Little endian unsigned 64 Bit integer
375
                    self._file.write(struct.pack("<Q", len(doc_idx)))
376

377
                    sizes = np.array(sizes, dtype=np.int32)
378
                    self._file.write(sizes.tobytes(order="C"))
379
                    del sizes
380

381
                    pointers = np.array(pointers, dtype=np.int64)
382
                    self._file.write(pointers.tobytes(order="C"))
383
                    del pointers
384

385
                    doc_idx = np.array(doc_idx, dtype=np.int64)
386
                    self._file.write(doc_idx.tobytes(order="C"))
387

388
                def __exit__(self, exc_type, exc_val, exc_tb):
389
                    self._file.close()
390

391
            return _Writer()
392

393
        def __init__(self, path, skip_warmup=False):
394
            with open(path, "rb") as stream:
395
                magic_test = stream.read(9)
396
                assert self._HDR_MAGIC == magic_test, (
397
                    "Index file doesn't match expected format. "
398
                    "Make sure that --dataset-impl is configured properly."
399
                )
400
                # Little endian unsigned 64 Bit integer
401
                version = struct.unpack("<Q", stream.read(8))
402
                assert (1,) == version
403

404
                # Little endian unsigned 8 Bit integer
405
                (dtype_code,) = struct.unpack("<B", stream.read(1))
406
                self._dtype = dtypes[dtype_code]
407
                self._dtype_size = self._dtype().itemsize
408

409
                self._len = struct.unpack("<Q", stream.read(8))[0]
410
                self._doc_count = struct.unpack("<Q", stream.read(8))[0]
411
                offset = stream.tell()
412

413
            if not skip_warmup:
414
                print_rank_0("    warming up index mmap file...")
415
                _warmup_mmap_file(path)
416

417
            self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
418
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
419
            print_rank_0("    reading sizes...")
420
            self._sizes = np.frombuffer(
421
                self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
422
            )
423
            print_rank_0("    reading pointers...")
424
            self._pointers = np.frombuffer(
425
                self._bin_buffer,
426
                dtype=np.int64,
427
                count=self._len,
428
                offset=offset + self._sizes.nbytes,
429
            )
430
            print_rank_0("    reading document index...")
431
            self._doc_idx = np.frombuffer(
432
                self._bin_buffer,
433
                dtype=np.int64,
434
                count=self._doc_count,
435
                offset=offset + self._sizes.nbytes + self._pointers.nbytes,
436
            )
437

438
        def __del__(self):
439
            self._bin_buffer_mmap._mmap.close()
440
            del self._bin_buffer_mmap
441

442
        @property
443
        def dtype(self):
444
            return self._dtype
445

446
        @property
447
        def sizes(self):
448
            return self._sizes
449

450
        @property
451
        def doc_idx(self):
452
            return self._doc_idx
453

454
        @lru_cache(maxsize=8)
455
        def __getitem__(self, i):
456
            return self._pointers[i], self._sizes[i]
457

458
        def __len__(self):
459
            return self._len
460

461
    def __init__(self, path, skip_warmup=False):
462
        super().__init__()
463

464
        self._path = None
465
        self._index = None
466
        self._bin_buffer = None
467

468
        self._do_init(path, skip_warmup)
469

470
    def __getstate__(self):
471
        return self._path
472

473
    def __setstate__(self, state):
474
        self._do_init(state)
475

476
    def _do_init(self, path, skip_warmup):
477
        self._path = path
478
        self._index = self.Index(index_file_path(self._path), skip_warmup)
479

480
        if not skip_warmup:
481
            print_rank_0("    warming up data mmap file...")
482
            _warmup_mmap_file(data_file_path(self._path))
483
        print_rank_0("    creating numpy buffer of mmap...")
484
        self._bin_buffer_mmap = np.memmap(
485
            data_file_path(self._path), mode="r", order="C"
486
        )
487
        print_rank_0("    creating memory view of numpy buffer...")
488
        self._bin_buffer = memoryview(self._bin_buffer_mmap)
489

490
    def __del__(self):
491
        self._bin_buffer_mmap._mmap.close()
492
        del self._bin_buffer_mmap
493
        del self._index
494

495
    def __len__(self):
496
        return len(self._index)
497

498
    # @lru_cache(maxsize=8)
499
    def __getitem__(self, idx):
500
        if isinstance(idx, int):
501
            ptr, size = self._index[idx]
502
            np_array = np.frombuffer(
503
                self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
504
            )
505
            return np_array
506
        elif isinstance(idx, slice):
507
            start, stop, step = idx.indices(len(self))
508
            if step != 1:
509
                raise ValueError("Slices into indexed_dataset must be contiguous")
510
            ptr = self._index._pointers[start]
511
            sizes = self._index._sizes[idx]
512
            offsets = list(accumulate(sizes))
513
            total_size = sum(sizes)
514
            np_array = np.frombuffer(
515
                self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
516
            )
517
            sents = np.split(np_array, offsets[:-1])
518
            return sents
519

520
    def get(self, idx, offset=0, length=None):
521
        """Retrieves a single item from the dataset with the option to only
522
        return a portion of the item.
523

524
        get(idx) is the same as [idx] but get() does not support slicing.
525
        """
526
        ptr, size = self._index[idx]
527
        if length is None:
528
            length = size - offset
529
        ptr += offset * np.dtype(self._index.dtype).itemsize
530
        np_array = np.frombuffer(
531
            self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
532
        )
533
        return np_array
534

535
    @property
536
    def sizes(self):
537
        return self._index.sizes
538

539
    @property
540
    def doc_idx(self):
541
        return self._index.doc_idx
542

543
    def get_doc_idx(self):
544
        return self._index._doc_idx
545

546
    def set_doc_idx(self, doc_idx_):
547
        self._index._doc_idx = doc_idx_
548

549
    @property
550
    def supports_prefetch(self):
551
        return False
552

553
    @staticmethod
554
    def exists(path):
555
        return os.path.exists(index_file_path(path)) and os.path.exists(
556
            data_file_path(path)
557
        )
558

559

560
class MMapIndexedDatasetBuilder(object):
561
    def __init__(self, out_file, dtype=np.int64):
562
        self._data_file = open(out_file, "wb")
563
        self._dtype = dtype
564
        self._sizes = []
565
        self._doc_idx = [0]
566

567
    @property
568
    def dtype(self):
569
        return self._dtype
570

571
    def add_item(self, np_array):
572
        assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
573
        self._data_file.write(np_array.tobytes(order="C"))
574
        self._sizes.append(np_array.size)
575

576
    def end_document(self):
577
        self._doc_idx.append(len(self._sizes))
578

579
    def merge_file_(self, another_file):
580
        # Concatenate index
581
        index = MMapIndexedDataset.Index(index_file_path(another_file))
582
        assert index.dtype == self._dtype
583

584
        for size in index.sizes:
585
            self._sizes.append(size)
586

587
        # Concatenate data
588
        with open(data_file_path(another_file), "rb") as f:
589
            shutil.copyfileobj(f, self._data_file)
590

591
    def finalize(self, index_file):
592
        self._data_file.close()
593

594
        with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
595
            index.write(self._sizes, self._doc_idx)
596

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.