lama

Форк
0
/
gen_mask_dataset_hydra.py 
124 строки · 5.2 Кб
1
#!/usr/bin/env python3
2

3
import glob
4
import os
5
import shutil
6
import traceback
7
import hydra
8
from omegaconf import OmegaConf
9

10
import PIL.Image as Image
11
import numpy as np
12
from joblib import Parallel, delayed
13

14
from saicinpainting.evaluation.masks.mask import SegmentationMask, propose_random_square_crop
15
from saicinpainting.evaluation.utils import load_yaml, SmallMode
16
from saicinpainting.training.data.masks import MixedMaskGenerator
17

18

19
class MakeManyMasksWrapper:
20
    def __init__(self, impl, variants_n=2):
21
        self.impl = impl
22
        self.variants_n = variants_n
23

24
    def get_masks(self, img):
25
        img = np.transpose(np.array(img), (2, 0, 1))
26
        return [self.impl(img)[0] for _ in range(self.variants_n)]
27

28

29
def process_images(src_images, indir, outdir, config):
30
    if config.generator_kind == 'segmentation':
31
        mask_generator = SegmentationMask(**config.mask_generator_kwargs)
32
    elif config.generator_kind == 'random':
33
        mask_generator_kwargs = OmegaConf.to_container(config.mask_generator_kwargs, resolve=True)
34
        variants_n = mask_generator_kwargs.pop('variants_n', 2)
35
        mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**mask_generator_kwargs),
36
                                              variants_n=variants_n)
37
    else:
38
        raise ValueError(f'Unexpected generator kind: {config.generator_kind}')
39

40
    max_tamper_area = config.get('max_tamper_area', 1)
41

42
    for infile in src_images:
43
        try:
44
            file_relpath = infile[len(indir):]
45
            img_outpath = os.path.join(outdir, file_relpath)
46
            os.makedirs(os.path.dirname(img_outpath), exist_ok=True)
47

48
            image = Image.open(infile).convert('RGB')
49

50
            # scale input image to output resolution and filter smaller images
51
            if min(image.size) < config.cropping.out_min_size:
52
                handle_small_mode = SmallMode(config.cropping.handle_small_mode)
53
                if handle_small_mode == SmallMode.DROP:
54
                    continue
55
                elif handle_small_mode == SmallMode.UPSCALE:
56
                    factor = config.cropping.out_min_size / min(image.size)
57
                    out_size = (np.array(image.size) * factor).round().astype('uint32')
58
                    image = image.resize(out_size, resample=Image.BICUBIC)
59
            else:
60
                factor = config.cropping.out_min_size / min(image.size)
61
                out_size = (np.array(image.size) * factor).round().astype('uint32')
62
                image = image.resize(out_size, resample=Image.BICUBIC)
63

64
            # generate and select masks
65
            src_masks = mask_generator.get_masks(image)
66

67
            filtered_image_mask_pairs = []
68
            for cur_mask in src_masks:
69
                if config.cropping.out_square_crop:
70
                    (crop_left,
71
                     crop_top,
72
                     crop_right,
73
                     crop_bottom) = propose_random_square_crop(cur_mask,
74
                                                               min_overlap=config.cropping.crop_min_overlap)
75
                    cur_mask = cur_mask[crop_top:crop_bottom, crop_left:crop_right]
76
                    cur_image = image.copy().crop((crop_left, crop_top, crop_right, crop_bottom))
77
                else:
78
                    cur_image = image
79

80
                if len(np.unique(cur_mask)) == 0 or cur_mask.mean() > max_tamper_area:
81
                    continue
82

83
                filtered_image_mask_pairs.append((cur_image, cur_mask))
84

85
            mask_indices = np.random.choice(len(filtered_image_mask_pairs),
86
                                            size=min(len(filtered_image_mask_pairs), config.max_masks_per_image),
87
                                            replace=False)
88

89
            # crop masks; save masks together with input image
90
            mask_basename = os.path.join(outdir, os.path.splitext(file_relpath)[0])
91
            for i, idx in enumerate(mask_indices):
92
                cur_image, cur_mask = filtered_image_mask_pairs[idx]
93
                cur_basename = mask_basename + f'_crop{i:03d}'
94
                Image.fromarray(np.clip(cur_mask * 255, 0, 255).astype('uint8'),
95
                                mode='L').save(cur_basename + f'_mask{i:03d}.png')
96
                cur_image.save(cur_basename + '.png')
97
        except KeyboardInterrupt:
98
            return
99
        except Exception as ex:
100
            print(f'Could not make masks for {infile} due to {ex}:\n{traceback.format_exc()}')
101

102

103
@hydra.main(config_path='../configs/data_gen/whydra', config_name='random_medium_256.yaml')
104
def main(config: OmegaConf):
105
    if not config.indir.endswith('/'):
106
        config.indir += '/'
107

108
    os.makedirs(config.outdir, exist_ok=True)
109

110
    in_files = list(glob.glob(os.path.join(config.indir, '**', f'*.{config.location.extension}'),
111
                              recursive=True))
112
    if config.n_jobs == 0:
113
        process_images(in_files, config.indir, config.outdir, config)
114
    else:
115
        in_files_n = len(in_files)
116
        chunk_size = in_files_n // config.n_jobs + (1 if in_files_n % config.n_jobs > 0 else 0)
117
        Parallel(n_jobs=config.n_jobs)(
118
            delayed(process_images)(in_files[start:start+chunk_size], config.indir, config.outdir, config)
119
            for start in range(0, len(in_files), chunk_size)
120
        )
121

122

123
if __name__ == '__main__':
124
    main()
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.