google-research

Форк
0
112 строк · 5.4 Кб
1
--- external/eg3d/training/training_loop.py	2023-04-06 03:56:36.237864870 +0000
2
+++ external_reference/eg3d/training/training_loop.py	2023-04-06 03:41:04.630726517 +0000
3
@@ -26,9 +26,9 @@
4
 from torch_utils.ops import grid_sample_gradfix
5
 
6
 import legacy
7
-from metrics import metric_main
8
-from camera_utils import LookAtPoseSampler
9
-from training.crosssection_utils import sample_cross_section
10
+from external.stylegan.metrics import metric_main
11
+# from camera_utils import LookAtPoseSampler
12
+# from training.crosssection_utils import sample_cross_section
13
 
14
 #----------------------------------------------------------------------------
15
 
16
@@ -72,23 +72,33 @@
17
 
18
 #----------------------------------------------------------------------------
19
 
20
-def save_image_grid(img, fname, drange, grid_size):
21
+# modified to save RGB image and depth channel separately
22
+def save_image_grid(image_and_depth, fname, drange, grid_size):
23
     lo, hi = drange
24
-    img = np.asarray(img, dtype=np.float32)
25
-    img = (img - lo) * (255 / (hi - lo))
26
-    img = np.rint(img).clip(0, 255).astype(np.uint8)
27
-
28
-    gw, gh = grid_size
29
-    _N, C, H, W = img.shape
30
-    img = img.reshape([gh, gw, C, H, W])
31
-    img = img.transpose(0, 3, 1, 4, 2)
32
-    img = img.reshape([gh * H, gw * W, C])
33
-
34
-    assert C in [1, 3]
35
-    if C == 1:
36
-        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
37
-    if C == 3:
38
-        PIL.Image.fromarray(img, 'RGB').save(fname)
39
+
40
+    depth = image_and_depth[:, 3:4]
41
+    img = image_and_depth[:, :3]
42
+    for idx, img in enumerate((image_and_depth[:, :3], image_and_depth[:, 3:4])):
43
+        if idx == 1: # depth image
44
+            lo, hi = 0, 1
45
+        else:
46
+            lo, hi = drange
47
+
48
+        img = np.asarray(img, dtype=np.float32)
49
+        img = (img - lo) * (255 / (hi - lo))
50
+        img = np.rint(img).clip(0, 255).astype(np.uint8)
51
+
52
+        gw, gh = grid_size
53
+        _N, C, H, W = img.shape
54
+        img = img.reshape([gh, gw, C, H, W])
55
+        img = img.transpose(0, 3, 1, 4, 2)
56
+        img = img.reshape([gh * H, gw * W, C])
57
+
58
+        assert C in [1, 3]
59
+        if C == 1:
60
+            PIL.Image.fromarray(img[:, :, 0], 'L').save(fname.replace('.png', '-disp.png'))
61
+        if C == 3:
62
+            PIL.Image.fromarray(img, 'RGB').save(fname.replace('.png', '-rgb.png'))
63
 
64
 #----------------------------------------------------------------------------
65
 
66
@@ -138,6 +148,9 @@
67
     conv2d_gradfix.enabled = True                       # Improves training speed. # TODO: ENABLE
68
     grid_sample_gradfix.enabled = False                  # Avoids errors with the augmentation pipe.
69
 
70
+    # added to prevent data_loader pin_memory to load to device 0 for every process
71
+    torch.cuda.set_device(device)
72
+
73
     # Load training set.
74
     if rank == 0:
75
         print('Loading training set...')
76
@@ -262,7 +275,9 @@
77
         # Fetch training data.
78
         with torch.autograd.profiler.record_function('data_fetch'):
79
             phase_real_img, phase_real_c = next(training_set_iterator)
80
-            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
81
+            # convert rgb img from [0, 255] to [-1, 1]
82
+            phase_real_img[:, :3] = phase_real_img[:, :3] / 127.5 - 1
83
+            phase_real_img = (phase_real_img.to(device).to(torch.float32)).split(batch_gpu)
84
             phase_real_c = phase_real_c.to(device).split(batch_gpu)
85
             all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
86
             all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
87
@@ -361,10 +376,10 @@
88
             out = [G_ema(z=z, c=c, noise_mode='const') for z, c in zip(grid_z, grid_c)]
89
             images = torch.cat([o['image'].cpu() for o in out]).numpy()
90
             images_raw = torch.cat([o['image_raw'].cpu() for o in out]).numpy()
91
-            images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy()
92
+            # images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy()
93
             save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
94
             save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_raw.png'), drange=[-1,1], grid_size=grid_size)
95
-            save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size)
96
+            # save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size)
97
 
98
             #--------------------
99
             # # Log forward-conditioned images
100
@@ -414,8 +429,10 @@
101
                 print(run_dir)
102
                 print('Evaluating metrics...')
103
             for metric in metrics:
104
-                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
105
-                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
106
+                result_dict = metric_main.calc_metric(
107
+                    metric=metric, G=snapshot_data['G_ema'],
108
+                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
109
+                    rank=rank, device=device, training_mode='triplane')
110
                 if rank == 0:
111
                     metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
112
                 stats_metrics.update(result_dict.results)
113

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.