google-research
415 строк · 11.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Noise generators."""
17import numpy as np18from scipy import ndimage19import scipy.stats20import torch21import torch.nn as nn22import torch.nn.functional as F23
24
25def make_kernel(size=3, bounds=3):26"""Create Gaussian kernel."""27kernel_basis = np.linspace(-bounds, bounds, size+1)28
29# Create gaussian kernel30kernel_1d = np.diff(scipy.stats.norm.cdf(kernel_basis))31kernel = np.outer(kernel_1d, kernel_1d)32
33# Normalize kernel34kernel = kernel / kernel.sum()35
36# Reshape to dim for pytorch conv2d and repeat37kernel = torch.tensor(kernel).float()38kernel = kernel.reshape(1, 1, *kernel.size())39kernel = kernel.repeat(3, *[1] * (kernel.dim() - 1))40return kernel41
42
43def add_gaussian_blur(x, k_size=3):44"""Add Gaussian blur to image.45
46Adapted from
47https://github.com/kechan/FastaiPlayground/blob/master/Quick%20Tour%20of%20Data%20Augmentation.ipynb
48Args:
49x: source image.
50k_size: kernel size.
51
52Returns:
53x: Gaussian blurred image.
54"""
55kernel = make_kernel(k_size)56padding = (k_size - 1) // 257
58x = x.unsqueeze(dim=0)59padded_x = F.pad(x, [padding] * x.dim(), mode='reflect')60x = F.conv2d(padded_x, kernel, groups=3)61return x.squeeze()62
63
64def add_patch(tensor,65noise_location,66patch_type=False,67min_size=16,68max_size=32):69"""Add focus/occluding patch."""70_, h, w = tensor.shape71if noise_location == 'random':72w_size = np.random.randint(min_size, max_size+1)73h_size = w_size74x1 = np.random.randint(0, w - w_size + 1)75y1 = np.random.randint(0, h - h_size + 1)76elif noise_location == 'center':77w_size = min_size78h_size = min_size79# Center80x1 = (w - w_size) // 281y1 = (h - h_size) // 282
83x2 = x1 + w_size84y2 = y1 + h_size85
86if patch_type == 'focus':87blured_tensor = add_gaussian_blur(tensor.clone())88blured_tensor[:, y1:y2, x1:x2] = tensor[:, y1:y2, x1:x2]89tensor = blured_tensor.clone()90elif patch_type == 'occlusion':91tensor[:, y1:y2, x1:x2] = 092else:93assert False, f'{patch_type} not implemented!'94return tensor95
96
97def pad_image(img, padding=32 * 2):98"""Pad image."""99c, h, w = img.shape100
101x1 = padding102x2 = padding + w103y1 = padding104y2 = padding + h105
106# Base107x_padded = torch.zeros((c, h + padding * 2, w + padding * 2))108# Left109x_padded[:, y1:y2, :padding] = img[:, :, 0:1].repeat(1, 1, padding)110# Right111x_padded[:, y1:y2, x2:] = img[:, :, w - 1:w].repeat(1, 1, padding)112# Top113x_padded[:, :padding, x1:x2] = img[:, 0:1, :].repeat(1, padding, 1)114# Bottom115x_padded[:, y2:, x1:x2] = img[:, h - 1:h, :].repeat(1, padding, 1)116# Top Left corner117x_padded[:, :padding, :padding] = img[:, 0:1, 0:1].repeat(1, padding, padding)118# Bottom left corner119x_padded[:, y2:, :padding] = img[:, h - 1:h, 0:1].repeat(1, padding, padding)120# Top right corner121x_padded[:, :padding, x2:] = img[:, 0:1, w - 1:w].repeat(1, padding, padding)122# Bottom right corner123x_padded[:, y2:, x2:] = img[:, h - 1:h, w - 1:w].repeat(1, padding, padding)124# Fill in source image125x_padded[:, y1:y2, x1:x2] = img126
127return x_padded, (x1, y1)128
129
130def crop_image(img, top_left, offset=(0, 0), dim=32):131"""Crop image."""132_, h, w = img.shape133x_offset, y_offset = offset134x1, y1 = top_left135
136x1 += x_offset137x1 = min(max(x1, 0), w - dim)138x2 = x1 + dim139
140y1 += y_offset141y1 = min(max(y1, 0), h - dim)142y2 = y1 + dim143return img[:, y1:y2, x1:x2]144
145
146def shift_image(img, shift_at_t, dim=32):147"""Shift image."""148# Pad image149padding = dim * 2150padded_img, (x1, y1) = pad_image(img, padding=padding)151
152# Crop with offset153cropped_img = crop_image(padded_img,154top_left=(x1, y1),155offset=shift_at_t,156dim=dim)157return cropped_img158
159
160def rotate_image(img, max_rot_angle, dim=32):161"""Rotate image."""162# Pad image163padding = int(dim * 1.5)164padded_img, (x1, y1) = pad_image(img, padding=padding)165
166# Rotate image167rotation_deg = np.random.uniform(-max_rot_angle, max_rot_angle)168x_np = padded_img.permute(1, 2, 0).numpy()169x_np = ndimage.rotate(x_np, rotation_deg, reshape=False)170rotated_img = torch.tensor(x_np).permute(2, 0, 1)171
172# Crop image173cropped_img = crop_image(rotated_img,174top_left=(x1, y1),175offset=(0, 0),176dim=dim)177return cropped_img178
179
180def translate_image(img, shift_at_t, dim=32):181"""Translate image."""182# Pad image183padding = dim * 2184padded_img, (x1, y1) = pad_image(img, padding=padding)185
186# Crop with offset187cropped_img = crop_image(padded_img,188top_left=(x1, y1),189offset=shift_at_t,190dim=dim)191return cropped_img192
193
194def change_resolution(img):195"""Change resolution of image."""196scale_factor = np.random.choice(list(range(0, 6, 2)))197if scale_factor == 0:198return img199downsample = nn.AvgPool2d(scale_factor)200upsample = nn.UpsamplingNearest2d(scale_factor=scale_factor)201new_res_img = upsample(downsample(img.unsqueeze(dim=1))).squeeze()202return new_res_img203
204
205class RandomWalkGenerator:206"""Random walk handler."""207
208def __init__(self, n_timesteps, n_total_samples):209"""Initializes Randon walk."""210self.n_timesteps = n_timesteps if n_timesteps > 0 else 5211self.n_total_samples = n_total_samples212self._setup_random_walk()213
214def _generate(self, max_vals=(8, 8), move_prob=(1, 1)):215"""Generate Randon walk."""216init_loc = (0, 0)217max_x, max_y = max_vals218move_x_prob, move_y_prob = move_prob219locations = [init_loc]220for _ in range(self.n_timesteps - 1):221prev_x, prev_y = locations[-1]222new_x, new_y = prev_x, prev_y223if np.random.uniform() < move_x_prob:224new_x = prev_x + np.random.choice([-1, 1])225if np.random.uniform() < move_y_prob:226new_y = prev_y + np.random.choice([-1, 1])227new_x = max(min(new_x, max_x), -max_x)228new_y = max(min(new_y, max_y), -max_y)229loc_i = (new_x, new_y)230locations.append(loc_i)231return locations232
233def _setup_random_walk(self):234self._sample_shift_schedules = [235self._generate() for _ in range(self.n_total_samples)236]237np.random.shuffle(self._sample_shift_schedules)238
239def __call__(self, img, sample_i=None, t=None):240if sample_i is None:241sample_i = np.random.randint(len(self._sample_shift_schedules))242n_ts = self._sample_shift_schedules[sample_i]243t = np.random.randint(len(n_ts))244
245shift_at_t = self._sample_shift_schedules[sample_i][t]246noised_img = translate_image(img, shift_at_t)247return noised_img248
249
250class PerlinNoise(object):251"""Perlin noise handler."""252
253def __init__(self,254half=False,255half_dim='height',256frequency=5,257proportion=0.4,258b_w=True):259"""Initializes PerlinNoise generator."""260
261self.half = half262self.half_dim = half_dim263self.frequency = frequency264self.proportion = proportion265self.b_w = b_w266
267def _perlin(self, x, y, seed=0):268"""Perlin noise."""269def lerp(a, b, x):270return a + x * (b - a)271
272def fade(t):273return 6 * t**5 - 15 * t**4 + 10 * t**3274
275def gradient(h, x, y):276vectors = torch.tensor([[0, 1], [0, -1], [1, 0], [-1, 0]])277g = vectors[h % 4].float()278return g[:, :, 0] * x + g[:, :, 1] * y279
280# permutation table281np.random.seed(seed)282
283p = torch.randperm(256)284p = torch.stack([p, p]).flatten()285
286# coordinates of the top-left287xi = x.long()288yi = y.long()289
290# internal coordinates291xf = x - xi.float()292yf = y - yi.float()293
294# fade factors295u = fade(xf)296v = fade(yf)297
298x00 = p[p[xi] + yi]299x01 = p[p[xi] + yi+1]300x11 = p[p[xi+1] + yi+1]301x10 = p[p[xi+1] + yi]302
303n00 = gradient(x00, xf, yf)304n01 = gradient(x01, xf, yf-1)305n11 = gradient(x11, xf-1, yf-1)306n10 = gradient(x10, xf-1, yf)307
308# combine noises309x1 = lerp(n00, n10, u)310x2 = lerp(n01, n11, u)311
312return lerp(x1, x2, v)313
314def _create_mask(self, dim, seed=None):315"""Create mask."""316t_lin = torch.linspace(0, self.frequency, dim)317y, x = torch.meshgrid([t_lin, t_lin])318
319if seed is None:320seed = np.random.randint(1, 1000000)321
322mask = self._perlin(x, y, seed)323
324if self.b_w:325sorted_vals = np.sort(np.ndarray.flatten(mask.data.numpy()))326idx = int(np.round(len(sorted_vals) * (1 - self.proportion)))327threshold = sorted_vals[idx]328mask = (mask < threshold)*1.0329
330return mask331
332def __call__(self, img):333img_shape = img.shape334
335mask = torch.zeros_like(img)336dim = mask.shape[1]337perlin_mask = self._create_mask(dim)338for i in range(mask.shape[0]):339mask[i] = perlin_mask340
341if self.half:342half = img_shape[1]//2343if self.half_dim == 'height':344mask[:, :half, :] = 1345else:346mask[:, :, :half] = 1347
348noisy_image = img * mask349
350return noisy_image351
352
353class FocusBlur:354"""Average Blurring noise handler."""355
356def __init__(self):357"""Initializes averge blurring."""358self._factor_step = 2359self._max_factor = 6360self.res_range = range(0, self._max_factor, self._factor_step)361
362def __call__(self, img):363scale_factor = np.random.choice(list(self.res_range))364if scale_factor == 0:365return img366
367downsample_op = nn.AvgPool2d(scale_factor)368upsample_op = nn.UpsamplingNearest2d(scale_factor=scale_factor)369new_res_img = upsample_op(downsample_op(img.unsqueeze(dim=1))).squeeze()370return new_res_img371
372
373class NoiseHandler:374"""Noise handler."""375
376def __init__(self,377noise_type,378n_total_samples=1000,379n_total_timesteps=0,380n_timesteps_per_item=0,381n_transition_steps=0):382"""Initializes noise handler."""383self.noise_type = noise_type384self.n_total_samples = n_total_samples385self.n_total_timesteps = n_total_timesteps386self.n_timesteps_per_item = n_timesteps_per_item387self.n_transition_steps = n_transition_steps388
389self._min_size = 16390self._max_size = 16391self._max_rot_angle = 60392
393self._random_walker = None394if noise_type == 'translation':395self._random_walker = RandomWalkGenerator(n_total_timesteps,396n_total_samples)397
398def __call__(self, x_src, sample_i=None, t=None):399x = x_src.clone()400if self.noise_type in ['occlusion', 'focus']:401x_noised = add_patch(x,402noise_location='random',403patch_type=self.noise_type,404min_size=self._min_size,405max_size=self._max_size)406elif self.noise_type == 'resolution':407x_noised = FocusBlur()(x)408elif self.noise_type == 'Perlin':409x_noised = PerlinNoise()(x)410elif self.noise_type == 'translation':411x_noised = self._random_walker(x, sample_i, t)412elif self.noise_type == 'rotation':413x_noised = rotate_image(x, max_rot_angle=self._max_rot_angle)414
415return x_noised416