lama

Форк
0
220 строк · 9.5 Кб
1
import logging
2
import math
3
from typing import Dict
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8
import tqdm
9
from torch.utils.data import DataLoader
10

11
from saicinpainting.evaluation.utils import move_to_device
12

13
LOGGER = logging.getLogger(__name__)
14

15

16
class InpaintingEvaluator():
17
    def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
18
                 integral_func=None, integral_title=None, clamp_image_range=None):
19
        """
20
        :param dataset: torch.utils.data.Dataset which contains images and masks
21
        :param scores: dict {score_name: EvaluatorScore object}
22
        :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
23
            which are defined by share of area occluded by mask
24
        :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
25
        :param batch_size: batch_size for the dataloader
26
        :param device: device to use
27
        """
28
        self.scores = scores
29
        self.dataset = dataset
30

31
        self.area_grouping = area_grouping
32
        self.bins = bins
33

34
        self.device = torch.device(device)
35

36
        self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
37

38
        self.integral_func = integral_func
39
        self.integral_title = integral_title
40
        self.clamp_image_range = clamp_image_range
41

42
    def _get_bin_edges(self):
43
        bin_edges = np.linspace(0, 1, self.bins + 1)
44

45
        num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
46
        interval_names = []
47
        for idx_bin in range(self.bins):
48
            start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
49
                                         round(100 * bin_edges[idx_bin + 1], num_digits)
50
            start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
51
            end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
52
            interval_names.append("{0}-{1}%".format(start_percent, end_percent))
53

54
        groups = []
55
        for batch in self.dataloader:
56
            mask = batch['mask']
57
            batch_size = mask.shape[0]
58
            area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
59
            bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
60
            # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
61
            bin_indices[bin_indices == self.bins] = self.bins - 1
62
            groups.append(bin_indices)
63
        groups = np.hstack(groups)
64

65
        return groups, interval_names
66

67
    def evaluate(self, model=None):
68
        """
69
        :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
70
        :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
71
            name of the particular group arranged by area of mask (e.g. '10-20%')
72
            and score statistics for the group as values.
73
        """
74
        results = dict()
75
        if self.area_grouping:
76
            groups, interval_names = self._get_bin_edges()
77
        else:
78
            groups = None
79

80
        for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
81
            score.to(self.device)
82
            with torch.no_grad():
83
                score.reset()
84
                for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
85
                    batch = move_to_device(batch, self.device)
86
                    image_batch, mask_batch = batch['image'], batch['mask']
87
                    if self.clamp_image_range is not None:
88
                        image_batch = torch.clamp(image_batch,
89
                                                  min=self.clamp_image_range[0],
90
                                                  max=self.clamp_image_range[1])
91
                    if model is None:
92
                        assert 'inpainted' in batch, \
93
                            'Model is None, so we expected precomputed inpainting results at key "inpainted"'
94
                        inpainted_batch = batch['inpainted']
95
                    else:
96
                        inpainted_batch = model(image_batch, mask_batch)
97
                    score(inpainted_batch, image_batch, mask_batch)
98
                total_results, group_results = score.get_value(groups=groups)
99

100
            results[(score_name, 'total')] = total_results
101
            if groups is not None:
102
                for group_index, group_values in group_results.items():
103
                    group_name = interval_names[group_index]
104
                    results[(score_name, group_name)] = group_values
105

106
        if self.integral_func is not None:
107
            results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
108

109
        return results
110

111

112
def ssim_fid100_f1(metrics, fid_scale=100):
113
    ssim = metrics[('ssim', 'total')]['mean']
114
    fid = metrics[('fid', 'total')]['mean']
115
    fid_rel = max(0, fid_scale - fid) / fid_scale
116
    f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
117
    return f1
118

119

120
def lpips_fid100_f1(metrics, fid_scale=100):
121
    neg_lpips = 1 - metrics[('lpips', 'total')]['mean']  # invert, so bigger is better
122
    fid = metrics[('fid', 'total')]['mean']
123
    fid_rel = max(0, fid_scale - fid) / fid_scale
124
    f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
125
    return f1
126

127

128

129
class InpaintingEvaluatorOnline(nn.Module):
130
    def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
131
                 integral_func=None, integral_title=None, clamp_image_range=None):
132
        """
133
        :param scores: dict {score_name: EvaluatorScore object}
134
        :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
135
        :param device: device to use
136
        """
137
        super().__init__()
138
        LOGGER.info(f'{type(self)} init called')
139
        self.scores = nn.ModuleDict(scores)
140
        self.image_key = image_key
141
        self.inpainted_key = inpainted_key
142
        self.bins_num = bins
143
        self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
144

145
        num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
146
        self.interval_names = []
147
        for idx_bin in range(self.bins_num):
148
            start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
149
                                         round(100 * self.bin_edges[idx_bin + 1], num_digits)
150
            start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
151
            end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
152
            self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
153

154
        self.groups = []
155

156
        self.integral_func = integral_func
157
        self.integral_title = integral_title
158
        self.clamp_image_range = clamp_image_range
159

160
        LOGGER.info(f'{type(self)} init done')
161

162
    def _get_bins(self, mask_batch):
163
        batch_size = mask_batch.shape[0]
164
        area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
165
        bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
166
        return bin_indices
167

168
    def forward(self, batch: Dict[str, torch.Tensor]):
169
        """
170
        Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
171
        :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
172
        """
173
        result = {}
174
        with torch.no_grad():
175
            image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
176
            if self.clamp_image_range is not None:
177
                image_batch = torch.clamp(image_batch,
178
                                          min=self.clamp_image_range[0],
179
                                          max=self.clamp_image_range[1])
180
            self.groups.extend(self._get_bins(mask_batch))
181

182
            for score_name, score in self.scores.items():
183
                result[score_name] = score(inpainted_batch, image_batch, mask_batch)
184
        return result
185

186
    def process_batch(self, batch: Dict[str, torch.Tensor]):
187
        return self(batch)
188

189
    def evaluation_end(self, states=None):
190
        """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
191
            name of the particular group arranged by area of mask (e.g. '10-20%')
192
            and score statistics for the group as values.
193
        """
194
        LOGGER.info(f'{type(self)}: evaluation_end called')
195

196
        self.groups = np.array(self.groups)
197

198
        results = {}
199
        for score_name, score in self.scores.items():
200
            LOGGER.info(f'Getting value of {score_name}')
201
            cur_states = [s[score_name] for s in states] if states is not None else None
202
            total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
203
            LOGGER.info(f'Getting value of {score_name} done')
204
            results[(score_name, 'total')] = total_results
205

206
            for group_index, group_values in group_results.items():
207
                group_name = self.interval_names[group_index]
208
                results[(score_name, group_name)] = group_values
209

210
        if self.integral_func is not None:
211
            results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
212

213
        LOGGER.info(f'{type(self)}: reset scores')
214
        self.groups = []
215
        for sc in self.scores.values():
216
            sc.reset()
217
        LOGGER.info(f'{type(self)}: reset scores done')
218

219
        LOGGER.info(f'{type(self)}: evaluation_end done')
220
        return results
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.