google-research

Форк
0
/
model_utils.py 
66 строк · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Model utils."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import
23

24

25
def metric_fn(labels, logits):
26
  """Metric function for evaluation."""
27
  predictions = tf.argmax(logits, axis=1)
28
  top_1_accuracy = tf.metrics.accuracy(labels, predictions)
29

30
  return {
31
      'top_1_accuracy': top_1_accuracy,
32
  }
33

34

35
def get_label(labels, params, num_classes, batch_size=-1):  # pylint: disable=unused-argument
36
  """Returns the label."""
37
  one_hot_labels = tf.one_hot(tf.cast(labels, tf.int64), num_classes)
38
  return one_hot_labels
39

40

41
def update_exponential_moving_average(tensor, momentum, name=None):
42
  """Returns an exponential moving average of `tensor`.
43

44
  We will update the moving average every time the returned `tensor` is
45
  evaluated. A zero-debias will be applied, so we will return unbiased
46
  estimates during the first few training steps.
47
  Args:
48
    tensor: A floating point tensor.
49
    momentum: A scalar floating point Tensor with the same dtype as `tensor`.
50
    name: Optional string, the name of the operation in the TensorFlow graph.
51

52
  Returns:
53
    A Tensor with the same shape and dtype as `tensor`.
54
  """
55
  with tf.variable_scope(name, 'update_exponential_moving_average',
56
                         [tensor, momentum]):
57
    numerator = tf.get_variable(
58
        'numerator', initializer=0.0, trainable=False, use_resource=True)
59
    denominator = tf.get_variable(
60
        'denominator', initializer=0.0, trainable=False, use_resource=True)
61
    update_ops = [
62
        numerator.assign(momentum * numerator + (1 - momentum) * tensor),
63
        denominator.assign(momentum * denominator + (1 - momentum)),
64
    ]
65
    with tf.control_dependencies(update_ops):
66
      return numerator.read_value() / denominator.read_value()
67

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.