google-research
200 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Misc. utilities."""
17import numpy as np18import scipy.optimize19import tensorflow as tf20
21
22def l2_loss(prediction, target):23return tf.reduce_mean(tf.math.squared_difference(prediction, target))24
25
26def hungarian_huber_loss(x, y):27"""Huber loss for sets, matching elements with the Hungarian algorithm.28
29This loss is used as reconstruction loss in the paper 'Deep Set Prediction
30Networks' https://arxiv.org/abs/1906.06565, see Eq. 2. For each element in the
31batches we wish to compute min_{pi} ||y_i - x_{pi(i)}||^2 where pi is a
32permutation of the set elements. We first compute the pairwise distances
33between each point in both sets and then match the elements using the scipy
34implementation of the Hungarian algorithm. This is applied for every set in
35the two batches. Note that if the number of points does not match, some of the
36elements will not be matched. As distance function we use the Huber loss.
37
38Args:
39x: Batch of sets of size [batch_size, n_points, dim_points]. Each set in the
40batch contains n_points many points, each represented as a vector of
41dimension dim_points.
42y: Batch of sets of size [batch_size, n_points, dim_points].
43
44Returns:
45Average distance between all sets in the two batches.
46"""
47pairwise_cost = tf.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)(48tf.expand_dims(y, axis=-2), tf.expand_dims(x, axis=-3))49indices = np.array(50list(map(scipy.optimize.linear_sum_assignment, pairwise_cost)))51
52transposed_indices = np.transpose(indices, axes=(0, 2, 1))53
54actual_costs = tf.gather_nd(55pairwise_cost, transposed_indices, batch_dims=1)56
57return tf.reduce_mean(tf.reduce_sum(actual_costs, axis=1))58
59
60def average_precision_clevr(pred, attributes, distance_threshold):61"""Computes the average precision for CLEVR.62
63This function computes the average precision of the predictions specifically
64for the CLEVR dataset. First, we sort the predictions of the model by
65confidence (highest confidence first). Then, for each prediction we check
66whether there was a corresponding object in the input image. A prediction is
67considered a true positive if the discrete features are predicted correctly
68and the predicted position is within a certain distance from the ground truth
69object.
70
71Args:
72pred: Tensor of shape [batch_size, num_elements, dimension] containing
73predictions. The last dimension is expected to be the confidence of the
74prediction.
75attributes: Tensor of shape [batch_size, num_elements, dimension] containing
76ground-truth object properties.
77distance_threshold: Threshold to accept match. -1 indicates no threshold.
78
79Returns:
80Average precision of the predictions.
81"""
82
83[batch_size, _, element_size] = attributes.shape84[_, predicted_elements, _] = pred.shape85
86def unsorted_id_to_image(detection_id, predicted_elements):87"""Find the index of the image from the unsorted detection index."""88return int(detection_id // predicted_elements)89
90flat_size = batch_size * predicted_elements91flat_pred = np.reshape(pred, [flat_size, element_size])92sort_idx = np.argsort(flat_pred[:, -1], axis=0)[::-1] # Reverse order.93
94sorted_predictions = np.take_along_axis(95flat_pred, np.expand_dims(sort_idx, axis=1), axis=0)96idx_sorted_to_unsorted = np.take_along_axis(97np.arange(flat_size), sort_idx, axis=0)98
99def process_targets(target):100"""Unpacks the target into the CLEVR properties."""101coords = target[:3]102object_size = tf.argmax(target[3:5])103material = tf.argmax(target[5:7])104shape = tf.argmax(target[7:10])105color = tf.argmax(target[10:18])106real_obj = target[18]107return coords, object_size, material, shape, color, real_obj108
109true_positives = np.zeros(sorted_predictions.shape[0])110false_positives = np.zeros(sorted_predictions.shape[0])111
112detection_set = set()113
114for detection_id in range(sorted_predictions.shape[0]):115# Extract the current prediction.116current_pred = sorted_predictions[detection_id, :]117# Find which image the prediction belongs to. Get the unsorted index from118# the sorted one and then apply to unsorted_id_to_image function that undoes119# the reshape.120original_image_idx = unsorted_id_to_image(121idx_sorted_to_unsorted[detection_id], predicted_elements)122# Get the ground truth image.123gt_image = attributes[original_image_idx, :, :]124
125# Initialize the maximum distance and the id of the groud-truth object that126# was found.127best_distance = 10000128best_id = None129
130# Unpack the prediction by taking the argmax on the discrete attributes.131(pred_coords, pred_object_size, pred_material, pred_shape, pred_color,132_) = process_targets(current_pred)133
134# Loop through all objects in the ground-truth image to check for hits.135for target_object_id in range(gt_image.shape[0]):136target_object = gt_image[target_object_id, :]137# Unpack the targets taking the argmax on the discrete attributes.138(target_coords, target_object_size, target_material, target_shape,139target_color, target_real_obj) = process_targets(target_object)140# Only consider real objects as matches.141if target_real_obj:142# For the match to be valid all attributes need to be correctly143# predicted.144pred_attr = [pred_object_size, pred_material, pred_shape, pred_color]145target_attr = [146target_object_size, target_material, target_shape, target_color]147match = pred_attr == target_attr148if match:149# If a match was found, we check if the distance is below the150# specified threshold. Recall that we have rescaled the coordinates151# in the dataset from [-3, 3] to [0, 1], both for `target_coords` and152# `pred_coords`. To compare in the original scale, we thus need to153# multiply the distance values by 6 before applying the norm.154distance = np.linalg.norm((target_coords - pred_coords) * 6.)155
156# If this is the best match we've found so far we remember it.157if distance < best_distance:158best_distance = distance159best_id = target_object_id160if best_distance < distance_threshold or distance_threshold == -1:161# We have detected an object correctly within the distance confidence.162# If this object was not detected before it's a true positive.163if best_id is not None:164if (original_image_idx, best_id) not in detection_set:165true_positives[detection_id] = 1166detection_set.add((original_image_idx, best_id))167else:168false_positives[detection_id] = 1169else:170false_positives[detection_id] = 1171else:172false_positives[detection_id] = 1173accumulated_fp = np.cumsum(false_positives)174accumulated_tp = np.cumsum(true_positives)175recall_array = accumulated_tp / np.sum(attributes[:, :, -1])176precision_array = np.divide(accumulated_tp, (accumulated_fp + accumulated_tp))177
178return compute_average_precision(179np.array(precision_array, dtype=np.float32),180np.array(recall_array, dtype=np.float32))181
182
183def compute_average_precision(precision, recall):184"""Computation of the average precision from precision and recall arrays."""185recall = recall.tolist()186precision = precision.tolist()187recall = [0] + recall + [1]188precision = [0] + precision + [0]189
190for i in range(len(precision) - 1, -0, -1):191precision[i - 1] = max(precision[i - 1], precision[i])192
193indices_recall = [194i for i in range(len(recall) - 1) if recall[1:][i] != recall[:-1][i]195]196
197average_precision = 0.198for i in indices_recall:199average_precision += precision[i + 1] * (recall[i + 1] - recall[i])200return average_precision201