google-research

Форк
0
227 строк · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation script for RegNeRF."""
17
import functools
18
from os import path
19
import time
20

21
from absl import app
22
import flax
23
from flax.metrics import tensorboard
24
from flax.training import checkpoints
25
from internal import configs, datasets, math, models, utils, vis  # pylint: disable=g-multiple-import
26
import jax
27
from jax import random
28
import numpy as np
29
from skimage.metrics import structural_similarity
30
import tensorflow as tf
31

32
CENSUS_EPSILON = 1 / 256  # Guard against ground-truth quantization.
33

34
configs.define_common_flags()
35
jax.config.parse_flags_with_absl()
36

37

38
def main(unused_argv):
39

40
  tf.config.experimental.set_visible_devices([], 'GPU')
41
  tf.config.experimental.set_visible_devices([], 'TPU')
42

43
  config = configs.load_config(save_config=False)
44

45
  dataset = datasets.load_dataset('test', config.data_dir, config)
46
  model, init_variables = models.construct_mipnerf(
47
      random.PRNGKey(20200823),
48
      dataset.peek()['rays'],
49
      config)
50
  optimizer = flax.optim.Adam(config.lr_init).create(init_variables)
51
  state = utils.TrainState(optimizer=optimizer)
52
  del optimizer, init_variables
53

54
  # Rendering is forced to be deterministic even if training was randomized, as
55
  # this eliminates 'speckle' artifacts.
56
  def render_eval_fn(variables, _, rays):
57
    return jax.lax.all_gather(
58
        model.apply(
59
            variables,
60
            None,  # Deterministic.
61
            rays,
62
            resample_padding=config.resample_padding_final,
63
            compute_extras=True), axis_name='batch')
64

65
  # pmap over only the data input.
66
  render_eval_pfn = jax.pmap(
67
      render_eval_fn,
68
      in_axes=(None, None, 0),
69
      donate_argnums=2,
70
      axis_name='batch',
71
  )
72

73
  def ssim_fn(x, y):
74
    return structural_similarity(x, y, multichannel=True)
75

76
  census_fn = jax.jit(
77
      functools.partial(math.compute_census_err, epsilon=CENSUS_EPSILON))
78

79
  print('WARNING: LPIPS calculation not supported. NaN values used instead.')
80
  if config.eval_disable_lpips:
81
    lpips_fn = lambda x, y: np.nan
82
  else:
83
    lpips_fn = lambda x, y: np.nan
84

85
  last_step = 0
86
  out_dir = path.join(config.checkpoint_dir,
87
                      'path_renders' if config.render_path else 'test_preds')
88
  path_fn = lambda x: path.join(out_dir, x)
89

90
  if not config.eval_only_once:
91
    summary_writer = tensorboard.SummaryWriter(
92
        path.join(config.checkpoint_dir, 'eval'))
93
  while True:
94
    # Fix for loading pre-trained models.
95
    try:
96
      state = checkpoints.restore_checkpoint(config.checkpoint_dir, state)
97
    except:  # pylint: disable=bare-except
98
      print('Using pre-trained model.')
99
      state_dict = checkpoints.restore_checkpoint(config.checkpoint_dir, None)
100
      for i in [9, 17]:
101
        del state_dict['optimizer']['target']['params']['MLP_0'][f'Dense_{i}']
102
      state_dict['optimizer']['target']['params']['MLP_0'][
103
          'Dense_9'] = state_dict['optimizer']['target']['params']['MLP_0'][
104
              'Dense_18']
105
      state_dict['optimizer']['target']['params']['MLP_0'][
106
          'Dense_10'] = state_dict['optimizer']['target']['params']['MLP_0'][
107
              'Dense_19']
108
      state_dict['optimizer']['target']['params']['MLP_0'][
109
          'Dense_11'] = state_dict['optimizer']['target']['params']['MLP_0'][
110
              'Dense_20']
111
      del state_dict['optimizerd']
112
      state = flax.serialization.from_state_dict(state, state_dict)
113

114
    step = int(state.optimizer.state.step)
115
    if step <= last_step:
116
      print(f'Checkpoint step {step} <= last step {last_step}, sleeping.')
117
      time.sleep(10)
118
      continue
119
    print(f'Evaluating checkpoint at step {step}.')
120
    if config.eval_save_output and (not utils.isdir(out_dir)):
121
      utils.makedirs(out_dir)
122

123
    key = random.PRNGKey(0 if config.deterministic_showcase else step)
124
    perm = random.permutation(key, dataset.size)
125
    showcase_indices = np.sort(perm[:config.num_showcase_images])
126

127
    metrics = []
128
    showcases = []
129
    for idx in range(dataset.size):
130
      print(f'Evaluating image {idx+1}/{dataset.size}')
131
      eval_start_time = time.time()
132
      batch = next(dataset)
133
      rendering = models.render_image(
134
          functools.partial(render_eval_pfn, state.optimizer.target),
135
          batch['rays'],
136
          None,
137
          config)
138
      print(f'Rendered in {(time.time() - eval_start_time):0.3f}s')
139

140
      if jax.host_id() != 0:  # Only record via host 0.
141
        continue
142
      if not config.eval_only_once and idx in showcase_indices:
143
        showcase_idx = idx if config.deterministic_showcase else len(showcases)
144
        showcases.append((showcase_idx, rendering, batch))
145
      if not config.render_path:
146
        metric = {}
147
        metric['psnr'] = float(
148
            math.mse_to_psnr(((rendering['rgb'] - batch['rgb'])**2).mean()))
149
        metric['ssim'] = float(ssim_fn(rendering['rgb'], batch['rgb']))
150
        metric['lpips'] = float(lpips_fn(rendering['rgb'], batch['rgb']))
151
        metric['avg_err'] = float(
152
            math.compute_avg_error(
153
                psnr=metric['psnr'],
154
                ssim=metric['ssim'],
155
                lpips=metric['lpips'],
156
            ))
157
        metric['census_err'] = float(census_fn(rendering['rgb'], batch['rgb']))
158

159
        if config.compute_disp_metrics:
160
          disp = 1 / (1 + rendering['distance_mean'])
161
          metric['disp_mse'] = float(((disp - batch['disps'])**2).mean())
162

163
        if config.compute_normal_metrics:
164
          one_eps = 1 - np.finfo(np.float32).eps
165
          metric['normal_mae'] = float(
166
              np.arccos(
167
                  np.clip(
168
                      np.sum(batch['normals'] * rendering['normals'], axis=-1),
169
                      -one_eps, one_eps)).mean())
170

171
        if config.dataset_loader == 'dtu':
172
          rgb = batch['rgb']
173
          rgb_hat = rendering['rgb']
174
          mask = batch['mask']
175
          mask_bin = (mask == 1.)
176

177
          rgb_fg = rgb * mask + (1 - mask)
178
          rgb_hat_fg = rgb_hat * mask + (1 - mask)
179

180
          metric['psnr_masked'] = float(
181
              math.mse_to_psnr(((rgb - rgb_hat)[mask_bin]**2).mean()))
182
          metric['ssim_masked'] = float(ssim_fn(rgb_hat_fg, rgb_fg))
183
          metric['lpips_masked'] = float(lpips_fn(rgb_hat_fg, rgb_fg))
184
          metric['avg_err_masked'] = float(
185
              math.compute_avg_error(
186
                  psnr=metric['psnr_masked'],
187
                  ssim=metric['ssim_masked'],
188
                  lpips=metric['lpips_masked'],
189
              ))
190

191
        for m, v in metric.items():
192
          print(f'{m:10s} = {v:.4f}')
193
        metrics.append(metric)
194

195
      if config.eval_save_output and (config.eval_render_interval > 0):
196
        if (idx % config.eval_render_interval) == 0:
197
          utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx:03d}.png'))
198
          utils.save_img_u8(rendering['normals'] / 2. + 0.5,
199
                            path_fn(f'normals_{idx:03d}.png'))
200
          utils.save_img_f32(rendering['distance_mean'],
201
                             path_fn(f'distance_mean_{idx:03d}.tiff'))
202
          utils.save_img_f32(rendering['distance_median'],
203
                             path_fn(f'distance_median_{idx:03d}.tiff'))
204
          utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff'))
205

206
    if (not config.eval_only_once) and (jax.host_id() == 0):
207
      for name in list(metrics[0].keys()):
208
        summary_writer.scalar(name, np.mean([m[name] for m in metrics]), step)
209
      for i, r, b in showcases:
210
        for k, v in vis.visualize_suite(r, b['rays'], config).items():
211
          summary_writer.image(f'pred_{k}_{i}', v, step)
212
        if not config.render_path:
213
          summary_writer.image(f'target_{i}', b['rgb'], step)
214
    if (config.eval_save_output and (not config.render_path) and
215
        (jax.host_id() == 0)):
216
      for name in list(metrics[0].keys()):
217
        with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f:
218
          f.write(' '.join([str(m[name]) for m in metrics]))
219
    if config.eval_only_once:
220
      break
221
    if int(step) >= config.max_steps:
222
      break
223
    last_step = step
224

225

226
if __name__ == '__main__':
227
  app.run(main)
228

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.