google-research

Форк
0
446 строк · 27.0 Кб
1
--- external/stylegan/training/training_loop.py	2023-04-06 03:45:49.515150152 +0000
2
+++ external_reference/stylegan/training/training_loop.py	2023-04-06 03:41:03.318609421 +0000
3
@@ -24,14 +24,17 @@
4
 from torch_utils.ops import grid_sample_gradfix
5
 
6
 import legacy
7
-from metrics import metric_main
8
+from external.stylegan.metrics import metric_main
9
+
10
+from utils import camera_util
11
+from collections import defaultdict
12
 
13
 #----------------------------------------------------------------------------
14
 
15
 def setup_snapshot_image_grid(training_set, random_seed=0):
16
     rnd = np.random.RandomState(random_seed)
17
-    gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
18
-    gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
19
+    gw = np.clip(1024// training_set.image_shape[2], 7, 32)
20
+    gh = np.clip(1024// training_set.image_shape[1], 4, 32)
21
 
22
     # No labels => show random subset of training samples.
23
     if not training_set.has_labels:
24
@@ -61,29 +64,57 @@
25
             grid_indices += [indices[x % len(indices)] for x in range(gw)]
26
             label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
27
 
28
-    # Load data.
29
-    images, labels = zip(*[training_set[i] for i in grid_indices])
30
-    return (gw, gh), np.stack(images), np.stack(labels)
31
+    # Load data -- modified for rgb/depth/mask loader
32
+    rgbs, depths, origs, accs, Rts, labels = [], [], [], [], [], []
33
+    for i in grid_indices:
34
+        image_info, label = training_set[i]
35
+        # (rgb, depth, acc, K, Rt) = image_info
36
+        rgbs.append(image_info['rgb'])
37
+        depths.append(image_info['depth'])
38
+        origs.append(image_info['orig'])
39
+        accs.append(image_info['acc'])
40
+        Rts.append(image_info['Rt'])
41
+        labels.append(label)
42
+    rgbs = np.stack(rgbs)
43
+    depths = np.stack(depths)
44
+    origs = np.stack(origs)
45
+    accs = np.stack(accs)
46
+    Rts = np.stack(Rts)
47
+    # images, labels = zip(*[training_set[i] for i in grid_indices])
48
+    return (gw, gh), {'rgb': rgbs, 'depth': depths, 'orig': origs, 'acc': accs, 'Rt': Rts}, np.stack(labels)
49
 
50
 #----------------------------------------------------------------------------
51
 
52
-def save_image_grid(img, fname, drange, grid_size):
53
-    lo, hi = drange
54
-    img = np.asarray(img, dtype=np.float32)
55
-    img = (img - lo) * (255 / (hi - lo))
56
-    img = np.rint(img).clip(0, 255).astype(np.uint8)
57
-
58
-    gw, gh = grid_size
59
-    _N, C, H, W = img.shape
60
-    img = img.reshape([gh, gw, C, H, W])
61
-    img = img.transpose(0, 3, 1, 4, 2)
62
-    img = img.reshape([gh * H, gw * W, C])
63
-
64
-    assert C in [1, 3]
65
-    if C == 1:
66
-        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
67
-    if C == 3:
68
-        PIL.Image.fromarray(img, 'RGB').save(fname)
69
+def save_image_grid(img_infos, fname_base, drange, grid_size):
70
+    # modified to save rgb, depth, mask outputs
71
+    if not isinstance(img_infos, dict):
72
+        img_infos = {'rgb': img_infos}
73
+
74
+    rgb_drange=drange
75
+    for key, img in img_infos.items():
76
+        if key not in ['rgb', 'depth', 'acc', 'orig']:
77
+            continue
78
+        if key in ['rgb', 'orig']:
79
+            drange = rgb_drange
80
+        else:
81
+            drange = [0, 1]
82
+        fname = fname_base.replace('.png', '-%s.png' % key)
83
+        lo, hi = drange
84
+        img = np.asarray(img, dtype=np.float32)
85
+        img = (img - lo) * (255 / (hi - lo))
86
+        img = np.rint(img).clip(0, 255).astype(np.uint8)
87
+
88
+        gw, gh = grid_size
89
+        _N, C, H, W = img.shape
90
+        img = img.reshape([gh, gw, C, H, W])
91
+        img = img.transpose(0, 3, 1, 4, 2)
92
+        img = img.reshape([gh * H, gw * W, C])
93
+
94
+        assert C in [1, 3]
95
+        if C == 1:
96
+            PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
97
+        if C == 3:
98
+            PIL.Image.fromarray(img, 'RGB').save(fname)
99
 
100
 #----------------------------------------------------------------------------
101
 
102
@@ -120,7 +151,13 @@
103
     cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
104
     abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
105
     progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
106
+    training_mode           = None,     # which training mode to use
107
+    wrapper_kwargs          = {},       # model wrapper arguments
108
+    decoder_kwargs          = {},       # additional arguments for layout decoder 
109
+    torgb_kwargs            = {},       # additional arguments for layout torgb layer
110
+
111
 ):
112
+
113
     # Initialize.
114
     start_time = time.time()
115
     device = torch.device('cuda', rank)
116
@@ -132,6 +169,9 @@
117
     conv2d_gradfix.enabled = True                       # Improves training speed.
118
     grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.
119
 
120
+    # added to prevent data_loader pin_memory to load to device 0 for every process
121
+    torch.cuda.set_device(device)
122
+
123
     # Load training set.
124
     if rank == 0:
125
         print('Loading training set...')
126
@@ -148,9 +188,60 @@
127
     # Construct networks.
128
     if rank == 0:
129
         print('Constructing networks...')
130
-    common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
131
-    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
132
-    D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
133
+    # common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
134
+    # G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
135
+    # D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
136
+    if training_mode == 'layout':
137
+        from models.layout import model_layout
138
+        from external.gsn.models.model_utils import TrajectorySampler
139
+        voxel_size = wrapper_kwargs.voxel_size
140
+        voxel_res = wrapper_kwargs.voxel_res
141
+        image_infos,  _ = training_set[0]
142
+        trajectory_sampler = TrajectorySampler(real_Rts=torch.from_numpy(training_set.Rt).float(), mode='sample').to(device)
143
+        # decoder part common kwargs
144
+        G_kwargs.c_dim = training_set.label_dim
145
+        # discriminator common kwargs
146
+        D_kwargs.c_dim = training_set.label_dim
147
+        D_kwargs.img_resolution=training_set.resolution
148
+        D_kwargs.img_channels=training_set.num_channels+int(loss_kwargs.loss_layout_kwargs.concat_acc)+int(loss_kwargs.loss_layout_kwargs.concat_depth)
149
+        G = model_layout.ModelLayout(G_kwargs, decoder_kwargs, torgb_kwargs, **wrapper_kwargs).train().requires_grad_(False).to(device)
150
+        G.set_trajectory_sampler(trajectory_sampler=trajectory_sampler)
151
+        if not loss_kwargs.loss_layout_kwargs.use_wrapped_discriminator:
152
+            D = dnnlib.util.construct_class_by_name(**D_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
153
+        else:
154
+            # use two discriminators: one on RGBD, and one on sky mask
155
+            assert loss_kwargs.loss_layout_kwargs.concat_acc
156
+            D_kwargs.img_channels=training_set.num_channels+int(loss_kwargs.loss_layout_kwargs.concat_depth)
157
+            D_img = dnnlib.util.construct_class_by_name(**D_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
158
+            D_kwargs_acc = copy.deepcopy(D_kwargs)
159
+            D_kwargs_acc.img_channels=1
160
+            D_acc = dnnlib.util.construct_class_by_name(**D_kwargs_acc).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
161
+            from models.misc.networks import WrappedDiscriminator
162
+            wrappedD = WrappedDiscriminator(D_img, D_acc).train().requires_grad_(False).to(device)
163
+            D = wrappedD
164
+    elif training_mode == 'upsampler':
165
+        common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution)
166
+        G_kwargs.img_channels = training_set.num_channels + G_kwargs.num_additional_feature_channels
167
+        # G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
168
+        D_img_channels = G_kwargs.img_channels
169
+        D_kwargs.img_channels = G_kwargs.img_channels
170
+        D_kwargs.recon = False # no reconstruction part
171
+        D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
172
+        from models.layout.model_terrain import ModelTerrain
173
+        wrapper = ModelTerrain(**G_kwargs, **common_kwargs, **wrapper_kwargs)
174
+        G = wrapper.train().requires_grad_(False).to(device)
175
+    elif training_mode == 'sky':
176
+        common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
177
+        G_kwargs.enc_dim = 512
178
+        G_kwargs.training_mode = 'global-360'
179
+        G_kwargs.fov = int(training_set.fov_mean)
180
+        G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
181
+        D_kwargs.recon = False # no reconstruction part
182
+        D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
183
+        from models.sky.model_sky import ModelSky
184
+        wrapper = ModelSky(G)
185
+        G = wrapper.train().requires_grad_(False).to(device)
186
+
187
     G_ema = copy.deepcopy(G).eval()
188
 
189
     # Resume from existing pickle.
190
@@ -158,15 +249,33 @@
191
         print(f'Resuming from "{resume_pkl}"')
192
         with dnnlib.util.open_url(resume_pkl) as f:
193
             resume_data = legacy.load_network_pkl(f)
194
-        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
195
-            misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
196
+        if training_mode == 'layout' and loss_kwargs.loss_layout_kwargs.use_wrapped_discriminator:
197
+            for name, module in [('G', G), ('D', D.D_img), ('G_ema', G_ema)]:
198
+                # assume that resume is from a checkpoint without wrapped discriminator
199
+                misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
200
+        else:
201
+            for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
202
+                misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
203
 
204
     # Print network summary tables.
205
     if rank == 0:
206
         z = torch.empty([batch_gpu, G.z_dim], device=device)
207
         c = torch.empty([batch_gpu, G.c_dim], device=device)
208
-        img = misc.print_module_summary(G, [z, c])
209
-        misc.print_module_summary(D, [img, c])
210
+        if training_mode == 'layout': 
211
+            camera_params = camera_util.get_full_image_parameters(
212
+                G, G.layout_decoder_kwargs.nerf_out_res,
213
+                batch_size=batch_gpu, device=z.device, Rt=None)
214
+            img, _ = misc.print_module_summary(G, [z, c, camera_params])
215
+            img = torch.empty([batch_gpu, D.img_channels, D.img_resolution, D.img_resolution], device=device)
216
+            misc.print_module_summary(D, [img, c])
217
+        elif training_mode == 'upsampler':
218
+            img, thumb = misc.print_module_summary(G, [z, c])
219
+            img_for_D = torch.empty([batch_gpu, D.img_channels, D.img_resolution, D.img_resolution], device=device)
220
+            misc.print_module_summary(D, [img_for_D, c])
221
+        elif training_mode == 'sky':
222
+            ref_img = torch.empty([batch_gpu, G.img_channels, G.img_resolution, G.img_resolution], device=device)
223
+            img = misc.print_module_summary(G, [z, c, ref_img, ref_img])
224
+            misc.print_module_summary(D, [img, c])
225
 
226
     # Setup augmentation.
227
     if rank == 0:
228
@@ -183,17 +292,24 @@
229
     if rank == 0:
230
         print(f'Distributing across {num_gpus} GPUs...')
231
     for module in [G, D, G_ema, augment_pipe]:
232
-        if module is not None:
233
+        if module is not None and num_gpus > 1:
234
             for param in misc.params_and_buffers(module):
235
-                if param.numel() > 0 and num_gpus > 1:
236
-                    torch.distributed.broadcast(param, src=0)
237
+                torch.distributed.broadcast(param, src=0)
238
 
239
     # Setup training phases.
240
     if rank == 0:
241
         print('Setting up training phases...')
242
+    # loss kwargs needs training mode and loss kwargs by mode
243
+    loss_kwargs.training_mode = training_mode
244
     loss = dnnlib.util.construct_class_by_name(device=device, G=G, D=D, augment_pipe=augment_pipe, **loss_kwargs) # subclass of training.loss.Loss
245
     phases = []
246
     for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
247
+        if training_mode == 'sky' and name == 'G':
248
+            optim_params = module.G.parameters()
249
+        elif training_mode == 'upsampler' and name == 'G':
250
+            optim_params = module.upsampler.parameters()
251
+        else:
252
+            optim_params = module.parameters()
253
         if reg_interval is None:
254
             opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
255
             phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
256
@@ -222,8 +338,49 @@
257
         save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
258
         grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
259
         grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
260
-        images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
261
-        save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
262
+        # custom image saving for each training mode
263
+        if training_mode == 'layout':
264
+            images = defaultdict(list)
265
+            for z, c in zip(grid_z, grid_c):
266
+                camera_params = camera_util.get_full_image_parameters(
267
+                    G, G.layout_decoder_kwargs.nerf_out_res,
268
+                    batch_size=z.shape[0], device=z.device, Rt=None)
269
+                _, infos = G_ema(z=z, c=c, camera_params=camera_params, noise_mode='const')
270
+                for k, v in infos.items():
271
+                    if k in ['rgb', 'depth', 'acc']:
272
+                        images[k].append(v.cpu())
273
+            images = {k: torch.cat(v).numpy() for k, v in images.items()}
274
+            save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
275
+        elif training_mode == 'upsampler':
276
+            images_fake, images_fake_thumb = [], []
277
+            for z, c in zip(grid_z, grid_c):
278
+                im_fake, im_fake_thumb = G_ema(z=z, c=c, noise_mode='const')
279
+                images_fake.append(im_fake.cpu())
280
+                images_fake_thumb.append(im_fake_thumb.cpu())
281
+            images_fake = torch.cat(images_fake).numpy()
282
+            images_fake_thumb = torch.cat(images_fake_thumb).numpy()
283
+            images_fake_infos = {'rgb': images_fake[:, :3]}
284
+            images_fake_thumb_infos = {'rgb': images_fake_thumb[:, :3]}
285
+            if images_fake.shape[1] > 3:
286
+                images_fake_infos['depth'] = images_fake[:, 3:4]
287
+                images_fake_thumb_infos['depth'] = images_fake_thumb[:, 3:4]
288
+            if images_fake.shape[1] > 4:
289
+                images_fake_infos['acc'] = images_fake[:, 4:5]
290
+                images_fake_thumb_infos['acc'] = images_fake_thumb[:, 4:5]
291
+            save_image_grid(images_fake_infos, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
292
+            save_image_grid(images_fake_thumb_infos, os.path.join(run_dir, 'fakes_init_thumb.png'), drange=[-1,1], grid_size=grid_size)
293
+        elif training_mode == 'sky':
294
+            images_masked = torch.from_numpy((images['rgb'] / 127.5) - 1).to(device).split(batch_gpu)
295
+            images_acc = torch.from_numpy(images['acc']).to(device).split(batch_gpu)
296
+            images_fake = torch.cat([G_ema(z=z, c=c, img=im_masked, acc=im_acc, noise_mode='const').cpu()
297
+                                     for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
298
+            save_image_grid(images_fake, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
299
+            images_fake_nomask = torch.cat([G_ema(z=z, c=c, img=im_masked,
300
+                                                  acc=im_acc, multiply=False, noise_mode='const').cpu()
301
+                                     for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
302
+            save_image_grid(images_fake_nomask, os.path.join(run_dir, 'fakes_init_nomask.png'), drange=[-1,1], grid_size=grid_size)
303
+        # images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
304
+        # save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
305
 
306
     # Initialize logs.
307
     if rank == 0:
308
@@ -256,8 +413,30 @@
309
 
310
         # Fetch training data.
311
         with torch.autograd.profiler.record_function('data_fetch'):
312
-            phase_real_img, phase_real_c = next(training_set_iterator)
313
-            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
314
+            phase_real_infos, phase_real_c = next(training_set_iterator)
315
+            # masked image without sky
316
+            phase_real_img = (phase_real_infos['rgb'].to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
317
+            # full image with sky
318
+            phase_real_orig = (phase_real_infos['orig'].to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
319
+            # phase_real_depth is actually inverse depth if use_disp=True 
320
+            phase_real_depth = phase_real_infos['depth'].to(device).to(torch.float32).split(batch_gpu)
321
+            phase_real_acc = phase_real_infos['acc'].to(device).to(torch.float32).split(batch_gpu)
322
+            # zeroes the sky out, s.t. the sky maps exactly to zero for rgb and disparity
323
+            phase_real_img = [img * acc for img, acc in zip(phase_real_img, phase_real_acc)]
324
+            phase_real_depth = [depth * acc for depth, acc in zip(phase_real_depth, phase_real_acc)]
325
+            phase_real_K = phase_real_infos['K'].to(device).to(torch.float32).split(batch_gpu)
326
+            phase_real_Rt = phase_real_infos['Rt'].to(device).to(torch.float32).split(batch_gpu)
327
+            phase_real_size = phase_real_infos['global_size'].to(device).split(batch_gpu)
328
+            phase_real_fov = phase_real_infos['fov'].to(device).split(batch_gpu)
329
+            phase_real_infos = [dict(rgb=rgb, depth=depth, acc=acc, orig=orig,
330
+                                     camera_params=dict(K=K, Rt=Rt,
331
+                                     global_size=size, fov=fov))
332
+                                for (rgb, depth, acc, orig, K, Rt, size, fov)
333
+                                in zip(phase_real_img, phase_real_depth,
334
+                                       phase_real_acc, phase_real_orig,
335
+                                       phase_real_K, phase_real_Rt,
336
+                                       phase_real_size, phase_real_fov)]
337
+
338
             phase_real_c = phase_real_c.to(device).split(batch_gpu)
339
             all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
340
             all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
341
@@ -275,13 +454,13 @@
342
             # Accumulate gradients.
343
             phase.opt.zero_grad(set_to_none=True)
344
             phase.module.requires_grad_(True)
345
-            for real_img, real_c, gen_z, gen_c in zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c):
346
-                loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
347
+            for real_img_infos, real_c, gen_z, gen_c in zip(phase_real_infos, phase_real_c, phase_gen_z, phase_gen_c):
348
+                loss.accumulate_gradients(phase=phase.name, real_img_infos=real_img_infos, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
349
             phase.module.requires_grad_(False)
350
 
351
             # Update weights.
352
             with torch.autograd.profiler.record_function(phase.name + '_opt'):
353
-                params = [param for param in phase.module.parameters() if param.numel() > 0 and param.grad is not None]
354
+                params = [param for param in phase.module.parameters() if param.grad is not None]
355
                 if len(params) > 0:
356
                     flat = torch.cat([param.grad.flatten() for param in params])
357
                     if num_gpus > 1:
358
@@ -351,21 +530,66 @@
359
 
360
         # Save image snapshot.
361
         if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
362
-            images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
363
-            save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
364
+            # images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
365
+            # save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
366
+            # custom image saving for each training mode
367
+            if training_mode == 'layout':
368
+                images = defaultdict(list)
369
+                for z, c in zip(grid_z, grid_c):
370
+                    camera_params = camera_util.get_full_image_parameters(
371
+                        G, G.layout_decoder_kwargs.nerf_out_res,
372
+                        batch_size=z.shape[0], device=z.device, Rt=None)
373
+                    _, infos = G_ema(z=z, c=c, camera_params=camera_params, noise_mode='const')
374
+                    for k, v in infos.items():
375
+                        if k in ['rgb', 'depth', 'acc']:
376
+                            images[k].append(v.cpu())
377
+                images = {k: torch.cat(v).numpy() for k, v in images.items()}
378
+                save_image_grid(images,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
379
+            elif training_mode == 'upsampler':
380
+                images_fake, images_fake_thumb = [], []
381
+                for z, c in zip(grid_z, grid_c):
382
+                    im_fake, im_fake_thumb = G_ema(z=z, c=c, noise_mode='const')
383
+                    images_fake.append(im_fake.cpu())
384
+                    images_fake_thumb.append(im_fake_thumb.cpu())
385
+                images_fake = torch.cat(images_fake).numpy()
386
+                images_fake_thumb = torch.cat(images_fake_thumb).numpy()
387
+                images_fake_infos = {'rgb': images_fake[:, :3]}
388
+                images_fake_thumb_infos = {'rgb': images_fake_thumb[:, :3]}
389
+                if images_fake.shape[1] > 3:
390
+                    images_fake_infos['depth'] = images_fake[:, 3:4]
391
+                    images_fake_thumb_infos['depth'] = images_fake_thumb[:, 3:4]
392
+                if images_fake.shape[1] > 4:
393
+                    images_fake_infos['acc'] = images_fake[:, 4:5]
394
+                    images_fake_thumb_infos['acc'] = images_fake_thumb[:, 4:5]
395
+                save_image_grid(images_fake_infos,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
396
+                save_image_grid(images_fake_thumb_infos,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_thumb.png'),drange=[-1,1], grid_size=grid_size)
397
+            elif training_mode == 'sky':
398
+                images_masked = torch.from_numpy((images['rgb'] / 127.5) - 1).to(device).split(batch_gpu)
399
+                images_acc = torch.from_numpy(images['acc']).to(device).split(batch_gpu)
400
+                images_fake = torch.cat([G_ema(z=z, c=c, img=im_masked, acc=im_acc, noise_mode='const').cpu()
401
+                                         for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
402
+                save_image_grid(images_fake,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'),drange=[-1,1], grid_size=grid_size)
403
+                images_fake_nomask = torch.cat([G_ema(z=z, c=c, img=im_masked,
404
+                                                      acc=im_acc, multiply=False, noise_mode='const').cpu()
405
+                                         for z, c, im_masked, im_acc in zip(grid_z, grid_c, images_masked, images_acc)]).numpy()
406
+                save_image_grid(images_fake,  os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_nomask.png'),drange=[-1,1], grid_size=grid_size)
407
 
408
         # Save network snapshot.
409
         snapshot_pkl = None
410
         snapshot_data = None
411
         if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
412
-            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
413
-            for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
414
-                if module is not None:
415
+            snapshot_data = dict(G=G, D=D, G_ema=G_ema, augment_pipe=augment_pipe, training_set_kwargs=dict(training_set_kwargs))
416
+            for key, value in snapshot_data.items():
417
+                if isinstance(value, torch.nn.Module):
418
+                    value = copy.deepcopy(value).eval().requires_grad_(False)
419
+                    if training_mode == 'sky' and 'G' in key:
420
+                        value.encoder = None
421
                     if num_gpus > 1:
422
-                        misc.check_ddp_consistency(module, ignore_regex=r'.*\.[^.]+_(avg|ema)')
423
-                    module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
424
-                snapshot_data[name] = module
425
-                del module # conserve memory
426
+                        misc.check_ddp_consistency(value, ignore_regex=r'.*\.[^.]+_(avg|ema)')
427
+                        for param in misc.params_and_buffers(value):
428
+                            torch.distributed.broadcast(param, src=0)
429
+                    snapshot_data[key] = value.cpu()
430
+                del value # conserve memory
431
             snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
432
             if rank == 0:
433
                 with open(snapshot_pkl, 'wb') as f:
434
@@ -376,8 +600,10 @@
435
             if rank == 0:
436
                 print('Evaluating metrics...')
437
             for metric in metrics:
438
-                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
439
-                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
440
+                result_dict = metric_main.calc_metric(
441
+                    metric=metric, G=G_ema, # snapshot_data['G_ema'],
442
+                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
443
+                    rank=rank, device=device, training_mode=training_mode)
444
                 if rank == 0:
445
                     metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
446
                 stats_metrics.update(result_dict.results)
447

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.