google-research

Форк
0
/
lower_triangular_multiplication_test.py 
209 строк · 6.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Test Lower Triangular Multiplication Algorithm."""
17

18
from absl.testing import absltest
19
from absl.testing import parameterized
20
import jax.numpy as jnp
21
from polysketchformer import lower_triangular_multiplication
22

23

24
class LowerTriangularMultiplicationTest(parameterized.TestCase):
25

26
  @parameterized.named_parameters(
27
      {'testcase_name': 'grain_size_1', 'grain_size': 1},
28
      {'testcase_name': 'grain_size_2', 'grain_size': 2},
29
      {'testcase_name': 'grain_size_4', 'grain_size': 4},
30
  )
31
  def test_lt_multiply_build_cache(self, grain_size):
32
    """Test lt_multiply with different grain sizes when build_cache is True."""
33
    n = 4
34
    r = 3
35
    d = 8
36
    batches = 4
37

38
    # Instantiate inputs for the test.
39
    a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))
40
    b = (
41
        jnp.arange(batches * n * r, 2 * batches * n * r)
42
        .astype(jnp.float32)
43
        .reshape((batches, n, r))
44
    )
45
    c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))
46

47
    # Expected outputs from the implementation.
48
    direct_result = jnp.tril(a @ b.transpose(0, 2, 1)) @ c
49
    direct_cache = b.transpose(0, 2, 1) @ c
50

51
    # Compute outputs from the implementation.
52
    lt_multiply_result, cache = (
53
        lower_triangular_multiplication.lt_multiply(
54
            a,
55
            b,
56
            c,
57
            grain_size=grain_size,
58
            build_cache=True,
59
        )
60
    )
61

62
    self.assertTrue(
63
        jnp.allclose(
64
            lt_multiply_result, direct_result, rtol=1e-3, atol=1e-5
65
        )
66
    )
67
    self.assertTrue(jnp.allclose(cache, direct_cache, rtol=1e-3, atol=1e-5))
68

69
  @parameterized.named_parameters(
70
      {'testcase_name': 'grain_size_1', 'grain_size': 1},
71
      {'testcase_name': 'grain_size_2', 'grain_size': 2},
72
      {'testcase_name': 'grain_size_4', 'grain_size': 4},
73
  )
74
  def test_lt_multiply_no_build_cache(self, grain_size):
75
    """Test lt_multiply with different grain sizes when build_cache is False."""
76
    n = 4
77
    r = 3
78
    d = 8
79
    batches = 4
80

81
    # Instantiate inputs for the test.
82
    a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))
83
    b = (
84
        jnp.arange(batches * n * r, 2 * batches * n * r)
85
        .astype(jnp.float32)
86
        .reshape((batches, n, r))
87
    )
88
    c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))
89

90
    # Expected output from the implementation.
91
    direct_result = jnp.tril(a @ b.transpose(0, 2, 1)) @ c
92

93
    # Compute outputs from the implementation.
94
    lt_multiply_result, cache = (
95
        lower_triangular_multiplication.lt_multiply(
96
            a,
97
            b,
98
            c,
99
            grain_size=grain_size,
100
            build_cache=False,
101
        )
102
    )
103

104
    # Check that the expected outputs are close to the actual outputs.
105
    self.assertTrue(
106
        jnp.allclose(
107
            lt_multiply_result, direct_result, rtol=1e-3, atol=1e-5
108
        )
109
    )
110
    self.assertIsNone(cache)
111

112
  @parameterized.named_parameters(
113
      {'testcase_name': 'grain_size_1', 'grain_size': 1},
114
      {'testcase_name': 'grain_size_2', 'grain_size': 2},
115
      {'testcase_name': 'grain_size_4', 'grain_size': 4},
116
  )
117
  def test_tensor_lt_multiply_build_cache(self, grain_size):
118
    """Test tensor_lt_multiply with different grain sizes."""
119
    n = 4
120
    r = 3
121
    d = 8
122
    batches = 4
123

124
    # Instantiating inputs for the test.
125
    a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))
126
    b = (
127
        jnp.arange(batches * n * r, 2 * batches * n * r)
128
        .astype(jnp.float32)
129
        .reshape((batches, n, r))
130
    )
131
    c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))
132

133
    # Expected outputs.
134
    direct_result = jnp.tril((a @ b.transpose(0, 2, 1)) ** 2) @ c
135
    direct_cache = jnp.einsum('...ti, ...tj, ...td -> ...ijd', b, b, c)
136

137
    tensor_lt_multiply_result, cache = (
138
        lower_triangular_multiplication.tensor_lt_multiply(
139
            a, b, c, grain_size=grain_size, build_cache=True
140
        )
141
    )
142

143
    # Checking closeness of the outputs from implementation with expected
144
    # outputs.
145

146
    self.assertTrue(
147
        jnp.allclose(
148
            tensor_lt_multiply_result,
149
            direct_result,
150
            rtol=1e-3,
151
            atol=1e-5,
152
        )
153
    )
154
    self.assertTrue(
155
        jnp.allclose(
156
            cache,
157
            direct_cache,
158
            rtol=1e-3,
159
            atol=1e-5,
160
        )
161
    )
162

163
  @parameterized.named_parameters(
164
      {'testcase_name': 'grain_size_1', 'grain_size': 1},
165
      {'testcase_name': 'grain_size_2', 'grain_size': 2},
166
      {'testcase_name': 'grain_size_4', 'grain_size': 4},
167
  )
168
  def test_tensor_lt_multiply_no_build_cache(self, grain_size):
169
    """Test tensor_lt_multiply when build_cache=False."""
170
    n = 32
171
    r = 8
172
    d = 8
173
    batches = 3
174

175
    # Instantiating inputs for the test.
176
    a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))
177
    b = (
178
        jnp.arange(batches * n * r, 2 * batches * n * r)
179
        .astype(jnp.float32)
180
        .reshape((batches, n, r))
181
    )
182
    c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))
183

184
    # Expected output.
185
    direct_result = jnp.tril((a @ b.transpose(0, 2, 1)) ** 2) @ c
186

187
    # Outputs from the implementation with different grain_size params.
188

189
    tensor_lt_multiply_result, cache = (
190
        lower_triangular_multiplication.tensor_lt_multiply(
191
            a, b, c, grain_size=grain_size, build_cache=False
192
        )
193
    )
194

195
    # Checking closeness of the outputs from implementation with expected
196
    # outputs.
197
    self.assertTrue(
198
        jnp.allclose(
199
            tensor_lt_multiply_result,
200
            direct_result,
201
            rtol=1e-3,
202
            atol=1e-5,
203
        )
204
    )
205
    self.assertIsNone(cache)
206

207

208
if __name__ == '__main__':
209
  absltest.main()
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.