google-research

Форк
0
/
layers_test.py 
108 строк · 3.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for soft sorting tensorflow layers."""
17

18
from absl.testing import parameterized
19
import tensorflow.compat.v2 as tf
20

21
from soft_sort import layers
22

23

24
class LayersTest(parameterized.TestCase, tf.test.TestCase):
25

26
  def setUp(self):
27
    super().setUp()
28
    tf.random.set_seed(0)
29
    self._input_shape = (1, 10, 2)
30
    self._axis = 1
31
    self._inputs = tf.random.normal(self._input_shape)
32

33
  @parameterized.parameters([6, 3])
34
  def test_soft_topk_layer(self, topk):
35
    direction = 'DESCENDING'
36
    layer = layers.SoftSortLayer(
37
        axis=self._axis, topk=topk, direction=direction, epsilon=1e-3)
38
    outputs = layer(self._inputs)
39
    expected_shape = list(self._input_shape)
40
    expected_shape[self._axis] = topk
41
    self.assertAllEqual(outputs.shape, expected_shape)
42
    sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)
43
    self.assertAllClose(sorted_inputs[:, :topk, :], outputs, atol=1e-2)
44

45
  def test_softsortlayer(self):
46
    direction = 'DESCENDING'
47
    layer = layers.SoftSortLayer(
48
        axis=self._axis, direction=direction, epsilon=1e-3)
49
    outputs = layer(self._inputs)
50
    self.assertAllEqual(outputs.shape, self._inputs.shape)
51
    sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)
52
    self.assertAllClose(sorted_inputs, outputs, atol=1e-2)
53

54
  def take_model_output(self, layer, inputs):
55
    model = tf.keras.Sequential([
56
        tf.keras.layers.Flatten(input_shape=inputs[0].shape),
57
        layer,
58
        tf.keras.layers.Dense(1, activation='softmax')
59
    ])
60
    model.build(inputs.shape)
61
    model.compile(tf.keras.optimizers.SGD(1e-3),
62
                  loss='binary_crossentropy',
63
                  metrics=['accuracy'])
64
    return model(inputs)
65

66
  @parameterized.parameters([None, 4])
67
  def test_sortlayer_in_model(self, topk):
68
    inputs = tf.random.uniform((32, 10))
69
    outputs = self.take_model_output(layers.SoftSortLayer(topk=topk), inputs)
70
    self.assertAllEqual([inputs.shape[0], 1], outputs.shape)
71

72
  def test_rankslayer_in_model(self):
73
    inputs = tf.random.uniform((32, 10))
74
    outputs = self.take_model_output(layers.SoftRanksLayer(), inputs)
75
    self.assertAllEqual([inputs.shape[0], 1], outputs.shape)
76

77
  def test_quantilelayer_in_model(self):
78
    inputs = tf.random.uniform((32, 10))
79
    outputs = self.take_model_output(
80
        layers.SoftQuantilesLayer(
81
            quantiles=[0.2, 0.5, 0.8], output_shape=(32, 3)),
82
        inputs)
83
    self.assertAllEqual([inputs.shape[0], 1], outputs.shape)
84

85
  def test_softranks(self):
86
    layer = layers.SoftRanksLayer(axis=self._axis, epsilon=1e-4)
87
    outputs = layer(self._inputs)
88
    self.assertAllEqual(outputs.shape, self._inputs.shape)
89
    ranks = tf.argsort(
90
        tf.argsort(self._inputs, axis=self._axis), axis=self._axis)
91
    self.assertAllClose(ranks, outputs, atol=0.5)
92

93
  def test_softquantiles(self):
94
    inputs = tf.reshape(tf.range(101, dtype=tf.float32), (1, -1))
95
    axis = 1
96
    quantiles = [0.25, 0.50, 0.75]
97
    layer = layers.SoftQuantilesLayer(
98
        quantiles=quantiles, output_shape=None, axis=axis, epsilon=1e-3)
99

100
    outputs = layer(inputs)
101
    self.assertAllEqual(outputs.shape, (1, 3))
102

103
    self.assertAllClose(tf.constant([[25., 50., 75.]]), outputs, atol=0.5)
104

105

106
if __name__ == '__main__':
107
  tf.enable_v2_behavior()
108
  tf.test.main()
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.