google-research

Форк
0
353 строки · 12.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Different model implementation plus a general port for all the models."""
17
from typing import Any, Callable
18
from flax import linen as nn
19
from jax import random
20
import jax.numpy as jnp
21

22
from snerg.nerf import model_utils
23
from snerg.nerf import utils
24

25

26
def get_model(key, example_batch, args):
27
  """A helper function that wraps around a 'model zoo'."""
28
  model_dict = {
29
      "nerf": construct_nerf,
30
  }
31
  return model_dict[args.model](key, example_batch, args)
32

33

34
class NerfModel(nn.Module):
35
  """Nerf NN Model with both coarse and fine MLPs."""
36
  num_coarse_samples: int  # The number of samples for the coarse nerf.
37
  num_fine_samples: int  # The number of samples for the fine nerf.
38
  use_viewdirs: bool  # If True, use viewdirs as an input.
39
  near: float  # The distance to the near plane
40
  far: float  # The distance to the far plane
41
  noise_std: float  # The std dev of noise added to raw sigma.
42
  net_depth: int  # The depth of the first part of MLP.
43
  net_width: int  # The width of the first part of MLP.
44
  num_viewdir_channels: int  # The number of extra channels for view-dependence.
45
  viewdir_net_depth: int  # The depth of the view-dependence MLP.
46
  viewdir_net_width: int  # The width of the view-dependence MLP.
47
  net_activation: Callable[Ellipsis, Any]  # MLP activation
48
  skip_layer: int  # How often to add skip connections.
49
  num_rgb_channels: int  # The number of RGB channels.
50
  num_sigma_channels: int  # The number of density channels.
51
  white_bkgd: bool  # If True, use a white background.
52
  min_deg_point: int  # The minimum degree of positional encoding for positions.
53
  max_deg_point: int  # The maximum degree of positional encoding for positions.
54
  deg_view: int  # The degree of positional encoding for viewdirs.
55
  lindisp: bool  # If True, sample linearly in disparity rather than in depth.
56
  rgb_activation: Callable[Ellipsis, Any]  # Output RGB activation.
57
  sigma_activation: Callable[Ellipsis, Any]  # Output sigma activation.
58
  legacy_posenc_order: bool  # Keep the same ordering as the original tf code.
59

60
  @nn.compact
61
  def __call__(self, rng_0, rng_1, rays, randomized):
62
    """Nerf Model.
63

64
    Args:
65
      rng_0: jnp.ndarray, random number generator for coarse model sampling.
66
      rng_1: jnp.ndarray, random number generator for fine model sampling.
67
      rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
68
      randomized: bool, use randomized stratified sampling.
69

70
    Returns:
71
      ret: list, [(rgb_coarse, disp_coarse, acc_coarse, features_coarse,
72
      specular_coarse), (rgb, disp, acc, features, specular)]
73
    """
74
    # Stratified sampling along rays
75
    key, rng_0 = random.split(rng_0)
76
    z_vals, coarse_samples = model_utils.sample_along_rays(
77
        key,
78
        rays.origins,
79
        rays.directions,
80
        self.num_coarse_samples,
81
        self.near,
82
        self.far,
83
        randomized,
84
        self.lindisp,
85
    )
86
    coarse_samples_enc = model_utils.posenc(
87
        coarse_samples,
88
        self.min_deg_point,
89
        self.max_deg_point,
90
        self.legacy_posenc_order,
91
    )
92

93
    # Construct the "coarse" MLP.
94
    coarse_mlp = model_utils.MLP(
95
        net_depth=self.net_depth,
96
        net_width=self.net_width,
97
        net_activation=self.net_activation,
98
        skip_layer=self.skip_layer,
99
        num_rgb_channels=self.num_rgb_channels + self.num_viewdir_channels,
100
        num_sigma_channels=self.num_sigma_channels)
101

102
    # Point attribute predictions
103
    if self.use_viewdirs:
104
      viewdirs_enc = model_utils.posenc(
105
          rays.viewdirs,
106
          0,
107
          self.deg_view,
108
          self.legacy_posenc_order,
109
      )
110
      raw_features_and_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)
111
    else:
112
      raw_rgb, raw_sigma = coarse_mlp(coarse_samples_enc)
113

114
    # Add noises to regularize the density predictions if needed
115
    key, rng_0 = random.split(rng_0)
116
    raw_sigma = model_utils.add_gaussian_noise(
117
        key,
118
        raw_sigma,
119
        self.noise_std,
120
        randomized,
121
    )
122
    sigma = self.sigma_activation(raw_sigma)
123

124
    if self.use_viewdirs:
125
      coarse_viewdir_mlp = model_utils.MLP(
126
          net_depth=self.viewdir_net_depth,
127
          net_width=self.viewdir_net_width,
128
          net_activation=self.net_activation,
129
          skip_layer=self.skip_layer,
130
          num_rgb_channels=self.num_rgb_channels,
131
          num_sigma_channels=self.num_sigma_channels)
132

133
      # Overcomposite the features to get an encoding for the features.
134
      comp_features, _, _, _ = model_utils.volumetric_rendering(
135
          raw_features_and_rgb[Ellipsis, self.num_rgb_channels:(
136
              self.num_rgb_channels + self.num_viewdir_channels)],
137
          sigma,
138
          z_vals,
139
          rays.directions,
140
          white_bkgd=False,
141
      )
142
      features = comp_features[Ellipsis, 0:self.num_rgb_channels]
143

144
      diffuse_rgb = self.rgb_activation(
145
          raw_features_and_rgb[Ellipsis, 0:self.num_rgb_channels])
146
      comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
147
          diffuse_rgb,
148
          sigma,
149
          z_vals,
150
          rays.directions,
151
          white_bkgd=self.white_bkgd,
152
      )
153

154
      viewdirs_enc_features = jnp.concatenate(
155
          [viewdirs_enc, comp_rgb, comp_features], axis=-1)
156
      viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
157
      raw_comp_rgb_residual, _ = coarse_viewdir_mlp(viewdirs_enc_features)
158

159
      output_shape = list(comp_features.shape)
160
      output_shape[-1] = 3
161
      raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(output_shape)
162
      rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
163
      comp_rgb += rgb_residual
164
    else:
165
      rgb = self.rgb_activation(raw_rgb)
166
      # Volumetric rendering.
167
      comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
168
          rgb,
169
          sigma,
170
          z_vals,
171
          rays.directions,
172
          white_bkgd=self.white_bkgd,
173
      )
174
      features = jnp.zeros_like(comp_rgb)
175
      rgb_residual = jnp.zeros_like(comp_rgb)
176

177
    ret = [
178
        (comp_rgb, disp, acc, sigma, features, rgb_residual),
179
    ]
180
    # Hierarchical sampling based on coarse predictions
181
    if self.num_fine_samples > 0:
182
      z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
183
      key, rng_1 = random.split(rng_1)
184
      z_vals, fine_samples = model_utils.sample_pdf(
185
          key,
186
          z_vals_mid,
187
          weights[Ellipsis, 1:-1],
188
          rays.origins,
189
          rays.directions,
190
          z_vals,
191
          self.num_fine_samples,
192
          randomized,
193
      )
194
      fine_samples_enc = model_utils.posenc(
195
          fine_samples,
196
          self.min_deg_point,
197
          self.max_deg_point,
198
          self.legacy_posenc_order,
199
      )
200

201
      # Construct the "fine" MLP.
202
      fine_mlp = model_utils.MLP(
203
          net_depth=self.net_depth,
204
          net_width=self.net_width,
205
          net_activation=self.net_activation,
206
          skip_layer=self.skip_layer,
207
          num_rgb_channels=self.num_rgb_channels + self.num_viewdir_channels,
208
          num_sigma_channels=self.num_sigma_channels)
209

210
      if self.use_viewdirs:
211
        raw_features_and_rgb, raw_sigma = fine_mlp(fine_samples_enc)
212
      else:
213
        raw_rgb, raw_sigma = fine_mlp(fine_samples_enc)
214

215
      key, rng_1 = random.split(rng_1)
216
      raw_sigma = model_utils.add_gaussian_noise(
217
          key,
218
          raw_sigma,
219
          self.noise_std,
220
          randomized,
221
      )
222
      sigma = self.sigma_activation(raw_sigma)
223

224
      _, raw_reg_sigma = fine_mlp(coarse_samples_enc)
225
      reg_sigma = self.sigma_activation(raw_reg_sigma)
226

227
      if self.use_viewdirs:
228
        fine_viewdir_mlp = model_utils.MLP(
229
            net_depth=self.viewdir_net_depth,
230
            net_width=self.viewdir_net_width,
231
            net_activation=self.net_activation,
232
            skip_layer=self.skip_layer,
233
            num_rgb_channels=self.num_rgb_channels,
234
            num_sigma_channels=self.num_sigma_channels)
235

236
        # Overcomposite the features to get an encoding for the features.
237
        features_and_rgb = self.rgb_activation(raw_features_and_rgb)
238
        features = features_and_rgb[Ellipsis, self.num_rgb_channels:(
239
            self.num_rgb_channels + self.num_viewdir_channels)]
240

241
        comp_features, _, _, _ = model_utils.volumetric_rendering(
242
            features,
243
            sigma,
244
            z_vals,
245
            rays.directions,
246
            white_bkgd=False,
247
        )
248
        features = comp_features[Ellipsis, 0:self.num_rgb_channels]
249

250
        diffuse_rgb = features_and_rgb[Ellipsis, 0:self.num_rgb_channels]
251
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
252
            diffuse_rgb,
253
            sigma,
254
            z_vals,
255
            rays.directions,
256
            white_bkgd=self.white_bkgd,
257
        )
258

259
        viewdirs_enc_features = jnp.concatenate(
260
            [viewdirs_enc, comp_rgb, comp_features], axis=-1)
261
        viewdirs_enc_features = jnp.expand_dims(viewdirs_enc_features, -2)
262
        raw_comp_rgb_residual, _ = fine_viewdir_mlp(viewdirs_enc_features)
263

264
        output_shape = list(comp_features.shape)
265
        output_shape[-1] = 3
266
        raw_comp_rgb_residual = raw_comp_rgb_residual.reshape(output_shape)
267
        rgb_residual = self.rgb_activation(raw_comp_rgb_residual)
268
        comp_rgb += rgb_residual
269
      else:
270
        rgb = self.rgb_activation(raw_rgb)
271
        # Volumetric rendering.
272
        comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
273
            rgb,
274
            sigma,
275
            z_vals,
276
            rays.directions,
277
            white_bkgd=self.white_bkgd,
278
        )
279
        features = jnp.zeros_like(comp_rgb)
280
        rgb_residual = jnp.zeros_like(comp_rgb)
281

282
      ret.append((comp_rgb, disp, acc, reg_sigma, features, rgb_residual))
283
    return ret
284

285

286
def construct_nerf(key, example_batch, args):
287
  """Construct a Neural Radiance Field.
288

289
  Args:
290
    key: jnp.ndarray. Random number generator.
291
    example_batch: dict, an example of a batch of data.
292
    args: FLAGS class. Hyperparameters of nerf.
293

294
  Returns:
295
    model: nn.Model. Nerf model with parameters.
296
    state: flax.Module.state. Nerf model state for stateful parameters.
297
  """
298
  net_activation = nn.relu
299
  rgb_activation = nn.sigmoid
300
  sigma_activation = nn.relu
301

302
  # Assert that rgb_activation always produces outputs in [0, 1], and
303
  # sigma_activation always produce non-negative outputs.
304
  x = jnp.exp(jnp.linspace(-90, 90, 1024))
305
  x = jnp.concatenate([-x[::-1], x], 0)
306

307
  rgb = rgb_activation(x)
308
  if jnp.any(rgb < 0) or jnp.any(rgb > 1):
309
    raise NotImplementedError(
310
        "Choice of rgb_activation `{}` produces colors outside of [0, 1]"
311
        .format(args.rgb_activation))
312

313
  sigma = sigma_activation(x)
314
  if jnp.any(sigma < 0):
315
    raise NotImplementedError(
316
        "Choice of sigma_activation `{}` produces negative densities".format(
317
            args.sigma_activation))
318

319
  model = NerfModel(
320
      min_deg_point=args.min_deg_point,
321
      max_deg_point=args.max_deg_point,
322
      deg_view=args.deg_view,
323
      num_coarse_samples=args.num_coarse_samples,
324
      num_fine_samples=args.num_fine_samples,
325
      use_viewdirs=args.use_viewdirs,
326
      near=args.near,
327
      far=args.far,
328
      noise_std=args.noise_std,
329
      white_bkgd=args.white_bkgd,
330
      net_depth=args.net_depth,
331
      net_width=args.net_width,
332
      num_viewdir_channels=args.num_viewdir_channels,
333
      viewdir_net_depth=args.viewdir_net_depth,
334
      viewdir_net_width=args.viewdir_net_width,
335
      skip_layer=args.skip_layer,
336
      num_rgb_channels=args.num_rgb_channels,
337
      num_sigma_channels=args.num_sigma_channels,
338
      lindisp=args.lindisp,
339
      net_activation=net_activation,
340
      rgb_activation=rgb_activation,
341
      sigma_activation=sigma_activation,
342
      legacy_posenc_order=args.legacy_posenc_order)
343
  rays = example_batch["rays"]
344
  key1, key2, key3 = random.split(key, num=3)
345

346
  init_variables = model.init(
347
      key1,
348
      rng_0=key2,
349
      rng_1=key3,
350
      rays=utils.namedtuple_map(lambda x: x[0], rays),
351
      randomized=args.randomized)
352

353
  return model, init_variables
354

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.