llava

Форк
0
/
convert_seed_for_submission.py 
74 строки · 2.5 Кб
1
import os
2
import json
3
import argparse
4

5

6
def get_args():
7
    parser = argparse.ArgumentParser()
8
    parser.add_argument("--annotation-file", type=str)
9
    parser.add_argument("--result-file", type=str)
10
    parser.add_argument("--result-upload-file", type=str)
11
    return parser.parse_args()
12

13

14
def eval_single(result_file, eval_only_type=None):
15
    results = {}
16
    for line in open(result_file):
17
        row = json.loads(line)
18
        results[row['question_id']] = row
19

20
    type_counts = {}
21
    correct_counts = {}
22
    for question_data in data['questions']:
23
        if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue
24
        data_type = question_data['question_type_id']
25
        type_counts[data_type] = type_counts.get(data_type, 0) + 1
26
        try:
27
            question_id = int(question_data['question_id'])
28
        except:
29
            question_id = question_data['question_id']
30
        if question_id not in results:
31
            correct_counts[data_type] = correct_counts.get(data_type, 0)
32
            continue
33
        row = results[question_id]
34
        if row['text'] == question_data['answer']:
35
            correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
36

37
    total_count = 0
38
    total_correct = 0
39
    for data_type in sorted(type_counts.keys()):
40
        accuracy = correct_counts[data_type] / type_counts[data_type] * 100
41
        if eval_only_type is None:
42
            print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
43

44
        total_count += type_counts[data_type]
45
        total_correct += correct_counts[data_type]
46

47
    total_accuracy = total_correct / total_count * 100
48
    if eval_only_type is None:
49
        print(f"Total accuracy: {total_accuracy:.2f}%")
50
    else:
51
        print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
52

53
    return results
54

55
if __name__ == "__main__":
56
    args = get_args()
57
    data = json.load(open(args.annotation_file))
58
    ques_type_id_to_name = {id:n for n,id in data['question_type'].items()}
59

60
    results = eval_single(args.result_file)
61
    eval_single(args.result_file, eval_only_type='image')
62
    eval_single(args.result_file, eval_only_type='video')
63

64
    with open(args.result_upload_file, 'w') as fp:
65
        for question in data['questions']:
66
            qid = question['question_id']
67
            if qid in results:
68
                result = results[qid]
69
            else:
70
                result = results[int(qid)]
71
            fp.write(json.dumps({
72
                'question_id': qid,
73
                'prediction': result['text']
74
            }) + '\n')
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.