google-research

Форк
0
104 строки · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Cow mask generation."""
17
import math
18
import jax
19
import jax.numpy as jnp
20

21
_ROOT_2 = math.sqrt(2.0)
22
_ROOT_2_PI = math.sqrt(2.0 * math.pi)
23

24

25
def gaussian_kernels(sigmas, max_sigma):
26
  """Make Gaussian kernels for Gaussian blur.
27

28
  Args:
29
      sigmas: kernel sigmas as a [N] jax.numpy array
30
      max_sigma: sigma upper limit as a float (this is used to determine
31
        the size of kernel required to fit all kernels)
32

33
  Returns:
34
      a (N, kernel_width) jax.numpy array
35
  """
36
  sigmas = sigmas[:, None]
37
  size = round(max_sigma * 3) * 2 + 1
38
  x = jnp.arange(-size, size + 1)[None, :].astype(jnp.float32)
39
  y = jnp.exp(-0.5 * x ** 2 / sigmas ** 2)
40
  return y / (sigmas * _ROOT_2_PI)
41

42

43
def cow_masks(n_masks, mask_size, log_sigma_range, max_sigma,
44
              prop_range, rng_key):
45
  """Generate Cow Mask.
46

47
  Args:
48
      n_masks: number of masks to generate as an int
49
      mask_size: image size as a `(height, width)` tuple
50
      log_sigma_range: the range of the sigma (smoothing kernel)
51
          parameter in log-space`(log(sigma_min), log(sigma_max))`
52
      max_sigma: smoothing sigma upper limit
53
      prop_range: range from which to draw the proportion `p` that
54
        controls the proportion of pixel in a mask that are 1 vs 0
55
      rng_key: a `jax.random.PRNGKey`
56

57
  Returns:
58
      Cow Masks as a [v, height, width, 1] jax.numpy array
59
  """
60
  rng_k1, rng_k2 = jax.random.split(rng_key)
61
  rng_k2, rng_k3 = jax.random.split(rng_k2)
62

63
  # Draw the per-mask proportion p
64
  p = jax.random.uniform(
65
      rng_k1, (n_masks,), minval=prop_range[0], maxval=prop_range[1],
66
      dtype=jnp.float32)
67
  # Compute threshold factors
68
  threshold_factors = jax.scipy.special.erfinv(2 * p - 1) * _ROOT_2
69

70
  sigmas = jnp.exp(jax.random.uniform(
71
      rng_k2, (n_masks,), minval=log_sigma_range[0],
72
      maxval=log_sigma_range[1]))
73

74
  # Create initial noise with the batch and channel axes swapped so we can use
75
  # tf.nn.depthwise_conv2d to convolve it with the Gaussian kernels
76
  noise = jax.random.normal(rng_k3, (1,) + mask_size + (n_masks,))
77

78
  # Generate a kernel for each sigma
79
  kernels = gaussian_kernels(sigmas, max_sigma)
80
  # kernels: [batch, width] -> [width, batch]
81
  kernels = kernels.transpose((1, 0))
82
  # kernels in y and x
83
  krn_y = kernels[:, None, None, :]
84
  krn_x = kernels[None, :, None, :]
85

86
  # Apply kernels in y and x separately
87
  smooth_noise = jax.lax.conv_general_dilated(
88
      noise, krn_y, (1, 1), 'SAME',
89
      dimension_numbers=('NHWC', 'HWIO', 'NHWC'), feature_group_count=n_masks)
90
  smooth_noise = jax.lax.conv_general_dilated(
91
      smooth_noise, krn_x, (1, 1), 'SAME',
92
      dimension_numbers=('NHWC', 'HWIO', 'NHWC'), feature_group_count=n_masks)
93

94
  # [1, height, width, batch] -> [batch, height, width, 1]
95
  smooth_noise = smooth_noise.transpose((3, 1, 2, 0))
96

97
  # Compute mean and std-dev
98
  noise_mu = smooth_noise.mean(axis=(1, 2, 3), keepdims=True)
99
  noise_sigma = smooth_noise.std(axis=(1, 2, 3), keepdims=True)
100
  # Compute thresholds
101
  thresholds = threshold_factors[:, None, None, None] * noise_sigma + noise_mu
102
  # Apply threshold
103
  masks = (smooth_noise <= thresholds).astype(jnp.float32)
104
  return masks
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.