google-research

Форк
0
162 строки · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Build and train MNIST models for UQ experiments."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23

24
import attr
25
import numpy as np
26
import scipy.special
27
import six
28
from six.moves import range
29
import tensorflow.compat.v2 as tf
30
import tensorflow_datasets as tfds
31
from uq_benchmark_2019 import uq_utils
32
keras = tf.keras
33

34
_NUM_CLASSES = 10
35
_MNIST_SHAPE = (28, 28, 1)
36
_NUM_IMAGE_EXAMPLES_TO_RECORD = 32
37
_BATCH_SIZE_FOR_PREDICT = 1024
38

39
ARCHITECTURES = ['mlp', 'lenet']
40
METHODS = ['vanilla', 'dropout', 'svi', 'll_dropout', 'll_svi']
41

42

43
@attr.s
44
class ModelOptions(object):
45
  """Parameters for model construction and fitting."""
46
  train_epochs = attr.ib()
47
  num_train_examples = attr.ib()
48
  batch_size = attr.ib()
49
  learning_rate = attr.ib()
50
  method = attr.ib()
51
  architecture = attr.ib()
52
  mlp_layer_sizes = attr.ib()
53
  dropout_rate = attr.ib()
54
  num_examples_for_predict = attr.ib()
55
  predictions_per_example = attr.ib()
56

57

58
def _build_mlp(opts):
59
  """Builds a multi-layer perceptron Keras model."""
60
  layer_builders = uq_utils.get_layer_builders(opts.method, opts.dropout_rate,
61
                                               opts.num_train_examples)
62
  _, dense_layer, dense_last, dropout_fn, dropout_fn_last = layer_builders
63

64
  inputs = keras.layers.Input(_MNIST_SHAPE)
65
  net = keras.layers.Flatten(input_shape=_MNIST_SHAPE)(inputs)
66
  for size in opts.mlp_layer_sizes:
67
    net = dropout_fn(net)
68
    net = dense_layer(size, activation='relu')(net)
69
  net = dropout_fn_last(net)
70
  logits = dense_last(_NUM_CLASSES)(net)
71
  return keras.Model(inputs=inputs, outputs=logits)
72

73

74
def _build_lenet(opts):
75
  """Builds a LeNet Keras model."""
76
  layer_builders = uq_utils.get_layer_builders(opts.method, opts.dropout_rate,
77
                                               opts.num_train_examples)
78
  conv2d, dense_layer, dense_last, dropout_fn, dropout_fn_last = layer_builders
79

80
  inputs = keras.layers.Input(_MNIST_SHAPE)
81
  net = inputs
82
  net = conv2d(32, kernel_size=(3, 3),
83
               activation='relu',
84
               input_shape=_MNIST_SHAPE)(net)
85
  net = conv2d(64, (3, 3), activation='relu')(net)
86
  net = keras.layers.MaxPooling2D(pool_size=(2, 2))(net)
87
  net = dropout_fn(net)
88
  net = keras.layers.Flatten()(net)
89
  net = dense_layer(128, activation='relu')(net)
90
  net = dropout_fn_last(net)
91
  logits = dense_last(_NUM_CLASSES)(net)
92
  return keras.Model(inputs=inputs, outputs=logits)
93

94

95
def build_model(opts):
96
  """Builds (uncompiled) Keras model from ModelOptions instance."""
97
  return {'mlp': _build_mlp, 'lenet': _build_lenet}[opts.architecture](opts)
98

99

100
def build_and_train(opts, dataset_train, dataset_eval, output_dir):
101
  """Returns a trained MNIST model and saves it to output_dir.
102

103
  Args:
104
    opts: ModelOptions
105
    dataset_train: Pair of images, labels np.ndarrays for training.
106
    dataset_eval: Pair of images, labels np.ndarrays for continuous eval.
107
    output_dir: Directory for the saved model and tensorboard events.
108
  Returns:
109
    Trained Keras model.
110
  """
111
  model = build_model(opts)
112
  model.compile(
113
      keras.optimizers.legacy.Adam(opts.learning_rate),
114
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
115
      metrics=['accuracy'],
116
  )
117

118
  tensorboard_cb = keras.callbacks.TensorBoard(
119
      log_dir=output_dir, write_graph=False)
120

121
  train_images, train_labels = dataset_train
122
  assert len(train_images) == opts.num_train_examples, (
123
      '%d != %d' % (len(train_images), opts.num_train_examples))
124
  model.fit(
125
      train_images, train_labels,
126
      epochs=opts.train_epochs,
127
      # NOTE: steps_per_epoch will cause OOM for some reason.
128
      validation_data=dataset_eval,
129
      batch_size=opts.batch_size,
130
      callbacks=[tensorboard_cb],
131
  )
132
  return model
133

134

135
def make_predictions(opts, model, dataset):
136
  """Build a dictionary of model predictions on a given dataset.
137

138
  Args:
139
    opts: ModelOptions.
140
    model: Trained Keras model.
141
    dataset: tf.data.Dataset of <image, label> pairs.
142
  Returns:
143
    Dictionary containing labels and model logits.
144
  """
145
  if opts.num_examples_for_predict:
146
    dataset = tuple(x[:opts.num_examples_for_predict] for x in dataset)
147

148
  batched_dataset = (tf.data.Dataset.from_tensor_slices(dataset)
149
                     .batch(_BATCH_SIZE_FOR_PREDICT))
150
  out = collections.defaultdict(list)
151
  for images, labels in tfds.as_numpy(batched_dataset):
152
    logits_samples = np.stack(
153
        [model.predict(images) for _ in range(opts.predictions_per_example)],
154
        axis=1)  # shape: [batch_size, num_samples, num_classes]
155
    probs = scipy.special.softmax(logits_samples, axis=-1).mean(-2)
156
    out['labels'].extend(labels)
157
    out['logits_samples'].extend(logits_samples)
158
    out['probs'].extend(probs)
159
    if len(out['image_examples']) < _NUM_IMAGE_EXAMPLES_TO_RECORD:
160
      out['image_examples'].extend(images)
161

162
  return {k: np.stack(a) for k, a in six.iteritems(out)}
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.