google-research

Форк
0
415 строк · 11.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Noise generators."""
17
import numpy as np
18
from scipy import ndimage
19
import scipy.stats
20
import torch
21
import torch.nn as nn
22
import torch.nn.functional as F
23

24

25
def make_kernel(size=3, bounds=3):
26
  """Create Gaussian kernel."""
27
  kernel_basis = np.linspace(-bounds, bounds, size+1)
28

29
  # Create gaussian kernel
30
  kernel_1d = np.diff(scipy.stats.norm.cdf(kernel_basis))
31
  kernel = np.outer(kernel_1d, kernel_1d)
32

33
  # Normalize kernel
34
  kernel = kernel / kernel.sum()
35

36
  # Reshape to dim for pytorch conv2d and repeat
37
  kernel = torch.tensor(kernel).float()
38
  kernel = kernel.reshape(1, 1, *kernel.size())
39
  kernel = kernel.repeat(3, *[1] * (kernel.dim() - 1))
40
  return kernel
41

42

43
def add_gaussian_blur(x, k_size=3):
44
  """Add Gaussian blur to image.
45

46
  Adapted from
47
  https://github.com/kechan/FastaiPlayground/blob/master/Quick%20Tour%20of%20Data%20Augmentation.ipynb
48
  Args:
49
    x: source image.
50
    k_size: kernel size.
51

52
  Returns:
53
    x: Gaussian blurred image.
54
  """
55
  kernel = make_kernel(k_size)
56
  padding = (k_size - 1) // 2
57

58
  x = x.unsqueeze(dim=0)
59
  padded_x = F.pad(x, [padding] * x.dim(), mode='reflect')
60
  x = F.conv2d(padded_x, kernel, groups=3)
61
  return x.squeeze()
62

63

64
def add_patch(tensor,
65
              noise_location,
66
              patch_type=False,
67
              min_size=16,
68
              max_size=32):
69
  """Add focus/occluding patch."""
70
  _, h, w = tensor.shape
71
  if noise_location == 'random':
72
    w_size = np.random.randint(min_size, max_size+1)
73
    h_size = w_size
74
    x1 = np.random.randint(0, w - w_size + 1)
75
    y1 = np.random.randint(0, h - h_size + 1)
76
  elif noise_location == 'center':
77
    w_size = min_size
78
    h_size = min_size
79
    # Center
80
    x1 = (w - w_size) // 2
81
    y1 = (h - h_size) // 2
82

83
  x2 = x1 + w_size
84
  y2 = y1 + h_size
85

86
  if patch_type == 'focus':
87
    blured_tensor = add_gaussian_blur(tensor.clone())
88
    blured_tensor[:, y1:y2, x1:x2] = tensor[:, y1:y2, x1:x2]
89
    tensor = blured_tensor.clone()
90
  elif patch_type == 'occlusion':
91
    tensor[:, y1:y2, x1:x2] = 0
92
  else:
93
    assert False, f'{patch_type} not implemented!'
94
  return tensor
95

96

97
def pad_image(img, padding=32 * 2):
98
  """Pad image."""
99
  c, h, w = img.shape
100

101
  x1 = padding
102
  x2 = padding + w
103
  y1 = padding
104
  y2 = padding + h
105

106
  # Base
107
  x_padded = torch.zeros((c, h + padding * 2, w + padding * 2))
108
  # Left
109
  x_padded[:, y1:y2, :padding] = img[:, :, 0:1].repeat(1, 1, padding)
110
  # Right
111
  x_padded[:, y1:y2, x2:] = img[:, :, w - 1:w].repeat(1, 1, padding)
112
  # Top
113
  x_padded[:, :padding, x1:x2] = img[:, 0:1, :].repeat(1, padding, 1)
114
  # Bottom
115
  x_padded[:, y2:, x1:x2] = img[:, h - 1:h, :].repeat(1, padding, 1)
116
  # Top Left corner
117
  x_padded[:, :padding, :padding] = img[:, 0:1, 0:1].repeat(1, padding, padding)
118
  # Bottom left corner
119
  x_padded[:, y2:, :padding] = img[:, h - 1:h, 0:1].repeat(1, padding, padding)
120
  # Top right corner
121
  x_padded[:, :padding, x2:] = img[:, 0:1, w - 1:w].repeat(1, padding, padding)
122
  # Bottom right corner
123
  x_padded[:, y2:, x2:] = img[:, h - 1:h, w - 1:w].repeat(1, padding, padding)
124
  # Fill in source image
125
  x_padded[:, y1:y2, x1:x2] = img
126

127
  return x_padded, (x1, y1)
128

129

130
def crop_image(img, top_left, offset=(0, 0), dim=32):
131
  """Crop image."""
132
  _, h, w = img.shape
133
  x_offset, y_offset = offset
134
  x1, y1 = top_left
135

136
  x1 += x_offset
137
  x1 = min(max(x1, 0), w - dim)
138
  x2 = x1 + dim
139

140
  y1 += y_offset
141
  y1 = min(max(y1, 0), h - dim)
142
  y2 = y1 + dim
143
  return img[:, y1:y2, x1:x2]
144

145

146
def shift_image(img, shift_at_t, dim=32):
147
  """Shift image."""
148
  # Pad image
149
  padding = dim * 2
150
  padded_img, (x1, y1) = pad_image(img, padding=padding)
151

152
  # Crop with offset
153
  cropped_img = crop_image(padded_img,
154
                           top_left=(x1, y1),
155
                           offset=shift_at_t,
156
                           dim=dim)
157
  return cropped_img
158

159

160
def rotate_image(img, max_rot_angle, dim=32):
161
  """Rotate image."""
162
  # Pad image
163
  padding = int(dim * 1.5)
164
  padded_img, (x1, y1) = pad_image(img, padding=padding)
165

166
  # Rotate image
167
  rotation_deg = np.random.uniform(-max_rot_angle, max_rot_angle)
168
  x_np = padded_img.permute(1, 2, 0).numpy()
169
  x_np = ndimage.rotate(x_np, rotation_deg, reshape=False)
170
  rotated_img = torch.tensor(x_np).permute(2, 0, 1)
171

172
  # Crop image
173
  cropped_img = crop_image(rotated_img,
174
                           top_left=(x1, y1),
175
                           offset=(0, 0),
176
                           dim=dim)
177
  return cropped_img
178

179

180
def translate_image(img, shift_at_t, dim=32):
181
  """Translate image."""
182
  # Pad image
183
  padding = dim * 2
184
  padded_img, (x1, y1) = pad_image(img, padding=padding)
185

186
  # Crop with offset
187
  cropped_img = crop_image(padded_img,
188
                           top_left=(x1, y1),
189
                           offset=shift_at_t,
190
                           dim=dim)
191
  return cropped_img
192

193

194
def change_resolution(img):
195
  """Change resolution of image."""
196
  scale_factor = np.random.choice(list(range(0, 6, 2)))
197
  if scale_factor == 0:
198
    return img
199
  downsample = nn.AvgPool2d(scale_factor)
200
  upsample = nn.UpsamplingNearest2d(scale_factor=scale_factor)
201
  new_res_img = upsample(downsample(img.unsqueeze(dim=1))).squeeze()
202
  return new_res_img
203

204

205
class RandomWalkGenerator:
206
  """Random walk handler."""
207

208
  def __init__(self, n_timesteps, n_total_samples):
209
    """Initializes Randon walk."""
210
    self.n_timesteps = n_timesteps if n_timesteps > 0 else 5
211
    self.n_total_samples = n_total_samples
212
    self._setup_random_walk()
213

214
  def _generate(self, max_vals=(8, 8), move_prob=(1, 1)):
215
    """Generate Randon walk."""
216
    init_loc = (0, 0)
217
    max_x, max_y = max_vals
218
    move_x_prob, move_y_prob = move_prob
219
    locations = [init_loc]
220
    for _ in range(self.n_timesteps - 1):
221
      prev_x, prev_y = locations[-1]
222
      new_x, new_y = prev_x, prev_y
223
      if np.random.uniform() < move_x_prob:
224
        new_x = prev_x + np.random.choice([-1, 1])
225
      if np.random.uniform() < move_y_prob:
226
        new_y = prev_y + np.random.choice([-1, 1])
227
      new_x = max(min(new_x, max_x), -max_x)
228
      new_y = max(min(new_y, max_y), -max_y)
229
      loc_i = (new_x, new_y)
230
      locations.append(loc_i)
231
    return locations
232

233
  def _setup_random_walk(self):
234
    self._sample_shift_schedules = [
235
        self._generate() for _ in range(self.n_total_samples)
236
    ]
237
    np.random.shuffle(self._sample_shift_schedules)
238

239
  def __call__(self, img, sample_i=None, t=None):
240
    if sample_i is None:
241
      sample_i = np.random.randint(len(self._sample_shift_schedules))
242
      n_ts = self._sample_shift_schedules[sample_i]
243
      t = np.random.randint(len(n_ts))
244

245
    shift_at_t = self._sample_shift_schedules[sample_i][t]
246
    noised_img = translate_image(img, shift_at_t)
247
    return noised_img
248

249

250
class PerlinNoise(object):
251
  """Perlin noise handler."""
252

253
  def __init__(self,
254
               half=False,
255
               half_dim='height',
256
               frequency=5,
257
               proportion=0.4,
258
               b_w=True):
259
    """Initializes PerlinNoise generator."""
260

261
    self.half = half
262
    self.half_dim = half_dim
263
    self.frequency = frequency
264
    self.proportion = proportion
265
    self.b_w = b_w
266

267
  def _perlin(self, x, y, seed=0):
268
    """Perlin noise."""
269
    def lerp(a, b, x):
270
      return a + x * (b - a)
271

272
    def fade(t):
273
      return 6 * t**5 - 15 * t**4 + 10 * t**3
274

275
    def gradient(h, x, y):
276
      vectors = torch.tensor([[0, 1], [0, -1], [1, 0], [-1, 0]])
277
      g = vectors[h % 4].float()
278
      return g[:, :, 0] * x + g[:, :, 1] * y
279

280
    # permutation table
281
    np.random.seed(seed)
282

283
    p = torch.randperm(256)
284
    p = torch.stack([p, p]).flatten()
285

286
    # coordinates of the top-left
287
    xi = x.long()
288
    yi = y.long()
289

290
    # internal coordinates
291
    xf = x - xi.float()
292
    yf = y - yi.float()
293

294
    # fade factors
295
    u = fade(xf)
296
    v = fade(yf)
297

298
    x00 = p[p[xi]   + yi]
299
    x01 = p[p[xi]   + yi+1]
300
    x11 = p[p[xi+1] + yi+1]
301
    x10 = p[p[xi+1] + yi]
302

303
    n00 = gradient(x00, xf, yf)
304
    n01 = gradient(x01, xf, yf-1)
305
    n11 = gradient(x11, xf-1, yf-1)
306
    n10 = gradient(x10, xf-1, yf)
307

308
    # combine noises
309
    x1 = lerp(n00, n10, u)
310
    x2 = lerp(n01, n11, u)
311

312
    return lerp(x1, x2, v)
313

314
  def _create_mask(self, dim, seed=None):
315
    """Create mask."""
316
    t_lin = torch.linspace(0, self.frequency, dim)
317
    y, x = torch.meshgrid([t_lin, t_lin])
318

319
    if seed is None:
320
      seed = np.random.randint(1, 1000000)
321

322
    mask = self._perlin(x, y, seed)
323

324
    if self.b_w:
325
      sorted_vals = np.sort(np.ndarray.flatten(mask.data.numpy()))
326
      idx = int(np.round(len(sorted_vals) * (1 - self.proportion)))
327
      threshold = sorted_vals[idx]
328
      mask = (mask < threshold)*1.0
329

330
    return mask
331

332
  def __call__(self, img):
333
    img_shape = img.shape
334

335
    mask = torch.zeros_like(img)
336
    dim = mask.shape[1]
337
    perlin_mask = self._create_mask(dim)
338
    for i in range(mask.shape[0]):
339
      mask[i] = perlin_mask
340

341
    if self.half:
342
      half = img_shape[1]//2
343
      if self.half_dim == 'height':
344
        mask[:, :half, :] = 1
345
      else:
346
        mask[:, :, :half] = 1
347

348
    noisy_image = img * mask
349

350
    return noisy_image
351

352

353
class FocusBlur:
354
  """Average Blurring noise handler."""
355

356
  def __init__(self):
357
    """Initializes averge blurring."""
358
    self._factor_step = 2
359
    self._max_factor = 6
360
    self.res_range = range(0, self._max_factor, self._factor_step)
361

362
  def __call__(self, img):
363
    scale_factor = np.random.choice(list(self.res_range))
364
    if scale_factor == 0:
365
      return img
366

367
    downsample_op = nn.AvgPool2d(scale_factor)
368
    upsample_op = nn.UpsamplingNearest2d(scale_factor=scale_factor)
369
    new_res_img = upsample_op(downsample_op(img.unsqueeze(dim=1))).squeeze()
370
    return new_res_img
371

372

373
class NoiseHandler:
374
  """Noise handler."""
375

376
  def __init__(self,
377
               noise_type,
378
               n_total_samples=1000,
379
               n_total_timesteps=0,
380
               n_timesteps_per_item=0,
381
               n_transition_steps=0):
382
    """Initializes noise handler."""
383
    self.noise_type = noise_type
384
    self.n_total_samples = n_total_samples
385
    self.n_total_timesteps = n_total_timesteps
386
    self.n_timesteps_per_item = n_timesteps_per_item
387
    self.n_transition_steps = n_transition_steps
388

389
    self._min_size = 16
390
    self._max_size = 16
391
    self._max_rot_angle = 60
392

393
    self._random_walker = None
394
    if noise_type == 'translation':
395
      self._random_walker = RandomWalkGenerator(n_total_timesteps,
396
                                                n_total_samples)
397

398
  def __call__(self, x_src, sample_i=None, t=None):
399
    x = x_src.clone()
400
    if self.noise_type in ['occlusion', 'focus']:
401
      x_noised = add_patch(x,
402
                           noise_location='random',
403
                           patch_type=self.noise_type,
404
                           min_size=self._min_size,
405
                           max_size=self._max_size)
406
    elif self.noise_type == 'resolution':
407
      x_noised = FocusBlur()(x)
408
    elif self.noise_type == 'Perlin':
409
      x_noised = PerlinNoise()(x)
410
    elif self.noise_type == 'translation':
411
      x_noised = self._random_walker(x, sample_i, t)
412
    elif self.noise_type == 'rotation':
413
      x_noised = rotate_image(x, max_rot_angle=self._max_rot_angle)
414

415
    return x_noised
416

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.