google-research
350 строк · 12.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Metric util functions for contrastive learning experiments.
17
18Defines custom metrics and metrics helper classes for experiments where we want
19to track performance along multiple dataset-specific axes.
20"""
21
22import abc23import os24
25from absl import flags26from absl import logging27
28import tensorflow.compat.v2 as tf29
30FLAGS = flags.FLAGS31
32
33class R2Metric(tf.keras.metrics.Metric):34"""Compute and store running R^2 score."""35
36def __init__(self, tss, name='R^2', **kwargs):37super().__init__(name=name, **kwargs)38if tf.rank(tss) > 0:39# TODO(zeef) Find a way to store TSS and RSS values as arrays.40# Currently, resetting the metric will throw an error if RSS is not 0-dim.41self.tss = tf.reduce_mean(tss)42else:43self.tss = tss44self.rss = self.add_weight(name='rss', initializer='zeros')45
46def update_state(self, actual, preds):47res_squared = tf.reduce_mean(48(actual - preds)**2, axis=tf.range(1, tf.rank(actual)))49res_squared_summed = tf.reduce_sum(res_squared)50self.rss.assign_add(res_squared_summed)51
52def result(self):53return 1 - self.rss / self.tss54
55
56class DspritesAccuracy(tf.keras.metrics.Metric):57"""A measure of correctness for dsprites (non-shape) latents.58
59We treat a predicted set of latents as 'correct' if it is closer to the
60correct values than to nearby values.
61"""
62
63def __init__(self, tolerance, name='dsprites_accuracy', **kwargs):64super().__init__(name=name, **kwargs)65self.tolerance = tf.constant(tolerance)66self.correct = self.add_weight(name='correct', initializer='zeros')67self.seen = self.add_weight(name='seen', initializer='zeros')68
69def update_state(self, actual, preds):70is_correct = tf.math.abs(actual - preds) < self.tolerance71is_correct = tf.cast(is_correct, tf.float32)72# need to count no. of examples seen but can't use .shape[0] in graph mode73is_seen = tf.reduce_sum(tf.ones_like(is_correct))74self.correct.assign_add(tf.reduce_sum(is_correct))75self.seen.assign_add(is_seen)76
77def result(self):78return self.correct / self.seen79
80
81class DspritesShapeAccuracy(tf.keras.metrics.Metric):82"""A measure of correctness for dsprites shape prediction.83
84We treat a prediction as 'correct' if it is close to 1 for the correct shape
85AND close to 0 for the other shapes.
86"""
87
88def __init__(self, tolerance, name='dsprites_shape_accuracy', **kwargs):89super().__init__(name=name, **kwargs)90self.tolerance = tf.constant(tolerance)91self.correct = self.add_weight(name='correct', initializer='zeros')92self.seen = self.add_weight(name='seen', initializer='zeros')93
94def update_state(self, actual, preds):95# actual and preds are shape (batch_size, 3)96is_correct = tf.math.abs(actual - preds) < self.tolerance97# require all three shape predictions to be accurate to count as correct98is_correct = tf.experimental.numpy.all(is_correct, axis=1)99is_correct = tf.cast(is_correct, tf.float32)100# need to count no. of examples seen but can't use .shape[0] in graph mode101is_seen = tf.reduce_sum(tf.ones_like(is_correct))102self.correct.assign_add(tf.reduce_sum(is_correct))103self.seen.assign_add(is_seen)104
105def result(self):106return self.correct / self.seen107
108
109class MetricsInterface(object):110"""Interface for managing metric definition, collection, and updating.111"""
112
113@abc.abstractmethod114def __init__(self, data_dir):115pass116
117def setup_metrics(self):118"""Creates a consistent way of managing metrics for each individual latent.119
120Returns:
121Dictionary with a key for each axis to be measured.
122"""
123raise NotImplementedError()124
125def update_metrics(self):126"""Updates the metric values for each axis created in setup_metrics."""127raise NotImplementedError()128
129def setup_summary_writers(self, data_dir, writer_names):130"""Creates a tf summary writer for each name in writer_names.131
132Args:
133data_dir: Str, path to folder where summary writers should write to.
134writer_names: List of writer names, e.g. ['train', 'test'] or
135['eval_overall', 'eval_shape_accuracy', 'eval_position'], etc.
136
137Returns:
138Dict with (key,value) pairs of the form ('writer_name': writer).
139"""
140all_summary_writers = {}141for name in writer_names:142log_dir = os.path.join(data_dir, name)143summary_writer = tf.summary.create_file_writer(log_dir)144all_summary_writers[name] = summary_writer145return all_summary_writers146
147def write_metrics_to_summary(self, all_metrics, global_step):148"""Updates the summary writers at the end of each step.149
150Call this at the end of each step, from within a
151`with summary_writer_name.as_default():` context.
152
153Args:
154all_metrics: List of tf.keras.metrics objects.
155global_step: Int.
156"""
157for metric in all_metrics:158metric_value = metric.result().numpy().astype(float)159logging.info('Step: [%d] %s = %f', global_step, metric.name, metric_value)160tf.summary.scalar(metric.name, metric_value, step=global_step)161
162
163class DspritesEvalMetrics(MetricsInterface):164"""Handles storing and updating metrics during dsprites evaluation loops.165
166Simplifies the process of collecting metrics on multiple individual latents
167as well as overall performance by abstracting it away from the training loop.
168To add a new axis of metric collection: simply specify its name and behaviour
169in setup_metrics and update_metrics, and (optionally) create a separate
170summary writer for it by adding it to writer_names.
171"""
172
173def __init__(self, data_dir, tss):174super().__init__(data_dir)175self.writer_names = [176'eval_overall', 'eval_shapes', 'eval_scale', 'eval_orientation',177'eval_x_pos', 'eval_y_pos'178]179self.tss = tss180self.summary_writers = self.setup_summary_writers(data_dir,181self.writer_names)182self.metrics_dict = self.setup_metrics()183
184def setup_metrics(self):185"""Sets up metrics for dsprites eval loop.186
187Returns:
188Dictionary with a key for each axis to be measured (overall performance,
189individual latents, etc).
190"""
191metrics_dict = {}192tss = self.tss193metrics_dict['eval_overall'] = [tf.keras.metrics.Mean('MSE loss')]194metrics_dict['eval_shapes'] = self.create_metric_for_latent(1950.1, tf.reduce_mean(tss[0:3]), is_shape=True)196metrics_dict['eval_scale'] = self.create_metric_for_latent(1971 / (2 * 10), tss[3])198metrics_dict['eval_orientation'] = self.create_metric_for_latent(1991 / (2 * 40), tss[4])200metrics_dict['eval_x_pos'] = self.create_metric_for_latent(2011 / (2 * 32), tss[5])202metrics_dict['eval_y_pos'] = self.create_metric_for_latent(2031 / (2 * 32), tss[6])204return metrics_dict205
206def update_metrics(self, total_loss, actual, preds):207"""Updates all metric values for dsprites eval.208
209Args:
210total_loss: Float, loss score for current global step.
211actual: 2d array of shape (minibatch_size, 7) of actual latent values.
212preds: 2d array of shape (minibatch_size, 7) of predicted latent values.
213"""
214metrics_dict = self.metrics_dict215for k in metrics_dict:216if k == 'eval_overall':217self.update_individual_metrics(metrics_dict[k], total_loss)218elif k == 'eval_shapes':219self.update_individual_metrics(220metrics_dict[k], actual=actual[:, :3], preds=preds[:, :3])221elif k == 'eval_scale':222self.update_individual_metrics(223metrics_dict[k], actual=actual[:, 3], preds=preds[:, 3])224elif k == 'eval_orientation':225self.update_individual_metrics(226metrics_dict[k], actual=actual[:, 4], preds=preds[:, 4])227elif k == 'eval_x_pos':228self.update_individual_metrics(229metrics_dict[k], actual=actual[:, 5], preds=preds[:, 5])230elif k == 'eval_y_pos':231self.update_individual_metrics(232metrics_dict[k], actual=actual[:, 6], preds=preds[:, 6])233else:234pass235
236def create_metric_for_latent(self, tolerance, tss, is_shape=False):237"""Creates the tf.keras.metrics objects for an individual latent axis.238
239Args:
240tolerance: Specifies how close to correct a measurement must be to the
241ground truth, for determining accuracy.
242tss: Total sum of squares value (over entire dataset) for individual
243latent.
244is_shape: Whether to use the DspritesShapeAccuracy metric for accuracy.
245
246Returns:
247List of tf.keras.metrics objects for latent.
248"""
249metrics = []250metrics.append(tf.keras.metrics.Mean('MSE loss'))251if is_shape:252metrics.append(DspritesShapeAccuracy(tolerance, 'accuracy'))253else:254metrics.append(DspritesAccuracy(tolerance, 'accuracy'))255metrics.append(R2Metric(tss, 'R^2'))256return metrics257
258def update_individual_metrics(self,259metrics_list,260total_loss=None,261actual=None,262preds=None):263"""Logic for updating individual dsprites eval metrics within a collection.264
265Args:
266metrics_list: List of tf.keras.metrics objects.
267total_loss: Optional float, total loss score from current step.
268actual: 2d array, actual latent values for minibatch.
269preds: 2d array, predicted latent values for minibatch.
270"""
271for metric in metrics_list:272if metric.name == 'MSE loss':273if total_loss is not None:274metric.update_state(total_loss)275else:276mse = (actual - preds)**2277metric.update_state(mse)278elif metric.name == 'accuracy':279metric.update_state(actual, preds)280elif metric.name == 'R^2':281metric.update_state(actual, preds)282else:283logging.info(284'Received unknown metric %s, please add desired behaviour to dsprites update_individual_metrics function',285metric.name)286
287
288class DspritesTrainMetrics(MetricsInterface):289"""Handles storing and updating metrics during dsprites train loops.290"""
291
292def __init__(self, data_dir):293super().__init__(data_dir)294self.writer_names = ['train']295self.summary_writers = self.setup_summary_writers(data_dir,296self.writer_names)297self.metrics_dict = self.setup_metrics()298
299def setup_metrics(self):300metrics_dict = {}301metrics_dict['train'] = [tf.keras.metrics.Mean('MSE loss')]302return metrics_dict303
304def update_metrics(self, total_loss, actual, preds):305del actual, preds # not used here306for k in self.metrics_dict:307if k == 'train':308if self.metrics_dict[k][0].name == 'MSE loss':309self.metrics_dict[k][0].update_state(total_loss)310else:311pass312
313
314@tf.function315def get_tss_for_r2(strategy, ds, num_classes, num_examples, batch_size=1):316"""Computes dataset-wide stats for use in R^2 computation.317
318Args:
319strategy: tf.distribute.Strategy object.
320ds: tf.data.Dataset object.
321num_classes: Int.
322num_examples: Int, number of examples in dataset.
323batch_size: If ds is batched, specify batch size here.
324
325Returns:
326Tuple (y_bar, tss): arrays of size (7,), containing the average value and
327total sum of squares for each of the seven latents.
328"""
329
330def y_bar_step(x):331return tf.reduce_sum(x['values'], axis=0)332
333def tss_step(x, y_bar):334return tf.reduce_sum((y_bar - x['values'])**2, axis=0)335
336y_bar = tf.zeros(num_classes)337num_steps = num_examples // batch_size338ds_iter = iter(ds)339for _ in tf.range(num_steps):340x = next(ds_iter)341per_replica = strategy.run(y_bar_step, args=(x,))342y_bar += strategy.reduce('SUM', per_replica, axis=None)343y_bar = y_bar / num_examples344tss = tf.zeros(num_classes)345ds_iter = iter(ds)346for _ in tf.range(num_steps):347x = next(ds_iter)348per_replica = strategy.run(tss_step, args=(x, y_bar))349tss += strategy.reduce('SUM', per_replica, axis=None)350return y_bar, tss351