google-research
109 строк · 4.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for semiring."""
17
18from absl.testing import parameterized
19from lingvo import compat as tf
20import semiring
21
22
23class SemiringTest(parameterized.TestCase, tf.test.TestCase):
24
25def _TestAdditiveMonoid(self, r, elem_1, elem_2, elem_3):
26"""Checks that addition is a commutative monoid."""
27self.assertAllClose(r.add(elem_1, elem_2), r.add(elem_2, elem_1))
28self.assertAllClose(r.add(elem_1, r.additive_identity()), elem_1)
29self.assertAllClose(r.add(r.additive_identity(), elem_1), elem_1)
30self.assertAllClose(
31r.add(r.add(elem_1, elem_2), elem_3),
32r.add(elem_1, r.add(elem_2, elem_3)))
33
34def _TestMultiplicativeMonoid(self, r, elem_1, elem_2, elem_3):
35"""Checks that multiplication is a monoid."""
36self.assertAllClose(r.multiply(elem_1, r.multiplicative_identity()), elem_1)
37self.assertAllClose(r.multiply(r.multiplicative_identity(), elem_1), elem_1)
38self.assertAllClose(
39r.multiply(r.multiply(elem_1, elem_2), elem_3),
40r.multiply(elem_1, r.multiply(elem_2, elem_3)))
41
42def _TestAdditionList(self, r, elem_list):
43"""Compare result of r.add_list with r.add."""
44manual_sum = elem_list[0]
45for elem in elem_list[1:]:
46manual_sum = r.add(manual_sum, elem)
47self.assertAllClose(manual_sum, r.add_list(elem_list))
48
49def _TestMultiplicationList(self, r, elem_list):
50"""Compare result of r.multiply_list with r.multiply."""
51manual_prod = elem_list[0]
52for elem in elem_list[1:]:
53manual_prod = r.multiply(manual_prod, elem)
54self.assertAllClose(manual_prod, r.multiply_list(elem_list))
55
56def _TestDistributiveProperty(self, r, elem_1, elem_2, elem_3):
57"""Checks that multiplication distributes over addition."""
58self.assertAllClose(
59r.multiply(elem_1, r.add(elem_2, elem_3)),
60r.add(r.multiply(elem_1, elem_2), r.multiply(elem_1, elem_3)))
61self.assertAllClose(
62r.multiply(r.add(elem_2, elem_3), elem_1),
63r.add(r.multiply(elem_2, elem_1), r.multiply(elem_3, elem_1)))
64
65def _TestAnnihilation(self, r, elem_1):
66"""Checks that additive identity is a multiplicative annihilator."""
67self.assertAllClose(
68r.multiply(r.additive_identity(), elem_1), r.additive_identity())
69self.assertAllClose(
70r.multiply(elem_1, r.additive_identity()), r.additive_identity())
71
72@parameterized.named_parameters(
73(
74'Log Semiring',
75semiring.LogSemiring(),
76(tf.constant([-2.0]),),
77(tf.constant([-3.0]),),
78(tf.constant([-4.0]),),
79),
80(
81'Log Entropy Semiring',
82semiring.LogEntropySemiring(),
83(tf.constant([-2.0]), tf.constant([-2.5])),
84(tf.constant([-3.0]), tf.constant([-3.5])),
85(tf.constant([-4.0]), tf.constant([-4.5])),
86),
87(
88'Log Reverse-KL Semiring',
89semiring.LogReverseKLSemiring(),
90(tf.constant([-2.0]), tf.constant([-2.5]), tf.constant(
91[-2.6]), tf.constant([-2.7])),
92(tf.constant([-3.0]), tf.constant([-3.5]), tf.constant(
93[-3.6]), tf.constant([-3.7])),
94(tf.constant([-4.0]), tf.constant([-4.5]), tf.constant(
95[-4.6]), tf.constant([-4.7])),
96),
97)
98def testSemiring(self, r, elem_1, elem_2, elem_3):
99"""Tests if r is a semiring."""
100self._TestAdditiveMonoid(r, elem_1, elem_2, elem_3)
101self._TestMultiplicativeMonoid(r, elem_1, elem_2, elem_3)
102self._TestAdditionList(r, [elem_1, elem_2, elem_3])
103self._TestMultiplicationList(r, [elem_1, elem_2, elem_3])
104self._TestDistributiveProperty(r, elem_1, elem_2, elem_3)
105self._TestAnnihilation(r, elem_1)
106
107
108if __name__ == '__main__':
109tf.test.main()
110