google-research
353 строки · 12.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Different model implementation plus a general port for all the models."""
17from typing import Any, Callable
18from flax import linen as nn
19from jax import random
20import jax.numpy as jnp
21
22from snerg.nerf import model_utils
23from snerg.nerf import utils
24
25
26def get_model(key, example_batch, args):
27"""A helper function that wraps around a 'model zoo'."""
28model_dict = {
29"nerf": construct_nerf,
30}
31return model_dict[args.model](key, example_batch, args)
32
33
34class NerfModel(nn.Module):
35"""Nerf NN Model with both coarse and fine MLPs."""
36num_coarse_samples: int # The number of samples for the coarse nerf.
37num_fine_samples: int # The number of samples for the fine nerf.
38use_viewdirs: bool # If True, use viewdirs as an input.
39near: float # The distance to the near plane
40far: float # The distance to the far plane
41noise_std: float # The std dev of noise added to raw sigma.
42net_depth: int # The depth of the first part of MLP.
43net_width: int # The width of the first part of MLP.
44num_viewdir_channels: int # The number of extra channels for view-dependence.
45viewdir_net_depth: int # The depth of the view-dependence MLP.
46viewdir_net_width: int # The width of the view-dependence MLP.
47net_activation: Callable[Ellipsis, Any] # MLP activation
48skip_layer: int # How often to add skip connections.
49num_rgb_channels: int # The number of RGB channels.
50num_sigma_channels: int # The number of density channels.
51white_bkgd: bool # If True, use a white background.
52min_deg_point: int # The minimum degree of positional encoding for positions.
53max_deg_point: int # The maximum degree of positional encoding for positions.
54deg_view: int # The degree of positional encoding for viewdirs.
55lindisp: bool # If True, sample linearly in disparity rather than in depth.
56rgb_activation: Callable[Ellipsis, Any] # Output RGB activation.
57sigma_activation: Callable[Ellipsis, Any] # Output sigma activation.
58legacy_posenc_order: bool # Keep the same ordering as the original tf code.
59
60@nn.compact
61def __call__(self, rng_0, rng_1, rays, randomized):
62"""Nerf Model.
63
64Args:
65rng_0: jnp.ndarray, random number generator for coarse model sampling.
66rng_1: jnp.ndarray, random number generator for fine model sampling.
67rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
68randomized: bool, use randomized stratified sampling.
69
70Returns:
71ret: list, [(rgb_coarse, disp_coarse, acc_coarse, features_coarse,
72specular_coarse), (rgb, disp, acc, features, specular)]
73"""
74# Stratified sampling along rays
75key, rng_0 = random.split(rng_0)
76z_vals, coarse_samples = model_utils.sample_along_rays(
77key,
78rays.origins,
79rays.directions,
80self.num_coarse_samples,
81self.near,
82self.far,
83randomized,
84self.lindisp,
85)
86coarse_samples_enc = model_utils.posenc(
87coarse_samples,
88self.min_deg_point,
89self.max_deg_point,
90self.legacy_posenc_order,
91)
92
93# Construct the "coarse" MLP.
94coarse_mlp = model_utils.MLP(
95net_depth=self.net_depth,
96net_width=self.net_width,
97net_activation=self.net_activation,
98skip_layer=self.skip_layer,
99num_rgb_channels=self.num_rgb_channels + self.num_viewdir_channels,
100num_sigma_channels=self.num_sigma_channels)
101
102# Point attribute predictions
103if self.use_viewdirs:
104viewdirs_enc = model_utils.posenc(
105rays.viewdirs,
1060,
107self.deg_view,
108self.legacy_posenc_order,
109)
110raw_features_and_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)
111else:
112raw_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)
113
114# Add noises to regularize the density predictions if needed
115key, rng_0 = random.split(rng_0)
116raw_sigma = model_utils.add_gaussian_noise(
117key,
118raw_sigma,
119self.noise_std,
120randomized,
121)
122sigma = self.sigma_activation(raw_sigma)
123
124if self.use_viewdirs:
125coarse_viewdir_mlp = model_utils.MLP(
126net_depth=self.viewdir_net_depth,
127net_width=self.viewdir_net_width,
128net_activation=self.net_activation,
129skip_layer=self.skip_layer,
130num_rgb_channels=self.num_rgb_channels,
131num_sigma_channels=self.num_sigma_channels)
132
133# Overcomposite the features to get an encoding for the features.
134comp_features, _, _, _ = model_utils.volumetric_rendering(
135raw_features_and_rgb[Ellipsis, self.num_rgb_channels:(
136self.num_rgb_channels + self.num_viewdir_channels)],
137sigma,
138z_vals,
139rays.directions,
140white_bkgd=False,
141)
142features = comp_features[Ellipsis, 0:self.num_rgb_channels]
143
144diffuse_rgb = self.rgb_activation(
145raw_features_and_rgb[Ellipsis, 0:self.num_rgb_channels])
146comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
147diffuse_rgb,
148sigma,
149z_vals,
150rays.directions,
151white_bkgd=self.white_bkgd,
152)
153
154viewdirs_enc_features = jnp.concatenate(
155[viewdirs_enc, comp_rgb, comp_features], axis=-1)
156viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
157raw_comp_rgb_residual, _ = coarse_viewdir_mlp(viewdirs_enc_features)
158
159output_shape = list(comp_features.shape)
160output_shape[-1] = 3
161raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(output_shape)
162rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
163comp_rgb += rgb_residual
164else:
165rgb = self.rgb_activation(raw_rgb)
166# Volumetric rendering.
167comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
168rgb,
169sigma,
170z_vals,
171rays.directions,
172white_bkgd=self.white_bkgd,
173)
174features = jnp.zeros_like(comp_rgb)
175rgb_residual = jnp.zeros_like(comp_rgb)
176
177ret = [
178(comp_rgb, disp, acc, sigma, features, rgb_residual),
179]
180# Hierarchical sampling based on coarse predictions
181if self.num_fine_samples > 0:
182z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
183key, rng_1 = random.split(rng_1)
184z_vals, fine_samples = model_utils.sample_pdf(
185key,
186z_vals_mid,
187weights[Ellipsis, 1:-1],
188rays.origins,
189rays.directions,
190z_vals,
191self.num_fine_samples,
192randomized,
193)
194fine_samples_enc = model_utils.posenc(
195fine_samples,
196self.min_deg_point,
197self.max_deg_point,
198self.legacy_posenc_order,
199)
200
201# Construct the "fine" MLP.
202fine_mlp = model_utils.MLP(
203net_depth=self.net_depth,
204net_width=self.net_width,
205net_activation=self.net_activation,
206skip_layer=self.skip_layer,
207num_rgb_channels=self.num_rgb_channels + self.num_viewdir_channels,
208num_sigma_channels=self.num_sigma_channels)
209
210if self.use_viewdirs:
211raw_features_and_rgb, raw_sigma = fine_mlp(fine_samples_enc)
212else:
213raw_rgb, raw_sigma = fine_mlp(fine_samples_enc)
214
215key, rng_1 = random.split(rng_1)
216raw_sigma = model_utils.add_gaussian_noise(
217key,
218raw_sigma,
219self.noise_std,
220randomized,
221)
222sigma = self.sigma_activation(raw_sigma)
223
224_, raw_reg_sigma = fine_mlp(coarse_samples_enc)
225reg_sigma = self.sigma_activation(raw_reg_sigma)
226
227if self.use_viewdirs:
228fine_viewdir_mlp = model_utils.MLP(
229net_depth=self.viewdir_net_depth,
230net_width=self.viewdir_net_width,
231net_activation=self.net_activation,
232skip_layer=self.skip_layer,
233num_rgb_channels=self.num_rgb_channels,
234num_sigma_channels=self.num_sigma_channels)
235
236# Overcomposite the features to get an encoding for the features.
237features_and_rgb = self.rgb_activation(raw_features_and_rgb)
238features = features_and_rgb[Ellipsis, self.num_rgb_channels:(
239self.num_rgb_channels + self.num_viewdir_channels)]
240
241comp_features, _, _, _ = model_utils.volumetric_rendering(
242features,
243sigma,
244z_vals,
245rays.directions,
246white_bkgd=False,
247)
248features = comp_features[Ellipsis, 0:self.num_rgb_channels]
249
250diffuse_rgb = features_and_rgb[Ellipsis, 0:self.num_rgb_channels]
251comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
252diffuse_rgb,
253sigma,
254z_vals,
255rays.directions,
256white_bkgd=self.white_bkgd,
257)
258
259viewdirs_enc_features = jnp.concatenate(
260[viewdirs_enc, comp_rgb, comp_features], axis=-1)
261viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
262raw_comp_rgb_residual, _ = fine_viewdir_mlp(viewdirs_enc_features)
263
264output_shape = list(comp_features.shape)
265output_shape[-1] = 3
266raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(output_shape)
267rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
268comp_rgb += rgb_residual
269else:
270rgb = self.rgb_activation(raw_rgb)
271# Volumetric rendering.
272comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
273rgb,
274sigma,
275z_vals,
276rays.directions,
277white_bkgd=self.white_bkgd,
278)
279features = jnp.zeros_like(comp_rgb)
280rgb_residual = jnp.zeros_like(comp_rgb)
281
282ret.append((comp_rgb, disp, acc, reg_sigma, features, rgb_residual))
283return ret
284
285
286def construct_nerf(key, example_batch, args):
287"""Construct a Neural Radiance Field.
288
289Args:
290key: jnp.ndarray. Random number generator.
291example_batch: dict, an example of a batch of data.
292args: FLAGS class. Hyperparameters of nerf.
293
294Returns:
295model: nn.Model. Nerf model with parameters.
296state: flax.Module.state. Nerf model state for stateful parameters.
297"""
298net_activation = nn.relu
299rgb_activation = nn.sigmoid
300sigma_activation = nn.relu
301
302# Assert that rgb_activation always produces outputs in [0, 1], and
303# sigma_activation always produce non-negative outputs.
304x = jnp.exp(jnp.linspace(-90, 90, 1024))
305x = jnp.concatenate([-x[::-1], x], 0)
306
307rgb = rgb_activation(x)
308if jnp.any(rgb < 0) or jnp.any(rgb > 1):
309raise NotImplementedError(
310"Choice of rgb_activation `{}` produces colors outside of [0, 1]"
311.format(args.rgb_activation))
312
313sigma = sigma_activation(x)
314if jnp.any(sigma < 0):
315raise NotImplementedError(
316"Choice of sigma_activation `{}` produces negative densities".format(
317args.sigma_activation))
318
319model = NerfModel(
320min_deg_point=args.min_deg_point,
321max_deg_point=args.max_deg_point,
322deg_view=args.deg_view,
323num_coarse_samples=args.num_coarse_samples,
324num_fine_samples=args.num_fine_samples,
325use_viewdirs=args.use_viewdirs,
326near=args.near,
327far=args.far,
328noise_std=args.noise_std,
329white_bkgd=args.white_bkgd,
330net_depth=args.net_depth,
331net_width=args.net_width,
332num_viewdir_channels=args.num_viewdir_channels,
333viewdir_net_depth=args.viewdir_net_depth,
334viewdir_net_width=args.viewdir_net_width,
335skip_layer=args.skip_layer,
336num_rgb_channels=args.num_rgb_channels,
337num_sigma_channels=args.num_sigma_channels,
338lindisp=args.lindisp,
339net_activation=net_activation,
340rgb_activation=rgb_activation,
341sigma_activation=sigma_activation,
342legacy_posenc_order=args.legacy_posenc_order)
343rays = example_batch["rays"]
344key1, key2, key3 = random.split(key, num=3)
345
346init_variables = model.init(
347key1,
348rng_0=key2,
349rng_1=key3,
350rays=utils.namedtuple_map(lambda x: x[0], rays),
351randomized=args.randomized)
352
353return model, init_variables
354