google-research

Форк
0
/
cubic_spline_test.py 
198 строк · 7.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for cubic_spline.py."""
17

18
from absl.testing import absltest
19
import chex
20
import jax
21
import jax.numpy as jnp
22
import jax.random as random
23
from robust_loss_jax import cubic_spline
24

25

26
class CubicSplineTest(chex.TestCase):
27

28
  def _setup_toy_data(self, n=32768):
29
    x = jnp.float32(jnp.arange(n))
30
    rng = random.PRNGKey(0)
31
    rng, key = random.split(rng)
32
    values = random.normal(key, shape=[n])
33
    rng, key = random.split(rng)
34
    tangents = random.normal(key, shape=[n])
35
    return x, values, tangents
36

37
  def _interpolate1d(self, x, values, tangents):
38
    """Compute interpolate1d(x, values, tangents) and its derivative.
39

40
    This is just a helper function around cubic_spline.interpolate1d() that
41
    computes a tensor of values and gradients.
42

43
    Args:
44
      x: A tensor of values to interpolate with.
45
      values: A tensor of knot values for the spline.
46
      tangents: A tensor of knot tangents for the spline.
47

48
    Returns:
49
      A tuple containing:
50
       (An tensor of interpolated values,
51
        A tensor of derivatives of interpolated values wrt `x`)
52

53
    Typical usage example:
54
      y, dy_dx = self._interpolate1d(x, values, tangents)
55
    """
56
    fn = self.variant(cubic_spline.interpolate1d)
57
    y = fn(x, values, tangents)
58
    dy_dx = jax.grad(lambda z: jnp.sum(fn(z, values, tangents)))(x)
59
    return y, dy_dx
60

61
  @chex.all_variants()
62
  def testInterpolationReproducesValuesAtKnots(self):
63
    """Check that interpolating at a knot produces the value at that knot."""
64
    x, values, tangents = self._setup_toy_data()
65
    y = self.variant(cubic_spline.interpolate1d)(x, values, tangents)
66
    chex.assert_trees_all_close(y, values, atol=1e-5, rtol=1e-5)
67

68
  @chex.all_variants()
69
  def testInterpolationReproducesTangentsAtKnots(self):
70
    """Check that the derivative at a knot produces the tangent at that knot."""
71
    x, values, tangents = self._setup_toy_data()
72
    _, dy_dx = self.variant(self._interpolate1d)(x, values, tangents)
73
    chex.assert_trees_all_close(dy_dx, tangents, atol=1e-5, rtol=1e-5)
74

75
  @chex.all_variants()
76
  def testZeroTangentMidpointValuesAndDerivativesAreCorrect(self):
77
    """Check that splines with zero tangents behave correctly at midpoints.
78

79
    Make a spline whose tangents are all zeros, and then verify that
80
    midpoints between each pair of knots have the mean value of their adjacent
81
    knots, and have a derivative that is 1.5x the difference between their
82
    adjacent knots.
83
    """
84
    # Make a spline with random values and all-zero tangents.
85
    _, values, _ = self._setup_toy_data()
86
    tangents = jnp.zeros_like(values)
87

88
    # Query n-1 points placed exactly in between each pair of knots.
89
    x = jnp.arange(len(values) - 1) + 0.5
90

91
    # Get the interpolated values and derivatives.
92
    y, dy_dx = self._interpolate1d(x, values, tangents)
93

94
    # Check that the interpolated values of all queries lies at the midpoint of
95
    # its surrounding knot values.
96
    y_true = (values[0:-1] + values[1:]) / 2.
97
    chex.assert_trees_all_close(y, y_true, atol=1e-5, rtol=1e-5)
98

99
    # Check that the derivative of all interpolated values is (fun fact!) 1.5x
100
    # the numerical difference between adjacent knot values.
101
    dy_dx_true = 1.5 * (values[1:] - values[0:-1])
102
    chex.assert_trees_all_close(dy_dx, dy_dx_true, atol=1e-5, rtol=1e-5)
103

104
  @chex.all_variants()
105
  def testZeroTangentIntermediateValuesAndDerivativesDoNotOvershoot(self):
106
    """Check that splines with zero tangents behave correctly between knots.
107

108
    Make a spline whose tangents are all zeros, and then verify that points
109
    between each knot lie in between the knot values, and have derivatives
110
    are between 0 and 1.5x the numerical difference between knot values
111
    (mathematically, 1.5x is the max derivative if the tangents are zero).
112
    """
113
    # Make a spline with all-zero tangents and random values.
114
    _, values, _ = self._setup_toy_data()
115
    tangents = jnp.zeros_like(values)
116

117
    # Query n-1 points placed somewhere randomly in between all adjacent knots.
118
    rng = random.PRNGKey(0)
119
    rng, key = random.split(rng)
120
    x = jnp.arange(len(values) - 1) + random.uniform(
121
        key, shape=[len(values) - 1])
122

123
    # Get the interpolated values and derivatives.
124
    y, dy_dx = self._interpolate1d(x, values, tangents)
125

126
    # Check that the interpolated values of all queries lies between its
127
    # surrounding knot values.
128
    self.assertTrue(
129
        jnp.all(((values[0:-1] <= y) & (y <= values[1:]))
130
                | ((values[0:-1] >= y) & (y >= values[1:]))))
131

132
    # Check that all derivatives of interpolated values are between 0 and 1.5x
133
    # the numerical difference between adjacent knot values.
134
    max_dy_dx = (1.5 + 1e-3) * (values[1:] - values[0:-1])
135
    self.assertTrue(
136
        jnp.all(((0 <= dy_dx) & (dy_dx <= max_dy_dx))
137
                | ((0 >= dy_dx) & (dy_dx >= max_dy_dx))))
138

139
  @chex.all_variants()
140
  def testLinearRampsReproduceCorrectly(self):
141
    """Check that interpolating a ramp reproduces a ramp.
142

143
    Generate linear ramps, render them into splines, and then interpolate and
144
    extrapolate the splines and verify that they reproduce the ramp.
145
    """
146
    n = 256
147
    # Generate queries inside and outside the support of the spline.
148
    rng, key = random.split(random.PRNGKey(0))
149
    x = (random.uniform(key, shape=[1024]) * 2 - 0.5) * (n - 1)
150
    idx = jnp.float32(jnp.arange(n))
151
    fn = self.variant(cubic_spline.interpolate1d)
152
    for _ in range(8):
153
      rng, key = random.split(rng)
154
      slope = random.normal(key)
155
      rng, key = random.split(rng)
156
      bias = random.normal(key)
157
      values = slope * idx + bias
158
      tangents = jnp.ones_like(values) * slope
159
      y = fn(x, values, tangents)
160
      y_true = slope * x + bias
161
      chex.assert_trees_all_close(y, y_true, atol=1e-5, rtol=1e-5)
162

163
  @chex.all_variants()
164
  def testExtrapolationIsLinear(self):
165
    """Check that extrapolation is linear with respect to the endpoint knots.
166

167
    Generate random splines and query them outside of the support of the
168
    spline, and veify that extrapolation is linear with respect to the
169
    endpoint knots.
170
    """
171
    n = 256
172
    # Generate queries above and below the support of the spline.
173
    rng, key = random.split(random.PRNGKey(0))
174
    x_below = -(random.uniform(key, shape=[1024])) * (n - 1)
175
    rng, key = random.split(rng)
176
    x_above = (random.uniform(key, shape=[1024]) + 1.) * (n - 1)
177
    fn = self.variant(cubic_spline.interpolate1d)
178
    for _ in range(8):
179
      rng, key = random.split(rng)
180
      values = random.normal(key, shape=[n])
181
      rng, key = random.split(rng)
182
      tangents = random.normal(key, shape=[n])
183

184
      # Query the spline below its support and check that it's a linear ramp
185
      # with the slope and bias of the beginning of the spline.
186
      y_below = fn(x_below, values, tangents)
187
      y_below_true = tangents[0] * x_below + values[0]
188
      chex.assert_trees_all_close(y_below, y_below_true, atol=1e-5, rtol=1e-5)
189

190
      # Query the spline above its support and check that it's a linear ramp
191
      # with the slope and bias of the end of the spline.
192
      y_above = fn(x_above, values, tangents)
193
      y_above_true = tangents[-1] * (x_above - (n - 1)) + values[-1]
194
      chex.assert_trees_all_close(y_above, y_above_true, atol=1e-5, rtol=1e-5)
195

196

197
if __name__ == '__main__':
198
  absltest.main()
199

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.