google-research

Форк
0
/
evaluate.py 
133 строки · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluates the model based on a performance metric."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from absl import app
23
from absl import flags
24
import model
25
import model_utils
26
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
27
from tensorflow import estimator as tf_estimator
28
import tensorflow_datasets as tfds
29

30
flags.DEFINE_string('ckpt_path', '', 'Path to evaluation checkpoint')
31
flags.DEFINE_string('cls_dense_name', '', 'Final dense layer name')
32

33
flags.DEFINE_integer('train_batch_size', 600,
34
                     'The batch size for the target dataset.')
35

36
FLAGS = flags.FLAGS
37

38
NUM_EVAL_IMAGES = {
39
    'mnist': 10000,
40
    'svhn_cropped_small': 6000,
41
}
42

43

44
def get_model_fn():
45
  """Returns the model definition."""
46

47
  def model_fn(features, labels, mode, params):
48
    """Returns the model function."""
49
    feature = features['feature']
50
    labels = labels['label']
51
    one_hot_labels = model_utils.get_label(
52
        labels,
53
        params,
54
        FLAGS.src_num_classes,
55
        batch_size=FLAGS.train_batch_size)
56

57
    def get_logits():
58
      """Return the logits."""
59
      network_output = model.conv_model(
60
          feature,
61
          mode,
62
          target_dataset=FLAGS.target_dataset,
63
          src_hw=FLAGS.src_hw,
64
          target_hw=FLAGS.target_hw)
65
      name = FLAGS.cls_dense_name
66
      with tf.variable_scope('target_CLS'):
67
        logits = tf.layers.dense(
68
            inputs=network_output, units=FLAGS.src_num_classes, name=name)
69
      return logits
70

71
    logits = get_logits()
72
    logits = tf.cast(logits, tf.float32)
73

74
    dst_loss = tf.losses.softmax_cross_entropy(
75
        logits=logits,
76
        onehot_labels=one_hot_labels,
77
    )
78
    loss = dst_loss
79

80
    eval_metrics = model_utils.metric_fn(labels, logits)
81

82
    return tf_estimator.EstimatorSpec(
83
        mode=mode,
84
        loss=loss,
85
        train_op=None,
86
        eval_metric_ops=eval_metrics,
87
    )
88

89
  return model_fn
90

91

92
def main(unused_argv):
93
  config = tf_estimator.RunConfig()
94

95
  classifier = tf_estimator.Estimator(get_model_fn(), config=config)
96

97
  def _merge_datasets(test_batch):
98
    feature, label = test_batch['image'], test_batch['label'],
99
    features = {
100
        'feature': feature,
101
    }
102
    labels = {
103
        'label': label,
104
    }
105
    return (features, labels)
106

107
  def get_dataset(dataset_split):
108
    """Returns dataset creation function."""
109

110
    def make_input_dataset():
111
      """Returns input dataset."""
112
      test_data = tfds.load(name=FLAGS.target_dataset, split=dataset_split)
113
      test_data = test_data.batch(FLAGS.train_batch_size)
114
      dataset = tf.data.Dataset.zip((test_data,))
115
      dataset = dataset.map(_merge_datasets)
116
      dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
117
      return dataset
118

119
    return make_input_dataset
120

121
  num_eval_images = NUM_EVAL_IMAGES[FLAGS.target_dataset]
122
  eval_steps = num_eval_images // FLAGS.train_batch_size
123

124
  classifier.evaluate(
125
      input_fn=get_dataset('test'),
126
      steps=eval_steps,
127
      checkpoint_path=FLAGS.ckpt_path,
128
  )
129

130

131
if __name__ == '__main__':
132
  tf.logging.set_verbosity(tf.logging.INFO)
133
  app.run(main)
134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.