google-research

Форк
0
210 строк · 7.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Remove flares from RGB images.
17

18
This script exercises the following paper on RGB images:
19
  Yicheng Wu, Qiurui He, Tianfan Xue, Rahul Garg, Jiawen Chen, Ashok
20
  Veeraraghavan, and Jonathan T. Barron. How to train neural networks for flare
21
  removal. ICCV, 2021.
22

23
Input images:
24

25
- Images larger than 512 x 512 will be center-cropped to 512 x 512 before being
26
  passed to the model.
27

28
- Images larger than 2048 x 2048 will be center-cropped to 2048 x 2048 first.
29
  Next, they will be downsampled to 512 x 512 and passed into the model. The
30
  inferred flare-free images will be upsampled back to 2048 x 2048. (Section 6.4
31
  of the paper.)
32

33
- Images smaller than 512 x 512 are not supported.
34

35
Output images:
36

37
- By default, output images will be written to separate directories:
38
  - Preprocessed input
39
  - Inferred scene
40
  - Inferred flare
41
  - Inferred scene with blended light source (Section 5.2 of the paper)
42

43
- Alternatively, use `--separate_out_dirs=0` to write output images to the same
44
  directory as the input. The output images will have different suffixes.
45
"""
46

47
import os.path
48
from typing import Optional
49

50
from absl import app
51
from absl import flags
52
import tensorflow as tf
53
import tqdm
54

55
from flare_removal.python import models
56
from flare_removal.python import utils
57

58
FLAGS = flags.FLAGS
59

60
_DEFAULT_CKPT = None
61
flags.DEFINE_string(
62
    'ckpt', _DEFAULT_CKPT,
63
    'Location of the model checkpoint. May be a SavedModel dir, in which case '
64
    'the model architecture & weights are both loaded, and "--model" is '
65
    'ignored. May also be a TF checkpoint path, in which case only the latest '
66
    'model weights are loaded (this is much faster), and "--model" is '
67
    'required. To load a specific checkpoint, use the checkpoint prefix '
68
    'instead of the checkpoint directory for this argument.')
69
flags.DEFINE_string(
70
    'model', None,
71
    'Only required when "--ckpt" points to a TF checkpoint or checkpoint dir. '
72
    'Must be one of "unet" or "can".')
73
flags.DEFINE_integer(
74
    'batch_size', 1,
75
    'Number of images in each batch. Some networks (e.g., the rain removal '
76
    'network) can only accept predefined batch sizes.')
77
flags.DEFINE_string('input_dir', None,
78
                    'The directory contains all input images.')
79
flags.DEFINE_string('out_dir', None, 'Output directory.')
80
flags.DEFINE_boolean(
81
    'separate_out_dirs', True,
82
    'Whether the results are saved in separate folders under different names '
83
    '(True), or the same folder under different names (False).')
84

85

86
def center_crop(image, width, height):
87
  """Returns the center crop of a given image."""
88
  old_height, old_width, _ = image.shape
89
  x_offset = (old_width - width) // 2
90
  y_offset = (old_height - height) // 2
91
  if x_offset < 0 or y_offset < 0:
92
    raise ValueError('The specified output size is bigger than the image size.')
93
  return image[y_offset:(y_offset + height), x_offset:(x_offset + width), :]
94

95

96
def write_outputs_same_dir(out_dir,
97
                           name_prefix,
98
                           input_image = None,
99
                           pred_scene = None,
100
                           pred_flare = None,
101
                           pred_blend = None):
102
  """Writes various outputs to the same directory on disk."""
103
  if not tf.io.gfile.isdir(out_dir):
104
    raise ValueError(f'{out_dir} is not a directory.')
105
  path_prefix = os.path.join(out_dir, name_prefix)
106
  if input_image is not None:
107
    utils.write_image(input_image, path_prefix + '_input.png')
108
  if pred_scene is not None:
109
    utils.write_image(pred_scene, path_prefix + '_output.png')
110
  if pred_flare is not None:
111
    utils.write_image(pred_flare, path_prefix + '_output_flare.png')
112
  if pred_blend is not None:
113
    utils.write_image(pred_blend, path_prefix + '_output_blend.png')
114

115

116
def write_outputs_separate_dir(out_dir,
117
                               file_name,
118
                               input_image = None,
119
                               pred_scene = None,
120
                               pred_flare = None,
121
                               pred_blend = None):
122
  """Writes various outputs to separate subdirectories on disk."""
123
  if not tf.io.gfile.isdir(out_dir):
124
    raise ValueError(f'{out_dir} is not a directory.')
125
  if input_image is not None:
126
    utils.write_image(input_image, os.path.join(out_dir, 'input', file_name))
127
  if pred_scene is not None:
128
    utils.write_image(pred_scene, os.path.join(out_dir, 'output', file_name))
129
  if pred_flare is not None:
130
    utils.write_image(pred_flare,
131
                      os.path.join(out_dir, 'output_flare', file_name))
132
  if pred_blend is not None:
133
    utils.write_image(pred_blend,
134
                      os.path.join(out_dir, 'output_blend', file_name))
135

136

137
def process_one_image(model, image_path, out_dir, separate_out_dirs):
138
  """Reads one image and writes inference results to disk."""
139
  with tf.io.gfile.GFile(image_path, 'rb') as f:
140
    blob = f.read()
141
  input_u8 = tf.image.decode_image(blob)[Ellipsis, :3]
142
  input_f32 = tf.image.convert_image_dtype(input_u8, tf.float32, saturate=True)
143
  h, w, _ = input_f32.shape
144

145
  if min(h, w) >= 2048:
146
    input_image = center_crop(input_f32, 2048, 2048)[None, Ellipsis]
147
    input_low = tf.image.resize(
148
        input_image, [512, 512], method=tf.image.ResizeMethod.AREA)
149
    pred_scene_low = tf.clip_by_value(model(input_low), 0.0, 1.0)
150
    pred_flare_low = utils.remove_flare(input_low, pred_scene_low)
151
    pred_flare = tf.image.resize(pred_flare_low, [2048, 2048], antialias=True)
152
    pred_scene = utils.remove_flare(input_image, pred_flare)
153
  else:
154
    input_image = center_crop(input_f32, 512, 512)[None, Ellipsis]
155
    input_image = tf.concat([input_image] * FLAGS.batch_size, axis=0)
156
    pred_scene = tf.clip_by_value(model(input_image), 0.0, 1.0)
157
    pred_flare = utils.remove_flare(input_image, pred_scene)
158
  pred_blend = utils.blend_light_source(input_image[0, Ellipsis], pred_scene[0, Ellipsis])
159

160
  out_filename_stem = os.path.splitext(os.path.basename(image_path))[0]
161
  if separate_out_dirs:
162
    write_outputs_separate_dir(
163
        out_dir,
164
        out_filename_stem + '.png',
165
        input_image=input_image[0, Ellipsis],
166
        pred_scene=pred_scene[0, Ellipsis],
167
        pred_flare=pred_flare[0, Ellipsis],
168
        pred_blend=pred_blend)
169
  else:
170
    write_outputs_same_dir(
171
        out_dir,
172
        out_filename_stem,
173
        input_image=input_image[0, Ellipsis],
174
        pred_scene=pred_scene[0, Ellipsis],
175
        pred_flare=pred_flare[0, Ellipsis],
176
        pred_blend=pred_blend)
177

178

179
def load_model(path,
180
               model_type = None,
181
               batch_size = None):
182
  """Loads a model from SavedModel or standard TF checkpoint."""
183
  try:
184
    return tf.keras.models.load_model(path)
185
  except (ImportError, IOError):
186
    print(f'Didn\'t find SavedModel at "{path}". '
187
          'Trying latest checkpoint next.')
188
  model = models.build_model(model_type, batch_size)
189
  ckpt = tf.train.Checkpoint(model=model)
190
  ckpt_path = tf.train.latest_checkpoint(path) or path
191
  ckpt.restore(ckpt_path).assert_existing_objects_matched()
192
  return model
193

194

195
def main(_):
196
  out_dir = FLAGS.out_dir or os.path.join(FLAGS.input_dir, 'model_output')
197
  tf.io.gfile.makedirs(out_dir)
198

199
  model = load_model(FLAGS.ckpt, FLAGS.model, FLAGS.batch_size)
200

201
  # The following grep works for both png and jpg.
202
  input_files = sorted(tf.io.gfile.glob(os.path.join(FLAGS.input_dir, '*.*g')))
203
  for input_file in tqdm.tqdm(input_files):
204
    process_one_image(model, input_file, out_dir, FLAGS.separate_out_dirs)
205

206
  print('done')
207

208

209
if __name__ == '__main__':
210
  app.run(main)
211

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.