google-research

Форк
0
323 строки · 20.5 Кб
1
--- external/stylegan/training/loss.py	2023-04-06 03:45:26.165066348 +0000
2
+++ external_reference/stylegan/training/loss.py	2023-04-06 03:41:03.250603352 +0000
3
@@ -14,6 +14,13 @@
4
 from torch_utils.ops import conv2d_gradfix
5
 from torch_utils.ops import upfirdn2d
6
 
7
+import copy
8
+import random
9
+import torch.nn.functional as F
10
+from utils import camera_util, losses, regularizers
11
+from external.gsn.models.diff_augment import DiffAugment
12
+from utils.utils import interpolate
13
+
14
 #----------------------------------------------------------------------------
15
 
16
 class Loss:
17
@@ -23,7 +30,11 @@
18
 #----------------------------------------------------------------------------
19
 
20
 class StyleGAN2Loss(Loss):
21
-    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0):
22
+    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10,
23
+                 style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2,
24
+                 pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0,
25
+                 blur_fade_kimg=0, training_mode=None, loss_layout_kwargs=None,
26
+                 loss_upsampler_kwargs=None, loss_sky_kwargs=None):
27
         super().__init__()
28
         self.device             = device
29
         self.G                  = G
30
@@ -38,18 +49,73 @@
31
         self.pl_mean            = torch.zeros([], device=device)
32
         self.blur_init_sigma    = blur_init_sigma
33
         self.blur_fade_kimg     = blur_fade_kimg
34
+        self.loss_layout_kwargs = loss_layout_kwargs
35
+        self.loss_upsampler_kwargs = loss_upsampler_kwargs
36
+        self.loss_sky_kwargs = loss_sky_kwargs
37
+        self.training_mode = training_mode
38
+        self.loss_l1 = losses.L1_Loss().to(device)
39
+
40
+    def run_G(self, z, c, camera_params, real_img_masked=None, real_acc=None, update_emas=False):
41
+        assert(self.style_mixing_prob == 0) # not implemented
42
+        if self.training_mode == 'layout':
43
+            img, infos = self.G(z=z, c=c, camera_params=camera_params,
44
+                                update_emas=update_emas, extras=['opacity_regularization'])
45
+            return img, infos
46
+        elif self.training_mode == 'upsampler':
47
+            upsampler_ws, feature, thumb, extras = self.G.mapping(z, c, update_emas=update_emas)
48
+            img = self.G.synthesis(upsampler_ws, feature, thumb, extras=extras, update_emas=update_emas)
49
+            return img, dict(thumb=thumb, img=img, ws=upsampler_ws, extras=extras)
50
+        elif self.training_mode == 'sky':
51
+            ws = self.G.mapping(z, c, update_emas=update_emas)
52
+            # taken from anyres 360 loss
53
+            multiply = True if random.uniform(0, 1) < self.loss_sky_kwargs.mask_prob else False
54
+            input_layer = self.G.G.synthesis.input # extract input tensor from synthesis network
55
+            crop_start = random.randint(0, 360 // input_layer.fov * input_layer.frame_size[0] - 1)
56
+            crop_fn = lambda grid : grid[:, :, crop_start:crop_start+input_layer.size[0], :]
57
+            img_base = self.G.synthesis(ws, real_img_masked, real_acc,
58
+                                        multiply=False, crop_fn=crop_fn, update_emas=update_emas)
59
+            crop_shift = crop_start + input_layer.frame_size[0]
60
+            # generate shifted frame for cross-frame discriminator
61
+            crop_fn_shift = lambda grid : grid[:, :, crop_shift:crop_shift+input_layer.size[0], :]
62
+            img_shifted = self.G.synthesis(ws, real_img_masked, real_acc,
63
+                                           multiply=False, crop_fn=crop_fn_shift,
64
+                                           update_emas=update_emas)
65
+            img_splice = torch.cat([img_base, img_shifted], dim=3)
66
+            img_size = img_base.shape[-1]
67
+            splice_start = random.randint(0, img_size)
68
+            img = img_splice[:, :, :, splice_start:splice_start+img_size]
69
+            # multiply real img only after generating both splices
70
+            if multiply:
71
+                img = img * (1-real_acc) + real_img_masked * real_acc
72
+            # img = self.G.synthesis(ws, real_img_masked, real_acc, update_emas=update_emas)
73
+            return img, ws
74
+
75
+    def run_D(self, infos, c, blur_sigma=0, update_emas=False):
76
+        if self.training_mode == 'layout':
77
+            img = infos['rgb']
78
+            if self.loss_layout_kwargs.concat_depth:
79
+                depth = infos['depth']
80
+                # D_shape = self.D.img_resolution
81
+                # depth = F.interpolate(depth, size=D_shape, mode='bilinear', align_corners=False)
82
+                img = torch.cat([img, depth], dim=1)
83
+            if self.loss_layout_kwargs.concat_acc:
84
+                acc = infos['acc']
85
+                # D_shape = self.D.img_resolution
86
+                # acc = F.interpolate(acc, size=D_shape, mode='bilinear', align_corners=False)
87
+                img = torch.cat([img, acc], dim=1)
88
+            if self.loss_layout_kwargs.aug_policy:
89
+                img = DiffAugment(img, normalize=True, policy=self.loss_layout_kwargs.aug_policy)
90
+            assert(blur_sigma == 0) # using GSN DiffAugment module
91
+        elif self.training_mode == 'upsampler':
92
+            img = infos['img']
93
+            if self.loss_upsampler_kwargs.d_ignore_depth_acc:
94
+                # use 0.5 constant as input into depth and acc channels
95
+                b, c, h, w = img.shape
96
+                ignore_tensor = torch.ones(b, c-3, h, w).to(img.device) * 0.5
97
+                img = torch.cat([img[:, :3], ignore_tensor], dim=1)
98
+        elif self.training_mode == 'sky':
99
+            img = infos # output of sky generator is img itself
100
 
101
-    def run_G(self, z, c, update_emas=False):
102
-        ws = self.G.mapping(z, c, update_emas=update_emas)
103
-        if self.style_mixing_prob > 0:
104
-            with torch.autograd.profiler.record_function('style_mixing'):
105
-                cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
106
-                cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
107
-                ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
108
-        img = self.G.synthesis(ws, update_emas=update_emas)
109
-        return img, ws
110
-
111
-    def run_D(self, img, c, blur_sigma=0, update_emas=False):
112
         blur_size = np.floor(blur_sigma * 3)
113
         if blur_size > 0:
114
             with torch.autograd.profiler.record_function('blur'):
115
@@ -57,10 +123,17 @@
116
                 img = upfirdn2d.filter2d(img, f / f.sum())
117
         if self.augment_pipe is not None:
118
             img = self.augment_pipe(img)
119
+
120
+        if self.D.recon:
121
+            assert(self.training_mode == 'layout')
122
+            # handle return elements with reconstruction discriminator
123
+            logits, recon = self.D(img, c, update_emas=update_emas)
124
+            return logits, img, recon
125
+
126
         logits = self.D(img, c, update_emas=update_emas)
127
         return logits
128
 
129
-    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
130
+    def accumulate_gradients(self, phase, real_img_infos, real_c, gen_z, gen_c, gain, cur_nimg):
131
         assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
132
         if self.pl_weight == 0:
133
             phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
134
@@ -68,15 +141,72 @@
135
             phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
136
         blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
137
 
138
+        camera_params = copy.deepcopy(real_img_infos['camera_params'])
139
+        del camera_params['Rt']
140
+        real_img = real_img_infos['rgb'] # img with sky masked
141
+        real_orig = real_img_infos['orig'] # img with sky 
142
+        real_depth = real_img_infos['depth']
143
+        real_acc = real_img_infos['acc']
144
+
145
         # Gmain: Maximize logits for generated images.
146
         if phase in ['Gmain', 'Gboth']:
147
             with torch.autograd.profiler.record_function('Gmain_forward'):
148
-                gen_img, _gen_ws = self.run_G(gen_z, gen_c)
149
-                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
150
-                training_stats.report('Loss/scores/fake', gen_logits)
151
-                training_stats.report('Loss/signs/fake', gen_logits.sign())
152
-                loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
153
-                training_stats.report('Loss/G/loss', loss_Gmain)
154
+                if self.training_mode == 'layout':
155
+                    gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params)
156
+                    gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma)
157
+                    training_stats.report('Loss/scores/fake', gen_logits)
158
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
159
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
160
+                    training_stats.report('Loss/G/loss', loss_Gmain)
161
+                    if self.loss_layout_kwargs.lambda_finite_difference > 0:
162
+                        if self.loss_layout_kwargs.lambda_ramp_end > 0:
163
+                            ramp_multiplier = cur_nimg / self.loss_layout_kwargs.lambda_ramp_end
164
+                        else:
165
+                            ramp_multiplier = 1
166
+                        ramp_multiplier = np.clip(ramp_multiplier, 0, 1)
167
+                        training_stats.report('Loss/G/reg_ramp', ramp_multiplier)
168
+                        finite_diff = regularizers.ray_finite_difference(gen_infos['extra_outputs'])
169
+                        loss_Gmain = loss_Gmain + ramp_multiplier * self.loss_layout_kwargs.lambda_finite_difference * finite_diff
170
+                        training_stats.report('Loss/G/reg_ray', finite_diff)
171
+                    training_stats.report('Loss/G/loss_total', loss_Gmain)
172
+                elif self.training_mode == 'upsampler':
173
+                    gen_img, infos = self.run_G(gen_z, gen_c, camera_params)
174
+                    thumb = infos['thumb']
175
+                    gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma)
176
+                    training_stats.report('Loss/scores/fake', gen_logits)
177
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
178
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
179
+                    training_stats.report('Loss/G/loss', loss_Gmain)
180
+                    if self.loss_upsampler_kwargs.lambda_rec > 0:
181
+                        # downsample gen img, it should match thumb
182
+                        ch = thumb.shape[1]
183
+                        gen_down = interpolate(gen_img, thumb.shape[-2:])
184
+                        l1_loss = self.loss_l1(gen_down, thumb)
185
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_rec * l1_loss)
186
+                        training_stats.report('Loss/G/loss_rec_l1', l1_loss)
187
+                    if self.loss_upsampler_kwargs.lambda_up > 0:
188
+                        # upsample the depth and acc maps
189
+                        depth_and_acc_thumb_up = interpolate(thumb[:, 3:], gen_img.shape[-2:])
190
+                        depth_and_acc_gen = gen_img[:, 3:]
191
+                        l1_loss = self.loss_l1(depth_and_acc_gen, depth_and_acc_thumb_up)
192
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_up * l1_loss)
193
+                        training_stats.report('Loss/G/loss_up_l1', l1_loss)
194
+                    if self.loss_upsampler_kwargs.lambda_gray_pixel > 0:
195
+                        acc_mask = (gen_img[:, -1:] > 0.5).float().detach()
196
+                        pixel_sum = torch.sum(torch.abs(gen_img[:, :3]), dim=1, keepdim=True)
197
+                        penalty = torch.exp(-self.loss_upsampler_kwargs.lambda_gray_pixel_falloff * pixel_sum) * acc_mask
198
+                        reg_loss = torch.mean(penalty, dim=(1, 2, 3))
199
+                        loss_Gmain = (loss_Gmain + self.loss_upsampler_kwargs.lambda_gray_pixel * reg_loss)
200
+                        training_stats.report('Loss/G/loss_reg_gray', reg_loss)
201
+                    training_stats.report('Loss/G/loss_total', loss_Gmain)
202
+                elif self.training_mode == 'sky':
203
+                    gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc)
204
+                    gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
205
+                    training_stats.report('Loss/scores/fake', gen_logits)
206
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
207
+                    loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
208
+                    training_stats.report('Loss/G/loss', loss_Gmain)
209
+
210
             with torch.autograd.profiler.record_function('Gmain_backward'):
211
                 loss_Gmain.mean().mul(gain).backward()
212
 
213
@@ -84,7 +214,24 @@
214
         if phase in ['Greg', 'Gboth']:
215
             with torch.autograd.profiler.record_function('Gpl_forward'):
216
                 batch_size = gen_z.shape[0] // self.pl_batch_shrink
217
-                gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
218
+                camera_params_batch = {k: v[:batch_size] for k, v in camera_params.items()}
219
+                if self.training_mode == 'layout':
220
+                    gen_img, gen_infos = self.run_G(gen_z[:batch_size],
221
+                                                    gen_c[:batch_size],
222
+                                                    camera_params_batch)
223
+                    gen_ws = gen_infos['ws']
224
+                if self.training_mode == 'upsampler':
225
+                    gen_img, infos = self.run_G(gen_z[:batch_size],
226
+                                                gen_c[:batch_size],
227
+                                                camera_params_batch)
228
+                    gen_ws = infos['ws']
229
+                if self.training_mode == 'sky':
230
+                    gen_img, gen_ws = self.run_G(gen_z[:batch_size],
231
+                                                 gen_c[:batch_size],
232
+                                                 camera_params_batch,
233
+                                                 real_img[:batch_size],
234
+                                                 real_acc[:batch_size])
235
+                # gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size])
236
                 pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
237
                 with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad):
238
                     pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
239
@@ -102,11 +249,24 @@
240
         loss_Dgen = 0
241
         if phase in ['Dmain', 'Dboth']:
242
             with torch.autograd.profiler.record_function('Dgen_forward'):
243
-                gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True)
244
-                gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
245
-                training_stats.report('Loss/scores/fake', gen_logits)
246
-                training_stats.report('Loss/signs/fake', gen_logits.sign())
247
-                loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
248
+                if self.training_mode == 'layout':
249
+                    gen_img, gen_infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
250
+                    gen_logits, _, _ = self.run_D(gen_infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
251
+                    training_stats.report('Loss/scores/fake', gen_logits)
252
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
253
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
254
+                elif self.training_mode == 'upsampler':
255
+                    gen_img, infos = self.run_G(gen_z, gen_c, camera_params, update_emas=True)
256
+                    gen_logits = self.run_D(infos, gen_c, blur_sigma=blur_sigma, update_emas=True)
257
+                    training_stats.report('Loss/scores/fake', gen_logits)
258
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
259
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
260
+                elif self.training_mode == 'sky':
261
+                    gen_img, _gen_ws = self.run_G(gen_z, gen_c, camera_params, real_img, real_acc, update_emas=True)
262
+                    gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
263
+                    training_stats.report('Loss/scores/fake', gen_logits)
264
+                    training_stats.report('Loss/signs/fake', gen_logits.sign())
265
+                    loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
266
             with torch.autograd.profiler.record_function('Dgen_backward'):
267
                 loss_Dgen.mean().mul(gain).backward()
268
 
269
@@ -115,8 +275,29 @@
270
         if phase in ['Dmain', 'Dreg', 'Dboth']:
271
             name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
272
             with torch.autograd.profiler.record_function(name + '_forward'):
273
-                real_img_tmp = real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
274
-                real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
275
+                if self.training_mode == 'layout':
276
+                    real_img_tmp = {
277
+                        'rgb': real_img.detach().requires_grad_(phase in ['Dreg', 'Dboth']),
278
+                        'depth': real_depth.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
279
+                        'acc': real_acc.detach(), # .requires_grad_(phase in ['Dreg', 'Dboth']),
280
+                    }
281
+                    r1_grads_input = real_img_tmp['rgb']
282
+                    real_logits, real_disc_in, real_recon = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
283
+
284
+                elif self.training_mode == 'upsampler':
285
+                    real_input = real_img
286
+                    if self.G.upsampler.synthesis.num_additional_feature_channels > 0:
287
+                        real_input = torch.cat([real_input, real_depth], dim=1)
288
+                    if self.G.upsampler.synthesis.num_additional_feature_channels > 1:
289
+                        real_input = torch.cat([real_input, real_acc], dim=1)
290
+                    real_img_tmp = real_input.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
291
+                    r1_grads_input = real_img_tmp
292
+                    infos_D = dict(img=real_img_tmp)
293
+                    real_logits = self.run_D(infos_D, real_c, blur_sigma=blur_sigma)
294
+                elif self.training_mode == 'sky':
295
+                    real_img_tmp = real_orig.detach().requires_grad_(phase in ['Dreg', 'Dboth'])
296
+                    r1_grads_input = real_img_tmp
297
+                    real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
298
                 training_stats.report('Loss/scores/real', real_logits)
299
                 training_stats.report('Loss/signs/real', real_logits.sign())
300
 
301
@@ -128,13 +309,20 @@
302
                 loss_Dr1 = 0
303
                 if phase in ['Dreg', 'Dboth']:
304
                     with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
305
-                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
306
+                        r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[r1_grads_input], create_graph=True, only_inputs=True)[0]
307
                     r1_penalty = r1_grads.square().sum([1,2,3])
308
                     loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
309
                     training_stats.report('Loss/r1_penalty', r1_penalty)
310
                     training_stats.report('Loss/D/reg', loss_Dr1)
311
 
312
+                loss_Drecon = 0
313
+                if self.D.recon:
314
+                    assert(self.training_mode == 'layout')
315
+                    if phase in ['Dmain', 'Dboth']:
316
+                        loss_Drecon = F.mse_loss(real_disc_in, real_recon) * self.loss_layout_kwargs.recon_weight
317
+                        training_stats.report('Loss/D/recon_loss', loss_Drecon)
318
+
319
             with torch.autograd.profiler.record_function(name + '_backward'):
320
-                (loss_Dreal + loss_Dr1).mean().mul(gain).backward()
321
+                (loss_Dreal + loss_Dr1 + loss_Drecon).mean().mul(gain).backward()
322
 
323
 #----------------------------------------------------------------------------
324

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.