google-research

Форк
0
94 строки · 3.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for the ops module in Jax."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22

23
from absl.testing import absltest
24
from absl.testing import parameterized
25
import jax.numpy as jnp
26
import numpy as np
27

28
from soft_sort.jax import ops
29

30

31
class OpsTestCase(parameterized.TestCase):
32
  """Tests for the ops module in jax."""
33

34
  def setUp(self):
35
    super(OpsTestCase, self).setUp()
36
    self.x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
37
                         [17.9, 14.2, 55.5, 9.8, 3.5]]])
38

39
  def test_sort(self):
40
    s = ops.softsort(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
41
    self.assertEqual(s.shape, self.x.shape)
42
    deltas = jnp.diff(s, axis=-1) > 0
43
    np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))
44

45
  def test_sort_descending(self):
46
    x = self.x[0][0]
47
    s = ops.softsort(x, axis=-1, direction='DESCENDING',
48
                     threshold=1e-3, epsilon=1e-3)
49
    self.assertEqual(s.shape, x.shape)
50
    deltas = jnp.diff(s, axis=-1) < 0
51
    np.testing.assert_allclose(deltas, jnp.ones(deltas.shape, dtype=bool))
52

53
  def test_ranks(self):
54
    ranks = ops.softranks(self.x, axis=-1, threshold=1e-3, epsilon=1e-3)
55
    self.assertEqual(ranks.shape, self.x.shape)
56
    true_ranks = jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1)
57
    np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
58

59
  def test_ranks_one_based(self):
60
    ranks = ops.softranks(self.x, axis=-1, zero_based=False,
61
                          threshold=1e-3, epsilon=1e-3)
62
    self.assertEqual(ranks.shape, self.x.shape)
63
    true_ranks = jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1) + 1
64
    np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
65

66
  def test_ranks_descending(self):
67
    ranks = ops.softranks(
68
        self.x, axis=-1, zero_based=True, direction='DESCENDING',
69
        threshold=1e-3, epsilon=1e-3)
70
    self.assertEqual(ranks.shape, self.x.shape)
71

72
    max_rank = self.x.shape[-1] - 1
73
    true_ranks = max_rank - jnp.argsort(jnp.argsort(self.x, axis=-1), axis=-1)
74
    np.testing.assert_allclose(ranks, true_ranks, atol=1e-3)
75

76
  @parameterized.named_parameters(
77
      ('medians_-1', 0.5, -1),
78
      ('medians_1', 0.5, 1),
79
      ('percentile25_-1', 0.25, -1))
80
  def test_softquantile(self, quantile, axis):
81
    x = jnp.array([[[7.9, 1.2, 5.5, 9.8, 3.5], [7.9, 12.2, 45.5, 9.8, 3.5],
82
                    [17.9, 14.2, 55.5, 9.8, 3.5]],
83
                   [[4.9, 1.2, 15.5, 4.8, 3.5], [7.9, 1.2, 5.5, 7.8, 2.5],
84
                    [1.9, 4.2, 55.5, 9.8, 1.5]]])
85
    qs = ops.softquantile(x, quantile, axis=axis, threshold=1e-3, epsilon=1e-3)
86
    s = list(x.shape)
87
    s.pop(axis)
88
    self.assertTupleEqual(qs.shape, tuple(s))
89
    np.testing.assert_allclose(
90
        qs, jnp.quantile(x, quantile, axis=axis), rtol=1e-2)
91

92

93
if __name__ == '__main__':
94
  absltest.main()
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.