google-research
258 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utility functions for scene computation."""
17import numpy as np18import tensorflow as tf19from osf import geo_utils20
21
22OID2NAME = {230: 'couch',241: 'chair',252: 'table',26}
27NAME2OID = {v: k for k, v in OID2NAME.items()}28BKGD_ID = len(OID2NAME) # Background is the last index.29
30
31OBJECT2SHIFT_Z_CENTER = [32'antique_couch',33'armchair',34'cash_register',35'commode',36'drill',37'ottoman',38'rocking_chair',39'side_table',40'sign',41'ukulele',42'wooden_table',43]
44
45
46def extract_box_for_scene_object(scene_info, sid, name, padding=0.0,47swap_yz=False, box_delta_t=None,48convert_to_tf=False):49"""Extracts the bounding box for a scene's object.50
51Note: If the object does not exist in the scene, we return a dummy bounding
52box such that rays will never intersect with this box.
53
54Args:
55scene_info: Dict.
56sid: int.
57name: str.
58padding: float.
59swap_yz: bool.
60box_delta_t: List of floats.
61convert_to_tf: bool. Whether to convert the resulting box tensors into tf
62tensors.
63
64Returns:
65box_dims: [3,] float32, either numpy or tf depending on `convert_to_tf`.
66box_center: [3,] float32, either numpy or tf depending on `convert_to_tf`.
67box_rotation: [3, 3] float32, either numpy or tf depending on
68`convert_to_tf`.
69"""
70if name in scene_info[sid]['objects']:71object_info = scene_info[sid]['objects'][name]72else:73# Create a dummy box such that rays will never intersect it.74object_info = {75'R': np.eye(3),76'T': [np.inf, np.inf, np.inf],77'dims': [0.0, 0.0, 0.0],78'scale': [0.0, 0.0, 0.0]79}80box_center = np.array(object_info['T'], dtype=np.float32)81box_rotation = np.array(object_info['R'], dtype=np.float32)82box_center = np.copy(box_center)83
84# Temporarily add a z translation.85if box_delta_t is not None:86box_center = box_center + np.array(box_delta_t, dtype=np.float32)87
88box_dims = np.array(object_info['dims']) * np.array(object_info['scale'])89
90# Swap y and z if requested.91if swap_yz:92y, z = box_dims[1:]93box_dims[1] = z94box_dims[2] = y95
96# The z transformation may be referring to the bottom of the object instead of97# the center. If that is the case, we add z_dim / 2 to the z transformation.98if name in OBJECT2SHIFT_Z_CENTER:99box_center[2] += box_dims[2] / 2100
101# Apply padding to the box.102box_dims += padding103
104if convert_to_tf:105box_dims = tf.constant(box_dims, dtype=tf.float32)106box_center = tf.constant(box_center, dtype=tf.float32)107box_rotation = tf.constant(box_rotation, dtype=tf.float32)108return box_dims, box_center, box_rotation109
110
111def extract_boxes_for_all_scenes(scene_info, name, padding, swap_yz,112box_delta_t):113"""Extracts the bounding box for all scenes.114
115Note: If the object does not exist in the scene, we return a dummy bounding
116box such that rays will never intersect with this box.
117
118Args:
119scene_info: Dict.
120name: str.
121padding: float.
122swap_yz: bool.
123box_delta_t: List of floats.
124
125Returns:
126all_box_dims: [N, 3,] tf.float32.
127all_box_center: [N, 3,] tf.float32.
128all_box_rotation: [N, 3, 3] tf.float32.
129
130where N is the number of scenes.
131"""
132box_dims_list = []133box_center_list = []134box_rotation_list = []135for sid in scene_info:136box_dims, box_center, box_rotation = extract_box_for_scene_object(137scene_info=scene_info, sid=sid, name=name, padding=padding,138swap_yz=swap_yz, box_delta_t=box_delta_t, convert_to_tf=True)139box_dims_list.append(box_dims) # List of [3,]140box_center_list.append(box_center) # List of [3,]141box_rotation_list.append(box_rotation) # List of [3, 3]142all_box_dims = tf.stack(box_dims_list, axis=0) # [N, 3]143all_box_center = tf.stack(box_center_list, axis=0) # [N, 3]144all_box_rotation = tf.stack(box_rotation_list, axis=0) # [N, 3, 3]145return all_box_dims, all_box_center, all_box_rotation146
147
148def extract_object_boxes_for_scenes(name, scene_info, sids, padding, swap_yz,149box_delta_t):150"""Extracts object boxes given scene IDs.151
152Args:
153name: The object name.
154scene_info: The scene information.
155sids: [R, 1] tf.int32. Scene IDs.
156padding: float32. The amount of padding to apply in all dimensions.
157swap_yz: bool. Whether to swap y and z box dimensions.
158box_delta_t: List of floats.
159
160Returns:
161sid_box_dims: [R, 3] tf.float32.
162sid_box_center: [R, 3] tf.float32.
163sid_box_rotation: [R, 3] tf.float32.
164"""
165all_box_dims, all_box_center, all_box_rotation = extract_boxes_for_all_scenes(166scene_info=scene_info, name=name, padding=padding, swap_yz=swap_yz,167box_delta_t=box_delta_t)168
169# Gather the corresponding boxes for the provided sids.170sid_box_dims = tf.gather_nd( # [R, 3]171params=all_box_dims, # [R, 3]172indices=sids, # [R, 1]173)174sid_box_center = tf.gather_nd( # [R, 3]175params=all_box_center, # [R, 3]176indices=sids, # [R, 1]177)178sid_box_rotation = tf.gather_nd( # [R, 3, 3]179params=all_box_rotation, # [R, 3, 3]180indices=sids, # [R, 1]181)182return sid_box_dims, sid_box_center, sid_box_rotation183
184
185def extract_w2o_transformations_per_scene(name, scene_info, box_delta_t):186"""Extract world-to-object transformations for each scene.187
188Args:
189name: str. Object name.
190scene_info: dict.
191box_delta_t: List of floats.
192
193Returns:
194w2o_rt_per_scene: [N_scenes, 4, 4] tf.float32.
195w2o_r_per_scene: [N_scenes, 4, 4] tf.float32.
196"""
197w2o_rt_per_scene = []198w2o_r_per_scene = []199for sid, info in scene_info.items():200if name not in info['objects']:201# The object does not exist in the scene. We will not end up selecting202# this scene in the parent function, `create_w2o_transformations_tensors`203# anyway.204w2o_rt = geo_utils.construct_rt(r=None, t=None)205w2o_r = geo_utils.construct_rt(r=None, t=None)206else:207_, box_center, box_rotation = extract_box_for_scene_object(208scene_info=scene_info, sid=sid, name=name, box_delta_t=box_delta_t)209w2o_rt = geo_utils.construct_rt(r=box_rotation, t=box_center,210inverse=True)211w2o_r = geo_utils.construct_rt(r=box_rotation, t=None, inverse=True)212w2o_rt_per_scene.append(w2o_rt)213w2o_r_per_scene.append(w2o_r)214w2o_rt_per_scene = tf.constant(215np.array(w2o_rt_per_scene), dtype=tf.float32) # [N_scenes, 4, 4]216w2o_r_per_scene = tf.constant(217np.array(w2o_r_per_scene), dtype=tf.float32) # [N_scenes, 4, 4]218return w2o_rt_per_scene, w2o_r_per_scene219
220
221def extract_light_positions_for_all_scenes(scene_info, light_pos=None):222"""Extracts light positions for all scenes.223
224Args:
225scene_info: Dict.
226light_pos: Hardcoded light pos to override with.
227
228Returns:
229light_positions: [N_scenes, 3] tf.float32.
230"""
231light_positions = []232for sid in scene_info:233if light_pos is None:234light_positions.append(scene_info[sid]['light_pos'])235else:236light_positions.append(light_pos)237light_positions = tf.constant(light_positions, dtype=tf.float32) # [N, 3]238return light_positions239
240
241def extract_light_positions_for_sids(sids, scene_info, light_pos):242"""Extracts light positions given scene IDs.243
244Args:
245sids: [N, 1] tf.int32.
246scene_info: Dict.
247light_pos: Light position.
248
249Returns:
250light_positions: [N, 3] tf.float32.
251"""
252all_light_positions = extract_light_positions_for_all_scenes(253scene_info=scene_info, light_pos=light_pos) # [S, 3]254light_positions = tf.gather_nd( # [N, 3]255params=all_light_positions, # [S, 3]256indices=sids, # [N, 1]257)258return light_positions259