google-research

Форк
0
367 строк · 12.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Loss and metric definitions for KeyPose models."""
17

18
import numpy as np
19
import tensorflow as tf
20

21
from keypose import nets
22

23
# num_targs = inp.MAX_TARGET_FRAMES
24
num_targs = 5
25

26
# Reordering based on symmetry.
27

28

29
def make_order(sym, num_kp):
30
  """Returns all rotations of the <sym> keypoints."""
31
  rots = np.array([sym[-i:] + sym[:-i] for i in range(len(sym))])
32
  rot_list = np.array([range(num_kp)] * len(sym))
33
  for i, rot in enumerate(rot_list):
34
    rot[rots[i]] = rot_list[0][rots[0]]
35
  return rot_list
36

37

38
def reorder(tensor, order):
39
  """Re-orders a tensor along the num_kp dimension.
40

41
  Args:
42
    tensor: has shape [batch, num_kp, ...]
43
    order: permutation of keypoints.
44

45
  Returns:
46
    shape [batch, order, num_kp, ...]
47
  """
48
  return tf.stack(
49
      # order is list of lists, so pylint: disable=not-an-iterable
50
      [tf.stack([tensor[:, x, ...] for x in ord], axis=1) for ord in order],
51
      axis=1)
52

53

54
def reduce_order(tensor, mask, mean=False):
55
  res = tf.multiply(tensor, mask)
56
  res = tf.reduce_sum(res, axis=1)
57
  if mean:
58
    return tf.reduce_mean(res)
59
  else:
60
    return res
61

62

63
# Loss functions.
64

65

66
def project(tmat, tvec, tvec_transpose=False):
67
  """Projects homogeneous 3D XYZ coordinates to image uvd coordinates, or vv.
68

69
  Args:
70
    tmat: has shape [[N,] batch_size, 4, 4].
71
    tvec: has shape [[N,] batch_size, 4, num_kp] or [batch_size, num_kp, 4].
72
    tvec_transpose: True if tvec is to be transposed before application.
73

74
  Returns:
75
    Has shape [[N,] batch_size, 4, num_kp].
76
  """
77
  tp = tf.matmul(tmat, tvec, transpose_b=tvec_transpose)
78
  # Using <3:4> instead of <3> preserves shape.
79
  tp = tp / (tp[..., 3:4, :] + 1.0e-10)
80
  return tp
81

82

83
def keypoint_loss_targets(uvd, keys_uvd, mparams):
84
  """Computes the supervised keypoint loss between computed and gt keypoints.
85

86
  Args:
87
    uvd: [batch, order, num_targs, 4, num_kp] Predicted set of keypoint uv's
88
      (pixels).
89
    keys_uvd: [batch, order, num_targs, 4, num_kp] The ground-truth set of uvdw
90
      coords.
91
    mparams: model parameters.
92

93
  Returns:
94
    Keypoint projection loss of size [batch, order].
95
  """
96
  print('uvd shape in klt [batch, order, num_targs, 4, num_kp]:', uvd.shape)
97
  print('keys_uvd shape in klt [batch, order, num_targs, 4, num_kp]:',
98
        keys_uvd.shape)
99
  keys_uvd = nets.to_norm(keys_uvd, mparams)
100
  uvd = nets.to_norm(uvd, mparams)
101

102
  wd = tf.square(uvd[..., :2, :] - keys_uvd[..., :2, :])
103
  wd = tf.reduce_sum(wd, axis=[-1, -2])  # uv dist [batch, order, num_targs]
104
  print('wd shape in klt [batch, order, num_targs]:', wd.shape)
105
  wd = tf.reduce_mean(wd, axis=[-1])  # [batch, order]
106
  return wd
107

108

109
# Compute the reprojection error on the target frames.
110
def keypose_loss_proj(uvdw_pos, labels, mparams, num_order):
111
  """Compute the reprojection error on the target frames.
112

113
  Args:
114
    uvdw_pos: predicted uvd, always positive.
115
    labels: sample labels.
116
    mparams: model parameters.
117
    num_order: number of order permutations.
118

119
  Returns:
120
    Scalar loss.
121
  """
122
  num_kp = mparams.num_kp
123
  to_world = labels['to_world_L']  # [batch, 4, 4]
124
  to_world_order = tf.stack(
125
      [to_world] * num_order, axis=1)  # [batch, order, 4, 4]
126
  to_world_order = tf.ensure_shape(
127
      to_world_order, [None, num_order, 4, 4], name='to_world_order')
128
  world_coords = project(to_world_order, uvdw_pos,
129
                         True)  # [batch, order, 4, num_kp]
130
  world_coords = tf.ensure_shape(
131
      world_coords, [None, num_order, 4, num_kp], name='world_coords')
132
  print('world_coords shape [batch, order, 4, num_kp]:', world_coords.shape)
133

134
  # Target transform and keypoints.
135
  # [batch, num_targs, 4, 4] for transforms
136
  # [batch, num_targs, 4, num_kp] for keypoints (after transpose)
137
  targets_to_uvd = labels['targets_to_uvd_L']
138
  targets_keys_uvd = tf.transpose(labels['targets_keys_uvd_L'], [0, 1, 3, 2])
139
  targets_keys_uvd_order = tf.stack([targets_keys_uvd] * num_order, axis=1)
140
  print('Model fn targets_to_uvd shape [batch, num_targs, 4, 4]:',
141
        targets_to_uvd.shape)
142
  print(
143
      'Model fn targets_keys_uvd_order shape [batch, order, num_targs, 4, '
144
      'num_kp]:', targets_keys_uvd_order.shape)
145

146
  # [batch, order, num_targs, 4, num_kp]
147
  proj_uvds = project(
148
      tf.stack([targets_to_uvd] * num_order, axis=1),
149
      tf.stack([world_coords] * num_targs, axis=2))
150
  proj_uvds = tf.ensure_shape(
151
      proj_uvds, [None, num_order, 5, 4, num_kp], name='proj_uvds')
152
  print('proj_uvds shape [batch, order, num_targs, 4, num_kp]:',
153
        proj_uvds.shape)
154
  loss_proj = keypoint_loss_targets(proj_uvds, targets_keys_uvd_order, mparams)
155
  loss_proj = tf.ensure_shape(loss_proj, [None, num_order], name='loss_proj')
156
  print('loss_proj shape [batch, order]:', loss_proj.shape)
157
  return loss_proj
158

159

160
# Keypoint loss function, direct comparison of u,v,d.
161
# Compares them in normalized image coordinates.  For some reason this
162
# seems to work better than pixels.
163
# uvdw_order: [batch, order, num_kp, 4]
164
# keys_uvd_order: [batch, order, num_kp, 4]
165
# Returns: [batch, order]
166
def keypose_loss_kp(uvdw_order, keys_uvd_order, mparams):
167
  uvdw_order = nets.to_norm_vec(uvdw_order, mparams)
168
  keys_uvd_order = nets.to_norm_vec(keys_uvd_order, mparams)
169
  ret = tf.reduce_sum(
170
      tf.square(uvdw_order[..., :3] - keys_uvd_order[..., :3]),
171
      axis=[-1, -2])  # [batch, order]
172
  return ret
173

174

175
# Probability coherence loss.
176
def keypose_loss_prob(prob_order, prob_label_order):
177
  ret = tf.reduce_sum(
178
      prob_order * prob_label_order, axis=[-1, -2])  # [batch, order]
179
  return tf.reduce_mean(ret, axis=-1)
180

181

182
# Adjust the gain of loss_proj; return in [0,1].
183
def adjust_proj_factor(step, loss_step, minf=0.0):
184
  if loss_step[1] == 0:
185
    return 1.0
186
  step = tf.cast(step, tf.float32)
187
  return tf.maximum(
188
      minf,
189
      tf.minimum((step - loss_step[0]) /
190
                 tf.cast(loss_step[1] - loss_step[0], tf.float32), 1.0))
191

192

193
# Custom loss function for the Keras model in Estimator
194
# Args are:
195
#  Dict of tensors for labels, with keys_uvd, offsets.
196
#  Tensor of raw uvd values for preds, [batch, num_kp, 3],
197
#    order is [u,v,d].
198
# Note that the loss is batched.
199
# This loss only works with tf Estimator, not keras models.
200
@tf.function
201
def keypose_loss(labels, preds, step, mparams, do_print=True):
202
  """Custom loss function for the Keras model in Estimator.
203

204
  Note that the loss is batched.
205
  This loss only works with tf Estimator, not keras models.
206

207
  Args:
208
    labels: dict of tensors for labels, with keys_uvd, offsets.
209
    preds: tensor of raw uvd values for preds, [batch, num_kp, 3], order is
210
      [u,v,d].
211
    step: training step.
212
    mparams: model training parameters.
213
    do_print: True to print loss values at every step.
214

215
  Returns:
216
    Scalar loss.
217
  """
218
  num_kp = mparams.num_kp
219
  sym = mparams.sym
220
  order = make_order(sym, num_kp)
221
  num_order = len(order)
222

223
  uvdw = preds['uvdw']
224
  uvdw_pos = preds['uvdw_pos']
225
  uv_pix_raw = preds['uv_pix_raw']
226
  prob = preds['prob']  # [batch, num_kp, resy, resx]
227
  xyzw = tf.transpose(preds['xyzw'], [0, 2, 1])  # [batch, num_kp, 4]
228

229
  uvdw_order = reorder(uvdw, order)  # [batch, order, num_kp, 4]
230
  print('uvdw_order shape:', uvdw_order.shape)
231
  uvdw_pos_order = reorder(uvdw_pos, order)  # [batch, order, num_kp, 4]
232
  xyzw_order = reorder(xyzw, order)  # [batch, order, 4, num_kp]
233
  print('xyzw_order shape:', xyzw_order.shape)
234
  prob_order = reorder(prob, order)  # [ batch, order, num_kp, resy, resx]
235
  print('prob_order shape:', prob_order.shape)
236

237
  keys_uvd = labels['keys_uvd_L']  # [batch, num_kp, 4]
238
  # [batch, order, num_kp, 4]
239
  keys_uvd_order = tf.stack([keys_uvd] * num_order, axis=1)
240

241
  loss_kp = keypose_loss_kp(uvdw_order, keys_uvd_order,
242
                            mparams)  # [batch, order]
243
  loss_kp.set_shape([None, num_order])
244
  # [batch, order]
245
  loss_proj = keypose_loss_proj(uvdw_pos_order, labels, mparams, num_order)
246
  loss_proj.set_shape([None, num_order])
247
  loss_proj_adj = adjust_proj_factor(step, mparams.loss_proj_step)
248

249
  prob_label = labels['prob_label']  # [batch, num_kp, resy, resx]
250
  # [batch, order, num_kp, resy, resx]
251
  prob_label_order = tf.stack([prob_label] * num_order, axis=1)
252
  loss_prob = keypose_loss_prob(prob_order, prob_label_order)
253

254
  loss_order = (
255
      mparams.loss_kp * loss_kp +
256
      mparams.loss_proj * loss_proj_adj * loss_proj +
257
      mparams.loss_prob * loss_prob)
258

259
  print('loss_order shape [batch, order]:', loss_order.shape)
260
  loss = tf.reduce_min(loss_order, axis=1, keepdims=True)  # shape [batch, 1]
261
  print('loss shape [batch, 1]:', loss.shape)
262

263
  loss_mask = tf.cast(tf.equal(loss, loss_order), tf.float32)  # [batch, order]
264
  loss_mask3 = tf.expand_dims(loss_mask, -1)  # [batch, order, 1]
265
  loss_mask4 = tf.expand_dims(loss_mask3, -1)  # [batch, order, 1, 1]
266
  print('loss_mask shape [batch, order]:', loss_mask.shape)
267

268
  loss = tf.reduce_mean(loss)  # Scalar, reduction over batch.
269
  loss_kp = reduce_order(loss_kp, loss_mask, mean=True)
270
  loss_proj = reduce_order(loss_proj, loss_mask, mean=True)
271
  loss_prob = reduce_order(loss_prob, loss_mask, mean=True)
272

273
  uvdw = reduce_order(uvdw_order, loss_mask4)  # [batch, num_kp, 4]
274
  xyzw = reduce_order(xyzw_order, loss_mask4)  # [batch, num_kp, 4]
275
  print('xyzw shape:', xyzw.shape)
276

277
  if do_print:
278
    tf.print(
279
        '  ',
280
        step,
281
        'Keypose loss:',
282
        loss,
283
        mparams.loss_kp * loss_kp,
284
        mparams.loss_proj * loss_proj_adj * loss_proj,
285
        mparams.loss_prob * loss_prob,
286
        '  ',
287
        loss_proj_adj,
288
        uv_pix_raw[0, 0, :3],
289
        uvdw_pos[0, 0, :3],
290
        keys_uvd[0, 0, :3],
291
        summarize=-1)
292
  return loss, uvdw, xyzw
293

294

295
#
296
# Metrics and visualization.
297
#
298

299

300
def add_keypoints(img, uv, colors=None):
301
  """Add keypoint markers to an image, using draw_bounding_boxes.
302

303
  Args:
304
    img: [batch, vh, vw, 3]
305
    uv: [batch, num_kp, 2], in normalized coords [-1,1], xy order.
306
    colors: color palette for keypoints.
307

308
  Returns:
309
    tf images with drawn keypoints.
310
  """
311
  if colors is None:
312
    colors = tf.constant([[0.0, 1.0, 0.0, 1.0]])
313
  else:
314
    colors = tf.constant(colors)
315
  uv = uv[:, :, :2] * 0.5 + 0.5  # [-1,1] -> [0,1]
316
  keys_bb_ul = tf.stack([uv[:, :, 1], uv[:, :, 0]], axis=2)
317
  keys_bb_lr = keys_bb_ul + 3.0 / tf.cast(tf.shape(img)[1], dtype=tf.float32)
318
  keys_bb = tf.concat([keys_bb_ul, keys_bb_lr], axis=2)
319
  print('Bounding box shape:', keys_bb.shape)
320
  return tf.image.draw_bounding_boxes(img, tf.cast(keys_bb, dtype=tf.float32),
321
                                      colors)
322

323

324
def add_keypoints_uv(img, uv, colors=None):
325
  """Add keypoint markers to an image, using draw_bounding_boxes.
326

327
  Args:
328
    img: [batch, vh, vw, 3]
329
    uv: [batch, num_kp, 2], in image coords, xy order.
330
    colors: color palette for drawing keypoints.
331

332
  Returns:
333
    tf images with drawn keypoints.
334
  """
335
  resy = img.shape[1]
336
  resx = img.shape[2]
337
  uvx = uv[:, :, 0] / resx
338
  uvy = uv[:, :, 1] / resy
339
  uv = tf.stack([uvx, uvy], axis=2)
340
  return add_keypoints(img, (uv - 0.5) * 2.0, colors)
341

342

343
def uv_error(labels, uvdw, _):
344
  diff = labels['keys_uvd_L'][..., :2] - uvdw[..., :2]
345
  return tf.sqrt(tf.reduce_sum(tf.square(diff), axis=-1))  # [batch, num_kp]
346

347

348
def disp_error(labels, uvdw, _):
349
  return tf.abs(labels['keys_uvd_L'][..., 2] - uvdw[..., 2])  # [batch, num_kp]
350

351

352
def world_error(labels, xyzw):
353
  xyzw = tf.transpose(xyzw, [0, 2, 1])  # [batch, 4, num_kp]
354
  # [batch, 4, num_kp]
355
  gt_world_coords = project(labels['to_world_L'], labels['keys_uvd_L'], True)
356
  sub = xyzw[:, :3, :] - gt_world_coords[:, :3, :]
357
  wd = tf.square(sub)
358
  wd = tf.reduce_sum(wd, axis=[-2])  # [batch, num_kp] result.
359
  wd = tf.sqrt(wd)
360
  return wd  # [batch, num_kp]
361

362

363
def lt_2cm_error(labels, xyzw):
364
  err = world_error(labels, xyzw)
365
  lt = tf.less(err, 0.02)
366
  return (100.0 * tf.cast(tf.math.count_nonzero(lt, axis=[-1]), tf.float32) /
367
          tf.cast(tf.shape(err)[1], tf.float32))  # [batch]
368

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.