google-research
102 строки · 3.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Implements the general form of the loss.
17
18This is the simplest way of using this loss. No parameters will be tuned
19automatically, it's just a simple function that takes in parameters (likely
20hand-tuned ones) and return a loss. For an adaptive loss, look at adaptive.py
21or distribution.py.
22"""
23
24import jax
25import jax.numpy as jnp
26
27
28@jax.custom_jvp
29def fake_clip(a, a_min, a_max):
30"""jnp.clip() but the gradient doesn't get clipped on the backward pass."""
31return jnp.clip(a, a_min, a_max)
32
33
34@fake_clip.defjvp
35def fake_clip_jvp(primals, tangents):
36"""Override fake_clip()'s gradient so that it's a no-op."""
37return jnp.clip(*primals), tangents[0]
38
39
40@jax.jit
41def lossfun(x, alpha, scale):
42r"""Implements the general form of the loss.
43
44This implements the rho(x, \alpha, c) function described in "A General and
45Adaptive Robust Loss Function", Jonathan T. Barron,
46https://arxiv.org/abs/1701.03077.
47
48Args:
49x: The residual for which the loss is being computed. x can have any shape,
50and alpha and scale will be broadcasted to match x's shape if necessary.
51alpha: The shape parameter of the loss (\alpha in the paper), where more
52negative values produce a loss with more robust behavior (outliers "cost"
53less), and more positive values produce a loss with less robust behavior
54(outliers are penalized more heavily). Alpha can be any value in
55[-infinity, infinity], but the gradient of the loss with respect to alpha
56is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
57interpolation between several discrete robust losses:
58alpha=-Infinity: Welsch/Leclerc Loss.
59alpha=-2: Geman-McClure loss.
60alpha=0: Cauchy/Lortentzian loss.
61alpha=1: Charbonnier/pseudo-Huber loss.
62alpha=2: L2 loss.
63scale: The scale parameter of the loss. When |x| < scale, the loss is an
64L2-like quadratic bowl, and when |x| > scale the loss function takes on a
65different shape according to alpha.
66
67Returns:
68The losses for each element of x, in the same shape as x.
69"""
70eps = jnp.finfo(jnp.float32).eps
71maxval = 1e15
72
73# A "safe" versions of expm1 that will not NaN-out on large inputs.
74expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 43))
75
76# `scale` must be > 0.
77scale = jnp.maximum(eps, scale)
78
79# Large values of |x| can cause non-finite gradients.
80x = fake_clip(x, -maxval, maxval)
81
82# The loss when alpha == 2. This will get reused repeatedly.
83loss_two = 0.5 * (x / scale)**2
84
85# Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
86a = jnp.where(alpha >= 0, jnp.ones_like(alpha),
87-jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha))
88
89# Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
90b = jnp.maximum(eps, jnp.abs(a - 2))
91
92# The loss when not in one of the special casess.
93loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * a) - 1)
94
95# Select which of the cases of the loss to return as a function of alpha.
96return jnp.where(
97alpha == -jnp.inf, -expm1_safe(-loss_two),
98jnp.where(
99alpha == 0, jnp.log1p(loss_two),
100jnp.where(alpha == 2, loss_two,
101jnp.where(alpha == jnp.inf, expm1_safe(loss_two),
102loss_ow))))
103