google-research

Форк
0
165 строк · 6.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generates synthetic scenes containing lens flare."""
17
import math
18

19
import tensorflow as tf
20

21
from flare_removal.python import utils
22

23

24
def add_flare(scene,
25
              flare,
26
              noise,
27
              flare_max_gain = 10.0,
28
              apply_affine = True,
29
              training_res = 512):
30
  """Adds flare to natural images.
31

32
  Here the natural images are in sRGB. They are first linearized before flare
33
  patterns are added. The result is then converted back to sRGB.
34

35
  Args:
36
    scene: Natural image batch in sRGB.
37
    flare: Lens flare image batch in sRGB.
38
    noise: Strength of the additive Gaussian noise. For each image, the Gaussian
39
      variance is drawn from a scaled Chi-squared distribution, where the scale
40
      is defined by `noise`.
41
    flare_max_gain: Maximum gain applied to the flare images in the linear
42
      domain. RGB gains are applied randomly and independently, not exceeding
43
      this maximum.
44
    apply_affine: Whether to apply affine transformation.
45
    training_res: Resolution of training images. Images must be square, and this
46
      value specifies the side length.
47

48
  Returns:
49
    - Flare-free scene in sRGB.
50
    - Flare-only image in sRGB.
51
    - Scene with flare in sRGB.
52
    - Gamma value used during synthesis.
53
  """
54
  batch_size, flare_input_height, flare_input_width, _ = flare.shape
55

56
  # Since the gamma encoding is unknown, we use a random value so that the model
57
  # will hopefully generalize to a reasonable range of gammas.
58
  gamma = tf.random.uniform([], 1.8, 2.2)
59
  flare_linear = tf.image.adjust_gamma(flare, gamma)
60

61
  # Remove DC background in flare.
62
  flare_linear = utils.remove_background(flare_linear)
63

64
  if apply_affine:
65
    rotation = tf.random.uniform([batch_size], minval=-math.pi, maxval=math.pi)
66
    shift = tf.random.normal([batch_size, 2], mean=0.0, stddev=10.0)
67
    shear = tf.random.uniform([batch_size, 2],
68
                              minval=-math.pi / 9,
69
                              maxval=math.pi / 9)
70
    scale = tf.random.uniform([batch_size, 2], minval=0.9, maxval=1.2)
71

72
    flare_linear = utils.apply_affine_transform(
73
        flare_linear,
74
        rotation=rotation,
75
        shift_x=shift[:, 0],
76
        shift_y=shift[:, 1],
77
        shear_x=shear[:, 0],
78
        shear_y=shear[:, 1],
79
        scale_x=scale[:, 0],
80
        scale_y=scale[:, 1])
81

82
  flare_linear = tf.clip_by_value(flare_linear, 0.0, 1.0)
83
  flare_linear = tf.image.crop_to_bounding_box(
84
      flare_linear,
85
      offset_height=(flare_input_height - training_res) // 2,
86
      offset_width=(flare_input_width - training_res) // 2,
87
      target_height=training_res,
88
      target_width=training_res)
89
  flare_linear = tf.image.random_flip_left_right(
90
      tf.image.random_flip_up_down(flare_linear))
91

92
  # First normalize the white balance. Then apply random white balance.
93
  flare_linear = utils.normalize_white_balance(flare_linear)
94
  rgb_gains = tf.random.uniform([3], 0, flare_max_gain, dtype=tf.float32)
95
  flare_linear *= rgb_gains
96

97
  # Further augmentation on flare patterns: random blur and DC offset.
98
  blur_size = tf.random.uniform([], 0.1, 3)
99
  flare_linear = utils.apply_blur(flare_linear, blur_size)
100
  offset = tf.random.uniform([], -0.02, 0.02)
101
  flare_linear = tf.clip_by_value(flare_linear + offset, 0.0, 1.0)
102

103
  flare_srgb = tf.image.adjust_gamma(flare_linear, 1.0 / gamma)
104

105
  # Scene augmentation: random crop and flips.
106
  scene_linear = tf.image.adjust_gamma(scene, gamma)
107
  scene_linear = tf.image.random_crop(scene_linear, flare_linear.shape)
108
  scene_linear = tf.image.random_flip_left_right(
109
      tf.image.random_flip_up_down(scene_linear))
110

111
  # Additive Gaussian noise. The Gaussian's variance is drawn from a Chi-squared
112
  # distribution. This is equivalent to drawing the Gaussian's standard
113
  # deviation from a truncated normal distribution, as shown below.
114
  sigma = tf.abs(tf.random.normal([], 0, noise))
115
  noise = tf.random.normal(scene_linear.shape, 0, sigma)
116
  scene_linear += noise
117

118
  # Random digital gain.
119
  gain = tf.random.uniform([], 0, 1.2)  # varying the intensity scale
120
  scene_linear = tf.clip_by_value(gain * scene_linear, 0.0, 1.0)
121

122
  scene_srgb = tf.image.adjust_gamma(scene_linear, 1.0 / gamma)
123

124
  # Combine the flare-free scene with a flare pattern to produce a synthetic
125
  # training example.
126
  combined_linear = scene_linear + flare_linear
127
  combined_srgb = tf.image.adjust_gamma(combined_linear, 1.0 / gamma)
128
  combined_srgb = tf.clip_by_value(combined_srgb, 0.0, 1.0)
129

130
  return (utils.quantize_8(scene_srgb), utils.quantize_8(flare_srgb),
131
          utils.quantize_8(combined_srgb), gamma)
132

133

134
def run_step(scene,
135
             flare,
136
             model,
137
             loss_fn,
138
             noise = 0.0,
139
             flare_max_gain = 10.0,
140
             flare_loss_weight = 0.0,
141
             training_res = 512):
142
  """Executes a forward step."""
143
  scene, flare, combined, gamma = add_flare(
144
      scene,
145
      flare,
146
      flare_max_gain=flare_max_gain,
147
      noise=noise,
148
      training_res=training_res)
149

150
  pred_scene = model(combined)
151
  pred_flare = utils.remove_flare(combined, pred_scene, gamma)
152

153
  flare_mask = utils.get_highlight_mask(flare)
154
  # Fill the saturation region with the ground truth, so that no L1/L2 loss
155
  # and better for perceptual loss since it matches the surrounding scenes.
156
  masked_scene = pred_scene * (1 - flare_mask) + scene * flare_mask
157
  loss_value = loss_fn(scene, masked_scene)
158
  if flare_loss_weight > 0:
159
    masked_flare = pred_flare * (1 - flare_mask) + flare * flare_mask
160
    loss_value += flare_loss_weight * loss_fn(flare, masked_flare)
161

162
  image_summary = tf.concat([combined, pred_scene, scene, pred_flare, flare],
163
                            axis=2)
164

165
  return loss_value, image_summary
166

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.