google-research
218 строк · 13.2 Кб
1--- external/eg3d/training/volumetric_rendering/renderer.py 2023-03-14 23:32:13.458508000 +0000
2+++ external_reference/eg3d/training/volumetric_rendering/renderer.py 2023-03-14 23:28:05.439774713 +0000
3@@ -17,8 +17,9 @@
4import torch
5import torch.nn as nn
6
7-from training.volumetric_rendering.ray_marcher import MipRayMarcher2
8-from training.volumetric_rendering import math_utils
9+from external.eg3d.training.volumetric_rendering.ray_marcher import MipRayMarcher2
10+from external.eg3d.training.volumetric_rendering import math_utils
11+from utils import noise_util
12
13def generate_planes():
14"""
15@@ -26,15 +27,15 @@
16plane. Should work with arbitrary number of planes and planes of
17arbitrary orientation.
18"""
19- return torch.tensor([[[1, 0, 0],
20+ return torch.tensor([[[1, 0, 0], # XY
21[0, 1, 0],
22[0, 0, 1]],
23- [[1, 0, 0],
24+ [[1, 0, 0], #XZ
25[0, 0, 1],
26[0, 1, 0]],
27- [[0, 0, 1],
28- [1, 0, 0],
29- [0, 1, 0]]], dtype=torch.float32)
30+ [[0, 1, 0], # YZ
31+ [0, 0, 1],
32+ [1, 0, 0]]], dtype=torch.float32)
33
34def project_onto_planes(planes, coordinates):
35"""
36@@ -50,16 +51,17 @@
37coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
38inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
39projections = torch.bmm(coordinates, inv_planes)
40- return projections[..., :2]
41+ return projections[..., :2] # projections are ordered (0,1,2) corresp to batch [0]
42
43-def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
44- assert padding_mode == 'zeros'
45- N, n_planes, C, H, W = plane_features.shape
46- _, M, _ = coordinates.shape
47+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='reflection', box_warp=None):
48+ assert padding_mode == 'reflection'
49+ N, n_planes, C, H, W = plane_features.shape # bs x num_samples x C x H x W
50+ _, M, _ = coordinates.shape # bs x (render_h * render_w * samples) x 3
51plane_features = plane_features.view(N*n_planes, C, H, W)
52
53coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
54
55+ # shape = [bs*n_planes, 1, render_h * render_w * samples, 2]
56projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
57output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
58return output_features
59@@ -85,7 +87,7 @@
60self.ray_marcher = MipRayMarcher2()
61self.plane_axes = generate_planes()
62
63- def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options):
64+ def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, noise_input=None):
65self.plane_axes = self.plane_axes.to(ray_origins.device)
66
67if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto':
68@@ -94,10 +96,10 @@
69if torch.any(is_ray_valid).item():
70ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
71ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
72- depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
73+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
74else:
75# Create stratified depth samples
76- depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
77+ depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], rendering_options['sample_deterministic'])
78
79batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
80
81@@ -105,19 +107,30 @@
82sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
83sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
84
85+ # prevent rays from exceeding Y clip value
86+ # removes sky blobs when extending far bound at inference
87+ if rendering_options['y_clip'] is not None:
88+ # limit the depth of the ray so that it does not surpass y_clip
89+ y_clip = rendering_options['y_clip']
90+ max_depth = (y_clip - ray_origins[..., 1]) / ray_directions[..., 1]
91+ max_depth = max_depth[..., None, None] # B, HW, num_samples, 1)
92+ depths_clip = torch.where(max_depth > 0, torch.minimum(depths_coarse, max_depth), depths_coarse)
93+ depths_coarse = depths_clip # replace coarse depths with clipped depth
94+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths_clip * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
95
96out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
97colors_coarse = out['rgb']
98densities_coarse = out['sigma']
99colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1])
100densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1)
101+ noise_coarse = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, samples_per_ray, 1)
102
103# Fine Pass
104N_importance = rendering_options['depth_resolution_importance']
105if N_importance > 0:
106- _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
107+ _, _, _, weights, _ = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
108
109- depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
110+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance, rendering_options['sample_deterministic'])
111
112sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
113sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
114@@ -127,20 +140,19 @@
115densities_fine = out['sigma']
116colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1])
117densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1)
118+ noise_fine = noise_util.sample_noise(noise_input, sample_coordinates).reshape(batch_size, num_rays, N_importance, 1)
119
120- all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
121- depths_fine, colors_fine, densities_fine)
122+ all_depths, all_colors, all_densities, all_noise = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, noise_coarse, depths_fine, colors_fine, densities_fine, noise_fine)
123
124# Aggregate
125- rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
126+ rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, all_noise)
127else:
128- rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
129-
130+ rgb_final, depth_final, disp_final, weights, noise_final = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, noise_coarse)
131
132- return rgb_final, depth_final, weights.sum(2)
133+ return rgb_final, depth_final, disp_final, weights.sum(2), noise_final
134
135def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
136- sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
137+ sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='reflection', box_warp=options['box_warp'])
138
139out = decoder(sampled_features, sample_directions)
140if options.get('density_noise', 0) > 0:
141@@ -154,19 +166,22 @@
142all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
143return all_depths, all_colors, all_densities
144
145- def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
146+ def unify_samples(self, depths1, colors1, densities1, noise1,
147+ depths2, colors2, densities2, noise2):
148all_depths = torch.cat([depths1, depths2], dim = -2)
149all_colors = torch.cat([colors1, colors2], dim = -2)
150all_densities = torch.cat([densities1, densities2], dim = -2)
151+ all_noise = torch.cat([noise1, noise2], dim = -2)
152
153_, indices = torch.sort(all_depths, dim=-2)
154all_depths = torch.gather(all_depths, -2, indices)
155all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
156all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
157+ all_noise = torch.gather(all_noise, -2, indices.expand(-1, -1, -1, 1))
158
159- return all_depths, all_colors, all_densities
160+ return all_depths, all_colors, all_densities, all_noise
161
162- def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
163+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False, sample_deterministic=False):
164"""
165Return depths of approximately uniformly spaced samples along rays.
166"""
167@@ -177,21 +192,30 @@
168depth_resolution,
169device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
170depth_delta = 1/(depth_resolution - 1)
171- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
172+ if not sample_deterministic:
173+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
174+ else:
175+ depths_coarse += 0.5 * depth_delta
176depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
177else:
178if type(ray_start) == torch.Tensor:
179depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
180depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
181- depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
182+ if not sample_deterministic:
183+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
184+ else:
185+ depths_coarse += 0.5 * depth_delta[..., None]
186else:
187depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
188depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
189- depths_coarse += torch.rand_like(depths_coarse) * depth_delta
190+ if not sample_deterministic:
191+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
192+ else:
193+ depths_coarse += 0.5 * depth_delta
194
195return depths_coarse
196
197- def sample_importance(self, z_vals, weights, N_importance):
198+ def sample_importance(self, z_vals, weights, N_importance, sample_deterministic=False):
199"""
200Return depths of importance sampled points along rays. See NeRF importance sampling for more.
201"""
202@@ -207,8 +231,7 @@
203weights = weights + 0.01
204
205z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
206- importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
207- N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
208+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], N_importance, det=sample_deterministic).detach().reshape(batch_size, num_rays, N_importance, 1)
209return importance_z_vals
210
211def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
212@@ -250,4 +273,4 @@
213# anyway, therefore any value for it is fine (set to 1 here)
214
215samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
216- return samples
217\ No newline at end of file
218+ return samples
219