google-research
198 строк · 7.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for cubic_spline.py."""
17
18from absl.testing import absltest19import chex20import jax21import jax.numpy as jnp22import jax.random as random23from robust_loss_jax import cubic_spline24
25
26class CubicSplineTest(chex.TestCase):27
28def _setup_toy_data(self, n=32768):29x = jnp.float32(jnp.arange(n))30rng = random.PRNGKey(0)31rng, key = random.split(rng)32values = random.normal(key, shape=[n])33rng, key = random.split(rng)34tangents = random.normal(key, shape=[n])35return x, values, tangents36
37def _interpolate1d(self, x, values, tangents):38"""Compute interpolate1d(x, values, tangents) and its derivative.39
40This is just a helper function around cubic_spline.interpolate1d() that
41computes a tensor of values and gradients.
42
43Args:
44x: A tensor of values to interpolate with.
45values: A tensor of knot values for the spline.
46tangents: A tensor of knot tangents for the spline.
47
48Returns:
49A tuple containing:
50(An tensor of interpolated values,
51A tensor of derivatives of interpolated values wrt `x`)
52
53Typical usage example:
54y, dy_dx = self._interpolate1d(x, values, tangents)
55"""
56fn = self.variant(cubic_spline.interpolate1d)57y = fn(x, values, tangents)58dy_dx = jax.grad(lambda z: jnp.sum(fn(z, values, tangents)))(x)59return y, dy_dx60
61@chex.all_variants()62def testInterpolationReproducesValuesAtKnots(self):63"""Check that interpolating at a knot produces the value at that knot."""64x, values, tangents = self._setup_toy_data()65y = self.variant(cubic_spline.interpolate1d)(x, values, tangents)66chex.assert_trees_all_close(y, values, atol=1e-5, rtol=1e-5)67
68@chex.all_variants()69def testInterpolationReproducesTangentsAtKnots(self):70"""Check that the derivative at a knot produces the tangent at that knot."""71x, values, tangents = self._setup_toy_data()72_, dy_dx = self.variant(self._interpolate1d)(x, values, tangents)73chex.assert_trees_all_close(dy_dx, tangents, atol=1e-5, rtol=1e-5)74
75@chex.all_variants()76def testZeroTangentMidpointValuesAndDerivativesAreCorrect(self):77"""Check that splines with zero tangents behave correctly at midpoints.78
79Make a spline whose tangents are all zeros, and then verify that
80midpoints between each pair of knots have the mean value of their adjacent
81knots, and have a derivative that is 1.5x the difference between their
82adjacent knots.
83"""
84# Make a spline with random values and all-zero tangents.85_, values, _ = self._setup_toy_data()86tangents = jnp.zeros_like(values)87
88# Query n-1 points placed exactly in between each pair of knots.89x = jnp.arange(len(values) - 1) + 0.590
91# Get the interpolated values and derivatives.92y, dy_dx = self._interpolate1d(x, values, tangents)93
94# Check that the interpolated values of all queries lies at the midpoint of95# its surrounding knot values.96y_true = (values[0:-1] + values[1:]) / 2.97chex.assert_trees_all_close(y, y_true, atol=1e-5, rtol=1e-5)98
99# Check that the derivative of all interpolated values is (fun fact!) 1.5x100# the numerical difference between adjacent knot values.101dy_dx_true = 1.5 * (values[1:] - values[0:-1])102chex.assert_trees_all_close(dy_dx, dy_dx_true, atol=1e-5, rtol=1e-5)103
104@chex.all_variants()105def testZeroTangentIntermediateValuesAndDerivativesDoNotOvershoot(self):106"""Check that splines with zero tangents behave correctly between knots.107
108Make a spline whose tangents are all zeros, and then verify that points
109between each knot lie in between the knot values, and have derivatives
110are between 0 and 1.5x the numerical difference between knot values
111(mathematically, 1.5x is the max derivative if the tangents are zero).
112"""
113# Make a spline with all-zero tangents and random values.114_, values, _ = self._setup_toy_data()115tangents = jnp.zeros_like(values)116
117# Query n-1 points placed somewhere randomly in between all adjacent knots.118rng = random.PRNGKey(0)119rng, key = random.split(rng)120x = jnp.arange(len(values) - 1) + random.uniform(121key, shape=[len(values) - 1])122
123# Get the interpolated values and derivatives.124y, dy_dx = self._interpolate1d(x, values, tangents)125
126# Check that the interpolated values of all queries lies between its127# surrounding knot values.128self.assertTrue(129jnp.all(((values[0:-1] <= y) & (y <= values[1:]))130| ((values[0:-1] >= y) & (y >= values[1:]))))131
132# Check that all derivatives of interpolated values are between 0 and 1.5x133# the numerical difference between adjacent knot values.134max_dy_dx = (1.5 + 1e-3) * (values[1:] - values[0:-1])135self.assertTrue(136jnp.all(((0 <= dy_dx) & (dy_dx <= max_dy_dx))137| ((0 >= dy_dx) & (dy_dx >= max_dy_dx))))138
139@chex.all_variants()140def testLinearRampsReproduceCorrectly(self):141"""Check that interpolating a ramp reproduces a ramp.142
143Generate linear ramps, render them into splines, and then interpolate and
144extrapolate the splines and verify that they reproduce the ramp.
145"""
146n = 256147# Generate queries inside and outside the support of the spline.148rng, key = random.split(random.PRNGKey(0))149x = (random.uniform(key, shape=[1024]) * 2 - 0.5) * (n - 1)150idx = jnp.float32(jnp.arange(n))151fn = self.variant(cubic_spline.interpolate1d)152for _ in range(8):153rng, key = random.split(rng)154slope = random.normal(key)155rng, key = random.split(rng)156bias = random.normal(key)157values = slope * idx + bias158tangents = jnp.ones_like(values) * slope159y = fn(x, values, tangents)160y_true = slope * x + bias161chex.assert_trees_all_close(y, y_true, atol=1e-5, rtol=1e-5)162
163@chex.all_variants()164def testExtrapolationIsLinear(self):165"""Check that extrapolation is linear with respect to the endpoint knots.166
167Generate random splines and query them outside of the support of the
168spline, and veify that extrapolation is linear with respect to the
169endpoint knots.
170"""
171n = 256172# Generate queries above and below the support of the spline.173rng, key = random.split(random.PRNGKey(0))174x_below = -(random.uniform(key, shape=[1024])) * (n - 1)175rng, key = random.split(rng)176x_above = (random.uniform(key, shape=[1024]) + 1.) * (n - 1)177fn = self.variant(cubic_spline.interpolate1d)178for _ in range(8):179rng, key = random.split(rng)180values = random.normal(key, shape=[n])181rng, key = random.split(rng)182tangents = random.normal(key, shape=[n])183
184# Query the spline below its support and check that it's a linear ramp185# with the slope and bias of the beginning of the spline.186y_below = fn(x_below, values, tangents)187y_below_true = tangents[0] * x_below + values[0]188chex.assert_trees_all_close(y_below, y_below_true, atol=1e-5, rtol=1e-5)189
190# Query the spline above its support and check that it's a linear ramp191# with the slope and bias of the end of the spline.192y_above = fn(x_above, values, tangents)193y_above_true = tangents[-1] * (x_above - (n - 1)) + values[-1]194chex.assert_trees_all_close(y_above, y_above_true, atol=1e-5, rtol=1e-5)195
196
197if __name__ == '__main__':198absltest.main()199