google-research

Форк
0
102 строки · 3.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
r"""Implements the general form of the loss.
17

18
This is the simplest way of using this loss. No parameters will be tuned
19
automatically, it's just a simple function that takes in parameters (likely
20
hand-tuned ones) and return a loss. For an adaptive loss, look at adaptive.py
21
or distribution.py.
22
"""
23

24
import jax
25
import jax.numpy as jnp
26

27

28
@jax.custom_jvp
29
def fake_clip(a, a_min, a_max):
30
  """jnp.clip() but the gradient doesn't get clipped on the backward pass."""
31
  return jnp.clip(a, a_min, a_max)
32

33

34
@fake_clip.defjvp
35
def fake_clip_jvp(primals, tangents):
36
  """Override fake_clip()'s gradient so that it's a no-op."""
37
  return jnp.clip(*primals), tangents[0]
38

39

40
@jax.jit
41
def lossfun(x, alpha, scale):
42
  r"""Implements the general form of the loss.
43

44
  This implements the rho(x, \alpha, c) function described in "A General and
45
  Adaptive Robust Loss Function", Jonathan T. Barron,
46
  https://arxiv.org/abs/1701.03077.
47

48
  Args:
49
    x: The residual for which the loss is being computed. x can have any shape,
50
      and alpha and scale will be broadcasted to match x's shape if necessary.
51
    alpha: The shape parameter of the loss (\alpha in the paper), where more
52
      negative values produce a loss with more robust behavior (outliers "cost"
53
      less), and more positive values produce a loss with less robust behavior
54
      (outliers are penalized more heavily). Alpha can be any value in
55
      [-infinity, infinity], but the gradient of the loss with respect to alpha
56
      is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
57
      interpolation between several discrete robust losses:
58
        alpha=-Infinity: Welsch/Leclerc Loss.
59
        alpha=-2: Geman-McClure loss.
60
        alpha=0: Cauchy/Lortentzian loss.
61
        alpha=1: Charbonnier/pseudo-Huber loss.
62
        alpha=2: L2 loss.
63
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
64
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
65
      different shape according to alpha.
66

67
  Returns:
68
    The losses for each element of x, in the same shape as x.
69
  """
70
  eps = jnp.finfo(jnp.float32).eps
71
  maxval = 1e15
72

73
  # A "safe" versions of expm1 that will not NaN-out on large inputs.
74
  expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 43))
75

76
  # `scale` must be > 0.
77
  scale = jnp.maximum(eps, scale)
78

79
  # Large values of |x| can cause non-finite gradients.
80
  x = fake_clip(x, -maxval, maxval)
81

82
  # The loss when alpha == 2. This will get reused repeatedly.
83
  loss_two = 0.5 * (x / scale)**2
84

85
  # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
86
  a = jnp.where(alpha >= 0, jnp.ones_like(alpha),
87
                -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha))
88

89
  # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
90
  b = jnp.maximum(eps, jnp.abs(a - 2))
91

92
  # The loss when not in one of the special casess.
93
  loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * a) - 1)
94

95
  # Select which of the cases of the loss to return as a function of alpha.
96
  return jnp.where(
97
      alpha == -jnp.inf, -expm1_safe(-loss_two),
98
      jnp.where(
99
          alpha == 0, jnp.log1p(loss_two),
100
          jnp.where(alpha == 2, loss_two,
101
                    jnp.where(alpha == jnp.inf, expm1_safe(loss_two),
102
                              loss_ow))))
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.