google-research
205 строк · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation script for Nerf."""
17import functools18from os import path19
20from absl import app21from absl import flags22import flax23from flax.metrics import tensorboard24from flax.training import checkpoints25import jax26from jax import random27import numpy as np28import tensorflow as tf29import tensorflow_hub as tf_hub30import pickle31import optax32
33from jaxbarf.src import datasets34from jaxbarf.src import models35from jaxbarf.src import utils36from jaxbarf.src import camera37
38FLAGS = flags.FLAGS39utils.define_flags()40LPIPS_TFHUB_PATH = "@neural-rendering/lpips/distance/1"41
42def compute_lpips(image1, image2, model):43"""Compute the LPIPS metric."""44# The LPIPS model expects a batch dimension.45return model(46tf.convert_to_tensor(image1[None, Ellipsis]),47tf.convert_to_tensor(image2[None, Ellipsis]))[0]48
49
50def main(unused_argv):51"""Entry point for evaluation binary."""52tf.config.experimental.set_visible_devices([], "GPU")53tf.config.experimental.set_visible_devices([], "TPU")54rng = random.PRNGKey(20200823)55
56if FLAGS.config is not None:57utils.update_flags(FLAGS)58if FLAGS.train_dir is None:59raise ValueError("train_dir must be set. None set now.")60if FLAGS.data_dir is None:61raise ValueError("data_dir must be set. None set now.")62
63# load train dataset to get GT poses64#with utils.open_file(FLAGS.init_poses_file, "rb") as f: # load init poses65# poses_train_init = pickle.load(f)66dataset_train = datasets.get_dataset("train", FLAGS, train_mode=False)67poses_train = utils.to_device(dataset_train.get_all_poses())68
69rng, key = random.split(rng)70model, variables = models.get_model(key, dataset_train.peek(), FLAGS)71# Set up seperate optimizer and LR schedule for pose and MLP parameters72params = variables["params"]73learning_rate_fn_mlp = functools.partial(74utils.learning_rate_decay,75lr_init=FLAGS.lr_init,76lr_final=FLAGS.lr_final,77max_steps=FLAGS.max_steps,78lr_delay_steps=FLAGS.lr_delay_steps,79lr_delay_mult=FLAGS.lr_delay_mult)80learning_rate_fn_pose = functools.partial(81utils.learning_rate_decay,82lr_init=FLAGS.lr_init_pose,83lr_final=FLAGS.lr_final_pose,84max_steps=FLAGS.max_steps,85lr_delay_steps=FLAGS.lr_delay_steps_pose,86lr_delay_mult=FLAGS.lr_delay_mult_pose)87pose_params = flax.traverse_util.ModelParamTraversal(88lambda path, _: "POSE" in path)89mlp_params = flax.traverse_util.ModelParamTraversal(90lambda path, _: "MLP" in path)91all_false = jax.tree_util.tree_map(lambda _: False, params)92pose_mask = pose_params.update(lambda _: True, all_false)93mlp_mask = mlp_params.update(lambda _: True, all_false)94optimizer = optax.chain(95optax.scale_by_adam(),96optax.masked(optax.scale_by_schedule(learning_rate_fn_pose), pose_mask),97optax.masked(optax.scale_by_schedule(learning_rate_fn_mlp), mlp_mask),98optax.scale(-1),99)100optimizer_state = optimizer.init(params)101state = utils.TrainState(optimizer_state=optimizer_state, params=params, step=0)102del params, optimizer_state103
104
105# Rendering is forced to be deterministic even if training was randomized, as106# this eliminates "speckle" artifacts.107def render_fn(variables, key_0, key_1, rays, step):108"""Render function (no learned pose refinement if train_mode=False.)"""109return jax.lax.all_gather(110model.apply({"params":variables}, key_0, key_1, rays,111False, train_mode=False, step=step),112axis_name="batch")113
114# pmap over only the data input.115render_pfn = jax.pmap(116render_fn,117in_axes=(None, None, None, 0, None),118donate_argnums=(3,),119axis_name="batch",120)121
122# Compiling to the CPU because it's faster and more accurate.123ssim_fn = jax.jit(124functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")125lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)126
127last_step = 0128out_dir = path.join(FLAGS.train_dir, "test_preds")129if not utils.isdir(out_dir):130utils.makedirs(out_dir)131
132summary_writer = tensorboard.SummaryWriter(path.join(FLAGS.train_dir, "eval"))133
134while True:135state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)136step = int(state.step * FLAGS.max_steps)137if step <= last_step:138continue139
140poses_refine_se3 = state.params["POSE_0"]["delta_se3"]141poses_refine_se3exp = camera.se3_exp(poses_refine_se3)142poses_train_pred = camera.compose([poses_refine_se3exp,143poses_train["poses_init"]])144poses_train_aligned, sim3 = camera.prealign_cameras(poses_train_pred,145poses_train["poses_gt"])146r_error, t_error = camera.evaluate_camera(poses_train_pred,147poses_train["poses_gt"])148
149psnr_values = []150ssim_values = []151lpips_values = []152
153# Every time we load a new checkpoint, we need to update poses154dataset = datasets.get_dataset("test", FLAGS,155calib_matrix=sim3,156train_mode=False)157for idx in range(8):158print(f"Evaluating {idx+1}/{dataset.size}")159batch = next(dataset)160pred_color, pred_disp, pred_acc = utils.render_image(161functools.partial(render_pfn, state.params),162batch["rays"],163rng,164FLAGS.dataset == "llff",165chunk=FLAGS.chunk,166step=step/FLAGS.max_steps)167if jax.host_id() != 0: # Only record via host 0.168continue169
170psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())171ssim = ssim_fn(pred_color, batch["pixels"])172lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)173psnr_values.append(float(psnr))174ssim_values.append(float(ssim))175lpips_values.append(float(lpips))176
177utils.save_img(pred_color, path.join(out_dir,178"pred_{:03d}_{}.png".format(idx, step)))179utils.save_img(batch["pixels"], path.join(out_dir,180"gt_{:03d}_{}.png".format(idx, step)))181summary_writer.image("val_pred_color", pred_color, step)182summary_writer.image("val_gt_color", batch["pixels"], step)183summary_writer.scalar("val_psnr", np.mean(np.array(psnr_values)), step)184
185with utils.open_file(path.join(out_dir, f"{step}.txt"), "w") as f:186f.write("Trainset: num {}, R_error: {:.3f}, t_error: {:.3f}\n".format(187len(r_error), np.mean(r_error)*180/np.pi, np.mean(t_error)))188f.write("Average over {} validation images\n".format(len(psnr_values)))189f.write("Mean PSNR: {:.2f}\n".format(np.mean(np.array(psnr_values))))190f.write("Mean SSIM: {:.2f}\n".format(np.mean(np.array(ssim_values))))191f.write("Mean LPIPS: {:.2f}\n".format(np.mean(np.array(lpips_values))))192f.write("Mean PSNR (first 8): {:.2f}\n".format(193np.mean(np.array(psnr_values)[:8])))194f.write("Mean SSIM (first 8): {:.2f}\n".format(195np.mean(np.array(ssim_values)[:8])))196f.write("Mean LPIPS (first 8): {:.2f}\n".format(197np.mean(np.array(lpips_values)[:8])))198
199if int(step) >= FLAGS.max_steps:200break201last_step = step202
203
204if __name__ == "__main__":205app.run(main)206