lmops

Форк
0
/
process_data_pretrain.py 
101 строка · 3.7 Кб
1
import multiprocessing
2
import os
3
import time
4
import torch
5
import sys
6
from numerize.numerize import numerize
7
import numpy as np
8
from data_utils.indexed_dataset import make_builder
9
from transformers import AutoTokenizer
10
from arguments import get_args
11

12

13
# 1. Implement an Encoder, which gives it a line of input data and it returns you the tokenized result.
14
class Encoder(object): 
15
    def __init__(self, args):
16
        self.args = args
17
        
18
    def initializer(self):
19
        Encoder.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
20

21
    def encode(self, line):
22
        line = line.replace("<@x(x!>", "\n")
23
        token_ids = Encoder.tokenizer.encode(line, add_special_tokens=False) + [Encoder.tokenizer.eos_token_id]
24
        
25
        return token_ids, len(line)
26

27

28
def main():
29
    args = get_args()
30
        
31
    args.processed_data_dir = os.path.join(args.processed_data_dir, numerize(args.train_num))
32

33
    os.makedirs(args.processed_data_dir, exist_ok=True)
34
        
35
    file_name = os.path.join(args.data_dir, "data.txt")
36
    fin = open(file_name, "r", encoding="utf-8")
37
    # encoder use the tokenizer to encode data
38
    encoder = Encoder(args)
39

40
    # 2. Mapping all datas with Encoder, with the help of multiprocessing
41
    pool = multiprocessing.Pool(processes=args.data_process_workers, initializer=encoder.initializer)
42
    encoded_docs = pool.imap_unordered(encoder.encode, fin, chunksize=50)
43
    proc_start = time.time()
44
    total_bytes_processed = 0
45

46
    # 3. tool `indexed_dataset` compress the tokenized data into binary format `bin_file`
47
    # it will also generate another small `idx_file` for saving meta information in order to decode `bin_file`.
48
    train_bin_file = os.path.join(args.processed_data_dir, f"train_{0}.bin")
49
    train_idx_file = os.path.join(args.processed_data_dir, f"train_{0}.idx")
50

51
    valid_bin_file = os.path.join(args.processed_data_dir, f"valid_{0}.bin")
52
    valid_idx_file = os.path.join(args.processed_data_dir, f"valid_{0}.idx")
53

54
    if args.model_type!="qwen":
55
        train_binary_builder = make_builder(train_bin_file, impl="mmap", dtype=np.uint16)
56
        valid_binary_builder = make_builder(valid_bin_file, impl="mmap", dtype=np.uint16)
57
    else:
58
        train_binary_builder = make_builder(train_bin_file, impl="mmap", dtype=np.uint32)
59
        valid_binary_builder = make_builder(valid_bin_file, impl="mmap", dtype=np.uint32)
60

61
    # put tokenized data into binary_builder
62
    buffer = []
63
    inst_num = 0
64
    for lid, (input_ids, bytes_processed) in enumerate(encoded_docs):
65
        total_bytes_processed += bytes_processed
66
        if input_ids is None:
67
            continue
68
        
69
        buffer.extend(input_ids)
70
        while len(buffer) >= args.max_length:
71
            inst = buffer[:args.max_length]
72
            buffer = buffer[args.max_length:]
73
        
74
            if inst_num < args.dev_num:
75
                valid_binary_builder.add_item(torch.IntTensor(inst))
76
            else:
77
                train_binary_builder.add_item(torch.IntTensor(inst))
78
            
79
            inst_num += 1
80
            
81
        if lid % 10000 == 0:
82
            current = time.time()
83
            elapsed = current - proc_start
84
            mbs = total_bytes_processed / elapsed / 1024 / 1024
85
            print(f"Processed {lid} documents. {inst_num} instances.",
86
                f"({lid/elapsed} docs/s, {mbs} MB/s).",
87
                file=sys.stderr)
88
        
89
        if inst_num - args.dev_num >= args.train_num:
90
            break
91

92
    # finish compressing tokenized data into `bin_file`, and generate meta information into `idx_file`
93
    train_binary_builder.finalize(train_idx_file)
94
    valid_binary_builder.finalize(valid_idx_file)
95

96
    # close multiproceessing mapping
97
    pool.close()
98

99

100
if __name__ == '__main__':
101
    main()

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.