google-research
47 строк · 1.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for special2."""
17
18from absl.testing import absltest19import jax20import jax.numpy as jnp21import numpy as np22
23from scaling_transformer_inference_efficiency import special224
25
26class Special2Test(absltest.TestCase):27
28def test_softmax2(self):29x = jax.random.normal(jax.random.PRNGKey(0), (8,), jnp.float32)30expected = jax.nn.softmax(x)31actual = special2.softmax2(x * special2.LOG2_E)32np.testing.assert_allclose(expected, actual, rtol=1e-6)33
34def test_logsumexp2(self):35x = jax.random.normal(jax.random.PRNGKey(0), (2, 8), jnp.float32)36expected = jax.scipy.special.logsumexp(x, axis=-1)37actual = special2.logsumexp2(x * special2.LOG2_E) * special2.LN_238np.testing.assert_allclose(expected, actual, rtol=1e-6)39
40def test_swish(self):41x = jax.random.normal(jax.random.PRNGKey(0), (2, 8), jnp.float32)42expected = jax.nn.swish(x)43actual = special2.swish2(x * 0.5)44np.testing.assert_allclose(expected, actual, rtol=1e-6)45
46if __name__ == '__main__':47absltest.main()48