UltraChat

Форк
0
/
split_long.py 
110 строк · 3.4 Кб
1
"""
2
Modified based on https://github.com/lm-sys/FastChat/blob/main/fastchat/data/split_long_conversation.py
3
Split long conversations based on certain max length.
4

5
Usage:
6
1. download json data files to `./raw`
7
2. run command below for each file
8
python -u split_long.py --in-file ./raw/input.json --out-file ./processed/output.json --begin 0 --model-name-or-path /path/to/huggingface/llama --max-length 2048
9
"""
10
import argparse
11
from concurrent.futures import ProcessPoolExecutor
12
import json
13
from typing import Dict, Sequence, Optional
14

15
import transformers
16
from tqdm import tqdm
17

18

19
def make_sample(sample, start_idx, end_idx):
20
    assert (end_idx - start_idx) % 2 == 0
21
    return {
22
        "id": sample["id"] + "_" + str(start_idx),
23
        "data": sample["data"][start_idx:end_idx],
24
    }
25

26

27
tokenizer = max_length = None
28

29

30
def split_one_sample(sample):
31
    tokenized_lens = []
32
    conversations = sample["data"]
33
    assert len(conversations) %2 == 0, print(conversations)
34
    # conversations = conversations[: len(conversations) // 2 * 2]
35
    for c in conversations:
36
        length = len(tokenizer(c).input_ids) + 6
37
        tokenized_lens.append(length)
38

39
    start_idx = 0
40
    cur_len = 0
41

42
    # if len(conversations) % 2 != 0 or len(conversations) < 2:
43
    #     return []
44

45
    new_samples = []
46
    for i in range(0, len(conversations), 2):
47
        tmp_len = tokenized_lens[i] + tokenized_lens[i + 1]
48
        if cur_len + tmp_len > max_length:
49
            new_samples.append(make_sample(sample, start_idx, i))
50
            start_idx = i
51
            cur_len = 0
52
        elif i == len(conversations) - 2:
53
            new_samples.append(make_sample(sample, start_idx, i + 2))
54

55
        cur_len += tmp_len
56

57
    return new_samples
58

59

60
def split_all(content, begin, end, tokenizer_, max_length_):
61
    """
62
    Keep the maximum round of conversations within the max token length constraint
63
    """
64
    global tokenizer, max_length
65
    tokenizer = tokenizer_
66
    max_length = max_length_
67

68
    content = content[begin:end]
69
    new_content = []
70

71
    with ProcessPoolExecutor() as executor:
72
        for result in tqdm(executor.map(split_one_sample, content), total=len(content)):
73
            new_content.extend(result)
74

75
    return new_content
76

77
def check_content(content):
78
    new_content = []
79
    for c in content:
80
        if len(c["data"]) > 0 and len(c["data"]) % 2 == 0:
81
            new_content.append(c)
82
    return new_content
83

84

85
def main(args):
86
    content = [json.loads(l) for l in open(args.in_file, "r")]
87
    tokenizer = transformers.AutoTokenizer.from_pretrained(
88
        args.model_name_or_path,
89
        model_max_length=args.max_length,
90
        padding_side="right",
91
        use_fast=False,
92
    )
93
    new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length)
94
    new_content = check_content(new_content)
95

96
    print(f"total: {len(content)}, new: {len(new_content)}")
97
    with open(args.out_file, "w")as f:
98
        f.writelines("\n".join([json.dumps(l) for l in new_content]))
99

100

101
if __name__ == "__main__":
102
    parser = argparse.ArgumentParser()
103
    parser.add_argument("--in-file", type=str, required=True)
104
    parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
105
    parser.add_argument("--begin", type=int)
106
    parser.add_argument("--end", type=int)
107
    parser.add_argument("--model-name-or-path", type=str, required=True)
108
    parser.add_argument("--max-length", type=int, default=2048)
109
    args = parser.parse_args()
110
    main(args)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.