google-research
101 строка · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Box mask generation."""
17import jax
18import jax.numpy as jnp
19
20
21def generate_boxes(n_boxes, mask_size, scale, random_aspect_ratio, rng):
22"""Generate boxes for box masking.
23
24Used for CutOut or CutMix regularizers.
25
26Args:
27n_boxes: number of boxes to generate
28mask_size: image size as a `(height, width)` tuple
29scale: either a float in the range (0,1) for a box whose area is a fixed
30proportion of that of the mask
31or 'random_area' for area proportion drawn from U(0,1)
32or 'random_size' for edge length proportion drawn from U(0,1)
33random_aspect_ratio: if True, randomly select aspect ratio, if False fix
34to same AR as image
35rng: a PRNGKey
36
37Returns:
38boxes as a [n_boxes, [y0, x0, y1, x1]] jnp.array
39"""
40k1, k2, k3 = jax.random.split(rng, num=3)
41# Draw area scales
42if isinstance(scale, float):
43area_scales = jnp.ones((n_boxes,), dtype=jnp.float32) * scale
44elif scale == 'random_size':
45area_scales = jax.random.uniform(k1, (n_boxes,), dtype=jnp.float32,
46minval=0.0, maxval=1.0) ** 2
47elif scale == 'random_area':
48area_scales = jax.random.uniform(k1, (n_boxes,), dtype=jnp.float32,
49minval=0.0, maxval=1.0)
50else:
51raise TypeError('Invalid scale {}'.format(scale))
52
53j_mask_size = jnp.array(mask_size, dtype=jnp.float32)
54
55if random_aspect_ratio:
56log_scale = jnp.log(jnp.maximum(area_scales, 1e-8))
57log_aspect_ratios = (jax.random.uniform(
58k2, (n_boxes,), dtype=jnp.float32) * 2 - 1) * log_scale
59aspect_ratios = jnp.exp(log_aspect_ratios)
60root_scale = jnp.sqrt(area_scales)
61root_aspect = jnp.sqrt(aspect_ratios)
62box_props = jnp.stack([root_scale * root_aspect, root_scale / root_aspect],
63axis=1)
64box_sizes = j_mask_size[None, :] * box_props
65else:
66box_sizes = j_mask_size[None, :] * jnp.sqrt(area_scales)[:, None]
67
68box_pos = jax.random.uniform(k3, (n_boxes, 2), dtype=jnp.float32) * \
69(j_mask_size[None, :] - box_sizes)
70
71boxes = jnp.concatenate([box_pos, box_pos + box_sizes], axis=1)
72
73return boxes
74
75
76def box_masks(boxes, mask_size):
77"""Generate box masks, given boxes from `generate_boxes`.
78
79Used for CutOut or CutMix regularizers.
80
81Args:
82boxes: bounding boxes as a [n_boxes, [y0, x0, y1, x1]] tf.Tensor,
83mask_size: image size as a `(height, width)` tuple
84
85Returns:
86Cut Masks as a [n_boxes, height, width, 1] jnp.array
87"""
88y = jnp.arange(0, mask_size[0], dtype=jnp.float32) + 0.5
89x = jnp.arange(0, mask_size[1], dtype=jnp.float32) + 0.5
90
91boxes = boxes.astype(jnp.float32)
92
93y_mask = (y[None, :] >= boxes[:, 0:1]) & \
94(y[None, :] <= boxes[:, 2:3])
95x_mask = (x[None, :] >= boxes[:, 1:2]) & \
96(x[None, :] <= boxes[:, 3:4])
97
98masks = y_mask.astype(jnp.float32)[:, :, None, None] * \
99x_mask.astype(jnp.float32)[:, None, :, None]
100
101return 1.0 - masks
102