google-research
290 строк · 10.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities to help set up and run experiments."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import json23import os.path24
25from absl import logging26import numpy as np27import scipy.special28from six.moves import range29from six.moves import zip30import tensorflow.compat.v2 as tf31import tensorflow_datasets as tfds32
33gfile = tf.io.gfile34
35
36class _SimpleJsonEncoder(json.JSONEncoder):37
38def default(self, o):39return o.__dict__40
41
42def json_dumps(x):43return json.dumps(x, indent=2, cls=_SimpleJsonEncoder)44
45
46def record_config(config, path):47out = json_dumps(config)48logging.info('Recording config to %s\n %s', path, out)49gfile.makedirs(os.path.dirname(path))50with gfile.GFile(path, 'w') as fh:51fh.write(out)52
53
54def load_config(path):55logging.info('Loading config from %s', path)56with gfile.GFile(path) as fh:57return json.loads(fh.read())58
59
60def save_model(model, output_dir):61"""Save Keras model weights and architecture as HDF5 file."""62save_path = '%s/model.hdf5' % output_dir63logging.info('Saving model to %s', save_path)64model.save(save_path, include_optimizer=False)65return save_path66
67
68def load_model(path):69logging.info('Loading model from %s', path)70return tf.keras.models.load_model(path)71
72
73def metrics_from_stats(stats):74"""Compute metrics to report to hyperparameter tuner."""75labels, probs = stats['labels'], stats['probs']76# Reshape binary predictions to 2-class.77if len(probs.shape) == 1:78probs = np.stack([1-probs, probs], axis=-1)79assert len(probs.shape) == 280
81predictions = np.argmax(probs, axis=-1)82accuracy = np.equal(labels, predictions)83
84label_probs = probs[np.arange(len(labels)), labels]85log_probs = np.maximum(-1e10, np.log(label_probs))86brier_scores = np.square(probs).sum(-1) - 2 * label_probs87
88return {'accuracy': accuracy.mean(0),89'brier_score': brier_scores.mean(0),90'log_prob': log_probs.mean(0)}91
92
93def make_predictions(94model, batched_dataset, predictions_per_example=1, writers=None,95predictions_are_logits=True, record_image_samples=True, max_batches=1e6):96"""Build a dictionary of predictions for examples from a dataset.97
98Args:
99model: Trained Keras model.
100batched_dataset: tf.data.Dataset that yields batches of image, label pairs.
101predictions_per_example: Number of predictions to generate per example.
102writers: `dict` with keys 'small' and 'full', containing
103array_utils.StatsWriter instances for full prediction results and small
104prediction results (omitting logits).
105predictions_are_logits: Indicates whether model outputs are logits or
106probabilities.
107record_image_samples: `bool` Record one batch of input examples.
108max_batches: `int`, maximum number of batches.
109Returns:
110Dictionary containing:
111labels: Labels copied from the dataset (shape=[N]).
112logits_samples: Samples of model predict outputs for each example
113(shape=[N, M, K]).
114probs: Probabilities after averaging over samples (shape=[N, K]).
115image_samples: One batch of input images (for sanity checking).
116"""
117if predictions_are_logits:118samples_key = 'logits_samples'119avg_probs_fn = lambda x: scipy.special.softmax(x, axis=-1).mean(-2)120else:121samples_key = 'probs_samples'122avg_probs_fn = lambda x: x.mean(-2)123
124labels, outputs = [], []125predict_fn = model.predict if hasattr(model, 'predict') else model126for i, (inputs_i, labels_i) in enumerate(tfds.as_numpy(batched_dataset)):127logging.info('iteration: %d', i)128outputs_i = np.stack(129[predict_fn(inputs_i) for _ in range(predictions_per_example)], axis=1)130
131if writers is None:132labels.extend(labels_i)133outputs.append(outputs_i)134else:135avg_probs_i = avg_probs_fn(outputs_i)136prediction_batch = dict(labels=labels_i, probs=avg_probs_i)137if i == 0 and record_image_samples:138prediction_batch['image_samples'] = inputs_i139
140writers['small'].write_batch(prediction_batch)141prediction_batch[samples_key] = outputs_i142writers['full'].write_batch(prediction_batch)143
144# Don't predict whole ImageNet training set145if i > max_batches:146break147
148if writers is None:149image_samples = inputs_i # pylint: disable=undefined-loop-variable150labels = np.stack(labels, axis=0)151outputs = np.concatenate(outputs, axis=0)152
153stats = {'labels': labels, 'image_samples': image_samples,154samples_key: outputs, 'probs': avg_probs_fn(outputs)}155if record_image_samples:156stats['image_samples'] = image_samples157return stats158
159
160def download_dataset(dataset, batch_size_for_dl=1024):161logging.info('Starting dataset download...')162tup = list(zip(*tfds.as_numpy(dataset.batch(batch_size_for_dl))))163logging.info('dataset download complete.')164return tuple(np.concatenate(x, axis=0) for x in tup)165
166
167def get_distribution_strategy(distribution_strategy='default',168num_gpus=0,169num_workers=1,170all_reduce_alg=None,171num_packs=1):172"""Return a DistributionStrategy for running the model.173
174Args:
175distribution_strategy: a string specifying which distribution strategy to
176use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
177'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means
178not to use Distribution Strategy; 'default' means to choose from
179`MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
180according to the number of GPUs and number of workers.
181num_gpus: Number of GPUs to run this model.
182num_workers: Number of workers to run this model.
183all_reduce_alg: Optional. Specifies which algorithm to use when performing
184all-reduce. For `MirroredStrategy`, valid values are 'nccl' and
185'hierarchical_copy'. For `MultiWorkerMirroredStrategy`, valid values are
186'ring' and 'nccl'. If None, DistributionStrategy will choose based on
187device topology.
188num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
189or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
190
191Returns:
192tf.distribute.DistibutionStrategy object.
193Raises:
194ValueError: if `distribution_strategy` is 'off' or 'one_device' and
195`num_gpus` is larger than 1; or `num_gpus` is negative.
196"""
197if num_gpus < 0:198raise ValueError('`num_gpus` can not be negative.')199
200distribution_strategy = distribution_strategy.lower()201if distribution_strategy == 'off':202if num_gpus > 1:203raise ValueError(204'When {} GPUs and {} workers are specified, distribution_strategy '205'flag cannot be set to "off".'.format(num_gpus, num_workers))206return None207
208if distribution_strategy == 'multi_worker_mirrored':209return tf.distribute.experimental.MultiWorkerMirroredStrategy(210communication=_collective_communication(all_reduce_alg))211
212if (distribution_strategy == 'one_device' or213(distribution_strategy == 'default' and num_gpus <= 1)):214if num_gpus == 0:215return tf.distribute.OneDeviceStrategy('device:CPU:0')216else:217if num_gpus > 1:218raise ValueError('`OneDeviceStrategy` can not be used for more than '219'one device.')220return tf.distribute.OneDeviceStrategy('device:GPU:0')221
222if distribution_strategy in ('mirrored', 'default'):223if num_gpus == 0:224assert distribution_strategy == 'mirrored'225devices = ['device:CPU:0']226else:227devices = ['device:GPU:%d' % i for i in range(num_gpus)]228return tf.distribute.MirroredStrategy(229devices=devices,230cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))231
232if distribution_strategy == 'parameter_server':233return tf.compat.v1.distribute.experimental.ParameterServerStrategy()234
235raise ValueError(236'Unrecognized Distribution Strategy: %r' % distribution_strategy)237
238
239def _collective_communication(all_reduce_alg):240"""Return a CollectiveCommunication based on all_reduce_alg.241
242Args:
243all_reduce_alg: a string specifying which collective communication to pick,
244or None.
245
246Returns:
247tf.distribute.experimental.CollectiveCommunication object
248
249Raises:
250ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
251"""
252collective_communication_options = {253None: tf.distribute.experimental.CollectiveCommunication.AUTO,254'ring': tf.distribute.experimental.CollectiveCommunication.RING,255'nccl': tf.distribute.experimental.CollectiveCommunication.NCCL256}257if all_reduce_alg not in collective_communication_options:258raise ValueError(259'When used with `multi_worker_mirrored`, valid values for '260'all_reduce_alg are ["ring", "nccl"]. Supplied value: {}'.format(261all_reduce_alg))262return collective_communication_options[all_reduce_alg]263
264
265def _mirrored_cross_device_ops(all_reduce_alg, num_packs):266"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.267
268Args:
269all_reduce_alg: a string specifying which cross device op to pick, or None.
270num_packs: an integer specifying number of packs for the cross device op.
271
272Returns:
273tf.distribute.CrossDeviceOps object or None.
274
275Raises:
276ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
277"""
278if all_reduce_alg is None:279return None280mirrored_all_reduce_options = {281'nccl': tf.distribute.NcclAllReduce,282'hierarchical_copy': tf.distribute.HierarchicalCopyAllReduce283}284if all_reduce_alg not in mirrored_all_reduce_options:285raise ValueError(286'When used with `mirrored`, valid values for all_reduce_alg are '287'["nccl", "hierarchical_copy"]. Supplied value: {}'.format(288all_reduce_alg))289cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]290return cross_device_ops_class(num_packs=num_packs)291