google-research

Форк
0
240 строк · 9.3 Кб
1
--- external/gsn/models/generator.py	2023-03-28 16:58:33.300084813 +0000
2
+++ external_reference/gsn/models/generator.py	2023-03-28 16:00:32.533475741 +0000
3
@@ -5,10 +5,11 @@
4
 import torch
5
 from torch import nn
6
 
7
-from .layers import *
8
-from utils.utils import instantiate_from_config
9
-from .nerf_utils import get_sample_points, volume_render_radiance_field, sample_pdf_2
10
-
11
+# remove relative imports -- causes persistence error
12
+from external.gsn.models.layers import *
13
+from external.gsn.models.nerf_utils import get_sample_points, volume_render_radiance_field, sample_pdf_2
14
+from torch_utils import persistence
15
+import dnnlib
16
 
17
 class StyleGenerator2D(nn.Module):
18
     def __init__(self, out_res, out_ch, z_dim, ch_mul=1, ch_max=512, skip_conn=True):
19
@@ -140,7 +141,7 @@
20
             skip = self.out_rgb(out, z[i])
21
         return skip
22
 
23
-
24
+@persistence.persistent_class
25
 class NerfStyleGenerator(nn.Module):
26
     """NeRF MLP with style modulation.
27
 
28
@@ -175,17 +176,19 @@
29
 
30
         self.skips = skips
31
 
32
-        self.from_coords = PositionalEncoding(in_dim=3, frequency_bands=omega_coord)
33
-        self.from_dirs = PositionalEncoding(in_dim=3, frequency_bands=omega_dir)
34
+        self.shift_y = 1 # shift_y=1 prevents discontinuity on ground plane
35
+        viewdir_dim = 0 # ignore view directions
36
+        coord_dim = 3 # xyz
37
+
38
         self.n_layers = n_layers
39
 
40
         self.layers = nn.ModuleList(
41
-            [ModulationLinear(in_channel=self.from_coords.out_dim, out_channel=channels, z_dim=z_dim)]
42
+            [ModulationLinear(in_channel=coord_dim, out_channel=channels, z_dim=z_dim)]
43
         )
44
 
45
         for i in range(1, n_layers):
46
             if i in skips:
47
-                in_channels = channels + self.from_coords.out_dim
48
+                in_channels = channels + coord_dim
49
             else:
50
                 in_channels = channels
51
             self.layers.append(ModulationLinear(in_channel=in_channels, out_channel=channels, z_dim=z_dim))
52
@@ -195,7 +198,7 @@
53
         )
54
         self.fc_feat = ModulationLinear(in_channel=channels, out_channel=channels, z_dim=z_dim)
55
         self.fc_viewdir = ModulationLinear(
56
-            in_channel=channels + self.from_dirs.out_dim, out_channel=channels, z_dim=z_dim
57
+            in_channel=channels + viewdir_dim, out_channel=channels, z_dim=z_dim
58
         )
59
         self.fc_out = ModulationLinear(
60
             in_channel=channels, out_channel=out_channel, z_dim=z_dim, demodulate=False, activate=False
61
@@ -214,6 +217,14 @@
62
             z = [z[:, i] for i in range(n_latents)]
63
         return z
64
 
65
+    def extract_height(self, coords):
66
+         # CHANGED: removed positional encoding since it causes repeating
67
+         # patterns, and return height above the ground plane
68
+        encoding = coords.clone()
69
+        encoding[:, 0::3] = 0
70
+        encoding[:, 2::3] = 0
71
+        return encoding
72
+
73
     def forward(self, z, coords, viewdirs=None):
74
         """Forward pass.
75
 
76
@@ -234,7 +245,8 @@
77
             Occupancy values of shape [B, 1].
78
 
79
         """
80
-        coords = self.from_coords(coords)
81
+        coords[:, 1] += self.shift_y
82
+        coords = self.extract_height(coords)
83
         z = self.process_latents(z)
84
 
85
         h = coords
86
@@ -251,15 +263,14 @@
87
 
88
         h = self.fc_feat(h, z[i + 2])
89
 
90
-        viewdirs = self.from_dirs(viewdirs)
91
-        h = torch.cat([h, viewdirs], dim=-1)
92
+        # viewdirs = self.from_dirs(viewdirs)
93
+        # h = torch.cat([h, viewdirs], dim=-1)
94
 
95
         h = self.fc_viewdir(h, z[i + 3])
96
         out = self.fc_out(h, z[i + 4])
97
 
98
         return out, alpha
99
 
100
-
101
 class NerfSimpleGenerator(nn.Module):
102
     """NeRF MLP with with standard latent concatenation.
103
 
104
@@ -420,7 +431,8 @@
105
         rgb = torch.sigmoid(rgb)
106
         return rgb
107
 
108
-
109
+# CHANGED: keep track of inference feature resolution and adjust feature sampling
110
+@persistence.persistent_class
111
 class SceneGenerator(nn.Module):
112
     """NeRF scene generator.
113
 
114
@@ -466,6 +478,7 @@
115
         local_coordinates=True,
116
         hierarchical_sampling=False,
117
         density_bias=0,
118
+        use_disp=True,
119
         **kwargs
120
     ):
121
         super().__init__()
122
@@ -478,9 +491,12 @@
123
         self.local_coordinates = local_coordinates
124
         self.hierarchical_sampling = hierarchical_sampling
125
         self.density_bias = density_bias
126
-        self.out_dim = nerf_mlp_config.params.out_channel
127
+        self.out_dim = nerf_mlp_config.out_channel
128
+        self.use_disp = use_disp
129
+
130
+        self.local_generator = dnnlib.util.construct_class_by_name(**nerf_mlp_config)
131
+        self.inference_feat_res = None
132
 
133
-        self.local_generator = instantiate_from_config(nerf_mlp_config)
134
 
135
     def get_local_coordinates(self, global_coords, local_grid_length, preserve_y=True):
136
         local_coords = global_coords.clone()
137
@@ -490,7 +506,16 @@
138
         # scale so that each grid cell in the local_latent grid is 1x1 in size
139
         local_coords = local_coords * local_grid_length
140
         # subtract integer from each coordinate so that they are all in range [0, 1]
141
-        local_coords = local_coords - (local_coords - 0.5).round()
142
+        if self.inference_feat_res is not None:
143
+            offset = (self.inference_feat_res - self.global_feat_res) // 2
144
+            local_coords = local_coords + offset
145
+            # now local_coords should be in range [0, inference_feat_res]
146
+            local_coords = local_coords - local_coords.clip(0, self.inference_feat_res).floor()
147
+        else:
148
+            # clip st. input if different if local coords goes off grid
149
+            local_coords = local_coords - local_coords.clip(0, local_grid_length).floor()
150
+            # local_coords = local_coords - (local_coords - 0.5).round()
151
+
152
         # return to [-1, 1] scale
153
         local_coords = (local_coords * 2) - 1
154
 
155
@@ -516,6 +541,11 @@
156
 
157
         samples_per_ray = xyz.shape[2]
158
 
159
+        if self.inference_feat_res is not None:
160
+            # adjust the coordinates for grid sampling the local latents
161
+            assert(self.inference_feat_res == local_latents.shape[-1])
162
+            xyz = xyz * self.global_feat_res / self.inference_feat_res
163
+
164
         # all samples get the most detailed latent codes
165
         sampled_local_latents = nn.functional.grid_sample(
166
             input=local_latents,
167
@@ -542,6 +572,12 @@
168
             # this tries to get all input coordinates to lie within [-1, 1]
169
             xyz = xyz / (self.coordinate_scale / 2)
170
 
171
+        if local_latents.shape[-1] != self.global_feat_res:
172
+            # coordinate adjustment for inference time sampling
173
+            self.inference_feat_res = local_latents.shape[-1]
174
+        else:
175
+            self.inference_feat_res = None
176
+
177
         B, n_samples, samples_per_ray, _ = xyz.shape  # n_samples = H * W
178
         sampled_local_latents, local_latents = self.sample_local_latents(local_latents, xyz=xyz)
179
 
180
@@ -591,10 +627,8 @@
181
             H, W = render_params.nerf_out_res, render_params.nerf_out_res
182
 
183
             # if using feature-NeRF, need to adjust camera intrinsics to account for lower sampling resolution
184
-            if self.img_res is not None:
185
-                downsampling_ratio = render_params.nerf_out_res / self.img_res
186
-            else:
187
-                downsampling_ratio = 1
188
+            # note: intrinsics K should be adjusted to match output resolution of nerf
189
+            downsampling_ratio = 1
190
             fx, fy = render_params.K[0, 0, 0] * downsampling_ratio, render_params.K[0, 1, 1] * downsampling_ratio
191
             xyz, viewdirs, z_vals, rd, ro = get_sample_points(
192
                 tform_cam2world=render_params.Rt.inverse(),
193
@@ -618,7 +652,7 @@
194
             return alpha_coarse
195
 
196
         if self.hierarchical_sampling:
197
-            _, _, _, weights, _, occupancy_prior = volume_render_radiance_field(
198
+            _, _, _, weights, _, occupancy_prior, _ = volume_render_radiance_field(
199
                 rgb=rgb_coarse,
200
                 occupancy=alpha_coarse,
201
                 depth_values=z_vals,
202
@@ -627,6 +661,7 @@
203
                 alpha_activation=self.alpha_activation,
204
                 activate_rgb=not self.feature_nerf,
205
                 density_bias=self.density_bias,
206
+                use_disp=self.use_disp
207
             )
208
 
209
             z_vals_fine = self.importance_sampling(z_vals, weights, render_params.samples_per_ray)
210
@@ -648,7 +683,8 @@
211
             rgb, alpha = rgb_coarse, alpha_coarse
212
             z_vals = z_vals
213
 
214
-        rgb, _, _, _, depth, occupancy_prior = volume_render_radiance_field(
215
+        # note: if use_disp=True, then the depth output is inverse depth
216
+        rgb, disp, acc, weights, depth, occupancy_prior, extras = volume_render_radiance_field(
217
             rgb=rgb,
218
             occupancy=alpha,
219
             depth_values=z_vals,
220
@@ -657,14 +693,20 @@
221
             alpha_activation=self.alpha_activation,
222
             activate_rgb=not self.feature_nerf,
223
             density_bias=self.density_bias,
224
+            use_disp=self.use_disp
225
         )
226
 
227
         out = {
228
             'rgb': rgb,
229
             'depth': depth,
230
+            'acc': acc,
231
             'Rt': render_params.Rt,
232
             'K': render_params.K,
233
             'local_latents': local_latents,
234
             'occupancy_prior': occupancy_prior,
235
+            'xyz': xyz,  # also return the sampled points
236
+            'weights': weights, # and weights at each point
237
+            'alpha': extras['alpha'], # for opacity regularization
238
+            'dists': extras['dists'], # for opacity regularization
239
         }
240
         return out
241

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.