google-research
133 строки · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluates the model based on a performance metric."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22from absl import app23from absl import flags24import model25import model_utils26import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import27from tensorflow import estimator as tf_estimator28import tensorflow_datasets as tfds29
30flags.DEFINE_string('ckpt_path', '', 'Path to evaluation checkpoint')31flags.DEFINE_string('cls_dense_name', '', 'Final dense layer name')32
33flags.DEFINE_integer('train_batch_size', 600,34'The batch size for the target dataset.')35
36FLAGS = flags.FLAGS37
38NUM_EVAL_IMAGES = {39'mnist': 10000,40'svhn_cropped_small': 6000,41}
42
43
44def get_model_fn():45"""Returns the model definition."""46
47def model_fn(features, labels, mode, params):48"""Returns the model function."""49feature = features['feature']50labels = labels['label']51one_hot_labels = model_utils.get_label(52labels,53params,54FLAGS.src_num_classes,55batch_size=FLAGS.train_batch_size)56
57def get_logits():58"""Return the logits."""59network_output = model.conv_model(60feature,61mode,62target_dataset=FLAGS.target_dataset,63src_hw=FLAGS.src_hw,64target_hw=FLAGS.target_hw)65name = FLAGS.cls_dense_name66with tf.variable_scope('target_CLS'):67logits = tf.layers.dense(68inputs=network_output, units=FLAGS.src_num_classes, name=name)69return logits70
71logits = get_logits()72logits = tf.cast(logits, tf.float32)73
74dst_loss = tf.losses.softmax_cross_entropy(75logits=logits,76onehot_labels=one_hot_labels,77)78loss = dst_loss79
80eval_metrics = model_utils.metric_fn(labels, logits)81
82return tf_estimator.EstimatorSpec(83mode=mode,84loss=loss,85train_op=None,86eval_metric_ops=eval_metrics,87)88
89return model_fn90
91
92def main(unused_argv):93config = tf_estimator.RunConfig()94
95classifier = tf_estimator.Estimator(get_model_fn(), config=config)96
97def _merge_datasets(test_batch):98feature, label = test_batch['image'], test_batch['label'],99features = {100'feature': feature,101}102labels = {103'label': label,104}105return (features, labels)106
107def get_dataset(dataset_split):108"""Returns dataset creation function."""109
110def make_input_dataset():111"""Returns input dataset."""112test_data = tfds.load(name=FLAGS.target_dataset, split=dataset_split)113test_data = test_data.batch(FLAGS.train_batch_size)114dataset = tf.data.Dataset.zip((test_data,))115dataset = dataset.map(_merge_datasets)116dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)117return dataset118
119return make_input_dataset120
121num_eval_images = NUM_EVAL_IMAGES[FLAGS.target_dataset]122eval_steps = num_eval_images // FLAGS.train_batch_size123
124classifier.evaluate(125input_fn=get_dataset('test'),126steps=eval_steps,127checkpoint_path=FLAGS.ckpt_path,128)129
130
131if __name__ == '__main__':132tf.logging.set_verbosity(tf.logging.INFO)133app.run(main)134