google-research
94 строки · 3.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for the ops module in Jax."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23from absl.testing import absltest
24from absl.testing import parameterized
25import jax.numpy as jnp
26import numpy as np
27
28from soft_sort.jax import ops
29
30
31class OpsTestCase(parameterized.TestCase):
32"""Tests for the ops module in jax."""
33
34def setUp(self):
35super(OpsTestCase, self).setUp()
36self.x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
37[17.9, 14.2, 55.5, 9.8, 3.5]]])
38
39def test_sort(self):
40s = ops.softsort(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
41self.assertEqual(s.shape, self.x.shape)
42deltas = jnp.diff(s, axis=-1) > 0
43np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))
44
45def test_sort_descending(self):
46x = self.x[0][0]
47s = ops.softsort(x, axis=-1, direction='DESCENDING',
48threshold=1e-3, epsilon=1e-3)
49self.assertEqual(s.shape, x.shape)
50deltas = jnp.diff(s, axis=-1) < 0
51np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))
52
53def test_ranks(self):
54ranks = ops.softranks(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
55self.assertEqual(ranks.shape, self.x.shape)
56true_ranks = jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1)
57np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
58
59def test_ranks_one_based(self):
60ranks = ops.softranks(self.x, axis=-1, zero_based=False,
61threshold=1e-3, epsilon=1e-3)
62self.assertEqual(ranks.shape, self.x.shape)
63true_ranks = jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1) + 1
64np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
65
66def test_ranks_descending(self):
67ranks = ops.softranks(
68self.x, axis=-1, zero_based=True, direction='DESCENDING',
69threshold=1e-3, epsilon=1e-3)
70self.assertEqual(ranks.shape, self.x.shape)
71
72max_rank = self.x.shape[-1] - 1
73true_ranks = max_rank - jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1)
74np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
75
76@parameterized.named_parameters(
77('medians_-1', 0.5, -1),
78('medians_1', 0.5, 1),
79('percentile25_-1', 0.25, -1))
80def test_softquantile(self, quantile, axis):
81x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
82[17.9, 14.2, 55.5, 9.8, 3.5]],
83[[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
84[1.9, 4.2, 55.5, 9.8, 1.5]]])
85qs = ops.softquantile(x, quantile, axis=axis, threshold=1e-3, epsilon=1e-3)
86s = list(x.shape)
87s.pop(axis)
88self.assertTupleEqual(qs.shape, tuple(s))
89np.testing.assert_allclose(
90qs, jnp.quantile(x, quantile, axis=axis), rtol=1e-2)
91
92
93if __name__ == '__main__':
94absltest.main()
95