google-research
290 строк · 10.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Helper functions/classes for model definition."""
17
18import functools19from typing import Any, Callable20
21from flax import linen as nn22import jax23from jax import lax24from jax import random25import jax.numpy as jnp26
27
28class MLP(nn.Module):29"""A simple MLP."""30net_depth: int = 8 # The depth of the first part of MLP.31net_width: int = 256 # The width of the first part of MLP.32net_activation: Callable[Ellipsis, Any] = nn.relu # The activation function.33skip_layer: int = 4 # The layer to add skip layers to.34num_rgb_channels: int = 3 # The number of RGB channels.35num_sigma_channels: int = 1 # The number of sigma channels.36
37@nn.compact38def __call__(self, x):39"""Evaluate the MLP.40
41Args:
42x: jnp.ndarray(float32), [batch, num_samples, feature], points.
43
44Returns:
45raw_rgb: jnp.ndarray(float32), with a shape of
46[batch, num_samples, num_rgb_channels].
47raw_sigma: jnp.ndarray(float32), with a shape of
48[batch, num_samples, num_sigma_channels].
49"""
50feature_dim = x.shape[-1]51num_samples = x.shape[1]52x = x.reshape([-1, feature_dim])53dense_layer = functools.partial(54nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())55inputs = x56for i in range(self.net_depth):57x = dense_layer(self.net_width)(x)58x = self.net_activation(x)59if i % self.skip_layer == 0 and i > 0:60x = jnp.concatenate([x, inputs], axis=-1)61raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(62[-1, num_samples, self.num_sigma_channels])63raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(64[-1, num_samples, self.num_rgb_channels])65return raw_rgb, raw_sigma66
67
68def cast_rays(z_vals, origins, directions):69return origins[Ellipsis, None, :] + z_vals[Ellipsis, None] * directions[Ellipsis, None, :]70
71
72def sample_along_rays(key, origins, directions, num_samples, near, far,73randomized, lindisp):74"""Stratified sampling along the rays.75
76Args:
77key: jnp.ndarray, random generator key.
78origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
79directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
80num_samples: int.
81near: float, near clip.
82far: float, far clip.
83randomized: bool, use randomized stratified sampling.
84lindisp: bool, sampling linearly in disparity rather than depth.
85
86Returns:
87z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
88points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
89"""
90batch_size = origins.shape[0]91
92t_vals = jnp.linspace(0., 1., num_samples)93if lindisp:94z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)95else:96z_vals = near * (1. - t_vals) + far * t_vals97
98if randomized:99mids = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])100upper = jnp.concatenate([mids, z_vals[Ellipsis, -1:]], -1)101lower = jnp.concatenate([z_vals[Ellipsis, :1], mids], -1)102t_rand = random.uniform(key, [batch_size, num_samples])103z_vals = lower + (upper - lower) * t_rand104else:105# Broadcast z_vals to make the returned shape consistent.106z_vals = jnp.broadcast_to(z_vals[None, Ellipsis], [batch_size, num_samples])107
108coords = cast_rays(z_vals, origins, directions)109return z_vals, coords110
111
112def posenc(x, min_deg, max_deg, legacy_posenc_order=False):113"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].114
115Instead of computing [sin(x), cos(x)], we use the trig identity
116cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
117
118Args:
119x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
120min_deg: int, the minimum (inclusive) degree of the encoding.
121max_deg: int, the maximum (exclusive) degree of the encoding.
122legacy_posenc_order: bool, keep the same ordering as the original tf code.
123
124Returns:
125encoded: jnp.ndarray, encoded variables.
126"""
127if min_deg == max_deg:128return x129scales = jnp.array([2**i for i in range(min_deg, max_deg)])130if legacy_posenc_order:131xb = x[Ellipsis, None, :] * scales[:, None]132four_feat = jnp.reshape(133jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),134list(x.shape[:-1]) + [-1])135else:136xb = jnp.reshape((x[Ellipsis, None, :] * scales[:, None]),137list(x.shape[:-1]) + [-1])138four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))139return jnp.concatenate([x] + [four_feat], axis=-1)140
141
142def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):143"""Volumetric Rendering Function.144
145Args:
146rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
147sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
148z_vals: jnp.ndarray(float32), [batch_size, num_samples].
149dirs: jnp.ndarray(float32), [batch_size, 3].
150white_bkgd: bool.
151
152Returns:
153comp_rgb: jnp.ndarray(float32), [batch_size, 3].
154disp: jnp.ndarray(float32), [batch_size].
155acc: jnp.ndarray(float32), [batch_size].
156weights: jnp.ndarray(float32), [batch_size, num_samples]
157"""
158eps = 1e-10159dists = jnp.concatenate([160z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1],161jnp.broadcast_to(1e10, z_vals[Ellipsis, :1].shape)162], -1)163dists = dists * jnp.linalg.norm(dirs[Ellipsis, None, :], axis=-1)164# Note that we're quietly turning sigma from [..., 0] to [...].165alpha = 1.0 - jnp.exp(-sigma[Ellipsis, 0] * dists)166accum_prod = jnp.concatenate([167jnp.ones_like(alpha[Ellipsis, :1], alpha.dtype),168jnp.cumprod(1.0 - alpha[Ellipsis, :-1] + eps, axis=-1)169],170axis=-1)171weights = alpha * accum_prod172
173comp_rgb = (weights[Ellipsis, None] * rgb).sum(axis=-2)174depth = (weights * z_vals).sum(axis=-1)175acc = weights.sum(axis=-1)176# Equivalent to (but slightly more efficient and stable than):177# disp = 1 / max(eps, where(acc > eps, depth / acc, 0))178inv_eps = 1 / eps179disp = acc / depth180disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)181if white_bkgd:182comp_rgb = comp_rgb + (1. - acc[Ellipsis, None])183return comp_rgb, disp, acc, weights184
185
186def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):187"""Piecewise-Constant PDF sampling.188
189Args:
190key: jnp.ndarray(float32), [2,], random number generator.
191bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
192weights: jnp.ndarray(float32), [batch_size, num_bins].
193num_samples: int, the number of samples.
194randomized: bool, use randomized samples.
195
196Returns:
197z_samples: jnp.ndarray(float32), [batch_size, num_samples].
198"""
199# Pad each weight vector (only if necessary) to bring its sum to `eps`. This200# avoids NaNs when the input is zeros or small, but has no effect otherwise.201eps = 1e-5202weight_sum = jnp.sum(weights, axis=-1, keepdims=True)203padding = jnp.maximum(0, eps - weight_sum)204weights += padding / weights.shape[-1]205weight_sum += padding206
207# Compute the PDF and CDF for each weight vector, while ensuring that the CDF208# starts with exactly 0 and ends with exactly 1.209pdf = weights / weight_sum210cdf = jnp.minimum(1, jnp.cumsum(pdf[Ellipsis, :-1], axis=-1))211cdf = jnp.concatenate([212jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,213jnp.ones(list(cdf.shape[:-1]) + [1])214],215axis=-1)216
217# Draw uniform samples.218if randomized:219# Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.220u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])221else:222# Match the behavior of random.uniform() by spanning [0, 1-eps].223u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)224u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])225
226# Identify the location in `cdf` that corresponds to a random sample.227# The final `True` index in `mask` will be the start of the sampled interval.228mask = u[Ellipsis, None, :] >= cdf[Ellipsis, :, None]229
230def find_interval(x):231# Grab the value where `mask` switches from True to False, and vice versa.232# This approach takes advantage of the fact that `x` is sorted.233x0 = jnp.max(jnp.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)234x1 = jnp.min(jnp.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)235return x0, x1236
237bins_g0, bins_g1 = find_interval(bins)238cdf_g0, cdf_g1 = find_interval(cdf)239
240t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)241samples = bins_g0 + t * (bins_g1 - bins_g0)242
243# Prevent gradient from backprop-ing through `samples`.244return lax.stop_gradient(samples)245
246
247def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,248randomized):249"""Hierarchical sampling.250
251Args:
252key: jnp.ndarray(float32), [2,], random number generator.
253bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
254weights: jnp.ndarray(float32), [batch_size, num_bins].
255origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
256directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
257z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
258num_samples: int, the number of samples.
259randomized: bool, use randomized samples.
260
261Returns:
262z_vals: jnp.ndarray(float32),
263[batch_size, num_coarse_samples + num_fine_samples].
264points: jnp.ndarray(float32),
265[batch_size, num_coarse_samples + num_fine_samples, 3].
266"""
267z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,268randomized)269# Compute united z_vals and sample points270z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)271coords = cast_rays(z_vals, origins, directions)272return z_vals, coords273
274
275def add_gaussian_noise(key, raw, noise_std, randomized):276"""Adds gaussian noise to `raw`, which can used to regularize it.277
278Args:
279key: jnp.ndarray(float32), [2,], random number generator.
280raw: jnp.ndarray(float32), arbitrary shape.
281noise_std: float, The standard deviation of the noise to be added.
282randomized: bool, add noise if randomized is True.
283
284Returns:
285raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
286"""
287if (noise_std is not None) and randomized:288return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std289else:290return raw291