google-research
271 строка · 9.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Runs pose embedding model inference.
17
18Currently, we support loading model inputs from a CSV file. The CSV file is
19expected to have:
20
211. The first row as header, including the following values:
22
23image/width,
24image/height,
25image/object/part/NOSE_TIP/center/x,
26image/object/part/NOSE_TIP/center/y,
27image/object/part/NOSE_TIP/score,
28image/object/part/LEFT_SHOULDER/center/x,
29image/object/part/LEFT_SHOULDER/center/y,
30image/object/part/LEFT_SHOULDER/score,
31image/object/part/RIGHT_SHOULDER/center/x,
32image/object/part/RIGHT_SHOULDER/center/y,
33image/object/part/RIGHT_SHOULDER/score,
34image/object/part/LEFT_ELBOW/center/x,
35image/object/part/LEFT_ELBOW/center/y,
36image/object/part/LEFT_ELBOW/score,
37image/object/part/RIGHT_ELBOW/center/x,
38image/object/part/RIGHT_ELBOW/center/y,
39image/object/part/RIGHT_ELBOW/score,
40image/object/part/LEFT_WRIST/center/x,
41image/object/part/LEFT_WRIST/center/y,
42image/object/part/LEFT_WRIST/score,
43image/object/part/RIGHT_WRIST/center/x,
44image/object/part/RIGHT_WRIST/center/y,
45image/object/part/RIGHT_WRIST/score,
46image/object/part/LEFT_HIP/center/x,
47image/object/part/LEFT_HIP/center/y,
48image/object/part/LEFT_HIP/score,
49image/object/part/RIGHT_HIP/center/x,
50image/object/part/RIGHT_HIP/center/y,
51image/object/part/RIGHT_HIP/score,
52image/object/part/LEFT_KNEE/center/x,
53image/object/part/LEFT_KNEE/center/y,
54image/object/part/LEFT_KNEE/score,
55image/object/part/RIGHT_KNEE/center/x,
56image/object/part/RIGHT_KNEE/center/y,
57image/object/part/RIGHT_KNEE/score,
58image/object/part/LEFT_ANKLE/center/x,
59image/object/part/LEFT_ANKLE/center/y,
60image/object/part/LEFT_ANKLE/score,
61image/object/part/RIGHT_ANKLE/center/x,
62image/object/part/RIGHT_ANKLE/center/y,
63image/object/part/RIGHT_ANKLE/score
64
652. The following rows are CSVs according to the header, one sample per row.
66
67Note: The input 2D keypoint coordinate values are required to be normalized by
68image sizes to within [0, 1].
69
70The outputs will be written to `output_dir` in the format of CSV, with file
71base names being the corresponding tensor keys, such as
72`unnormalized_embeddings.csv`, `embedding_stddevs.csv`, etc.
73
74In an output CSV file, each row corresponds to an input sample (the same row in
75the input CSV file).
76
77"""
78
79import os
80
81from absl import app
82from absl import flags
83import numpy as np
84import pandas as pd
85import tensorflow.compat.v1 as tf
86
87from poem.core import common
88from poem.core import input_generator
89from poem.core import keypoint_profiles
90from poem.core import keypoint_utils
91from poem.core import models
92from poem.core import pipeline_utils
93tf.disable_v2_behavior()
94
95FLAGS = flags.FLAGS
96
97flags.adopt_module_key_flags(common)
98
99flags.DEFINE_string('input_csv', None, 'Path to input CSV file.')
100flags.mark_flag_as_required('input_csv')
101
102flags.DEFINE_string('output_dir', None, 'Path to output directory.')
103flags.mark_flag_as_required('output_dir')
104
105flags.DEFINE_string(
106'input_keypoint_profile_name_2d', 'LEGACY_2DCOCO13',
107'Profile name for 2D keypoints from input sources. Use None to ignore input'
108' 2D keypoints.')
109
110flags.DEFINE_string('model_input_keypoint_mask_type', 'NO_USE',
111'Usage type of model input keypoint masks.')
112
113flags.DEFINE_float(
114'min_input_keypoint_score_2d', -1.0,
115'Minimum threshold for input keypoint score binarization. Use negative '
116'value to ignore. Only used if 2D keypoint masks are used.')
117
118# See `common.SUPPORTED_EMBEDDING_TYPES`.
119flags.DEFINE_string('embedding_type', 'GAUSSIAN', 'Type of embeddings.')
120
121flags.DEFINE_integer('embedding_size', 16, 'Size of predicted embeddings.')
122
123flags.DEFINE_integer(
124'num_embedding_components', 1,
125'Number of embedding components, e.g., the number of Gaussians in mixture.')
126
127flags.DEFINE_integer('num_embedding_samples', 20,
128'Number of samples from embedding distributions.')
129
130# See `common.SUPPORTED_BASE_MODEL_TYPES`.
131flags.DEFINE_string('base_model_type', 'SIMPLE', 'Type of base model.')
132
133flags.DEFINE_integer('num_fc_blocks', 2, 'Number of fully connected blocks.')
134
135flags.DEFINE_integer('num_fcs_per_block', 2,
136'Number of fully connected layers per block.')
137
138flags.DEFINE_integer('num_hidden_nodes', 1024,
139'Number of nodes in each hidden fully connected layer.')
140
141flags.DEFINE_integer(
142'num_bottleneck_nodes', 0,
143'Number of nodes in the bottleneck layer before the output layer(s). '
144'Ignored if non-positive.')
145
146flags.DEFINE_float(
147'weight_max_norm', 0.0,
148'Maximum norm of fully connected layer weights. Only used if positive.')
149
150flags.DEFINE_string('checkpoint_path', None,
151'Path to checkpoint to initialize from.')
152flags.mark_flag_as_required('checkpoint_path')
153
154flags.DEFINE_bool('use_moving_average', True,
155'Whether to use exponential moving average.')
156
157flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
158
159
160def read_inputs(keypoint_profile_2d):
161"""Reads model inputs."""
162keypoints_2d_col_names, keypoint_scores_2d_col_names = [], []
163for keypoint_name in keypoint_profile_2d.keypoint_names:
164keypoints_2d_col_names.append(
165(common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
166common.TFE_KEY_SUFFIX_KEYPOINT_2D[0]))
167keypoints_2d_col_names.append(
168(common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
169common.TFE_KEY_SUFFIX_KEYPOINT_2D[1]))
170keypoint_scores_2d_col_names.append(
171(common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
172common.TFE_KEY_SUFFIX_KEYPOINT_SCORE))
173
174with tf.gfile.GFile(FLAGS.input_csv, 'r') as f:
175data = pd.read_csv(
176f,
177usecols=([common.TFE_KEY_IMAGE_HEIGHT, common.TFE_KEY_IMAGE_WIDTH] +
178keypoints_2d_col_names + keypoint_scores_2d_col_names))
179image_sizes = tf.constant(
180data[[common.TFE_KEY_IMAGE_HEIGHT,
181common.TFE_KEY_IMAGE_WIDTH]].to_numpy(dtype=np.float32))
182keypoints_2d = tf.constant(
183data[keypoints_2d_col_names].to_numpy(dtype=np.float32))
184keypoint_scores_2d = tf.constant(
185data[keypoint_scores_2d_col_names].to_numpy(dtype=np.float32))
186
187keypoints_2d = tf.reshape(
188keypoints_2d,
189[-1, keypoint_profile_2d.keypoint_num, keypoint_profile_2d.keypoint_dim])
190keypoints_2d = keypoint_utils.denormalize_points_by_image_size(
191keypoints_2d, image_sizes=image_sizes)
192
193if FLAGS.min_input_keypoint_score_2d < 0.0:
194keypoint_masks_2d = tf.ones_like(keypoint_scores_2d, dtype=tf.float32)
195else:
196keypoint_masks_2d = tf.cast(
197tf.math.greater_equal(keypoint_scores_2d,
198FLAGS.min_input_keypoint_score_2d),
199dtype=tf.float32)
200
201return keypoints_2d, keypoint_masks_2d
202
203
204def main(_):
205"""Runs inference."""
206keypoint_profile_2d = (
207keypoint_profiles.create_keypoint_profile_or_die(
208FLAGS.input_keypoint_profile_name_2d))
209
210g = tf.Graph()
211with g.as_default():
212keypoints_2d, keypoint_masks_2d = read_inputs(keypoint_profile_2d)
213
214model_inputs, _ = input_generator.create_model_input(
215keypoints_2d,
216keypoint_masks_2d=keypoint_masks_2d,
217keypoints_3d=None,
218model_input_keypoint_type=common.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
219model_input_keypoint_mask_type=FLAGS.model_input_keypoint_mask_type,
220keypoint_profile_2d=keypoint_profile_2d,
221# Fix seed for determinism.
222seed=1)
223
224embedder_fn = models.get_embedder(
225base_model_type=FLAGS.base_model_type,
226embedding_type=FLAGS.embedding_type,
227num_embedding_components=FLAGS.num_embedding_components,
228embedding_size=FLAGS.embedding_size,
229num_embedding_samples=FLAGS.num_embedding_samples,
230is_training=False,
231num_fc_blocks=FLAGS.num_fc_blocks,
232num_fcs_per_block=FLAGS.num_fcs_per_block,
233num_hidden_nodes=FLAGS.num_hidden_nodes,
234num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
235weight_max_norm=FLAGS.weight_max_norm)
236
237outputs, _ = embedder_fn(model_inputs)
238
239if FLAGS.use_moving_average:
240variables_to_restore = (
241pipeline_utils.get_moving_average_variables_to_restore())
242saver = tf.train.Saver(variables_to_restore)
243else:
244saver = tf.train.Saver()
245
246scaffold = tf.train.Scaffold(
247init_op=tf.global_variables_initializer(), saver=saver)
248session_creator = tf.train.ChiefSessionCreator(
249scaffold=scaffold,
250master=FLAGS.master,
251checkpoint_filename_with_path=FLAGS.checkpoint_path)
252
253with tf.train.MonitoredSession(
254session_creator=session_creator, hooks=None) as sess:
255outputs_result = sess.run(outputs)
256
257tf.gfile.MakeDirs(FLAGS.output_dir)
258for key in [
259common.KEY_EMBEDDING_MEANS, common.KEY_EMBEDDING_STDDEVS,
260common.KEY_EMBEDDING_SAMPLES
261]:
262if key in outputs_result:
263output = outputs_result[key]
264np.savetxt(
265os.path.join(FLAGS.output_dir, key + '.csv'),
266output.reshape([output.shape[0], -1]),
267delimiter=',')
268
269
270if __name__ == '__main__':
271app.run(main)
272