lama

Форк
0
/
analyze_errors.py 
316 строк · 17.3 Кб
1
#!/usr/bin/env python3
2
import cv2
3
import numpy as np
4
import sklearn
5
import torch
6
import os
7
import pickle
8
import pandas as pd
9
import matplotlib.pyplot as plt
10
from joblib import Parallel, delayed
11

12
from saicinpainting.evaluation.data import PrecomputedInpaintingResultsDataset, load_image
13
from saicinpainting.evaluation.losses.fid.inception import InceptionV3
14
from saicinpainting.evaluation.utils import load_yaml
15
from saicinpainting.training.visualizers.base import visualize_mask_and_images
16

17

18
def draw_score(img, score):
19
    img = np.transpose(img, (1, 2, 0))
20
    cv2.putText(img, f'{score:.2f}',
21
                (40, 40),
22
                cv2.FONT_HERSHEY_SIMPLEX,
23
                1,
24
                (0, 1, 0),
25
                thickness=3)
26
    img = np.transpose(img, (2, 0, 1))
27
    return img
28

29

30
def save_global_samples(global_mask_fnames, mask2real_fname, mask2fake_fname, out_dir, real_scores_by_fname, fake_scores_by_fname):
31
    for cur_mask_fname in global_mask_fnames:
32
        cur_real_fname = mask2real_fname[cur_mask_fname]
33
        orig_img = load_image(cur_real_fname, mode='RGB')
34
        fake_img = load_image(mask2fake_fname[cur_mask_fname], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
35
        mask = load_image(cur_mask_fname, mode='L')[None, ...]
36

37
        draw_score(orig_img, real_scores_by_fname.loc[cur_real_fname, 'real_score'])
38
        draw_score(fake_img, fake_scores_by_fname.loc[cur_mask_fname, 'fake_score'])
39

40
        cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=mask, fake=fake_img),
41
                                             keys=['image', 'fake'],
42
                                             last_without_mask=True)
43
        cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
44
        cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
45
        cv2.imwrite(os.path.join(out_dir, os.path.splitext(os.path.basename(cur_mask_fname))[0] + '.jpg'),
46
                    cur_grid)
47

48

49
def save_samples_by_real(worst_best_by_real, mask2fake_fname, fake_info, out_dir):
50
    for real_fname in worst_best_by_real.index:
51
        worst_mask_path = worst_best_by_real.loc[real_fname, 'worst']
52
        best_mask_path = worst_best_by_real.loc[real_fname, 'best']
53
        orig_img = load_image(real_fname, mode='RGB')
54
        worst_mask_img = load_image(worst_mask_path, mode='L')[None, ...]
55
        worst_fake_img = load_image(mask2fake_fname[worst_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
56
        best_mask_img = load_image(best_mask_path, mode='L')[None, ...]
57
        best_fake_img = load_image(mask2fake_fname[best_mask_path], mode='RGB')[:, :orig_img.shape[1], :orig_img.shape[2]]
58

59
        draw_score(orig_img, worst_best_by_real.loc[real_fname, 'real_score'])
60
        draw_score(worst_fake_img, worst_best_by_real.loc[real_fname, 'worst_score'])
61
        draw_score(best_fake_img, worst_best_by_real.loc[real_fname, 'best_score'])
62

63
        cur_grid = visualize_mask_and_images(dict(image=orig_img, mask=np.zeros_like(worst_mask_img),
64
                                                  worst_mask=worst_mask_img, worst_img=worst_fake_img,
65
                                                  best_mask=best_mask_img, best_img=best_fake_img),
66
                                             keys=['image', 'worst_mask', 'worst_img', 'best_mask', 'best_img'],
67
                                             rescale_keys=['worst_mask', 'best_mask'],
68
                                             last_without_mask=True)
69
        cur_grid = np.clip(cur_grid * 255, 0, 255).astype('uint8')
70
        cur_grid = cv2.cvtColor(cur_grid, cv2.COLOR_RGB2BGR)
71
        cv2.imwrite(os.path.join(out_dir,
72
                                 os.path.splitext(os.path.basename(real_fname))[0] + '.jpg'),
73
                    cur_grid)
74

75
        fig, (ax1, ax2) = plt.subplots(1, 2)
76
        cur_stat = fake_info[fake_info['real_fname'] == real_fname]
77
        cur_stat['fake_score'].hist(ax=ax1)
78
        cur_stat['real_score'].hist(ax=ax2)
79
        fig.tight_layout()
80
        fig.savefig(os.path.join(out_dir,
81
                                 os.path.splitext(os.path.basename(real_fname))[0] + '_scores.png'))
82
        plt.close(fig)
83

84

85
def extract_overlapping_masks(mask_fnames, cur_i, fake_scores_table, max_overlaps_n=2):
86
    result_pairs = []
87
    result_scores = []
88
    mask_fname_a = mask_fnames[cur_i]
89
    mask_a = load_image(mask_fname_a, mode='L')[None, ...] > 0.5
90
    cur_score_a = fake_scores_table.loc[mask_fname_a, 'fake_score']
91
    for mask_fname_b in mask_fnames[cur_i + 1:]:
92
        mask_b = load_image(mask_fname_b, mode='L')[None, ...] > 0.5
93
        if not np.any(mask_a & mask_b):
94
            continue
95
        cur_score_b = fake_scores_table.loc[mask_fname_b, 'fake_score']
96
        result_pairs.append((mask_fname_a, mask_fname_b))
97
        result_scores.append(cur_score_b - cur_score_a)
98
        if len(result_pairs) >= max_overlaps_n:
99
            break
100
    return result_pairs, result_scores
101

102

103
def main(args):
104
    config = load_yaml(args.config)
105

106
    latents_dir = os.path.join(args.outpath, 'latents')
107
    os.makedirs(latents_dir, exist_ok=True)
108
    global_worst_dir = os.path.join(args.outpath, 'global_worst')
109
    os.makedirs(global_worst_dir, exist_ok=True)
110
    global_best_dir = os.path.join(args.outpath, 'global_best')
111
    os.makedirs(global_best_dir, exist_ok=True)
112
    worst_best_by_best_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_max')
113
    os.makedirs(worst_best_by_best_worst_score_diff_max_dir, exist_ok=True)
114
    worst_best_by_best_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'best_worst_score_diff_min')
115
    os.makedirs(worst_best_by_best_worst_score_diff_min_dir, exist_ok=True)
116
    worst_best_by_real_best_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_max')
117
    os.makedirs(worst_best_by_real_best_score_diff_max_dir, exist_ok=True)
118
    worst_best_by_real_best_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_best_score_diff_min')
119
    os.makedirs(worst_best_by_real_best_score_diff_min_dir, exist_ok=True)
120
    worst_best_by_real_worst_score_diff_max_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_max')
121
    os.makedirs(worst_best_by_real_worst_score_diff_max_dir, exist_ok=True)
122
    worst_best_by_real_worst_score_diff_min_dir = os.path.join(args.outpath, 'worst_best_by_real', 'real_worst_score_diff_min')
123
    os.makedirs(worst_best_by_real_worst_score_diff_min_dir, exist_ok=True)
124

125
    if not args.only_report:
126
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
127
        inception_model = InceptionV3([block_idx]).eval().cuda()
128

129
        dataset = PrecomputedInpaintingResultsDataset(args.datadir, args.predictdir, **config.dataset_kwargs)
130

131
        real2vector_cache = {}
132

133
        real_features = []
134
        fake_features = []
135

136
        orig_fnames = []
137
        mask_fnames = []
138
        mask2real_fname = {}
139
        mask2fake_fname = {}
140

141
        for batch_i, batch in enumerate(dataset):
142
            orig_img_fname = dataset.img_filenames[batch_i]
143
            mask_fname = dataset.mask_filenames[batch_i]
144
            fake_fname = dataset.pred_filenames[batch_i]
145
            mask2real_fname[mask_fname] = orig_img_fname
146
            mask2fake_fname[mask_fname] = fake_fname
147

148
            cur_real_vector = real2vector_cache.get(orig_img_fname, None)
149
            if cur_real_vector is None:
150
                with torch.no_grad():
151
                    in_img = torch.from_numpy(batch['image'][None, ...]).cuda()
152
                    cur_real_vector = inception_model(in_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
153
                real2vector_cache[orig_img_fname] = cur_real_vector
154

155
            pred_img = torch.from_numpy(batch['inpainted'][None, ...]).cuda()
156
            cur_fake_vector = inception_model(pred_img)[0].squeeze(-1).squeeze(-1).cpu().numpy()
157

158
            real_features.append(cur_real_vector)
159
            fake_features.append(cur_fake_vector)
160

161
            orig_fnames.append(orig_img_fname)
162
            mask_fnames.append(mask_fname)
163

164
        ids_features = np.concatenate(real_features + fake_features, axis=0)
165
        ids_labels = np.array(([1] * len(real_features)) + ([0] * len(fake_features)))
166

167
        with open(os.path.join(latents_dir, 'featues.pkl'), 'wb') as f:
168
            pickle.dump(ids_features, f, protocol=3)
169
        with open(os.path.join(latents_dir, 'labels.pkl'), 'wb') as f:
170
            pickle.dump(ids_labels, f, protocol=3)
171
        with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'wb') as f:
172
            pickle.dump(orig_fnames, f, protocol=3)
173
        with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'wb') as f:
174
            pickle.dump(mask_fnames, f, protocol=3)
175
        with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'wb') as f:
176
            pickle.dump(mask2real_fname, f, protocol=3)
177
        with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'wb') as f:
178
            pickle.dump(mask2fake_fname, f, protocol=3)
179

180
        svm = sklearn.svm.LinearSVC(dual=False)
181
        svm.fit(ids_features, ids_labels)
182

183
        pred_scores = svm.decision_function(ids_features)
184
        real_scores = pred_scores[:len(real_features)]
185
        fake_scores = pred_scores[len(real_features):]
186

187
        with open(os.path.join(latents_dir, 'pred_scores.pkl'), 'wb') as f:
188
            pickle.dump(pred_scores, f, protocol=3)
189
        with open(os.path.join(latents_dir, 'real_scores.pkl'), 'wb') as f:
190
            pickle.dump(real_scores, f, protocol=3)
191
        with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'wb') as f:
192
            pickle.dump(fake_scores, f, protocol=3)
193
    else:
194
        with open(os.path.join(latents_dir, 'orig_fnames.pkl'), 'rb') as f:
195
            orig_fnames = pickle.load(f)
196
        with open(os.path.join(latents_dir, 'mask_fnames.pkl'), 'rb') as f:
197
            mask_fnames = pickle.load(f)
198
        with open(os.path.join(latents_dir, 'mask2real_fname.pkl'), 'rb') as f:
199
            mask2real_fname = pickle.load(f)
200
        with open(os.path.join(latents_dir, 'mask2fake_fname.pkl'), 'rb') as f:
201
            mask2fake_fname = pickle.load(f)
202
        with open(os.path.join(latents_dir, 'real_scores.pkl'), 'rb') as f:
203
            real_scores = pickle.load(f)
204
        with open(os.path.join(latents_dir, 'fake_scores.pkl'), 'rb') as f:
205
            fake_scores = pickle.load(f)
206

207
    real_info = pd.DataFrame(data=[dict(real_fname=fname,
208
                                        real_score=score)
209
                                   for fname, score
210
                                   in zip(orig_fnames, real_scores)])
211
    real_info.set_index('real_fname', drop=True, inplace=True)
212

213
    fake_info = pd.DataFrame(data=[dict(mask_fname=fname,
214
                                        fake_fname=mask2fake_fname[fname],
215
                                        real_fname=mask2real_fname[fname],
216
                                        fake_score=score)
217
                                   for fname, score
218
                                   in zip(mask_fnames, fake_scores)])
219
    fake_info = fake_info.join(real_info, on='real_fname', how='left')
220
    fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
221

222
    fake_stats_by_real = fake_info.groupby('real_fname')['fake_score'].describe()[['mean', 'std']].rename(
223
        {'mean': 'mean_fake_by_real', 'std': 'std_fake_by_real'}, axis=1)
224
    fake_info = fake_info.join(fake_stats_by_real, on='real_fname', rsuffix='stat_by_real')
225
    fake_info.drop_duplicates(['fake_fname', 'real_fname'], inplace=True)
226
    fake_info.to_csv(os.path.join(latents_dir, 'join_scores_table.csv'), sep='\t', index=False)
227

228
    fake_scores_table = fake_info.set_index('mask_fname')['fake_score'].to_frame()
229
    real_scores_table = fake_info.set_index('real_fname')['real_score'].drop_duplicates().to_frame()
230

231
    fig, (ax1, ax2) = plt.subplots(1, 2)
232
    ax1.hist(fake_scores)
233
    ax2.hist(real_scores)
234
    fig.tight_layout()
235
    fig.savefig(os.path.join(args.outpath, 'global_scores_hist.png'))
236
    plt.close(fig)
237

238
    global_worst_masks = fake_info.sort_values('fake_score', ascending=True)['mask_fname'].iloc[:config.take_global_top].to_list()
239
    global_best_masks = fake_info.sort_values('fake_score', ascending=False)['mask_fname'].iloc[:config.take_global_top].to_list()
240
    save_global_samples(global_worst_masks, mask2real_fname, mask2fake_fname, global_worst_dir, real_scores_table, fake_scores_table)
241
    save_global_samples(global_best_masks, mask2real_fname, mask2fake_fname, global_best_dir, real_scores_table, fake_scores_table)
242

243
    # grouped by real
244
    worst_samples_by_real = fake_info.groupby('real_fname').apply(
245
        lambda d: d.set_index('mask_fname')['fake_score'].idxmin()).to_frame().rename({0: 'worst'}, axis=1)
246
    best_samples_by_real = fake_info.groupby('real_fname').apply(
247
        lambda d: d.set_index('mask_fname')['fake_score'].idxmax()).to_frame().rename({0: 'best'}, axis=1)
248
    worst_best_by_real = pd.concat([worst_samples_by_real, best_samples_by_real], axis=1)
249

250
    worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'worst_score'}, axis=1),
251
                                                 on='worst')
252
    worst_best_by_real = worst_best_by_real.join(fake_scores_table.rename({'fake_score': 'best_score'}, axis=1),
253
                                                 on='best')
254
    worst_best_by_real = worst_best_by_real.join(real_scores_table)
255

256
    worst_best_by_real['best_worst_score_diff'] = worst_best_by_real['best_score'] - worst_best_by_real['worst_score']
257
    worst_best_by_real['real_best_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['best_score']
258
    worst_best_by_real['real_worst_score_diff'] = worst_best_by_real['real_score'] - worst_best_by_real['worst_score']
259

260
    worst_best_by_best_worst_score_diff_min = worst_best_by_real.sort_values('best_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
261
    worst_best_by_best_worst_score_diff_max = worst_best_by_real.sort_values('best_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
262
    save_samples_by_real(worst_best_by_best_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_min_dir)
263
    save_samples_by_real(worst_best_by_best_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_best_worst_score_diff_max_dir)
264

265
    worst_best_by_real_best_score_diff_min = worst_best_by_real.sort_values('real_best_score_diff', ascending=True).iloc[:config.take_worst_best_top]
266
    worst_best_by_real_best_score_diff_max = worst_best_by_real.sort_values('real_best_score_diff', ascending=False).iloc[:config.take_worst_best_top]
267
    save_samples_by_real(worst_best_by_real_best_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_min_dir)
268
    save_samples_by_real(worst_best_by_real_best_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_best_score_diff_max_dir)
269

270
    worst_best_by_real_worst_score_diff_min = worst_best_by_real.sort_values('real_worst_score_diff', ascending=True).iloc[:config.take_worst_best_top]
271
    worst_best_by_real_worst_score_diff_max = worst_best_by_real.sort_values('real_worst_score_diff', ascending=False).iloc[:config.take_worst_best_top]
272
    save_samples_by_real(worst_best_by_real_worst_score_diff_min, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_min_dir)
273
    save_samples_by_real(worst_best_by_real_worst_score_diff_max, mask2fake_fname, fake_info, worst_best_by_real_worst_score_diff_max_dir)
274

275
    # analyze what change of mask causes bigger change of score
276
    overlapping_mask_fname_pairs = []
277
    overlapping_mask_fname_score_diffs = []
278
    for cur_real_fname in orig_fnames:
279
        cur_fakes_info = fake_info[fake_info['real_fname'] == cur_real_fname]
280
        cur_mask_fnames = sorted(cur_fakes_info['mask_fname'].unique())
281

282
        cur_mask_pairs_and_scores = Parallel(args.n_jobs)(
283
            delayed(extract_overlapping_masks)(cur_mask_fnames, i, fake_scores_table)
284
            for i in range(len(cur_mask_fnames) - 1)
285
        )
286
        for cur_pairs, cur_scores in cur_mask_pairs_and_scores:
287
            overlapping_mask_fname_pairs.extend(cur_pairs)
288
            overlapping_mask_fname_score_diffs.extend(cur_scores)
289

290
    overlapping_mask_fname_pairs = np.asarray(overlapping_mask_fname_pairs)
291
    overlapping_mask_fname_score_diffs = np.asarray(overlapping_mask_fname_score_diffs)
292
    overlapping_sort_idx = np.argsort(overlapping_mask_fname_score_diffs)
293
    overlapping_mask_fname_pairs = overlapping_mask_fname_pairs[overlapping_sort_idx]
294
    overlapping_mask_fname_score_diffs = overlapping_mask_fname_score_diffs[overlapping_sort_idx]
295

296

297

298

299

300

301
if __name__ == '__main__':
302
    import argparse
303

304
    aparser = argparse.ArgumentParser()
305
    aparser.add_argument('config', type=str, help='Path to config for dataset generation')
306
    aparser.add_argument('datadir', type=str,
307
                         help='Path to folder with images and masks (output of gen_mask_dataset.py)')
308
    aparser.add_argument('predictdir', type=str,
309
                         help='Path to folder with predicts (e.g. predict_hifill_baseline.py)')
310
    aparser.add_argument('outpath', type=str, help='Where to put results')
311
    aparser.add_argument('--only-report', action='store_true',
312
                         help='Whether to skip prediction and feature extraction, '
313
                              'load all the possible latents and proceed with report only')
314
    aparser.add_argument('--n-jobs', type=int, default=8, help='how many processes to use for pair mask mining')
315

316
    main(aparser.parse_args())
317

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.