google-research

Форк
0
89 строк · 3.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Generate predictions from a trained model on a range of datasets."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from absl import app
23
from absl import flags
24
from absl import logging
25

26
import tensorflow.compat.v2 as tf
27
from uq_benchmark_2019 import array_utils
28
from uq_benchmark_2019 import experiment_utils
29
from uq_benchmark_2019 import image_data_utils
30
from uq_benchmark_2019.cifar import data_lib
31
from uq_benchmark_2019.cifar import models_lib
32

33
_BATCH_SIZE = 256
34

35
FLAGS = flags.FLAGS
36

37

38
def _declare_flags():
39
  """Declare flags; not invoked when this module is imported as a library."""
40
  flags.DEFINE_string('model_dir', None, 'Path to Keras model.')
41
  flags.DEFINE_string('output_dir', None, 'Output directory.')
42

43
  flags.DEFINE_integer('predictions_per_example', 1,
44
                       'Number of prediction samples to generate per example.')
45
  flags.DEFINE_integer('max_examples', None,
46
                       'Maximum number of examples to process per dataset.')
47

48
  flags.DEFINE_string('dataset_name', None, 'Configured dataset name.')
49
  flags.DEFINE_integer('task', 0, 'Task number.')
50

51

52
def run(dataset_name, model_dir,
53
        predictions_per_example, max_examples,
54
        output_dir,
55
        fake_data=False):
56
  """Runs predictions on the given dataset using the specified model."""
57
  tf.io.gfile.makedirs(output_dir)
58
  data_config = image_data_utils.get_data_config(dataset_name)
59
  dataset = data_lib.build_dataset(data_config, fake_data=fake_data)
60
  if max_examples:
61
    dataset = dataset.take(max_examples)
62

63
  model = models_lib.load_model(model_dir)
64
  logging.info('Starting predictions.')
65
  predictions = experiment_utils.make_predictions(
66
      model, dataset.batch(_BATCH_SIZE), predictions_per_example)
67

68
  logging.info('Done computing predictions; recording results to disk.')
69
  array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name,
70
                        predictions)
71
  del predictions['logits_samples']
72
  array_utils.write_npz(output_dir, 'predictions_small_%s.npz' % dataset_name,
73
                        predictions)
74

75

76
def main(argv):
77
  if len(argv) > 1:
78
    raise app.UsageError('Too many command-line arguments.')
79

80
  run(FLAGS.dataset_name,
81
      FLAGS.model_dir,
82
      FLAGS.predictions_per_example,
83
      FLAGS.max_examples,
84
      FLAGS.output_dir.replace('%task%', str(FLAGS.task)))
85

86

87
if __name__ == '__main__':
88
  _declare_flags()
89
  app.run(main)
90

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.