google-research
852 строки · 31.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""NeRF helpers."""
17import math
18import tensorflow as tf
19from osf import box_utils
20from osf import ray_utils
21from osf import scene_utils
22
23
24def default_ray_sampling(ray_batch, n_samples, perturb, lindisp):
25"""Default NeRF ray sampling.
26
27This function takes a batch of rays and returns points along each ray that
28should be evaluated by the coarse NeRF model.
29
30Args:
31ray_batch: Array of shape [batch_size, ...]. All information necessary
32for sampling along a ray, including: ray origin, ray direction, min dist,
33max dist, and unit-magnitude viewing direction.
34n_samples: Number of samples to take on each ray.
35perturb: Whether to perturb the points with white noise.
36lindisp: bool. If True, sample linearly in inverse depth rather than in
37depth.
38
39Returns:
40z_vals: Positions of the sampled points on each ray as a scalar:
41[n_rays, n_samples, 1].
42pts: Actual sampled points in 3D: [n_rays, n_samples, 3].
43"""
44# batch size
45n_rays = tf.shape(ray_batch)[0]
46
47# Extract ray origin, direction.
48rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [n_rays, 3] each
49
50# Extract lower, upper bound for ray distance.
51bounds = tf.reshape(ray_batch[Ellipsis, 6:8], [-1, 1, 2])
52near, far = bounds[Ellipsis, 0], bounds[Ellipsis, 1] # [R, 1]
53
54# Decide where to sample along each ray. Under the logic, all rays will be
55# sampled at the same times.
56t_vals = tf.linspace(0., 1., n_samples)
57if not lindisp:
58# Space integration times linearly between 'near' and 'far'. Same
59# integration points will be used for all rays.
60z_vals = near * (1. - t_vals) + far * (t_vals)
61else:
62tf.debugging.assert_greater(near, 0)
63tf.debugging.assert_greater(far, 0)
64# Sample linearly in inverse depth (disparity).
65z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))
66z_vals = tf.broadcast_to(z_vals, [n_rays, n_samples])
67
68# Perturb sampling time along each ray.
69if perturb > 0.:
70# get intervals between samples
71mids = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
72upper = tf.concat([mids, z_vals[Ellipsis, -1:]], -1)
73lower = tf.concat([z_vals[Ellipsis, :1], mids], -1)
74# stratified samples in those intervals
75t_rand = tf.random.uniform(tf.shape(z_vals))
76z_vals = lower + (upper - lower) * t_rand
77# Points in space to evaluate model at.
78pts = rays_o[Ellipsis, None, :] + rays_d[Ellipsis, None, :] * z_vals[Ellipsis, :, None]
79tf.debugging.assert_equal(tf.shape(z_vals)[0], tf.shape(pts)[0])
80return z_vals, pts
81
82
83def raw2alpha(raw, dists, act_fn=tf.nn.relu):
84return 1.0 - tf.exp(-act_fn(raw) * dists)
85
86
87def compute_alpha(z_vals,
88raw_alpha,
89raw_noise_std,
90last_dist_method='infinity'):
91"""Normalizes raw sigma predictions from the network into normalized alpha."""
92# Compute 'distance' (in time) between each integration time along a ray.
93dists = z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1]
94
95# The 'distance' from the last integration time is infinity.
96if last_dist_method == 'infinity':
97dists = tf.concat(
98[dists, tf.broadcast_to([1e10], tf.shape(dists[Ellipsis, :1]))],
99axis=-1) # [n_rays, n_samples]
100elif last_dist_method == 'last':
101dists = tf.concat([dists, dists[Ellipsis, -1:]], axis=-1)
102
103# Multiply each distance by the norm of its corresponding direction ray
104# to convert to real world distance (accounts for non-unit directions).
105# dists = dists * tf.linalg.norm(rays_d[..., None, :], axis=-1)
106
107raw_alpha = tf.squeeze(raw_alpha, axis=-1) # [n_rays, n_samples]
108
109# Add noise to model's predictions for density. Can be used to
110# regularize network during training (prevents floater artifacts).
111noise = 0.
112if raw_noise_std > 0.:
113noise = tf.random.normal(tf.shape(raw_alpha)) * raw_noise_std
114
115# Convert from raw alpha to alpha values between [0, 1].
116# Predict density of each sample along each ray. Higher values imply
117# higher likelihood of being absorbed at this point.
118alpha = raw2alpha(raw_alpha + noise, dists) # [n_rays, n_samples]
119return alpha[Ellipsis, None]
120
121
122def broadcast_samples_dim(x, target):
123"""Broadcast shape of 'x' to match 'target'.
124
125Given 'target' of shape [N, S, M] and 'x' of shape [N, K],
126broadcast 'x' to have shape [N, S, K].
127
128Args:
129x: array to broadcast.
130target: array to match.
131
132Returns:
133x, broadcasts to shape [..., num_samples, K].
134"""
135s = target.shape[1]
136result = tf.expand_dims(x, axis=1) # [N, 1, K]
137result_tile = tf.tile(result, [1, s, 1]) # [N, S, K]
138return result_tile
139
140
141def sample_pdf(bins, weights, n_samples, det=False):
142"""Function for sampling a probability distribution."""
143# Get pdf
144weights += 1e-5 # prevent nans
145pdf = weights / tf.reduce_sum(weights, -1, keepdims=True)
146cdf = tf.cumsum(pdf, -1)
147cdf = tf.concat([tf.zeros_like(cdf[Ellipsis, :1]), cdf], -1)
148
149# Take uniform samples
150u_shape = tf.concat([tf.shape(cdf)[:-1], tf.constant([n_samples])], axis=0)
151if det:
152u = tf.linspace(0., 1., n_samples)
153u = tf.broadcast_to(u, u_shape)
154else:
155u = tf.random.uniform(u_shape)
156
157# Invert CDF
158inds = tf.searchsorted(cdf, u, side='right')
159below = tf.maximum(0, inds - 1)
160above = tf.minimum(tf.shape(cdf)[-1] - 1, inds)
161inds_g = tf.stack([below, above], -1)
162cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(tf.shape(inds_g)) - 2)
163bins_g = tf.gather(
164bins, inds_g, axis=-1, batch_dims=len(tf.shape(inds_g)) - 2)
165
166denom = (cdf_g[Ellipsis, 1] - cdf_g[Ellipsis, 0])
167denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom)
168t = (u - cdf_g[Ellipsis, 0]) / denom
169samples = bins_g[Ellipsis, 0] + t * (bins_g[Ellipsis, 1] - bins_g[Ellipsis, 0])
170
171return samples
172
173
174def default_ray_sampling_fine(ray_batch,
175z_vals,
176weights,
177n_samples,
178perturb,
179keep_original_points=True,
180compute_fine_indices=False):
181"""Ray sampling of fine points."""
182n_orig_samples = z_vals.shape[1]
183
184# Extract ray origin, direction.
185rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [R, 3] each
186
187# Obtain additional integration times to evaluate based on the weights
188# assigned to colors in the coarse model.
189z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
190z_samples = sample_pdf(
191z_vals_mid, weights[Ellipsis, 1:-1], n_samples, det=(perturb == 0.))
192z_samples = tf.stop_gradient(z_samples)
193
194# Obtain all points to evaluate color, density at.
195if keep_original_points:
196z_list = [z_vals, z_samples] # [R, S], [R, I]
197else:
198z_list = [z_samples]
199z_vals = tf.sort(tf.concat(z_list, -1), -1) # [R, S + I]
200
201fine_indices = None
202if compute_fine_indices:
203# The last `n_samples` values represent the indices of the fine samples
204# in the final sorted set of `z_samples`.
205z_argsort_indices = tf.argsort(tf.concat(z_list, -1), -1) # [R, S + I]
206fine_indices = z_argsort_indices[:, -n_samples:] # [R, I]
207fine_indices = tf.reshape(fine_indices, [-1, n_samples]) # [R, I]
208pts = rays_o[Ellipsis, None, :] + \
209rays_d[Ellipsis, None, :] * z_vals[Ellipsis, :, None] # [R, S + I, 3]
210
211# The inputs may contain unknown batch size, and leads to results were the
212# first two dimensions are unknown. Set dimension one with the known dimension
213# size.
214n_total_samples = n_orig_samples + n_samples
215
216z_vals = tf.reshape(z_vals, [-1, n_total_samples]) # [R, S + I]
217z_samples = tf.reshape(z_samples, [-1, n_samples]) # [R, I]
218pts = tf.reshape(pts, [-1, n_total_samples, 3]) # [R, S + I, 3]
219
220tf.debugging.assert_equal(tf.shape(z_vals)[0], tf.shape(pts)[0])
221return z_vals, z_samples, pts, fine_indices
222
223
224def apply_intersect_mask_to_tensors(intersect_mask, tensors):
225intersect_tensors = []
226for t in tensors:
227intersect_t = tf.boolean_mask(tensor=t, mask=intersect_mask) # [Ro, ...]
228intersect_tensors.append(intersect_t)
229return intersect_tensors
230
231
232def compute_object_intersect_tensors(name, ray_batch, scene_info, far,
233object2padding, swap_object_yz, **kwargs):
234"""Compute intersecting rays."""
235rays_o = ray_utils.extract_slice_from_ray_batch( # [R, 3]
236ray_batch=ray_batch, # [R, M]
237key='origin')
238rays_d = ray_utils.extract_slice_from_ray_batch( # [R, 3]
239ray_batch=ray_batch, # [R, M]
240key='direction')
241rays_far = ray_utils.extract_slice_from_ray_batch( # [R, 3]
242ray_batch=ray_batch, # [R, M]
243key='far')
244rays_sid = ray_utils.extract_slice_from_ray_batch( # [R, 1]
245ray_batch=ray_batch, # [R, M]
246key='metadata')
247
248(
249box_dims, # [R, 3], [R, 3], [R, 3, 3]
250box_center,
251box_rotation) = scene_utils.extract_object_boxes_for_scenes(
252name=name,
253scene_info=scene_info,
254sids=rays_sid, # [R, 1]
255padding=object2padding[name],
256swap_yz=swap_object_yz,
257box_delta_t=kwargs['box_delta_t'])
258
259# Compute ray-bbox intersections.
260intersect_bounds, intersect_indices, intersect_mask = ( # [R', 2],[R',],[R,]
261box_utils.compute_ray_bbox_bounds_pairwise(
262rays_o=rays_o, # [R, 3]
263rays_d=rays_d, # [R, 3]
264rays_far=rays_far, # [R, 1]
265box_length=box_dims[:, 0], # [R,]
266box_width=box_dims[:, 1], # [R,]
267box_height=box_dims[:, 2], # [R,]
268box_center=box_center, # [R, 3]
269box_rotation=box_rotation, # [R, 3, 3]
270far_limit=far))
271
272# Apply the intersection mask to the ray batch.
273intersect_ray_batch = apply_intersect_mask_to_tensors( # [R', M]
274intersect_mask=intersect_mask, # [R,]
275tensors=[ray_batch])[0] # [R, M]
276
277# Update the near and far bounds of the ray batch with the intersect bounds.
278intersect_ray_batch = ray_utils.update_ray_batch_bounds( # [R', M]
279ray_batch=intersect_ray_batch, # [R', M]
280bounds=intersect_bounds) # [R', 2]
281return intersect_ray_batch, intersect_indices # [R', M], [R', 1]
282
283
284def compute_lightdirs(pts,
285metadata,
286scene_info,
287use_lightdir_norm,
288use_random_lightdirs,
289light_pos,
290light_origins=None):
291"""Compute light directions."""
292n_rays = tf.shape(pts)[0]
293n_samples = tf.shape(pts)[1]
294if use_random_lightdirs:
295# https://stackoverflow.com/questions/5408276/sampling-uniformly-distributed-random-points-inside-a-spherical-volume
296phi = tf.random.uniform(
297shape=[n_rays, n_samples], minval=0, maxval=2 * math.pi) # [R, S]
298cos_theta = tf.random.uniform(
299shape=[n_rays, n_samples], minval=-1, maxval=1) # [R, S]
300theta = tf.math.acos(cos_theta) # [R, S]
301
302x = tf.math.sin(theta) * tf.math.cos(phi)
303y = tf.math.sin(theta) * tf.math.sin(phi)
304z = tf.math.cos(theta)
305
306light_origins = tf.zeros([n_rays, n_samples, 3],
307dtype=tf.float32) # [R, S, 3]
308light_dst = tf.concat([x[Ellipsis, None], y[Ellipsis, None], z[Ellipsis, None]],
309axis=-1) # [R, S, 3]
310# Transform ray origin/dst to points.
311light_origins = light_origins + pts
312light_dst = light_dst + pts
313light_origins_flat = tf.reshape(light_origins, [-1, 3]) # [RS, 3]
314light_dst_flat = tf.reshape(light_dst, [-1, 3]) # [RS, 3]
315light_dst_flat = tf.reshape(light_dst, [-1, 3]) # [RS, 3]
316metadata_tiled = tf.tile(metadata[:, None, :],
317[1, n_samples, 1]) # [R, S, 1]
318metadata_flat = tf.reshape(metadata_tiled, [-1, 1]) # [RS, 1]
319# We want points to be the origin of the light rays in the batch because
320# we will be computing radiance along that ray.
321light_ray_batch = ray_utils.create_ray_batch( # [RS, M]
322rays_o=light_origins_flat, # [RS, 3]
323rays_dst=light_dst_flat, # [RS, 3]
324rays_sid=metadata_flat) # [RS, 1]
325lightdirs = light_dst
326return light_ray_batch, lightdirs
327
328# Compute light origins using scene info and control metadata.
329if light_origins is None:
330light_origins = scene_utils.extract_light_positions_for_sids(
331sids=metadata, scene_info=scene_info, light_pos=light_pos)
332
333# Make sure to use tf.shape instead of light_origins.shape:
334# https://github.com/tensorflow/models/issues/6245
335light_origins = tf.reshape( # [n_rays, 1, 3]
336light_origins,
337[tf.shape(light_origins)[0], 1,
338tf.shape(light_origins)[1]])
339
340# Compute the incoming light direction for each point.
341# Note that we use points in the world coordinate space because we currently
342# assume that the light position is in world coordinate space.
343lightdirs = pts - light_origins # [n_rays, n_samples, 3]
344
345# Make all directions unit magnitude.
346if use_lightdir_norm:
347lightdirs_norm = tf.linalg.norm(
348lightdirs, axis=-1, keepdims=True) # [n_rays, n_samples, 1]
349lightdirs = tf.math.divide_no_nan(lightdirs,
350lightdirs_norm) # [n_rays, n_samples, 3]
351lightdirs = tf.reshape(lightdirs, [n_rays, n_samples, 3])
352lightdirs = tf.cast(lightdirs, dtype=tf.float32)
353
354light_origins = tf.tile(light_origins, [1, n_samples, 1])
355metadata = tf.tile(metadata[:, None, :], [1, n_samples, 1])
356light_origins_flat = tf.reshape(light_origins, [-1, 3]) # [?, 3]
357pts_flat = tf.reshape(pts, [-1, 3]) # [?, 3]
358metadata_flat = tf.reshape(metadata, [-1, 1]) # [?, 1]
359light_ray_batch = ray_utils.create_ray_batch( # [?S, M]
360rays_o=light_origins_flat, # [?S, 3]
361rays_dst=pts_flat, # [?S, 3]
362rays_sid=metadata_flat) # [?S, 1]
363return light_ray_batch, lightdirs
364
365
366def compute_view_light_dirs(ray_batch, pts, scene_info, use_viewdirs,
367use_lightdir_norm, use_random_lightdirs, light_pos):
368"""Compute viewing and lighting directions."""
369viewdirs = ray_batch[:, 8:11] # [R, 3]
370metadata = ray_utils.extract_slice_from_ray_batch(
371ray_batch, key='metadata', use_viewdirs=use_viewdirs) # [R, 1]
372viewdirs = broadcast_samples_dim(x=viewdirs, target=pts) # [R, S, 3]
373light_ray_batch, lightdirs = compute_lightdirs(
374pts=pts,
375metadata=metadata,
376scene_info=scene_info,
377use_lightdir_norm=use_lightdir_norm,
378use_random_lightdirs=use_random_lightdirs,
379light_pos=light_pos) # [R, S, 3]
380return viewdirs, light_ray_batch, lightdirs
381
382
383def network_query_fn_helper(pts, ray_batch, network, network_query_fn, viewdirs,
384lightdirs, use_viewdirs, use_lightdirs):
385"""Query the NeRF network."""
386if not use_viewdirs:
387viewdirs = None
388if not use_lightdirs:
389lightdirs = None
390# Extract unit-normalized viewing direction.
391# [n_rays, 3]
392# viewdirs = ray_batch[:, 8:11] if use_viewdirs else None
393
394# Extract additional per-ray metadata.
395# [n_rays, metadata_channels]
396rays_data = ray_utils.extract_slice_from_ray_batch(
397ray_batch, key='example_id', use_viewdirs=use_viewdirs)
398
399# Query NeRF for the corresponding densities for the light points.
400raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
401return raw
402
403
404def network_query_fn_helper_nodirs(pts,
405ray_batch,
406network,
407network_query_fn,
408use_viewdirs,
409use_lightdirs,
410use_lightdir_norm,
411scene_info,
412use_random_lightdirs,
413light_origins=None,
414**kwargs):
415"""Same as network_query_fn_helper, but without input directions."""
416_ = kwargs
417
418if not use_viewdirs:
419viewdirs = None
420if not use_lightdirs:
421lightdirs = None
422
423# Extract unit-normalized viewing direction.
424if use_viewdirs:
425viewdirs = ray_batch[:, 8:11] # [R, 3]
426viewdirs = broadcast_samples_dim(x=viewdirs, target=pts) # [R, S, 3]
427else:
428viewdirs = None
429
430# Compute the light directions.
431# if use_lightdirs:
432light_ray_batch, lightdirs = compute_lightdirs( # [R, S, 3]
433pts=pts,
434metadata=ray_utils.extract_slice_from_ray_batch(
435ray_batch, key='metadata', use_viewdirs=use_viewdirs),
436scene_info=scene_info,
437use_lightdir_norm=use_lightdir_norm,
438use_random_lightdirs=use_random_lightdirs,
439light_pos=kwargs['light_pos'],
440light_origins=light_origins)
441# else:
442# light_ray_batch = None
443# lightdirs = None
444
445# Extract additional per-ray metadata.
446rays_data = ray_utils.extract_slice_from_ray_batch(
447ray_batch, key='example_id', use_viewdirs=use_viewdirs)
448
449# Query NeRF for the corresponding densities for the light points.
450raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
451return light_ray_batch, raw
452
453
454def create_w2o_transformations_tensors(name, scene_info, ray_batch,
455use_viewdirs, box_delta_t):
456"""Create transformation tensor from world to object space."""
457metadata = ray_utils.extract_slice_from_ray_batch(
458ray_batch, key='metadata', use_viewdirs=use_viewdirs) # [R, 1]
459w2o_rt_per_scene, w2o_r_per_scene = (
460scene_utils.extract_w2o_transformations_per_scene(
461name=name, scene_info=scene_info, box_delta_t=box_delta_t))
462w2o_rt = tf.gather_nd( # [R, 4, 4]
463params=w2o_rt_per_scene, # [N_scenes, 4, 4]
464indices=metadata) # [R, 1]
465w2o_r = tf.gather_nd( # [R, 4, 4]
466params=w2o_r_per_scene, # [N_scenes, 4, 4]
467indices=metadata) # [R, 1]
468return w2o_rt, w2o_r
469
470
471def apply_batched_transformations(inputs, transformations):
472"""Batched transformation of inputs.
473
474Args:
475inputs: List of [R, S, 3]
476transformations: [R, 4, 4]
477
478Returns:
479transformed_inputs: List of [R, S, 3]
480"""
481transformed_inputs = []
482for x in inputs:
483n_samples = tf.shape(x)[1]
484homog_transformations = tf.expand_dims(
485input=transformations, axis=1) # [R, 1, 4, 4]
486homog_transformations = tf.tile(homog_transformations,
487[1, n_samples, 1, 1]) # [R, S, 4, 4]
488homog_component = tf.ones_like(x)[Ellipsis, 0:1] # [R, S, 1]
489homog_x = tf.concat([x, homog_component], axis=-1) # [R, S, 4]
490homog_x = tf.expand_dims(input=homog_x, axis=2) # [R, S, 1, 4]
491transformed_x = tf.matmul(homog_x,
492tf.transpose(homog_transformations,
493(0, 1, 3, 2))) # [R, S, 1, 4]
494transformed_x = transformed_x[Ellipsis, 0, :3] # [R, S, 3]
495transformed_inputs.append(transformed_x)
496return transformed_inputs
497
498
499def compute_object_inputs(name, ray_batch, pts, scene_info,
500use_random_lightdirs, **kwargs):
501"""Compute inputs to object networks."""
502# Extract viewing and lighting directions.
503# [Ro, S, 3]
504object_viewdirs, light_ray_batch, object_lightdirs = compute_view_light_dirs(
505ray_batch=ray_batch,
506pts=pts,
507scene_info=scene_info,
508use_viewdirs=kwargs['use_viewdirs'],
509use_lightdir_norm=kwargs['use_lightdir_norm'],
510use_random_lightdirs=use_random_lightdirs,
511light_pos=kwargs['light_pos'])
512
513# Transform points and optionally directions from world to canonical
514# coordinate frame.
515w2o_rt, w2o_r = create_w2o_transformations_tensors( # [Ro, 4, 4]
516name=name,
517scene_info=scene_info,
518ray_batch=ray_batch,
519use_viewdirs=kwargs['use_viewdirs'],
520box_delta_t=kwargs['box_delta_t']) # [Ro, 1]
521object_pts = apply_batched_transformations(
522inputs=[pts],
523transformations=w2o_rt,
524)[0]
525if kwargs['use_transform_dirs']:
526# pylint: disable=unbalanced-tuple-unpacking
527[object_viewdirs, object_lightdirs] = apply_batched_transformations(
528inputs=[object_viewdirs, object_lightdirs], transformations=w2o_r)
529# pylint: enable=unbalanced-tuple-unpacking
530return object_pts, object_viewdirs, light_ray_batch, object_lightdirs
531
532
533def normalize_rgb(raw_rgb, scaled_sigmoid):
534# Extract RGB of each sample position along each ray.
535rgb = tf.math.sigmoid(raw_rgb) # [n_rays, n_samples, 3]
536if scaled_sigmoid:
537rgb = 1.2 * (rgb - 0.5) + 0.5 # [n_rays, n_samples, 3]
538return rgb
539
540
541def normalize_raw(raw, z_vals, scaled_sigmoid, raw_noise_std, last_dist_method):
542"""Normalize raw outputs of the network."""
543# Compute weight for RGB of each sample along each ray. A cumprod() is
544# used to express the idea of the ray not having reflected up to this
545# sample yet.
546# [n_rays, n_samples]
547alpha = compute_alpha(
548z_vals=z_vals,
549raw_alpha=raw['alpha'],
550raw_noise_std=raw_noise_std,
551last_dist_method=last_dist_method)
552normalized = {
553'rgb': normalize_rgb(raw_rgb=raw['rgb'], scaled_sigmoid=scaled_sigmoid),
554'alpha': alpha
555}
556return normalized
557
558
559def run_sparse_network(name, network, intersect_z_vals, intersect_pts,
560intersect_ray_batch, use_random_lightdirs, **kwargs):
561"""Runs a single network."""
562if name.startswith('bkgd'):
563intersect_light_ray_batch, intersect_raw = network_query_fn_helper_nodirs(
564network=network,
565pts=intersect_pts, # [R, S, 3]
566ray_batch=intersect_ray_batch, # [R, M]
567use_random_lightdirs=use_random_lightdirs,
568**kwargs) # [R, 3]
569else:
570(object_intersect_pts, object_intersect_viewdirs, intersect_light_ray_batch,
571object_intersect_lightdirs) = compute_object_inputs(
572name=name,
573ray_batch=intersect_ray_batch,
574pts=intersect_pts,
575use_random_lightdirs=use_random_lightdirs,
576**kwargs)
577
578# Query the object NeRF.
579intersect_raw = network_query_fn_helper(
580pts=object_intersect_pts,
581ray_batch=intersect_ray_batch,
582network=network,
583network_query_fn=kwargs['network_query_fn'],
584viewdirs=object_intersect_viewdirs,
585lightdirs=object_intersect_lightdirs,
586use_viewdirs=kwargs['use_viewdirs'],
587use_lightdirs=kwargs['use_lightdirs'])
588
589# Compute weights of the intersecting points on normalized raw values.
590normalized_raw = normalize_raw( # [Ro, S, 4]
591raw=intersect_raw, # [Ro, S, 4]
592z_vals=intersect_z_vals, # [Ro, S]
593scaled_sigmoid=kwargs['scaled_sigmoid'],
594raw_noise_std=kwargs['raw_noise_std'],
595last_dist_method=kwargs['last_dist_method'])
596return intersect_light_ray_batch, normalized_raw
597
598
599def compute_transmittance(alpha):
600"""Computes transmittance from (normalized) alpha values.
601
602Args:
603alpha: [R, S]
604
605Returns:
606t: [R, S]
607"""
608# Compute the accumulated transmittance along the ray at each point.
609print(f'[compute_transmittance] alpha.shape: {alpha.shape}')
610# TODO(guom): fix this
611t = 1. - alpha
612return t
613
614
615def compute_weights(normalized_alpha):
616trans = compute_transmittance(normalized_alpha)
617weights = normalized_alpha * trans
618return trans, weights
619
620
621def run_single_object(name, ray_batch, use_random_lightdirs, **kwargs):
622"""Run and generate predictions for a single object.
623
624Args:
625name: The name of the object to run.
626ray_batch: [R, M] tf.float32. A batch of rays.
627use_random_lightdirs:
628**kwargs: Additional arguments.
629
630Returns:
631intersect_0: Dict.
632intersect: Dict.
633"""
634# Compute intersection rays and indices.
635if name.startswith('bkgd'):
636intersect_ray_batch = ray_batch
637intersect_indices = None
638else:
639intersect_ray_batch, intersect_indices = compute_object_intersect_tensors(
640name=name, ray_batch=ray_batch, **kwargs)
641# Run coarse stage.
642intersect_z_vals_0, intersect_pts_0 = default_ray_sampling(
643ray_batch=intersect_ray_batch,
644n_samples=kwargs['n_samples'],
645perturb=kwargs['perturb'],
646lindisp=kwargs['lindisp'])
647
648intersect_light_ray_batch_0, normalized_raw_0 = run_sparse_network(
649name=name,
650network=kwargs['name2model'][name],
651intersect_z_vals=intersect_z_vals_0,
652intersect_pts=intersect_pts_0,
653intersect_ray_batch=intersect_ray_batch,
654use_random_lightdirs=use_random_lightdirs,
655**kwargs)
656
657# Run fine stage.
658if kwargs['N_importance'] > 0:
659# normalized_alpha_0 = normalized_raw_0['alpha']
660_, intersect_weights_0 = compute_weights(
661normalized_alpha=normalized_raw_0['alpha'][Ellipsis, 0]) # [Ro, S]
662# Generate fine samples using weights from the coarse stage.
663intersect_z_vals, _, intersect_pts, _ = default_ray_sampling_fine(
664ray_batch=intersect_ray_batch, # [Ro, M]
665z_vals=intersect_z_vals_0, # [Ro, S]
666weights=intersect_weights_0, # [Ro, S]
667n_samples=kwargs['N_importance'],
668perturb=kwargs['perturb'])
669
670# Run the networks for all the objects.
671intersect_light_ray_batch, normalized_raw = run_sparse_network(
672name=name,
673network=kwargs['name2model'][name],
674intersect_z_vals=intersect_z_vals,
675intersect_pts=intersect_pts,
676intersect_ray_batch=intersect_ray_batch,
677use_random_lightdirs=use_random_lightdirs,
678**kwargs)
679
680intersect_0 = {
681'ray_batch': intersect_ray_batch, # [R, M]
682'light_ray_batch': intersect_light_ray_batch_0, # [R, S, M]
683'indices': intersect_indices,
684'z_vals': intersect_z_vals_0,
685'pts': intersect_pts_0,
686'normalized_rgb': normalized_raw_0['rgb'],
687'normalized_alpha': normalized_raw_0['alpha'],
688}
689intersect = {
690'ray_batch': intersect_ray_batch, # [R, M]
691'light_ray_batch': intersect_light_ray_batch, # [R, S, M]
692'indices': intersect_indices,
693'z_vals': intersect_z_vals,
694'pts': intersect_pts,
695'normalized_rgb': normalized_raw['rgb'],
696'normalized_alpha': normalized_raw['alpha'],
697}
698return intersect_0, intersect
699
700
701def create_scatter_indices_for_dim(dim, shape, indices=None):
702"""Create scatter indieces for a given dimension."""
703dim_size = shape[dim]
704n_dims = len(shape)
705reshape = [1] * n_dims
706reshape[dim] = -1
707
708if indices is None:
709indices = tf.range(dim_size, dtype=tf.int32) # [dim_size,]
710
711indices = tf.reshape(indices, reshape) # [1, ..., dim_size, ..., 1]
712
713tf.debugging.assert_equal(tf.shape(indices)[dim], shape[dim])
714indices = tf.broadcast_to(
715indices, shape) # [Ro, S, 1] or [Ro, S, C, 1] [0,1,1,1] vs. [512,64,1,1]
716
717indices = tf.cast(indices, dtype=tf.int32)
718return indices
719
720
721def create_scatter_indices(updates, dim2known_indices):
722"""Create scatter indices."""
723updates_expanded = tf.expand_dims(updates, -1) # [Ro, S, 1] or [Ro, S, C, 1]
724target_shape = tf.shape(updates_expanded)
725n_dims = len(tf.shape(updates)) # 2 or 3
726
727dim_indices_list = []
728for dim in range(n_dims):
729indices = None
730if dim in dim2known_indices:
731indices = dim2known_indices[dim]
732dim_indices = create_scatter_indices_for_dim( # [Ro, S, C, 1]
733dim=dim,
734shape=target_shape, # [Ro, S, 1] or [Ro, S, C, 1]
735indices=indices) # [Ro,]
736dim_indices_list.append(dim_indices)
737scatter_indices = tf.concat(dim_indices_list, axis=-1) # [Ro, S, C, 3]
738return scatter_indices
739
740
741def scatter_nd(tensor, updates, dim2known_indices):
742scatter_indices = create_scatter_indices( # [Ro, S, C, 3]
743updates=updates, # [Ro, S]
744dim2known_indices=dim2known_indices) # [Ro,]
745scattered_tensor = tf.tensor_scatter_nd_update(
746tensor=tensor, # [R, S, C]
747indices=scatter_indices, # [Ro, S, C, 3]
748updates=updates) # [Ro, S, C]
749return scattered_tensor
750
751
752def scatter_results(intersect, n_rays, keys):
753"""Scatters intersecting ray results into the original set of rays.
754
755Args:
756intersect: Dict. Intersecting values.
757n_rays: int or tf.int32. Total number of rays.
758keys: [str]. List of keys to scatter.
759
760Returns:
761scattered_results: Dict. Scattered results.
762"""
763# We use `None` to indicate that the intersecting set of rays is equivalent to
764# the full set of rays, so we are done.
765intersect_indices = intersect['indices']
766if intersect_indices is None:
767return {k: intersect[k] for k in keys}
768
769scattered_results = {}
770n_samples = intersect['z_vals'].shape[1]
771dim2known_indices = {0: intersect_indices} # [R?, 1]
772for k in keys:
773if k == 'z_vals':
774tensor = tf.random.uniform((n_rays, n_samples),
775dtype=tf.float32) # [R, S]
776elif k == 'pts':
777tensor = tf.cast( # [R, S, 3]
778tf.fill((n_rays, n_samples, 3), 1000.0),
779dtype=tf.float32)
780elif 'rgb' in k:
781tensor = tf.zeros((n_rays, n_samples, 3), dtype=tf.float32) # [R, S, 3]
782elif 'alpha' in k:
783tensor = tf.zeros((n_rays, n_samples, 1), dtype=tf.float32) # [R, S, 1]
784else:
785raise ValueError(f'Invalid key: {k}')
786scattered_v = scatter_nd( # [R, S, K]
787tensor=tensor,
788updates=intersect[k], # [Ro, S]
789dim2known_indices=dim2known_indices)
790# Convert the batch dimension to a known dimension.
791# For some reason `scattered_z_vals` becomes [R, ?]. We need to explicitly
792# reshape it with `n_samples`.
793if k == 'z_vals':
794scattered_v = tf.reshape(scattered_v, (n_rays, n_samples)) # [R, S]
795else:
796# scattered_v = tf.reshape(
797# scattered_v, (n_rays,) + scattered_v.shape[1:]) # [R, S, K]
798# scattered_v = tf.reshape(
799# scattered_v, (-1,) + scattered_v.shape[1:]) # [R, S, K]
800scattered_v = tf.reshape(
801scattered_v, (n_rays, n_samples, tensor.shape[2])) # [R, S, K]
802scattered_results[k] = scattered_v
803return scattered_results
804
805
806def combine_results(name2results, keys):
807"""Combines network outputs.
808
809Args:
810name2results: Dict. For each object results, `z_vals` is required.
811keys: [str]. A list of keys to combine results over.
812
813Returns:
814results: Dict. Combined results.
815"""
816# Collect z values across all objects.
817z_vals_list = []
818for _, results in name2results.items():
819z_vals_list.append(results['z_vals'])
820
821# Concatenate lists of object results into a single tensor.
822z_vals = tf.concat(z_vals_list, axis=-1) # [R, S*O]
823
824# Compute the argsort indices.
825z_argsort_indices = tf.argsort(z_vals, -1) # [R, S*O]
826n_rays, n_samples = tf.shape(z_vals)[0], tf.shape(z_vals)[1]
827gather_indices = tf.range(n_rays)[:, None] # [R, 1]
828gather_indices = tf.tile(gather_indices, [1, n_samples]) # [R, S]
829gather_indices = tf.concat(
830[gather_indices[Ellipsis, None], z_argsort_indices[Ellipsis, None]], axis=-1)
831
832results = {}
833for k in keys:
834if k == 'z_vals':
835v_combined = z_vals
836else:
837v_list = [r[k] for r in name2results.values()]
838v_combined = tf.concat(v_list, axis=1) # [R, S*O, K]
839
840# Sort the tensors.
841v_sorted = tf.gather_nd( # [R, S, K]
842params=v_combined, # [R, S, K]
843indices=gather_indices) # [R, S, 2]
844results[k] = v_sorted
845return results
846
847
848def compose_outputs(results, light_rgb, white_bkgd):
849del results
850del light_rgb
851del white_bkgd
852return -1
853
854
855