lama
314 строк · 11.5 Кб
1import torch
2import torch.nn as nn
3from torch.optim import Adam, SGD
4from kornia.filters import gaussian_blur2d
5from kornia.geometry.transform import resize
6from kornia.morphology import erosion
7from torch.nn import functional as F
8import numpy as np
9import cv2
10
11from saicinpainting.evaluation.data import pad_tensor_to_modulo
12from saicinpainting.evaluation.utils import move_to_device
13from saicinpainting.training.modules.ffc import FFCResnetBlock
14from saicinpainting.training.modules.pix2pixhd import ResnetBlock
15
16from tqdm import tqdm
17
18
19def _pyrdown(im : torch.Tensor, downsize : tuple=None):
20"""downscale the image"""
21if downsize is None:
22downsize = (im.shape[2]//2, im.shape[3]//2)
23assert im.shape[1] == 3, "Expected shape for the input to be (n,3,height,width)"
24im = gaussian_blur2d(im, kernel_size=(5,5), sigma=(1.0,1.0))
25im = F.interpolate(im, size=downsize, mode='bilinear', align_corners=False)
26return im
27
28def _pyrdown_mask(mask : torch.Tensor, downsize : tuple=None, eps : float=1e-8, blur_mask : bool=True, round_up : bool=True):
29"""downscale the mask tensor
30
31Parameters
32----------
33mask : torch.Tensor
34mask of size (B, 1, H, W)
35downsize : tuple, optional
36size to downscale to. If None, image is downscaled to half, by default None
37eps : float, optional
38threshold value for binarizing the mask, by default 1e-8
39blur_mask : bool, optional
40if True, apply gaussian filter before downscaling, by default True
41round_up : bool, optional
42if True, values above eps are marked 1, else, values below 1-eps are marked 0, by default True
43
44Returns
45-------
46torch.Tensor
47downscaled mask
48"""
49
50if downsize is None:
51downsize = (mask.shape[2]//2, mask.shape[3]//2)
52assert mask.shape[1] == 1, "Expected shape for the input to be (n,1,height,width)"
53if blur_mask == True:
54mask = gaussian_blur2d(mask, kernel_size=(5,5), sigma=(1.0,1.0))
55mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False)
56else:
57mask = F.interpolate(mask, size=downsize, mode='bilinear', align_corners=False)
58if round_up:
59mask[mask>=eps] = 1
60mask[mask<eps] = 0
61else:
62mask[mask>=1.0-eps] = 1
63mask[mask<1.0-eps] = 0
64return mask
65
66def _erode_mask(mask : torch.Tensor, ekernel : torch.Tensor=None, eps : float=1e-8):
67"""erode the mask, and set gray pixels to 0"""
68if ekernel is not None:
69mask = erosion(mask, ekernel)
70mask[mask>=1.0-eps] = 1
71mask[mask<1.0-eps] = 0
72return mask
73
74
75def _l1_loss(
76pred : torch.Tensor, pred_downscaled : torch.Tensor, ref : torch.Tensor,
77mask : torch.Tensor, mask_downscaled : torch.Tensor,
78image : torch.Tensor, on_pred : bool=True
79):
80"""l1 loss on src pixels, and downscaled predictions if on_pred=True"""
81loss = torch.mean(torch.abs(pred[mask<1e-8] - image[mask<1e-8]))
82if on_pred:
83loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8]))
84return loss
85
86def _infer(
87image : torch.Tensor, mask : torch.Tensor,
88forward_front : nn.Module, forward_rears : nn.Module,
89ref_lower_res : torch.Tensor, orig_shape : tuple, devices : list,
90scale_ind : int, n_iters : int=15, lr : float=0.002):
91"""Performs inference with refinement at a given scale.
92
93Parameters
94----------
95image : torch.Tensor
96input image to be inpainted, of size (1,3,H,W)
97mask : torch.Tensor
98input inpainting mask, of size (1,1,H,W)
99forward_front : nn.Module
100the front part of the inpainting network
101forward_rears : nn.Module
102the rear part of the inpainting network
103ref_lower_res : torch.Tensor
104the inpainting at previous scale, used as reference image
105orig_shape : tuple
106shape of the original input image before padding
107devices : list
108list of available devices
109scale_ind : int
110the scale index
111n_iters : int, optional
112number of iterations of refinement, by default 15
113lr : float, optional
114learning rate, by default 0.002
115
116Returns
117-------
118torch.Tensor
119inpainted image
120"""
121masked_image = image * (1 - mask)
122masked_image = torch.cat([masked_image, mask], dim=1)
123
124mask = mask.repeat(1,3,1,1)
125if ref_lower_res is not None:
126ref_lower_res = ref_lower_res.detach()
127with torch.no_grad():
128z1,z2 = forward_front(masked_image)
129# Inference
130mask = mask.to(devices[-1])
131ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
132ekernel = ekernel.to(devices[-1])
133image = image.to(devices[-1])
134z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
135z1.requires_grad, z2.requires_grad = True, True
136
137optimizer = Adam([z1,z2], lr=lr)
138
139pbar = tqdm(range(n_iters), leave=False)
140for idi in pbar:
141optimizer.zero_grad()
142input_feat = (z1,z2)
143for idd, forward_rear in enumerate(forward_rears):
144output_feat = forward_rear(input_feat)
145if idd < len(devices) - 1:
146midz1, midz2 = output_feat
147midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1])
148input_feat = (midz1, midz2)
149else:
150pred = output_feat
151
152if ref_lower_res is None:
153break
154losses = {}
155######################### multi-scale #############################
156# scaled loss with downsampler
157pred_downscaled = _pyrdown(pred[:,:,:orig_shape[0],:orig_shape[1]])
158mask_downscaled = _pyrdown_mask(mask[:,:1,:orig_shape[0],:orig_shape[1]], blur_mask=False, round_up=False)
159mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
160mask_downscaled = mask_downscaled.repeat(1,3,1,1)
161losses["ms_l1"] = _l1_loss(pred, pred_downscaled, ref_lower_res, mask, mask_downscaled, image, on_pred=True)
162
163loss = sum(losses.values())
164pbar.set_description("Refining scale {} using scale {} ...current loss: {:.4f}".format(scale_ind+1, scale_ind, loss.item()))
165if idi < n_iters - 1:
166loss.backward()
167optimizer.step()
168del pred_downscaled
169del loss
170del pred
171# "pred" is the prediction after Plug-n-Play module
172inpainted = mask * pred + (1 - mask) * image
173inpainted = inpainted.detach().cpu()
174return inpainted
175
176def _get_image_mask_pyramid(batch : dict, min_side : int, max_scales : int, px_budget : int):
177"""Build the image mask pyramid
178
179Parameters
180----------
181batch : dict
182batch containing image, mask, etc
183min_side : int
184minimum side length to limit the number of scales of the pyramid
185max_scales : int
186maximum number of scales allowed
187px_budget : int
188the product H*W cannot exceed this budget, because of resource constraints
189
190Returns
191-------
192tuple
193image-mask pyramid in the form of list of images and list of masks
194"""
195
196assert batch['image'].shape[0] == 1, "refiner works on only batches of size 1!"
197
198h, w = batch['unpad_to_size']
199h, w = h[0].item(), w[0].item()
200
201image = batch['image'][...,:h,:w]
202mask = batch['mask'][...,:h,:w]
203if h*w > px_budget:
204#resize
205ratio = np.sqrt(px_budget / float(h*w))
206h_orig, w_orig = h, w
207h,w = int(h*ratio), int(w*ratio)
208print(f"Original image too large for refinement! Resizing {(h_orig,w_orig)} to {(h,w)}...")
209image = resize(image, (h,w),interpolation='bilinear', align_corners=False)
210mask = resize(mask, (h,w),interpolation='bilinear', align_corners=False)
211mask[mask>1e-8] = 1
212breadth = min(h,w)
213n_scales = min(1 + int(round(max(0,np.log2(breadth / min_side)))), max_scales)
214ls_images = []
215ls_masks = []
216
217ls_images.append(image)
218ls_masks.append(mask)
219
220for _ in range(n_scales - 1):
221image_p = _pyrdown(ls_images[-1])
222mask_p = _pyrdown_mask(ls_masks[-1])
223ls_images.append(image_p)
224ls_masks.append(mask_p)
225# reverse the lists because we want the lowest resolution image as index 0
226return ls_images[::-1], ls_masks[::-1]
227
228def refine_predict(
229batch : dict, inpainter : nn.Module, gpu_ids : str,
230modulo : int, n_iters : int, lr : float, min_side : int,
231max_scales : int, px_budget : int
232):
233"""Refines the inpainting of the network
234
235Parameters
236----------
237batch : dict
238image-mask batch, currently we assume the batchsize to be 1
239inpainter : nn.Module
240the inpainting neural network
241gpu_ids : str
242the GPU ids of the machine to use. If only single GPU, use: "0,"
243modulo : int
244pad the image to ensure dimension % modulo == 0
245n_iters : int
246number of iterations of refinement for each scale
247lr : float
248learning rate
249min_side : int
250all sides of image on all scales should be >= min_side / sqrt(2)
251max_scales : int
252max number of downscaling scales for the image-mask pyramid
253px_budget : int
254pixels budget. Any image will be resized to satisfy height*width <= px_budget
255
256Returns
257-------
258torch.Tensor
259inpainted image of size (1,3,H,W)
260"""
261
262assert not inpainter.training
263assert not inpainter.add_noise_kwargs
264assert inpainter.concat_mask
265
266gpu_ids = [f'cuda:{gpuid}' for gpuid in gpu_ids.replace(" ","").split(",") if gpuid.isdigit()]
267n_resnet_blocks = 0
268first_resblock_ind = 0
269found_first_resblock = False
270for idl in range(len(inpainter.generator.model)):
271if isinstance(inpainter.generator.model[idl], FFCResnetBlock) or isinstance(inpainter.generator.model[idl], ResnetBlock):
272n_resnet_blocks += 1
273found_first_resblock = True
274elif not found_first_resblock:
275first_resblock_ind += 1
276resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
277
278devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
279
280# split the model into front, and rear parts
281forward_front = inpainter.generator.model[0:first_resblock_ind]
282forward_front.to(devices[0])
283forward_rears = []
284for idd in range(len(gpu_ids)):
285if idd < len(gpu_ids) - 1:
286forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):first_resblock_ind+resblocks_per_gpu*(idd+1)])
287else:
288forward_rears.append(inpainter.generator.model[first_resblock_ind + resblocks_per_gpu*(idd):])
289forward_rears[idd].to(devices[idd])
290
291ls_images, ls_masks = _get_image_mask_pyramid(
292batch,
293min_side,
294max_scales,
295px_budget
296)
297image_inpainted = None
298
299for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
300orig_shape = image.shape[2:]
301image = pad_tensor_to_modulo(image, modulo)
302mask = pad_tensor_to_modulo(mask, modulo)
303mask[mask >= 1e-8] = 1.0
304mask[mask < 1e-8] = 0.0
305image, mask = move_to_device(image, devices[0]), move_to_device(mask, devices[0])
306if image_inpainted is not None:
307image_inpainted = move_to_device(image_inpainted, devices[-1])
308image_inpainted = _infer(image, mask, forward_front, forward_rears, image_inpainted, orig_shape, devices, ids, n_iters, lr)
309image_inpainted = image_inpainted[:,:,:orig_shape[0], :orig_shape[1]]
310# detach everything to save resources
311image = image.detach().cpu()
312mask = mask.detach().cpu()
313
314return image_inpainted
315