google-research

Форк
0
93 строки · 3.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Loss functions from https://openreview.net/forum?id=8twKpG5s8Qh.
17

18
Loss functions have the following args (in addition to hyperparameters):
19
  labels: (num_examples, num_classes) matrix of one-hot labels.
20
  final_layer: (num_examples, num_classes) matrix of final-layer outputs.
21
  weights: Scalar or vector of weights for individual examples. In standard
22
    usage, this argument should be 1 or, when evaluating with batches that,
23
    may be padded, a vector containing 1s for real examples and 0 for padding.
24

25
Each loss function returns:
26
  loss: The value of the loss.
27
  outputs: (num_examples, num_classes) matrix of outputs. For loss functions
28
    that normalize or scale outputs before computing softmax cross-entropy,
29
    these outputs incorporate that normalization/scaling. For other loss
30
    functions, these outputs are identical to `final_layer`.
31
"""
32

33
import math
34
import tensorflow.compat.v1 as tf
35

36

37
def softmax(labels, final_layer, weights):
38
  return (tf.losses.softmax_cross_entropy(labels, final_layer, weights),
39
          final_layer)
40

41

42
def label_smoothing(labels, final_layer, weights, alpha):
43
  return (
44
      tf.losses.softmax_cross_entropy(
45
          labels, final_layer, weights * 1/(1 - alpha), label_smoothing=alpha),
46
      final_layer)
47

48

49
# Penultimate layer dropout implemented in resnet_model.py and uses ordinary
50
# softmax.
51
dropout = softmax
52

53

54
def extra_final_layer_l2(labels, final_layer, weights, lambda_,
55
                         final_layer_weights_variable_name='dense/kernel'):
56
  xent = tf.losses.softmax_cross_entropy(labels, final_layer, weights)
57
  final_layer_weights = tf.trainable_variables(
58
      final_layer_weights_variable_name)[0]
59
  l2 = lambda_ * tf.nn.l2_loss(final_layer_weights)
60
  tf.losses.add_loss(l2)
61
  return xent + l2, final_layer
62

63

64
def logit_penalty(labels, final_layer, weights, beta):
65
  xent = tf.losses.softmax_cross_entropy(labels, final_layer, weights)
66
  penalty = tf.losses.compute_weighted_loss(
67
      beta * tf.reduce_sum(final_layer ** 2, -1) / 2, weights)
68
  return xent + penalty, final_layer
69

70

71
def logit_normalization(labels, final_layer, weights, tau):
72
  logits = tf.nn.l2_normalize(final_layer, axis=-1) / tau
73
  return tf.losses.softmax_cross_entropy(labels, logits, weights), logits
74

75

76
def cosine_softmax(labels, final_layer, weights, tau):
77
  logits = final_layer / tau
78
  return tf.losses.softmax_cross_entropy(labels, logits, weights), logits
79

80

81
def sigmoid(labels, final_layer, weights):
82
  logits = final_layer - math.log(int(final_layer.shape[1]))
83
  return tf.losses.sigmoid_cross_entropy(
84
      labels, logits, weights * int(final_layer.shape[1])), logits
85

86

87
def squared_error(labels, final_layer, weights, kappa, m, loss_scale):
88
  correct_class_loss = tf.squared_difference(
89
      tf.reduce_sum(final_layer * labels, -1), m)
90
  other_class_loss = tf.reduce_sum(tf.square(final_layer * (1.0 - labels)), -1)
91
  return tf.losses.compute_weighted_loss(
92
      loss_scale * (kappa * correct_class_loss + other_class_loss),
93
      weights), final_layer
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.