google-research
130 строк · 5.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation script for flare removal."""
17
18import os.path
19
20from absl import app
21from absl import flags
22from absl import logging
23import tensorflow as tf
24
25from flare_removal.python import data_provider
26from flare_removal.python import losses
27from flare_removal.python import models
28from flare_removal.python import synthesis
29
30
31flags.DEFINE_string(
32'eval_dir', '/tmp/eval',
33'Directory where evaluation summaries and outputs are written.')
34flags.DEFINE_string(
35'train_dir', '/tmp/train',
36'Directory where training checkpoints are written. This script will '
37'repeatedly poll and evaluate the latest checkpoint.')
38flags.DEFINE_string('scene_dir', None,
39'Full path to the directory containing scene images.')
40flags.DEFINE_string('flare_dir', None,
41'Full path to the directory containing flare images.')
42flags.DEFINE_enum(
43'data_source', 'jpg', ['tfrecord', 'jpg'],
44'Source of training data. Use "jpg" for individual image files, such as '
45'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.')
46flags.DEFINE_string('model', 'unet', 'the name of the training model')
47flags.DEFINE_string('loss', 'percep', 'the name of the loss for training')
48flags.DEFINE_integer('batch_size', 2, 'Evaluation batch size.')
49flags.DEFINE_float(
50'learning_rate', 1e-4,
51'Unused placeholder. The flag has to be defined to satisfy parameter sweep '
52'requirements.')
53flags.DEFINE_float(
54'scene_noise', 0.01,
55'Gaussian noise sigma added in the scene in synthetic data. The actual '
56'Gaussian variance for each image will be drawn from a Chi-squared '
57'distribution with a scale of scene_noise.')
58flags.DEFINE_float(
59'flare_max_gain', 10.0,
60'Max digital gain applied to the flare patterns during synthesis.')
61flags.DEFINE_float('flare_loss_weight', 1.0,
62'Weight added on the flare loss (scene loss is 1).')
63flags.DEFINE_integer('training_res', 512,
64'Image resolution at which the network is trained.')
65FLAGS = flags.FLAGS
66
67
68def main(_):
69eval_dir = FLAGS.eval_dir
70assert eval_dir, 'Flag --eval_dir must not be empty.'
71train_dir = FLAGS.train_dir
72assert train_dir, 'Flag --train_dir must not be empty.'
73summary_dir = os.path.join(eval_dir, 'summary')
74
75# Load data.
76scenes = data_provider.get_scene_dataset(
77FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=0)
78flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source,
79FLAGS.batch_size)
80
81# Make a model.
82model = models.build_model(FLAGS.model, FLAGS.batch_size)
83loss_fn = losses.get_loss(FLAGS.loss)
84
85ckpt = tf.train.Checkpoint(
86step=tf.Variable(0, dtype=tf.int64),
87training_finished=tf.Variable(False, dtype=tf.bool),
88model=model)
89
90summary_writer = tf.summary.create_file_writer(summary_dir)
91
92# The checkpoints_iterator keeps polling the latest training checkpoints,
93# until:
94# 1) `timeout` seconds have passed waiting for a new checkpoint; and
95# 2) `timeout_fn` (in this case, the flag indicating the last training
96# checkpoint) evaluates to true.
97for ckpt_path in tf.train.checkpoints_iterator(
98train_dir, timeout=30, timeout_fn=lambda: ckpt.training_finished):
99try:
100status = ckpt.restore(ckpt_path)
101# Assert that all model variables are restored, but allow extra unmatched
102# variables in the checkpoint. (For example, optimizer states are not
103# needed for evaluation.)
104status.assert_existing_objects_matched()
105# Suppress warnings about unmatched variables.
106status.expect_partial()
107logging.info('Restored checkpoint %s @ step %d.', ckpt_path, ckpt.step)
108except (tf.errors.NotFoundError, AssertionError):
109logging.exception('Failed to restore checkpoint from %s.', ckpt_path)
110continue
111
112for scene, flare in tf.data.Dataset.zip((scenes, flares)):
113loss_value, summary = synthesis.run_step(
114scene,
115flare,
116model,
117loss_fn,
118noise=FLAGS.scene_noise,
119flare_max_gain=FLAGS.flare_max_gain,
120flare_loss_weight=FLAGS.flare_loss_weight,
121training_res=FLAGS.training_res)
122with summary_writer.as_default():
123tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step)
124tf.summary.scalar('loss', loss_value, step=ckpt.step)
125
126logging.info('Done!')
127
128
129if __name__ == '__main__':
130app.run(main)
131