google-research

Форк
0
/
experiment_utils.py 
290 строк · 10.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Utilities to help set up and run experiments."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import json
23
import os.path
24

25
from absl import logging
26
import numpy as np
27
import scipy.special
28
from six.moves import range
29
from six.moves import zip
30
import tensorflow.compat.v2 as tf
31
import tensorflow_datasets as tfds
32

33
gfile = tf.io.gfile
34

35

36
class _SimpleJsonEncoder(json.JSONEncoder):
37

38
  def default(self, o):
39
    return o.__dict__
40

41

42
def json_dumps(x):
43
  return json.dumps(x, indent=2, cls=_SimpleJsonEncoder)
44

45

46
def record_config(config, path):
47
  out = json_dumps(config)
48
  logging.info('Recording config to %s\n %s', path, out)
49
  gfile.makedirs(os.path.dirname(path))
50
  with gfile.GFile(path, 'w') as fh:
51
    fh.write(out)
52

53

54
def load_config(path):
55
  logging.info('Loading config from %s', path)
56
  with gfile.GFile(path) as fh:
57
    return json.loads(fh.read())
58

59

60
def save_model(model, output_dir):
61
  """Save Keras model weights and architecture as HDF5 file."""
62
  save_path = '%s/model.hdf5' % output_dir
63
  logging.info('Saving model to %s', save_path)
64
  model.save(save_path, include_optimizer=False)
65
  return save_path
66

67

68
def load_model(path):
69
  logging.info('Loading model from %s', path)
70
  return tf.keras.models.load_model(path)
71

72

73
def metrics_from_stats(stats):
74
  """Compute metrics to report to hyperparameter tuner."""
75
  labels, probs = stats['labels'], stats['probs']
76
  # Reshape binary predictions to 2-class.
77
  if len(probs.shape) == 1:
78
    probs = np.stack([1-probs, probs], axis=-1)
79
  assert len(probs.shape) == 2
80

81
  predictions = np.argmax(probs, axis=-1)
82
  accuracy = np.equal(labels, predictions)
83

84
  label_probs = probs[np.arange(len(labels)), labels]
85
  log_probs = np.maximum(-1e10, np.log(label_probs))
86
  brier_scores = np.square(probs).sum(-1) - 2 * label_probs
87

88
  return {'accuracy': accuracy.mean(0),
89
          'brier_score': brier_scores.mean(0),
90
          'log_prob': log_probs.mean(0)}
91

92

93
def make_predictions(
94
    model, batched_dataset, predictions_per_example=1, writers=None,
95
    predictions_are_logits=True, record_image_samples=True, max_batches=1e6):
96
  """Build a dictionary of predictions for examples from a dataset.
97

98
  Args:
99
    model: Trained Keras model.
100
    batched_dataset: tf.data.Dataset that yields batches of image, label pairs.
101
    predictions_per_example: Number of predictions to generate per example.
102
    writers: `dict` with keys 'small' and 'full', containing
103
      array_utils.StatsWriter instances for full prediction results and small
104
      prediction results (omitting logits).
105
    predictions_are_logits: Indicates whether model outputs are logits or
106
      probabilities.
107
    record_image_samples: `bool` Record one batch of input examples.
108
    max_batches: `int`, maximum number of batches.
109
  Returns:
110
    Dictionary containing:
111
      labels: Labels copied from the dataset (shape=[N]).
112
      logits_samples: Samples of model predict outputs for each example
113
          (shape=[N, M, K]).
114
      probs: Probabilities after averaging over samples (shape=[N, K]).
115
      image_samples: One batch of input images (for sanity checking).
116
  """
117
  if predictions_are_logits:
118
    samples_key = 'logits_samples'
119
    avg_probs_fn = lambda x: scipy.special.softmax(x, axis=-1).mean(-2)
120
  else:
121
    samples_key = 'probs_samples'
122
    avg_probs_fn = lambda x: x.mean(-2)
123

124
  labels, outputs = [], []
125
  predict_fn = model.predict if hasattr(model, 'predict') else model
126
  for i, (inputs_i, labels_i) in enumerate(tfds.as_numpy(batched_dataset)):
127
    logging.info('iteration: %d', i)
128
    outputs_i = np.stack(
129
        [predict_fn(inputs_i) for _ in range(predictions_per_example)], axis=1)
130

131
    if writers is None:
132
      labels.extend(labels_i)
133
      outputs.append(outputs_i)
134
    else:
135
      avg_probs_i = avg_probs_fn(outputs_i)
136
      prediction_batch = dict(labels=labels_i, probs=avg_probs_i)
137
      if i == 0 and record_image_samples:
138
        prediction_batch['image_samples'] = inputs_i
139

140
      writers['small'].write_batch(prediction_batch)
141
      prediction_batch[samples_key] = outputs_i
142
      writers['full'].write_batch(prediction_batch)
143

144
    # Don't predict whole ImageNet training set
145
    if i > max_batches:
146
      break
147

148
  if writers is None:
149
    image_samples = inputs_i  # pylint: disable=undefined-loop-variable
150
    labels = np.stack(labels, axis=0)
151
    outputs = np.concatenate(outputs, axis=0)
152

153
    stats = {'labels': labels, 'image_samples': image_samples,
154
             samples_key: outputs, 'probs': avg_probs_fn(outputs)}
155
    if record_image_samples:
156
      stats['image_samples'] = image_samples
157
    return stats
158

159

160
def download_dataset(dataset, batch_size_for_dl=1024):
161
  logging.info('Starting dataset download...')
162
  tup = list(zip(*tfds.as_numpy(dataset.batch(batch_size_for_dl))))
163
  logging.info('dataset download complete.')
164
  return tuple(np.concatenate(x, axis=0) for x in tup)
165

166

167
def get_distribution_strategy(distribution_strategy='default',
168
                              num_gpus=0,
169
                              num_workers=1,
170
                              all_reduce_alg=None,
171
                              num_packs=1):
172
  """Return a DistributionStrategy for running the model.
173

174
  Args:
175
    distribution_strategy: a string specifying which distribution strategy to
176
      use. Accepted values are 'off', 'default', 'one_device', 'mirrored',
177
      'parameter_server', 'multi_worker_mirrored', case insensitive. 'off' means
178
      not to use Distribution Strategy; 'default' means to choose from
179
      `MirroredStrategy`, `MultiWorkerMirroredStrategy`, or `OneDeviceStrategy`
180
      according to the number of GPUs and number of workers.
181
    num_gpus: Number of GPUs to run this model.
182
    num_workers: Number of workers to run this model.
183
    all_reduce_alg: Optional. Specifies which algorithm to use when performing
184
      all-reduce. For `MirroredStrategy`, valid values are 'nccl' and
185
      'hierarchical_copy'. For `MultiWorkerMirroredStrategy`, valid values are
186
      'ring' and 'nccl'.  If None, DistributionStrategy will choose based on
187
      device topology.
188
    num_packs: Optional.  Sets the `num_packs` in `tf.distribute.NcclAllReduce`
189
      or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
190

191
  Returns:
192
    tf.distribute.DistibutionStrategy object.
193
  Raises:
194
    ValueError: if `distribution_strategy` is 'off' or 'one_device' and
195
      `num_gpus` is larger than 1; or `num_gpus` is negative.
196
  """
197
  if num_gpus < 0:
198
    raise ValueError('`num_gpus` can not be negative.')
199

200
  distribution_strategy = distribution_strategy.lower()
201
  if distribution_strategy == 'off':
202
    if num_gpus > 1:
203
      raise ValueError(
204
          'When {} GPUs and  {} workers are specified, distribution_strategy '
205
          'flag cannot be set to "off".'.format(num_gpus, num_workers))
206
    return None
207

208
  if distribution_strategy == 'multi_worker_mirrored':
209
    return tf.distribute.experimental.MultiWorkerMirroredStrategy(
210
        communication=_collective_communication(all_reduce_alg))
211

212
  if (distribution_strategy == 'one_device' or
213
      (distribution_strategy == 'default' and num_gpus <= 1)):
214
    if num_gpus == 0:
215
      return tf.distribute.OneDeviceStrategy('device:CPU:0')
216
    else:
217
      if num_gpus > 1:
218
        raise ValueError('`OneDeviceStrategy` can not be used for more than '
219
                         'one device.')
220
      return tf.distribute.OneDeviceStrategy('device:GPU:0')
221

222
  if distribution_strategy in ('mirrored', 'default'):
223
    if num_gpus == 0:
224
      assert distribution_strategy == 'mirrored'
225
      devices = ['device:CPU:0']
226
    else:
227
      devices = ['device:GPU:%d' % i for i in range(num_gpus)]
228
    return tf.distribute.MirroredStrategy(
229
        devices=devices,
230
        cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
231

232
  if distribution_strategy == 'parameter_server':
233
    return tf.compat.v1.distribute.experimental.ParameterServerStrategy()
234

235
  raise ValueError(
236
      'Unrecognized Distribution Strategy: %r' % distribution_strategy)
237

238

239
def _collective_communication(all_reduce_alg):
240
  """Return a CollectiveCommunication based on all_reduce_alg.
241

242
  Args:
243
    all_reduce_alg: a string specifying which collective communication to pick,
244
      or None.
245

246
  Returns:
247
    tf.distribute.experimental.CollectiveCommunication object
248

249
  Raises:
250
    ValueError: if `all_reduce_alg` not in [None, 'ring', 'nccl']
251
  """
252
  collective_communication_options = {
253
      None: tf.distribute.experimental.CollectiveCommunication.AUTO,
254
      'ring': tf.distribute.experimental.CollectiveCommunication.RING,
255
      'nccl': tf.distribute.experimental.CollectiveCommunication.NCCL
256
  }
257
  if all_reduce_alg not in collective_communication_options:
258
    raise ValueError(
259
        'When used with `multi_worker_mirrored`, valid values for '
260
        'all_reduce_alg are ["ring", "nccl"].  Supplied value: {}'.format(
261
            all_reduce_alg))
262
  return collective_communication_options[all_reduce_alg]
263

264

265
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
266
  """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
267

268
  Args:
269
    all_reduce_alg: a string specifying which cross device op to pick, or None.
270
    num_packs: an integer specifying number of packs for the cross device op.
271

272
  Returns:
273
    tf.distribute.CrossDeviceOps object or None.
274

275
  Raises:
276
    ValueError: if `all_reduce_alg` not in [None, 'nccl', 'hierarchical_copy'].
277
  """
278
  if all_reduce_alg is None:
279
    return None
280
  mirrored_all_reduce_options = {
281
      'nccl': tf.distribute.NcclAllReduce,
282
      'hierarchical_copy': tf.distribute.HierarchicalCopyAllReduce
283
  }
284
  if all_reduce_alg not in mirrored_all_reduce_options:
285
    raise ValueError(
286
        'When used with `mirrored`, valid values for all_reduce_alg are '
287
        '["nccl", "hierarchical_copy"].  Supplied value: {}'.format(
288
            all_reduce_alg))
289
  cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
290
  return cross_device_ops_class(num_packs=num_packs)
291

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.