google-research
165 строк · 6.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generates synthetic scenes containing lens flare."""
17import math
18
19import tensorflow as tf
20
21from flare_removal.python import utils
22
23
24def add_flare(scene,
25flare,
26noise,
27flare_max_gain = 10.0,
28apply_affine = True,
29training_res = 512):
30"""Adds flare to natural images.
31
32Here the natural images are in sRGB. They are first linearized before flare
33patterns are added. The result is then converted back to sRGB.
34
35Args:
36scene: Natural image batch in sRGB.
37flare: Lens flare image batch in sRGB.
38noise: Strength of the additive Gaussian noise. For each image, the Gaussian
39variance is drawn from a scaled Chi-squared distribution, where the scale
40is defined by `noise`.
41flare_max_gain: Maximum gain applied to the flare images in the linear
42domain. RGB gains are applied randomly and independently, not exceeding
43this maximum.
44apply_affine: Whether to apply affine transformation.
45training_res: Resolution of training images. Images must be square, and this
46value specifies the side length.
47
48Returns:
49- Flare-free scene in sRGB.
50- Flare-only image in sRGB.
51- Scene with flare in sRGB.
52- Gamma value used during synthesis.
53"""
54batch_size, flare_input_height, flare_input_width, _ = flare.shape
55
56# Since the gamma encoding is unknown, we use a random value so that the model
57# will hopefully generalize to a reasonable range of gammas.
58gamma = tf.random.uniform([], 1.8, 2.2)
59flare_linear = tf.image.adjust_gamma(flare, gamma)
60
61# Remove DC background in flare.
62flare_linear = utils.remove_background(flare_linear)
63
64if apply_affine:
65rotation = tf.random.uniform([batch_size], minval=-math.pi, maxval=math.pi)
66shift = tf.random.normal([batch_size, 2], mean=0.0, stddev=10.0)
67shear = tf.random.uniform([batch_size, 2],
68minval=-math.pi / 9,
69maxval=math.pi / 9)
70scale = tf.random.uniform([batch_size, 2], minval=0.9, maxval=1.2)
71
72flare_linear = utils.apply_affine_transform(
73flare_linear,
74rotation=rotation,
75shift_x=shift[:, 0],
76shift_y=shift[:, 1],
77shear_x=shear[:, 0],
78shear_y=shear[:, 1],
79scale_x=scale[:, 0],
80scale_y=scale[:, 1])
81
82flare_linear = tf.clip_by_value(flare_linear, 0.0, 1.0)
83flare_linear = tf.image.crop_to_bounding_box(
84flare_linear,
85offset_height=(flare_input_height - training_res) // 2,
86offset_width=(flare_input_width - training_res) // 2,
87target_height=training_res,
88target_width=training_res)
89flare_linear = tf.image.random_flip_left_right(
90tf.image.random_flip_up_down(flare_linear))
91
92# First normalize the white balance. Then apply random white balance.
93flare_linear = utils.normalize_white_balance(flare_linear)
94rgb_gains = tf.random.uniform([3], 0, flare_max_gain, dtype=tf.float32)
95flare_linear *= rgb_gains
96
97# Further augmentation on flare patterns: random blur and DC offset.
98blur_size = tf.random.uniform([], 0.1, 3)
99flare_linear = utils.apply_blur(flare_linear, blur_size)
100offset = tf.random.uniform([], -0.02, 0.02)
101flare_linear = tf.clip_by_value(flare_linear + offset, 0.0, 1.0)
102
103flare_srgb = tf.image.adjust_gamma(flare_linear, 1.0 / gamma)
104
105# Scene augmentation: random crop and flips.
106scene_linear = tf.image.adjust_gamma(scene, gamma)
107scene_linear = tf.image.random_crop(scene_linear, flare_linear.shape)
108scene_linear = tf.image.random_flip_left_right(
109tf.image.random_flip_up_down(scene_linear))
110
111# Additive Gaussian noise. The Gaussian's variance is drawn from a Chi-squared
112# distribution. This is equivalent to drawing the Gaussian's standard
113# deviation from a truncated normal distribution, as shown below.
114sigma = tf.abs(tf.random.normal([], 0, noise))
115noise = tf.random.normal(scene_linear.shape, 0, sigma)
116scene_linear += noise
117
118# Random digital gain.
119gain = tf.random.uniform([], 0, 1.2) # varying the intensity scale
120scene_linear = tf.clip_by_value(gain * scene_linear, 0.0, 1.0)
121
122scene_srgb = tf.image.adjust_gamma(scene_linear, 1.0 / gamma)
123
124# Combine the flare-free scene with a flare pattern to produce a synthetic
125# training example.
126combined_linear = scene_linear + flare_linear
127combined_srgb = tf.image.adjust_gamma(combined_linear, 1.0 / gamma)
128combined_srgb = tf.clip_by_value(combined_srgb, 0.0, 1.0)
129
130return (utils.quantize_8(scene_srgb), utils.quantize_8(flare_srgb),
131utils.quantize_8(combined_srgb), gamma)
132
133
134def run_step(scene,
135flare,
136model,
137loss_fn,
138noise = 0.0,
139flare_max_gain = 10.0,
140flare_loss_weight = 0.0,
141training_res = 512):
142"""Executes a forward step."""
143scene, flare, combined, gamma = add_flare(
144scene,
145flare,
146flare_max_gain=flare_max_gain,
147noise=noise,
148training_res=training_res)
149
150pred_scene = model(combined)
151pred_flare = utils.remove_flare(combined, pred_scene, gamma)
152
153flare_mask = utils.get_highlight_mask(flare)
154# Fill the saturation region with the ground truth, so that no L1/L2 loss
155# and better for perceptual loss since it matches the surrounding scenes.
156masked_scene = pred_scene * (1 - flare_mask) + scene * flare_mask
157loss_value = loss_fn(scene, masked_scene)
158if flare_loss_weight > 0:
159masked_flare = pred_flare * (1 - flare_mask) + flare * flare_mask
160loss_value += flare_loss_weight * loss_fn(flare, masked_flare)
161
162image_summary = tf.concat([combined, pred_scene, scene, pred_flare, flare],
163axis=2)
164
165return loss_value, image_summary
166