google-research
209 строк · 6.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Test Lower Triangular Multiplication Algorithm."""
17
18from absl.testing import absltest19from absl.testing import parameterized20import jax.numpy as jnp21from polysketchformer import lower_triangular_multiplication22
23
24class LowerTriangularMultiplicationTest(parameterized.TestCase):25
26@parameterized.named_parameters(27{'testcase_name': 'grain_size_1', 'grain_size': 1},28{'testcase_name': 'grain_size_2', 'grain_size': 2},29{'testcase_name': 'grain_size_4', 'grain_size': 4},30)31def test_lt_multiply_build_cache(self, grain_size):32"""Test lt_multiply with different grain sizes when build_cache is True."""33n = 434r = 335d = 836batches = 437
38# Instantiate inputs for the test.39a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))40b = (41jnp.arange(batches * n * r, 2 * batches * n * r)42.astype(jnp.float32)43.reshape((batches, n, r))44)45c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))46
47# Expected outputs from the implementation.48direct_result = jnp.tril(a @ b.transpose(0, 2, 1)) @ c49direct_cache = b.transpose(0, 2, 1) @ c50
51# Compute outputs from the implementation.52lt_multiply_result, cache = (53lower_triangular_multiplication.lt_multiply(54a,55b,56c,57grain_size=grain_size,58build_cache=True,59)60)61
62self.assertTrue(63jnp.allclose(64lt_multiply_result, direct_result, rtol=1e-3, atol=1e-565)66)67self.assertTrue(jnp.allclose(cache, direct_cache, rtol=1e-3, atol=1e-5))68
69@parameterized.named_parameters(70{'testcase_name': 'grain_size_1', 'grain_size': 1},71{'testcase_name': 'grain_size_2', 'grain_size': 2},72{'testcase_name': 'grain_size_4', 'grain_size': 4},73)74def test_lt_multiply_no_build_cache(self, grain_size):75"""Test lt_multiply with different grain sizes when build_cache is False."""76n = 477r = 378d = 879batches = 480
81# Instantiate inputs for the test.82a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))83b = (84jnp.arange(batches * n * r, 2 * batches * n * r)85.astype(jnp.float32)86.reshape((batches, n, r))87)88c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))89
90# Expected output from the implementation.91direct_result = jnp.tril(a @ b.transpose(0, 2, 1)) @ c92
93# Compute outputs from the implementation.94lt_multiply_result, cache = (95lower_triangular_multiplication.lt_multiply(96a,97b,98c,99grain_size=grain_size,100build_cache=False,101)102)103
104# Check that the expected outputs are close to the actual outputs.105self.assertTrue(106jnp.allclose(107lt_multiply_result, direct_result, rtol=1e-3, atol=1e-5108)109)110self.assertIsNone(cache)111
112@parameterized.named_parameters(113{'testcase_name': 'grain_size_1', 'grain_size': 1},114{'testcase_name': 'grain_size_2', 'grain_size': 2},115{'testcase_name': 'grain_size_4', 'grain_size': 4},116)117def test_tensor_lt_multiply_build_cache(self, grain_size):118"""Test tensor_lt_multiply with different grain sizes."""119n = 4120r = 3121d = 8122batches = 4123
124# Instantiating inputs for the test.125a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))126b = (127jnp.arange(batches * n * r, 2 * batches * n * r)128.astype(jnp.float32)129.reshape((batches, n, r))130)131c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))132
133# Expected outputs.134direct_result = jnp.tril((a @ b.transpose(0, 2, 1)) ** 2) @ c135direct_cache = jnp.einsum('...ti, ...tj, ...td -> ...ijd', b, b, c)136
137tensor_lt_multiply_result, cache = (138lower_triangular_multiplication.tensor_lt_multiply(139a, b, c, grain_size=grain_size, build_cache=True140)141)142
143# Checking closeness of the outputs from implementation with expected144# outputs.145
146self.assertTrue(147jnp.allclose(148tensor_lt_multiply_result,149direct_result,150rtol=1e-3,151atol=1e-5,152)153)154self.assertTrue(155jnp.allclose(156cache,157direct_cache,158rtol=1e-3,159atol=1e-5,160)161)162
163@parameterized.named_parameters(164{'testcase_name': 'grain_size_1', 'grain_size': 1},165{'testcase_name': 'grain_size_2', 'grain_size': 2},166{'testcase_name': 'grain_size_4', 'grain_size': 4},167)168def test_tensor_lt_multiply_no_build_cache(self, grain_size):169"""Test tensor_lt_multiply when build_cache=False."""170n = 32171r = 8172d = 8173batches = 3174
175# Instantiating inputs for the test.176a = jnp.arange(batches * n * r).astype(jnp.float32).reshape((batches, n, r))177b = (178jnp.arange(batches * n * r, 2 * batches * n * r)179.astype(jnp.float32)180.reshape((batches, n, r))181)182c = jnp.arange(batches * n * d).astype(jnp.float32).reshape((batches, n, d))183
184# Expected output.185direct_result = jnp.tril((a @ b.transpose(0, 2, 1)) ** 2) @ c186
187# Outputs from the implementation with different grain_size params.188
189tensor_lt_multiply_result, cache = (190lower_triangular_multiplication.tensor_lt_multiply(191a, b, c, grain_size=grain_size, build_cache=False192)193)194
195# Checking closeness of the outputs from implementation with expected196# outputs.197self.assertTrue(198jnp.allclose(199tensor_lt_multiply_result,200direct_result,201rtol=1e-3,202atol=1e-5,203)204)205self.assertIsNone(cache)206
207
208if __name__ == '__main__':209absltest.main()210