google-research

Форк
0
200 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Misc. utilities."""
17
import numpy as np
18
import scipy.optimize
19
import tensorflow as tf
20

21

22
def l2_loss(prediction, target):
23
  return tf.reduce_mean(tf.math.squared_difference(prediction, target))
24

25

26
def hungarian_huber_loss(x, y):
27
  """Huber loss for sets, matching elements with the Hungarian algorithm.
28

29
  This loss is used as reconstruction loss in the paper 'Deep Set Prediction
30
  Networks' https://arxiv.org/abs/1906.06565, see Eq. 2. For each element in the
31
  batches we wish to compute min_{pi} ||y_i - x_{pi(i)}||^2 where pi is a
32
  permutation of the set elements. We first compute the pairwise distances
33
  between each point in both sets and then match the elements using the scipy
34
  implementation of the Hungarian algorithm. This is applied for every set in
35
  the two batches. Note that if the number of points does not match, some of the
36
  elements will not be matched. As distance function we use the Huber loss.
37

38
  Args:
39
    x: Batch of sets of size [batch_size, n_points, dim_points]. Each set in the
40
      batch contains n_points many points, each represented as a vector of
41
      dimension dim_points.
42
    y: Batch of sets of size [batch_size, n_points, dim_points].
43

44
  Returns:
45
    Average distance between all sets in the two batches.
46
  """
47
  pairwise_cost = tf.losses.Huber(reduction=tf.keras.losses.Reduction.NONE)(
48
      tf.expand_dims(y, axis=-2), tf.expand_dims(x, axis=-3))
49
  indices = np.array(
50
      list(map(scipy.optimize.linear_sum_assignment, pairwise_cost)))
51

52
  transposed_indices = np.transpose(indices, axes=(0, 2, 1))
53

54
  actual_costs = tf.gather_nd(
55
      pairwise_cost, transposed_indices, batch_dims=1)
56

57
  return tf.reduce_mean(tf.reduce_sum(actual_costs, axis=1))
58

59

60
def average_precision_clevr(pred, attributes, distance_threshold):
61
  """Computes the average precision for CLEVR.
62

63
  This function computes the average precision of the predictions specifically
64
  for the CLEVR dataset. First, we sort the predictions of the model by
65
  confidence (highest confidence first). Then, for each prediction we check
66
  whether there was a corresponding object in the input image. A prediction is
67
  considered a true positive if the discrete features are predicted correctly
68
  and the predicted position is within a certain distance from the ground truth
69
  object.
70

71
  Args:
72
    pred: Tensor of shape [batch_size, num_elements, dimension] containing
73
      predictions. The last dimension is expected to be the confidence of the
74
      prediction.
75
    attributes: Tensor of shape [batch_size, num_elements, dimension] containing
76
      ground-truth object properties.
77
    distance_threshold: Threshold to accept match. -1 indicates no threshold.
78

79
  Returns:
80
    Average precision of the predictions.
81
  """
82

83
  [batch_size, _, element_size] = attributes.shape
84
  [_, predicted_elements, _] = pred.shape
85

86
  def unsorted_id_to_image(detection_id, predicted_elements):
87
    """Find the index of the image from the unsorted detection index."""
88
    return int(detection_id // predicted_elements)
89

90
  flat_size = batch_size * predicted_elements
91
  flat_pred = np.reshape(pred, [flat_size, element_size])
92
  sort_idx = np.argsort(flat_pred[:, -1], axis=0)[::-1]  # Reverse order.
93

94
  sorted_predictions = np.take_along_axis(
95
      flat_pred, np.expand_dims(sort_idx, axis=1), axis=0)
96
  idx_sorted_to_unsorted = np.take_along_axis(
97
      np.arange(flat_size), sort_idx, axis=0)
98

99
  def process_targets(target):
100
    """Unpacks the target into the CLEVR properties."""
101
    coords = target[:3]
102
    object_size = tf.argmax(target[3:5])
103
    material = tf.argmax(target[5:7])
104
    shape = tf.argmax(target[7:10])
105
    color = tf.argmax(target[10:18])
106
    real_obj = target[18]
107
    return coords, object_size, material, shape, color, real_obj
108

109
  true_positives = np.zeros(sorted_predictions.shape[0])
110
  false_positives = np.zeros(sorted_predictions.shape[0])
111

112
  detection_set = set()
113

114
  for detection_id in range(sorted_predictions.shape[0]):
115
    # Extract the current prediction.
116
    current_pred = sorted_predictions[detection_id, :]
117
    # Find which image the prediction belongs to. Get the unsorted index from
118
    # the sorted one and then apply to unsorted_id_to_image function that undoes
119
    # the reshape.
120
    original_image_idx = unsorted_id_to_image(
121
        idx_sorted_to_unsorted[detection_id], predicted_elements)
122
    # Get the ground truth image.
123
    gt_image = attributes[original_image_idx, :, :]
124

125
    # Initialize the maximum distance and the id of the groud-truth object that
126
    # was found.
127
    best_distance = 10000
128
    best_id = None
129

130
    # Unpack the prediction by taking the argmax on the discrete attributes.
131
    (pred_coords, pred_object_size, pred_material, pred_shape, pred_color,
132
     _) = process_targets(current_pred)
133

134
    # Loop through all objects in the ground-truth image to check for hits.
135
    for target_object_id in range(gt_image.shape[0]):
136
      target_object = gt_image[target_object_id, :]
137
      # Unpack the targets taking the argmax on the discrete attributes.
138
      (target_coords, target_object_size, target_material, target_shape,
139
       target_color, target_real_obj) = process_targets(target_object)
140
      # Only consider real objects as matches.
141
      if target_real_obj:
142
        # For the match to be valid all attributes need to be correctly
143
        # predicted.
144
        pred_attr = [pred_object_size, pred_material, pred_shape, pred_color]
145
        target_attr = [
146
            target_object_size, target_material, target_shape, target_color]
147
        match = pred_attr == target_attr
148
        if match:
149
          # If a match was found, we check if the distance is below the
150
          # specified threshold. Recall that we have rescaled the coordinates
151
          # in the dataset from [-3, 3] to [0, 1], both for `target_coords` and
152
          # `pred_coords`. To compare in the original scale, we thus need to
153
          # multiply the distance values by 6 before applying the norm.
154
          distance = np.linalg.norm((target_coords - pred_coords) * 6.)
155

156
          # If this is the best match we've found so far we remember it.
157
          if distance < best_distance:
158
            best_distance = distance
159
            best_id = target_object_id
160
    if best_distance < distance_threshold or distance_threshold == -1:
161
      # We have detected an object correctly within the distance confidence.
162
      # If this object was not detected before it's a true positive.
163
      if best_id is not None:
164
        if (original_image_idx, best_id) not in detection_set:
165
          true_positives[detection_id] = 1
166
          detection_set.add((original_image_idx, best_id))
167
        else:
168
          false_positives[detection_id] = 1
169
      else:
170
        false_positives[detection_id] = 1
171
    else:
172
      false_positives[detection_id] = 1
173
  accumulated_fp = np.cumsum(false_positives)
174
  accumulated_tp = np.cumsum(true_positives)
175
  recall_array = accumulated_tp / np.sum(attributes[:, :, -1])
176
  precision_array = np.divide(accumulated_tp, (accumulated_fp + accumulated_tp))
177

178
  return compute_average_precision(
179
      np.array(precision_array, dtype=np.float32),
180
      np.array(recall_array, dtype=np.float32))
181

182

183
def compute_average_precision(precision, recall):
184
  """Computation of the average precision from precision and recall arrays."""
185
  recall = recall.tolist()
186
  precision = precision.tolist()
187
  recall = [0] + recall + [1]
188
  precision = [0] + precision + [0]
189

190
  for i in range(len(precision) - 1, -0, -1):
191
    precision[i - 1] = max(precision[i - 1], precision[i])
192

193
  indices_recall = [
194
      i for i in range(len(recall) - 1) if recall[1:][i] != recall[:-1][i]
195
  ]
196

197
  average_precision = 0.
198
  for i in indices_recall:
199
    average_precision += precision[i + 1] * (recall[i + 1] - recall[i])
200
  return average_precision
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.