6
from numerize.numerize import numerize
8
from data_utils.indexed_dataset import make_builder
9
from transformers import AutoTokenizer
10
from arguments import get_args
15
def __init__(self, args):
18
def initializer(self):
19
Encoder.tokenizer = AutoTokenizer.from_pretrained(self.args.model_path)
21
def encode(self, line):
22
line = line.replace("<@x(x!>", "\n")
23
token_ids = Encoder.tokenizer.encode(line, add_special_tokens=False) + [Encoder.tokenizer.eos_token_id]
25
return token_ids, len(line)
31
args.processed_data_dir = os.path.join(args.processed_data_dir, numerize(args.train_num))
33
os.makedirs(args.processed_data_dir, exist_ok=True)
35
file_name = os.path.join(args.data_dir, "data.txt")
36
fin = open(file_name, "r", encoding="utf-8")
38
encoder = Encoder(args)
41
pool = multiprocessing.Pool(processes=args.data_process_workers, initializer=encoder.initializer)
42
encoded_docs = pool.imap_unordered(encoder.encode, fin, chunksize=50)
43
proc_start = time.time()
44
total_bytes_processed = 0
48
train_bin_file = os.path.join(args.processed_data_dir, f"train_{0}.bin")
49
train_idx_file = os.path.join(args.processed_data_dir, f"train_{0}.idx")
51
valid_bin_file = os.path.join(args.processed_data_dir, f"valid_{0}.bin")
52
valid_idx_file = os.path.join(args.processed_data_dir, f"valid_{0}.idx")
54
if args.model_type!="qwen":
55
train_binary_builder = make_builder(train_bin_file, impl="mmap", dtype=np.uint16)
56
valid_binary_builder = make_builder(valid_bin_file, impl="mmap", dtype=np.uint16)
58
train_binary_builder = make_builder(train_bin_file, impl="mmap", dtype=np.uint32)
59
valid_binary_builder = make_builder(valid_bin_file, impl="mmap", dtype=np.uint32)
64
for lid, (input_ids, bytes_processed) in enumerate(encoded_docs):
65
total_bytes_processed += bytes_processed
69
buffer.extend(input_ids)
70
while len(buffer) >= args.max_length:
71
inst = buffer[:args.max_length]
72
buffer = buffer[args.max_length:]
74
if inst_num < args.dev_num:
75
valid_binary_builder.add_item(torch.IntTensor(inst))
77
train_binary_builder.add_item(torch.IntTensor(inst))
83
elapsed = current - proc_start
84
mbs = total_bytes_processed / elapsed / 1024 / 1024
85
print(f"Processed {lid} documents. {inst_num} instances.",
86
f"({lid/elapsed} docs/s, {mbs} MB/s).",
89
if inst_num - args.dev_num >= args.train_num:
93
train_binary_builder.finalize(train_idx_file)
94
valid_binary_builder.finalize(valid_idx_file)
100
if __name__ == '__main__':