google-research

Форк
0
366 строк · 13.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for general.py."""
17

18
from absl.testing import absltest
19
from absl.testing import parameterized
20
import chex
21
import jax
22
from jax import random
23
import jax.numpy as jnp
24
import numpy as np
25

26
from robust_loss_jax import general
27

28

29
class LossfunTest(chex.TestCase, parameterized.TestCase):
30

31
  def _precompute_lossfun_inputs(self):
32
    """Precompute a loss and its derivatives for random inputs and parameters.
33

34
    Generates a large number of random inputs to the loss, and random
35
    shape/scale parameters for the loss function at each sample, and
36
    computes the loss and its derivative with respect to all inputs and
37
    parameters, returning everything to be used to assert various properties
38
    in our unit tests.
39

40
    Returns:
41
      A tuple containing:
42
       (the number (int) of samples, and the length of all following arrays,
43
        A tensor of losses for each sample,
44
        A tensor of residuals of each sample (the loss inputs),
45
        A tensor of shape parameters of each loss,
46
        A tensor of scale parameters of each loss,
47
        A tensor of derivatives of each loss wrt each x,
48
        A tensor of derivatives of each loss wrt each alpha,
49
        A tensor of derivatives of each loss wrt each scale)
50

51
    Typical usage example:
52
    (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
53
        = self._precompute_lossfun_inputs()
54
    """
55
    num_samples = 100000
56
    rng = random.PRNGKey(0)
57

58
    # Normally distributed inputs.
59
    rng, key = random.split(rng)
60
    x = random.normal(key, shape=[num_samples])
61

62
    # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1
63
    # to ensure that we hit the special cases at 0, 2.
64
    rng, key = random.split(rng)
65
    alpha = jnp.round(
66
        random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
67
        10) / 10.
68
    # Push the sampled alphas at the extents of the range to +/- infinity, so
69
    # that we probe those cases too.
70
    alpha = jnp.where(alpha == 3, jnp.inf, alpha)
71
    alpha = jnp.where(alpha == -16, -jnp.inf, alpha)
72

73
    # Random log-normally distributed values in approx (1e-5, 100000):
74
    rng, key = random.split(rng)
75
    scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5
76

77
    fn = self.variant(general.lossfun)
78
    loss = fn(x, alpha, scale)
79
    d_x, d_alpha, d_scale = (
80
        jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha,
81
                                                                  scale))
82

83
    return (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
84

85
  @chex.all_variants()
86
  def testDerivativeIsMonotonicWrtX(self):
87
    # Check that the loss increases monotonically with |x|.
88
    _, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()
89
    # This is just to suppress a warning below.
90
    d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))
91
    mask = jnp.isfinite(alpha) & (
92
        jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps))
93
    chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))
94

95
  @chex.all_variants()
96
  def testLossIsNearZeroAtOrigin(self):
97
    # Check that the loss is near-zero when x is near-zero.
98
    _, loss, x, _, _, _, _, _ = self._precompute_lossfun_inputs()
99
    loss_near_zero = loss[jnp.abs(x) < 1e-5]
100
    chex.assert_trees_all_close(
101
        loss_near_zero, jnp.zeros_like(loss_near_zero), atol=1e-5)
102

103
  @chex.all_variants()
104
  def testLossIsQuadraticNearOrigin(self):
105
    # Check that the loss is well-approximated by a quadratic bowl when
106
    # |x| < scale
107
    _, loss, x, _, scale, _, _, _ = self._precompute_lossfun_inputs()
108
    mask = jnp.abs(x) < (0.5 * scale)
109
    loss_quad = 0.5 * jnp.square(x / scale)
110
    chex.assert_trees_all_close(
111
        loss_quad[mask], loss[mask], rtol=1e-5, atol=1e-2)
112

113
  @chex.all_variants()
114
  def testLossIsBoundedWhenAlphaIsNegative(self):
115
    # Assert that loss < (alpha - 2)/alpha when alpha < 0.
116
    _, loss, _, alpha, _, _, _, _ = self._precompute_lossfun_inputs()
117
    mask = alpha < 0.
118
    min_val = jnp.finfo(jnp.float32).min
119
    alpha_clipped = jnp.maximum(min_val, alpha[mask])
120
    self.assertTrue(
121
        jnp.all(loss[mask] <= ((alpha_clipped - 2.) / alpha_clipped)))
122

123
  @chex.all_variants()
124
  def testDerivativeIsBoundedWhenAlphaIsBelow2(self):
125
    # Assert that |d_x| < |x|/scale^2 when alpha <= 2.
126
    _, _, x, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()
127
    mask = jnp.isfinite(alpha) & (alpha <= 2)
128
    grad = jnp.abs(d_x[mask])
129
    bound = ((jnp.abs(x[mask]) + (300. * jnp.finfo(jnp.float32).eps)) /
130
             scale[mask]**2)
131
    self.assertTrue(jnp.all(grad <= bound))
132

133
  @chex.all_variants()
134
  def testDerivativeIsBoundedWhenAlphaIsBelow1(self):
135
    # Assert that |d_x| < 1/scale when alpha <= 1.
136
    _, _, _, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()
137
    mask = jnp.isfinite(alpha) & (alpha <= 1)
138
    grad = jnp.abs(d_x[mask])
139
    bound = ((1. + (300. * jnp.finfo(jnp.float32).eps)) / scale[mask])
140
    self.assertTrue(jnp.all(grad <= bound))
141

142
  @chex.all_variants()
143
  def testAlphaDerivativeIsPositive(self):
144
    # Assert that d_loss / d_alpha > 0.
145
    _, _, _, alpha, _, _, d_alpha, _ = self._precompute_lossfun_inputs()
146
    mask = jnp.isfinite(alpha)
147
    self.assertTrue(
148
        jnp.all(d_alpha[mask] > (-300. * jnp.finfo(jnp.float32).eps)))
149

150
  @chex.all_variants()
151
  def testScaleDerivativeIsNegative(self):
152
    # Assert that d_loss / d_scale < 0.
153
    _, _, _, alpha, _, _, _, d_scale = self._precompute_lossfun_inputs()
154
    mask = jnp.isfinite(alpha)
155
    self.assertTrue(
156
        jnp.all(d_scale[mask] < (300. * jnp.finfo(jnp.float32).eps)))
157

158
  @chex.all_variants()
159
  def testLossIsScaleInvariant(self):
160
    # Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale)
161
    (num_samples, loss, x, alpha, scale, _, _, _) = (
162
        self._precompute_lossfun_inputs())
163
    # Random log-normally distributed scalings in ~(0.2, 20)
164

165
    rng = random.PRNGKey(0)
166
    mult = jnp.maximum(0.2, jnp.exp(random.normal(rng, shape=[num_samples])))
167

168
    # Compute the scaled loss.
169
    loss_scaled = general.lossfun(mult * x, alpha, mult * scale)
170
    chex.assert_trees_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)
171

172
  @chex.all_variants()
173
  def testAlphaEqualsNegativeInfinity(self):
174
    # Check that alpha == -Infinity reproduces Welsch aka Leclerc loss.
175
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
176
    alpha = -float('inf')
177
    scale = 1.7
178

179
    # Our loss.
180
    loss = self.variant(general.lossfun)(x, alpha, scale)
181

182
    # Welsch/Leclerc loss.
183
    loss_true = (1. - np.exp(-0.5 * np.square(x / scale)))
184

185
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
186

187
  @chex.all_variants()
188
  def testAlphaEqualsNegativeTwo(self):
189
    # Check that alpha == -2 reproduces Geman-McClure loss.
190
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
191
    alpha = -2.
192
    scale = 1.7
193

194
    # Our loss.
195
    loss = self.variant(general.lossfun)(x, alpha, scale)
196

197
    # Geman-McClure loss.
198
    loss_true = (2. * np.square(x / scale) / (np.square(x / scale) + 4.))
199

200
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
201

202
  @chex.all_variants()
203
  def testAlphaEqualsZero(self):
204
    # Check that alpha == 0 reproduces Cauchy aka Lorentzian loss.
205
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
206
    alpha = 0.
207
    scale = 1.7
208

209
    # Our loss.
210
    loss = self.variant(general.lossfun)(x, alpha, scale)
211

212
    # Cauchy/Lorentzian loss.
213
    loss_true = (np.log(0.5 * np.square(x / scale) + 1))
214

215
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
216

217
  @chex.all_variants()
218
  def testAlphaEqualsOne(self):
219
    # Check that alpha == 1 reproduces Charbonnier aka pseudo-Huber loss.
220
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
221
    alpha = 1.
222
    scale = 1.7
223

224
    # Our loss.
225
    loss = self.variant(general.lossfun)(x, alpha, scale)
226

227
    # Charbonnier loss.
228
    loss_true = (np.sqrt(np.square(x / scale) + 1) - 1)
229

230
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
231

232
  @chex.all_variants()
233
  def testAlphaEqualsTwo(self):
234
    # Check that alpha == 2 reproduces L2 loss.
235
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
236
    alpha = 2.
237
    scale = 1.7
238

239
    # Our loss.
240
    loss = self.variant(general.lossfun)(x, alpha, scale)
241

242
    # L2 Loss.
243
    loss_true = 0.5 * np.square(x / scale)
244

245
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
246

247
  @chex.all_variants()
248
  def testAlphaEqualsFour(self):
249
    # Check that alpha == 4 reproduces a quartic.
250
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
251
    alpha = 4.
252
    scale = 1.7
253

254
    # Our loss.
255
    loss = self.variant(general.lossfun)(x, alpha, scale)
256

257
    # The true loss.
258
    loss_true = np.square(np.square(x / scale)) / 8 + np.square(x / scale) / 2
259

260
    chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)
261

262
  @chex.all_variants()
263
  def testAlphaEqualsInfinity(self):
264
    # Check that alpha == Infinity takes the correct form.
265
    x = np.linspace(-15, 15, 1000, dtype=np.float64)
266
    alpha = float('inf')
267
    scale = 1.7
268

269
    # Our loss.
270
    loss = self.variant(general.lossfun)(x, alpha, scale)
271

272
    # The true loss.
273
    loss_true = (jnp.exp(0.5 * jnp.square(x / scale)) - 1.)
274

275
    chex.assert_trees_all_close(loss, loss_true, atol=1e-4, rtol=1e-4)
276

277
  @chex.all_variants()
278
  def testLossAndGradientsAreFinite(self):
279
    # Test that the loss and its approximation both give finite losses and
280
    # derivatives everywhere that they should for a wide range of values.
281
    num_samples = 100000
282
    rng = random.PRNGKey(0)
283

284
    # Normally distributed inputs.
285
    rng, key = random.split(rng)
286
    x = random.normal(key, shape=[num_samples])
287

288
    # Uniformly distributed values in (-16, 3), quantized to the nearest 0.1
289
    # to ensure that we hit the special cases at 0, 2.
290
    rng, key = random.split(rng)
291
    alpha = jnp.round(
292
        random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
293
        10) / 10.
294

295
    # Random log-normally distributed values in approx (1e-5, 100000):
296
    rng, key = random.split(rng)
297
    scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5
298

299
    fn = self.variant(general.lossfun)
300
    loss = fn(x, alpha, scale)
301
    d_x, d_alpha, d_scale = (
302
        jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha,
303
                                                                  scale))
304

305
    for v in [loss, d_x, d_alpha, d_scale]:
306
      chex.assert_tree_all_finite(v)
307

308
  @chex.all_variants()
309
  def testGradientMatchesFiniteDifferences(self):
310
    # Test that the loss and its approximation both return gradients that are
311
    # close to the numerical gradient from finite differences, with forward
312
    # differencing. Returning correct gradients is JAX's job, so this is
313
    # just an aggressive sanity check.
314
    num_samples = 100000
315
    rng = random.PRNGKey(0)
316

317
    # Normally distributed inputs.
318
    rng, key = random.split(rng)
319
    x = random.normal(key, shape=[num_samples])
320

321
    # Uniformly distributed values in (-16, 3), quantized to the nearest
322
    # 0.1 and then shifted by 0.05 so that we avoid the special cases at
323
    # 0 and 2 where the analytical gradient wont match finite differences.
324
    rng, key = random.split(rng)
325
    alpha = jnp.round(
326
        random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *
327
        10) / 10. + 0.05
328

329
    # Random log-normally distributed values in approx (1e-5, 100000):
330
    rng, key = random.split(rng)
331
    scale = random.uniform(key, shape=[num_samples], minval=0.5, maxval=1.5)
332

333
    loss = general.lossfun(x, alpha, scale)
334
    d_x, d_alpha, d_scale = (
335
        jax.grad(lambda x, a, s: jnp.sum(general.lossfun(x, a, s)),
336
                 [0, 1, 2])(x, alpha, scale))
337

338
    step_size = 1e-3
339
    fn = self.variant(general.lossfun)
340
    n_x = (fn(x + step_size, alpha, scale) - loss) / step_size
341
    n_alpha = (fn(x, alpha + step_size, scale) - loss) / step_size
342
    n_scale = (fn(x, alpha, scale + step_size) - loss) / step_size
343

344
    chex.assert_trees_all_close(n_x, d_x, atol=1e-2, rtol=1e-2)
345
    chex.assert_trees_all_close(n_alpha, d_alpha, atol=1e-2, rtol=1e-2)
346
    chex.assert_trees_all_close(n_scale, d_scale, atol=1e-2, rtol=1e-2)
347

348
  @chex.all_variants()
349
  @parameterized.parameters((-2,), (-1,), (0,), (1,), (2,))
350
  def testGradientsAreFiniteWithAllInputs(self, alpha):
351
    x_half = jnp.concatenate(
352
        [jnp.exp(jnp.linspace(-80, 80, 1001)),
353
         jnp.array([jnp.inf])])
354
    x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half])
355
    scale = jnp.full_like(x, 1.)
356

357
    fn = self.variant(lambda x, s: general.lossfun(x, alpha, s))
358
    loss = fn(x, scale)
359
    d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale)
360

361
    for v in [loss, d_x, d_scale]:
362
      chex.assert_tree_all_finite(v)
363

364

365
if __name__ == '__main__':
366
  absltest.main()
367

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.