moe-llava

Форк
0
/
convert_sqa_to_llava.py 
88 строк · 2.9 Кб
1
import json
2
import os
3
import fire
4
import re
5
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
6

7

8
def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
9
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
10
    problems = json.load(open(os.path.join(base_dir, "problems.json")))
11

12
    split_problems = build_prompt_chatbot(
13
        problems, split_indices, prompt_format,
14
        use_caption=False, is_test=False)
15

16
    target_format = []
17
    for prob_id, (input, output) in split_problems.items():
18
        if input.startswith('Question: '):
19
            input = input.replace('Question: ', '')
20
        if output.startswith('Answer: '):
21
            output = output.replace('Answer: ', '')
22

23
        raw_prob_data = problems[prob_id]
24
        if raw_prob_data['image'] is None:
25
            target_format.append({
26
                "id": prob_id,
27
                "conversations": [
28
                    {'from': 'human', 'value': f"{input}"},
29
                    {'from': 'gpt', 'value': f"{output}"},
30
                ],
31
            })
32

33
        else:
34
            target_format.append({
35
                "id": prob_id,
36
                "image": os.path.join(prob_id, raw_prob_data['image']),
37
                "conversations": [
38
                    {'from': 'human', 'value': f"{input}\n<image>"},
39
                    {'from': 'gpt', 'value': f"{output}"},
40
                ],
41
            })
42

43
    print(f'Number of samples: {len(target_format)}')
44

45
    with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
46
        json.dump(target_format, f, indent=2)
47

48

49
def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
50
    split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
51
    problems = json.load(open(os.path.join(base_dir, "problems.json")))
52

53
    split_problems = build_prompt_chatbot(
54
        problems, split_indices, prompt_format,
55
        use_caption=False, is_test=False)
56

57
    writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
58
    for prob_id, (input, output) in split_problems.items():
59
        if input.startswith('Question: '):
60
            input = input.replace('Question: ', '')
61
        if output.startswith('Answer: '):
62
            output = output.replace('Answer: ', '')
63

64
        raw_prob_data = problems[prob_id]
65
        if raw_prob_data['image'] is None:
66
            data = {
67
                "id": prob_id,
68
                "instruction": f"{input}",
69
                "output": f"{output}",
70
            }
71

72
        else:
73
            data = {
74
                "id": prob_id,
75
                "image": os.path.join(prob_id, raw_prob_data['image']),
76
                "instruction": f"{input}\n<image>",
77
                "output": f"{output}",
78
            }
79
        writer.write(json.dumps(data) + '\n')
80
    writer.close()
81

82

83
def main(task, **kwargs):
84
    globals()[task](**kwargs)
85

86

87
if __name__ == "__main__":
88
    fire.Fire(main)
89

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.