lama

Форк
0
314 строк · 11.5 Кб
1
import torch
2
import torch.nn as nn
3
from torch.optim import Adam, SGD 
4
from kornia.filters import gaussian_blur2d
5
from kornia.geometry.transform import resize
6
from kornia.morphology import erosion
7
from torch.nn import functional as F
8
import numpy as np
9
import cv2
10

11
from saicinpainting.evaluation.data import pad_tensor_to_modulo
12
from saicinpainting.evaluation.utils import move_to_device
13
from saicinpainting.training.modules.ffc import FFCResnetBlock
14
from saicinpainting.training.modules.pix2pixhd import ResnetBlock
15

16
from tqdm import tqdm
17

18

19
def _pyrdown(im : torch.Tensor, downsize : tuple=None):
20
    """downscale the image"""
21
    if downsize is None:
22
        downsize = (im.shape[2]//2, im.shape[3]//2)
23
    assert im.shape[1] == 3, "Expected shape for the input to be (n,3,height,width)"
24
    im = gaussian_blur2d(im, kernel_size=(5,5), sigma=(1.0,1.0))
25
    im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False)
26
    return im
27

28
def _pyrdown_mask(mask : torch.Tensor, downsize : tuple=None, eps : float=1e-8, blur_mask : bool=True, round_up : bool=True):
29
    """downscale the mask tensor
30

31
    Parameters
32
    ----------
33
    mask : torch.Tensor
34
        mask of size (B, 1, H, W)
35
    downsize : tuple, optional
36
        size to downscale to. If None, image is downscaled to half, by default None
37
    eps : float, optional
38
        threshold value for binarizing the mask, by default 1e-8
39
    blur_mask : bool, optional
40
        if True, apply gaussian filter before downscaling, by default True
41
    round_up : bool, optional
42
        if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True
43

44
    Returns
45
    -------
46
    torch.Tensor
47
        downscaled mask
48
    """
49

50
    if downsize is None:
51
        downsize = (mask.shape[2]//2, mask.shape[3]//2)
52
    assert mask.shape[1] == 1, "Expected shape for the input to be (n,1,height,width)"
53
    if blur_mask == True:
54
        mask = gaussian_blur2d(mask, kernel_size=(5,5), sigma=(1.0,1.0))
55
        mask = F.interpolate(mask, size=downsize,  mode='bilinear', align_corners=False)
56
    else:
57
        mask = F.interpolate(mask, size=downsize,  mode='bilinear', align_corners=False)
58
    if round_up:
59
        mask[mask>=eps] = 1
60
        mask[mask<eps] = 0
61
    else:
62
        mask[mask>=1.0-eps] = 1
63
        mask[mask<1.0-eps] = 0
64
    return mask
65

66
def _erode_mask(mask : torch.Tensor, ekernel : torch.Tensor=None, eps : float=1e-8):
67
    """erode the mask, and set gray pixels to 0"""
68
    if ekernel is not None:
69
        mask = erosion(mask, ekernel)
70
        mask[mask>=1.0-eps] = 1
71
        mask[mask<1.0-eps] = 0
72
    return mask
73

74

75
def _l1_loss(
76
    pred : torch.Tensor, pred_downscaled : torch.Tensor, ref : torch.Tensor, 
77
    mask : torch.Tensor, mask_downscaled : torch.Tensor, 
78
    image : torch.Tensor, on_pred : bool=True
79
    ):
80
    """l1 loss on src pixels, and downscaled predictions if on_pred=True"""
81
    loss = torch.mean(torch.abs(pred[mask<1e-8] - image[mask<1e-8]))
82
    if on_pred: 
83
        loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8]))                
84
    return loss
85

86
def _infer(
87
    image : torch.Tensor, mask : torch.Tensor, 
88
    forward_front : nn.Module, forward_rears : nn.Module, 
89
    ref_lower_res : torch.Tensor, orig_shape : tuple, devices : list, 
90
    scale_ind : int, n_iters : int=15, lr : float=0.002):
91
    """Performs inference with refinement at a given scale.
92

93
    Parameters
94
    ----------
95
    image : torch.Tensor
96
        input image to be inpainted, of size (1,3,H,W)
97
    mask : torch.Tensor
98
        input inpainting mask, of size (1,1,H,W) 
99
    forward_front : nn.Module
100
        the front part of the inpainting network
101
    forward_rears : nn.Module
102
        the rear part of the inpainting network
103
    ref_lower_res : torch.Tensor
104
        the inpainting at previous scale, used as reference image
105
    orig_shape : tuple
106
        shape of the original input image before padding
107
    devices : list
108
        list of available devices
109
    scale_ind : int
110
        the scale index
111
    n_iters : int, optional
112
        number of iterations of refinement, by default 15
113
    lr : float, optional
114
        learning rate, by default 0.002
115

116
    Returns
117
    -------
118
    torch.Tensor
119
        inpainted image
120
    """
121
    masked_image = image * (1 - mask)
122
    masked_image = torch.cat([masked_image, mask], dim=1)
123

124
    mask = mask.repeat(1,3,1,1)
125
    if ref_lower_res is not None:
126
        ref_lower_res = ref_lower_res.detach()
127
    with torch.no_grad():
128
        z1,z2 = forward_front(masked_image)
129
    # Inference
130
    mask = mask.to(devices[-1])
131
    ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
132
    ekernel = ekernel.to(devices[-1])
133
    image = image.to(devices[-1])
134
    z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
135
    z1.requires_grad, z2.requires_grad = True, True
136

137
    optimizer = Adam([z1,z2], lr=lr)
138

139
    pbar = tqdm(range(n_iters), leave=False)
140
    for idi in pbar:
141
        optimizer.zero_grad()
142
        input_feat = (z1,z2)
143
        for idd, forward_rear in enumerate(forward_rears):
144
            output_feat = forward_rear(input_feat)
145
            if idd < len(devices) - 1:
146
                midz1, midz2 = output_feat
147
                midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1])
148
                input_feat = (midz1, midz2)
149
            else:        
150
                pred = output_feat
151

152
        if ref_lower_res is None:
153
            break
154
        losses = {}
155
        ######################### multi-scale #############################
156
        # scaled loss with downsampler
157
        pred_downscaled = _pyrdown(pred[:,:,:orig_shape[0],:orig_shape[1]])
158
        mask_downscaled = _pyrdown_mask(mask[:,:1,:orig_shape[0],:orig_shape[1]], blur_mask=False, round_up=False)
159
        mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
160
        mask_downscaled = mask_downscaled.repeat(1,3,1,1)
161
        losses["ms_l1"] = _l1_loss(pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True)
162

163
        loss = sum(losses.values())
164
        pbar.set_description("Refining scale {} using scale {} ...current loss: {:.4f}".format(scale_ind+1, scale_ind, loss.item()))
165
        if idi < n_iters - 1:
166
            loss.backward()
167
            optimizer.step()
168
            del pred_downscaled
169
            del loss
170
            del pred
171
    # "pred" is the prediction after Plug-n-Play module
172
    inpainted = mask * pred + (1 - mask) * image
173
    inpainted = inpainted.detach().cpu()
174
    return inpainted
175

176
def _get_image_mask_pyramid(batch : dict, min_side : int, max_scales : int, px_budget : int):
177
    """Build the image mask pyramid
178

179
    Parameters
180
    ----------
181
    batch : dict
182
        batch containing image, mask, etc
183
    min_side : int
184
        minimum side length to limit the number of scales of the pyramid 
185
    max_scales : int
186
        maximum number of scales allowed
187
    px_budget : int
188
        the product H*W cannot exceed this budget, because of resource constraints
189

190
    Returns
191
    -------
192
    tuple
193
        image-mask pyramid in the form of list of images and list of masks
194
    """
195

196
    assert batch['image'].shape[0] == 1, "refiner works on only batches of size 1!"
197

198
    h, w = batch['unpad_to_size']
199
    h, w = h[0].item(), w[0].item()
200

201
    image = batch['image'][...,:h,:w]
202
    mask = batch['mask'][...,:h,:w]
203
    if h*w > px_budget:
204
        #resize 
205
        ratio = np.sqrt(px_budget / float(h*w))
206
        h_orig, w_orig = h, w
207
        h,w = int(h*ratio), int(w*ratio)
208
        print(f"Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...")
209
        image = resize(image, (h,w),interpolation='bilinear', align_corners=False)
210
        mask = resize(mask, (h,w),interpolation='bilinear', align_corners=False)
211
        mask[mask>1e-8] = 1        
212
    breadth = min(h,w)
213
    n_scales = min(1 + int(round(max(0,np.log2(breadth / min_side)))), max_scales)        
214
    ls_images = []
215
    ls_masks = []
216
    
217
    ls_images.append(image)
218
    ls_masks.append(mask)
219
    
220
    for _ in range(n_scales - 1):
221
        image_p = _pyrdown(ls_images[-1])
222
        mask_p = _pyrdown_mask(ls_masks[-1])
223
        ls_images.append(image_p)
224
        ls_masks.append(mask_p)
225
    # reverse the lists because we want the lowest resolution image as index 0
226
    return ls_images[::-1], ls_masks[::-1]
227

228
def refine_predict(
229
    batch : dict, inpainter : nn.Module, gpu_ids : str, 
230
    modulo : int, n_iters : int, lr : float, min_side : int, 
231
    max_scales : int, px_budget : int
232
    ):
233
    """Refines the inpainting of the network
234

235
    Parameters
236
    ----------
237
    batch : dict
238
        image-mask batch, currently we assume the batchsize to be 1
239
    inpainter : nn.Module
240
        the inpainting neural network
241
    gpu_ids : str
242
        the GPU ids of the machine to use. If only single GPU, use: "0,"
243
    modulo : int
244
        pad the image to ensure dimension % modulo == 0
245
    n_iters : int
246
        number of iterations of refinement for each scale
247
    lr : float
248
        learning rate
249
    min_side : int
250
        all sides of image on all scales should be >= min_side / sqrt(2)
251
    max_scales : int
252
        max number of downscaling scales for the image-mask pyramid
253
    px_budget : int
254
        pixels budget. Any image will be resized to satisfy height*width <= px_budget
255

256
    Returns
257
    -------
258
    torch.Tensor
259
        inpainted image of size (1,3,H,W)
260
    """
261

262
    assert not inpainter.training
263
    assert not inpainter.add_noise_kwargs
264
    assert inpainter.concat_mask
265

266
    gpu_ids = [f'cuda:{gpuid}' for gpuid in gpu_ids.replace(" ","").split(",") if gpuid.isdigit()]
267
    n_resnet_blocks = 0
268
    first_resblock_ind = 0
269
    found_first_resblock = False
270
    for idl in range(len(inpainter.generator.model)):
271
        if isinstance(inpainter.generator.model[idl], FFCResnetBlock) or isinstance(inpainter.generator.model[idl], ResnetBlock):
272
            n_resnet_blocks += 1
273
            found_first_resblock = True
274
        elif not found_first_resblock:
275
            first_resblock_ind += 1
276
    resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
277

278
    devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
279
    
280
    # split the model into front, and rear parts    
281
    forward_front = inpainter.generator.model[0:first_resblock_ind]
282
    forward_front.to(devices[0])
283
    forward_rears = []
284
    for idd in range(len(gpu_ids)):
285
        if idd < len(gpu_ids) - 1:
286
            forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):first_resblock_ind+resblocks_per_gpu*(idd+1)]) 
287
        else:
288
            forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):]) 
289
        forward_rears[idd].to(devices[idd]) 
290

291
    ls_images, ls_masks = _get_image_mask_pyramid(
292
        batch, 
293
        min_side, 
294
        max_scales, 
295
        px_budget
296
        )
297
    image_inpainted = None
298

299
    for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
300
        orig_shape = image.shape[2:]
301
        image = pad_tensor_to_modulo(image, modulo)
302
        mask = pad_tensor_to_modulo(mask, modulo)
303
        mask[mask >= 1e-8] = 1.0
304
        mask[mask < 1e-8] = 0.0
305
        image, mask = move_to_device(image, devices[0]), move_to_device(mask, devices[0])
306
        if image_inpainted is not None:
307
            image_inpainted = move_to_device(image_inpainted, devices[-1])
308
        image_inpainted = _infer(image, mask, forward_front, forward_rears, image_inpainted, orig_shape, devices, ids, n_iters, lr)
309
        image_inpainted = image_inpainted[:,:,:orig_shape[0], :orig_shape[1]]
310
        # detach everything to save resources
311
        image = image.detach().cpu()
312
        mask = mask.detach().cpu()
313
    
314
    return image_inpainted
315

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.