google-research
93 строки · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Loss functions from https://openreview.net/forum?id=8twKpG5s8Qh.
17
18Loss functions have the following args (in addition to hyperparameters):
19labels: (num_examples, num_classes) matrix of one-hot labels.
20final_layer: (num_examples, num_classes) matrix of final-layer outputs.
21weights: Scalar or vector of weights for individual examples. In standard
22usage, this argument should be 1 or, when evaluating with batches that,
23may be padded, a vector containing 1s for real examples and 0 for padding.
24
25Each loss function returns:
26loss: The value of the loss.
27outputs: (num_examples, num_classes) matrix of outputs. For loss functions
28that normalize or scale outputs before computing softmax cross-entropy,
29these outputs incorporate that normalization/scaling. For other loss
30functions, these outputs are identical to `final_layer`.
31"""
32
33import math
34import tensorflow.compat.v1 as tf
35
36
37def softmax(labels, final_layer, weights):
38return (tf.losses.softmax_cross_entropy(labels, final_layer, weights),
39final_layer)
40
41
42def label_smoothing(labels, final_layer, weights, alpha):
43return (
44tf.losses.softmax_cross_entropy(
45labels, final_layer, weights * 1/(1 - alpha), label_smoothing=alpha),
46final_layer)
47
48
49# Penultimate layer dropout implemented in resnet_model.py and uses ordinary
50# softmax.
51dropout = softmax
52
53
54def extra_final_layer_l2(labels, final_layer, weights, lambda_,
55final_layer_weights_variable_name='dense/kernel'):
56xent = tf.losses.softmax_cross_entropy(labels, final_layer, weights)
57final_layer_weights = tf.trainable_variables(
58final_layer_weights_variable_name)[0]
59l2 = lambda_ * tf.nn.l2_loss(final_layer_weights)
60tf.losses.add_loss(l2)
61return xent + l2, final_layer
62
63
64def logit_penalty(labels, final_layer, weights, beta):
65xent = tf.losses.softmax_cross_entropy(labels, final_layer, weights)
66penalty = tf.losses.compute_weighted_loss(
67beta * tf.reduce_sum(final_layer ** 2, -1) / 2, weights)
68return xent + penalty, final_layer
69
70
71def logit_normalization(labels, final_layer, weights, tau):
72logits = tf.nn.l2_normalize(final_layer, axis=-1) / tau
73return tf.losses.softmax_cross_entropy(labels, logits, weights), logits
74
75
76def cosine_softmax(labels, final_layer, weights, tau):
77logits = final_layer / tau
78return tf.losses.softmax_cross_entropy(labels, logits, weights), logits
79
80
81def sigmoid(labels, final_layer, weights):
82logits = final_layer - math.log(int(final_layer.shape[1]))
83return tf.losses.sigmoid_cross_entropy(
84labels, logits, weights * int(final_layer.shape[1])), logits
85
86
87def squared_error(labels, final_layer, weights, kappa, m, loss_scale):
88correct_class_loss = tf.squared_difference(
89tf.reduce_sum(final_layer * labels, -1), m)
90other_class_loss = tf.reduce_sum(tf.square(final_layer * (1.0 - labels)), -1)
91return tf.losses.compute_weighted_loss(
92loss_scale * (kappa * correct_class_loss + other_class_loss),
93weights), final_layer
94