google-research
162 строки · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Build and train MNIST models for UQ experiments."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24import attr
25import numpy as np
26import scipy.special
27import six
28from six.moves import range
29import tensorflow.compat.v2 as tf
30import tensorflow_datasets as tfds
31from uq_benchmark_2019 import uq_utils
32keras = tf.keras
33
34_NUM_CLASSES = 10
35_MNIST_SHAPE = (28, 28, 1)
36_NUM_IMAGE_EXAMPLES_TO_RECORD = 32
37_BATCH_SIZE_FOR_PREDICT = 1024
38
39ARCHITECTURES = ['mlp', 'lenet']
40METHODS = ['vanilla', 'dropout', 'svi', 'll_dropout', 'll_svi']
41
42
43@attr.s
44class ModelOptions(object):
45"""Parameters for model construction and fitting."""
46train_epochs = attr.ib()
47num_train_examples = attr.ib()
48batch_size = attr.ib()
49learning_rate = attr.ib()
50method = attr.ib()
51architecture = attr.ib()
52mlp_layer_sizes = attr.ib()
53dropout_rate = attr.ib()
54num_examples_for_predict = attr.ib()
55predictions_per_example = attr.ib()
56
57
58def _build_mlp(opts):
59"""Builds a multi-layer perceptron Keras model."""
60layer_builders = uq_utils.get_layer_builders(opts.method, opts.dropout_rate,
61opts.num_train_examples)
62_, dense_layer, dense_last, dropout_fn, dropout_fn_last = layer_builders
63
64inputs = keras.layers.Input(_MNIST_SHAPE)
65net = keras.layers.Flatten(input_shape=_MNIST_SHAPE)(inputs)
66for size in opts.mlp_layer_sizes:
67net = dropout_fn(net)
68net = dense_layer(size, activation='relu')(net)
69net = dropout_fn_last(net)
70logits = dense_last(_NUM_CLASSES)(net)
71return keras.Model(inputs=inputs, outputs=logits)
72
73
74def _build_lenet(opts):
75"""Builds a LeNet Keras model."""
76layer_builders = uq_utils.get_layer_builders(opts.method, opts.dropout_rate,
77opts.num_train_examples)
78conv2d, dense_layer, dense_last, dropout_fn, dropout_fn_last = layer_builders
79
80inputs = keras.layers.Input(_MNIST_SHAPE)
81net = inputs
82net = conv2d(32, kernel_size=(3, 3),
83activation='relu',
84input_shape=_MNIST_SHAPE)(net)
85net = conv2d(64, (3, 3), activation='relu')(net)
86net = keras.layers.MaxPooling2D(pool_size=(2, 2))(net)
87net = dropout_fn(net)
88net = keras.layers.Flatten()(net)
89net = dense_layer(128, activation='relu')(net)
90net = dropout_fn_last(net)
91logits = dense_last(_NUM_CLASSES)(net)
92return keras.Model(inputs=inputs, outputs=logits)
93
94
95def build_model(opts):
96"""Builds (uncompiled) Keras model from ModelOptions instance."""
97return {'mlp': _build_mlp, 'lenet': _build_lenet}[opts.architecture](opts)
98
99
100def build_and_train(opts, dataset_train, dataset_eval, output_dir):
101"""Returns a trained MNIST model and saves it to output_dir.
102
103Args:
104opts: ModelOptions
105dataset_train: Pair of images, labels np.ndarrays for training.
106dataset_eval: Pair of images, labels np.ndarrays for continuous eval.
107output_dir: Directory for the saved model and tensorboard events.
108Returns:
109Trained Keras model.
110"""
111model = build_model(opts)
112model.compile(
113keras.optimizers.legacy.Adam(opts.learning_rate),
114loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
115metrics=['accuracy'],
116)
117
118tensorboard_cb = keras.callbacks.TensorBoard(
119log_dir=output_dir, write_graph=False)
120
121train_images, train_labels = dataset_train
122assert len(train_images) == opts.num_train_examples, (
123'%d != %d' % (len(train_images), opts.num_train_examples))
124model.fit(
125train_images, train_labels,
126epochs=opts.train_epochs,
127# NOTE: steps_per_epoch will cause OOM for some reason.
128validation_data=dataset_eval,
129batch_size=opts.batch_size,
130callbacks=[tensorboard_cb],
131)
132return model
133
134
135def make_predictions(opts, model, dataset):
136"""Build a dictionary of model predictions on a given dataset.
137
138Args:
139opts: ModelOptions.
140model: Trained Keras model.
141dataset: tf.data.Dataset of <image, label> pairs.
142Returns:
143Dictionary containing labels and model logits.
144"""
145if opts.num_examples_for_predict:
146dataset = tuple(x[:opts.num_examples_for_predict] for x in dataset)
147
148batched_dataset = (tf.data.Dataset.from_tensor_slices(dataset)
149.batch(_BATCH_SIZE_FOR_PREDICT))
150out = collections.defaultdict(list)
151for images, labels in tfds.as_numpy(batched_dataset):
152logits_samples = np.stack(
153[model.predict(images) for _ in range(opts.predictions_per_example)],
154axis=1) # shape: [batch_size, num_samples, num_classes]
155probs = scipy.special.softmax(logits_samples, axis=-1).mean(-2)
156out['labels'].extend(labels)
157out['logits_samples'].extend(logits_samples)
158out['probs'].extend(probs)
159if len(out['image_examples']) < _NUM_IMAGE_EXAMPLES_TO_RECORD:
160out['image_examples'].extend(images)
161
162return {k: np.stack(a) for k, a in six.iteritems(out)}
163