google-research

Форк
0
/
model_utils.py 
290 строк · 10.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Helper functions/classes for model definition."""
17

18
import functools
19
from typing import Any, Callable
20

21
from flax import linen as nn
22
import jax
23
from jax import lax
24
from jax import random
25
import jax.numpy as jnp
26

27

28
class MLP(nn.Module):
29
  """A simple MLP."""
30
  net_depth: int = 8  # The depth of the first part of MLP.
31
  net_width: int = 256  # The width of the first part of MLP.
32
  net_activation: Callable[Ellipsis, Any] = nn.relu  # The activation function.
33
  skip_layer: int = 4  # The layer to add skip layers to.
34
  num_rgb_channels: int = 3  # The number of RGB channels.
35
  num_sigma_channels: int = 1  # The number of sigma channels.
36

37
  @nn.compact
38
  def __call__(self, x):
39
    """Evaluate the MLP.
40

41
    Args:
42
      x: jnp.ndarray(float32), [batch, num_samples, feature], points.
43

44
    Returns:
45
      raw_rgb: jnp.ndarray(float32), with a shape of
46
           [batch, num_samples, num_rgb_channels].
47
      raw_sigma: jnp.ndarray(float32), with a shape of
48
           [batch, num_samples, num_sigma_channels].
49
    """
50
    feature_dim = x.shape[-1]
51
    num_samples = x.shape[1]
52
    x = x.reshape([-1, feature_dim])
53
    dense_layer = functools.partial(
54
        nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
55
    inputs = x
56
    for i in range(self.net_depth):
57
      x = dense_layer(self.net_width)(x)
58
      x = self.net_activation(x)
59
      if i % self.skip_layer == 0 and i > 0:
60
        x = jnp.concatenate([x, inputs], axis=-1)
61
    raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(
62
        [-1, num_samples, self.num_sigma_channels])
63
    raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
64
        [-1, num_samples, self.num_rgb_channels])
65
    return raw_rgb, raw_sigma
66

67

68
def cast_rays(z_vals, origins, directions):
69
  return origins[Ellipsis, None, :] + z_vals[Ellipsis, None] * directions[Ellipsis, None, :]
70

71

72
def sample_along_rays(key, origins, directions, num_samples, near, far,
73
                      randomized, lindisp):
74
  """Stratified sampling along the rays.
75

76
  Args:
77
    key: jnp.ndarray, random generator key.
78
    origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
79
    directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
80
    num_samples: int.
81
    near: float, near clip.
82
    far: float, far clip.
83
    randomized: bool, use randomized stratified sampling.
84
    lindisp: bool, sampling linearly in disparity rather than depth.
85

86
  Returns:
87
    z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
88
    points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
89
  """
90
  batch_size = origins.shape[0]
91

92
  t_vals = jnp.linspace(0., 1., num_samples)
93
  if lindisp:
94
    z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
95
  else:
96
    z_vals = near * (1. - t_vals) + far * t_vals
97

98
  if randomized:
99
    mids = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
100
    upper = jnp.concatenate([mids, z_vals[Ellipsis, -1:]], -1)
101
    lower = jnp.concatenate([z_vals[Ellipsis, :1], mids], -1)
102
    t_rand = random.uniform(key, [batch_size, num_samples])
103
    z_vals = lower + (upper - lower) * t_rand
104
  else:
105
    # Broadcast z_vals to make the returned shape consistent.
106
    z_vals = jnp.broadcast_to(z_vals[None, Ellipsis], [batch_size, num_samples])
107

108
  coords = cast_rays(z_vals, origins, directions)
109
  return z_vals, coords
110

111

112
def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
113
  """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
114

115
  Instead of computing [sin(x), cos(x)], we use the trig identity
116
  cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
117

118
  Args:
119
    x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
120
    min_deg: int, the minimum (inclusive) degree of the encoding.
121
    max_deg: int, the maximum (exclusive) degree of the encoding.
122
    legacy_posenc_order: bool, keep the same ordering as the original tf code.
123

124
  Returns:
125
    encoded: jnp.ndarray, encoded variables.
126
  """
127
  if min_deg == max_deg:
128
    return x
129
  scales = jnp.array([2**i for i in range(min_deg, max_deg)])
130
  if legacy_posenc_order:
131
    xb = x[Ellipsis, None, :] * scales[:, None]
132
    four_feat = jnp.reshape(
133
        jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),
134
        list(x.shape[:-1]) + [-1])
135
  else:
136
    xb = jnp.reshape((x[Ellipsis, None, :] * scales[:, None]),
137
                     list(x.shape[:-1]) + [-1])
138
    four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
139
  return jnp.concatenate([x] + [four_feat], axis=-1)
140

141

142
def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
143
  """Volumetric Rendering Function.
144

145
  Args:
146
    rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
147
    sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
148
    z_vals: jnp.ndarray(float32), [batch_size, num_samples].
149
    dirs: jnp.ndarray(float32), [batch_size, 3].
150
    white_bkgd: bool.
151

152
  Returns:
153
    comp_rgb: jnp.ndarray(float32), [batch_size, 3].
154
    disp: jnp.ndarray(float32), [batch_size].
155
    acc: jnp.ndarray(float32), [batch_size].
156
    weights: jnp.ndarray(float32), [batch_size, num_samples]
157
  """
158
  eps = 1e-10
159
  dists = jnp.concatenate([
160
      z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1],
161
      jnp.broadcast_to(1e10, z_vals[Ellipsis, :1].shape)
162
  ], -1)
163
  dists = dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1)
164
  # Note that we're quietly turning sigma from [..., 0] to [...].
165
  alpha = 1.0 - jnp.exp(-sigma[Ellipsis, 0] * dists)
166
  accum_prod = jnp.concatenate([
167
      jnp.ones_like(alpha[Ellipsis, :1], alpha.dtype),
168
      jnp.cumprod(1.0 - alpha[Ellipsis, :-1] + eps, axis=-1)
169
  ],
170
                               axis=-1)
171
  weights = alpha * accum_prod
172

173
  comp_rgb = (weights[Ellipsis, None] * rgb).sum(axis=-2)
174
  depth = (weights * z_vals).sum(axis=-1)
175
  acc = weights.sum(axis=-1)
176
  # Equivalent to (but slightly more efficient and stable than):
177
  #  disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
178
  inv_eps = 1 / eps
179
  disp = acc / depth
180
  disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
181
  if white_bkgd:
182
    comp_rgb = comp_rgb + (1. - acc[Ellipsis, None])
183
  return comp_rgb, disp, acc, weights
184

185

186
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
187
  """Piecewise-Constant PDF sampling.
188

189
  Args:
190
    key: jnp.ndarray(float32), [2,], random number generator.
191
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
192
    weights: jnp.ndarray(float32), [batch_size, num_bins].
193
    num_samples: int, the number of samples.
194
    randomized: bool, use randomized samples.
195

196
  Returns:
197
    z_samples: jnp.ndarray(float32), [batch_size, num_samples].
198
  """
199
  # Pad each weight vector (only if necessary) to bring its sum to `eps`. This
200
  # avoids NaNs when the input is zeros or small, but has no effect otherwise.
201
  eps = 1e-5
202
  weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
203
  padding = jnp.maximum(0, eps - weight_sum)
204
  weights += padding / weights.shape[-1]
205
  weight_sum += padding
206

207
  # Compute the PDF and CDF for each weight vector, while ensuring that the CDF
208
  # starts with exactly 0 and ends with exactly 1.
209
  pdf = weights / weight_sum
210
  cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))
211
  cdf = jnp.concatenate([
212
      jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
213
      jnp.ones(list(cdf.shape[:-1]) + [1])
214
  ],
215
                        axis=-1)
216

217
  # Draw uniform samples.
218
  if randomized:
219
    # Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
220
    u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
221
  else:
222
    # Match the behavior of random.uniform() by spanning [0, 1-eps].
223
    u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)
224
    u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
225

226
  # Identify the location in `cdf` that corresponds to a random sample.
227
  # The final `True` index in `mask` will be the start of the sampled interval.
228
  mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]
229

230
  def find_interval(x):
231
    # Grab the value where `mask` switches from True to False, and vice versa.
232
    # This approach takes advantage of the fact that `x` is sorted.
233
    x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)
234
    x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
235
    return x0, x1
236

237
  bins_g0, bins_g1 = find_interval(bins)
238
  cdf_g0, cdf_g1 = find_interval(cdf)
239

240
  t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
241
  samples = bins_g0 + t * (bins_g1 - bins_g0)
242

243
  # Prevent gradient from backprop-ing through `samples`.
244
  return lax.stop_gradient(samples)
245

246

247
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
248
               randomized):
249
  """Hierarchical sampling.
250

251
  Args:
252
    key: jnp.ndarray(float32), [2,], random number generator.
253
    bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
254
    weights: jnp.ndarray(float32), [batch_size, num_bins].
255
    origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
256
    directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
257
    z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
258
    num_samples: int, the number of samples.
259
    randomized: bool, use randomized samples.
260

261
  Returns:
262
    z_vals: jnp.ndarray(float32),
263
      [batch_size, num_coarse_samples + num_fine_samples].
264
    points: jnp.ndarray(float32),
265
      [batch_size, num_coarse_samples + num_fine_samples, 3].
266
  """
267
  z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
268
                                     randomized)
269
  # Compute united z_vals and sample points
270
  z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
271
  coords = cast_rays(z_vals, origins, directions)
272
  return z_vals, coords
273

274

275
def add_gaussian_noise(key, raw, noise_std, randomized):
276
  """Adds gaussian noise to `raw`, which can used to regularize it.
277

278
  Args:
279
    key: jnp.ndarray(float32), [2,], random number generator.
280
    raw: jnp.ndarray(float32), arbitrary shape.
281
    noise_std: float, The standard deviation of the noise to be added.
282
    randomized: bool, add noise if randomized is True.
283

284
  Returns:
285
    raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
286
  """
287
  if (noise_std is not None) and randomized:
288
    return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
289
  else:
290
    return raw
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.