google-research

Форк
0
153 строки · 10.1 Кб
1
--- external/eg3d/training/triplane.py	2023-03-31 20:15:14.373619528 +0000
2
+++ external_reference/eg3d/training/triplane.py	2023-03-31 20:14:21.579907466 +0000
3
@@ -10,11 +10,13 @@
4
 
5
 import torch
6
 from torch_utils import persistence
7
-from training.networks_stylegan2 import Generator as StyleGAN2Backbone
8
-from training.volumetric_rendering.renderer import ImportanceRenderer
9
-from training.volumetric_rendering.ray_sampler import RaySampler
10
+from external.stylegan.training.networks_stylegan2_terrain import Generator as StyleGAN2Backbone
11
+from external.eg3d.training.volumetric_rendering.renderer import ImportanceRenderer
12
+from external.eg3d.training.volumetric_rendering.ray_sampler import RaySampler
13
 import dnnlib
14
 
15
+from utils import noise_util
16
+
17
 @persistence.persistent_class
18
 class TriPlaneGenerator(torch.nn.Module):
19
     def __init__(self,
20
@@ -23,6 +25,7 @@
21
         w_dim,                      # Intermediate latent (W) dimensionality.
22
         img_resolution,             # Output resolution.
23
         img_channels,               # Number of output color channels.
24
+        plane_resolution    = 256,  # Resolution of feature planes 
25
         sr_num_fp16_res     = 0,
26
         mapping_kwargs      = {},   # Arguments for MappingNetwork.
27
         rendering_kwargs    = {},
28
@@ -35,22 +38,33 @@
29
         self.w_dim=w_dim
30
         self.img_resolution=img_resolution
31
         self.img_channels=img_channels
32
+        self.plane_resolution=plane_resolution
33
         self.renderer = ImportanceRenderer()
34
         self.ray_sampler = RaySampler()
35
-        self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
36
+        self.backbone_xy = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=plane_resolution, img_channels=32, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
37
+        self.backbone_xz = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=plane_resolution, img_channels=32, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
38
+        self.backbone_yz = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=plane_resolution, img_channels=32, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
39
+        # mapping network is only from backbone_xy
40
+        self.backbone_xz.mapping = None
41
+        self.backbone_yz.mapping = None
42
         self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
43
         self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': 32})
44
         self.neural_rendering_resolution = 64
45
         self.rendering_kwargs = rendering_kwargs
46
-    
47
+
48
         self._last_planes = None
49
-    
50
+        self.noise_generator = noise_util.NoiseGenerator(
51
+            plane_resolution, plane_resolution, plane_resolution,
52
+            rendering_kwargs['box_warp'], rendering_kwargs['box_warp'], rendering_kwargs['box_warp'])
53
+
54
     def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
55
         if self.rendering_kwargs['c_gen_conditioning_zero']:
56
                 c = torch.zeros_like(c)
57
-        return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
58
+        return self.backbone_xy.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
59
 
60
-    def synthesis(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
61
+    def synthesis(self, ws, c, neural_rendering_resolution=None,
62
+                  update_emas=False, cache_backbone=False, use_cached_backbone=False,
63
+                  noise_input=None, **synthesis_kwargs):
64
         cam2world_matrix = c[:, :16].view(-1, 4, 4)
65
         intrinsics = c[:, 16:25].view(-1, 3, 3)
66
 
67
@@ -67,38 +81,67 @@
68
         if use_cached_backbone and self._last_planes is not None:
69
             planes = self._last_planes
70
         else:
71
-            planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
72
+            plane_xy = self.backbone_xy.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
73
+            plane_xz = self.backbone_xz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
74
+            plane_yz = self.backbone_yz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
75
+            planes = torch.stack([plane_xy, plane_xz, plane_yz], dim=1) # N x 3 x 32 x H x W
76
         if cache_backbone:
77
             self._last_planes = planes
78
 
79
         # Reshape output into three 32-channel planes
80
-        planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
81
+        # planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
82
+
83
+        if noise_input is None:
84
+            noise_input = self.noise_generator.get_noise(ws.shape[0], ws.device)
85
 
86
         # Perform volume rendering
87
-        feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
88
+        feature_samples, depth_samples, disp_samples, weights_samples, noise_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, noise_input=noise_input) # channels last
89
 
90
         # Reshape into 'raw' neural-rendered image
91
         H = W = self.neural_rendering_resolution
92
         feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
93
         depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
94
+        disp_image = disp_samples.permute(0, 2, 1).reshape(N, 1, H, W)
95
+        weights_image = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W)
96
+        noise_image = noise_samples.permute(0, 2, 1).reshape(N, 1, H, W)
97
 
98
         # Run superresolution to get final image
99
         rgb_image = feature_image[:, :3]
100
-        sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
101
-
102
-        return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
103
+        rgb_and_disp_image = torch.cat([rgb_image, disp_image], dim=1)
104
+        # self.superresolution will modify rgb_and_disp_image in place
105
+        # if size=128 --> make a copy of it before input to SR network
106
+        sr_image_and_disp = self.superresolution(rgb_and_disp_image.clone(), feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], noise_input=noise_image, **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
107
+        sr_image = sr_image_and_disp[:, :3]
108
+        sr_disp = sr_image_and_disp[:, 3:]
109
+
110
+        return {'image': sr_image_and_disp,
111
+                'disp': sr_disp,
112
+                'image_raw': rgb_and_disp_image, # rgb_image,
113
+                'depth_raw': depth_image,
114
+                'disp_raw': disp_image,
115
+                'weights_raw': weights_image,
116
+                'noise_raw': noise_image
117
+               }
118
     
119
     def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
120
         # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. 
121
         ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
122
-        planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
123
-        planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
124
+        # planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
125
+        # planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
126
+        plane_xy = self.backbone_xy.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
127
+        plane_xz = self.backbone_xz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
128
+        plane_yz = self.backbone_yz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
129
+        planes = torch.stack([plane_xy, plane_xz, plane_yz], dim=1) # N x 3 x 32 x H x W
130
         return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
131
 
132
     def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
133
         # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
134
-        planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
135
-        planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
136
+        # planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
137
+        # planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
138
+        plane_xy = self.backbone_xy.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
139
+        plane_xz = self.backbone_xz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
140
+        plane_yz = self.backbone_yz.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
141
+        planes = torch.stack([plane_xy, plane_xz, plane_yz], dim=1) # N x 3 x 32 x H x W
142
         return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
143
 
144
     def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
145
@@ -107,7 +150,7 @@
146
         return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
147
 
148
 
149
-from training.networks_stylegan2 import FullyConnectedLayer
150
+from external.stylegan.training.networks_stylegan2_terrain import FullyConnectedLayer
151
 
152
 class OSGDecoder(torch.nn.Module):
153
     def __init__(self, n_features, options):
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.