lama
220 строк · 9.5 Кб
1import logging2import math3from typing import Dict4
5import numpy as np6import torch7import torch.nn as nn8import tqdm9from torch.utils.data import DataLoader10
11from saicinpainting.evaluation.utils import move_to_device12
13LOGGER = logging.getLogger(__name__)14
15
16class InpaintingEvaluator():17def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',18integral_func=None, integral_title=None, clamp_image_range=None):19"""20:param dataset: torch.utils.data.Dataset which contains images and masks
21:param scores: dict {score_name: EvaluatorScore object}
22:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
23which are defined by share of area occluded by mask
24:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
25:param batch_size: batch_size for the dataloader
26:param device: device to use
27"""
28self.scores = scores29self.dataset = dataset30
31self.area_grouping = area_grouping32self.bins = bins33
34self.device = torch.device(device)35
36self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)37
38self.integral_func = integral_func39self.integral_title = integral_title40self.clamp_image_range = clamp_image_range41
42def _get_bin_edges(self):43bin_edges = np.linspace(0, 1, self.bins + 1)44
45num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)46interval_names = []47for idx_bin in range(self.bins):48start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \49round(100 * bin_edges[idx_bin + 1], num_digits)50start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)51end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)52interval_names.append("{0}-{1}%".format(start_percent, end_percent))53
54groups = []55for batch in self.dataloader:56mask = batch['mask']57batch_size = mask.shape[0]58area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)59bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 160# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element61bin_indices[bin_indices == self.bins] = self.bins - 162groups.append(bin_indices)63groups = np.hstack(groups)64
65return groups, interval_names66
67def evaluate(self, model=None):68"""69:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
70:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
71name of the particular group arranged by area of mask (e.g. '10-20%')
72and score statistics for the group as values.
73"""
74results = dict()75if self.area_grouping:76groups, interval_names = self._get_bin_edges()77else:78groups = None79
80for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):81score.to(self.device)82with torch.no_grad():83score.reset()84for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):85batch = move_to_device(batch, self.device)86image_batch, mask_batch = batch['image'], batch['mask']87if self.clamp_image_range is not None:88image_batch = torch.clamp(image_batch,89min=self.clamp_image_range[0],90max=self.clamp_image_range[1])91if model is None:92assert 'inpainted' in batch, \93'Model is None, so we expected precomputed inpainting results at key "inpainted"'94inpainted_batch = batch['inpainted']95else:96inpainted_batch = model(image_batch, mask_batch)97score(inpainted_batch, image_batch, mask_batch)98total_results, group_results = score.get_value(groups=groups)99
100results[(score_name, 'total')] = total_results101if groups is not None:102for group_index, group_values in group_results.items():103group_name = interval_names[group_index]104results[(score_name, group_name)] = group_values105
106if self.integral_func is not None:107results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))108
109return results110
111
112def ssim_fid100_f1(metrics, fid_scale=100):113ssim = metrics[('ssim', 'total')]['mean']114fid = metrics[('fid', 'total')]['mean']115fid_rel = max(0, fid_scale - fid) / fid_scale116f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)117return f1118
119
120def lpips_fid100_f1(metrics, fid_scale=100):121neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better122fid = metrics[('fid', 'total')]['mean']123fid_rel = max(0, fid_scale - fid) / fid_scale124f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)125return f1126
127
128
129class InpaintingEvaluatorOnline(nn.Module):130def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',131integral_func=None, integral_title=None, clamp_image_range=None):132"""133:param scores: dict {score_name: EvaluatorScore object}
134:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
135:param device: device to use
136"""
137super().__init__()138LOGGER.info(f'{type(self)} init called')139self.scores = nn.ModuleDict(scores)140self.image_key = image_key141self.inpainted_key = inpainted_key142self.bins_num = bins143self.bin_edges = np.linspace(0, 1, self.bins_num + 1)144
145num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)146self.interval_names = []147for idx_bin in range(self.bins_num):148start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \149round(100 * self.bin_edges[idx_bin + 1], num_digits)150start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)151end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)152self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))153
154self.groups = []155
156self.integral_func = integral_func157self.integral_title = integral_title158self.clamp_image_range = clamp_image_range159
160LOGGER.info(f'{type(self)} init done')161
162def _get_bins(self, mask_batch):163batch_size = mask_batch.shape[0]164area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()165bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)166return bin_indices167
168def forward(self, batch: Dict[str, torch.Tensor]):169"""170Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
171:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
172"""
173result = {}174with torch.no_grad():175image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]176if self.clamp_image_range is not None:177image_batch = torch.clamp(image_batch,178min=self.clamp_image_range[0],179max=self.clamp_image_range[1])180self.groups.extend(self._get_bins(mask_batch))181
182for score_name, score in self.scores.items():183result[score_name] = score(inpainted_batch, image_batch, mask_batch)184return result185
186def process_batch(self, batch: Dict[str, torch.Tensor]):187return self(batch)188
189def evaluation_end(self, states=None):190""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or191name of the particular group arranged by area of mask (e.g. '10-20%')
192and score statistics for the group as values.
193"""
194LOGGER.info(f'{type(self)}: evaluation_end called')195
196self.groups = np.array(self.groups)197
198results = {}199for score_name, score in self.scores.items():200LOGGER.info(f'Getting value of {score_name}')201cur_states = [s[score_name] for s in states] if states is not None else None202total_results, group_results = score.get_value(groups=self.groups, states=cur_states)203LOGGER.info(f'Getting value of {score_name} done')204results[(score_name, 'total')] = total_results205
206for group_index, group_values in group_results.items():207group_name = self.interval_names[group_index]208results[(score_name, group_name)] = group_values209
210if self.integral_func is not None:211results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))212
213LOGGER.info(f'{type(self)}: reset scores')214self.groups = []215for sc in self.scores.values():216sc.reset()217LOGGER.info(f'{type(self)}: reset scores done')218
219LOGGER.info(f'{type(self)}: evaluation_end done')220return results221