lama
33 строки · 1.2 Кб
1import logging
2
3import torch
4
5from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
6from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore
7
8
9def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
10logging.info(f'Make evaluator {kind}')
11device = "cuda" if torch.cuda.is_available() else "cpu"
12metrics = {}
13if ssim:
14metrics['ssim'] = SSIMScore()
15if lpips:
16metrics['lpips'] = LPIPSScore()
17if fid:
18metrics['fid'] = FIDScore().to(device)
19
20if integral_kind is None:
21integral_func = None
22elif integral_kind == 'ssim_fid100_f1':
23integral_func = ssim_fid100_f1
24elif integral_kind == 'lpips_fid100_f1':
25integral_func = lpips_fid100_f1
26else:
27raise ValueError(f'Unexpected integral_kind={integral_kind}')
28
29if kind == 'default':
30return InpaintingEvaluatorOnline(scores=metrics,
31integral_func=integral_func,
32integral_title=integral_kind,
33**kwargs)
34