google-research
367 строк · 12.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Loss and metric definitions for KeyPose models."""
17
18import numpy as np19import tensorflow as tf20
21from keypose import nets22
23# num_targs = inp.MAX_TARGET_FRAMES
24num_targs = 525
26# Reordering based on symmetry.
27
28
29def make_order(sym, num_kp):30"""Returns all rotations of the <sym> keypoints."""31rots = np.array([sym[-i:] + sym[:-i] for i in range(len(sym))])32rot_list = np.array([range(num_kp)] * len(sym))33for i, rot in enumerate(rot_list):34rot[rots[i]] = rot_list[0][rots[0]]35return rot_list36
37
38def reorder(tensor, order):39"""Re-orders a tensor along the num_kp dimension.40
41Args:
42tensor: has shape [batch, num_kp, ...]
43order: permutation of keypoints.
44
45Returns:
46shape [batch, order, num_kp, ...]
47"""
48return tf.stack(49# order is list of lists, so pylint: disable=not-an-iterable50[tf.stack([tensor[:, x, ...] for x in ord], axis=1) for ord in order],51axis=1)52
53
54def reduce_order(tensor, mask, mean=False):55res = tf.multiply(tensor, mask)56res = tf.reduce_sum(res, axis=1)57if mean:58return tf.reduce_mean(res)59else:60return res61
62
63# Loss functions.
64
65
66def project(tmat, tvec, tvec_transpose=False):67"""Projects homogeneous 3D XYZ coordinates to image uvd coordinates, or vv.68
69Args:
70tmat: has shape [[N,] batch_size, 4, 4].
71tvec: has shape [[N,] batch_size, 4, num_kp] or [batch_size, num_kp, 4].
72tvec_transpose: True if tvec is to be transposed before application.
73
74Returns:
75Has shape [[N,] batch_size, 4, num_kp].
76"""
77tp = tf.matmul(tmat, tvec, transpose_b=tvec_transpose)78# Using <3:4> instead of <3> preserves shape.79tp = tp / (tp[..., 3:4, :] + 1.0e-10)80return tp81
82
83def keypoint_loss_targets(uvd, keys_uvd, mparams):84"""Computes the supervised keypoint loss between computed and gt keypoints.85
86Args:
87uvd: [batch, order, num_targs, 4, num_kp] Predicted set of keypoint uv's
88(pixels).
89keys_uvd: [batch, order, num_targs, 4, num_kp] The ground-truth set of uvdw
90coords.
91mparams: model parameters.
92
93Returns:
94Keypoint projection loss of size [batch, order].
95"""
96print('uvd shape in klt [batch, order, num_targs, 4, num_kp]:', uvd.shape)97print('keys_uvd shape in klt [batch, order, num_targs, 4, num_kp]:',98keys_uvd.shape)99keys_uvd = nets.to_norm(keys_uvd, mparams)100uvd = nets.to_norm(uvd, mparams)101
102wd = tf.square(uvd[..., :2, :] - keys_uvd[..., :2, :])103wd = tf.reduce_sum(wd, axis=[-1, -2]) # uv dist [batch, order, num_targs]104print('wd shape in klt [batch, order, num_targs]:', wd.shape)105wd = tf.reduce_mean(wd, axis=[-1]) # [batch, order]106return wd107
108
109# Compute the reprojection error on the target frames.
110def keypose_loss_proj(uvdw_pos, labels, mparams, num_order):111"""Compute the reprojection error on the target frames.112
113Args:
114uvdw_pos: predicted uvd, always positive.
115labels: sample labels.
116mparams: model parameters.
117num_order: number of order permutations.
118
119Returns:
120Scalar loss.
121"""
122num_kp = mparams.num_kp123to_world = labels['to_world_L'] # [batch, 4, 4]124to_world_order = tf.stack(125[to_world] * num_order, axis=1) # [batch, order, 4, 4]126to_world_order = tf.ensure_shape(127to_world_order, [None, num_order, 4, 4], name='to_world_order')128world_coords = project(to_world_order, uvdw_pos,129True) # [batch, order, 4, num_kp]130world_coords = tf.ensure_shape(131world_coords, [None, num_order, 4, num_kp], name='world_coords')132print('world_coords shape [batch, order, 4, num_kp]:', world_coords.shape)133
134# Target transform and keypoints.135# [batch, num_targs, 4, 4] for transforms136# [batch, num_targs, 4, num_kp] for keypoints (after transpose)137targets_to_uvd = labels['targets_to_uvd_L']138targets_keys_uvd = tf.transpose(labels['targets_keys_uvd_L'], [0, 1, 3, 2])139targets_keys_uvd_order = tf.stack([targets_keys_uvd] * num_order, axis=1)140print('Model fn targets_to_uvd shape [batch, num_targs, 4, 4]:',141targets_to_uvd.shape)142print(143'Model fn targets_keys_uvd_order shape [batch, order, num_targs, 4, '144'num_kp]:', targets_keys_uvd_order.shape)145
146# [batch, order, num_targs, 4, num_kp]147proj_uvds = project(148tf.stack([targets_to_uvd] * num_order, axis=1),149tf.stack([world_coords] * num_targs, axis=2))150proj_uvds = tf.ensure_shape(151proj_uvds, [None, num_order, 5, 4, num_kp], name='proj_uvds')152print('proj_uvds shape [batch, order, num_targs, 4, num_kp]:',153proj_uvds.shape)154loss_proj = keypoint_loss_targets(proj_uvds, targets_keys_uvd_order, mparams)155loss_proj = tf.ensure_shape(loss_proj, [None, num_order], name='loss_proj')156print('loss_proj shape [batch, order]:', loss_proj.shape)157return loss_proj158
159
160# Keypoint loss function, direct comparison of u,v,d.
161# Compares them in normalized image coordinates. For some reason this
162# seems to work better than pixels.
163# uvdw_order: [batch, order, num_kp, 4]
164# keys_uvd_order: [batch, order, num_kp, 4]
165# Returns: [batch, order]
166def keypose_loss_kp(uvdw_order, keys_uvd_order, mparams):167uvdw_order = nets.to_norm_vec(uvdw_order, mparams)168keys_uvd_order = nets.to_norm_vec(keys_uvd_order, mparams)169ret = tf.reduce_sum(170tf.square(uvdw_order[..., :3] - keys_uvd_order[..., :3]),171axis=[-1, -2]) # [batch, order]172return ret173
174
175# Probability coherence loss.
176def keypose_loss_prob(prob_order, prob_label_order):177ret = tf.reduce_sum(178prob_order * prob_label_order, axis=[-1, -2]) # [batch, order]179return tf.reduce_mean(ret, axis=-1)180
181
182# Adjust the gain of loss_proj; return in [0,1].
183def adjust_proj_factor(step, loss_step, minf=0.0):184if loss_step[1] == 0:185return 1.0186step = tf.cast(step, tf.float32)187return tf.maximum(188minf,189tf.minimum((step - loss_step[0]) /190tf.cast(loss_step[1] - loss_step[0], tf.float32), 1.0))191
192
193# Custom loss function for the Keras model in Estimator
194# Args are:
195# Dict of tensors for labels, with keys_uvd, offsets.
196# Tensor of raw uvd values for preds, [batch, num_kp, 3],
197# order is [u,v,d].
198# Note that the loss is batched.
199# This loss only works with tf Estimator, not keras models.
200@tf.function201def keypose_loss(labels, preds, step, mparams, do_print=True):202"""Custom loss function for the Keras model in Estimator.203
204Note that the loss is batched.
205This loss only works with tf Estimator, not keras models.
206
207Args:
208labels: dict of tensors for labels, with keys_uvd, offsets.
209preds: tensor of raw uvd values for preds, [batch, num_kp, 3], order is
210[u,v,d].
211step: training step.
212mparams: model training parameters.
213do_print: True to print loss values at every step.
214
215Returns:
216Scalar loss.
217"""
218num_kp = mparams.num_kp219sym = mparams.sym220order = make_order(sym, num_kp)221num_order = len(order)222
223uvdw = preds['uvdw']224uvdw_pos = preds['uvdw_pos']225uv_pix_raw = preds['uv_pix_raw']226prob = preds['prob'] # [batch, num_kp, resy, resx]227xyzw = tf.transpose(preds['xyzw'], [0, 2, 1]) # [batch, num_kp, 4]228
229uvdw_order = reorder(uvdw, order) # [batch, order, num_kp, 4]230print('uvdw_order shape:', uvdw_order.shape)231uvdw_pos_order = reorder(uvdw_pos, order) # [batch, order, num_kp, 4]232xyzw_order = reorder(xyzw, order) # [batch, order, 4, num_kp]233print('xyzw_order shape:', xyzw_order.shape)234prob_order = reorder(prob, order) # [ batch, order, num_kp, resy, resx]235print('prob_order shape:', prob_order.shape)236
237keys_uvd = labels['keys_uvd_L'] # [batch, num_kp, 4]238# [batch, order, num_kp, 4]239keys_uvd_order = tf.stack([keys_uvd] * num_order, axis=1)240
241loss_kp = keypose_loss_kp(uvdw_order, keys_uvd_order,242mparams) # [batch, order]243loss_kp.set_shape([None, num_order])244# [batch, order]245loss_proj = keypose_loss_proj(uvdw_pos_order, labels, mparams, num_order)246loss_proj.set_shape([None, num_order])247loss_proj_adj = adjust_proj_factor(step, mparams.loss_proj_step)248
249prob_label = labels['prob_label'] # [batch, num_kp, resy, resx]250# [batch, order, num_kp, resy, resx]251prob_label_order = tf.stack([prob_label] * num_order, axis=1)252loss_prob = keypose_loss_prob(prob_order, prob_label_order)253
254loss_order = (255mparams.loss_kp * loss_kp +256mparams.loss_proj * loss_proj_adj * loss_proj +257mparams.loss_prob * loss_prob)258
259print('loss_order shape [batch, order]:', loss_order.shape)260loss = tf.reduce_min(loss_order, axis=1, keepdims=True) # shape [batch, 1]261print('loss shape [batch, 1]:', loss.shape)262
263loss_mask = tf.cast(tf.equal(loss, loss_order), tf.float32) # [batch, order]264loss_mask3 = tf.expand_dims(loss_mask, -1) # [batch, order, 1]265loss_mask4 = tf.expand_dims(loss_mask3, -1) # [batch, order, 1, 1]266print('loss_mask shape [batch, order]:', loss_mask.shape)267
268loss = tf.reduce_mean(loss) # Scalar, reduction over batch.269loss_kp = reduce_order(loss_kp, loss_mask, mean=True)270loss_proj = reduce_order(loss_proj, loss_mask, mean=True)271loss_prob = reduce_order(loss_prob, loss_mask, mean=True)272
273uvdw = reduce_order(uvdw_order, loss_mask4) # [batch, num_kp, 4]274xyzw = reduce_order(xyzw_order, loss_mask4) # [batch, num_kp, 4]275print('xyzw shape:', xyzw.shape)276
277if do_print:278tf.print(279' ',280step,281'Keypose loss:',282loss,283mparams.loss_kp * loss_kp,284mparams.loss_proj * loss_proj_adj * loss_proj,285mparams.loss_prob * loss_prob,286' ',287loss_proj_adj,288uv_pix_raw[0, 0, :3],289uvdw_pos[0, 0, :3],290keys_uvd[0, 0, :3],291summarize=-1)292return loss, uvdw, xyzw293
294
295#
296# Metrics and visualization.
297#
298
299
300def add_keypoints(img, uv, colors=None):301"""Add keypoint markers to an image, using draw_bounding_boxes.302
303Args:
304img: [batch, vh, vw, 3]
305uv: [batch, num_kp, 2], in normalized coords [-1,1], xy order.
306colors: color palette for keypoints.
307
308Returns:
309tf images with drawn keypoints.
310"""
311if colors is None:312colors = tf.constant([[0.0, 1.0, 0.0, 1.0]])313else:314colors = tf.constant(colors)315uv = uv[:, :, :2] * 0.5 + 0.5 # [-1,1] -> [0,1]316keys_bb_ul = tf.stack([uv[:, :, 1], uv[:, :, 0]], axis=2)317keys_bb_lr = keys_bb_ul + 3.0 / tf.cast(tf.shape(img)[1], dtype=tf.float32)318keys_bb = tf.concat([keys_bb_ul, keys_bb_lr], axis=2)319print('Bounding box shape:', keys_bb.shape)320return tf.image.draw_bounding_boxes(img, tf.cast(keys_bb, dtype=tf.float32),321colors)322
323
324def add_keypoints_uv(img, uv, colors=None):325"""Add keypoint markers to an image, using draw_bounding_boxes.326
327Args:
328img: [batch, vh, vw, 3]
329uv: [batch, num_kp, 2], in image coords, xy order.
330colors: color palette for drawing keypoints.
331
332Returns:
333tf images with drawn keypoints.
334"""
335resy = img.shape[1]336resx = img.shape[2]337uvx = uv[:, :, 0] / resx338uvy = uv[:, :, 1] / resy339uv = tf.stack([uvx, uvy], axis=2)340return add_keypoints(img, (uv - 0.5) * 2.0, colors)341
342
343def uv_error(labels, uvdw, _):344diff = labels['keys_uvd_L'][..., :2] - uvdw[..., :2]345return tf.sqrt(tf.reduce_sum(tf.square(diff), axis=-1)) # [batch, num_kp]346
347
348def disp_error(labels, uvdw, _):349return tf.abs(labels['keys_uvd_L'][..., 2] - uvdw[..., 2]) # [batch, num_kp]350
351
352def world_error(labels, xyzw):353xyzw = tf.transpose(xyzw, [0, 2, 1]) # [batch, 4, num_kp]354# [batch, 4, num_kp]355gt_world_coords = project(labels['to_world_L'], labels['keys_uvd_L'], True)356sub = xyzw[:, :3, :] - gt_world_coords[:, :3, :]357wd = tf.square(sub)358wd = tf.reduce_sum(wd, axis=[-2]) # [batch, num_kp] result.359wd = tf.sqrt(wd)360return wd # [batch, num_kp]361
362
363def lt_2cm_error(labels, xyzw):364err = world_error(labels, xyzw)365lt = tf.less(err, 0.02)366return (100.0 * tf.cast(tf.math.count_nonzero(lt, axis=[-1]), tf.float32) /367tf.cast(tf.shape(err)[1], tf.float32)) # [batch]368