google-research
66 строк · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Model utils."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import23
24
25def metric_fn(labels, logits):26"""Metric function for evaluation."""27predictions = tf.argmax(logits, axis=1)28top_1_accuracy = tf.metrics.accuracy(labels, predictions)29
30return {31'top_1_accuracy': top_1_accuracy,32}33
34
35def get_label(labels, params, num_classes, batch_size=-1): # pylint: disable=unused-argument36"""Returns the label."""37one_hot_labels = tf.one_hot(tf.cast(labels, tf.int64), num_classes)38return one_hot_labels39
40
41def update_exponential_moving_average(tensor, momentum, name=None):42"""Returns an exponential moving average of `tensor`.43
44We will update the moving average every time the returned `tensor` is
45evaluated. A zero-debias will be applied, so we will return unbiased
46estimates during the first few training steps.
47Args:
48tensor: A floating point tensor.
49momentum: A scalar floating point Tensor with the same dtype as `tensor`.
50name: Optional string, the name of the operation in the TensorFlow graph.
51
52Returns:
53A Tensor with the same shape and dtype as `tensor`.
54"""
55with tf.variable_scope(name, 'update_exponential_moving_average',56[tensor, momentum]):57numerator = tf.get_variable(58'numerator', initializer=0.0, trainable=False, use_resource=True)59denominator = tf.get_variable(60'denominator', initializer=0.0, trainable=False, use_resource=True)61update_ops = [62numerator.assign(momentum * numerator + (1 - momentum) * tensor),63denominator.assign(momentum * denominator + (1 - momentum)),64]65with tf.control_dependencies(update_ops):66return numerator.read_value() / denominator.read_value()67