google-research

Форк
0
101 строка · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Box mask generation."""
17
import jax
18
import jax.numpy as jnp
19

20

21
def generate_boxes(n_boxes, mask_size, scale, random_aspect_ratio, rng):
22
  """Generate boxes for box masking.
23

24
  Used for CutOut or CutMix regularizers.
25

26
  Args:
27
    n_boxes: number of boxes to generate
28
    mask_size: image size as a `(height, width)` tuple
29
    scale: either a float in the range (0,1) for a box whose area is a fixed
30
      proportion of that of the mask
31
     or 'random_area' for area proportion drawn from U(0,1)
32
     or 'random_size' for edge length proportion drawn from U(0,1)
33
    random_aspect_ratio: if True, randomly select aspect ratio, if False fix
34
      to same AR as image
35
    rng: a PRNGKey
36

37
  Returns:
38
    boxes as a [n_boxes, [y0, x0, y1, x1]] jnp.array
39
  """
40
  k1, k2, k3 = jax.random.split(rng, num=3)
41
  # Draw area scales
42
  if isinstance(scale, float):
43
    area_scales = jnp.ones((n_boxes,), dtype=jnp.float32) * scale
44
  elif scale == 'random_size':
45
    area_scales = jax.random.uniform(k1, (n_boxes,), dtype=jnp.float32,
46
                                     minval=0.0, maxval=1.0) ** 2
47
  elif scale == 'random_area':
48
    area_scales = jax.random.uniform(k1, (n_boxes,), dtype=jnp.float32,
49
                                     minval=0.0, maxval=1.0)
50
  else:
51
    raise TypeError('Invalid scale {}'.format(scale))
52

53
  j_mask_size = jnp.array(mask_size, dtype=jnp.float32)
54

55
  if random_aspect_ratio:
56
    log_scale = jnp.log(jnp.maximum(area_scales, 1e-8))
57
    log_aspect_ratios = (jax.random.uniform(
58
        k2, (n_boxes,), dtype=jnp.float32) * 2 - 1) * log_scale
59
    aspect_ratios = jnp.exp(log_aspect_ratios)
60
    root_scale = jnp.sqrt(area_scales)
61
    root_aspect = jnp.sqrt(aspect_ratios)
62
    box_props = jnp.stack([root_scale * root_aspect, root_scale / root_aspect],
63
                          axis=1)
64
    box_sizes = j_mask_size[None, :] * box_props
65
  else:
66
    box_sizes = j_mask_size[None, :] * jnp.sqrt(area_scales)[:, None]
67

68
  box_pos = jax.random.uniform(k3, (n_boxes, 2), dtype=jnp.float32) * \
69
      (j_mask_size[None, :] - box_sizes)
70

71
  boxes = jnp.concatenate([box_pos, box_pos + box_sizes], axis=1)
72

73
  return boxes
74

75

76
def box_masks(boxes, mask_size):
77
  """Generate box masks, given boxes from `generate_boxes`.
78

79
  Used for CutOut or CutMix regularizers.
80

81
  Args:
82
      boxes: bounding boxes as a [n_boxes, [y0, x0, y1, x1]] tf.Tensor,
83
      mask_size: image size as a `(height, width)` tuple
84

85
  Returns:
86
      Cut Masks as a [n_boxes, height, width, 1] jnp.array
87
  """
88
  y = jnp.arange(0, mask_size[0], dtype=jnp.float32) + 0.5
89
  x = jnp.arange(0, mask_size[1], dtype=jnp.float32) + 0.5
90

91
  boxes = boxes.astype(jnp.float32)
92

93
  y_mask = (y[None, :] >= boxes[:, 0:1]) & \
94
           (y[None, :] <= boxes[:, 2:3])
95
  x_mask = (x[None, :] >= boxes[:, 1:2]) & \
96
           (x[None, :] <= boxes[:, 3:4])
97

98
  masks = y_mask.astype(jnp.float32)[:, :, None, None] * \
99
          x_mask.astype(jnp.float32)[:, None, :, None]
100

101
  return 1.0 - masks
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.