google-research
323 строки · 20.5 Кб
1--- external/stylegan/training/loss.py 2023-04-06 03:45:26.165066348 +0000
2+++ external_reference/stylegan/training/loss.py 2023-04-06 03:41:03.250603352 +0000
3@@ -14,6 +14,13 @@
4from torch_utils.ops import conv2d_gradfix
5from torch_utils.ops import upfirdn2d
6
7+import copy
8+import random
9+import torch.nn.functional as F
10+from utils import camera_util, losses, regularizers
11+from external.gsn.models.diff_augment import DiffAugment
12+from utils.utils import interpolate
13+
14#----------------------------------------------------------------------------
15
16class Loss:
17@@ -23,7 +30,11 @@
18#----------------------------------------------------------------------------
19
20class StyleGAN2Loss(Loss):
21- def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0):
22+ def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10,
23+ style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2,
24+ pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0,
25+ blur_fade_kimg=0, training_mode=None, loss_layout_kwargs=None,
26+ loss_upsampler_kwargs=None, loss_sky_kwargs=None):
27super().__init__()
28self.device = device
29self.G = G
30@@ -38,18 +49,73 @@
31self.pl_mean = torch.zeros([], device=device)
32self.blur_init_sigma = blur_init_sigma
33self.blur_fade_kimg = blur_fade_kimg
34+ self.loss_layout_kwargs = loss_layout_kwargs
35+ self.loss_upsampler_kwargs = loss_upsampler_kwargs
36+ self.loss_sky_kwargs = loss_sky_kwargs
37+ self.training_mode = training_mode
38+ self.loss_l1 = losses.L1_Loss().to(device)
39+
40+ def run_G(self, z, c, camera_params, real_img_masked=None, real_acc=None, update_emas=False):
41+ assert(self.style_mixing_prob == 0) # not implemented
42+ if self.training_mode == 'layout':
43+ img, infos = self.G(z=z, c=c, camera_params=camera_params,
44+ update_emas=update_emas, extras=['opacity_regularization'])
45+ return img, infos
46+ elif self.training_mode == 'upsampler':
47+ upsampler_ws, feature, thumb, extras = self.G.mapping(z, c, update_emas=update_emas)
48+ img = self.G.synthesis(upsampler_ws, feature, thumb, extras=extras, update_emas=update_emas)
49+ return img, dict(thumb=thumb, img=img, ws=upsampler_ws, extras=extras)
50+ elif self.training_mode == 'sky':
51+ ws = self.G.mapping(z, c, update_emas=update_emas)
52+ # taken from anyres 360 loss
53+ multiply = True if random.uniform(0, 1) < self.loss_sky_kwargs.mask_prob else False
54+ input_layer = self.G.G.synthesis.input # extract input tensor from synthesis network
55+ crop_start = random.randint(0, 360 // input_layer.fov * input_layer.frame_size[0] - 1)
56+ crop_fn = lambda grid : grid[:, :, crop_start:crop_start+input_layer.size[0], :]
57+ img_base = self.G.synthesis(ws, real_img_masked, real_acc,
58+ multiply=False, crop_fn=crop_fn, update_emas=update_emas)
59+ crop_shift = crop_start + input_layer.frame_size[0]
60+ # generate shifted frame for cross-frame discriminator
61+ crop_fn_shift = lambda grid : grid[:, :, crop_shift:crop_shift+input_layer.size[0], :]
62+ img_shifted = self.G.synthesis(ws, real_img_masked, real_acc,
63+ multiply=False, crop_fn=crop_fn_shift,
64+ update_emas=update_emas)
65+ img_splice = torch.cat([img_base, img_shifted], dim=3)
66+ img_size = img_base.shape[-1]
67+ splice_start = random.randint(0, img_size)
68+ img = img_splice[:, :, :, splice_start:splice_start+img_size]
69+ # multiply real img only after generating both splices
70+ if multiply:
71+ img = img * (1-real_acc) + real_img_masked * real_acc
72+ # img = self.G.synthesis(ws, real_img_masked, real_acc, update_emas=update_emas)
73+ return img, ws
74+
75+ def run_D(self, infos, c, blur_sigma=0, update_emas=False):
76+ if self.training_mode == 'layout':
77+ img = infos['rgb']
78+ if self.loss_layout_kwargs.concat_depth:
79+ depth = infos['depth']
80+ # D_shape = self.D.img_resolution
81+ # depth = F.interpolate(depth, size=D_shape, mode='bilinear', align_corners=False)
82+ img = torch.cat([img, depth], dim=1)
83+ if self.loss_layout_kwargs.concat_acc:
84+ acc = infos['acc']
85+ # D_shape = self.D.img_resolution
86+ # acc = F.interpolate(acc, size=D_shape, mode='bilinear', align_corners=False)
87+ img = torch.cat([img, acc], dim=1)
88+ if self.loss_layout_kwargs.aug_policy:
89+ img = DiffAugment(img, normalize=True, policy=self.loss_layout_kwargs.aug_policy)
90+ assert(blur_sigma == 0) # using GSN DiffAugment module
91+ elif self.training_mode == 'upsampler':
92+ img = infos['img']
93+ if self.loss_upsampler_kwargs.d_ignore_depth_acc:
94+ # use 0.5 constant as input into depth and acc channels
95+ b, c, h, w = img.shape
96+ ignore_tensor = torch.ones(b, c-3, h, w).to(img.device) * 0.5
97+ img = torch.cat([img[:, :3], ignore_tensor], dim=1)
98+ elif self.training_mode == 'sky':
99+ img = infos # output of sky generator is img itself
100
101- def run_G(self, z, c, update_emas=False):
102- ws = self.G.mapping(z, c, update_emas=update_emas)
103- if self.style_mixing_prob > 0:
104- with torch.autograd.profiler.record_function('style_mixing'):
105- cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
106- cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
107- ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
108- img = self.G.synthesis(ws, update_emas=update_emas)
109- return img, ws
110-
111- def run_D(self, img, c, blur_sigma=0, update_emas=False):
112blur_size = np.floor(blur_sigma * 3)
113if blur_size > 0:
114with torch.autograd.profiler.record_function('blur'):
115@@ -57,10 +123,17 @@
116img = upfirdn2d.filter2d(img, f / f.sum())
117if self.augment_pipe is not None:
118img = self.augment_pipe(img)
119+
120+ if self.D.recon:
121+ assert(self.training_mode == 'layout')
122+ # handle return elements with reconstruction discriminator
123+ logits, recon = self.D(img, c, update_emas=update_emas)
124+ return logits, img, recon
125+
126logits = self.D(img, c, update_emas=update_emas)
127return logits
128
129- def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
130+ def accumulate_gradients(self, phase, real_img_infos, real_c, gen_z, gen_c, gain, cur_nimg):
131assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
132if self.pl_weight == 0:
133phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
134@@ -68,15 +141,72 @@
135phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
136blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
137
138+ camera_params = copy.deepcopy(real_img_infos['camera_params'])
139+ del camera_params['Rt']
140+ real_img = real_img_infos['rgb'] # img with sky masked
141+ real_orig = real_img_infos['orig'] # img with sky
142+ real_depth = real_img_infos['depth']
143+ real_acc = real_img_infos['acc']
144+
145# Gmain: Maximize logits for generated images.
146if phase in ['Gmain', 'Gboth']:
147with torch.autograd.profiler.record_function('Gmain_forward'):
148- gen_img, _gen_ws = self.run_G(gen_z, gen_c)
149- gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
150- training_stats.report('Loss/scores/fake', gen_logits)
151- training_stats.report('Loss/signs/fake', gen_logits.sign())
152- loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
153- training_stats.report('Loss/G/loss', loss_Gmain)
154+ if self.training_mode == 'layout':
155+ gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params)
156+ gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma)
157+ training_stats.report('Loss/scores/fake', gen_logits)
158+ training_stats.report('Loss/signs/fake', gen_logits.sign())
159+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
160+ training_stats.report('Loss/G/loss', loss_Gmain)
161+ if self.loss_layout_kwargs.lambda_finite_difference > 0:
162+ if self.loss_layout_kwargs.lambda_ramp_end > 0:
163+ ramp_multiplier = cur_nimg / self.loss_layout_kwargs.lambda_ramp_end
164+ else:
165+ ramp_multiplier = 1
166+ ramp_multiplier = np.clip(ramp_multiplier, 0, 1)
167+ training_stats.report('Loss/G/reg_ramp', ramp_multiplier)
168+ finite_diff = regularizers.ray_finite_difference(gen_infos['extra_outputs'])
169+ loss_Gmain = loss_Gmain + ramp_multiplier * self.loss_layout_kwargs.lambda_finite_difference * finite_diff
170+ training_stats.report('Loss/G/reg_ray', finite_diff)
171+ training_stats.report('Loss/G/loss_total', loss_Gmain)
172+ elif self.training_mode == 'upsampler':
173+ gen_img, infos = self.run_G(gen_z, gen_c, camera_params)
174+ thumb = infos['thumb']
175+ gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma)
176+ training_stats.report('Loss/scores/fake', gen_logits)
177+ training_stats.report('Loss/signs/fake', gen_logits.sign())
178+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
179+ training_stats.report('Loss/G/loss', loss_Gmain)
180+ if self.loss_upsampler_kwargs.lambda_rec > 0:
181+ # downsample gen img, it should match thumb
182+ ch = thumb.shape[1]
183+ gen_down = interpolate(gen_img, thumb.shape[-2:])
184+ l1_loss = self.loss_l1(gen_down, thumb)
185+ loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_rec * l1_loss)
186+ training_stats.report('Loss/G/loss_rec_l1', l1_loss)
187+ if self.loss_upsampler_kwargs.lambda_up > 0:
188+ # upsample the depth and acc maps
189+ depth_and_acc_thumb_up = interpolate(thumb[:, 3:], gen_img.shape[-2:])
190+ depth_and_acc_gen = gen_img[:, 3:]
191+ l1_loss = self.loss_l1(depth_and_acc_gen, depth_and_acc_thumb_up)
192+ loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_up * l1_loss)
193+ training_stats.report('Loss/G/loss_up_l1', l1_loss)
194+ if self.loss_upsampler_kwargs.lambda_gray_pixel > 0:
195+ acc_mask = (gen_img[:, -1:] > 0.5).float().detach()
196+ pixel_sum = torch.sum(torch.abs(gen_img[:, :3]), dim=1, keepdim=True)
197+ penalty = torch.exp(-self.loss_upsampler_kwargs.lambda_gray_pixel_falloff * pixel_sum) * acc_mask
198+ reg_loss = torch.mean(penalty, dim=(1, 2, 3))
199+ loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_gray_pixel * reg_loss)
200+ training_stats.report('Loss/G/loss_reg_gray', reg_loss)
201+ training_stats.report('Loss/G/loss_total', loss_Gmain)
202+ elif self.training_mode == 'sky':
203+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc)
204+ gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
205+ training_stats.report('Loss/scores/fake', gen_logits)
206+ training_stats.report('Loss/signs/fake', gen_logits.sign())
207+ loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
208+ training_stats.report('Loss/G/loss', loss_Gmain)
209+
210with torch.autograd.profiler.record_function('Gmain_backward'):
211loss_Gmain.mean().mul(gain).backward()
212
213@@ -84,7 +214,24 @@
214if phase in ['Greg', 'Gboth']:
215with torch.autograd.profiler.record_function('Gpl_forward'):
216batch_size = gen_z.shape[0] // self.pl_batch_shrink
217- gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
218+ camera_params_batch = {k: v[:batch_size] for k, v in camera_params.items()}
219+ if self.training_mode == 'layout':
220+ gen_img, gen_infos = self.run_G(gen_z[:batch_size],
221+ gen_c[:batch_size],
222+ camera_params_batch)
223+ gen_ws = gen_infos['ws']
224+ if self.training_mode == 'upsampler':
225+ gen_img, infos = self.run_G(gen_z[:batch_size],
226+ gen_c[:batch_size],
227+ camera_params_batch)
228+ gen_ws = infos['ws']
229+ if self.training_mode == 'sky':
230+ gen_img, gen_ws = self.run_G(gen_z[:batch_size],
231+ gen_c[:batch_size],
232+ camera_params_batch,
233+ real_img[:batch_size],
234+ real_acc[:batch_size])
235+ # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
236pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
237with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
238pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
239@@ -102,11 +249,24 @@
240loss_Dgen = 0
241if phase in ['Dmain', 'Dboth']:
242with torch.autograd.profiler.record_function('Dgen_forward'):
243- gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True)
244- gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
245- training_stats.report('Loss/scores/fake', gen_logits)
246- training_stats.report('Loss/signs/fake', gen_logits.sign())
247- loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
248+ if self.training_mode == 'layout':
249+ gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
250+ gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
251+ training_stats.report('Loss/scores/fake', gen_logits)
252+ training_stats.report('Loss/signs/fake', gen_logits.sign())
253+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
254+ elif self.training_mode == 'upsampler':
255+ gen_img, infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
256+ gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
257+ training_stats.report('Loss/scores/fake', gen_logits)
258+ training_stats.report('Loss/signs/fake', gen_logits.sign())
259+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
260+ elif self.training_mode == 'sky':
261+ gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc, update_emas=True)
262+ gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
263+ training_stats.report('Loss/scores/fake', gen_logits)
264+ training_stats.report('Loss/signs/fake', gen_logits.sign())
265+ loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
266with torch.autograd.profiler.record_function('Dgen_backward'):
267loss_Dgen.mean().mul(gain).backward()
268
269@@ -115,8 +275,29 @@
270if phase in ['Dmain', 'Dreg', 'Dboth']:
271name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
272with torch.autograd.profiler.record_function(name + '_forward'):
273- real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
274- real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
275+ if self.training_mode == 'layout':
276+ real_img_tmp = {
277+ 'rgb': real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']),
278+ 'depth': real_depth.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
279+ 'acc': real_acc.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
280+ }
281+ r1_grads_input = real_img_tmp['rgb']
282+ real_logits, real_disc_in, real_recon = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
283+
284+ elif self.training_mode == 'upsampler':
285+ real_input = real_img
286+ if self.G.upsampler.synthesis.num_additional_feature_channels > 0:
287+ real_input = torch.cat([real_input, real_depth], dim=1)
288+ if self.G.upsampler.synthesis.num_additional_feature_channels > 1:
289+ real_input = torch.cat([real_input, real_acc], dim=1)
290+ real_img_tmp = real_input.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
291+ r1_grads_input = real_img_tmp
292+ infos_D = dict(img=real_img_tmp)
293+ real_logits = self.run_D(infos_D, real_c, blur_sigma=blur_sigma)
294+ elif self.training_mode == 'sky':
295+ real_img_tmp = real_orig.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
296+ r1_grads_input = real_img_tmp
297+ real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
298training_stats.report('Loss/scores/real', real_logits)
299training_stats.report('Loss/signs/real', real_logits.sign())
300
301@@ -128,13 +309,20 @@
302loss_Dr1 = 0
303if phase in ['Dreg', 'Dboth']:
304with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
305- r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
306+ r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[r1_grads_input], create_graph=True, only_inputs=True)[0]
307r1_penalty = r1_grads.square().sum([1,2,3])
308loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
309training_stats.report('Loss/r1_penalty', r1_penalty)
310training_stats.report('Loss/D/reg', loss_Dr1)
311
312+ loss_Drecon = 0
313+ if self.D.recon:
314+ assert(self.training_mode == 'layout')
315+ if phase in ['Dmain', 'Dboth']:
316+ loss_Drecon = F.mse_loss(real_disc_in, real_recon) * self.loss_layout_kwargs.recon_weight
317+ training_stats.report('Loss/D/recon_loss', loss_Drecon)
318+
319with torch.autograd.profiler.record_function(name + '_backward'):
320- (loss_Dreal + loss_Dr1).mean().mul(gain).backward()
321+ (loss_Dreal + loss_Dr1 + loss_Drecon).mean().mul(gain).backward()
322
323#----------------------------------------------------------------------------
324