google-research

Форк
0
/
run_nerf_helpers.py 
852 строки · 31.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""NeRF helpers."""
17
import math
18
import tensorflow as tf
19
from osf import box_utils
20
from osf import ray_utils
21
from osf import scene_utils
22

23

24
def default_ray_sampling(ray_batch, n_samples, perturb, lindisp):
25
  """Default NeRF ray sampling.
26

27
  This function takes a batch of rays and returns points along each ray that
28
  should be evaluated by the coarse NeRF model.
29

30
  Args:
31
    ray_batch: Array of shape [batch_size, ...]. All information necessary
32
      for sampling along a ray, including: ray origin, ray direction, min dist,
33
        max dist, and unit-magnitude viewing direction.
34
    n_samples: Number of samples to take on each ray.
35
    perturb: Whether to perturb the points with white noise.
36
    lindisp: bool. If True, sample linearly in inverse depth rather than in
37
      depth.
38

39
  Returns:
40
    z_vals: Positions of the sampled points on each ray as a scalar:
41
      [n_rays, n_samples, 1].
42
    pts: Actual sampled points in 3D: [n_rays, n_samples, 3].
43
  """
44
  # batch size
45
  n_rays = tf.shape(ray_batch)[0]
46

47
  # Extract ray origin, direction.
48
  rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [n_rays, 3] each
49

50
  # Extract lower, upper bound for ray distance.
51
  bounds = tf.reshape(ray_batch[Ellipsis, 6:8], [-1, 1, 2])
52
  near, far = bounds[Ellipsis, 0], bounds[Ellipsis, 1]  # [R, 1]
53

54
  # Decide where to sample along each ray. Under the logic, all rays will be
55
  # sampled at the same times.
56
  t_vals = tf.linspace(0., 1., n_samples)
57
  if not lindisp:
58
    # Space integration times linearly between 'near' and 'far'. Same
59
    # integration points will be used for all rays.
60
    z_vals = near * (1. - t_vals) + far * (t_vals)
61
  else:
62
    tf.debugging.assert_greater(near, 0)
63
    tf.debugging.assert_greater(far, 0)
64
    # Sample linearly in inverse depth (disparity).
65
    z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * (t_vals))
66
  z_vals = tf.broadcast_to(z_vals, [n_rays, n_samples])
67

68
  # Perturb sampling time along each ray.
69
  if perturb > 0.:
70
    # get intervals between samples
71
    mids = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
72
    upper = tf.concat([mids, z_vals[Ellipsis, -1:]], -1)
73
    lower = tf.concat([z_vals[Ellipsis, :1], mids], -1)
74
    # stratified samples in those intervals
75
    t_rand = tf.random.uniform(tf.shape(z_vals))
76
    z_vals = lower + (upper - lower) * t_rand
77
  # Points in space to evaluate model at.
78
  pts = rays_o[Ellipsis, None, :] + rays_d[Ellipsis, None, :] * z_vals[Ellipsis, :, None]
79
  tf.debugging.assert_equal(tf.shape(z_vals)[0], tf.shape(pts)[0])
80
  return z_vals, pts
81

82

83
def raw2alpha(raw, dists, act_fn=tf.nn.relu):
84
  return 1.0 - tf.exp(-act_fn(raw) * dists)
85

86

87
def compute_alpha(z_vals,
88
                  raw_alpha,
89
                  raw_noise_std,
90
                  last_dist_method='infinity'):
91
  """Normalizes raw sigma predictions from the network into normalized alpha."""
92
  # Compute 'distance' (in time) between each integration time along a ray.
93
  dists = z_vals[Ellipsis, 1:] - z_vals[Ellipsis, :-1]
94

95
  # The 'distance' from the last integration time is infinity.
96
  if last_dist_method == 'infinity':
97
    dists = tf.concat(
98
        [dists, tf.broadcast_to([1e10], tf.shape(dists[Ellipsis, :1]))],
99
        axis=-1)  # [n_rays, n_samples]
100
  elif last_dist_method == 'last':
101
    dists = tf.concat([dists, dists[Ellipsis, -1:]], axis=-1)
102

103
  # Multiply each distance by the norm of its corresponding direction ray
104
  # to convert to real world distance (accounts for non-unit directions).
105
  # dists = dists * tf.linalg.norm(rays_d[..., None, :], axis=-1)
106

107
  raw_alpha = tf.squeeze(raw_alpha, axis=-1)  # [n_rays, n_samples]
108

109
  # Add noise to model's predictions for density. Can be used to
110
  # regularize network during training (prevents floater artifacts).
111
  noise = 0.
112
  if raw_noise_std > 0.:
113
    noise = tf.random.normal(tf.shape(raw_alpha)) * raw_noise_std
114

115
  # Convert from raw alpha to alpha values between [0, 1].
116
  # Predict density of each sample along each ray. Higher values imply
117
  # higher likelihood of being absorbed at this point.
118
  alpha = raw2alpha(raw_alpha + noise, dists)  # [n_rays, n_samples]
119
  return alpha[Ellipsis, None]
120

121

122
def broadcast_samples_dim(x, target):
123
  """Broadcast shape of 'x' to match 'target'.
124

125
  Given 'target' of shape [N, S, M] and 'x' of shape [N, K],
126
  broadcast 'x' to have shape [N, S, K].
127

128
  Args:
129
    x: array to broadcast.
130
    target: array to match.
131

132
  Returns:
133
    x, broadcasts to shape [..., num_samples, K].
134
  """
135
  s = target.shape[1]
136
  result = tf.expand_dims(x, axis=1)  # [N, 1, K]
137
  result_tile = tf.tile(result, [1, s, 1])  # [N, S, K]
138
  return result_tile
139

140

141
def sample_pdf(bins, weights, n_samples, det=False):
142
  """Function for sampling a probability distribution."""
143
  # Get pdf
144
  weights += 1e-5  # prevent nans
145
  pdf = weights / tf.reduce_sum(weights, -1, keepdims=True)
146
  cdf = tf.cumsum(pdf, -1)
147
  cdf = tf.concat([tf.zeros_like(cdf[Ellipsis, :1]), cdf], -1)
148

149
  # Take uniform samples
150
  u_shape = tf.concat([tf.shape(cdf)[:-1], tf.constant([n_samples])], axis=0)
151
  if det:
152
    u = tf.linspace(0., 1., n_samples)
153
    u = tf.broadcast_to(u, u_shape)
154
  else:
155
    u = tf.random.uniform(u_shape)
156

157
  # Invert CDF
158
  inds = tf.searchsorted(cdf, u, side='right')
159
  below = tf.maximum(0, inds - 1)
160
  above = tf.minimum(tf.shape(cdf)[-1] - 1, inds)
161
  inds_g = tf.stack([below, above], -1)
162
  cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(tf.shape(inds_g)) - 2)
163
  bins_g = tf.gather(
164
      bins, inds_g, axis=-1, batch_dims=len(tf.shape(inds_g)) - 2)
165

166
  denom = (cdf_g[Ellipsis, 1] - cdf_g[Ellipsis, 0])
167
  denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom)
168
  t = (u - cdf_g[Ellipsis, 0]) / denom
169
  samples = bins_g[Ellipsis, 0] + t * (bins_g[Ellipsis, 1] - bins_g[Ellipsis, 0])
170

171
  return samples
172

173

174
def default_ray_sampling_fine(ray_batch,
175
                              z_vals,
176
                              weights,
177
                              n_samples,
178
                              perturb,
179
                              keep_original_points=True,
180
                              compute_fine_indices=False):
181
  """Ray sampling of fine points."""
182
  n_orig_samples = z_vals.shape[1]
183

184
  # Extract ray origin, direction.
185
  rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6]  # [R, 3] each
186

187
  # Obtain additional integration times to evaluate based on the weights
188
  # assigned to colors in the coarse model.
189
  z_vals_mid = .5 * (z_vals[Ellipsis, 1:] + z_vals[Ellipsis, :-1])
190
  z_samples = sample_pdf(
191
      z_vals_mid, weights[Ellipsis, 1:-1], n_samples, det=(perturb == 0.))
192
  z_samples = tf.stop_gradient(z_samples)
193

194
  # Obtain all points to evaluate color, density at.
195
  if keep_original_points:
196
    z_list = [z_vals, z_samples]  # [R, S], [R, I]
197
  else:
198
    z_list = [z_samples]
199
  z_vals = tf.sort(tf.concat(z_list, -1), -1)  # [R, S + I]
200

201
  fine_indices = None
202
  if compute_fine_indices:
203
    # The last `n_samples` values represent the indices of the fine samples
204
    # in the final sorted set of `z_samples`.
205
    z_argsort_indices = tf.argsort(tf.concat(z_list, -1), -1)  # [R, S + I]
206
    fine_indices = z_argsort_indices[:, -n_samples:]  # [R, I]
207
    fine_indices = tf.reshape(fine_indices, [-1, n_samples])  # [R, I]
208
  pts = rays_o[Ellipsis, None, :] + \
209
      rays_d[Ellipsis, None, :] * z_vals[Ellipsis, :, None]  # [R, S + I, 3]
210

211
  # The inputs may contain unknown batch size, and leads to results were the
212
  # first two dimensions are unknown. Set dimension one with the known dimension
213
  # size.
214
  n_total_samples = n_orig_samples + n_samples
215

216
  z_vals = tf.reshape(z_vals, [-1, n_total_samples])  # [R, S + I]
217
  z_samples = tf.reshape(z_samples, [-1, n_samples])  # [R, I]
218
  pts = tf.reshape(pts, [-1, n_total_samples, 3])  # [R, S + I, 3]
219

220
  tf.debugging.assert_equal(tf.shape(z_vals)[0], tf.shape(pts)[0])
221
  return z_vals, z_samples, pts, fine_indices
222

223

224
def apply_intersect_mask_to_tensors(intersect_mask, tensors):
225
  intersect_tensors = []
226
  for t in tensors:
227
    intersect_t = tf.boolean_mask(tensor=t, mask=intersect_mask)  # [Ro, ...]
228
    intersect_tensors.append(intersect_t)
229
  return intersect_tensors
230

231

232
def compute_object_intersect_tensors(name, ray_batch, scene_info, far,
233
                                     object2padding, swap_object_yz, **kwargs):
234
  """Compute intersecting rays."""
235
  rays_o = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
236
      ray_batch=ray_batch,  # [R, M]
237
      key='origin')
238
  rays_d = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
239
      ray_batch=ray_batch,  # [R, M]
240
      key='direction')
241
  rays_far = ray_utils.extract_slice_from_ray_batch(  # [R, 3]
242
      ray_batch=ray_batch,  # [R, M]
243
      key='far')
244
  rays_sid = ray_utils.extract_slice_from_ray_batch(  # [R, 1]
245
      ray_batch=ray_batch,  # [R, M]
246
      key='metadata')
247

248
  (
249
      box_dims,  # [R, 3], [R, 3], [R, 3, 3]
250
      box_center,
251
      box_rotation) = scene_utils.extract_object_boxes_for_scenes(
252
          name=name,
253
          scene_info=scene_info,
254
          sids=rays_sid,  # [R, 1]
255
          padding=object2padding[name],
256
          swap_yz=swap_object_yz,
257
          box_delta_t=kwargs['box_delta_t'])
258

259
  # Compute ray-bbox intersections.
260
  intersect_bounds, intersect_indices, intersect_mask = (  # [R', 2],[R',],[R,]
261
      box_utils.compute_ray_bbox_bounds_pairwise(
262
          rays_o=rays_o,  # [R, 3]
263
          rays_d=rays_d,  # [R, 3]
264
          rays_far=rays_far,  # [R, 1]
265
          box_length=box_dims[:, 0],  # [R,]
266
          box_width=box_dims[:, 1],  # [R,]
267
          box_height=box_dims[:, 2],  # [R,]
268
          box_center=box_center,  # [R, 3]
269
          box_rotation=box_rotation,  # [R, 3, 3]
270
          far_limit=far))
271

272
  # Apply the intersection mask to the ray batch.
273
  intersect_ray_batch = apply_intersect_mask_to_tensors(  # [R', M]
274
      intersect_mask=intersect_mask,  # [R,]
275
      tensors=[ray_batch])[0]  # [R, M]
276

277
  # Update the near and far bounds of the ray batch with the intersect bounds.
278
  intersect_ray_batch = ray_utils.update_ray_batch_bounds(  # [R', M]
279
      ray_batch=intersect_ray_batch,  # [R', M]
280
      bounds=intersect_bounds)  # [R', 2]
281
  return intersect_ray_batch, intersect_indices  # [R', M], [R', 1]
282

283

284
def compute_lightdirs(pts,
285
                      metadata,
286
                      scene_info,
287
                      use_lightdir_norm,
288
                      use_random_lightdirs,
289
                      light_pos,
290
                      light_origins=None):
291
  """Compute light directions."""
292
  n_rays = tf.shape(pts)[0]
293
  n_samples = tf.shape(pts)[1]
294
  if use_random_lightdirs:
295
    # https://stackoverflow.com/questions/5408276/sampling-uniformly-distributed-random-points-inside-a-spherical-volume
296
    phi = tf.random.uniform(
297
        shape=[n_rays, n_samples], minval=0, maxval=2 * math.pi)  # [R, S]
298
    cos_theta = tf.random.uniform(
299
        shape=[n_rays, n_samples], minval=-1, maxval=1)  # [R, S]
300
    theta = tf.math.acos(cos_theta)  # [R, S]
301

302
    x = tf.math.sin(theta) * tf.math.cos(phi)
303
    y = tf.math.sin(theta) * tf.math.sin(phi)
304
    z = tf.math.cos(theta)
305

306
    light_origins = tf.zeros([n_rays, n_samples, 3],
307
                             dtype=tf.float32)  # [R, S, 3]
308
    light_dst = tf.concat([x[Ellipsis, None], y[Ellipsis, None], z[Ellipsis, None]],
309
                          axis=-1)  # [R, S, 3]
310
    # Transform ray origin/dst to points.
311
    light_origins = light_origins + pts
312
    light_dst = light_dst + pts
313
    light_origins_flat = tf.reshape(light_origins, [-1, 3])  # [RS, 3]
314
    light_dst_flat = tf.reshape(light_dst, [-1, 3])  # [RS, 3]
315
    light_dst_flat = tf.reshape(light_dst, [-1, 3])  # [RS, 3]
316
    metadata_tiled = tf.tile(metadata[:, None, :],
317
                             [1, n_samples, 1])  # [R, S, 1]
318
    metadata_flat = tf.reshape(metadata_tiled, [-1, 1])  # [RS, 1]
319
    # We want points to be the origin of the light rays in the batch because
320
    # we will be computing radiance along that ray.
321
    light_ray_batch = ray_utils.create_ray_batch(  # [RS, M]
322
        rays_o=light_origins_flat,  # [RS, 3]
323
        rays_dst=light_dst_flat,  # [RS, 3]
324
        rays_sid=metadata_flat)  # [RS, 1]
325
    lightdirs = light_dst
326
    return light_ray_batch, lightdirs
327

328
  # Compute light origins using scene info and control metadata.
329
  if light_origins is None:
330
    light_origins = scene_utils.extract_light_positions_for_sids(
331
        sids=metadata, scene_info=scene_info, light_pos=light_pos)
332

333
    # Make sure to use tf.shape instead of light_origins.shape:
334
    # https://github.com/tensorflow/models/issues/6245
335
    light_origins = tf.reshape(  # [n_rays, 1, 3]
336
        light_origins,
337
        [tf.shape(light_origins)[0], 1,
338
         tf.shape(light_origins)[1]])
339

340
  # Compute the incoming light direction for each point.
341
  # Note that we use points in the world coordinate space because we currently
342
  # assume that the light position is in world coordinate space.
343
  lightdirs = pts - light_origins  # [n_rays, n_samples, 3]
344

345
  # Make all directions unit magnitude.
346
  if use_lightdir_norm:
347
    lightdirs_norm = tf.linalg.norm(
348
        lightdirs, axis=-1, keepdims=True)  # [n_rays, n_samples, 1]
349
    lightdirs = tf.math.divide_no_nan(lightdirs,
350
                                      lightdirs_norm)  # [n_rays, n_samples, 3]
351
    lightdirs = tf.reshape(lightdirs, [n_rays, n_samples, 3])
352
    lightdirs = tf.cast(lightdirs, dtype=tf.float32)
353

354
  light_origins = tf.tile(light_origins, [1, n_samples, 1])
355
  metadata = tf.tile(metadata[:, None, :], [1, n_samples, 1])
356
  light_origins_flat = tf.reshape(light_origins, [-1, 3])  # [?, 3]
357
  pts_flat = tf.reshape(pts, [-1, 3])  # [?, 3]
358
  metadata_flat = tf.reshape(metadata, [-1, 1])  # [?, 1]
359
  light_ray_batch = ray_utils.create_ray_batch(  # [?S, M]
360
      rays_o=light_origins_flat,  # [?S, 3]
361
      rays_dst=pts_flat,  # [?S, 3]
362
      rays_sid=metadata_flat)  # [?S, 1]
363
  return light_ray_batch, lightdirs
364

365

366
def compute_view_light_dirs(ray_batch, pts, scene_info, use_viewdirs,
367
                            use_lightdir_norm, use_random_lightdirs, light_pos):
368
  """Compute viewing and lighting directions."""
369
  viewdirs = ray_batch[:, 8:11]  # [R, 3]
370
  metadata = ray_utils.extract_slice_from_ray_batch(
371
      ray_batch, key='metadata', use_viewdirs=use_viewdirs)  # [R, 1]
372
  viewdirs = broadcast_samples_dim(x=viewdirs, target=pts)  # [R, S, 3]
373
  light_ray_batch, lightdirs = compute_lightdirs(
374
      pts=pts,
375
      metadata=metadata,
376
      scene_info=scene_info,
377
      use_lightdir_norm=use_lightdir_norm,
378
      use_random_lightdirs=use_random_lightdirs,
379
      light_pos=light_pos)  # [R, S, 3]
380
  return viewdirs, light_ray_batch, lightdirs
381

382

383
def network_query_fn_helper(pts, ray_batch, network, network_query_fn, viewdirs,
384
                            lightdirs, use_viewdirs, use_lightdirs):
385
  """Query the NeRF network."""
386
  if not use_viewdirs:
387
    viewdirs = None
388
  if not use_lightdirs:
389
    lightdirs = None
390
  # Extract unit-normalized viewing direction.
391
  # [n_rays, 3]
392
  # viewdirs = ray_batch[:, 8:11] if use_viewdirs else None
393

394
  # Extract additional per-ray metadata.
395
  # [n_rays, metadata_channels]
396
  rays_data = ray_utils.extract_slice_from_ray_batch(
397
      ray_batch, key='example_id', use_viewdirs=use_viewdirs)
398

399
  # Query NeRF for the corresponding densities for the light points.
400
  raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
401
  return raw
402

403

404
def network_query_fn_helper_nodirs(pts,
405
                                   ray_batch,
406
                                   network,
407
                                   network_query_fn,
408
                                   use_viewdirs,
409
                                   use_lightdirs,
410
                                   use_lightdir_norm,
411
                                   scene_info,
412
                                   use_random_lightdirs,
413
                                   light_origins=None,
414
                                   **kwargs):
415
  """Same as network_query_fn_helper, but without input directions."""
416
  _ = kwargs
417

418
  if not use_viewdirs:
419
    viewdirs = None
420
  if not use_lightdirs:
421
    lightdirs = None
422

423
  # Extract unit-normalized viewing direction.
424
  if use_viewdirs:
425
    viewdirs = ray_batch[:, 8:11]  # [R, 3]
426
    viewdirs = broadcast_samples_dim(x=viewdirs, target=pts)  # [R, S, 3]
427
  else:
428
    viewdirs = None
429

430
  # Compute the light directions.
431
  # if use_lightdirs:
432
  light_ray_batch, lightdirs = compute_lightdirs(  # [R, S, 3]
433
      pts=pts,
434
      metadata=ray_utils.extract_slice_from_ray_batch(
435
          ray_batch, key='metadata', use_viewdirs=use_viewdirs),
436
      scene_info=scene_info,
437
      use_lightdir_norm=use_lightdir_norm,
438
      use_random_lightdirs=use_random_lightdirs,
439
      light_pos=kwargs['light_pos'],
440
      light_origins=light_origins)
441
  # else:
442
  #   light_ray_batch = None
443
  #   lightdirs = None
444

445
  # Extract additional per-ray metadata.
446
  rays_data = ray_utils.extract_slice_from_ray_batch(
447
      ray_batch, key='example_id', use_viewdirs=use_viewdirs)
448

449
  # Query NeRF for the corresponding densities for the light points.
450
  raw = network_query_fn(pts, viewdirs, lightdirs, rays_data, network)
451
  return light_ray_batch, raw
452

453

454
def create_w2o_transformations_tensors(name, scene_info, ray_batch,
455
                                       use_viewdirs, box_delta_t):
456
  """Create transformation tensor from world to object space."""
457
  metadata = ray_utils.extract_slice_from_ray_batch(
458
      ray_batch, key='metadata', use_viewdirs=use_viewdirs)  # [R, 1]
459
  w2o_rt_per_scene, w2o_r_per_scene = (
460
      scene_utils.extract_w2o_transformations_per_scene(
461
          name=name, scene_info=scene_info, box_delta_t=box_delta_t))
462
  w2o_rt = tf.gather_nd(  # [R, 4, 4]
463
      params=w2o_rt_per_scene,  # [N_scenes, 4, 4]
464
      indices=metadata)  # [R, 1]
465
  w2o_r = tf.gather_nd(  # [R, 4, 4]
466
      params=w2o_r_per_scene,  # [N_scenes, 4, 4]
467
      indices=metadata)  # [R, 1]
468
  return w2o_rt, w2o_r
469

470

471
def apply_batched_transformations(inputs, transformations):
472
  """Batched transformation of inputs.
473

474
  Args:
475
      inputs: List of [R, S, 3]
476
      transformations: [R, 4, 4]
477

478
  Returns:
479
      transformed_inputs: List of [R, S, 3]
480
  """
481
  transformed_inputs = []
482
  for x in inputs:
483
    n_samples = tf.shape(x)[1]
484
    homog_transformations = tf.expand_dims(
485
        input=transformations, axis=1)  # [R, 1, 4, 4]
486
    homog_transformations = tf.tile(homog_transformations,
487
                                    [1, n_samples, 1, 1])  # [R, S, 4, 4]
488
    homog_component = tf.ones_like(x)[Ellipsis, 0:1]  # [R, S, 1]
489
    homog_x = tf.concat([x, homog_component], axis=-1)  # [R, S, 4]
490
    homog_x = tf.expand_dims(input=homog_x, axis=2)  # [R, S, 1, 4]
491
    transformed_x = tf.matmul(homog_x,
492
                              tf.transpose(homog_transformations,
493
                                           (0, 1, 3, 2)))  # [R, S, 1, 4]
494
    transformed_x = transformed_x[Ellipsis, 0, :3]  # [R, S, 3]
495
    transformed_inputs.append(transformed_x)
496
  return transformed_inputs
497

498

499
def compute_object_inputs(name, ray_batch, pts, scene_info,
500
                          use_random_lightdirs, **kwargs):
501
  """Compute inputs to object networks."""
502
  # Extract viewing and lighting directions.
503
  # [Ro, S, 3]
504
  object_viewdirs, light_ray_batch, object_lightdirs = compute_view_light_dirs(
505
      ray_batch=ray_batch,
506
      pts=pts,
507
      scene_info=scene_info,
508
      use_viewdirs=kwargs['use_viewdirs'],
509
      use_lightdir_norm=kwargs['use_lightdir_norm'],
510
      use_random_lightdirs=use_random_lightdirs,
511
      light_pos=kwargs['light_pos'])
512

513
  # Transform points and optionally directions from world to canonical
514
  # coordinate frame.
515
  w2o_rt, w2o_r = create_w2o_transformations_tensors(  # [Ro, 4, 4]
516
      name=name,
517
      scene_info=scene_info,
518
      ray_batch=ray_batch,
519
      use_viewdirs=kwargs['use_viewdirs'],
520
      box_delta_t=kwargs['box_delta_t'])  # [Ro, 1]
521
  object_pts = apply_batched_transformations(
522
      inputs=[pts],
523
      transformations=w2o_rt,
524
  )[0]
525
  if kwargs['use_transform_dirs']:
526
    # pylint: disable=unbalanced-tuple-unpacking
527
    [object_viewdirs, object_lightdirs] = apply_batched_transformations(
528
        inputs=[object_viewdirs, object_lightdirs], transformations=w2o_r)
529
    # pylint: enable=unbalanced-tuple-unpacking
530
  return object_pts, object_viewdirs, light_ray_batch, object_lightdirs
531

532

533
def normalize_rgb(raw_rgb, scaled_sigmoid):
534
  # Extract RGB of each sample position along each ray.
535
  rgb = tf.math.sigmoid(raw_rgb)  # [n_rays, n_samples, 3]
536
  if scaled_sigmoid:
537
    rgb = 1.2 * (rgb - 0.5) + 0.5  # [n_rays, n_samples, 3]
538
  return rgb
539

540

541
def normalize_raw(raw, z_vals, scaled_sigmoid, raw_noise_std, last_dist_method):
542
  """Normalize raw outputs of the network."""
543
  # Compute weight for RGB of each sample along each ray.  A cumprod() is
544
  # used to express the idea of the ray not having reflected up to this
545
  # sample yet.
546
  # [n_rays, n_samples]
547
  alpha = compute_alpha(
548
      z_vals=z_vals,
549
      raw_alpha=raw['alpha'],
550
      raw_noise_std=raw_noise_std,
551
      last_dist_method=last_dist_method)
552
  normalized = {
553
      'rgb': normalize_rgb(raw_rgb=raw['rgb'], scaled_sigmoid=scaled_sigmoid),
554
      'alpha': alpha
555
  }
556
  return normalized
557

558

559
def run_sparse_network(name, network, intersect_z_vals, intersect_pts,
560
                       intersect_ray_batch, use_random_lightdirs, **kwargs):
561
  """Runs a single network."""
562
  if name.startswith('bkgd'):
563
    intersect_light_ray_batch, intersect_raw = network_query_fn_helper_nodirs(
564
        network=network,
565
        pts=intersect_pts,  # [R, S, 3]
566
        ray_batch=intersect_ray_batch,  # [R, M]
567
        use_random_lightdirs=use_random_lightdirs,
568
        **kwargs)  # [R, 3]
569
  else:
570
    (object_intersect_pts, object_intersect_viewdirs, intersect_light_ray_batch,
571
     object_intersect_lightdirs) = compute_object_inputs(
572
         name=name,
573
         ray_batch=intersect_ray_batch,
574
         pts=intersect_pts,
575
         use_random_lightdirs=use_random_lightdirs,
576
         **kwargs)
577

578
    # Query the object NeRF.
579
    intersect_raw = network_query_fn_helper(
580
        pts=object_intersect_pts,
581
        ray_batch=intersect_ray_batch,
582
        network=network,
583
        network_query_fn=kwargs['network_query_fn'],
584
        viewdirs=object_intersect_viewdirs,
585
        lightdirs=object_intersect_lightdirs,
586
        use_viewdirs=kwargs['use_viewdirs'],
587
        use_lightdirs=kwargs['use_lightdirs'])
588

589
  # Compute weights of the intersecting points on normalized raw values.
590
  normalized_raw = normalize_raw(  # [Ro, S, 4]
591
      raw=intersect_raw,  # [Ro, S, 4]
592
      z_vals=intersect_z_vals,  # [Ro, S]
593
      scaled_sigmoid=kwargs['scaled_sigmoid'],
594
      raw_noise_std=kwargs['raw_noise_std'],
595
      last_dist_method=kwargs['last_dist_method'])
596
  return intersect_light_ray_batch, normalized_raw
597

598

599
def compute_transmittance(alpha):
600
  """Computes transmittance from (normalized) alpha values.
601

602
  Args:
603
    alpha: [R, S]
604

605
  Returns:
606
    t: [R, S]
607
  """
608
  # Compute the accumulated transmittance along the ray at each point.
609
  print(f'[compute_transmittance] alpha.shape: {alpha.shape}')
610
  # TODO(guom): fix this
611
  t = 1. - alpha
612
  return t
613

614

615
def compute_weights(normalized_alpha):
616
  trans = compute_transmittance(normalized_alpha)
617
  weights = normalized_alpha * trans
618
  return trans, weights
619

620

621
def run_single_object(name, ray_batch, use_random_lightdirs, **kwargs):
622
  """Run and generate predictions for a single object.
623

624
  Args:
625
    name: The name of the object to run.
626
    ray_batch: [R, M] tf.float32. A batch of rays.
627
    use_random_lightdirs:
628
    **kwargs: Additional arguments.
629

630
  Returns:
631
    intersect_0: Dict.
632
    intersect: Dict.
633
  """
634
  # Compute intersection rays and indices.
635
  if name.startswith('bkgd'):
636
    intersect_ray_batch = ray_batch
637
    intersect_indices = None
638
  else:
639
    intersect_ray_batch, intersect_indices = compute_object_intersect_tensors(
640
        name=name, ray_batch=ray_batch, **kwargs)
641
  # Run coarse stage.
642
  intersect_z_vals_0, intersect_pts_0 = default_ray_sampling(
643
      ray_batch=intersect_ray_batch,
644
      n_samples=kwargs['n_samples'],
645
      perturb=kwargs['perturb'],
646
      lindisp=kwargs['lindisp'])
647

648
  intersect_light_ray_batch_0, normalized_raw_0 = run_sparse_network(
649
      name=name,
650
      network=kwargs['name2model'][name],
651
      intersect_z_vals=intersect_z_vals_0,
652
      intersect_pts=intersect_pts_0,
653
      intersect_ray_batch=intersect_ray_batch,
654
      use_random_lightdirs=use_random_lightdirs,
655
      **kwargs)
656

657
  # Run fine stage.
658
  if kwargs['N_importance'] > 0:
659
    # normalized_alpha_0 = normalized_raw_0['alpha']
660
    _, intersect_weights_0 = compute_weights(
661
        normalized_alpha=normalized_raw_0['alpha'][Ellipsis, 0])  # [Ro, S]
662
    # Generate fine samples using weights from the coarse stage.
663
    intersect_z_vals, _, intersect_pts, _ = default_ray_sampling_fine(
664
        ray_batch=intersect_ray_batch,  # [Ro, M]
665
        z_vals=intersect_z_vals_0,  # [Ro, S]
666
        weights=intersect_weights_0,  # [Ro, S]
667
        n_samples=kwargs['N_importance'],
668
        perturb=kwargs['perturb'])
669

670
    # Run the networks for all the objects.
671
    intersect_light_ray_batch, normalized_raw = run_sparse_network(
672
        name=name,
673
        network=kwargs['name2model'][name],
674
        intersect_z_vals=intersect_z_vals,
675
        intersect_pts=intersect_pts,
676
        intersect_ray_batch=intersect_ray_batch,
677
        use_random_lightdirs=use_random_lightdirs,
678
        **kwargs)
679

680
  intersect_0 = {
681
      'ray_batch': intersect_ray_batch,  # [R, M]
682
      'light_ray_batch': intersect_light_ray_batch_0,  # [R, S, M]
683
      'indices': intersect_indices,
684
      'z_vals': intersect_z_vals_0,
685
      'pts': intersect_pts_0,
686
      'normalized_rgb': normalized_raw_0['rgb'],
687
      'normalized_alpha': normalized_raw_0['alpha'],
688
  }
689
  intersect = {
690
      'ray_batch': intersect_ray_batch,  # [R, M]
691
      'light_ray_batch': intersect_light_ray_batch,  # [R, S, M]
692
      'indices': intersect_indices,
693
      'z_vals': intersect_z_vals,
694
      'pts': intersect_pts,
695
      'normalized_rgb': normalized_raw['rgb'],
696
      'normalized_alpha': normalized_raw['alpha'],
697
  }
698
  return intersect_0, intersect
699

700

701
def create_scatter_indices_for_dim(dim, shape, indices=None):
702
  """Create scatter indieces for a given dimension."""
703
  dim_size = shape[dim]
704
  n_dims = len(shape)
705
  reshape = [1] * n_dims
706
  reshape[dim] = -1
707

708
  if indices is None:
709
    indices = tf.range(dim_size, dtype=tf.int32)  # [dim_size,]
710

711
  indices = tf.reshape(indices, reshape)  # [1, ..., dim_size, ..., 1]
712

713
  tf.debugging.assert_equal(tf.shape(indices)[dim], shape[dim])
714
  indices = tf.broadcast_to(
715
      indices, shape)  # [Ro, S, 1] or [Ro, S, C, 1]  [0,1,1,1] vs. [512,64,1,1]
716

717
  indices = tf.cast(indices, dtype=tf.int32)
718
  return indices
719

720

721
def create_scatter_indices(updates, dim2known_indices):
722
  """Create scatter indices."""
723
  updates_expanded = tf.expand_dims(updates, -1)  # [Ro, S, 1] or [Ro, S, C, 1]
724
  target_shape = tf.shape(updates_expanded)
725
  n_dims = len(tf.shape(updates))  # 2 or 3
726

727
  dim_indices_list = []
728
  for dim in range(n_dims):
729
    indices = None
730
    if dim in dim2known_indices:
731
      indices = dim2known_indices[dim]
732
    dim_indices = create_scatter_indices_for_dim(  # [Ro, S, C, 1]
733
        dim=dim,
734
        shape=target_shape,  # [Ro, S, 1] or [Ro, S, C, 1]
735
        indices=indices)  # [Ro,]
736
    dim_indices_list.append(dim_indices)
737
  scatter_indices = tf.concat(dim_indices_list, axis=-1)  # [Ro, S, C, 3]
738
  return scatter_indices
739

740

741
def scatter_nd(tensor, updates, dim2known_indices):
742
  scatter_indices = create_scatter_indices(  # [Ro, S, C, 3]
743
      updates=updates,  # [Ro, S]
744
      dim2known_indices=dim2known_indices)  # [Ro,]
745
  scattered_tensor = tf.tensor_scatter_nd_update(
746
      tensor=tensor,  # [R, S, C]
747
      indices=scatter_indices,  # [Ro, S, C, 3]
748
      updates=updates)  # [Ro, S, C]
749
  return scattered_tensor
750

751

752
def scatter_results(intersect, n_rays, keys):
753
  """Scatters intersecting ray results into the original set of rays.
754

755
  Args:
756
    intersect: Dict. Intersecting values.
757
    n_rays: int or tf.int32. Total number of rays.
758
    keys: [str]. List of keys to scatter.
759

760
  Returns:
761
    scattered_results: Dict. Scattered results.
762
  """
763
  # We use `None` to indicate that the intersecting set of rays is equivalent to
764
  # the full set of rays, so we are done.
765
  intersect_indices = intersect['indices']
766
  if intersect_indices is None:
767
    return {k: intersect[k] for k in keys}
768

769
  scattered_results = {}
770
  n_samples = intersect['z_vals'].shape[1]
771
  dim2known_indices = {0: intersect_indices}  # [R?, 1]
772
  for k in keys:
773
    if k == 'z_vals':
774
      tensor = tf.random.uniform((n_rays, n_samples),
775
                                 dtype=tf.float32)  # [R, S]
776
    elif k == 'pts':
777
      tensor = tf.cast(  # [R, S, 3]
778
          tf.fill((n_rays, n_samples, 3), 1000.0),
779
          dtype=tf.float32)
780
    elif 'rgb' in k:
781
      tensor = tf.zeros((n_rays, n_samples, 3), dtype=tf.float32)  # [R, S, 3]
782
    elif 'alpha' in k:
783
      tensor = tf.zeros((n_rays, n_samples, 1), dtype=tf.float32)  # [R, S, 1]
784
    else:
785
      raise ValueError(f'Invalid key: {k}')
786
    scattered_v = scatter_nd(  # [R, S, K]
787
        tensor=tensor,
788
        updates=intersect[k],  # [Ro, S]
789
        dim2known_indices=dim2known_indices)
790
    # Convert the batch dimension to a known dimension.
791
    # For some reason `scattered_z_vals` becomes [R, ?]. We need to explicitly
792
    # reshape it with `n_samples`.
793
    if k == 'z_vals':
794
      scattered_v = tf.reshape(scattered_v, (n_rays, n_samples))  # [R, S]
795
    else:
796
      # scattered_v = tf.reshape(
797
      #     scattered_v, (n_rays,) + scattered_v.shape[1:])  # [R, S, K]
798
      # scattered_v = tf.reshape(
799
      # scattered_v, (-1,) + scattered_v.shape[1:])  # [R, S, K]
800
      scattered_v = tf.reshape(
801
          scattered_v, (n_rays, n_samples, tensor.shape[2]))  # [R, S, K]
802
    scattered_results[k] = scattered_v
803
  return scattered_results
804

805

806
def combine_results(name2results, keys):
807
  """Combines network outputs.
808

809
  Args:
810
    name2results: Dict. For each object results, `z_vals` is required.
811
    keys: [str]. A list of keys to combine results over.
812

813
  Returns:
814
    results: Dict. Combined results.
815
  """
816
  # Collect z values across all objects.
817
  z_vals_list = []
818
  for _, results in name2results.items():
819
    z_vals_list.append(results['z_vals'])
820

821
  # Concatenate lists of object results into a single tensor.
822
  z_vals = tf.concat(z_vals_list, axis=-1)  # [R, S*O]
823

824
  # Compute the argsort indices.
825
  z_argsort_indices = tf.argsort(z_vals, -1)  # [R, S*O]
826
  n_rays, n_samples = tf.shape(z_vals)[0], tf.shape(z_vals)[1]
827
  gather_indices = tf.range(n_rays)[:, None]  # [R, 1]
828
  gather_indices = tf.tile(gather_indices, [1, n_samples])  # [R, S]
829
  gather_indices = tf.concat(
830
      [gather_indices[Ellipsis, None], z_argsort_indices[Ellipsis, None]], axis=-1)
831

832
  results = {}
833
  for k in keys:
834
    if k == 'z_vals':
835
      v_combined = z_vals
836
    else:
837
      v_list = [r[k] for r in name2results.values()]
838
      v_combined = tf.concat(v_list, axis=1)  # [R, S*O, K]
839

840
    # Sort the tensors.
841
    v_sorted = tf.gather_nd(  # [R, S, K]
842
        params=v_combined,  # [R, S, K]
843
        indices=gather_indices)  # [R, S, 2]
844
    results[k] = v_sorted
845
  return results
846

847

848
def compose_outputs(results, light_rgb, white_bkgd):
849
  del results
850
  del light_rgb
851
  del white_bkgd
852
  return -1
853

854

855

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.