google-research

Форк
0
200 строк · 7.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Loads the SYMSOL dataset."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import os
22

23
import numpy as np
24
from scipy.spatial.transform import Rotation
25
import tensorflow as tf
26
import tensorflow_datasets as tfds
27

28

29
SHAPE_NAMES = [
30
    'tet', 'cube', 'icosa', 'cone', 'cyl', 'tetX', 'cylO', 'sphereX'
31
]
32

33

34
def load_symsol(shapes, mode='train', downsample_continuous_gt=0, mock=False):
35
  """Loads the symmetric_solids dataset.
36

37
  Args:
38
    shapes: Can be 'symsol1' or any subset from the 8 shapes in SHAPE_NAMES.
39
    mode: 'train' or 'test', determining the split of the dataset.
40
    downsample_continuous_gt: An integer, the amount to downsample the
41
      continuous symmetry ground truths, if any.  The gt rotations for the cone
42
      and cyl have been discretized to 1 degree increments, but this can be
43
      overkill for evaluation during training. If 0, use the full annotation.
44
    mock: Make random data to avoid downloading it.
45

46
  Returns:
47
    tf.data.Dataset of images with the associated rotation matrices.
48
  """
49
  shape_inds = [SHAPE_NAMES.index(shape) for shape in shapes]
50
  dataset_loaded = False
51
  if not dataset_loaded:
52
    if mock:
53
      with tfds.testing.mock_data(num_examples=100):
54
        dataset = tfds.load('symmetric_solids', split=mode)
55
    else:
56
      dataset = tfds.load('symmetric_solids', split=mode)
57

58
  # Filter the dataset by shape index, and use the full set of equivalent
59
  # rotations only if mode == test
60
  dataset = dataset.filter(
61
      lambda x: tf.reduce_any(tf.equal(x['label_shape'], shape_inds)))
62

63
  annotation_key = 'rotation' if mode == 'train' else 'rotations_equivalent'
64

65
  dataset = dataset.map(
66
      lambda example: (example['image'], example[annotation_key]),
67
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
68

69

70
  dataset = dataset.map(
71
      lambda im, rots: (tf.image.convert_image_dtype(im, tf.float32), rots),
72
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
73
  if mode == 'test' and downsample_continuous_gt and shape_inds[0] in [3, 4]:
74
    # Downsample the full set of equivalent rotations for the cone and cyl.
75
    dataset = dataset.map(
76
        lambda im, rots: (im, rots[::downsample_continuous_gt]),
77
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
78

79
  return dataset
80

81

82
def compute_symsol_symmetries(num_steps_around_continuous=360):
83
  """Return the GT rotation matrices for the symmetric solids.
84

85
  We provide this primarily for the ability to generate the symmetry rotations
86
  for the cone and cylinder at arbitrary resolutions.
87

88
  The first matrix returned for each is the identity.
89

90
  Args:
91
    num_steps_around_continuous: The number of steps taken around each great
92
    circle of equivalent poses for the cylinder and cone.
93

94
  Returns:
95
    A dictionary, indexed by shape name, for the five solids of the SYMSOL
96
    dataset.  The values in the dictionary are [N, 3, 3] rotation matrices,
97
    where N is 12 for tet, 24 for cube, 60 for icosa,
98
    num_steps_around_continuous for cone, and 2*num_steps_around_continuous for
99
    cyl.
100
  """
101
  # Tetrahedron
102
  tet_seeds = [np.eye(3)]
103
  for i in range(3):
104
    tet_seeds.append(np.diag(np.roll([-1, -1, 1], i)))
105
  tet_syms = []
106
  for rotation_matrix in tet_seeds:
107
    tet_syms.append(rotation_matrix)
108
    tet_syms.append(np.roll(rotation_matrix, 1, axis=0))
109
    tet_syms.append(np.roll(rotation_matrix, -1, axis=0))
110

111
  tet_syms = np.stack(tet_syms, 0)
112
  # The syms are specific to the object coordinate axes used during rendering,
113
  # and for the tet the canonical frames were 45 deg from corners of a cube
114
  correction_rot = Rotation.from_euler('xyz',
115
                                       np.float32([0, 0, np.pi / 4.0])).as_matrix()
116
  # So we rotate to the cube frame, where the computed syms (above) are valid
117
  # and then rotate back
118
  tet_syms = correction_rot @ tet_syms @ correction_rot.T
119

120
  # Cube
121
  cube_seeds = [np.eye(3)]
122
  cube_seeds.append(np.float32([[0, 0, -1], [0, -1, 0], [-1, 0, 0]]))
123
  for i in range(3):
124
    cube_seeds.append(np.diag(np.roll([-1, -1, 1], i)))
125
    cube_seeds.append(np.diag(np.roll([-1, 1, 1], i)) @ np.float32([[0, 0, 1],
126
                                                                    [0, 1, 0],
127
                                                                    [1, 0, 0]]))
128
  cube_syms = []
129
  for rotation_matrix in cube_seeds:
130
    cube_syms.append(rotation_matrix)
131
    cube_syms.append(np.roll(rotation_matrix, 1, axis=0))
132
    cube_syms.append(np.roll(rotation_matrix, -1, axis=0))
133
  cube_syms = np.stack(cube_syms, 0)
134

135
  # Icosahedron
136
  golden_ratio = (1 + np.sqrt(5)) / 2.
137
  a, b = np.float32([1, golden_ratio]) / np.sqrt(1 + golden_ratio**2)
138
  icosa_verts = np.float32([[-a, b, 0],
139
                            [a, b, 0],
140
                            [-a, -b, 0],
141
                            [a, -b, 0],
142
                            [0, -a, b],
143
                            [0, a, b],
144
                            [0, -a, -b],
145
                            [0, a, -b],
146
                            [b, 0, -a],
147
                            [b, 0, a],
148
                            [-b, 0, -a],
149
                            [-b, 0, a]])
150
  icosa_syms = [np.eye(3)]
151
  for ind1 in range(12):
152
    for ind2 in range(ind1+1, 12):
153
      icosa_vert1 = icosa_verts[ind1]
154
      icosa_vert2 = icosa_verts[ind2]
155
      if np.abs(np.dot(icosa_vert1, icosa_vert2)) == 1:
156
        continue
157
      for angle1 in np.arange(3) * 2 * np.pi / 5:
158
        for angle2 in np.arange(1, 3) * 2 * np.pi / 5:
159
          rot = Rotation.from_rotvec(
160
              angle1 * icosa_vert1).as_matrix() @ Rotation.from_rotvec(
161
                  angle2 * icosa_vert2).as_matrix()
162
          icosa_syms.append(rot)
163

164
  # Remove duplicates
165
  icosa_syms = np.stack(icosa_syms, 0)
166
  trs = np.trace((icosa_syms[np.newaxis] @ np.transpose(
167
      icosa_syms, [0, 2, 1])[:, np.newaxis]),
168
                 axis1=2,
169
                 axis2=3)
170
  good_inds = []
171
  bad_inds = []
172
  eps = 1e-9
173
  for i in range(icosa_syms.shape[0]):
174
    if i not in bad_inds:
175
      good_inds.append(i)
176
    dups = np.where(trs[i, :] > (3 - eps))
177
    _ = [bad_inds.append(j) for j in dups[0]]
178
  icosa_syms = icosa_syms[good_inds]
179

180
  # Cone
181
  cone_syms = []
182
  for sym_val in np.linspace(0, 2*np.pi, num_steps_around_continuous):
183
    sym_rot = Rotation.from_euler('xyz', np.float32([0, 0, sym_val])).as_matrix()
184
    cone_syms.append(sym_rot)
185
  cone_syms = np.stack(cone_syms, 0)
186

187
  # Cylinder
188
  cyl_syms = []
189
  for sym_val in np.linspace(0, 2*np.pi, num_steps_around_continuous):
190
    for x_rot in [0., np.pi]:
191
      sym_rot = Rotation.from_euler('xyz', np.float32([x_rot, 0,
192
                                                       sym_val])).as_matrix()
193
      cyl_syms.append(sym_rot)
194
  cyl_syms = np.stack(cyl_syms, 0)
195

196
  return dict(tet=tet_syms,
197
              cube=cube_syms,
198
              icosa=icosa_syms,
199
              cyl=cyl_syms,
200
              cone=cone_syms)
201

202

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.