google-research
67 строк · 3.2 Кб
1--- external/eg3d/training/volumetric_rendering/ray_marcher.py 2023-03-14 23:34:32.199358815 +0000
2+++ external_reference/eg3d/training/volumetric_rendering/ray_marcher.py 2023-03-14 23:28:05.416774738 +0000
3@@ -22,11 +22,12 @@
4super().__init__()
5
6
7- def run_forward(self, colors, densities, depths, rendering_options):
8+ def run_forward(self, colors, densities, depths, rendering_options, noise):
9deltas = depths[:, :, 1:] - depths[:, :, :-1]
10colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
11densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
12depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
13+ noise_mid = (noise[:, :, :-1] + noise[:, :, 1:]) / 2
14
15
16if rendering_options['clamp_mode'] == 'softplus':
17@@ -34,12 +35,23 @@
18else:
19assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
20
21+ # multiply weights by a decay factor that depends on depths
22+ # adjusts for popup effect at horizon
23+ if rendering_options['decay_start'] is not None:
24+ decay_start = rendering_options['decay_start']
25+ ray_end = rendering_options['ray_end']
26+ decay_weight = torch.clamp((depths_mid - decay_start) / (ray_end - decay_start), 0, 1)
27+ decay_weight = 1-decay_weight
28+ else:
29+ decay_weight = torch.ones_like(depths_mid)
30+
31density_delta = densities_mid * deltas
32
33alpha = 1 - torch.exp(-density_delta)
34
35alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
36weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
37+ weights = weights * decay_weight
38
39composite_rgb = torch.sum(weights * colors_mid, -2)
40weight_total = weights.sum(2)
41@@ -49,15 +61,21 @@
42composite_depth = torch.nan_to_num(composite_depth, float('inf'))
43composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
44
45+ # disparity map
46+ composite_disp = torch.sum(weights * 1/depths_mid, -2)
47+ composite_disp = torch.nan_to_num(composite_disp, float('inf'))
48+ composite_disp = torch.clamp(composite_disp, 0., 1.)
49+
50if rendering_options.get('white_back', False):
51composite_rgb = composite_rgb + 1 - weight_total
52
53composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
54
55- return composite_rgb, composite_depth, weights
56+ composite_noise = torch.sum(weights * noise_mid, -2)
57
58+ return composite_rgb, composite_depth, composite_disp, weights, composite_noise
59
60- def forward(self, colors, densities, depths, rendering_options):
61- composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
62
63- return composite_rgb, composite_depth, weights
64\ No newline at end of file
65+ def forward(self, colors, densities, depths, rendering_options, noise):
66+ composite_rgb, composite_depth, composite_disp, weights, composite_noise = self.run_forward(colors, densities, depths, rendering_options, noise)
67+ return composite_rgb, composite_depth, composite_disp, weights, composite_noise
68