google-research
210 строк · 7.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Remove flares from RGB images.
17
18This script exercises the following paper on RGB images:
19Yicheng Wu, Qiurui He, Tianfan Xue, Rahul Garg, Jiawen Chen, Ashok
20Veeraraghavan, and Jonathan T. Barron. How to train neural networks for flare
21removal. ICCV, 2021.
22
23Input images:
24
25- Images larger than 512 x 512 will be center-cropped to 512 x 512 before being
26passed to the model.
27
28- Images larger than 2048 x 2048 will be center-cropped to 2048 x 2048 first.
29Next, they will be downsampled to 512 x 512 and passed into the model. The
30inferred flare-free images will be upsampled back to 2048 x 2048. (Section 6.4
31of the paper.)
32
33- Images smaller than 512 x 512 are not supported.
34
35Output images:
36
37- By default, output images will be written to separate directories:
38- Preprocessed input
39- Inferred scene
40- Inferred flare
41- Inferred scene with blended light source (Section 5.2 of the paper)
42
43- Alternatively, use `--separate_out_dirs=0` to write output images to the same
44directory as the input. The output images will have different suffixes.
45"""
46
47import os.path48from typing import Optional49
50from absl import app51from absl import flags52import tensorflow as tf53import tqdm54
55from flare_removal.python import models56from flare_removal.python import utils57
58FLAGS = flags.FLAGS59
60_DEFAULT_CKPT = None61flags.DEFINE_string(62'ckpt', _DEFAULT_CKPT,63'Location of the model checkpoint. May be a SavedModel dir, in which case '64'the model architecture & weights are both loaded, and "--model" is '65'ignored. May also be a TF checkpoint path, in which case only the latest '66'model weights are loaded (this is much faster), and "--model" is '67'required. To load a specific checkpoint, use the checkpoint prefix '68'instead of the checkpoint directory for this argument.')69flags.DEFINE_string(70'model', None,71'Only required when "--ckpt" points to a TF checkpoint or checkpoint dir. '72'Must be one of "unet" or "can".')73flags.DEFINE_integer(74'batch_size', 1,75'Number of images in each batch. Some networks (e.g., the rain removal '76'network) can only accept predefined batch sizes.')77flags.DEFINE_string('input_dir', None,78'The directory contains all input images.')79flags.DEFINE_string('out_dir', None, 'Output directory.')80flags.DEFINE_boolean(81'separate_out_dirs', True,82'Whether the results are saved in separate folders under different names '83'(True), or the same folder under different names (False).')84
85
86def center_crop(image, width, height):87"""Returns the center crop of a given image."""88old_height, old_width, _ = image.shape89x_offset = (old_width - width) // 290y_offset = (old_height - height) // 291if x_offset < 0 or y_offset < 0:92raise ValueError('The specified output size is bigger than the image size.')93return image[y_offset:(y_offset + height), x_offset:(x_offset + width), :]94
95
96def write_outputs_same_dir(out_dir,97name_prefix,98input_image = None,99pred_scene = None,100pred_flare = None,101pred_blend = None):102"""Writes various outputs to the same directory on disk."""103if not tf.io.gfile.isdir(out_dir):104raise ValueError(f'{out_dir} is not a directory.')105path_prefix = os.path.join(out_dir, name_prefix)106if input_image is not None:107utils.write_image(input_image, path_prefix + '_input.png')108if pred_scene is not None:109utils.write_image(pred_scene, path_prefix + '_output.png')110if pred_flare is not None:111utils.write_image(pred_flare, path_prefix + '_output_flare.png')112if pred_blend is not None:113utils.write_image(pred_blend, path_prefix + '_output_blend.png')114
115
116def write_outputs_separate_dir(out_dir,117file_name,118input_image = None,119pred_scene = None,120pred_flare = None,121pred_blend = None):122"""Writes various outputs to separate subdirectories on disk."""123if not tf.io.gfile.isdir(out_dir):124raise ValueError(f'{out_dir} is not a directory.')125if input_image is not None:126utils.write_image(input_image, os.path.join(out_dir, 'input', file_name))127if pred_scene is not None:128utils.write_image(pred_scene, os.path.join(out_dir, 'output', file_name))129if pred_flare is not None:130utils.write_image(pred_flare,131os.path.join(out_dir, 'output_flare', file_name))132if pred_blend is not None:133utils.write_image(pred_blend,134os.path.join(out_dir, 'output_blend', file_name))135
136
137def process_one_image(model, image_path, out_dir, separate_out_dirs):138"""Reads one image and writes inference results to disk."""139with tf.io.gfile.GFile(image_path, 'rb') as f:140blob = f.read()141input_u8 = tf.image.decode_image(blob)[Ellipsis, :3]142input_f32 = tf.image.convert_image_dtype(input_u8, tf.float32, saturate=True)143h, w, _ = input_f32.shape144
145if min(h, w) >= 2048:146input_image = center_crop(input_f32, 2048, 2048)[None, Ellipsis]147input_low = tf.image.resize(148input_image, [512, 512], method=tf.image.ResizeMethod.AREA)149pred_scene_low = tf.clip_by_value(model(input_low), 0.0, 1.0)150pred_flare_low = utils.remove_flare(input_low, pred_scene_low)151pred_flare = tf.image.resize(pred_flare_low, [2048, 2048], antialias=True)152pred_scene = utils.remove_flare(input_image, pred_flare)153else:154input_image = center_crop(input_f32, 512, 512)[None, Ellipsis]155input_image = tf.concat([input_image] * FLAGS.batch_size, axis=0)156pred_scene = tf.clip_by_value(model(input_image), 0.0, 1.0)157pred_flare = utils.remove_flare(input_image, pred_scene)158pred_blend = utils.blend_light_source(input_image[0, Ellipsis], pred_scene[0, Ellipsis])159
160out_filename_stem = os.path.splitext(os.path.basename(image_path))[0]161if separate_out_dirs:162write_outputs_separate_dir(163out_dir,164out_filename_stem + '.png',165input_image=input_image[0, Ellipsis],166pred_scene=pred_scene[0, Ellipsis],167pred_flare=pred_flare[0, Ellipsis],168pred_blend=pred_blend)169else:170write_outputs_same_dir(171out_dir,172out_filename_stem,173input_image=input_image[0, Ellipsis],174pred_scene=pred_scene[0, Ellipsis],175pred_flare=pred_flare[0, Ellipsis],176pred_blend=pred_blend)177
178
179def load_model(path,180model_type = None,181batch_size = None):182"""Loads a model from SavedModel or standard TF checkpoint."""183try:184return tf.keras.models.load_model(path)185except (ImportError, IOError):186print(f'Didn\'t find SavedModel at "{path}". '187'Trying latest checkpoint next.')188model = models.build_model(model_type, batch_size)189ckpt = tf.train.Checkpoint(model=model)190ckpt_path = tf.train.latest_checkpoint(path) or path191ckpt.restore(ckpt_path).assert_existing_objects_matched()192return model193
194
195def main(_):196out_dir = FLAGS.out_dir or os.path.join(FLAGS.input_dir, 'model_output')197tf.io.gfile.makedirs(out_dir)198
199model = load_model(FLAGS.ckpt, FLAGS.model, FLAGS.batch_size)200
201# The following grep works for both png and jpg.202input_files = sorted(tf.io.gfile.glob(os.path.join(FLAGS.input_dir, '*.*g')))203for input_file in tqdm.tqdm(input_files):204process_one_image(model, input_file, out_dir, FLAGS.separate_out_dirs)205
206print('done')207
208
209if __name__ == '__main__':210app.run(main)211