google-research
200 строк · 7.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Loads the SYMSOL dataset."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24from scipy.spatial.transform import Rotation
25import tensorflow as tf
26import tensorflow_datasets as tfds
27
28
29SHAPE_NAMES = [
30'tet', 'cube', 'icosa', 'cone', 'cyl', 'tetX', 'cylO', 'sphereX'
31]
32
33
34def load_symsol(shapes, mode='train', downsample_continuous_gt=0, mock=False):
35"""Loads the symmetric_solids dataset.
36
37Args:
38shapes: Can be 'symsol1' or any subset from the 8 shapes in SHAPE_NAMES.
39mode: 'train' or 'test', determining the split of the dataset.
40downsample_continuous_gt: An integer, the amount to downsample the
41continuous symmetry ground truths, if any. The gt rotations for the cone
42and cyl have been discretized to 1 degree increments, but this can be
43overkill for evaluation during training. If 0, use the full annotation.
44mock: Make random data to avoid downloading it.
45
46Returns:
47tf.data.Dataset of images with the associated rotation matrices.
48"""
49shape_inds = [SHAPE_NAMES.index(shape) for shape in shapes]
50dataset_loaded = False
51if not dataset_loaded:
52if mock:
53with tfds.testing.mock_data(num_examples=100):
54dataset = tfds.load('symmetric_solids', split=mode)
55else:
56dataset = tfds.load('symmetric_solids', split=mode)
57
58# Filter the dataset by shape index, and use the full set of equivalent
59# rotations only if mode == test
60dataset = dataset.filter(
61lambda x: tf.reduce_any(tf.equal(x['label_shape'], shape_inds)))
62
63annotation_key = 'rotation' if mode == 'train' else 'rotations_equivalent'
64
65dataset = dataset.map(
66lambda example: (example['image'], example[annotation_key]),
67num_parallel_calls=tf.data.experimental.AUTOTUNE)
68
69
70dataset = dataset.map(
71lambda im, rots: (tf.image.convert_image_dtype(im, tf.float32), rots),
72num_parallel_calls=tf.data.experimental.AUTOTUNE)
73if mode == 'test' and downsample_continuous_gt and shape_inds[0] in [3, 4]:
74# Downsample the full set of equivalent rotations for the cone and cyl.
75dataset = dataset.map(
76lambda im, rots: (im, rots[::downsample_continuous_gt]),
77num_parallel_calls=tf.data.experimental.AUTOTUNE)
78
79return dataset
80
81
82def compute_symsol_symmetries(num_steps_around_continuous=360):
83"""Return the GT rotation matrices for the symmetric solids.
84
85We provide this primarily for the ability to generate the symmetry rotations
86for the cone and cylinder at arbitrary resolutions.
87
88The first matrix returned for each is the identity.
89
90Args:
91num_steps_around_continuous: The number of steps taken around each great
92circle of equivalent poses for the cylinder and cone.
93
94Returns:
95A dictionary, indexed by shape name, for the five solids of the SYMSOL
96dataset. The values in the dictionary are [N, 3, 3] rotation matrices,
97where N is 12 for tet, 24 for cube, 60 for icosa,
98num_steps_around_continuous for cone, and 2*num_steps_around_continuous for
99cyl.
100"""
101# Tetrahedron
102tet_seeds = [np.eye(3)]
103for i in range(3):
104tet_seeds.append(np.diag(np.roll([-1, -1, 1], i)))
105tet_syms = []
106for rotation_matrix in tet_seeds:
107tet_syms.append(rotation_matrix)
108tet_syms.append(np.roll(rotation_matrix, 1, axis=0))
109tet_syms.append(np.roll(rotation_matrix, -1, axis=0))
110
111tet_syms = np.stack(tet_syms, 0)
112# The syms are specific to the object coordinate axes used during rendering,
113# and for the tet the canonical frames were 45 deg from corners of a cube
114correction_rot = Rotation.from_euler('xyz',
115np.float32([0, 0, np.pi / 4.0])).as_matrix()
116# So we rotate to the cube frame, where the computed syms (above) are valid
117# and then rotate back
118tet_syms = correction_rot @ tet_syms @ correction_rot.T
119
120# Cube
121cube_seeds = [np.eye(3)]
122cube_seeds.append(np.float32([[0, 0, -1], [0, -1, 0], [-1, 0, 0]]))
123for i in range(3):
124cube_seeds.append(np.diag(np.roll([-1, -1, 1], i)))
125cube_seeds.append(np.diag(np.roll([-1, 1, 1], i)) @ np.float32([[0, 0, 1],
126[0, 1, 0],
127[1, 0, 0]]))
128cube_syms = []
129for rotation_matrix in cube_seeds:
130cube_syms.append(rotation_matrix)
131cube_syms.append(np.roll(rotation_matrix, 1, axis=0))
132cube_syms.append(np.roll(rotation_matrix, -1, axis=0))
133cube_syms = np.stack(cube_syms, 0)
134
135# Icosahedron
136golden_ratio = (1 + np.sqrt(5)) / 2.
137a, b = np.float32([1, golden_ratio]) / np.sqrt(1 + golden_ratio**2)
138icosa_verts = np.float32([[-a, b, 0],
139[a, b, 0],
140[-a, -b, 0],
141[a, -b, 0],
142[0, -a, b],
143[0, a, b],
144[0, -a, -b],
145[0, a, -b],
146[b, 0, -a],
147[b, 0, a],
148[-b, 0, -a],
149[-b, 0, a]])
150icosa_syms = [np.eye(3)]
151for ind1 in range(12):
152for ind2 in range(ind1+1, 12):
153icosa_vert1 = icosa_verts[ind1]
154icosa_vert2 = icosa_verts[ind2]
155if np.abs(np.dot(icosa_vert1, icosa_vert2)) == 1:
156continue
157for angle1 in np.arange(3) * 2 * np.pi / 5:
158for angle2 in np.arange(1, 3) * 2 * np.pi / 5:
159rot = Rotation.from_rotvec(
160angle1 * icosa_vert1).as_matrix() @ Rotation.from_rotvec(
161angle2 * icosa_vert2).as_matrix()
162icosa_syms.append(rot)
163
164# Remove duplicates
165icosa_syms = np.stack(icosa_syms, 0)
166trs = np.trace((icosa_syms[np.newaxis] @ np.transpose(
167icosa_syms, [0, 2, 1])[:, np.newaxis]),
168axis1=2,
169axis2=3)
170good_inds = []
171bad_inds = []
172eps = 1e-9
173for i in range(icosa_syms.shape[0]):
174if i not in bad_inds:
175good_inds.append(i)
176dups = np.where(trs[i, :] > (3 - eps))
177_ = [bad_inds.append(j) for j in dups[0]]
178icosa_syms = icosa_syms[good_inds]
179
180# Cone
181cone_syms = []
182for sym_val in np.linspace(0, 2*np.pi, num_steps_around_continuous):
183sym_rot = Rotation.from_euler('xyz', np.float32([0, 0, sym_val])).as_matrix()
184cone_syms.append(sym_rot)
185cone_syms = np.stack(cone_syms, 0)
186
187# Cylinder
188cyl_syms = []
189for sym_val in np.linspace(0, 2*np.pi, num_steps_around_continuous):
190for x_rot in [0., np.pi]:
191sym_rot = Rotation.from_euler('xyz', np.float32([x_rot, 0,
192sym_val])).as_matrix()
193cyl_syms.append(sym_rot)
194cyl_syms = np.stack(cyl_syms, 0)
195
196return dict(tet=tet_syms,
197cube=cube_syms,
198icosa=icosa_syms,
199cyl=cyl_syms,
200cone=cone_syms)
201
202