HairFastGAN
109 строк · 3.7 Кб
1import os2import numpy as np3import torch4import torch.nn as nn5import torch.nn.functional as F6import torch.utils.data as data7
8from PIL import Image9from torch.autograd import grad10
11
12def clip_img(x):13"""Clip stylegan generated image to range(0,1)"""14img_tmp = x.clone()[0]15img_tmp = (img_tmp + 1) / 216img_tmp = torch.clamp(img_tmp, 0, 1)17return [img_tmp.detach().cpu()]18
19def tensor_byte(x):20return x.element_size()*x.nelement()21
22def count_parameters(net):23s = sum([np.prod(list(mm.size())) for mm in net.parameters()])24print(s)25
26def stylegan_to_classifier(x, out_size=(224, 224)):27"""Clip image to range(0,1)"""28img_tmp = x.clone()29img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1)30img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear')31img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.22932img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.22433img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.22534#img_tmp = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tmp)35return img_tmp36
37def downscale(x, scale_times=1, mode='bilinear'):38for i in range(scale_times):39x = F.interpolate(x, scale_factor=0.5, mode=mode)40return x41
42def upscale(x, scale_times=1, mode='bilinear'):43for i in range(scale_times):44x = F.interpolate(x, scale_factor=2, mode=mode)45return x46
47def hist_transform(source_tensor, target_tensor):48"""Histogram transformation"""49c, h, w = source_tensor.size()50s_t = source_tensor.view(c, -1)51t_t = target_tensor.view(c, -1)52s_t_sorted, s_t_indices = torch.sort(s_t)53t_t_sorted, t_t_indices = torch.sort(t_t)54for i in range(c):55s_t[i, s_t_indices[i]] = t_t_sorted[i]56return s_t.view(c, h, w)57
58def init_weights(m):59"""Initialize layers with Xavier uniform distribution"""60if type(m) == nn.Conv2d:61nn.init.xavier_uniform_(m.weight)62elif type(m) == nn.Linear:63nn.init.uniform_(m.weight, 0.0, 1.0)64if m.bias is not None:65nn.init.constant_(m.bias, 0.01)66
67def total_variation(x, delta=1):68"""Total variation, x: tensor of size (B, C, H, W)"""69out = torch.mean(torch.abs(x[:, :, :, :-delta] - x[:, :, :, delta:]))\70+ torch.mean(torch.abs(x[:, :, :-delta, :] - x[:, :, delta:, :]))71return out72
73def vgg_transform(x):74"""Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean"""75r, g, b = torch.split(x, 1, 1)76out = torch.cat((b, g, r), dim = 1)77out = F.interpolate(out, size=(224, 224), mode='bilinear')78out = out*255.79return out80
81# warp image with flow
82def normalize_axis(x,L):83return (x-1-(L-1)/2)*2/(L-1)84
85def unnormalize_axis(x,L):86return x*(L-1)/2+1+(L-1)/287
88def torch_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False):89b,c,h_tgt,w_tgt=flow.size()90grid_y, grid_x = torch.meshgrid(torch.tensor(range(1,w_tgt+1)),torch.tensor(range(1,h_tgt+1)))91disp_x=flow[:,0,:,:]92disp_y=flow[:,1,:,:]93source_x=grid_x.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_x94source_y=grid_y.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_y95source_x_norm=normalize_axis(source_x,w_src)96source_y_norm=normalize_axis(source_y,h_src)97sampling_grid=torch.cat((source_x_norm.unsqueeze(3), source_y_norm.unsqueeze(3)), dim=3)98if use_cuda:99sampling_grid = sampling_grid.cuda()100return sampling_grid101
102def warp_image_torch(image, flow):103"""104Warp image (tensor, shape=[b, 3, h_src, w_src]) with flow (tensor, shape=[b, h_tgt, w_tgt, 2])
105"""
106b,c,h_src,w_src=image.size()107sampling_grid_torch = torch_flow_to_th_sampling_grid(flow, h_src, w_src)108warped_image_torch = F.grid_sample(image, sampling_grid_torch)109return warped_image_torch