google-research

Форк
0
84 строки · 4.8 Кб
1
--- external/eg3d/training/loss.py	2023-04-06 03:55:57.421400822 +0000
2
+++ external_reference/eg3d/training/loss.py	2023-04-06 03:41:04.420707774 +0000
3
@@ -15,7 +15,7 @@
4
 from torch_utils import training_stats
5
 from torch_utils.ops import conv2d_gradfix
6
 from torch_utils.ops import upfirdn2d
7
-from training.dual_discriminator import filtered_resizing
8
+from external.eg3d.training.dual_discriminator import filtered_resizing
9
 
10
 #----------------------------------------------------------------------------
11
 
12
@@ -26,7 +26,17 @@
13
 #----------------------------------------------------------------------------
14
 
15
 class StyleGAN2Loss(Loss):
16
-    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
17
+    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10,
18
+                 style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2,
19
+                 pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0,
20
+                 blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0,
21
+                 neural_rendering_resolution_initial=64,
22
+                 neural_rendering_resolution_final=None,
23
+                 neural_rendering_resolution_fade_kimg=0,
24
+                 gpc_reg_fade_kimg=1000, gpc_reg_prob=None,
25
+                 dual_discrimination=False, filter_mode='antialiased', 
26
+                 ignore_LR_disp=False, ignore_HR_disp=True,
27
+                 lambda_sky_pixel=0., lambda_ramp_end=0):
28
         super().__init__()
29
         self.device             = device
30
         self.G                  = G
31
@@ -52,6 +62,10 @@
32
         self.filter_mode = filter_mode
33
         self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
34
         self.blur_raw_target = True
35
+        self.ignore_LR_disp = ignore_LR_disp
36
+        self.ignore_HR_disp = ignore_HR_disp
37
+        self.lambda_sky_pixel = lambda_sky_pixel
38
+        self.lambda_ramp_end = lambda_ramp_end
39
         assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)
40
 
41
     def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
42
@@ -84,6 +98,21 @@
43
             img['image'] = augmented_pair[:, :img['image'].shape[1]]
44
             img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)
45
 
46
+        if self.ignore_LR_disp:
47
+            # mask LR disp channel to discriminator
48
+            dummy = torch.zeros_like(img['image_raw'])
49
+            img['image_raw'] = torch.cat([img['image_raw'][:, :3],
50
+                                          dummy[:, :1],
51
+                                          img['image_raw'][:, 4:]
52
+                                         ], dim=1)
53
+        if self.ignore_HR_disp:
54
+            # mask HR disp channel to discriminator
55
+            dummy = torch.zeros_like(img['image'])
56
+            img['image'] = torch.cat([img['image'][:, :3],
57
+                                      dummy[:, :1],
58
+                                      img['image'][:, 4:]
59
+                                     ], dim=1)
60
+
61
         logits = self.D(img, c, update_emas=update_emas)
62
         return logits
63
 
64
@@ -124,6 +153,21 @@
65
                 training_stats.report('Loss/signs/fake', gen_logits.sign())
66
                 loss_Gmain = torch.nn.functional.softplus(-gen_logits)
67
                 training_stats.report('Loss/G/loss', loss_Gmain)
68
+                if self.lambda_sky_pixel > 0:
69
+                    if self.lambda_ramp_end > 0:
70
+                        ramp_multiplier = cur_nimg / self.lambda_ramp_end
71
+                        ramp_multiplier = np.clip(ramp_multiplier, 0, 1)
72
+                        training_stats.report('Loss/G/reg_ramp', ramp_multiplier)
73
+                    else:
74
+                        ramp_multiplier = 1
75
+                    weights_detach = gen_img['weights_raw'].detach()
76
+                    pixel_sum = gen_img['image_raw'][:, :3].sum(1, keepdim=True)
77
+                    # add penalty on white pixels when weights > 0
78
+                    penalty = torch.exp(5*(pixel_sum-3)) * weights_detach
79
+                    sky_penalty = penalty.mean(dim=(2,3)) # Nx1
80
+                    loss_Gmain = (loss_Gmain + ramp_multiplier * self.lambda_sky_pixel * sky_penalty)
81
+                    training_stats.report('Loss/G/loss_sky_pixel', sky_penalty)
82
+                training_stats.report('Loss/G/loss_total', loss_Gmain)
83
             with torch.autograd.profiler.record_function('Gmain_backward'):
84
                 loss_Gmain.mean().mul(gain).backward()
85
 
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.