google-research

Форк
0
/
networks_stylegan2_terrain.patch 
199 строк · 9.8 Кб
1
--- external/stylegan/training/networks_stylegan2_terrain.py	2023-03-09 18:11:00.963870634 +0000
2
+++ external_reference/stylegan/training/networks_stylegan2_terrain.py	2023-03-09 18:07:58.572064994 +0000
3
@@ -20,6 +20,9 @@
4
 from torch_utils.ops import bias_act
5
 from torch_utils.ops import fma
6
 
7
+from external.gsn.models.discriminator import ConvDecoder
8
+from utils.utils import interpolate
9
+
10
 #----------------------------------------------------------------------------
11
 
12
 @misc.profiled_function
13
@@ -306,17 +309,23 @@
14
             self.noise_strength = torch.nn.Parameter(torch.zeros([]))
15
         self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
16
 
17
-    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
18
-        assert noise_mode in ['random', 'const', 'none']
19
+    def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, noise_input=None):
20
+        assert noise_mode in ['random', 'const', 'none', '3dnoise']
21
         in_resolution = self.resolution // self.up
22
-        misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
23
+        # CHANGED: layout SOAT noise may have different dimensions
24
+        # misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
25
         styles = self.affine(w)
26
 
27
         noise = None
28
         if self.use_noise and noise_mode == 'random':
29
-            noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
30
+            # CHANGED: layout SOAT noise may be spatially larger than self.resolution
31
+            noise = torch.randn([x.shape[0], 1, x.shape[2]*self.up, x.shape[3]*self.up], device=x.device) * self.noise_strength
32
+            # noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
33
         if self.use_noise and noise_mode == 'const':
34
             noise = self.noise_const * self.noise_strength
35
+        if self.use_noise and noise_mode == '3dnoise':
36
+            # CHANGED: support 3d projected noise input in upsampler
37
+            noise = interpolate(noise_input, x.shape[2]*self.up)
38
 
39
         flip_weight = (self.up == 1) # slightly faster
40
         x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
41
@@ -463,31 +472,47 @@
42
 
43
 #----------------------------------------------------------------------------
44
 
45
+# CHANGED: add support for truncated generator (upsampler)
46
 @persistence.persistent_class
47
 class SynthesisNetwork(torch.nn.Module):
48
     def __init__(self,
49
         w_dim,                      # Intermediate latent (W) dimensionality.
50
         img_resolution,             # Output image resolution.
51
         img_channels,               # Number of color channels.
52
+        input_resolution = 4,        # Input resolution for truncated generator
53
         channel_base    = 32768,    # Overall multiplier for the number of channels.
54
         channel_max     = 512,      # Maximum number of channels in any layer.
55
         num_fp16_res    = 4,        # Use FP16 for the N highest resolutions.
56
+        num_additional_feature_channels = 0, # Additional feature channels for input layer
57
+        default_noise_mode = 'random', 
58
         **block_kwargs,             # Arguments for SynthesisBlock.
59
     ):
60
         assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
61
         super().__init__()
62
         self.w_dim = w_dim
63
+        self.input_resolution = input_resolution
64
+        self.input_resolution_log2 = int(np.log2(input_resolution))
65
         self.img_resolution = img_resolution
66
         self.img_resolution_log2 = int(np.log2(img_resolution))
67
         self.img_channels = img_channels
68
         self.num_fp16_res = num_fp16_res
69
-        self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
70
+        self.block_resolutions = [2 ** i for i in range(self.input_resolution_log2, self.img_resolution_log2 + 1)]
71
         channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
72
         fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
73
 
74
+        self.input_channels = channels_dict[self.block_resolutions[0]]
75
+        self.num_additional_feature_channels = num_additional_feature_channels
76
+        self.default_noise_mode = default_noise_mode
77
+
78
         self.num_ws = 0
79
-        for res in self.block_resolutions:
80
+        for num_res, res in enumerate(self.block_resolutions):
81
+            if num_res == 0 and res > 4:
82
+                # for upsampler network, skip the first entry (used for input layer)
83
+                continue
84
             in_channels = channels_dict[res // 2] if res > 4 else 0
85
+            # CHANGED: concatenate additional feature channels as input
86
+            if num_res == 1:
87
+                in_channels += num_additional_feature_channels
88
             out_channels = channels_dict[res]
89
             use_fp16 = (res >= fp16_resolution)
90
             is_last = (res == self.img_resolution)
91
@@ -498,19 +523,27 @@
92
                 self.num_ws += block.num_torgb
93
             setattr(self, f'b{res}', block)
94
 
95
-    def forward(self, ws, **block_kwargs):
96
+    def forward(self, ws, x=None, img=None, **block_kwargs):
97
         block_ws = []
98
+        block_res = self.block_resolutions
99
+        if self.input_resolution > 4:
100
+            # skip input block for upsampler network
101
+            block_res = self.block_resolutions[1:]
102
+
103
         with torch.autograd.profiler.record_function('split_ws'):
104
             misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
105
             ws = ws.to(torch.float32)
106
             w_idx = 0
107
-            for res in self.block_resolutions:
108
+            for res in block_res:
109
                 block = getattr(self, f'b{res}')
110
                 block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
111
                 w_idx += block.num_conv
112
 
113
-        x = img = None
114
-        for res, cur_ws in zip(self.block_resolutions, block_ws):
115
+        if 'noise_mode' not in block_kwargs:
116
+            block_kwargs['noise_mode'] = self.default_noise_mode
117
+
118
+        # x = img = None
119
+        for res, cur_ws in zip(block_res, block_ws):
120
             block = getattr(self, f'b{res}')
121
             x, img = block(x, img, cur_ws, **block_kwargs)
122
         return img
123
@@ -523,6 +556,7 @@
124
 
125
 #----------------------------------------------------------------------------
126
 
127
+# CHANGED: add support for truncated generator (upsampler)
128
 @persistence.persistent_class
129
 class Generator(torch.nn.Module):
130
     def __init__(self,
131
@@ -531,6 +565,7 @@
132
         w_dim,                      # Intermediate latent (W) dimensionality.
133
         img_resolution,             # Output resolution.
134
         img_channels,               # Number of output color channels.
135
+        input_resolution    = 4,    # Input resolution for truncated generator (upsampler)
136
         mapping_kwargs      = {},   # Arguments for MappingNetwork.
137
         **synthesis_kwargs,         # Arguments for SynthesisNetwork.
138
     ):
139
@@ -540,13 +575,14 @@
140
         self.w_dim = w_dim
141
         self.img_resolution = img_resolution
142
         self.img_channels = img_channels
143
-        self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
144
+        self.input_resolution = input_resolution
145
+        self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, input_resolution=input_resolution, **synthesis_kwargs)
146
         self.num_ws = self.synthesis.num_ws
147
         self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
148
 
149
-    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
150
+    def forward(self, z, c, x=None, img=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
151
         ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
152
-        img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
153
+        img = self.synthesis(ws, x=x, img=img, update_emas=update_emas, **synthesis_kwargs)
154
         return img
155
 
156
 #----------------------------------------------------------------------------
157
@@ -730,6 +766,7 @@
158
 
159
 #----------------------------------------------------------------------------
160
 
161
+# CHANGED: add reconstruction decoder to discriminator
162
 @persistence.persistent_class
163
 class Discriminator(torch.nn.Module):
164
     def __init__(self,
165
@@ -745,6 +782,7 @@
166
         block_kwargs        = {},       # Arguments for DiscriminatorBlock.
167
         mapping_kwargs      = {},       # Arguments for MappingNetwork.
168
         epilogue_kwargs     = {},       # Arguments for DiscriminatorEpilogue.
169
+        recon = True,
170
     ):
171
         super().__init__()
172
         self.c_dim = c_dim
173
@@ -774,6 +812,11 @@
174
         if c_dim > 0:
175
             self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
176
         self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
177
+        self.recon = recon
178
+        if self.recon:
179
+            self.decoder = ConvDecoder(in_channel=channels_dict[4],
180
+                                       out_channel=img_channels, in_res=4,
181
+                                       out_res=img_resolution)
182
 
183
     def forward(self, img, c, update_emas=False, **block_kwargs):
184
         _ = update_emas # unused
185
@@ -785,8 +828,13 @@
186
         cmap = None
187
         if self.c_dim > 0:
188
             cmap = self.mapping(None, c)
189
+        if self.recon:
190
+            recon = self.decoder(x.float())
191
         x = self.b4(x, img, cmap)
192
-        return x
193
+        if self.recon:
194
+            return x, recon
195
+        else:
196
+            return x
197
 
198
     def extra_repr(self):
199
         return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.