google-research

Форк
0
109 строк · 4.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for semiring."""
17

18
from absl.testing import parameterized
19
from lingvo import compat as tf
20
import semiring
21

22

23
class SemiringTest(parameterized.TestCase, tf.test.TestCase):
24

25
  def _TestAdditiveMonoid(self, r, elem_1, elem_2, elem_3):
26
    """Checks that addition is a commutative monoid."""
27
    self.assertAllClose(r.add(elem_1, elem_2), r.add(elem_2, elem_1))
28
    self.assertAllClose(r.add(elem_1, r.additive_identity()), elem_1)
29
    self.assertAllClose(r.add(r.additive_identity(), elem_1), elem_1)
30
    self.assertAllClose(
31
        r.add(r.add(elem_1, elem_2), elem_3),
32
        r.add(elem_1, r.add(elem_2, elem_3)))
33

34
  def _TestMultiplicativeMonoid(self, r, elem_1, elem_2, elem_3):
35
    """Checks that multiplication is a monoid."""
36
    self.assertAllClose(r.multiply(elem_1, r.multiplicative_identity()), elem_1)
37
    self.assertAllClose(r.multiply(r.multiplicative_identity(), elem_1), elem_1)
38
    self.assertAllClose(
39
        r.multiply(r.multiply(elem_1, elem_2), elem_3),
40
        r.multiply(elem_1, r.multiply(elem_2, elem_3)))
41

42
  def _TestAdditionList(self, r, elem_list):
43
    """Compare result of r.add_list with r.add."""
44
    manual_sum = elem_list[0]
45
    for elem in elem_list[1:]:
46
      manual_sum = r.add(manual_sum, elem)
47
    self.assertAllClose(manual_sum, r.add_list(elem_list))
48

49
  def _TestMultiplicationList(self, r, elem_list):
50
    """Compare result of r.multiply_list with r.multiply."""
51
    manual_prod = elem_list[0]
52
    for elem in elem_list[1:]:
53
      manual_prod = r.multiply(manual_prod, elem)
54
    self.assertAllClose(manual_prod, r.multiply_list(elem_list))
55

56
  def _TestDistributiveProperty(self, r, elem_1, elem_2, elem_3):
57
    """Checks that multiplication distributes over addition."""
58
    self.assertAllClose(
59
        r.multiply(elem_1, r.add(elem_2, elem_3)),
60
        r.add(r.multiply(elem_1, elem_2), r.multiply(elem_1, elem_3)))
61
    self.assertAllClose(
62
        r.multiply(r.add(elem_2, elem_3), elem_1),
63
        r.add(r.multiply(elem_2, elem_1), r.multiply(elem_3, elem_1)))
64

65
  def _TestAnnihilation(self, r, elem_1):
66
    """Checks that additive identity is a multiplicative annihilator."""
67
    self.assertAllClose(
68
        r.multiply(r.additive_identity(), elem_1), r.additive_identity())
69
    self.assertAllClose(
70
        r.multiply(elem_1, r.additive_identity()), r.additive_identity())
71

72
  @parameterized.named_parameters(
73
      (
74
          'Log Semiring',
75
          semiring.LogSemiring(),
76
          (tf.constant([-2.0]),),
77
          (tf.constant([-3.0]),),
78
          (tf.constant([-4.0]),),
79
      ),
80
      (
81
          'Log Entropy Semiring',
82
          semiring.LogEntropySemiring(),
83
          (tf.constant([-2.0]), tf.constant([-2.5])),
84
          (tf.constant([-3.0]), tf.constant([-3.5])),
85
          (tf.constant([-4.0]), tf.constant([-4.5])),
86
      ),
87
      (
88
          'Log Reverse-KL Semiring',
89
          semiring.LogReverseKLSemiring(),
90
          (tf.constant([-2.0]), tf.constant([-2.5]), tf.constant(
91
              [-2.6]), tf.constant([-2.7])),
92
          (tf.constant([-3.0]), tf.constant([-3.5]), tf.constant(
93
              [-3.6]), tf.constant([-3.7])),
94
          (tf.constant([-4.0]), tf.constant([-4.5]), tf.constant(
95
              [-4.6]), tf.constant([-4.7])),
96
      ),
97
  )
98
  def testSemiring(self, r, elem_1, elem_2, elem_3):
99
    """Tests if r is a semiring."""
100
    self._TestAdditiveMonoid(r, elem_1, elem_2, elem_3)
101
    self._TestMultiplicativeMonoid(r, elem_1, elem_2, elem_3)
102
    self._TestAdditionList(r, [elem_1, elem_2, elem_3])
103
    self._TestMultiplicationList(r, [elem_1, elem_2, elem_3])
104
    self._TestDistributiveProperty(r, elem_1, elem_2, elem_3)
105
    self._TestAnnihilation(r, elem_1)
106

107

108
if __name__ == '__main__':
109
  tf.test.main()
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.