google-research
589 строк · 21.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation script for mipNeRF360."""
17
18import functools
19import gc
20from os import path
21import sys
22import time
23
24from absl import app
25from absl import logging
26import chex
27import flax
28from flax.metrics import tensorboard
29from flax.training import checkpoints
30import gin
31from internal import alignment
32from internal import camera_utils
33from internal import configs
34from internal import datasets
35from internal import image_io
36from internal import image_utils
37from internal import models
38from internal import ref_utils
39from internal import train_utils
40from internal import utils
41from internal import vis
42import jax
43from jax import random
44import jax.numpy as jnp
45import jaxcam
46import numpy as np
47
48
49configs.define_common_flags()
50jax.config.parse_flags_with_absl()
51
52
53def plot_camera_metrics(
54*,
55summary_writer,
56camera_params,
57train_cameras,
58train_cameras_gt,
59config,
60step,
61tag,
62):
63"""Plots camera statistics to TensorBoard."""
64camera_delta = config.camera_delta_cls()
65optimized_cameras: jaxcam.Camera = camera_delta.apply(
66camera_params, train_cameras
67)
68diffs = camera_utils.compute_camera_metrics(
69train_cameras_gt, optimized_cameras
70)
71reduce_fns = {
72'mean': np.mean,
73'max': np.max,
74'std': np.std,
75}
76for reduce_name, reduce_fn in reduce_fns.items():
77for stat_name, stat in diffs.items():
78summary_writer.scalar(
79f'eval_train_camera_{tag}_{reduce_name}/{stat_name}',
80reduce_fn(np.array(stat)),
81step=step,
82)
83
84# pylint: disable=cell-var-from-loop
85return [
86jax.tree_util.tree_map(lambda x: float(x[i]), diffs)
87for i in range(len(train_cameras))
88]
89# pylint: enable=cell-var-from-loop
90
91
92def main(unused_argv):
93config = configs.load_config(save_config=False)
94
95train_dataset = datasets.load_dataset('train', config.data_dir, config)
96test_dataset = datasets.load_dataset('test', config.data_dir, config)
97
98key = random.PRNGKey(20200823)
99model, state, render_eval_pfn, _, _ = train_utils.setup_model(
100config, key, dataset=train_dataset
101)
102if config.rawnerf_mode:
103postprocess_fn = test_dataset.metadata['postprocess_fn']
104else:
105postprocess_fn = lambda z: z
106
107metric_harness = image_utils.MetricHarness(
108**config.metric_harness_eval_config
109)
110
111last_step = 0
112out_dir = path.join(
113config.checkpoint_dir,
114'path_renders' if config.render_path else 'test_preds',
115)
116path_fn = lambda x: path.join(out_dir, x)
117
118if not config.eval_only_once:
119summary_writer = tensorboard.SummaryWriter(
120path.join(config.checkpoint_dir, 'eval')
121)
122
123jnp_cameras = None
124if config.cast_rays_in_eval_step:
125np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x
126jnp_cameras = jax.tree_util.tree_map(np_to_jax, test_dataset.cameras)
127
128jnp_cameras_replicated = flax.jax_utils.replicate(jnp_cameras)
129
130last_eval_time = time.time()
131while True:
132state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
133step = int(state.step)
134state_params_replicated = flax.jax_utils.replicate(state.params)
135
136if step <= last_step:
137if time.time() - last_eval_time > config.eval_checkpoint_wait_timeout_sec:
138raise RuntimeError(
139'Waited for a new checkpoint for'
140f' {config.eval_checkpoint_wait_timeout_sec} seconds, got no new'
141' checkpoint. This likely means that the training script has died.'
142' Exiting. If this is expected, increase'
143' config.eval_checkpoint_wait_timeout_sec.'
144)
145logging.info(
146'Checkpoint step %d <= last step %d, sleeping.', step, last_step
147)
148time.sleep(10)
149continue
150
151last_eval_time = time.time()
152
153logging.info('Evaluating checkpoint at step %d.', step)
154if config.eval_save_output and (not utils.isdir(out_dir)):
155utils.makedirs(out_dir)
156
157num_eval = min(test_dataset.size, config.eval_dataset_limit)
158key = random.PRNGKey(0 if config.deterministic_showcase else step)
159perm = random.permutation(key, num_eval)
160showcase_indices = np.sort(perm[: config.num_showcase_images])
161
162metrics = []
163metrics_aligned_optimized = []
164metrics_aligned_procrustes = []
165metrics_cameras = []
166metrics_cameras_procrustes = []
167showcases = []
168render_times = []
169
170state_params_replicated = flax.jax_utils.replicate(state.params)
171
172compute_aligned_metric = (
173config.optimize_test_cameras and step == config.max_steps
174)
175
176procrustes_cameras = None
177if config.compute_procrustes_metric and config.optimize_cameras:
178test_image_sizes = np.array(
179[(x.shape[1], x.shape[0]) for x in test_dataset.images]
180)
181test_jax_cameras = jax.vmap(test_dataset.jax_camera_from_tuple_fn)(
182test_dataset.cameras, test_image_sizes
183)
184train_jax_cameras = train_dataset.get_train_cameras(
185config, return_jax_cameras=True
186)
187train_jax_cameras_gt = train_dataset.jax_cameras
188camera_params = state.params['camera_params']
189camera_delta = config.camera_delta_cls()
190train_jax_cameras_opt = camera_delta.apply(
191camera_params, train_jax_cameras
192)
193train_jax_cameras_procrustes, test_jax_cameras_procrustes = (
194alignment.compute_procrusted_aligned_cameras(
195train_jax_cameras_gt=train_jax_cameras_gt,
196train_jax_cameras_opt=train_jax_cameras_opt,
197test_jax_cameras=test_jax_cameras,
198)
199)
200metrics_cameras = plot_camera_metrics(
201summary_writer=summary_writer,
202camera_params=camera_params,
203train_cameras=train_jax_cameras,
204train_cameras_gt=train_jax_cameras_gt,
205config=config,
206step=step,
207tag='error',
208)
209metrics_cameras_procrustes = plot_camera_metrics(
210summary_writer=summary_writer,
211camera_params=camera_params,
212train_cameras=train_jax_cameras,
213train_cameras_gt=train_jax_cameras_procrustes,
214config=config,
215step=step,
216tag='error_procrustes',
217)
218# Convert to tuples.
219procrustes_cameras = jax.vmap(camera_utils.tuple_from_jax_camera)(
220test_jax_cameras_procrustes
221)
222procrustes_cameras = (*procrustes_cameras, *test_dataset.cameras[3:])
223
224procrustes_cameras_replicated = flax.jax_utils.replicate(
225procrustes_cameras
226)
227
228raybatcher = datasets.RayBatcher(test_dataset)
229for idx in range(test_dataset.size):
230gc.collect()
231with jax.profiler.StepTraceAnnotation('eval', step_num=idx):
232eval_start_time = time.time()
233batch = next(raybatcher)
234if idx >= num_eval:
235logging.info('Skipping image %d/%d', idx + 1, test_dataset.size)
236continue
237logging.info('Evaluating image %d/%d', idx + 1, test_dataset.size)
238rays = batch.rays
239train_frac = state.step / config.max_steps
240
241def _render_image(cameras, rays, train_frac):
242return models.render_image( # pytype: disable=wrong-arg-types # jnp-array
243functools.partial(
244render_eval_pfn,
245state_params_replicated,
246train_frac,
247cameras,
248),
249rays=rays,
250rng=None,
251config=config,
252return_all_levels=True,
253)
254
255if compute_aligned_metric:
256jnp_camera_optimized = alignment.align_test_camera(
257model, state, idx, test_dataset, config
258)
259jnp_camera_optimized_replicated = flax.jax_utils.replicate(
260jnp_camera_optimized
261)
262rendering_aligned_optimized = _render_image(
263jnp_camera_optimized_replicated, rays, train_frac
264)
265rendering_aligned_optimized = jax.tree_util.tree_map(
266np.asarray, rendering_aligned_optimized
267)
268
269if procrustes_cameras is not None:
270rendering_aligned_procrustes = _render_image(
271procrustes_cameras_replicated, rays, train_frac
272)
273rendering_aligned_procrustes = jax.tree_util.tree_map(
274np.asarray, rendering_aligned_procrustes
275)
276
277rendering = _render_image(jnp_cameras_replicated, rays, train_frac)
278rendering = jax.tree_util.tree_map(np.asarray, rendering)
279rays = jax.tree_util.tree_map(np.asarray, rays)
280
281if jax.host_id() != 0: # Only record via host 0.
282continue
283
284render_times.append((time.time() - eval_start_time))
285logging.info('Rendered in %0.3fs', render_times[-1])
286
287# Cast to 64-bit to ensure high precision for color correction function.
288gt_rgb = np.array(batch.rgb, dtype=np.float64)
289rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64)
290if compute_aligned_metric:
291rendering['rgb_aligned_optimized'] = np.array(
292rendering_aligned_optimized['rgb'], dtype=np.float64
293)
294if procrustes_cameras is not None:
295rendering['rgb_aligned_procrustes'] = np.array(
296rendering_aligned_procrustes['rgb'], dtype=np.float64
297)
298
299if not config.eval_only_once and idx in showcase_indices:
300showcase_idx = (
301idx if config.deterministic_showcase else len(showcases)
302)
303showcases.append((showcase_idx, rendering, batch))
304if not config.render_path:
305rgb = postprocess_fn(rendering['rgb'])
306if compute_aligned_metric:
307rgb_aligned_optimized = postprocess_fn(
308rendering['rgb_aligned_optimized']
309)
310
311if procrustes_cameras is not None:
312rgb_aligned_procrustes = postprocess_fn(
313rendering['rgb_aligned_procrustes']
314)
315rgb_gt = postprocess_fn(gt_rgb)
316
317if config.eval_quantize_metrics:
318# Ensures that the images written to disk reproduce the metrics.
319rgb = np.round(rgb * 255) / 255
320
321if config.eval_crop_borders > 0:
322crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]
323rgb = crop_fn(rgb)
324rgb_gt = crop_fn(rgb_gt)
325if compute_aligned_metric:
326rgb_aligned_optimized = crop_fn(rgb_aligned_optimized)
327if 'rgb_aligned_procrustes' in rendering:
328rgb_aligned_procrustes = crop_fn(rgb_aligned_procrustes)
329
330metric = metric_harness(rgb, rgb_gt)
331if compute_aligned_metric:
332metric_aligned_optimized = metric_harness(
333rgb_aligned_optimized, rgb_gt
334)
335if procrustes_cameras is not None:
336metric_aligned_procrustes = metric_harness(
337rgb_aligned_procrustes, rgb_gt
338)
339
340if config.compute_disp_metrics:
341for tag in ['mean', 'median']:
342key = f'distance_{tag}'
343if key in rendering:
344disparity = 1 / (1 + rendering[key][-1])
345metric[f'disparity_{tag}_mse'] = float(
346((disparity - batch.disps) ** 2).mean()
347)
348
349if config.compute_normal_metrics:
350weights = rendering['acc'][-1] * batch.alphas
351normalized_normals_gt = ref_utils.l2_normalize(batch.normals)
352for key, val in rendering.items():
353if key.startswith('normals') and val is not None:
354normalized_normals = ref_utils.l2_normalize(val[-1])
355metric[key + '_mae'] = ref_utils.compute_weighted_mae(
356weights, normalized_normals, normalized_normals_gt
357)
358
359for m, v in metric.items():
360logging.info('%s = %0.4f', m, v)
361
362metrics.append(metric)
363if compute_aligned_metric:
364metrics_aligned_optimized.append(metric_aligned_optimized)
365if procrustes_cameras is not None:
366metrics_aligned_procrustes.append(metric_aligned_procrustes)
367
368if config.eval_save_output and (config.eval_render_interval > 0):
369if (idx % config.eval_render_interval) == 0:
370image_io.save_img_u8(
371postprocess_fn(rendering['rgb']),
372path_fn(f'color_{idx:03d}.png'),
373)
374if compute_aligned_metric:
375image_io.save_img_u8(
376postprocess_fn(rendering['rgb_aligned_optimized']),
377path_fn(f'color_aligned_optimized_{idx:03d}.png'),
378)
379if procrustes_cameras is not None:
380image_io.save_img_u8(
381postprocess_fn(rendering['rgb_aligned_procrustes']),
382path_fn(f'color_aligned_procrustes_{idx:03d}.png'),
383)
384
385for key in ['distance_mean', 'distance_median']:
386if key in rendering:
387image_io.save_img_f32(
388rendering[key][-1], path_fn(f'{key}_{idx:03d}.tiff')
389)
390
391for key in ['normals']:
392if key in rendering:
393image_io.save_img_u8(
394rendering[key][-1] / 2.0 + 0.5,
395path_fn(f'{key}_{idx:03d}.png'),
396)
397
398if 'acc' in rendering:
399image_io.save_img_f32(
400rendering['acc'][-1], path_fn(f'acc_{idx:03d}.tiff')
401)
402
403if batch.masks is not None:
404image_io.save_img_u8(
405batch.rgb * batch.masks,
406path_fn(f'masked_input_{idx:03d}.png'),
407)
408
409if (not config.eval_only_once) and (jax.host_id() == 0):
410summary_writer.scalar(
411'eval_median_render_time', np.median(render_times), step
412)
413
414def summarize_metrics(metrics, metrics_suffix):
415for name in metrics[0]:
416scores = [m[name] for m in metrics]
417prefix = f'eval_metrics{metrics_suffix}/'
418summary_writer.scalar(prefix + name, np.mean(scores), step)
419summary_writer.histogram(prefix + 'perimage_' + name, scores, step)
420
421summarize_metrics(metrics, '')
422if compute_aligned_metric:
423summarize_metrics(metrics_aligned_optimized, '_aligned_optimized')
424if procrustes_cameras is not None:
425summarize_metrics(metrics_aligned_procrustes, '_aligned_procrustes')
426
427if config.multiscale_train_factors is not None:
428factors = [1] + list(config.multiscale_train_factors)
429n_images = len(metrics) // len(factors)
430# Split metrics into chunks of n_images (each downsampling level).
431for i, f in enumerate(factors):
432i0 = i * n_images
433i1 = (i + 1) * n_images
434image_shapes = np.array([z.shape for z in test_dataset.images[i0:i1]])
435if not np.all(image_shapes == image_shapes[0]):
436raise ValueError(
437'Not all image shapes match for downsampling '
438f'factor {f}x in evaluation'
439)
440summarize_metrics(metrics[i0:i1], f'_{f}x')
441if compute_aligned_metric:
442summarize_metrics(
443metrics_aligned_optimized[i0:i1], f'_{f}x_aligned_optimized'
444)
445if procrustes_cameras is not None:
446summarize_metrics(
447metrics_aligned_procrustes[i0:i1], f'_{f}x_aligned_procrustes'
448)
449
450for i, r, b in showcases:
451if config.vis_decimate > 1:
452d = config.vis_decimate
453decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
454else:
455decimate_fn = lambda x: x
456r = jax.tree_util.tree_map(decimate_fn, r)
457b = jax.tree_util.tree_map(decimate_fn, b)
458for k, v in vis.visualize_suite(r).items():
459if isinstance(v, list):
460for ii, vv in enumerate(v):
461summary_writer.image(f'output_{k}_{i}/{ii}', vv, step)
462else:
463summary_writer.image(f'output_{k}_{i}', v, step)
464if b.masks is not None:
465mask_float_array = jax.numpy.asarray(b.masks, dtype=jax.numpy.float32)
466summary_writer.image(f'mask_{i}', mask_float_array, step)
467summary_writer.image(
468f'masked_image_{i}', b.rgb * mask_float_array, step
469)
470if not config.render_path:
471target = postprocess_fn(b.rgb)
472pred = postprocess_fn(r['rgb'])
473if compute_aligned_metric:
474pred_aligned_optimized = postprocess_fn(r['rgb_aligned_optimized'])
475if procrustes_cameras is not None:
476pred_aligned_procrustes = postprocess_fn(
477r['rgb_aligned_procrustes']
478)
479summary_writer.image(f'output_color_{i}', pred, step)
480if compute_aligned_metric:
481summary_writer.image(
482f'output_color_aligned_optimized_{i}',
483pred_aligned_optimized,
484step,
485)
486if procrustes_cameras is not None:
487summary_writer.image(
488f'output_color_aligned_procrustes_{i}',
489pred_aligned_procrustes,
490step,
491)
492summary_writer.image(f'true_color_{i}', target, step)
493residual = pred - target
494summary_writer.image(
495f'output_residual_{i}', np.clip(residual + 0.5, 0, 1), step
496)
497if compute_aligned_metric:
498residual_aligned_optimized = pred_aligned_optimized - target
499summary_writer.image(
500f'output_residual_aligned_{i}',
501np.clip(residual_aligned_optimized + 0.5, 0, 1),
502step,
503)
504if procrustes_cameras is not None:
505residual_aligned_procrustes = pred_aligned_procrustes - target
506summary_writer.image(
507f'output_residual_aligned_{i}',
508np.clip(residual_aligned_procrustes + 0.5, 0, 1),
509step,
510)
511residual_hist = image_utils.render_histogram(
512np.array(residual).reshape([-1, 3]),
513bins=32,
514range=(-1, 1),
515log=True,
516color=('r', 'g', 'b'),
517)
518summary_writer.image(f'output_residual_hist_{i}', residual_hist, step)
519if config.compute_normal_metrics:
520summary_writer.image(
521f'true_normals_{i}', b.normals / 2.0 + 0.5, step
522)
523
524if (
525config.eval_save_output
526and (not config.render_path)
527and (jax.host_id() == 0)
528):
529with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:
530f.write(' '.join([str(r) for r in render_times]))
531for name in metrics[0]:
532with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
533f.write(' '.join([str(m[name]) for m in metrics]))
534if compute_aligned_metric:
535for name in metrics_aligned_optimized[0]:
536with utils.open_file(
537path_fn(f'metric_aligned_optimized_{name}_{step}.txt'), 'w'
538) as f:
539f.write(' '.join([str(m[name]) for m in metrics_aligned_optimized]))
540if procrustes_cameras is not None:
541for name in metrics_aligned_procrustes[0]:
542with utils.open_file(
543path_fn(f'metric_aligned_procrustes_{name}_{step}.txt'), 'w'
544) as f:
545f.write(
546' '.join([str(m[name]) for m in metrics_aligned_procrustes])
547)
548if metrics_cameras:
549for name in metrics_cameras[0]:
550with utils.open_file(
551path_fn(f'metric_cameras_{name}_{step}.txt'), 'w'
552) as f:
553f.write(' '.join([str(m[name]) for m in metrics_cameras]))
554if metrics_cameras_procrustes:
555for name in metrics_cameras_procrustes[0]:
556with utils.open_file(
557path_fn(f'metric_cameras_procrustes_{name}_{step}.txt'), 'w'
558) as f:
559f.write(
560' '.join([str(m[name]) for m in metrics_cameras_procrustes])
561)
562if config.eval_save_ray_data:
563for i, r, b in showcases:
564rays = {k: v for k, v in r.items() if 'ray_' in k}
565np.set_printoptions(threshold=sys.maxsize)
566with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:
567f.write(repr(rays))
568
569# A hack that forces Jax to keep all TPUs alive until every TPU is finished.
570x = jnp.ones([jax.local_device_count()])
571x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
572print(x)
573
574if config.eval_only_once:
575logging.info('Eval only once enabled, shutting down.')
576break
577if config.early_exit_steps is not None:
578num_steps = config.early_exit_steps
579else:
580num_steps = config.max_steps
581if int(step) >= num_steps:
582logging.info('Termination num steps reached (%d).', num_steps)
583break
584last_step = step
585
586
587if __name__ == '__main__':
588with gin.config_scope('eval'):
589app.run(main)
590