google-research
78 строк · 3.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Configures and runs distributional-skew UQ experiments on MNIST."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from absl import logging
23import tensorflow.compat.v2 as tf
24
25from uq_benchmark_2019 import array_utils
26from uq_benchmark_2019 import experiment_utils
27from uq_benchmark_2019.mnist import data_lib
28from uq_benchmark_2019.mnist import hparams_lib
29from uq_benchmark_2019.mnist import models_lib
30gfile = tf.io.gfile
31
32
33def get_experiment_config(method, architecture,
34test_level, output_dir=None):
35"""Returns model and data configs."""
36data_opts_list = data_lib.DATA_OPTIONS_LIST
37if test_level:
38data_opts_list = data_opts_list[:4]
39
40model_opts = hparams_lib.get_tuned_model_options(architecture, method,
41fake_data=test_level > 1,
42fake_training=test_level > 0)
43if output_dir:
44experiment_utils.record_config(model_opts, output_dir+'/model_options.json')
45return model_opts, data_opts_list
46
47
48def run(method, architecture, output_dir, test_level):
49"""Trains a model and records its predictions on configured datasets.
50
51Args:
52method: Name of modeling method (vanilla, dropout, svi, ll_svi).
53architecture: Name of DNN architecture (mlp or dropout).
54output_dir: Directory to record the trained model and output stats.
55test_level: Zero indicates no testing. One indicates testing with real data.
56Two is for testing with fake data.
57"""
58fake_data = test_level > 1
59gfile.makedirs(output_dir)
60model_opts, data_opts_list = get_experiment_config(method, architecture,
61test_level=test_level,
62output_dir=output_dir)
63
64# Separately build dataset[0] with shuffle=True for training.
65dataset_train = data_lib.build_dataset(data_opts_list[0], fake_data=fake_data)
66dataset_eval = data_lib.build_dataset(data_opts_list[1], fake_data=fake_data)
67model = models_lib.build_and_train(model_opts,
68dataset_train, dataset_eval, output_dir)
69logging.info('Saving model to output_dir.')
70model.save_weights(output_dir + '/model.ckpt')
71
72for idx, data_opts in enumerate(data_opts_list):
73dataset = data_lib.build_dataset(data_opts, fake_data=fake_data)
74logging.info('Running predictions for dataset #%d', idx)
75stats = models_lib.make_predictions(model_opts, model, dataset)
76array_utils.write_npz(output_dir, 'stats_%d.npz' % idx, stats)
77del stats['logits_samples']
78array_utils.write_npz(output_dir, 'stats_small_%d.npz' % idx, stats)
79