google-research

Форк
0
350 строк · 12.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Metric util functions for contrastive learning experiments.
17

18
Defines custom metrics and metrics helper classes for experiments where we want
19
to track performance along multiple dataset-specific axes.
20
"""
21

22
import abc
23
import os
24

25
from absl import flags
26
from absl import logging
27

28
import tensorflow.compat.v2 as tf
29

30
FLAGS = flags.FLAGS
31

32

33
class R2Metric(tf.keras.metrics.Metric):
34
  """Compute and store running R^2 score."""
35

36
  def __init__(self, tss, name='R^2', **kwargs):
37
    super().__init__(name=name, **kwargs)
38
    if tf.rank(tss) > 0:
39
      # TODO(zeef) Find a way to store TSS and RSS values as arrays.
40
      # Currently, resetting the metric will throw an error if RSS is not 0-dim.
41
      self.tss = tf.reduce_mean(tss)
42
    else:
43
      self.tss = tss
44
    self.rss = self.add_weight(name='rss', initializer='zeros')
45

46
  def update_state(self, actual, preds):
47
    res_squared = tf.reduce_mean(
48
        (actual - preds)**2, axis=tf.range(1, tf.rank(actual)))
49
    res_squared_summed = tf.reduce_sum(res_squared)
50
    self.rss.assign_add(res_squared_summed)
51

52
  def result(self):
53
    return 1 - self.rss / self.tss
54

55

56
class DspritesAccuracy(tf.keras.metrics.Metric):
57
  """A measure of correctness for dsprites (non-shape) latents.
58

59
  We treat a predicted set of latents as 'correct' if it is closer to the
60
  correct values than to nearby values.
61
  """
62

63
  def __init__(self, tolerance, name='dsprites_accuracy', **kwargs):
64
    super().__init__(name=name, **kwargs)
65
    self.tolerance = tf.constant(tolerance)
66
    self.correct = self.add_weight(name='correct', initializer='zeros')
67
    self.seen = self.add_weight(name='seen', initializer='zeros')
68

69
  def update_state(self, actual, preds):
70
    is_correct = tf.math.abs(actual - preds) < self.tolerance
71
    is_correct = tf.cast(is_correct, tf.float32)
72
    # need to count no. of examples seen but can't use .shape[0] in graph mode
73
    is_seen = tf.reduce_sum(tf.ones_like(is_correct))
74
    self.correct.assign_add(tf.reduce_sum(is_correct))
75
    self.seen.assign_add(is_seen)
76

77
  def result(self):
78
    return self.correct / self.seen
79

80

81
class DspritesShapeAccuracy(tf.keras.metrics.Metric):
82
  """A measure of correctness for dsprites shape prediction.
83

84
  We treat a prediction as 'correct' if it is close to 1 for the correct shape
85
  AND close to 0 for the other shapes.
86
  """
87

88
  def __init__(self, tolerance, name='dsprites_shape_accuracy', **kwargs):
89
    super().__init__(name=name, **kwargs)
90
    self.tolerance = tf.constant(tolerance)
91
    self.correct = self.add_weight(name='correct', initializer='zeros')
92
    self.seen = self.add_weight(name='seen', initializer='zeros')
93

94
  def update_state(self, actual, preds):
95
    # actual and preds are shape (batch_size, 3)
96
    is_correct = tf.math.abs(actual - preds) < self.tolerance
97
    # require all three shape predictions to be accurate to count as correct
98
    is_correct = tf.experimental.numpy.all(is_correct, axis=1)
99
    is_correct = tf.cast(is_correct, tf.float32)
100
    # need to count no. of examples seen but can't use .shape[0] in graph mode
101
    is_seen = tf.reduce_sum(tf.ones_like(is_correct))
102
    self.correct.assign_add(tf.reduce_sum(is_correct))
103
    self.seen.assign_add(is_seen)
104

105
  def result(self):
106
    return self.correct / self.seen
107

108

109
class MetricsInterface(object):
110
  """Interface for managing metric definition, collection, and updating.
111
  """
112

113
  @abc.abstractmethod
114
  def __init__(self, data_dir):
115
    pass
116

117
  def setup_metrics(self):
118
    """Creates a consistent way of managing metrics for each individual latent.
119

120
    Returns:
121
      Dictionary with a key for each axis to be measured.
122
    """
123
    raise NotImplementedError()
124

125
  def update_metrics(self):
126
    """Updates the metric values for each axis created in setup_metrics."""
127
    raise NotImplementedError()
128

129
  def setup_summary_writers(self, data_dir, writer_names):
130
    """Creates a tf summary writer for each name in writer_names.
131

132
    Args:
133
      data_dir: Str, path to folder where summary writers should write to.
134
      writer_names: List of writer names, e.g. ['train', 'test'] or
135
        ['eval_overall', 'eval_shape_accuracy', 'eval_position'], etc.
136

137
    Returns:
138
      Dict with (key,value) pairs of the form ('writer_name': writer).
139
    """
140
    all_summary_writers = {}
141
    for name in writer_names:
142
      log_dir = os.path.join(data_dir, name)
143
      summary_writer = tf.summary.create_file_writer(log_dir)
144
      all_summary_writers[name] = summary_writer
145
    return all_summary_writers
146

147
  def write_metrics_to_summary(self, all_metrics, global_step):
148
    """Updates the summary writers at the end of each step.
149

150
    Call this at the end of each step, from within a
151
    `with summary_writer_name.as_default():` context.
152

153
    Args:
154
      all_metrics: List of tf.keras.metrics objects.
155
      global_step: Int.
156
    """
157
    for metric in all_metrics:
158
      metric_value = metric.result().numpy().astype(float)
159
      logging.info('Step: [%d] %s = %f', global_step, metric.name, metric_value)
160
      tf.summary.scalar(metric.name, metric_value, step=global_step)
161

162

163
class DspritesEvalMetrics(MetricsInterface):
164
  """Handles storing and updating metrics during dsprites evaluation loops.
165

166
  Simplifies the process of collecting metrics on multiple individual latents
167
  as well as overall performance by abstracting it away from the training loop.
168
  To add a new axis of metric collection: simply specify its name and behaviour
169
  in setup_metrics and update_metrics, and (optionally) create a separate
170
  summary writer for it by adding it to writer_names.
171
  """
172

173
  def __init__(self, data_dir, tss):
174
    super().__init__(data_dir)
175
    self.writer_names = [
176
        'eval_overall', 'eval_shapes', 'eval_scale', 'eval_orientation',
177
        'eval_x_pos', 'eval_y_pos'
178
    ]
179
    self.tss = tss
180
    self.summary_writers = self.setup_summary_writers(data_dir,
181
                                                      self.writer_names)
182
    self.metrics_dict = self.setup_metrics()
183

184
  def setup_metrics(self):
185
    """Sets up metrics for dsprites eval loop.
186

187
    Returns:
188
      Dictionary with a key for each axis to be measured (overall performance,
189
        individual latents, etc).
190
    """
191
    metrics_dict = {}
192
    tss = self.tss
193
    metrics_dict['eval_overall'] = [tf.keras.metrics.Mean('MSE loss')]
194
    metrics_dict['eval_shapes'] = self.create_metric_for_latent(
195
        0.1, tf.reduce_mean(tss[0:3]), is_shape=True)
196
    metrics_dict['eval_scale'] = self.create_metric_for_latent(
197
        1 / (2 * 10), tss[3])
198
    metrics_dict['eval_orientation'] = self.create_metric_for_latent(
199
        1 / (2 * 40), tss[4])
200
    metrics_dict['eval_x_pos'] = self.create_metric_for_latent(
201
        1 / (2 * 32), tss[5])
202
    metrics_dict['eval_y_pos'] = self.create_metric_for_latent(
203
        1 / (2 * 32), tss[6])
204
    return metrics_dict
205

206
  def update_metrics(self, total_loss, actual, preds):
207
    """Updates all metric values for dsprites eval.
208

209
    Args:
210
      total_loss: Float, loss score for current global step.
211
      actual: 2d array of shape (minibatch_size, 7) of actual latent values.
212
      preds: 2d array of shape (minibatch_size, 7) of predicted latent values.
213
    """
214
    metrics_dict = self.metrics_dict
215
    for k in metrics_dict:
216
      if k == 'eval_overall':
217
        self.update_individual_metrics(metrics_dict[k], total_loss)
218
      elif k == 'eval_shapes':
219
        self.update_individual_metrics(
220
            metrics_dict[k], actual=actual[:, :3], preds=preds[:, :3])
221
      elif k == 'eval_scale':
222
        self.update_individual_metrics(
223
            metrics_dict[k], actual=actual[:, 3], preds=preds[:, 3])
224
      elif k == 'eval_orientation':
225
        self.update_individual_metrics(
226
            metrics_dict[k], actual=actual[:, 4], preds=preds[:, 4])
227
      elif k == 'eval_x_pos':
228
        self.update_individual_metrics(
229
            metrics_dict[k], actual=actual[:, 5], preds=preds[:, 5])
230
      elif k == 'eval_y_pos':
231
        self.update_individual_metrics(
232
            metrics_dict[k], actual=actual[:, 6], preds=preds[:, 6])
233
      else:
234
        pass
235

236
  def create_metric_for_latent(self, tolerance, tss, is_shape=False):
237
    """Creates the tf.keras.metrics objects for an individual latent axis.
238

239
    Args:
240
      tolerance: Specifies how close to correct a measurement must be to the
241
        ground truth, for determining accuracy.
242
      tss: Total sum of squares value (over entire dataset) for individual
243
        latent.
244
      is_shape: Whether to use the DspritesShapeAccuracy metric for accuracy.
245

246
    Returns:
247
      List of tf.keras.metrics objects for latent.
248
    """
249
    metrics = []
250
    metrics.append(tf.keras.metrics.Mean('MSE loss'))
251
    if is_shape:
252
      metrics.append(DspritesShapeAccuracy(tolerance, 'accuracy'))
253
    else:
254
      metrics.append(DspritesAccuracy(tolerance, 'accuracy'))
255
    metrics.append(R2Metric(tss, 'R^2'))
256
    return metrics
257

258
  def update_individual_metrics(self,
259
                                metrics_list,
260
                                total_loss=None,
261
                                actual=None,
262
                                preds=None):
263
    """Logic for updating individual dsprites eval metrics within a collection.
264

265
    Args:
266
      metrics_list: List of tf.keras.metrics objects.
267
      total_loss: Optional float, total loss score from current step.
268
      actual: 2d array, actual latent values for minibatch.
269
      preds: 2d array, predicted latent values for minibatch.
270
    """
271
    for metric in metrics_list:
272
      if metric.name == 'MSE loss':
273
        if total_loss is not None:
274
          metric.update_state(total_loss)
275
        else:
276
          mse = (actual - preds)**2
277
          metric.update_state(mse)
278
      elif metric.name == 'accuracy':
279
        metric.update_state(actual, preds)
280
      elif metric.name == 'R^2':
281
        metric.update_state(actual, preds)
282
      else:
283
        logging.info(
284
            'Received unknown metric %s, please add desired behaviour to dsprites update_individual_metrics function',
285
            metric.name)
286

287

288
class DspritesTrainMetrics(MetricsInterface):
289
  """Handles storing and updating metrics during dsprites train loops.
290
  """
291

292
  def __init__(self, data_dir):
293
    super().__init__(data_dir)
294
    self.writer_names = ['train']
295
    self.summary_writers = self.setup_summary_writers(data_dir,
296
                                                      self.writer_names)
297
    self.metrics_dict = self.setup_metrics()
298

299
  def setup_metrics(self):
300
    metrics_dict = {}
301
    metrics_dict['train'] = [tf.keras.metrics.Mean('MSE loss')]
302
    return metrics_dict
303

304
  def update_metrics(self, total_loss, actual, preds):
305
    del actual, preds  # not used here
306
    for k in self.metrics_dict:
307
      if k == 'train':
308
        if self.metrics_dict[k][0].name == 'MSE loss':
309
          self.metrics_dict[k][0].update_state(total_loss)
310
      else:
311
        pass
312

313

314
@tf.function
315
def get_tss_for_r2(strategy, ds, num_classes, num_examples, batch_size=1):
316
  """Computes dataset-wide stats for use in R^2 computation.
317

318
  Args:
319
    strategy: tf.distribute.Strategy object.
320
    ds: tf.data.Dataset object.
321
    num_classes: Int.
322
    num_examples: Int, number of examples in dataset.
323
    batch_size: If ds is batched, specify batch size here.
324

325
  Returns:
326
    Tuple (y_bar, tss): arrays of size (7,), containing the average value and
327
      total sum of squares for each of the seven latents.
328
  """
329

330
  def y_bar_step(x):
331
    return tf.reduce_sum(x['values'], axis=0)
332

333
  def tss_step(x, y_bar):
334
    return tf.reduce_sum((y_bar - x['values'])**2, axis=0)
335

336
  y_bar = tf.zeros(num_classes)
337
  num_steps = num_examples // batch_size
338
  ds_iter = iter(ds)
339
  for _ in tf.range(num_steps):
340
    x = next(ds_iter)
341
    per_replica = strategy.run(y_bar_step, args=(x,))
342
    y_bar += strategy.reduce('SUM', per_replica, axis=None)
343
  y_bar = y_bar / num_examples
344
  tss = tf.zeros(num_classes)
345
  ds_iter = iter(ds)
346
  for _ in tf.range(num_steps):
347
    x = next(ds_iter)
348
    per_replica = strategy.run(tss_step, args=(x, y_bar))
349
    tss += strategy.reduce('SUM', per_replica, axis=None)
350
  return y_bar, tss
351

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.