google-research

Форк
0
271 строка · 9.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Runs pose embedding model inference.
17

18
Currently, we support loading model inputs from a CSV file. The CSV file is
19
expected to have:
20

21
1. The first row as header, including the following values:
22

23
image/width,
24
image/height,
25
image/object/part/NOSE_TIP/center/x,
26
image/object/part/NOSE_TIP/center/y,
27
image/object/part/NOSE_TIP/score,
28
image/object/part/LEFT_SHOULDER/center/x,
29
image/object/part/LEFT_SHOULDER/center/y,
30
image/object/part/LEFT_SHOULDER/score,
31
image/object/part/RIGHT_SHOULDER/center/x,
32
image/object/part/RIGHT_SHOULDER/center/y,
33
image/object/part/RIGHT_SHOULDER/score,
34
image/object/part/LEFT_ELBOW/center/x,
35
image/object/part/LEFT_ELBOW/center/y,
36
image/object/part/LEFT_ELBOW/score,
37
image/object/part/RIGHT_ELBOW/center/x,
38
image/object/part/RIGHT_ELBOW/center/y,
39
image/object/part/RIGHT_ELBOW/score,
40
image/object/part/LEFT_WRIST/center/x,
41
image/object/part/LEFT_WRIST/center/y,
42
image/object/part/LEFT_WRIST/score,
43
image/object/part/RIGHT_WRIST/center/x,
44
image/object/part/RIGHT_WRIST/center/y,
45
image/object/part/RIGHT_WRIST/score,
46
image/object/part/LEFT_HIP/center/x,
47
image/object/part/LEFT_HIP/center/y,
48
image/object/part/LEFT_HIP/score,
49
image/object/part/RIGHT_HIP/center/x,
50
image/object/part/RIGHT_HIP/center/y,
51
image/object/part/RIGHT_HIP/score,
52
image/object/part/LEFT_KNEE/center/x,
53
image/object/part/LEFT_KNEE/center/y,
54
image/object/part/LEFT_KNEE/score,
55
image/object/part/RIGHT_KNEE/center/x,
56
image/object/part/RIGHT_KNEE/center/y,
57
image/object/part/RIGHT_KNEE/score,
58
image/object/part/LEFT_ANKLE/center/x,
59
image/object/part/LEFT_ANKLE/center/y,
60
image/object/part/LEFT_ANKLE/score,
61
image/object/part/RIGHT_ANKLE/center/x,
62
image/object/part/RIGHT_ANKLE/center/y,
63
image/object/part/RIGHT_ANKLE/score
64

65
2. The following rows are CSVs according to the header, one sample per row.
66

67
Note: The input 2D keypoint coordinate values are required to be normalized by
68
image sizes to within [0, 1].
69

70
The outputs will be written to `output_dir` in the format of CSV, with file
71
base names being the corresponding tensor keys, such as
72
`unnormalized_embeddings.csv`, `embedding_stddevs.csv`, etc.
73

74
In an output CSV file, each row corresponds to an input sample (the same row in
75
the input CSV file).
76

77
"""
78

79
import os
80

81
from absl import app
82
from absl import flags
83
import numpy as np
84
import pandas as pd
85
import tensorflow.compat.v1 as tf
86

87
from poem.core import common
88
from poem.core import input_generator
89
from poem.core import keypoint_profiles
90
from poem.core import keypoint_utils
91
from poem.core import models
92
from poem.core import pipeline_utils
93
tf.disable_v2_behavior()
94

95
FLAGS = flags.FLAGS
96

97
flags.adopt_module_key_flags(common)
98

99
flags.DEFINE_string('input_csv', None, 'Path to input CSV file.')
100
flags.mark_flag_as_required('input_csv')
101

102
flags.DEFINE_string('output_dir', None, 'Path to output directory.')
103
flags.mark_flag_as_required('output_dir')
104

105
flags.DEFINE_string(
106
    'input_keypoint_profile_name_2d', 'LEGACY_2DCOCO13',
107
    'Profile name for 2D keypoints from input sources. Use None to ignore input'
108
    ' 2D keypoints.')
109

110
flags.DEFINE_string('model_input_keypoint_mask_type', 'NO_USE',
111
                    'Usage type of model input keypoint masks.')
112

113
flags.DEFINE_float(
114
    'min_input_keypoint_score_2d', -1.0,
115
    'Minimum threshold for input keypoint score binarization. Use negative '
116
    'value to ignore. Only used if 2D keypoint masks are used.')
117

118
# See `common.SUPPORTED_EMBEDDING_TYPES`.
119
flags.DEFINE_string('embedding_type', 'GAUSSIAN', 'Type of embeddings.')
120

121
flags.DEFINE_integer('embedding_size', 16, 'Size of predicted embeddings.')
122

123
flags.DEFINE_integer(
124
    'num_embedding_components', 1,
125
    'Number of embedding components, e.g., the number of Gaussians in mixture.')
126

127
flags.DEFINE_integer('num_embedding_samples', 20,
128
                     'Number of samples from embedding distributions.')
129

130
# See `common.SUPPORTED_BASE_MODEL_TYPES`.
131
flags.DEFINE_string('base_model_type', 'SIMPLE', 'Type of base model.')
132

133
flags.DEFINE_integer('num_fc_blocks', 2, 'Number of fully connected blocks.')
134

135
flags.DEFINE_integer('num_fcs_per_block', 2,
136
                     'Number of fully connected layers per block.')
137

138
flags.DEFINE_integer('num_hidden_nodes', 1024,
139
                     'Number of nodes in each hidden fully connected layer.')
140

141
flags.DEFINE_integer(
142
    'num_bottleneck_nodes', 0,
143
    'Number of nodes in the bottleneck layer before the output layer(s). '
144
    'Ignored if non-positive.')
145

146
flags.DEFINE_float(
147
    'weight_max_norm', 0.0,
148
    'Maximum norm of fully connected layer weights. Only used if positive.')
149

150
flags.DEFINE_string('checkpoint_path', None,
151
                    'Path to checkpoint to initialize from.')
152
flags.mark_flag_as_required('checkpoint_path')
153

154
flags.DEFINE_bool('use_moving_average', True,
155
                  'Whether to use exponential moving average.')
156

157
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
158

159

160
def read_inputs(keypoint_profile_2d):
161
  """Reads model inputs."""
162
  keypoints_2d_col_names, keypoint_scores_2d_col_names = [], []
163
  for keypoint_name in keypoint_profile_2d.keypoint_names:
164
    keypoints_2d_col_names.append(
165
        (common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
166
         common.TFE_KEY_SUFFIX_KEYPOINT_2D[0]))
167
    keypoints_2d_col_names.append(
168
        (common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
169
         common.TFE_KEY_SUFFIX_KEYPOINT_2D[1]))
170
    keypoint_scores_2d_col_names.append(
171
        (common.TFE_KEY_PREFIX_KEYPOINT_2D + keypoint_name +
172
         common.TFE_KEY_SUFFIX_KEYPOINT_SCORE))
173

174
  with tf.gfile.GFile(FLAGS.input_csv, 'r') as f:
175
    data = pd.read_csv(
176
        f,
177
        usecols=([common.TFE_KEY_IMAGE_HEIGHT, common.TFE_KEY_IMAGE_WIDTH] +
178
                 keypoints_2d_col_names + keypoint_scores_2d_col_names))
179
    image_sizes = tf.constant(
180
        data[[common.TFE_KEY_IMAGE_HEIGHT,
181
              common.TFE_KEY_IMAGE_WIDTH]].to_numpy(dtype=np.float32))
182
    keypoints_2d = tf.constant(
183
        data[keypoints_2d_col_names].to_numpy(dtype=np.float32))
184
    keypoint_scores_2d = tf.constant(
185
        data[keypoint_scores_2d_col_names].to_numpy(dtype=np.float32))
186

187
  keypoints_2d = tf.reshape(
188
      keypoints_2d,
189
      [-1, keypoint_profile_2d.keypoint_num, keypoint_profile_2d.keypoint_dim])
190
  keypoints_2d = keypoint_utils.denormalize_points_by_image_size(
191
      keypoints_2d, image_sizes=image_sizes)
192

193
  if FLAGS.min_input_keypoint_score_2d < 0.0:
194
    keypoint_masks_2d = tf.ones_like(keypoint_scores_2d, dtype=tf.float32)
195
  else:
196
    keypoint_masks_2d = tf.cast(
197
        tf.math.greater_equal(keypoint_scores_2d,
198
                              FLAGS.min_input_keypoint_score_2d),
199
        dtype=tf.float32)
200

201
  return keypoints_2d, keypoint_masks_2d
202

203

204
def main(_):
205
  """Runs inference."""
206
  keypoint_profile_2d = (
207
      keypoint_profiles.create_keypoint_profile_or_die(
208
          FLAGS.input_keypoint_profile_name_2d))
209

210
  g = tf.Graph()
211
  with g.as_default():
212
    keypoints_2d, keypoint_masks_2d = read_inputs(keypoint_profile_2d)
213

214
    model_inputs, _ = input_generator.create_model_input(
215
        keypoints_2d,
216
        keypoint_masks_2d=keypoint_masks_2d,
217
        keypoints_3d=None,
218
        model_input_keypoint_type=common.MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
219
        model_input_keypoint_mask_type=FLAGS.model_input_keypoint_mask_type,
220
        keypoint_profile_2d=keypoint_profile_2d,
221
        # Fix seed for determinism.
222
        seed=1)
223

224
    embedder_fn = models.get_embedder(
225
        base_model_type=FLAGS.base_model_type,
226
        embedding_type=FLAGS.embedding_type,
227
        num_embedding_components=FLAGS.num_embedding_components,
228
        embedding_size=FLAGS.embedding_size,
229
        num_embedding_samples=FLAGS.num_embedding_samples,
230
        is_training=False,
231
        num_fc_blocks=FLAGS.num_fc_blocks,
232
        num_fcs_per_block=FLAGS.num_fcs_per_block,
233
        num_hidden_nodes=FLAGS.num_hidden_nodes,
234
        num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
235
        weight_max_norm=FLAGS.weight_max_norm)
236

237
    outputs, _ = embedder_fn(model_inputs)
238

239
    if FLAGS.use_moving_average:
240
      variables_to_restore = (
241
          pipeline_utils.get_moving_average_variables_to_restore())
242
      saver = tf.train.Saver(variables_to_restore)
243
    else:
244
      saver = tf.train.Saver()
245

246
    scaffold = tf.train.Scaffold(
247
        init_op=tf.global_variables_initializer(), saver=saver)
248
    session_creator = tf.train.ChiefSessionCreator(
249
        scaffold=scaffold,
250
        master=FLAGS.master,
251
        checkpoint_filename_with_path=FLAGS.checkpoint_path)
252

253
    with tf.train.MonitoredSession(
254
        session_creator=session_creator, hooks=None) as sess:
255
      outputs_result = sess.run(outputs)
256

257
  tf.gfile.MakeDirs(FLAGS.output_dir)
258
  for key in [
259
      common.KEY_EMBEDDING_MEANS, common.KEY_EMBEDDING_STDDEVS,
260
      common.KEY_EMBEDDING_SAMPLES
261
  ]:
262
    if key in outputs_result:
263
      output = outputs_result[key]
264
      np.savetxt(
265
          os.path.join(FLAGS.output_dir, key + '.csv'),
266
          output.reshape([output.shape[0], -1]),
267
          delimiter=',')
268

269

270
if __name__ == '__main__':
271
  app.run(main)
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.