google-research
67 строк · 2.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for partitioning."""
17
18import os19
20from absl.testing import absltest21import jax22from jax import core23import jax.numpy as jnp24from jax.sharding import NamedSharding25from jax.sharding import PartitionSpec as P26
27import resources28from scaling_transformer_inference_efficiency import checkpoint29from scaling_transformer_inference_efficiency import partitioning30
31
32_TOY_HPARAMS = checkpoint.HParams(33layers=3,34embed=128,35ff=256,36heads=2,37qkv=32,38max_len=128,39vocab=32128,40)
41
42
43class PartitioningTest(absltest.TestCase):44
45
46def test_copy_to_device_from_shape(self):47shape = core.ShapedArray((4, 4), dtype=jnp.bfloat16)48mesh = partitioning.make_mesh()49x = partitioning.copy_to_device(shape,50NamedSharding(mesh, P('x', ('y', 'z'))),51shape)52self.assertEqual(x.shape, shape.shape)53self.assertEqual(x.dtype, shape.dtype)54
55def test_copy_to_device_from_array(self):56array = jnp.zeros((4, 4), jnp.bfloat16)57shape = core.ShapedArray((4, 4), dtype=jnp.bfloat16)58mesh = partitioning.make_mesh()59x = partitioning.copy_to_device(array,60NamedSharding(mesh, P('x', ('y', 'z'))),61shape)62self.assertEqual(x.shape, array.shape)63self.assertEqual(x.dtype, array.dtype)64
65
66if __name__ == '__main__':67absltest.main()68