google-research
366 строк · 13.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for general.py."""
17
18from absl.testing import absltest19from absl.testing import parameterized20import chex21import jax22from jax import random23import jax.numpy as jnp24import numpy as np25
26from robust_loss_jax import general27
28
29class LossfunTest(chex.TestCase, parameterized.TestCase):30
31def _precompute_lossfun_inputs(self):32"""Precompute a loss and its derivatives for random inputs and parameters.33
34Generates a large number of random inputs to the loss, and random
35shape/scale parameters for the loss function at each sample, and
36computes the loss and its derivative with respect to all inputs and
37parameters, returning everything to be used to assert various properties
38in our unit tests.
39
40Returns:
41A tuple containing:
42(the number (int) of samples, and the length of all following arrays,
43A tensor of losses for each sample,
44A tensor of residuals of each sample (the loss inputs),
45A tensor of shape parameters of each loss,
46A tensor of scale parameters of each loss,
47A tensor of derivatives of each loss wrt each x,
48A tensor of derivatives of each loss wrt each alpha,
49A tensor of derivatives of each loss wrt each scale)
50
51Typical usage example:
52(num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)
53= self._precompute_lossfun_inputs()
54"""
55num_samples = 10000056rng = random.PRNGKey(0)57
58# Normally distributed inputs.59rng, key = random.split(rng)60x = random.normal(key, shape=[num_samples])61
62# Uniformly distributed values in (-16, 3), quantized to the nearest 0.163# to ensure that we hit the special cases at 0, 2.64rng, key = random.split(rng)65alpha = jnp.round(66random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *6710) / 10.68# Push the sampled alphas at the extents of the range to +/- infinity, so69# that we probe those cases too.70alpha = jnp.where(alpha == 3, jnp.inf, alpha)71alpha = jnp.where(alpha == -16, -jnp.inf, alpha)72
73# Random log-normally distributed values in approx (1e-5, 100000):74rng, key = random.split(rng)75scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-576
77fn = self.variant(general.lossfun)78loss = fn(x, alpha, scale)79d_x, d_alpha, d_scale = (80jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha,81scale))82
83return (num_samples, loss, x, alpha, scale, d_x, d_alpha, d_scale)84
85@chex.all_variants()86def testDerivativeIsMonotonicWrtX(self):87# Check that the loss increases monotonically with |x|.88_, _, x, alpha, _, d_x, _, _ = self._precompute_lossfun_inputs()89# This is just to suppress a warning below.90d_x = jnp.where(jnp.isfinite(d_x), d_x, jnp.zeros_like(d_x))91mask = jnp.isfinite(alpha) & (92jnp.abs(d_x) > (300. * jnp.finfo(jnp.float32).eps))93chex.assert_trees_all_close(jnp.sign(d_x[mask]), jnp.sign(x[mask]))94
95@chex.all_variants()96def testLossIsNearZeroAtOrigin(self):97# Check that the loss is near-zero when x is near-zero.98_, loss, x, _, _, _, _, _ = self._precompute_lossfun_inputs()99loss_near_zero = loss[jnp.abs(x) < 1e-5]100chex.assert_trees_all_close(101loss_near_zero, jnp.zeros_like(loss_near_zero), atol=1e-5)102
103@chex.all_variants()104def testLossIsQuadraticNearOrigin(self):105# Check that the loss is well-approximated by a quadratic bowl when106# |x| < scale107_, loss, x, _, scale, _, _, _ = self._precompute_lossfun_inputs()108mask = jnp.abs(x) < (0.5 * scale)109loss_quad = 0.5 * jnp.square(x / scale)110chex.assert_trees_all_close(111loss_quad[mask], loss[mask], rtol=1e-5, atol=1e-2)112
113@chex.all_variants()114def testLossIsBoundedWhenAlphaIsNegative(self):115# Assert that loss < (alpha - 2)/alpha when alpha < 0.116_, loss, _, alpha, _, _, _, _ = self._precompute_lossfun_inputs()117mask = alpha < 0.118min_val = jnp.finfo(jnp.float32).min119alpha_clipped = jnp.maximum(min_val, alpha[mask])120self.assertTrue(121jnp.all(loss[mask] <= ((alpha_clipped - 2.) / alpha_clipped)))122
123@chex.all_variants()124def testDerivativeIsBoundedWhenAlphaIsBelow2(self):125# Assert that |d_x| < |x|/scale^2 when alpha <= 2.126_, _, x, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()127mask = jnp.isfinite(alpha) & (alpha <= 2)128grad = jnp.abs(d_x[mask])129bound = ((jnp.abs(x[mask]) + (300. * jnp.finfo(jnp.float32).eps)) /130scale[mask]**2)131self.assertTrue(jnp.all(grad <= bound))132
133@chex.all_variants()134def testDerivativeIsBoundedWhenAlphaIsBelow1(self):135# Assert that |d_x| < 1/scale when alpha <= 1.136_, _, _, alpha, scale, d_x, _, _ = self._precompute_lossfun_inputs()137mask = jnp.isfinite(alpha) & (alpha <= 1)138grad = jnp.abs(d_x[mask])139bound = ((1. + (300. * jnp.finfo(jnp.float32).eps)) / scale[mask])140self.assertTrue(jnp.all(grad <= bound))141
142@chex.all_variants()143def testAlphaDerivativeIsPositive(self):144# Assert that d_loss / d_alpha > 0.145_, _, _, alpha, _, _, d_alpha, _ = self._precompute_lossfun_inputs()146mask = jnp.isfinite(alpha)147self.assertTrue(148jnp.all(d_alpha[mask] > (-300. * jnp.finfo(jnp.float32).eps)))149
150@chex.all_variants()151def testScaleDerivativeIsNegative(self):152# Assert that d_loss / d_scale < 0.153_, _, _, alpha, _, _, _, d_scale = self._precompute_lossfun_inputs()154mask = jnp.isfinite(alpha)155self.assertTrue(156jnp.all(d_scale[mask] < (300. * jnp.finfo(jnp.float32).eps)))157
158@chex.all_variants()159def testLossIsScaleInvariant(self):160# Check that loss(mult * x, alpha, mult * scale) == loss(x, alpha, scale)161(num_samples, loss, x, alpha, scale, _, _, _) = (162self._precompute_lossfun_inputs())163# Random log-normally distributed scalings in ~(0.2, 20)164
165rng = random.PRNGKey(0)166mult = jnp.maximum(0.2, jnp.exp(random.normal(rng, shape=[num_samples])))167
168# Compute the scaled loss.169loss_scaled = general.lossfun(mult * x, alpha, mult * scale)170chex.assert_trees_all_close(loss, loss_scaled, atol=1e-4, rtol=1e-4)171
172@chex.all_variants()173def testAlphaEqualsNegativeInfinity(self):174# Check that alpha == -Infinity reproduces Welsch aka Leclerc loss.175x = np.linspace(-15, 15, 1000, dtype=np.float64)176alpha = -float('inf')177scale = 1.7178
179# Our loss.180loss = self.variant(general.lossfun)(x, alpha, scale)181
182# Welsch/Leclerc loss.183loss_true = (1. - np.exp(-0.5 * np.square(x / scale)))184
185chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)186
187@chex.all_variants()188def testAlphaEqualsNegativeTwo(self):189# Check that alpha == -2 reproduces Geman-McClure loss.190x = np.linspace(-15, 15, 1000, dtype=np.float64)191alpha = -2.192scale = 1.7193
194# Our loss.195loss = self.variant(general.lossfun)(x, alpha, scale)196
197# Geman-McClure loss.198loss_true = (2. * np.square(x / scale) / (np.square(x / scale) + 4.))199
200chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)201
202@chex.all_variants()203def testAlphaEqualsZero(self):204# Check that alpha == 0 reproduces Cauchy aka Lorentzian loss.205x = np.linspace(-15, 15, 1000, dtype=np.float64)206alpha = 0.207scale = 1.7208
209# Our loss.210loss = self.variant(general.lossfun)(x, alpha, scale)211
212# Cauchy/Lorentzian loss.213loss_true = (np.log(0.5 * np.square(x / scale) + 1))214
215chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)216
217@chex.all_variants()218def testAlphaEqualsOne(self):219# Check that alpha == 1 reproduces Charbonnier aka pseudo-Huber loss.220x = np.linspace(-15, 15, 1000, dtype=np.float64)221alpha = 1.222scale = 1.7223
224# Our loss.225loss = self.variant(general.lossfun)(x, alpha, scale)226
227# Charbonnier loss.228loss_true = (np.sqrt(np.square(x / scale) + 1) - 1)229
230chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)231
232@chex.all_variants()233def testAlphaEqualsTwo(self):234# Check that alpha == 2 reproduces L2 loss.235x = np.linspace(-15, 15, 1000, dtype=np.float64)236alpha = 2.237scale = 1.7238
239# Our loss.240loss = self.variant(general.lossfun)(x, alpha, scale)241
242# L2 Loss.243loss_true = 0.5 * np.square(x / scale)244
245chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)246
247@chex.all_variants()248def testAlphaEqualsFour(self):249# Check that alpha == 4 reproduces a quartic.250x = np.linspace(-15, 15, 1000, dtype=np.float64)251alpha = 4.252scale = 1.7253
254# Our loss.255loss = self.variant(general.lossfun)(x, alpha, scale)256
257# The true loss.258loss_true = np.square(np.square(x / scale)) / 8 + np.square(x / scale) / 2259
260chex.assert_trees_all_close(loss, loss_true, atol=1e-5, rtol=1e-5)261
262@chex.all_variants()263def testAlphaEqualsInfinity(self):264# Check that alpha == Infinity takes the correct form.265x = np.linspace(-15, 15, 1000, dtype=np.float64)266alpha = float('inf')267scale = 1.7268
269# Our loss.270loss = self.variant(general.lossfun)(x, alpha, scale)271
272# The true loss.273loss_true = (jnp.exp(0.5 * jnp.square(x / scale)) - 1.)274
275chex.assert_trees_all_close(loss, loss_true, atol=1e-4, rtol=1e-4)276
277@chex.all_variants()278def testLossAndGradientsAreFinite(self):279# Test that the loss and its approximation both give finite losses and280# derivatives everywhere that they should for a wide range of values.281num_samples = 100000282rng = random.PRNGKey(0)283
284# Normally distributed inputs.285rng, key = random.split(rng)286x = random.normal(key, shape=[num_samples])287
288# Uniformly distributed values in (-16, 3), quantized to the nearest 0.1289# to ensure that we hit the special cases at 0, 2.290rng, key = random.split(rng)291alpha = jnp.round(292random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *29310) / 10.294
295# Random log-normally distributed values in approx (1e-5, 100000):296rng, key = random.split(rng)297scale = jnp.exp(random.normal(key, shape=[num_samples]) * 4.) + 1e-5298
299fn = self.variant(general.lossfun)300loss = fn(x, alpha, scale)301d_x, d_alpha, d_scale = (302jax.grad(lambda x, a, s: jnp.sum(fn(x, a, s)), [0, 1, 2])(x, alpha,303scale))304
305for v in [loss, d_x, d_alpha, d_scale]:306chex.assert_tree_all_finite(v)307
308@chex.all_variants()309def testGradientMatchesFiniteDifferences(self):310# Test that the loss and its approximation both return gradients that are311# close to the numerical gradient from finite differences, with forward312# differencing. Returning correct gradients is JAX's job, so this is313# just an aggressive sanity check.314num_samples = 100000315rng = random.PRNGKey(0)316
317# Normally distributed inputs.318rng, key = random.split(rng)319x = random.normal(key, shape=[num_samples])320
321# Uniformly distributed values in (-16, 3), quantized to the nearest322# 0.1 and then shifted by 0.05 so that we avoid the special cases at323# 0 and 2 where the analytical gradient wont match finite differences.324rng, key = random.split(rng)325alpha = jnp.round(326random.uniform(key, shape=[num_samples], minval=-16, maxval=3) *32710) / 10. + 0.05328
329# Random log-normally distributed values in approx (1e-5, 100000):330rng, key = random.split(rng)331scale = random.uniform(key, shape=[num_samples], minval=0.5, maxval=1.5)332
333loss = general.lossfun(x, alpha, scale)334d_x, d_alpha, d_scale = (335jax.grad(lambda x, a, s: jnp.sum(general.lossfun(x, a, s)),336[0, 1, 2])(x, alpha, scale))337
338step_size = 1e-3339fn = self.variant(general.lossfun)340n_x = (fn(x + step_size, alpha, scale) - loss) / step_size341n_alpha = (fn(x, alpha + step_size, scale) - loss) / step_size342n_scale = (fn(x, alpha, scale + step_size) - loss) / step_size343
344chex.assert_trees_all_close(n_x, d_x, atol=1e-2, rtol=1e-2)345chex.assert_trees_all_close(n_alpha, d_alpha, atol=1e-2, rtol=1e-2)346chex.assert_trees_all_close(n_scale, d_scale, atol=1e-2, rtol=1e-2)347
348@chex.all_variants()349@parameterized.parameters((-2,), (-1,), (0,), (1,), (2,))350def testGradientsAreFiniteWithAllInputs(self, alpha):351x_half = jnp.concatenate(352[jnp.exp(jnp.linspace(-80, 80, 1001)),353jnp.array([jnp.inf])])354x = jnp.concatenate([-x_half[::-1], jnp.array([0.]), x_half])355scale = jnp.full_like(x, 1.)356
357fn = self.variant(lambda x, s: general.lossfun(x, alpha, s))358loss = fn(x, scale)359d_x, d_scale = jax.vmap(jax.grad(fn, [0, 1]))(x, scale)360
361for v in [loss, d_x, d_scale]:362chex.assert_tree_all_finite(v)363
364
365if __name__ == '__main__':366absltest.main()367