lama
54 строки · 2.0 Кб
1#!/usr/bin/env python3
2
3
4import os
5from argparse import ArgumentParser
6
7
8def ssim_fid100_f1(metrics, fid_scale=100):
9ssim = metrics.loc['total', 'ssim']['mean']
10fid = metrics.loc['total', 'fid']['mean']
11fid_rel = max(0, fid_scale - fid) / fid_scale
12f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
13return f1
14
15
16def find_best_checkpoint(model_list, models_dir):
17with open(model_list) as f:
18models = [m.strip() for m in f.readlines()]
19with open(f'{model_list}_best', 'w') as f:
20for model in models:
21print(model)
22best_f1 = 0
23best_epoch = 0
24best_step = 0
25with open(os.path.join(models_dir, model, 'train.log')) as fm:
26lines = fm.readlines()
27for line_index in range(len(lines)):
28line = lines[line_index]
29if 'Validation metrics after epoch' in line:
30sharp_index = line.index('#')
31cur_ep = line[sharp_index + 1:]
32comma_index = cur_ep.index(',')
33cur_ep = int(cur_ep[:comma_index])
34total_index = line.index('total ')
35step = int(line[total_index:].split()[1].strip())
36total_line = lines[line_index + 5]
37if not total_line.startswith('total'):
38continue
39words = total_line.strip().split()
40f1 = float(words[-1])
41print(f'\tEpoch: {cur_ep}, f1={f1}')
42if f1 > best_f1:
43best_f1 = f1
44best_epoch = cur_ep
45best_step = step
46f.write(f'{model}\t{best_epoch}\t{best_step}\t{best_f1}\n')
47
48
49if __name__ == '__main__':
50parser = ArgumentParser()
51parser.add_argument('model_list')
52parser.add_argument('models_dir')
53args = parser.parse_args()
54find_best_checkpoint(args.model_list, args.models_dir)
55