google-research

Форк
0
589 строк · 21.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation script for mipNeRF360."""
17

18
import functools
19
import gc
20
from os import path
21
import sys
22
import time
23

24
from absl import app
25
from absl import logging
26
import chex
27
import flax
28
from flax.metrics import tensorboard
29
from flax.training import checkpoints
30
import gin
31
from internal import alignment
32
from internal import camera_utils
33
from internal import configs
34
from internal import datasets
35
from internal import image_io
36
from internal import image_utils
37
from internal import models
38
from internal import ref_utils
39
from internal import train_utils
40
from internal import utils
41
from internal import vis
42
import jax
43
from jax import random
44
import jax.numpy as jnp
45
import jaxcam
46
import numpy as np
47

48

49
configs.define_common_flags()
50
jax.config.parse_flags_with_absl()
51

52

53
def plot_camera_metrics(
54
    *,
55
    summary_writer,
56
    camera_params,
57
    train_cameras,
58
    train_cameras_gt,
59
    config,
60
    step,
61
    tag,
62
):
63
  """Plots camera statistics to TensorBoard."""
64
  camera_delta = config.camera_delta_cls()
65
  optimized_cameras: jaxcam.Camera = camera_delta.apply(
66
      camera_params, train_cameras
67
  )
68
  diffs = camera_utils.compute_camera_metrics(
69
      train_cameras_gt, optimized_cameras
70
  )
71
  reduce_fns = {
72
      'mean': np.mean,
73
      'max': np.max,
74
      'std': np.std,
75
  }
76
  for reduce_name, reduce_fn in reduce_fns.items():
77
    for stat_name, stat in diffs.items():
78
      summary_writer.scalar(
79
          f'eval_train_camera_{tag}_{reduce_name}/{stat_name}',
80
          reduce_fn(np.array(stat)),
81
          step=step,
82
      )
83

84
  # pylint: disable=cell-var-from-loop
85
  return [
86
      jax.tree_util.tree_map(lambda x: float(x[i]), diffs)
87
      for i in range(len(train_cameras))
88
  ]
89
  # pylint: enable=cell-var-from-loop
90

91

92
def main(unused_argv):
93
  config = configs.load_config(save_config=False)
94

95
  train_dataset = datasets.load_dataset('train', config.data_dir, config)
96
  test_dataset = datasets.load_dataset('test', config.data_dir, config)
97

98
  key = random.PRNGKey(20200823)
99
  model, state, render_eval_pfn, _, _ = train_utils.setup_model(
100
      config, key, dataset=train_dataset
101
  )
102
  if config.rawnerf_mode:
103
    postprocess_fn = test_dataset.metadata['postprocess_fn']
104
  else:
105
    postprocess_fn = lambda z: z
106

107
  metric_harness = image_utils.MetricHarness(
108
      **config.metric_harness_eval_config
109
  )
110

111
  last_step = 0
112
  out_dir = path.join(
113
      config.checkpoint_dir,
114
      'path_renders' if config.render_path else 'test_preds',
115
  )
116
  path_fn = lambda x: path.join(out_dir, x)
117

118
  if not config.eval_only_once:
119
    summary_writer = tensorboard.SummaryWriter(
120
        path.join(config.checkpoint_dir, 'eval')
121
    )
122

123
  jnp_cameras = None
124
  if config.cast_rays_in_eval_step:
125
    np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x
126
    jnp_cameras = jax.tree_util.tree_map(np_to_jax, test_dataset.cameras)
127

128
  jnp_cameras_replicated = flax.jax_utils.replicate(jnp_cameras)
129

130
  last_eval_time = time.time()
131
  while True:
132
    state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
133
    step = int(state.step)
134
    state_params_replicated = flax.jax_utils.replicate(state.params)
135

136
    if step <= last_step:
137
      if time.time() - last_eval_time > config.eval_checkpoint_wait_timeout_sec:
138
        raise RuntimeError(
139
            'Waited for a new checkpoint for'
140
            f' {config.eval_checkpoint_wait_timeout_sec} seconds, got no new'
141
            ' checkpoint. This likely means that the training script has died.'
142
            ' Exiting. If this is expected, increase'
143
            ' config.eval_checkpoint_wait_timeout_sec.'
144
        )
145
      logging.info(
146
          'Checkpoint step %d <= last step %d, sleeping.', step, last_step
147
      )
148
      time.sleep(10)
149
      continue
150

151
    last_eval_time = time.time()
152

153
    logging.info('Evaluating checkpoint at step %d.', step)
154
    if config.eval_save_output and (not utils.isdir(out_dir)):
155
      utils.makedirs(out_dir)
156

157
    num_eval = min(test_dataset.size, config.eval_dataset_limit)
158
    key = random.PRNGKey(0 if config.deterministic_showcase else step)
159
    perm = random.permutation(key, num_eval)
160
    showcase_indices = np.sort(perm[: config.num_showcase_images])
161

162
    metrics = []
163
    metrics_aligned_optimized = []
164
    metrics_aligned_procrustes = []
165
    metrics_cameras = []
166
    metrics_cameras_procrustes = []
167
    showcases = []
168
    render_times = []
169

170
    state_params_replicated = flax.jax_utils.replicate(state.params)
171

172
    compute_aligned_metric = (
173
        config.optimize_test_cameras and step == config.max_steps
174
    )
175

176
    procrustes_cameras = None
177
    if config.compute_procrustes_metric and config.optimize_cameras:
178
      test_image_sizes = np.array(
179
          [(x.shape[1], x.shape[0]) for x in test_dataset.images]
180
      )
181
      test_jax_cameras = jax.vmap(test_dataset.jax_camera_from_tuple_fn)(
182
          test_dataset.cameras, test_image_sizes
183
      )
184
      train_jax_cameras = train_dataset.get_train_cameras(
185
          config, return_jax_cameras=True
186
      )
187
      train_jax_cameras_gt = train_dataset.jax_cameras
188
      camera_params = state.params['camera_params']
189
      camera_delta = config.camera_delta_cls()
190
      train_jax_cameras_opt = camera_delta.apply(
191
          camera_params, train_jax_cameras
192
      )
193
      train_jax_cameras_procrustes, test_jax_cameras_procrustes = (
194
          alignment.compute_procrusted_aligned_cameras(
195
              train_jax_cameras_gt=train_jax_cameras_gt,
196
              train_jax_cameras_opt=train_jax_cameras_opt,
197
              test_jax_cameras=test_jax_cameras,
198
          )
199
      )
200
      metrics_cameras = plot_camera_metrics(
201
          summary_writer=summary_writer,
202
          camera_params=camera_params,
203
          train_cameras=train_jax_cameras,
204
          train_cameras_gt=train_jax_cameras_gt,
205
          config=config,
206
          step=step,
207
          tag='error',
208
      )
209
      metrics_cameras_procrustes = plot_camera_metrics(
210
          summary_writer=summary_writer,
211
          camera_params=camera_params,
212
          train_cameras=train_jax_cameras,
213
          train_cameras_gt=train_jax_cameras_procrustes,
214
          config=config,
215
          step=step,
216
          tag='error_procrustes',
217
      )
218
      # Convert to tuples.
219
      procrustes_cameras = jax.vmap(camera_utils.tuple_from_jax_camera)(
220
          test_jax_cameras_procrustes
221
      )
222
      procrustes_cameras = (*procrustes_cameras, *test_dataset.cameras[3:])
223

224
      procrustes_cameras_replicated = flax.jax_utils.replicate(
225
          procrustes_cameras
226
      )
227

228
    raybatcher = datasets.RayBatcher(test_dataset)
229
    for idx in range(test_dataset.size):
230
      gc.collect()
231
      with jax.profiler.StepTraceAnnotation('eval', step_num=idx):
232
        eval_start_time = time.time()
233
        batch = next(raybatcher)
234
        if idx >= num_eval:
235
          logging.info('Skipping image %d/%d', idx + 1, test_dataset.size)
236
          continue
237
        logging.info('Evaluating image %d/%d', idx + 1, test_dataset.size)
238
        rays = batch.rays
239
        train_frac = state.step / config.max_steps
240

241
        def _render_image(cameras, rays, train_frac):
242
          return models.render_image(  # pytype: disable=wrong-arg-types  # jnp-array
243
              functools.partial(
244
                  render_eval_pfn,
245
                  state_params_replicated,
246
                  train_frac,
247
                  cameras,
248
              ),
249
              rays=rays,
250
              rng=None,
251
              config=config,
252
              return_all_levels=True,
253
          )
254

255
        if compute_aligned_metric:
256
          jnp_camera_optimized = alignment.align_test_camera(
257
              model, state, idx, test_dataset, config
258
          )
259
          jnp_camera_optimized_replicated = flax.jax_utils.replicate(
260
              jnp_camera_optimized
261
          )
262
          rendering_aligned_optimized = _render_image(
263
              jnp_camera_optimized_replicated, rays, train_frac
264
          )
265
          rendering_aligned_optimized = jax.tree_util.tree_map(
266
              np.asarray, rendering_aligned_optimized
267
          )
268

269
        if procrustes_cameras is not None:
270
          rendering_aligned_procrustes = _render_image(
271
              procrustes_cameras_replicated, rays, train_frac
272
          )
273
          rendering_aligned_procrustes = jax.tree_util.tree_map(
274
              np.asarray, rendering_aligned_procrustes
275
          )
276

277
        rendering = _render_image(jnp_cameras_replicated, rays, train_frac)
278
        rendering = jax.tree_util.tree_map(np.asarray, rendering)
279
        rays = jax.tree_util.tree_map(np.asarray, rays)
280

281
        if jax.host_id() != 0:  # Only record via host 0.
282
          continue
283

284
        render_times.append((time.time() - eval_start_time))
285
        logging.info('Rendered in %0.3fs', render_times[-1])
286

287
        # Cast to 64-bit to ensure high precision for color correction function.
288
        gt_rgb = np.array(batch.rgb, dtype=np.float64)
289
        rendering['rgb'] = np.array(rendering['rgb'], dtype=np.float64)
290
        if compute_aligned_metric:
291
          rendering['rgb_aligned_optimized'] = np.array(
292
              rendering_aligned_optimized['rgb'], dtype=np.float64
293
          )
294
        if procrustes_cameras is not None:
295
          rendering['rgb_aligned_procrustes'] = np.array(
296
              rendering_aligned_procrustes['rgb'], dtype=np.float64
297
          )
298

299
        if not config.eval_only_once and idx in showcase_indices:
300
          showcase_idx = (
301
              idx if config.deterministic_showcase else len(showcases)
302
          )
303
          showcases.append((showcase_idx, rendering, batch))
304
        if not config.render_path:
305
          rgb = postprocess_fn(rendering['rgb'])
306
          if compute_aligned_metric:
307
            rgb_aligned_optimized = postprocess_fn(
308
                rendering['rgb_aligned_optimized']
309
            )
310

311
          if procrustes_cameras is not None:
312
            rgb_aligned_procrustes = postprocess_fn(
313
                rendering['rgb_aligned_procrustes']
314
            )
315
          rgb_gt = postprocess_fn(gt_rgb)
316

317
          if config.eval_quantize_metrics:
318
            # Ensures that the images written to disk reproduce the metrics.
319
            rgb = np.round(rgb * 255) / 255
320

321
          if config.eval_crop_borders > 0:
322
            crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c]
323
            rgb = crop_fn(rgb)
324
            rgb_gt = crop_fn(rgb_gt)
325
            if compute_aligned_metric:
326
              rgb_aligned_optimized = crop_fn(rgb_aligned_optimized)
327
            if 'rgb_aligned_procrustes' in rendering:
328
              rgb_aligned_procrustes = crop_fn(rgb_aligned_procrustes)
329

330
          metric = metric_harness(rgb, rgb_gt)
331
          if compute_aligned_metric:
332
            metric_aligned_optimized = metric_harness(
333
                rgb_aligned_optimized, rgb_gt
334
            )
335
          if procrustes_cameras is not None:
336
            metric_aligned_procrustes = metric_harness(
337
                rgb_aligned_procrustes, rgb_gt
338
            )
339

340
          if config.compute_disp_metrics:
341
            for tag in ['mean', 'median']:
342
              key = f'distance_{tag}'
343
              if key in rendering:
344
                disparity = 1 / (1 + rendering[key][-1])
345
                metric[f'disparity_{tag}_mse'] = float(
346
                    ((disparity - batch.disps) ** 2).mean()
347
                )
348

349
          if config.compute_normal_metrics:
350
            weights = rendering['acc'][-1] * batch.alphas
351
            normalized_normals_gt = ref_utils.l2_normalize(batch.normals)
352
            for key, val in rendering.items():
353
              if key.startswith('normals') and val is not None:
354
                normalized_normals = ref_utils.l2_normalize(val[-1])
355
                metric[key + '_mae'] = ref_utils.compute_weighted_mae(
356
                    weights, normalized_normals, normalized_normals_gt
357
                )
358

359
          for m, v in metric.items():
360
            logging.info('%s = %0.4f', m, v)
361

362
          metrics.append(metric)
363
          if compute_aligned_metric:
364
            metrics_aligned_optimized.append(metric_aligned_optimized)
365
          if procrustes_cameras is not None:
366
            metrics_aligned_procrustes.append(metric_aligned_procrustes)
367

368
        if config.eval_save_output and (config.eval_render_interval > 0):
369
          if (idx % config.eval_render_interval) == 0:
370
            image_io.save_img_u8(
371
                postprocess_fn(rendering['rgb']),
372
                path_fn(f'color_{idx:03d}.png'),
373
            )
374
            if compute_aligned_metric:
375
              image_io.save_img_u8(
376
                  postprocess_fn(rendering['rgb_aligned_optimized']),
377
                  path_fn(f'color_aligned_optimized_{idx:03d}.png'),
378
              )
379
            if procrustes_cameras is not None:
380
              image_io.save_img_u8(
381
                  postprocess_fn(rendering['rgb_aligned_procrustes']),
382
                  path_fn(f'color_aligned_procrustes_{idx:03d}.png'),
383
              )
384

385
            for key in ['distance_mean', 'distance_median']:
386
              if key in rendering:
387
                image_io.save_img_f32(
388
                    rendering[key][-1], path_fn(f'{key}_{idx:03d}.tiff')
389
                )
390

391
            for key in ['normals']:
392
              if key in rendering:
393
                image_io.save_img_u8(
394
                    rendering[key][-1] / 2.0 + 0.5,
395
                    path_fn(f'{key}_{idx:03d}.png'),
396
                )
397

398
            if 'acc' in rendering:
399
              image_io.save_img_f32(
400
                  rendering['acc'][-1], path_fn(f'acc_{idx:03d}.tiff')
401
              )
402

403
            if batch.masks is not None:
404
              image_io.save_img_u8(
405
                  batch.rgb * batch.masks,
406
                  path_fn(f'masked_input_{idx:03d}.png'),
407
              )
408

409
    if (not config.eval_only_once) and (jax.host_id() == 0):
410
      summary_writer.scalar(
411
          'eval_median_render_time', np.median(render_times), step
412
      )
413

414
      def summarize_metrics(metrics, metrics_suffix):
415
        for name in metrics[0]:
416
          scores = [m[name] for m in metrics]
417
          prefix = f'eval_metrics{metrics_suffix}/'
418
          summary_writer.scalar(prefix + name, np.mean(scores), step)
419
          summary_writer.histogram(prefix + 'perimage_' + name, scores, step)
420

421
      summarize_metrics(metrics, '')
422
      if compute_aligned_metric:
423
        summarize_metrics(metrics_aligned_optimized, '_aligned_optimized')
424
      if procrustes_cameras is not None:
425
        summarize_metrics(metrics_aligned_procrustes, '_aligned_procrustes')
426

427
      if config.multiscale_train_factors is not None:
428
        factors = [1] + list(config.multiscale_train_factors)
429
        n_images = len(metrics) // len(factors)
430
        # Split metrics into chunks of n_images (each downsampling level).
431
        for i, f in enumerate(factors):
432
          i0 = i * n_images
433
          i1 = (i + 1) * n_images
434
          image_shapes = np.array([z.shape for z in test_dataset.images[i0:i1]])
435
          if not np.all(image_shapes == image_shapes[0]):
436
            raise ValueError(
437
                'Not all image shapes match for downsampling '
438
                f'factor {f}x in evaluation'
439
            )
440
          summarize_metrics(metrics[i0:i1], f'_{f}x')
441
          if compute_aligned_metric:
442
            summarize_metrics(
443
                metrics_aligned_optimized[i0:i1], f'_{f}x_aligned_optimized'
444
            )
445
          if procrustes_cameras is not None:
446
            summarize_metrics(
447
                metrics_aligned_procrustes[i0:i1], f'_{f}x_aligned_procrustes'
448
            )
449

450
      for i, r, b in showcases:
451
        if config.vis_decimate > 1:
452
          d = config.vis_decimate
453
          decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d]
454
        else:
455
          decimate_fn = lambda x: x
456
        r = jax.tree_util.tree_map(decimate_fn, r)
457
        b = jax.tree_util.tree_map(decimate_fn, b)
458
        for k, v in vis.visualize_suite(r).items():
459
          if isinstance(v, list):
460
            for ii, vv in enumerate(v):
461
              summary_writer.image(f'output_{k}_{i}/{ii}', vv, step)
462
          else:
463
            summary_writer.image(f'output_{k}_{i}', v, step)
464
        if b.masks is not None:
465
          mask_float_array = jax.numpy.asarray(b.masks, dtype=jax.numpy.float32)
466
          summary_writer.image(f'mask_{i}', mask_float_array, step)
467
          summary_writer.image(
468
              f'masked_image_{i}', b.rgb * mask_float_array, step
469
          )
470
        if not config.render_path:
471
          target = postprocess_fn(b.rgb)
472
          pred = postprocess_fn(r['rgb'])
473
          if compute_aligned_metric:
474
            pred_aligned_optimized = postprocess_fn(r['rgb_aligned_optimized'])
475
          if procrustes_cameras is not None:
476
            pred_aligned_procrustes = postprocess_fn(
477
                r['rgb_aligned_procrustes']
478
            )
479
          summary_writer.image(f'output_color_{i}', pred, step)
480
          if compute_aligned_metric:
481
            summary_writer.image(
482
                f'output_color_aligned_optimized_{i}',
483
                pred_aligned_optimized,
484
                step,
485
            )
486
          if procrustes_cameras is not None:
487
            summary_writer.image(
488
                f'output_color_aligned_procrustes_{i}',
489
                pred_aligned_procrustes,
490
                step,
491
            )
492
          summary_writer.image(f'true_color_{i}', target, step)
493
          residual = pred - target
494
          summary_writer.image(
495
              f'output_residual_{i}', np.clip(residual + 0.5, 0, 1), step
496
          )
497
          if compute_aligned_metric:
498
            residual_aligned_optimized = pred_aligned_optimized - target
499
            summary_writer.image(
500
                f'output_residual_aligned_{i}',
501
                np.clip(residual_aligned_optimized + 0.5, 0, 1),
502
                step,
503
            )
504
          if procrustes_cameras is not None:
505
            residual_aligned_procrustes = pred_aligned_procrustes - target
506
            summary_writer.image(
507
                f'output_residual_aligned_{i}',
508
                np.clip(residual_aligned_procrustes + 0.5, 0, 1),
509
                step,
510
            )
511
          residual_hist = image_utils.render_histogram(
512
              np.array(residual).reshape([-1, 3]),
513
              bins=32,
514
              range=(-1, 1),
515
              log=True,
516
              color=('r', 'g', 'b'),
517
          )
518
          summary_writer.image(f'output_residual_hist_{i}', residual_hist, step)
519
          if config.compute_normal_metrics:
520
            summary_writer.image(
521
                f'true_normals_{i}', b.normals / 2.0 + 0.5, step
522
            )
523

524
    if (
525
        config.eval_save_output
526
        and (not config.render_path)
527
        and (jax.host_id() == 0)
528
    ):
529
      with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f:
530
        f.write(' '.join([str(r) for r in render_times]))
531
      for name in metrics[0]:
532
        with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
533
          f.write(' '.join([str(m[name]) for m in metrics]))
534
      if compute_aligned_metric:
535
        for name in metrics_aligned_optimized[0]:
536
          with utils.open_file(
537
              path_fn(f'metric_aligned_optimized_{name}_{step}.txt'), 'w'
538
          ) as f:
539
            f.write(' '.join([str(m[name]) for m in metrics_aligned_optimized]))
540
      if procrustes_cameras is not None:
541
        for name in metrics_aligned_procrustes[0]:
542
          with utils.open_file(
543
              path_fn(f'metric_aligned_procrustes_{name}_{step}.txt'), 'w'
544
          ) as f:
545
            f.write(
546
                ' '.join([str(m[name]) for m in metrics_aligned_procrustes])
547
            )
548
      if metrics_cameras:
549
        for name in metrics_cameras[0]:
550
          with utils.open_file(
551
              path_fn(f'metric_cameras_{name}_{step}.txt'), 'w'
552
          ) as f:
553
            f.write(' '.join([str(m[name]) for m in metrics_cameras]))
554
      if metrics_cameras_procrustes:
555
        for name in metrics_cameras_procrustes[0]:
556
          with utils.open_file(
557
              path_fn(f'metric_cameras_procrustes_{name}_{step}.txt'), 'w'
558
          ) as f:
559
            f.write(
560
                ' '.join([str(m[name]) for m in metrics_cameras_procrustes])
561
            )
562
      if config.eval_save_ray_data:
563
        for i, r, b in showcases:
564
          rays = {k: v for k, v in r.items() if 'ray_' in k}
565
          np.set_printoptions(threshold=sys.maxsize)
566
          with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f:
567
            f.write(repr(rays))
568

569
    # A hack that forces Jax to keep all TPUs alive until every TPU is finished.
570
    x = jnp.ones([jax.local_device_count()])
571
    x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
572
    print(x)
573

574
    if config.eval_only_once:
575
      logging.info('Eval only once enabled, shutting down.')
576
      break
577
    if config.early_exit_steps is not None:
578
      num_steps = config.early_exit_steps
579
    else:
580
      num_steps = config.max_steps
581
    if int(step) >= num_steps:
582
      logging.info('Termination num steps reached (%d).', num_steps)
583
      break
584
    last_step = step
585

586

587
if __name__ == '__main__':
588
  with gin.config_scope('eval'):
589
    app.run(main)
590

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.