google-research

Форк
0
172 строки · 6.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Training script for flare removal.
17

18
This script trains a model that outputs a flare-free image from a flare-polluted
19
image.
20
"""
21
import os.path
22
import time
23

24
from absl import app
25
from absl import flags
26
from absl import logging
27
import tensorflow as tf
28

29
from flare_removal.python import data_provider
30
from flare_removal.python import losses
31
from flare_removal.python import models
32
from flare_removal.python import synthesis
33

34

35
flags.DEFINE_string(
36
    'train_dir', '/tmp/train',
37
    'Directory where training checkpoints and summaries are written.')
38
flags.DEFINE_string('scene_dir', None,
39
                    'Full path to the directory containing scene images.')
40
flags.DEFINE_string('flare_dir', None,
41
                    'Full path to the directory containing flare images.')
42
flags.DEFINE_enum(
43
    'data_source', 'jpg', ['tfrecord', 'jpg'],
44
    'Source of training data. Use "jpg" for individual image files, such as '
45
    'JPG and PNG images. Use "tfrecord" for pre-baked sharded TFRecord files.')
46
flags.DEFINE_string('model', 'unet', 'the name of the training model')
47
flags.DEFINE_string('loss', 'percep', 'the name of the loss for training')
48
flags.DEFINE_integer('batch_size', 2, 'Training batch size.')
49
flags.DEFINE_integer('epochs', 100, 'Training config: epochs.')
50
flags.DEFINE_integer(
51
    'ckpt_period', 1000,
52
    'Write model checkpoint and summary to disk every ckpt_period steps.')
53
flags.DEFINE_float('learning_rate', 1e-4, 'Initial learning rate.')
54
flags.DEFINE_float(
55
    'scene_noise', 0.01,
56
    'Gaussian noise sigma added in the scene in synthetic data. The actual '
57
    'Gaussian variance for each image will be drawn from a Chi-squared '
58
    'distribution with a scale of scene_noise.')
59
flags.DEFINE_float(
60
    'flare_max_gain', 10.0,
61
    'Max digital gain applied to the flare patterns during synthesis.')
62
flags.DEFINE_float('flare_loss_weight', 1.0,
63
                   'Weight added on the flare loss (scene loss is 1).')
64
flags.DEFINE_integer('training_res', 512, 'Training resolution.')
65
FLAGS = flags.FLAGS
66

67

68
@tf.function
69
def train_step(model, scene, flare, loss_fn, optimizer):
70
  """Executes one step of gradient descent."""
71
  with tf.GradientTape() as tape:
72
    loss_value, summary = synthesis.run_step(
73
        scene,
74
        flare,
75
        model,
76
        loss_fn,
77
        noise=FLAGS.scene_noise,
78
        flare_max_gain=FLAGS.flare_max_gain,
79
        flare_loss_weight=FLAGS.flare_loss_weight,
80
        training_res=FLAGS.training_res)
81
  grads = tape.gradient(loss_value, model.trainable_weights)
82
  grads, _ = tf.clip_by_global_norm(grads, 5.0)
83
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
84
  return loss_value, summary
85

86

87
def main(_):
88
  train_dir = FLAGS.train_dir
89
  assert train_dir, 'Flag --train_dir must not be empty.'
90
  summary_dir = os.path.join(train_dir, 'summary')
91
  model_dir = os.path.join(train_dir, 'model')
92

93
  # Load data.
94
  scenes = data_provider.get_scene_dataset(
95
      FLAGS.scene_dir, FLAGS.data_source, FLAGS.batch_size, repeat=FLAGS.epochs)
96
  flares = data_provider.get_flare_dataset(FLAGS.flare_dir, FLAGS.data_source,
97
                                           FLAGS.batch_size)
98

99
  # Make a model.
100
  model = models.build_model(FLAGS.model, FLAGS.batch_size)
101
  optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate)
102
  loss_fn = losses.get_loss(FLAGS.loss)
103

104
  # Model checkpoints. Checkpoints don't contain model architecture, but
105
  # weights only. We use checkpoints to keep track of the training progress.
106
  ckpt = tf.train.Checkpoint(
107
      step=tf.Variable(0, dtype=tf.int64),
108
      training_finished=tf.Variable(False, dtype=tf.bool),
109
      optimizer=optimizer,
110
      model=model)
111
  ckpt_mgr = tf.train.CheckpointManager(
112
      ckpt, train_dir, max_to_keep=3, keep_checkpoint_every_n_hours=3)
113

114
  # Restore the latest checkpoint (model weights), if any. This is helpful if
115
  # the training job gets restarted from an unexpected termination.
116
  latest_ckpt = ckpt_mgr.latest_checkpoint
117
  restore_status = None
118
  if latest_ckpt is not None:
119
    # Note that due to lazy initialization, not all checkpointed variables can
120
    # be restored at this point. Hence 'expect_partial()'. Full restoration is
121
    # checked in the first training step below.
122
    restore_status = ckpt.restore(latest_ckpt).expect_partial()
123
    logging.info('Restoring latest checkpoint @ step %d from: %s', ckpt.step,
124
                 latest_ckpt)
125
  else:
126
    logging.info('Previous checkpoints not found. Starting afresh.')
127

128
  summary_writer = tf.summary.create_file_writer(summary_dir)
129

130
  step_time_metric = tf.keras.metrics.Mean('step_time')
131
  step_start_time = time.time()
132
  for scene, flare in tf.data.Dataset.zip((scenes, flares)):
133
    # Perform one training step.
134
    loss_value, summary = train_step(model, scene, flare, loss_fn, optimizer)
135

136
    # By this point, all lazily initialized variables should have been
137
    # restored by the checkpoint if one was available.
138
    if restore_status is not None:
139
      restore_status.assert_consumed()
140
      restore_status = None
141

142
    # Write training summaries and checkpoints to disk.
143
    ckpt.step.assign_add(1)
144
    if ckpt.step % FLAGS.ckpt_period == 0:
145
      # Write model checkpoint to disk.
146
      ckpt_mgr.save()
147

148
      # Also save the full model using the latest weights. To restore previous
149
      # weights, you'd have to load the model and restore a previously saved
150
      # checkpoint.
151
      tf.keras.models.save_model(model, model_dir, save_format='tf')
152

153
      # Write summaries to disk, which can be visualized with TensorBoard.
154
      with summary_writer.as_default():
155
        tf.summary.image('prediction', summary, max_outputs=1, step=ckpt.step)
156
        tf.summary.scalar('loss', loss_value, step=ckpt.step)
157
        tf.summary.scalar(
158
            'step_time', step_time_metric.result(), step=ckpt.step)
159
        step_time_metric.reset_state()
160

161
    # Record elapsed time in this training step.
162
    step_end_time = time.time()
163
    step_time_metric.update_state(step_end_time - step_start_time)
164
    step_start_time = step_end_time
165

166
  ckpt.training_finished.assign(True)
167
  ckpt_mgr.save()
168
  logging.info('Done!')
169

170

171
if __name__ == '__main__':
172
  app.run(main)
173

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.