google-research

Форк
0
218 строк · 13.2 Кб
1
--- external/eg3d/training/volumetric_rendering/renderer.py	2023-03-14 23:32:13.458508000 +0000
2
+++ external_reference/eg3d/training/volumetric_rendering/renderer.py	2023-03-14 23:28:05.439774713 +0000
3
@@ -17,8 +17,9 @@
4
 import torch
5
 import torch.nn as nn
6
 
7
-from training.volumetric_rendering.ray_marcher import MipRayMarcher2
8
-from training.volumetric_rendering import math_utils
9
+from external.eg3d.training.volumetric_rendering.ray_marcher import MipRayMarcher2
10
+from external.eg3d.training.volumetric_rendering import math_utils
11
+from utils import noise_util
12
 
13
 def generate_planes():
14
     """
15
@@ -26,15 +27,15 @@
16
     plane. Should work with arbitrary number of planes and planes of
17
     arbitrary orientation.
18
     """
19
-    return torch.tensor([[[1, 0, 0],
20
+    return torch.tensor([[[1, 0, 0], # XY
21
                             [0, 1, 0],
22
                             [0, 0, 1]],
23
-                            [[1, 0, 0],
24
+                            [[1, 0, 0], #XZ
25
                             [0, 0, 1],
26
                             [0, 1, 0]],
27
-                            [[0, 0, 1],
28
-                            [1, 0, 0],
29
-                            [0, 1, 0]]], dtype=torch.float32)
30
+                            [[0, 1, 0], # YZ
31
+                            [0, 0, 1],
32
+                            [1, 0, 0]]], dtype=torch.float32)
33
 
34
 def project_onto_planes(planes, coordinates):
35
     """
36
@@ -50,16 +51,17 @@
37
     coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
38
     inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
39
     projections = torch.bmm(coordinates, inv_planes)
40
-    return projections[..., :2]
41
+    return projections[..., :2] # projections are ordered (0,1,2) corresp to batch [0]
42
 
43
-def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
44
-    assert padding_mode == 'zeros'
45
-    N, n_planes, C, H, W = plane_features.shape
46
-    _, M, _ = coordinates.shape
47
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='reflection', box_warp=None):
48
+    assert padding_mode == 'reflection' 
49
+    N, n_planes, C, H, W = plane_features.shape # bs x num_samples x C x H x W
50
+    _, M, _ = coordinates.shape # bs x (render_h * render_w * samples) x 3
51
     plane_features = plane_features.view(N*n_planes, C, H, W)
52
 
53
     coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
54
 
55
+    # shape = [bs*n_planes, 1, render_h * render_w * samples, 2]
56
     projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
57
     output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
58
     return output_features
59
@@ -85,7 +87,7 @@
60
         self.ray_marcher = MipRayMarcher2()
61
         self.plane_axes = generate_planes()
62
 
63
-    def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
64
+    def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, noise_input=None):
65
         self.plane_axes = self.plane_axes.to(ray_origins.device)
66
 
67
         if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
68
@@ -94,10 +96,10 @@
69
             if torch.any(is_ray_valid).item():
70
                 ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
71
                 ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
72
-            depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
73
+            depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
74
         else:
75
             # Create stratified depth samples
76
-            depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
77
+            depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
78
 
79
         batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
80
 
81
@@ -105,19 +107,30 @@
82
         sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
83
         sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
84
 
85
+        # prevent rays from exceeding Y clip value
86
+        # removes sky blobs when extending far bound at inference
87
+        if rendering_options['y_clip'] is not None:
88
+            # limit the depth of the ray so that it does not surpass y_clip
89
+            y_clip = rendering_options['y_clip']
90
+            max_depth = (y_clip - ray_origins[..., 1]) / ray_directions[..., 1]
91
+            max_depth = max_depth[..., None, None] # B, HW, num_samples, 1)
92
+            depths_clip = torch.where(max_depth > 0, torch.minimum(depths_coarse, max_depth), depths_coarse)
93
+            depths_coarse = depths_clip # replace coarse depths with clipped depth
94
+            sample_coordinates = (ray_origins.unsqueeze(-2) + depths_clip * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
95
 
96
         out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
97
         colors_coarse = out['rgb']
98
         densities_coarse = out['sigma']
99
         colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1])
100
         densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1)
101
+        noise_coarse = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, samples_per_ray, 1)
102
 
103
         # Fine Pass
104
         N_importance = rendering_options['depth_resolution_importance']
105
         if N_importance > 0:
106
-            _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
107
+            _, _, _, weights, _ = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
108
 
109
-            depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
110
+            depths_fine = self.sample_importance(depths_coarse, weights, N_importance, rendering_options['sample_deterministic'])
111
 
112
             sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
113
             sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
114
@@ -127,20 +140,19 @@
115
             densities_fine = out['sigma']
116
             colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1])
117
             densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1)
118
+            noise_fine = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, N_importance, 1)
119
 
120
-            all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
121
-                                                                  depths_fine, colors_fine, densities_fine)
122
+            all_depths, all_colors, all_densities, all_noise = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, noise_coarse, depths_fine, colors_fine, densities_fine, noise_fine)
123
 
124
             # Aggregate
125
-            rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
126
+            rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, all_noise)
127
         else:
128
-            rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
129
-
130
+            rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
131
 
132
-        return rgb_final, depth_final, weights.sum(2)
133
+        return rgb_final, depth_final, disp_final, weights.sum(2), noise_final
134
 
135
     def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
136
-        sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
137
+        sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='reflection', box_warp=options['box_warp'])
138
 
139
         out = decoder(sampled_features, sample_directions)
140
         if options.get('density_noise', 0) > 0:
141
@@ -154,19 +166,22 @@
142
         all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
143
         return all_depths, all_colors, all_densities
144
 
145
-    def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
146
+    def unify_samples(self, depths1, colors1, densities1, noise1,
147
+                      depths2, colors2, densities2, noise2):
148
         all_depths = torch.cat([depths1, depths2], dim = -2)
149
         all_colors = torch.cat([colors1, colors2], dim = -2)
150
         all_densities = torch.cat([densities1, densities2], dim = -2)
151
+        all_noise = torch.cat([noise1, noise2], dim = -2)
152
 
153
         _, indices = torch.sort(all_depths, dim=-2)
154
         all_depths = torch.gather(all_depths, -2, indices)
155
         all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
156
         all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
157
+        all_noise = torch.gather(all_noise, -2, indices.expand(-1, -1, -1, 1))
158
 
159
-        return all_depths, all_colors, all_densities
160
+        return all_depths, all_colors, all_densities, all_noise
161
 
162
-    def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
163
+    def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False, sample_deterministic=False):
164
         """
165
         Return depths of approximately uniformly spaced samples along rays.
166
         """
167
@@ -177,21 +192,30 @@
168
                                     depth_resolution,
169
                                     device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
170
             depth_delta = 1/(depth_resolution - 1)
171
-            depths_coarse += torch.rand_like(depths_coarse) * depth_delta
172
+            if not sample_deterministic:
173
+                depths_coarse += torch.rand_like(depths_coarse) * depth_delta
174
+            else:
175
+                depths_coarse += 0.5 * depth_delta
176
             depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
177
         else:
178
             if type(ray_start) == torch.Tensor:
179
                 depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
180
                 depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
181
-                depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
182
+                if not sample_deterministic:
183
+                    depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
184
+                else:
185
+                    depths_coarse += 0.5 * depth_delta[..., None]
186
             else:
187
                 depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
188
                 depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
189
-                depths_coarse += torch.rand_like(depths_coarse) * depth_delta
190
+                if not sample_deterministic:
191
+                    depths_coarse += torch.rand_like(depths_coarse) * depth_delta
192
+                else:
193
+                    depths_coarse += 0.5 * depth_delta
194
 
195
         return depths_coarse
196
 
197
-    def sample_importance(self, z_vals, weights, N_importance):
198
+    def sample_importance(self, z_vals, weights, N_importance, sample_deterministic=False):
199
         """
200
         Return depths of importance sampled points along rays. See NeRF importance sampling for more.
201
         """
202
@@ -207,8 +231,7 @@
203
             weights = weights + 0.01
204
 
205
             z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
206
-            importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
207
-                                             N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
208
+            importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], N_importance, det=sample_deterministic).detach().reshape(batch_size, num_rays, N_importance, 1)
209
         return importance_z_vals
210
 
211
     def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
212
@@ -250,4 +273,4 @@
213
                              # anyway, therefore any value for it is fine (set to 1 here)
214
 
215
         samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
216
-        return samples
217
\ No newline at end of file
218
+        return samples
219

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.