google-research

Форк
0
432 строки · 14.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A JAX implementation of Robust Bi-tempered loss.
17

18
Source: https://bit.ly/3jSol8T
19
"""
20

21
import functools
22

23
import jax
24
from jax.lax import cond
25
from jax.lax import while_loop
26
import jax.numpy as jnp
27
from jax.scipy.special import logsumexp
28

29

30
@jax.jit
31
def _cross_entropy_loss(logits,
32
                        labels):
33
  log_preds = jax.nn.log_softmax(logits)
34
  return jnp.sum(labels * (jnp.log(labels + 1e-15) - log_preds), axis=-1)
35

36

37
@jax.jit
38
def log_t(u, t):
39
  """Compute log_t for `u`."""
40

41
  def _internal_log_t(u, t):
42
    return (jnp.power(u, (1.0 - t)) - 1.0) / (1.0 - t)
43

44
  return cond(
45
      jnp.abs(t - 1.0) < 1e-15, jnp.log,
46
      functools.partial(_internal_log_t, t=t), u)
47

48

49
@jax.jit
50
def exp_t(u, t):
51
  """Compute exp_t for `u`."""
52

53
  def _internal_exp_t(u, t):
54
    return jnp.power(jnp.maximum(1.0 + (1.0 - t) * u, 0.0), 1.0 / (1.0 - t))
55

56
  return cond(
57
      jnp.abs(t - 1.0) < 1e-15, jnp.exp,
58
      functools.partial(_internal_exp_t, t=t), u)
59

60

61
@jax.jit
62
def compute_normalization_fixed_point(activations,
63
                                      t,
64
                                      num_iters = 5):
65
  """Returns the normalization value for each example (t > 1.0).
66

67
  Args:
68
    activations: A multi-dimensional array with last dimension `num_classes`.
69
    t: Temperature 2 (> 1.0 for tail heaviness).
70
    num_iters: Number of iterations to run the method.
71
  Return: An array of same rank as activation with the last dimension being 1.
72
  """
73

74
  mu = jnp.max(activations, -1, keepdims=True)
75
  normalized_activations_step_0 = activations - mu
76

77
  def cond_fun(carry):
78
    _, iters = carry
79
    return iters < num_iters
80

81
  def body_fun(carry):
82
    normalized_activations, iters = carry
83
    logt_partition = jnp.sum(
84
        exp_t(normalized_activations, t), -1, keepdims=True)
85
    normalized_activations_t = normalized_activations_step_0 * jnp.power(
86
        logt_partition, 1.0 - t)
87
    return normalized_activations_t, iters + 1
88

89
  normalized_activations_t, _ = while_loop(cond_fun, body_fun,
90
                                           (normalized_activations_step_0, 0))
91
  logt_partition = jnp.sum(
92
      exp_t(normalized_activations_t, t), -1, keepdims=True)
93
  return -log_t(1.0 / logt_partition, t) + mu
94

95

96
@jax.jit
97
def compute_normalization_binary_search(activations,
98
                                        t,
99
                                        num_iters = 10):
100
  """Returns the normalization value for each example (t < 1.0).
101

102
  Args:
103
    activations: A multi-dimensional array with last dimension `num_classes`.
104
    t: Temperature 2 (< 1.0 for finite support).
105
    num_iters: Number of iterations to run the method.
106
  Return: An array of same rank as activation with the last dimension being 1.
107
  """
108
  mu = jnp.max(activations, -1, keepdims=True)
109
  normalized_activations = activations - mu
110
  shape_activations = activations.shape
111
  effective_dim = jnp.float32(
112
      jnp.sum(
113
          jnp.int32(normalized_activations > -1.0 / (1.0 - t)),
114
          -1,
115
          keepdims=True))
116
  shape_partition = list(shape_activations[:-1]) + [1]
117
  lower = jnp.zeros(shape_partition)
118
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
119

120
  def cond_fun(carry):
121
    _, _, iters = carry
122
    return iters < num_iters
123

124
  def body_fun(carry):
125
    lower, upper, iters = carry
126
    logt_partition = (upper + lower) / 2.0
127
    sum_probs = jnp.sum(
128
        exp_t(normalized_activations - logt_partition, t), -1, keepdims=True)
129
    update = jnp.float32(sum_probs < 1.0)
130
    lower = jnp.reshape(lower * update + (1.0 - update) * logt_partition,
131
                        shape_partition)
132
    upper = jnp.reshape(upper * (1.0 - update) + update * logt_partition,
133
                        shape_partition)
134
    return lower, upper, iters + 1
135

136
  lower = jnp.zeros(shape_partition)
137
  upper = -log_t(1.0 / effective_dim, t) * jnp.ones(shape_partition)
138
  lower, upper, _ = while_loop(cond_fun, body_fun, (lower, upper, 0))
139

140
  logt_partition = (upper + lower) / 2.0
141
  return logt_partition + mu
142

143

144
@jax.jit
145
def compute_tempered_normalization(activations,
146
                                   t,
147
                                   num_iters = 5):
148
  return cond(
149
      t < 1.0,
150
      functools.partial(
151
          compute_normalization_binary_search, t=t, num_iters=num_iters),
152
      functools.partial(
153
          compute_normalization_fixed_point, t=t, num_iters=num_iters),
154
      activations)
155

156

157
@jax.jit
158
def compute_normalization(activations,
159
                          t,
160
                          num_iters = 5):
161
  """Returns the normalization value for each example.
162

163
  Args:
164
    activations: A multi-dimensional array with last dimension `num_classes`.
165
    t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness).
166
    num_iters: Number of iterations to run the method.
167
  Return: An array of same rank as activation with the last dimension being 1.
168
  """
169
  return cond(
170
      jnp.abs(t - 1.0) < 1e-15,
171
      functools.partial(logsumexp, axis=-1, keepdims=True),
172
      functools.partial(
173
          compute_tempered_normalization, t=t, num_iters=num_iters),
174
      activations)
175

176

177
@jax.jit
178
def tempered_sigmoid(activations, t, num_iters=5):
179
  """Tempered sigmoid function.
180

181
  Args:
182
    activations: Activations for the positive class for binary classification.
183
    t: Temperature array > 0.0.
184
    num_iters: Number of iterations to run the method.
185

186
  Returns:
187
    A probabilities array.
188
  """
189
  input_shape = activations.shape
190
  activations_2d = jnp.reshape(activations, [-1, 1])
191
  internal_activations = jnp.concatenate(
192
      [jnp.zeros_like(activations_2d), activations_2d], 1)
193
  internal_probabilities = tempered_softmax(internal_activations, t, num_iters)
194
  one_class_probabilities = internal_probabilities[:, 1]
195
  return jnp.reshape(one_class_probabilities, input_shape)
196

197

198
@jax.custom_vjp
199
def tempered_softmax(activations, t, num_iters=5):
200
  """Tempered softmax function with custom gradient.
201

202
  Args:
203
    activations: A multi-dimensional array with last dimension `num_classes`.
204
    t: Temperature array > 0.0.
205
    num_iters: Number of iterations to run the method.
206

207
  Returns:
208
    A probabilities array.
209
  """
210
  probabilities, _ = tempered_softmax_fwd(activations, t, num_iters)
211
  return probabilities
212

213

214
@jax.jit
215
def tempered_softmax_fwd(activations, t, num_iters=5):
216
  """Forward pass function for tempered softmax function.
217

218
  Args:
219
    activations: A multi-dimensional array with last dimension `num_classes`.
220
    t: Temperature array > 0.0.
221
    num_iters: Number of iterations to run the method.
222

223
  Returns:
224
    A probabilities array, residuals.
225
  """
226
  activations = jnp.asarray(activations, dtype=float)
227
  normalization_constants = compute_normalization(activations, t, num_iters)
228
  probabilities = exp_t(activations - normalization_constants, t)
229
  return probabilities, (probabilities, t)
230

231

232
@jax.jit
233
def tempered_softmax_bwd(res, d_softmax):
234
  """Backward pass function for tempered softmax function.
235

236
  Args:
237
    res: Residuals.
238
    d_softmax: Differential.
239

240
  Returns:
241
    Derivatives.
242
  """
243
  probabilities, t = res
244
  probabilities_pow_t = jnp.power(probabilities, t)
245
  escorts = probabilities_pow_t / jnp.sum(
246
      probabilities_pow_t, -1, keepdims=True)
247
  derivative = probabilities_pow_t * (1. - escorts)
248
  return (jnp.multiply(d_softmax, derivative), None, None)
249

250

251
tempered_softmax.defvjp(tempered_softmax_fwd, tempered_softmax_bwd)
252

253

254
def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):
255
  """Computes the Bi-Tempered logistic loss.
256

257
  Args:
258
    activations: A multi-dimensional array with last dimension `num_classes`.
259
    labels: batch_size
260
    t1: Temperature 1 (< 1.0 for boundedness).
261
    t2: Temperature 2 (> 1.0 for tail heaviness).
262

263
  Returns:
264
    A loss array for robust loss.
265
  """
266
  normalization_constants = compute_normalization(activations, t2, num_iters=5)
267
  if t2 == 1.0:
268
    if t1 == 1.0:
269
      return normalization_constants + jnp.sum(
270
          jnp.multiply(labels,
271
                       jnp.log(labels + 1e-10) - activations), -1)
272
    else:
273
      shifted_activations = jnp.exp(activations - normalization_constants)
274
      one_minus_t1 = (1.0 - t1)
275
      one_minus_t2 = 1.0
276
  else:
277
    one_minus_t1 = (1.0 - t1)
278
    one_minus_t2 = (1.0 - t2)
279
    shifted_activations = jnp.maximum(
280
        1.0 + one_minus_t2 * (activations - normalization_constants), 0.0)
281

282
  if t1 == 1.0:
283
    return jnp.sum(
284
        jnp.multiply(
285
            jnp.log(labels + 1e-10) -
286
            jnp.log(jnp.power(shifted_activations, 1.0 / one_minus_t2)),
287
            labels), -1)
288
  else:
289
    beta = 1.0 + one_minus_t1
290
    logt_probs = (jnp.power(shifted_activations, one_minus_t1 / one_minus_t2) -
291
                  1.0) / one_minus_t1
292
    return jnp.sum(
293
        jnp.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *
294
        (jnp.power(labels, beta) -
295
         jnp.power(shifted_activations, beta / one_minus_t2)), -1)
296

297

298
@jax.custom_vjp
299
def bi_tempered_logistic_loss(activations,
300
                              labels,
301
                              t1,
302
                              t2,
303
                              label_smoothing=0.0,
304
                              num_iters=5):
305
  """Bi-Tempered Logistic Loss with custom gradient.
306

307
  Args:
308
    activations: A multi-dimensional array with last dimension `num_classes`.
309
    labels: An array with shape and dtype as activations.
310
    t1: Temperature 1 (< 1.0 for boundedness).
311
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
312
    label_smoothing: Label smoothing parameter between [0, 1).
313
    num_iters: Number of iterations to run the method.
314

315
  Returns:
316
    A loss array.
317
  """
318
  activations = jnp.asarray(activations, dtype=float)
319
  labels = jnp.asarray(labels, dtype=float)
320
  loss_values, _ = bi_tempered_logistic_loss_fwd(activations, labels, t1, t2,
321
                                                 label_smoothing, num_iters)
322
  return loss_values
323

324

325
@jax.jit
326
def bi_tempered_logistic_loss_fwd(activations,
327
                                  labels,
328
                                  t1,
329
                                  t2,
330
                                  label_smoothing=0.0,
331
                                  num_iters=5):
332
  """Forward pass function for bi-tempered logistic loss.
333

334
  Args:
335
    activations: A multi-dimensional array with last dimension `num_classes`.
336
    labels: An array with shape and dtype as activations.
337
    t1: Temperature 1 (< 1.0 for boundedness).
338
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
339
    label_smoothing: Label smoothing parameter between [0, 1).
340
    num_iters: Number of iterations to run the method.
341

342
  Returns:
343
    A loss array, residuals.
344
  """
345
  num_classes = jnp.int32(labels.shape[-1])
346
  labels = cond(
347
      label_smoothing > 0.0,
348
      lambda u:  # pylint: disable=g-long-lambda
349
      (1 - num_classes /
350
       (num_classes - 1) * label_smoothing) * u + label_smoothing /
351
      (num_classes - 1),
352
      lambda u: u,
353
      labels)
354
  probabilities = tempered_softmax(activations, t2, num_iters)
355

356
  def _tempred_cross_entropy_loss(unused_activations):
357
    loss_values = jnp.multiply(
358
        labels,
359
        log_t(labels + 1e-10, t1) -
360
        log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * (
361
            jnp.power(labels, 2.0 - t1) - jnp.power(probabilities, 2.0 - t1))
362
    loss_values = jnp.sum(loss_values, -1)
363
    return loss_values
364

365
  loss_values = cond(
366
      jnp.logical_and(
367
          jnp.less(jnp.abs(t1 - 1.0), 1e-15),
368
          jnp.less(jnp.abs(t2 - 1.0), 1e-15)),
369
      functools.partial(_cross_entropy_loss, labels=labels),
370
      _tempred_cross_entropy_loss,
371
      activations)
372
  return loss_values, (labels, t1, t2, probabilities)
373

374

375
@jax.jit
376
def bi_tempered_logistic_loss_bwd(res, d_loss):
377
  """Backward pass function for bi-tempered logistic loss.
378

379
  Args:
380
    res: Residuals.
381
    d_loss: Differential.
382

383
  Returns:
384
    Derivatives.
385
  """
386
  labels, t1, t2, probabilities = res
387
  delta_probs = probabilities - labels
388
  forget_factor = jnp.power(probabilities, t2 - t1)
389
  delta_probs_times_forget_factor = jnp.multiply(delta_probs, forget_factor)
390
  delta_forget_sum = jnp.sum(
391
      delta_probs_times_forget_factor, -1, keepdims=True)
392
  escorts = jnp.power(probabilities, t2)
393
  escorts = escorts / jnp.sum(escorts, -1, keepdims=True)
394
  derivative = delta_probs_times_forget_factor - jnp.multiply(
395
      escorts, delta_forget_sum)
396
  if len(d_loss.shape) < len(derivative.shape):
397
    d_loss = jnp.expand_dims(d_loss, -1)
398
  return (jnp.multiply(d_loss, derivative), None, None, None, None, None)
399

400

401
bi_tempered_logistic_loss.defvjp(
402
    bi_tempered_logistic_loss_fwd, bi_tempered_logistic_loss_bwd)
403

404

405
def bi_tempered_binary_logistic_loss(activations,
406
                                     labels,
407
                                     t1,
408
                                     t2,
409
                                     label_smoothing=0.0,
410
                                     num_iters=5):
411
  """Bi-Tempered binary logistic loss.
412

413
  Args:
414
    activations: An array containing activations for class 1.
415
    labels: An array with shape and dtype as activations.
416
    t1: Temperature 1 (< 1.0 for boundedness).
417
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
418
    label_smoothing: Label smoothing
419
    num_iters: Number of iterations to run the method.
420

421
  Returns:
422
    A loss array.
423
  """
424
  out_shape = labels.shape
425
  labels_2d = jnp.reshape(labels, [-1, 1])
426
  activations_2d = jnp.reshape(activations, [-1, 1])
427
  internal_labels = jnp.concatenate([1.0 - labels_2d, labels_2d], 1)
428
  internal_logits = jnp.concatenate(
429
      [jnp.zeros_like(activations_2d), activations_2d], 1)
430
  losses = bi_tempered_logistic_loss(internal_logits, internal_labels, t1, t2,
431
                                     label_smoothing, num_iters)
432
  return jnp.reshape(losses, out_shape)
433

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.