2
Modified based on https://github.com/lm-sys/FastChat/blob/main/fastchat/data/split_long_conversation.py
3
Split long conversations based on certain max length.
6
1. download json data files to `./raw`
7
2. run command below for each file
8
python -u split_long.py --in-file ./raw/input.json --out-file ./processed/output.json --begin 0 --model-name-or-path /path/to/huggingface/llama --max-length 2048
11
from concurrent.futures import ProcessPoolExecutor
13
from typing import Dict, Sequence, Optional
19
def make_sample(sample, start_idx, end_idx):
20
assert (end_idx - start_idx) % 2 == 0
22
"id": sample["id"] + "_" + str(start_idx),
23
"data": sample["data"][start_idx:end_idx],
27
tokenizer = max_length = None
30
def split_one_sample(sample):
32
conversations = sample["data"]
33
assert len(conversations) %2 == 0, print(conversations)
35
for c in conversations:
36
length = len(tokenizer(c).input_ids) + 6
37
tokenized_lens.append(length)
46
for i in range(0, len(conversations), 2):
47
tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
48
if cur_len + tmp_len > max_length:
49
new_samples.append(make_sample(sample, start_idx, i))
52
elif i == len(conversations) - 2:
53
new_samples.append(make_sample(sample, start_idx, i + 2))
60
def split_all(content, begin, end, tokenizer_, max_length_):
62
Keep the maximum round of conversations within the max token length constraint
64
global tokenizer, max_length
65
tokenizer = tokenizer_
66
max_length = max_length_
68
content = content[begin:end]
71
with ProcessPoolExecutor() as executor:
72
for result in tqdm(executor.map(split_one_sample, content), total=len(content)):
73
new_content.extend(result)
77
def check_content(content):
80
if len(c["data"]) > 0 and len(c["data"]) % 2 == 0:
86
content = [json.loads(l) for l in open(args.in_file, "r")]
87
tokenizer = transformers.AutoTokenizer.from_pretrained(
88
args.model_name_or_path,
89
model_max_length=args.max_length,
93
new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length)
94
new_content = check_content(new_content)
96
print(f"total: {len(content)}, new: {len(new_content)}")
97
with open(args.out_file, "w")as f:
98
f.writelines("\n".join([json.dumps(l) for l in new_content]))
101
if __name__ == "__main__":
102
parser = argparse.ArgumentParser()
103
parser.add_argument("--in-file", type=str, required=True)
104
parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
105
parser.add_argument("--begin", type=int)
106
parser.add_argument("--end", type=int)
107
parser.add_argument("--model-name-or-path", type=str, required=True)
108
parser.add_argument("--max-length", type=int, default=2048)
109
args = parser.parse_args()