google-research
89 строк · 3.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Generate predictions from a trained model on a range of datasets."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from absl import app
23from absl import flags
24from absl import logging
25
26import tensorflow.compat.v2 as tf
27from uq_benchmark_2019 import array_utils
28from uq_benchmark_2019 import experiment_utils
29from uq_benchmark_2019 import image_data_utils
30from uq_benchmark_2019.cifar import data_lib
31from uq_benchmark_2019.cifar import models_lib
32
33_BATCH_SIZE = 256
34
35FLAGS = flags.FLAGS
36
37
38def _declare_flags():
39"""Declare flags; not invoked when this module is imported as a library."""
40flags.DEFINE_string('model_dir', None, 'Path to Keras model.')
41flags.DEFINE_string('output_dir', None, 'Output directory.')
42
43flags.DEFINE_integer('predictions_per_example', 1,
44'Number of prediction samples to generate per example.')
45flags.DEFINE_integer('max_examples', None,
46'Maximum number of examples to process per dataset.')
47
48flags.DEFINE_string('dataset_name', None, 'Configured dataset name.')
49flags.DEFINE_integer('task', 0, 'Task number.')
50
51
52def run(dataset_name, model_dir,
53predictions_per_example, max_examples,
54output_dir,
55fake_data=False):
56"""Runs predictions on the given dataset using the specified model."""
57tf.io.gfile.makedirs(output_dir)
58data_config = image_data_utils.get_data_config(dataset_name)
59dataset = data_lib.build_dataset(data_config, fake_data=fake_data)
60if max_examples:
61dataset = dataset.take(max_examples)
62
63model = models_lib.load_model(model_dir)
64logging.info('Starting predictions.')
65predictions = experiment_utils.make_predictions(
66model, dataset.batch(_BATCH_SIZE), predictions_per_example)
67
68logging.info('Done computing predictions; recording results to disk.')
69array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name,
70predictions)
71del predictions['logits_samples']
72array_utils.write_npz(output_dir, 'predictions_small_%s.npz' % dataset_name,
73predictions)
74
75
76def main(argv):
77if len(argv) > 1:
78raise app.UsageError('Too many command-line arguments.')
79
80run(FLAGS.dataset_name,
81FLAGS.model_dir,
82FLAGS.predictions_per_example,
83FLAGS.max_examples,
84FLAGS.output_dir.replace('%task%', str(FLAGS.task)))
85
86
87if __name__ == '__main__':
88_declare_flags()
89app.run(main)
90