google-research
104 строки · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Cow mask generation."""
17import math
18import jax
19import jax.numpy as jnp
20
21_ROOT_2 = math.sqrt(2.0)
22_ROOT_2_PI = math.sqrt(2.0 * math.pi)
23
24
25def gaussian_kernels(sigmas, max_sigma):
26"""Make Gaussian kernels for Gaussian blur.
27
28Args:
29sigmas: kernel sigmas as a [N] jax.numpy array
30max_sigma: sigma upper limit as a float (this is used to determine
31the size of kernel required to fit all kernels)
32
33Returns:
34a (N, kernel_width) jax.numpy array
35"""
36sigmas = sigmas[:, None]
37size = round(max_sigma * 3) * 2 + 1
38x = jnp.arange(-size, size + 1)[None, :].astype(jnp.float32)
39y = jnp.exp(-0.5 * x ** 2 / sigmas ** 2)
40return y / (sigmas * _ROOT_2_PI)
41
42
43def cow_masks(n_masks, mask_size, log_sigma_range, max_sigma,
44prop_range, rng_key):
45"""Generate Cow Mask.
46
47Args:
48n_masks: number of masks to generate as an int
49mask_size: image size as a `(height, width)` tuple
50log_sigma_range: the range of the sigma (smoothing kernel)
51parameter in log-space`(log(sigma_min), log(sigma_max))`
52max_sigma: smoothing sigma upper limit
53prop_range: range from which to draw the proportion `p` that
54controls the proportion of pixel in a mask that are 1 vs 0
55rng_key: a `jax.random.PRNGKey`
56
57Returns:
58Cow Masks as a [v, height, width, 1] jax.numpy array
59"""
60rng_k1, rng_k2 = jax.random.split(rng_key)
61rng_k2, rng_k3 = jax.random.split(rng_k2)
62
63# Draw the per-mask proportion p
64p = jax.random.uniform(
65rng_k1, (n_masks,), minval=prop_range[0], maxval=prop_range[1],
66dtype=jnp.float32)
67# Compute threshold factors
68threshold_factors = jax.scipy.special.erfinv(2 * p - 1) * _ROOT_2
69
70sigmas = jnp.exp(jax.random.uniform(
71rng_k2, (n_masks,), minval=log_sigma_range[0],
72maxval=log_sigma_range[1]))
73
74# Create initial noise with the batch and channel axes swapped so we can use
75# tf.nn.depthwise_conv2d to convolve it with the Gaussian kernels
76noise = jax.random.normal(rng_k3, (1,) + mask_size + (n_masks,))
77
78# Generate a kernel for each sigma
79kernels = gaussian_kernels(sigmas, max_sigma)
80# kernels: [batch, width] -> [width, batch]
81kernels = kernels.transpose((1, 0))
82# kernels in y and x
83krn_y = kernels[:, None, None, :]
84krn_x = kernels[None, :, None, :]
85
86# Apply kernels in y and x separately
87smooth_noise = jax.lax.conv_general_dilated(
88noise, krn_y, (1, 1), 'SAME',
89dimension_numbers=('NHWC', 'HWIO', 'NHWC'), feature_group_count=n_masks)
90smooth_noise = jax.lax.conv_general_dilated(
91smooth_noise, krn_x, (1, 1), 'SAME',
92dimension_numbers=('NHWC', 'HWIO', 'NHWC'), feature_group_count=n_masks)
93
94# [1, height, width, batch] -> [batch, height, width, 1]
95smooth_noise = smooth_noise.transpose((3, 1, 2, 0))
96
97# Compute mean and std-dev
98noise_mu = smooth_noise.mean(axis=(1, 2, 3), keepdims=True)
99noise_sigma = smooth_noise.std(axis=(1, 2, 3), keepdims=True)
100# Compute thresholds
101thresholds = threshold_factors[:, None, None, None] * noise_sigma + noise_mu
102# Apply threshold
103masks = (smooth_noise <= thresholds).astype(jnp.float32)
104return masks
105