google-research
108 строк · 3.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for soft sorting tensorflow layers."""
17
18from absl.testing import parameterized19import tensorflow.compat.v2 as tf20
21from soft_sort import layers22
23
24class LayersTest(parameterized.TestCase, tf.test.TestCase):25
26def setUp(self):27super().setUp()28tf.random.set_seed(0)29self._input_shape = (1, 10, 2)30self._axis = 131self._inputs = tf.random.normal(self._input_shape)32
33@parameterized.parameters([6, 3])34def test_soft_topk_layer(self, topk):35direction = 'DESCENDING'36layer = layers.SoftSortLayer(37axis=self._axis, topk=topk, direction=direction, epsilon=1e-3)38outputs = layer(self._inputs)39expected_shape = list(self._input_shape)40expected_shape[self._axis] = topk41self.assertAllEqual(outputs.shape, expected_shape)42sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)43self.assertAllClose(sorted_inputs[:, :topk, :], outputs, atol=1e-2)44
45def test_softsortlayer(self):46direction = 'DESCENDING'47layer = layers.SoftSortLayer(48axis=self._axis, direction=direction, epsilon=1e-3)49outputs = layer(self._inputs)50self.assertAllEqual(outputs.shape, self._inputs.shape)51sorted_inputs = tf.sort(self._inputs, axis=self._axis, direction=direction)52self.assertAllClose(sorted_inputs, outputs, atol=1e-2)53
54def take_model_output(self, layer, inputs):55model = tf.keras.Sequential([56tf.keras.layers.Flatten(input_shape=inputs[0].shape),57layer,58tf.keras.layers.Dense(1, activation='softmax')59])60model.build(inputs.shape)61model.compile(tf.keras.optimizers.SGD(1e-3),62loss='binary_crossentropy',63metrics=['accuracy'])64return model(inputs)65
66@parameterized.parameters([None, 4])67def test_sortlayer_in_model(self, topk):68inputs = tf.random.uniform((32, 10))69outputs = self.take_model_output(layers.SoftSortLayer(topk=topk), inputs)70self.assertAllEqual([inputs.shape[0], 1], outputs.shape)71
72def test_rankslayer_in_model(self):73inputs = tf.random.uniform((32, 10))74outputs = self.take_model_output(layers.SoftRanksLayer(), inputs)75self.assertAllEqual([inputs.shape[0], 1], outputs.shape)76
77def test_quantilelayer_in_model(self):78inputs = tf.random.uniform((32, 10))79outputs = self.take_model_output(80layers.SoftQuantilesLayer(81quantiles=[0.2, 0.5, 0.8], output_shape=(32, 3)),82inputs)83self.assertAllEqual([inputs.shape[0], 1], outputs.shape)84
85def test_softranks(self):86layer = layers.SoftRanksLayer(axis=self._axis, epsilon=1e-4)87outputs = layer(self._inputs)88self.assertAllEqual(outputs.shape, self._inputs.shape)89ranks = tf.argsort(90tf.argsort(self._inputs, axis=self._axis), axis=self._axis)91self.assertAllClose(ranks, outputs, atol=0.5)92
93def test_softquantiles(self):94inputs = tf.reshape(tf.range(101, dtype=tf.float32), (1, -1))95axis = 196quantiles = [0.25, 0.50, 0.75]97layer = layers.SoftQuantilesLayer(98quantiles=quantiles, output_shape=None, axis=axis, epsilon=1e-3)99
100outputs = layer(inputs)101self.assertAllEqual(outputs.shape, (1, 3))102
103self.assertAllClose(tf.constant([[25., 50., 75.]]), outputs, atol=0.5)104
105
106if __name__ == '__main__':107tf.enable_v2_behavior()108tf.test.main()109