google-research
432 строки · 14.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""A JAX implementation of Robust Bi-tempered loss.
17
18Source: https://bit.ly/3jSol8T
19"""
20
21import functools22
23import jax24from jax.lax import cond25from jax.lax import while_loop26import jax.numpy as jnp27from jax.scipy.special import logsumexp28
29
30@jax.jit31def _cross_entropy_loss(logits,32labels):33log_preds = jax.nn.log_softmax(logits)34return jnp.sum(labels * (jnp.log(labels + 1e-15) - log_preds), axis=-1)35
36
37@jax.jit38def log_t(u, t):39"""Compute log_t for `u`."""40
41def _internal_log_t(u, t):42return (jnp.power(u, (1.0 - t)) - 1.0) / (1.0 - t)43
44return cond(45jnp.abs(t - 1.0) < 1e-15, jnp.log,46functools.partial(_internal_log_t, t=t), u)47
48
49@jax.jit50def exp_t(u, t):51"""Compute exp_t for `u`."""52
53def _internal_exp_t(u, t):54return jnp.power(jnp.maximum(1.0 + (1.0 - t) * u, 0.0), 1.0 / (1.0 - t))55
56return cond(57jnp.abs(t - 1.0) < 1e-15, jnp.exp,58functools.partial(_internal_exp_t, t=t), u)59
60
61@jax.jit62def compute_normalization_fixed_point(activations,63t,64num_iters = 5):65"""Returns the normalization value for each example (t > 1.0).66
67Args:
68activations: A multi-dimensional array with last dimension `num_classes`.
69t: Temperature 2 (> 1.0 for tail heaviness).
70num_iters: Number of iterations to run the method.
71Return: An array of same rank as activation with the last dimension being 1.
72"""
73
74mu = jnp.max(activations, -1, keepdims=True)75normalized_activations_step_0 = activations - mu76
77def cond_fun(carry):78_, iters = carry79return iters < num_iters80
81def body_fun(carry):82normalized_activations, iters = carry83logt_partition = jnp.sum(84exp_t(normalized_activations, t), -1, keepdims=True)85normalized_activations_t = normalized_activations_step_0 * jnp.power(86logt_partition, 1.0 - t)87return normalized_activations_t, iters + 188
89normalized_activations_t, _ = while_loop(cond_fun, body_fun,90(normalized_activations_step_0, 0))91logt_partition = jnp.sum(92exp_t(normalized_activations_t, t), -1, keepdims=True)93return -log_t(1.0 / logt_partition, t) + mu94
95
96@jax.jit97def compute_normalization_binary_search(activations,98t,99num_iters = 10):100"""Returns the normalization value for each example (t < 1.0).101
102Args:
103activations: A multi-dimensional array with last dimension `num_classes`.
104t: Temperature 2 (< 1.0 for finite support).
105num_iters: Number of iterations to run the method.
106Return: An array of same rank as activation with the last dimension being 1.
107"""
108mu = jnp.max(activations, -1, keepdims=True)109normalized_activations = activations - mu110shape_activations = activations.shape111effective_dim = jnp.float32(112jnp.sum(113jnp.int32(normalized_activations > -1.0 / (1.0 - t)),114-1,115keepdims=True))116shape_partition = list(shape_activations[:-1]) + [1]117lower = jnp.zeros(shape_partition)118upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)119
120def cond_fun(carry):121_, _, iters = carry122return iters < num_iters123
124def body_fun(carry):125lower, upper, iters = carry126logt_partition = (upper + lower) / 2.0127sum_probs = jnp.sum(128exp_t(normalized_activations - logt_partition, t), -1, keepdims=True)129update = jnp.float32(sum_probs < 1.0)130lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition,131shape_partition)132upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition,133shape_partition)134return lower, upper, iters + 1135
136lower = jnp.zeros(shape_partition)137upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)138lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0))139
140logt_partition = (upper + lower) / 2.0141return logt_partition + mu142
143
144@jax.jit145def compute_tempered_normalization(activations,146t,147num_iters = 5):148return cond(149t < 1.0,150functools.partial(151compute_normalization_binary_search, t=t, num_iters=num_iters),152functools.partial(153compute_normalization_fixed_point, t=t, num_iters=num_iters),154activations)155
156
157@jax.jit158def compute_normalization(activations,159t,160num_iters = 5):161"""Returns the normalization value for each example.162
163Args:
164activations: A multi-dimensional array with last dimension `num_classes`.
165t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness).
166num_iters: Number of iterations to run the method.
167Return: An array of same rank as activation with the last dimension being 1.
168"""
169return cond(170jnp.abs(t - 1.0) < 1e-15,171functools.partial(logsumexp, axis=-1, keepdims=True),172functools.partial(173compute_tempered_normalization, t=t, num_iters=num_iters),174activations)175
176
177@jax.jit178def tempered_sigmoid(activations, t, num_iters=5):179"""Tempered sigmoid function.180
181Args:
182activations: Activations for the positive class for binary classification.
183t: Temperature array > 0.0.
184num_iters: Number of iterations to run the method.
185
186Returns:
187A probabilities array.
188"""
189input_shape = activations.shape190activations_2d = jnp.reshape(activations, [-1, 1])191internal_activations = jnp.concatenate(192[jnp.zeros_like(activations_2d), activations_2d], 1)193internal_probabilities = tempered_softmax(internal_activations, t, num_iters)194one_class_probabilities = internal_probabilities[:, 1]195return jnp.reshape(one_class_probabilities, input_shape)196
197
198@jax.custom_vjp199def tempered_softmax(activations, t, num_iters=5):200"""Tempered softmax function with custom gradient.201
202Args:
203activations: A multi-dimensional array with last dimension `num_classes`.
204t: Temperature array > 0.0.
205num_iters: Number of iterations to run the method.
206
207Returns:
208A probabilities array.
209"""
210probabilities, _ = tempered_softmax_fwd(activations, t, num_iters)211return probabilities212
213
214@jax.jit215def tempered_softmax_fwd(activations, t, num_iters=5):216"""Forward pass function for tempered softmax function.217
218Args:
219activations: A multi-dimensional array with last dimension `num_classes`.
220t: Temperature array > 0.0.
221num_iters: Number of iterations to run the method.
222
223Returns:
224A probabilities array, residuals.
225"""
226activations = jnp.asarray(activations, dtype=float)227normalization_constants = compute_normalization(activations, t, num_iters)228probabilities = exp_t(activations - normalization_constants, t)229return probabilities, (probabilities, t)230
231
232@jax.jit233def tempered_softmax_bwd(res, d_softmax):234"""Backward pass function for tempered softmax function.235
236Args:
237res: Residuals.
238d_softmax: Differential.
239
240Returns:
241Derivatives.
242"""
243probabilities, t = res244probabilities_pow_t = jnp.power(probabilities, t)245escorts = probabilities_pow_t / jnp.sum(246probabilities_pow_t, -1, keepdims=True)247derivative = probabilities_pow_t * (1. - escorts)248return (jnp.multiply(d_softmax, derivative), None, None)249
250
251tempered_softmax.defvjp(tempered_softmax_fwd, tempered_softmax_bwd)252
253
254def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):255"""Computes the Bi-Tempered logistic loss.256
257Args:
258activations: A multi-dimensional array with last dimension `num_classes`.
259labels: batch_size
260t1: Temperature 1 (< 1.0 for boundedness).
261t2: Temperature 2 (> 1.0 for tail heaviness).
262
263Returns:
264A loss array for robust loss.
265"""
266normalization_constants = compute_normalization(activations, t2, num_iters=5)267if t2 == 1.0:268if t1 == 1.0:269return normalization_constants + jnp.sum(270jnp.multiply(labels,271jnp.log(labels + 1e-10) - activations), -1)272else:273shifted_activations = jnp.exp(activations - normalization_constants)274one_minus_t1 = (1.0 - t1)275one_minus_t2 = 1.0276else:277one_minus_t1 = (1.0 - t1)278one_minus_t2 = (1.0 - t2)279shifted_activations = jnp.maximum(2801.0 + one_minus_t2 * (activations - normalization_constants), 0.0)281
282if t1 == 1.0:283return jnp.sum(284jnp.multiply(285jnp.log(labels + 1e-10) -286jnp.log(jnp.power(shifted_activations, 1.0 / one_minus_t2)),287labels), -1)288else:289beta = 1.0 + one_minus_t1290logt_probs = (jnp.power(shifted_activations, one_minus_t1 / one_minus_t2) -2911.0) / one_minus_t1292return jnp.sum(293jnp.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *294(jnp.power(labels, beta) -295jnp.power(shifted_activations, beta / one_minus_t2)), -1)296
297
298@jax.custom_vjp299def bi_tempered_logistic_loss(activations,300labels,301t1,302t2,303label_smoothing=0.0,304num_iters=5):305"""Bi-Tempered Logistic Loss with custom gradient.306
307Args:
308activations: A multi-dimensional array with last dimension `num_classes`.
309labels: An array with shape and dtype as activations.
310t1: Temperature 1 (< 1.0 for boundedness).
311t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
312label_smoothing: Label smoothing parameter between [0, 1).
313num_iters: Number of iterations to run the method.
314
315Returns:
316A loss array.
317"""
318activations = jnp.asarray(activations, dtype=float)319labels = jnp.asarray(labels, dtype=float)320loss_values, _ = bi_tempered_logistic_loss_fwd(activations, labels, t1, t2,321label_smoothing, num_iters)322return loss_values323
324
325@jax.jit326def bi_tempered_logistic_loss_fwd(activations,327labels,328t1,329t2,330label_smoothing=0.0,331num_iters=5):332"""Forward pass function for bi-tempered logistic loss.333
334Args:
335activations: A multi-dimensional array with last dimension `num_classes`.
336labels: An array with shape and dtype as activations.
337t1: Temperature 1 (< 1.0 for boundedness).
338t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
339label_smoothing: Label smoothing parameter between [0, 1).
340num_iters: Number of iterations to run the method.
341
342Returns:
343A loss array, residuals.
344"""
345num_classes = jnp.int32(labels.shape[-1])346labels = cond(347label_smoothing > 0.0,348lambda u: # pylint: disable=g-long-lambda349(1 - num_classes /350(num_classes - 1) * label_smoothing) * u + label_smoothing /351(num_classes - 1),352lambda u: u,353labels)354probabilities = tempered_softmax(activations, t2, num_iters)355
356def _tempred_cross_entropy_loss(unused_activations):357loss_values = jnp.multiply(358labels,359log_t(labels + 1e-10, t1) -360log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * (361jnp.power(labels, 2.0 - t1) - jnp.power(probabilities, 2.0 - t1))362loss_values = jnp.sum(loss_values, -1)363return loss_values364
365loss_values = cond(366jnp.logical_and(367jnp.less(jnp.abs(t1 - 1.0), 1e-15),368jnp.less(jnp.abs(t2 - 1.0), 1e-15)),369functools.partial(_cross_entropy_loss, labels=labels),370_tempred_cross_entropy_loss,371activations)372return loss_values, (labels, t1, t2, probabilities)373
374
375@jax.jit376def bi_tempered_logistic_loss_bwd(res, d_loss):377"""Backward pass function for bi-tempered logistic loss.378
379Args:
380res: Residuals.
381d_loss: Differential.
382
383Returns:
384Derivatives.
385"""
386labels, t1, t2, probabilities = res387delta_probs = probabilities - labels388forget_factor = jnp.power(probabilities, t2 - t1)389delta_probs_times_forget_factor = jnp.multiply(delta_probs, forget_factor)390delta_forget_sum = jnp.sum(391delta_probs_times_forget_factor, -1, keepdims=True)392escorts = jnp.power(probabilities, t2)393escorts = escorts / jnp.sum(escorts, -1, keepdims=True)394derivative = delta_probs_times_forget_factor - jnp.multiply(395escorts, delta_forget_sum)396if len(d_loss.shape) < len(derivative.shape):397d_loss = jnp.expand_dims(d_loss, -1)398return (jnp.multiply(d_loss, derivative), None, None, None, None, None)399
400
401bi_tempered_logistic_loss.defvjp(402bi_tempered_logistic_loss_fwd, bi_tempered_logistic_loss_bwd)403
404
405def bi_tempered_binary_logistic_loss(activations,406labels,407t1,408t2,409label_smoothing=0.0,410num_iters=5):411"""Bi-Tempered binary logistic loss.412
413Args:
414activations: An array containing activations for class 1.
415labels: An array with shape and dtype as activations.
416t1: Temperature 1 (< 1.0 for boundedness).
417t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
418label_smoothing: Label smoothing
419num_iters: Number of iterations to run the method.
420
421Returns:
422A loss array.
423"""
424out_shape = labels.shape425labels_2d = jnp.reshape(labels, [-1, 1])426activations_2d = jnp.reshape(activations, [-1, 1])427internal_labels = jnp.concatenate([1.0 - labels_2d, labels_2d], 1)428internal_logits = jnp.concatenate(429[jnp.zeros_like(activations_2d), activations_2d], 1)430losses = bi_tempered_logistic_loss(internal_logits, internal_labels, t1, t2,431label_smoothing, num_iters)432return jnp.reshape(losses, out_shape)433