google-research

Форк
0
/
scene_utils.py 
258 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utility functions for scene computation."""
17
import numpy as np
18
import tensorflow as tf
19
from osf import geo_utils
20

21

22
OID2NAME = {
23
    0: 'couch',
24
    1: 'chair',
25
    2: 'table',
26
}
27
NAME2OID = {v: k for k, v in OID2NAME.items()}
28
BKGD_ID = len(OID2NAME)  # Background is the last index.
29

30

31
OBJECT2SHIFT_Z_CENTER = [
32
    'antique_couch',
33
    'armchair',
34
    'cash_register',
35
    'commode',
36
    'drill',
37
    'ottoman',
38
    'rocking_chair',
39
    'side_table',
40
    'sign',
41
    'ukulele',
42
    'wooden_table',
43
]
44

45

46
def extract_box_for_scene_object(scene_info, sid, name, padding=0.0,
47
                                 swap_yz=False, box_delta_t=None,
48
                                 convert_to_tf=False):
49
  """Extracts the bounding box for a scene's object.
50

51
  Note: If the object does not exist in the scene, we return a dummy bounding
52
  box such that rays will never intersect with this box.
53

54
  Args:
55
    scene_info: Dict.
56
    sid: int.
57
    name: str.
58
    padding: float.
59
    swap_yz: bool.
60
    box_delta_t: List of floats.
61
    convert_to_tf: bool. Whether to convert the resulting box tensors into tf
62
      tensors.
63

64
  Returns:
65
    box_dims: [3,] float32, either numpy or tf depending on `convert_to_tf`.
66
    box_center: [3,] float32, either numpy or tf depending on `convert_to_tf`.
67
    box_rotation: [3, 3] float32, either numpy or tf depending on
68
      `convert_to_tf`.
69
  """
70
  if name in scene_info[sid]['objects']:
71
    object_info = scene_info[sid]['objects'][name]
72
  else:
73
    # Create a dummy box such that rays will never intersect it.
74
    object_info = {
75
        'R': np.eye(3),
76
        'T': [np.inf, np.inf, np.inf],
77
        'dims': [0.0, 0.0, 0.0],
78
        'scale': [0.0, 0.0, 0.0]
79
    }
80
  box_center = np.array(object_info['T'], dtype=np.float32)
81
  box_rotation = np.array(object_info['R'], dtype=np.float32)
82
  box_center = np.copy(box_center)
83

84
  # Temporarily add a z translation.
85
  if box_delta_t is not None:
86
    box_center = box_center + np.array(box_delta_t, dtype=np.float32)
87

88
  box_dims = np.array(object_info['dims']) * np.array(object_info['scale'])
89

90
  # Swap y and z if requested.
91
  if swap_yz:
92
    y, z = box_dims[1:]
93
    box_dims[1] = z
94
    box_dims[2] = y
95

96
  # The z transformation may be referring to the bottom of the object instead of
97
  # the center. If that is the case, we add z_dim / 2 to the z transformation.
98
  if name in OBJECT2SHIFT_Z_CENTER:
99
    box_center[2] += box_dims[2] / 2
100

101
  # Apply padding to the box.
102
  box_dims += padding
103

104
  if convert_to_tf:
105
    box_dims = tf.constant(box_dims, dtype=tf.float32)
106
    box_center = tf.constant(box_center, dtype=tf.float32)
107
    box_rotation = tf.constant(box_rotation, dtype=tf.float32)
108
  return box_dims, box_center, box_rotation
109

110

111
def extract_boxes_for_all_scenes(scene_info, name, padding, swap_yz,
112
                                 box_delta_t):
113
  """Extracts the bounding box for all scenes.
114

115
  Note: If the object does not exist in the scene, we return a dummy bounding
116
  box such that rays will never intersect with this box.
117

118
  Args:
119
    scene_info: Dict.
120
    name: str.
121
    padding: float.
122
    swap_yz: bool.
123
    box_delta_t: List of floats.
124

125
  Returns:
126
    all_box_dims: [N, 3,] tf.float32.
127
    all_box_center: [N, 3,] tf.float32.
128
    all_box_rotation: [N, 3, 3] tf.float32.
129

130
    where N is the number of scenes.
131
  """
132
  box_dims_list = []
133
  box_center_list = []
134
  box_rotation_list = []
135
  for sid in scene_info:
136
    box_dims, box_center, box_rotation = extract_box_for_scene_object(
137
        scene_info=scene_info, sid=sid, name=name, padding=padding,
138
        swap_yz=swap_yz, box_delta_t=box_delta_t, convert_to_tf=True)
139
    box_dims_list.append(box_dims)  # List of [3,]
140
    box_center_list.append(box_center)  # List of [3,]
141
    box_rotation_list.append(box_rotation)  # List of [3, 3]
142
  all_box_dims = tf.stack(box_dims_list, axis=0)  # [N, 3]
143
  all_box_center = tf.stack(box_center_list, axis=0)  # [N, 3]
144
  all_box_rotation = tf.stack(box_rotation_list, axis=0)  # [N, 3, 3]
145
  return all_box_dims, all_box_center, all_box_rotation
146

147

148
def extract_object_boxes_for_scenes(name, scene_info, sids, padding, swap_yz,
149
                                    box_delta_t):
150
  """Extracts object boxes given scene IDs.
151

152
  Args:
153
    name: The object name.
154
    scene_info: The scene information.
155
    sids: [R, 1] tf.int32. Scene IDs.
156
    padding: float32. The amount of padding to apply in all dimensions.
157
    swap_yz: bool. Whether to swap y and z box dimensions.
158
    box_delta_t: List of floats.
159

160
  Returns:
161
    sid_box_dims: [R, 3] tf.float32.
162
    sid_box_center: [R, 3] tf.float32.
163
    sid_box_rotation: [R, 3] tf.float32.
164
  """
165
  all_box_dims, all_box_center, all_box_rotation = extract_boxes_for_all_scenes(
166
      scene_info=scene_info, name=name, padding=padding, swap_yz=swap_yz,
167
      box_delta_t=box_delta_t)
168

169
  # Gather the corresponding boxes for the provided sids.
170
  sid_box_dims = tf.gather_nd(  # [R, 3]
171
      params=all_box_dims,  # [R, 3]
172
      indices=sids,  # [R, 1]
173
  )
174
  sid_box_center = tf.gather_nd(  # [R, 3]
175
      params=all_box_center,  # [R, 3]
176
      indices=sids,  # [R, 1]
177
  )
178
  sid_box_rotation = tf.gather_nd(  # [R, 3, 3]
179
      params=all_box_rotation,  # [R, 3, 3]
180
      indices=sids,  # [R, 1]
181
  )
182
  return sid_box_dims, sid_box_center, sid_box_rotation
183

184

185
def extract_w2o_transformations_per_scene(name, scene_info, box_delta_t):
186
  """Extract world-to-object transformations for each scene.
187

188
  Args:
189
    name: str. Object name.
190
    scene_info: dict.
191
    box_delta_t: List of floats.
192

193
  Returns:
194
    w2o_rt_per_scene: [N_scenes, 4, 4] tf.float32.
195
    w2o_r_per_scene: [N_scenes, 4, 4] tf.float32.
196
  """
197
  w2o_rt_per_scene = []
198
  w2o_r_per_scene = []
199
  for sid, info in scene_info.items():
200
    if name not in info['objects']:
201
      # The object does not exist in the scene. We will not end up selecting
202
      # this scene in the parent function, `create_w2o_transformations_tensors`
203
      # anyway.
204
      w2o_rt = geo_utils.construct_rt(r=None, t=None)
205
      w2o_r = geo_utils.construct_rt(r=None, t=None)
206
    else:
207
      _, box_center, box_rotation = extract_box_for_scene_object(
208
          scene_info=scene_info, sid=sid, name=name, box_delta_t=box_delta_t)
209
      w2o_rt = geo_utils.construct_rt(r=box_rotation, t=box_center,
210
                                      inverse=True)
211
      w2o_r = geo_utils.construct_rt(r=box_rotation, t=None, inverse=True)
212
    w2o_rt_per_scene.append(w2o_rt)
213
    w2o_r_per_scene.append(w2o_r)
214
  w2o_rt_per_scene = tf.constant(
215
      np.array(w2o_rt_per_scene), dtype=tf.float32)  # [N_scenes, 4, 4]
216
  w2o_r_per_scene = tf.constant(
217
      np.array(w2o_r_per_scene), dtype=tf.float32)  # [N_scenes, 4, 4]
218
  return w2o_rt_per_scene, w2o_r_per_scene
219

220

221
def extract_light_positions_for_all_scenes(scene_info, light_pos=None):
222
  """Extracts light positions for all scenes.
223

224
  Args:
225
    scene_info: Dict.
226
    light_pos: Hardcoded light pos to override with.
227

228
  Returns:
229
    light_positions: [N_scenes, 3] tf.float32.
230
  """
231
  light_positions = []
232
  for sid in scene_info:
233
    if light_pos is None:
234
      light_positions.append(scene_info[sid]['light_pos'])
235
    else:
236
      light_positions.append(light_pos)
237
  light_positions = tf.constant(light_positions, dtype=tf.float32)  # [N, 3]
238
  return light_positions
239

240

241
def extract_light_positions_for_sids(sids, scene_info, light_pos):
242
  """Extracts light positions given scene IDs.
243

244
  Args:
245
    sids: [N, 1] tf.int32.
246
    scene_info: Dict.
247
    light_pos: Light position.
248

249
  Returns:
250
    light_positions: [N, 3] tf.float32.
251
  """
252
  all_light_positions = extract_light_positions_for_all_scenes(
253
      scene_info=scene_info, light_pos=light_pos)  # [S, 3]
254
  light_positions = tf.gather_nd(  # [N, 3]
255
      params=all_light_positions,  # [S, 3]
256
      indices=sids,  # [N, 1]
257
  )
258
  return light_positions
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.