google-research

Форк
0
205 строк · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation script for Nerf."""
17
import functools
18
from os import path
19

20
from absl import app
21
from absl import flags
22
import flax
23
from flax.metrics import tensorboard
24
from flax.training import checkpoints
25
import jax
26
from jax import random
27
import numpy as np
28
import tensorflow as tf
29
import tensorflow_hub as tf_hub
30
import pickle
31
import optax
32

33
from jaxbarf.src import datasets
34
from jaxbarf.src import models
35
from jaxbarf.src import utils
36
from jaxbarf.src import camera
37

38
FLAGS = flags.FLAGS
39
utils.define_flags()
40
LPIPS_TFHUB_PATH = "@neural-rendering/lpips/distance/1"
41

42
def compute_lpips(image1, image2, model):
43
  """Compute the LPIPS metric."""
44
  # The LPIPS model expects a batch dimension.
45
  return model(
46
      tf.convert_to_tensor(image1[None, Ellipsis]),
47
      tf.convert_to_tensor(image2[None, Ellipsis]))[0]
48

49

50
def main(unused_argv):
51
  """Entry point for evaluation binary."""
52
  tf.config.experimental.set_visible_devices([], "GPU")
53
  tf.config.experimental.set_visible_devices([], "TPU")
54
  rng = random.PRNGKey(20200823)
55

56
  if FLAGS.config is not None:
57
    utils.update_flags(FLAGS)
58
  if FLAGS.train_dir is None:
59
    raise ValueError("train_dir must be set. None set now.")
60
  if FLAGS.data_dir is None:
61
    raise ValueError("data_dir must be set. None set now.")
62

63
  # load train dataset to get GT poses
64
  #with utils.open_file(FLAGS.init_poses_file, "rb") as f: # load init poses
65
  #  poses_train_init = pickle.load(f)
66
  dataset_train = datasets.get_dataset("train", FLAGS, train_mode=False)
67
  poses_train = utils.to_device(dataset_train.get_all_poses())
68

69
  rng, key = random.split(rng)
70
  model, variables = models.get_model(key, dataset_train.peek(), FLAGS)
71
  # Set up seperate optimizer and LR schedule for pose and MLP parameters
72
  params = variables["params"]
73
  learning_rate_fn_mlp = functools.partial(
74
      utils.learning_rate_decay,
75
      lr_init=FLAGS.lr_init,
76
      lr_final=FLAGS.lr_final,
77
      max_steps=FLAGS.max_steps,
78
      lr_delay_steps=FLAGS.lr_delay_steps,
79
      lr_delay_mult=FLAGS.lr_delay_mult)
80
  learning_rate_fn_pose = functools.partial(
81
      utils.learning_rate_decay,
82
      lr_init=FLAGS.lr_init_pose,
83
      lr_final=FLAGS.lr_final_pose,
84
      max_steps=FLAGS.max_steps,
85
      lr_delay_steps=FLAGS.lr_delay_steps_pose,
86
      lr_delay_mult=FLAGS.lr_delay_mult_pose)
87
  pose_params = flax.traverse_util.ModelParamTraversal(
88
      lambda path, _: "POSE" in path)
89
  mlp_params = flax.traverse_util.ModelParamTraversal(
90
      lambda path, _: "MLP" in path)
91
  all_false = jax.tree_util.tree_map(lambda _: False, params)
92
  pose_mask = pose_params.update(lambda _: True, all_false)
93
  mlp_mask = mlp_params.update(lambda _: True, all_false)
94
  optimizer = optax.chain(
95
          optax.scale_by_adam(),
96
          optax.masked(optax.scale_by_schedule(learning_rate_fn_pose), pose_mask),
97
          optax.masked(optax.scale_by_schedule(learning_rate_fn_mlp), mlp_mask),
98
          optax.scale(-1),
99
  )
100
  optimizer_state = optimizer.init(params)
101
  state = utils.TrainState(optimizer_state=optimizer_state, params=params, step=0)
102
  del params, optimizer_state
103

104

105
  # Rendering is forced to be deterministic even if training was randomized, as
106
  # this eliminates "speckle" artifacts.
107
  def render_fn(variables, key_0, key_1, rays, step):
108
    """Render function (no learned pose refinement if train_mode=False.)"""
109
    return jax.lax.all_gather(
110
        model.apply({"params":variables}, key_0, key_1, rays,
111
                    False, train_mode=False, step=step),
112
        axis_name="batch")
113

114
  # pmap over only the data input.
115
  render_pfn = jax.pmap(
116
      render_fn,
117
      in_axes=(None, None, None, 0, None),
118
      donate_argnums=(3,),
119
      axis_name="batch",
120
  )
121

122
  # Compiling to the CPU because it's faster and more accurate.
123
  ssim_fn = jax.jit(
124
      functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
125
  lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)
126

127
  last_step = 0
128
  out_dir = path.join(FLAGS.train_dir, "test_preds")
129
  if not utils.isdir(out_dir):
130
    utils.makedirs(out_dir)
131

132
  summary_writer = tensorboard.SummaryWriter(path.join(FLAGS.train_dir, "eval"))
133

134
  while True:
135
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
136
    step = int(state.step * FLAGS.max_steps)
137
    if step <= last_step:
138
      continue
139

140
    poses_refine_se3 = state.params["POSE_0"]["delta_se3"]
141
    poses_refine_se3exp = camera.se3_exp(poses_refine_se3)
142
    poses_train_pred = camera.compose([poses_refine_se3exp,
143
                                       poses_train["poses_init"]])
144
    poses_train_aligned, sim3 = camera.prealign_cameras(poses_train_pred,
145
                                                        poses_train["poses_gt"])
146
    r_error, t_error = camera.evaluate_camera(poses_train_pred,
147
                                              poses_train["poses_gt"])
148

149
    psnr_values = []
150
    ssim_values = []
151
    lpips_values = []
152

153
    # Every time we load a new checkpoint, we need to update poses
154
    dataset = datasets.get_dataset("test", FLAGS,
155
                                   calib_matrix=sim3,
156
                                   train_mode=False)
157
    for idx in range(8):
158
      print(f"Evaluating {idx+1}/{dataset.size}")
159
      batch = next(dataset)
160
      pred_color, pred_disp, pred_acc = utils.render_image(
161
          functools.partial(render_pfn, state.params),
162
          batch["rays"],
163
          rng,
164
          FLAGS.dataset == "llff",
165
          chunk=FLAGS.chunk,
166
          step=step/FLAGS.max_steps)
167
      if jax.host_id() != 0:  # Only record via host 0.
168
        continue
169

170
      psnr = utils.compute_psnr(((pred_color - batch["pixels"])**2).mean())
171
      ssim = ssim_fn(pred_color, batch["pixels"])
172
      lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
173
      psnr_values.append(float(psnr))
174
      ssim_values.append(float(ssim))
175
      lpips_values.append(float(lpips))
176

177
      utils.save_img(pred_color, path.join(out_dir,
178
                                           "pred_{:03d}_{}.png".format(idx, step)))
179
      utils.save_img(batch["pixels"], path.join(out_dir,
180
                                           "gt_{:03d}_{}.png".format(idx, step)))
181
      summary_writer.image("val_pred_color", pred_color, step)
182
      summary_writer.image("val_gt_color", batch["pixels"], step)
183
    summary_writer.scalar("val_psnr", np.mean(np.array(psnr_values)), step)
184

185
    with utils.open_file(path.join(out_dir, f"{step}.txt"), "w") as f:
186
      f.write("Trainset: num {}, R_error: {:.3f}, t_error: {:.3f}\n".format(
187
          len(r_error), np.mean(r_error)*180/np.pi, np.mean(t_error)))
188
      f.write("Average over {} validation images\n".format(len(psnr_values)))
189
      f.write("Mean PSNR: {:.2f}\n".format(np.mean(np.array(psnr_values))))
190
      f.write("Mean SSIM: {:.2f}\n".format(np.mean(np.array(ssim_values))))
191
      f.write("Mean LPIPS: {:.2f}\n".format(np.mean(np.array(lpips_values))))
192
      f.write("Mean PSNR (first 8): {:.2f}\n".format(
193
          np.mean(np.array(psnr_values)[:8])))
194
      f.write("Mean SSIM (first 8): {:.2f}\n".format(
195
          np.mean(np.array(ssim_values)[:8])))
196
      f.write("Mean LPIPS (first 8): {:.2f}\n".format(
197
          np.mean(np.array(lpips_values)[:8])))
198

199
    if int(step) >= FLAGS.max_steps:
200
      break
201
    last_step = step
202

203

204
if __name__ == "__main__":
205
  app.run(main)
206

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.