google-research

Форк
0
130 строк · 5.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation script for flare removal."""
17

18
import os.path
19

20
from absl import app
21
from absl import flags
22
from absl import logging
23
import tensorflow as tf
24

25
from flare_removal.python import data_provider
26
from flare_removal.python import losses
27
from flare_removal.python import models
28
from flare_removal.python import synthesis
29

30

31
flags.DEFINE_string(
32
    'eval_dir', '/tmp/eval',
33
    'Directory where evaluation summaries and outputs are written.')
34
flags.DEFINE_string(
35
    'train_dir', '/tmp/train',
36
    'Directory where training checkpoints are written. This script will '
37
    'repeatedly poll and evaluate the latest checkpoint.')
38
flags.DEFINE_string('scene_dir', None,
39
                    'Full path to the directory containing scene images.')
40
flags.DEFINE_string('flare_dir', None,
41
                    'Full path to the directory containing flare images.')
42
flags.DEFINE_enum(
43
    'data_source', 'jpg', ['tfrecord', 'jpg'],
44
    'Source of training data. Use "jpg" for individual image files, such as '
45
    'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.')
46
flags.DEFINE_string('model', 'unet', 'the name of the training model')
47
flags.DEFINE_string('loss', 'percep', 'the name of the loss for training')
48
flags.DEFINE_integer('batch_size', 2, 'Evaluation batch size.')
49
flags.DEFINE_float(
50
    'learning_rate', 1e-4,
51
    'Unused placeholder. The flag has to be defined to satisfy parameter sweep '
52
    'requirements.')
53
flags.DEFINE_float(
54
    'scene_noise', 0.01,
55
    'Gaussian noise sigma added in the scene in synthetic data. The actual '
56
    'Gaussian variance for each image will be drawn from a Chi-squared '
57
    'distribution with a scale of scene_noise.')
58
flags.DEFINE_float(
59
    'flare_max_gain', 10.0,
60
    'Max digital gain applied to the flare patterns during synthesis.')
61
flags.DEFINE_float('flare_loss_weight', 1.0,
62
                   'Weight added on the flare loss (scene loss is 1).')
63
flags.DEFINE_integer('training_res', 512,
64
                     'Image resolution at which the network is trained.')
65
FLAGS = flags.FLAGS
66

67

68
def main(_):
69
  eval_dir = FLAGS.eval_dir
70
  assert eval_dir, 'Flag --eval_dir must not be empty.'
71
  train_dir = FLAGS.train_dir
72
  assert train_dir, 'Flag --train_dir must not be empty.'
73
  summary_dir = os.path.join(eval_dir, 'summary')
74

75
  # Load data.
76
  scenes = data_provider.get_scene_dataset(
77
      FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=0)
78
  flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source,
79
                                           FLAGS.batch_size)
80

81
  # Make a model.
82
  model = models.build_model(FLAGS.model, FLAGS.batch_size)
83
  loss_fn = losses.get_loss(FLAGS.loss)
84

85
  ckpt = tf.train.Checkpoint(
86
      step=tf.Variable(0, dtype=tf.int64),
87
      training_finished=tf.Variable(False, dtype=tf.bool),
88
      model=model)
89

90
  summary_writer = tf.summary.create_file_writer(summary_dir)
91

92
  # The checkpoints_iterator keeps polling the latest training checkpoints,
93
  # until:
94
  #   1) `timeout` seconds have passed waiting for a new checkpoint; and
95
  #   2) `timeout_fn` (in this case, the flag indicating the last training
96
  #      checkpoint) evaluates to true.
97
  for ckpt_path in tf.train.checkpoints_iterator(
98
      train_dir, timeout=30, timeout_fn=lambda: ckpt.training_finished):
99
    try:
100
      status = ckpt.restore(ckpt_path)
101
      # Assert that all model variables are restored, but allow extra unmatched
102
      # variables in the checkpoint. (For example, optimizer states are not
103
      # needed for evaluation.)
104
      status.assert_existing_objects_matched()
105
      # Suppress warnings about unmatched variables.
106
      status.expect_partial()
107
      logging.info('Restored checkpoint %s @ step %d.', ckpt_path, ckpt.step)
108
    except (tf.errors.NotFoundError, AssertionError):
109
      logging.exception('Failed to restore checkpoint from %s.', ckpt_path)
110
      continue
111

112
    for scene, flare in tf.data.Dataset.zip((scenes, flares)):
113
      loss_value, summary = synthesis.run_step(
114
          scene,
115
          flare,
116
          model,
117
          loss_fn,
118
          noise=FLAGS.scene_noise,
119
          flare_max_gain=FLAGS.flare_max_gain,
120
          flare_loss_weight=FLAGS.flare_loss_weight,
121
          training_res=FLAGS.training_res)
122
    with summary_writer.as_default():
123
      tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step)
124
      tf.summary.scalar('loss', loss_value, step=ckpt.step)
125

126
  logging.info('Done!')
127

128

129
if __name__ == '__main__':
130
  app.run(main)
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.