google-research
227 строк · 8.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation script for RegNeRF."""
17import functools
18from os import path
19import time
20
21from absl import app
22import flax
23from flax.metrics import tensorboard
24from flax.training import checkpoints
25from internal import configs, datasets, math, models, utils, vis # pylint: disable=g-multiple-import
26import jax
27from jax import random
28import numpy as np
29from skimage.metrics import structural_similarity
30import tensorflow as tf
31
32CENSUS_EPSILON = 1 / 256 # Guard against ground-truth quantization.
33
34configs.define_common_flags()
35jax.config.parse_flags_with_absl()
36
37
38def main(unused_argv):
39
40tf.config.experimental.set_visible_devices([], 'GPU')
41tf.config.experimental.set_visible_devices([], 'TPU')
42
43config = configs.load_config(save_config=False)
44
45dataset = datasets.load_dataset('test', config.data_dir, config)
46model, init_variables = models.construct_mipnerf(
47random.PRNGKey(20200823),
48dataset.peek()['rays'],
49config)
50optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
51state = utils.TrainState(optimizer=optimizer)
52del optimizer, init_variables
53
54# Rendering is forced to be deterministic even if training was randomized, as
55# this eliminates 'speckle' artifacts.
56def render_eval_fn(variables, _, rays):
57return jax.lax.all_gather(
58model.apply(
59variables,
60None, # Deterministic.
61rays,
62resample_padding=config.resample_padding_final,
63compute_extras=True), axis_name='batch')
64
65# pmap over only the data input.
66render_eval_pfn = jax.pmap(
67render_eval_fn,
68in_axes=(None, None, 0),
69donate_argnums=2,
70axis_name='batch',
71)
72
73def ssim_fn(x, y):
74return structural_similarity(x, y, multichannel=True)
75
76census_fn = jax.jit(
77functools.partial(math.compute_census_err, epsilon=CENSUS_EPSILON))
78
79print('WARNING: LPIPS calculation not supported. NaN values used instead.')
80if config.eval_disable_lpips:
81lpips_fn = lambda x, y: np.nan
82else:
83lpips_fn = lambda x, y: np.nan
84
85last_step = 0
86out_dir = path.join(config.checkpoint_dir,
87'path_renders' if config.render_path else 'test_preds')
88path_fn = lambda x: path.join(out_dir, x)
89
90if not config.eval_only_once:
91summary_writer = tensorboard.SummaryWriter(
92path.join(config.checkpoint_dir, 'eval'))
93while True:
94# Fix for loading pre-trained models.
95try:
96state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
97except: # pylint: disable=bare-except
98print('Using pre-trained model.')
99state_dict = checkpoints.restore_checkpoint(config.checkpoint_dir, None)
100for i in [9, 17]:
101del state_dict['optimizer']['target']['params']['MLP_0'][f'Dense_{i}']
102state_dict['optimizer']['target']['params']['MLP_0'][
103'Dense_9'] = state_dict['optimizer']['target']['params']['MLP_0'][
104'Dense_18']
105state_dict['optimizer']['target']['params']['MLP_0'][
106'Dense_10'] = state_dict['optimizer']['target']['params']['MLP_0'][
107'Dense_19']
108state_dict['optimizer']['target']['params']['MLP_0'][
109'Dense_11'] = state_dict['optimizer']['target']['params']['MLP_0'][
110'Dense_20']
111del state_dict['optimizerd']
112state = flax.serialization.from_state_dict(state, state_dict)
113
114step = int(state.optimizer.state.step)
115if step <= last_step:
116print(f'Checkpoint step {step} <= last step {last_step}, sleeping.')
117time.sleep(10)
118continue
119print(f'Evaluating checkpoint at step {step}.')
120if config.eval_save_output and (not utils.isdir(out_dir)):
121utils.makedirs(out_dir)
122
123key = random.PRNGKey(0 if config.deterministic_showcase else step)
124perm = random.permutation(key, dataset.size)
125showcase_indices = np.sort(perm[:config.num_showcase_images])
126
127metrics = []
128showcases = []
129for idx in range(dataset.size):
130print(f'Evaluating image {idx+1}/{dataset.size}')
131eval_start_time = time.time()
132batch = next(dataset)
133rendering = models.render_image(
134functools.partial(render_eval_pfn, state.optimizer.target),
135batch['rays'],
136None,
137config)
138print(f'Rendered in {(time.time() - eval_start_time):0.3f}s')
139
140if jax.host_id() != 0: # Only record via host 0.
141continue
142if not config.eval_only_once and idx in showcase_indices:
143showcase_idx = idx if config.deterministic_showcase else len(showcases)
144showcases.append((showcase_idx, rendering, batch))
145if not config.render_path:
146metric = {}
147metric['psnr'] = float(
148math.mse_to_psnr(((rendering['rgb'] - batch['rgb'])**2).mean()))
149metric['ssim'] = float(ssim_fn(rendering['rgb'], batch['rgb']))
150metric['lpips'] = float(lpips_fn(rendering['rgb'], batch['rgb']))
151metric['avg_err'] = float(
152math.compute_avg_error(
153psnr=metric['psnr'],
154ssim=metric['ssim'],
155lpips=metric['lpips'],
156))
157metric['census_err'] = float(census_fn(rendering['rgb'], batch['rgb']))
158
159if config.compute_disp_metrics:
160disp = 1 / (1 + rendering['distance_mean'])
161metric['disp_mse'] = float(((disp - batch['disps'])**2).mean())
162
163if config.compute_normal_metrics:
164one_eps = 1 - np.finfo(np.float32).eps
165metric['normal_mae'] = float(
166np.arccos(
167np.clip(
168np.sum(batch['normals'] * rendering['normals'], axis=-1),
169-one_eps, one_eps)).mean())
170
171if config.dataset_loader == 'dtu':
172rgb = batch['rgb']
173rgb_hat = rendering['rgb']
174mask = batch['mask']
175mask_bin = (mask == 1.)
176
177rgb_fg = rgb * mask + (1 - mask)
178rgb_hat_fg = rgb_hat * mask + (1 - mask)
179
180metric['psnr_masked'] = float(
181math.mse_to_psnr(((rgb - rgb_hat)[mask_bin]**2).mean()))
182metric['ssim_masked'] = float(ssim_fn(rgb_hat_fg, rgb_fg))
183metric['lpips_masked'] = float(lpips_fn(rgb_hat_fg, rgb_fg))
184metric['avg_err_masked'] = float(
185math.compute_avg_error(
186psnr=metric['psnr_masked'],
187ssim=metric['ssim_masked'],
188lpips=metric['lpips_masked'],
189))
190
191for m, v in metric.items():
192print(f'{m:10s} = {v:.4f}')
193metrics.append(metric)
194
195if config.eval_save_output and (config.eval_render_interval > 0):
196if (idx % config.eval_render_interval) == 0:
197utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx:03d}.png'))
198utils.save_img_u8(rendering['normals'] / 2. + 0.5,
199path_fn(f'normals_{idx:03d}.png'))
200utils.save_img_f32(rendering['distance_mean'],
201path_fn(f'distance_mean_{idx:03d}.tiff'))
202utils.save_img_f32(rendering['distance_median'],
203path_fn(f'distance_median_{idx:03d}.tiff'))
204utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))
205
206if (not config.eval_only_once) and (jax.host_id() == 0):
207for name in list(metrics[0].keys()):
208summary_writer.scalar(name, np.mean([m[name] for m in metrics]), step)
209for i, r, b in showcases:
210for k, v in vis.visualize_suite(r, b['rays'], config).items():
211summary_writer.image(f'pred_{k}_{i}', v, step)
212if not config.render_path:
213summary_writer.image(f'target_{i}', b['rgb'], step)
214if (config.eval_save_output and (not config.render_path) and
215(jax.host_id() == 0)):
216for name in list(metrics[0].keys()):
217with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
218f.write(' '.join([str(m[name]) for m in metrics]))
219if config.eval_only_once:
220break
221if int(step) >= config.max_steps:
222break
223last_step = step
224
225
226if __name__ == '__main__':
227app.run(main)
228