google-research
172 строки · 6.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Training script for flare removal.
17
18This script trains a model that outputs a flare-free image from a flare-polluted
19image.
20"""
21import os.path
22import time
23
24from absl import app
25from absl import flags
26from absl import logging
27import tensorflow as tf
28
29from flare_removal.python import data_provider
30from flare_removal.python import losses
31from flare_removal.python import models
32from flare_removal.python import synthesis
33
34
35flags.DEFINE_string(
36'train_dir', '/tmp/train',
37'Directory where training checkpoints and summaries are written.')
38flags.DEFINE_string('scene_dir', None,
39'Full path to the directory containing scene images.')
40flags.DEFINE_string('flare_dir', None,
41'Full path to the directory containing flare images.')
42flags.DEFINE_enum(
43'data_source', 'jpg', ['tfrecord', 'jpg'],
44'Source of training data. Use "jpg" for individual image files, such as '
45'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.')
46flags.DEFINE_string('model', 'unet', 'the name of the training model')
47flags.DEFINE_string('loss', 'percep', 'the name of the loss for training')
48flags.DEFINE_integer('batch_size', 2, 'Training batch size.')
49flags.DEFINE_integer('epochs', 100, 'Training config: epochs.')
50flags.DEFINE_integer(
51'ckpt_period', 1000,
52'Write model checkpoint and summary to disk every ckpt_period steps.')
53flags.DEFINE_float('learning_rate', 1e-4, 'Initial learning rate.')
54flags.DEFINE_float(
55'scene_noise', 0.01,
56'Gaussian noise sigma added in the scene in synthetic data. The actual '
57'Gaussian variance for each image will be drawn from a Chi-squared '
58'distribution with a scale of scene_noise.')
59flags.DEFINE_float(
60'flare_max_gain', 10.0,
61'Max digital gain applied to the flare patterns during synthesis.')
62flags.DEFINE_float('flare_loss_weight', 1.0,
63'Weight added on the flare loss (scene loss is 1).')
64flags.DEFINE_integer('training_res', 512, 'Training resolution.')
65FLAGS = flags.FLAGS
66
67
68@tf.function
69def train_step(model, scene, flare, loss_fn, optimizer):
70"""Executes one step of gradient descent."""
71with tf.GradientTape() as tape:
72loss_value, summary = synthesis.run_step(
73scene,
74flare,
75model,
76loss_fn,
77noise=FLAGS.scene_noise,
78flare_max_gain=FLAGS.flare_max_gain,
79flare_loss_weight=FLAGS.flare_loss_weight,
80training_res=FLAGS.training_res)
81grads = tape.gradient(loss_value, model.trainable_weights)
82grads, _ = tf.clip_by_global_norm(grads, 5.0)
83optimizer.apply_gradients(zip(grads, model.trainable_weights))
84return loss_value, summary
85
86
87def main(_):
88train_dir = FLAGS.train_dir
89assert train_dir, 'Flag --train_dir must not be empty.'
90summary_dir = os.path.join(train_dir, 'summary')
91model_dir = os.path.join(train_dir, 'model')
92
93# Load data.
94scenes = data_provider.get_scene_dataset(
95FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=FLAGS.epochs)
96flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source,
97FLAGS.batch_size)
98
99# Make a model.
100model = models.build_model(FLAGS.model, FLAGS.batch_size)
101optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
102loss_fn = losses.get_loss(FLAGS.loss)
103
104# Model checkpoints. Checkpoints don't contain model architecture, but
105# weights only. We use checkpoints to keep track of the training progress.
106ckpt = tf.train.Checkpoint(
107step=tf.Variable(0, dtype=tf.int64),
108training_finished=tf.Variable(False, dtype=tf.bool),
109optimizer=optimizer,
110model=model)
111ckpt_mgr = tf.train.CheckpointManager(
112ckpt, train_dir, max_to_keep=3, keep_checkpoint_every_n_hours=3)
113
114# Restore the latest checkpoint (model weights), if any. This is helpful if
115# the training job gets restarted from an unexpected termination.
116latest_ckpt = ckpt_mgr.latest_checkpoint
117restore_status = None
118if latest_ckpt is not None:
119# Note that due to lazy initialization, not all checkpointed variables can
120# be restored at this point. Hence 'expect_partial()'. Full restoration is
121# checked in the first training step below.
122restore_status = ckpt.restore(latest_ckpt).expect_partial()
123logging.info('Restoring latest checkpoint @ step %d from: %s', ckpt.step,
124latest_ckpt)
125else:
126logging.info('Previous checkpoints not found. Starting afresh.')
127
128summary_writer = tf.summary.create_file_writer(summary_dir)
129
130step_time_metric = tf.keras.metrics.Mean('step_time')
131step_start_time = time.time()
132for scene, flare in tf.data.Dataset.zip((scenes, flares)):
133# Perform one training step.
134loss_value, summary = train_step(model, scene, flare, loss_fn, optimizer)
135
136# By this point, all lazily initialized variables should have been
137# restored by the checkpoint if one was available.
138if restore_status is not None:
139restore_status.assert_consumed()
140restore_status = None
141
142# Write training summaries and checkpoints to disk.
143ckpt.step.assign_add(1)
144if ckpt.step % FLAGS.ckpt_period == 0:
145# Write model checkpoint to disk.
146ckpt_mgr.save()
147
148# Also save the full model using the latest weights. To restore previous
149# weights, you'd have to load the model and restore a previously saved
150# checkpoint.
151tf.keras.models.save_model(model, model_dir, save_format='tf')
152
153# Write summaries to disk, which can be visualized with TensorBoard.
154with summary_writer.as_default():
155tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step)
156tf.summary.scalar('loss', loss_value, step=ckpt.step)
157tf.summary.scalar(
158'step_time', step_time_metric.result(), step=ckpt.step)
159step_time_metric.reset_state()
160
161# Record elapsed time in this training step.
162step_end_time = time.time()
163step_time_metric.update_state(step_end_time - step_start_time)
164step_start_time = step_end_time
165
166ckpt.training_finished.assign(True)
167ckpt_mgr.save()
168logging.info('Done!')
169
170
171if __name__ == '__main__':
172app.run(main)
173