google-research
365 строк · 11.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for inference."""
17
18from functools import partial # pylint: disable = g-importing-member19
20from absl.testing import absltest21import jax22from jax.experimental.shard_map import shard_map23import jax.numpy as jnp24import numpy as np25
26from scaling_transformer_inference_efficiency import checkpoint27from scaling_transformer_inference_efficiency import chunk28from scaling_transformer_inference_efficiency import collectives29from scaling_transformer_inference_efficiency import inference30from scaling_transformer_inference_efficiency import partitioning31from scaling_transformer_inference_efficiency import weights32from scaling_transformer_inference_efficiency.layers import layers_pjit33from scaling_transformer_inference_efficiency.layers import one_d_parallel_xmap34from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap35
36X, Y, Z = 2, 2, 2 # slice sizes pylint: disable = invalid-name37
38
39def setup(40batch_size,41seq_len,42latency_collectives,43one_d = False,44):45"""Sets up necessary inputs."""46assert len(jax.devices()) == X * Y * Z47
48mesh = partitioning.make_mesh(one_d=one_d)49
50key = jax.random.PRNGKey(0)51dtype = jnp.float3252h = checkpoint.HParams(53layers=8, embed=16, ff=32, heads=16, qkv=4, max_len=256, vocab=1024)54key, k2, k3, k4, k5 = jax.random.split(key, 5)55q_wi = jax.random.normal(k2, (h.layers, h.heads, h.embed, h.q_wi_per_head),56dtype)57kv = jax.random.normal(k3, (h.layers, h.embed, 1, 2 * h.qkv), dtype)58o_wo = jax.random.normal(k4, (h.layers, h.heads, h.o_wo_per_head, h.embed),59dtype)60embedding = jax.random.normal(k5, (h.vocab, h.embed), dtype)61sin = jnp.ones((h.max_len, h.qkv // 2), dtype)62cos = jnp.ones((h.max_len, h.qkv // 2), dtype)63
64# create the params65params_pjit = weights.Weights(66weights.Layer(q_wi, kv, o_wo), sin, cos, embedding)67
68# create the token inputs69token_chunk = chunk.Chunk(70tokens=jnp.reshape(71jnp.arange(batch_size * seq_len), (batch_size, seq_len)),72lengths=jnp.array([seq_len] * batch_size))73
74def to_named_sharding(mesh, spec):75return jax.sharding.NamedSharding(mesh, spec)76
77to_named_sharding = partial(to_named_sharding, mesh)78
79# pjit sharding80chunk_spec = jax.tree_util.tree_map(81to_named_sharding, chunk.Chunk.physical_axes()82)83param_spec = jax.tree_util.tree_map(84to_named_sharding, weights.Weights.physical_axes()85)86# result_spec = jax.tree_util.tree_map(to_named_sharding, result_sharding)87
88token_chunk = jax.device_put(token_chunk, chunk_spec)89params_pjit = jax.device_put(params_pjit, param_spec)90
91def rotate_weights(params):92"""Rotate the weights for the collectives.93
94Assumed to occur in a per device form. Assumes 2D partitioning.
95q_wi: [layers, heads.YZ, dmodel.X, q_wi_per_head]
96o_wo: [layers, heads.YZ, owo_per_head, dmodel.X]
97
98Args:
99params: parameters
100
101Returns:
102params: rotated parameters
103"""
104new_layer = params.layer105new_layer = new_layer.replace(106q_wi=collectives.preshuffle_for_reducescatter_latency(107new_layer.q_wi, scatter_axis=1, axis_name='x'))108new_layer = new_layer.replace(109o_wo=collectives.preshuffle_for_allgather_matmul_latency(110new_layer.o_wo, shuffle_axis=1, axis_name='x'))111return params.replace(layer=new_layer)112
113if latency_collectives:114with mesh:115rotated_params = jax.jit(116shard_map(117rotate_weights,118mesh,119in_specs=(weights.Weights.physical_axes(),),120out_specs=weights.Weights.physical_axes(),121check_rep=False,122)123)(params_pjit)124else:125rotated_params = params_pjit126
127kv_caches = []128
129return (dtype, h, mesh, params_pjit, rotated_params, kv_caches, token_chunk)130
131
132# pylint: disable = dangerous-default-value
133def xmap_pjit_equivalency(134batch_size=4,135seq_len=32,136rules = [],137attn_sharding=partitioning.AttnAllToAll.NONE,138latency_collectives=False,139batch_unsharded=False,140shard_seqlen_vs_batch=False,141layer_fn=two_d_parallel_xmap.transformer_layer_weight_stationary,142atol=1e-03,143rtol=1e-06,144):145"""Tests shard map."""146# Within this function, we device put the relevant arrays ahead of time147one_d = layer_fn == one_d_parallel_xmap.weight_stationary_simple148
149with rules:150(dtype, h, mesh, params, rotated_params, kv_caches, token_chunk) = setup(151batch_size=batch_size,152seq_len=seq_len,153latency_collectives=latency_collectives,154one_d=one_d,155)156
157def fwd_pjit(params, token_chunk):158return inference.infer(159h,160layers_pjit.pjit_transformer_layer,161params,162kv_caches,163token_chunk,164intermediate_dtype=dtype)165
166with mesh:167result_baseline = jax.jit(fwd_pjit)(params, token_chunk)168
169sharding_config = partitioning.ShardingConfig(170mesh=mesh,171attn_all_to_all=attn_sharding,172latency_collectives=latency_collectives,173shard_seqlen_vs_batch=shard_seqlen_vs_batch,174batch_unsharded=batch_unsharded,175)176
177embed_fn = partial(178two_d_parallel_xmap.embed_manual,179shard_seqlen_vs_batch=shard_seqlen_vs_batch,180batch_unsharded=batch_unsharded,181one_d=one_d,182)183
184if layer_fn == two_d_parallel_xmap.transformer_layer_weight_stationary:185layer_fn = partial(186layer_fn,187attn_all_to_all=attn_sharding,188latency_collectives=latency_collectives,189shard_seqlen_vs_batch=shard_seqlen_vs_batch,190batch_unsharded=batch_unsharded,191)192elif layer_fn == one_d_parallel_xmap.weight_stationary_simple:193layer_fn = partial(layer_fn, latency_collectives=latency_collectives)194elif layer_fn == two_d_parallel_xmap.transformer_layer_weight_gathered:195raise NotImplementedError196
197unembed_fn = partial(198two_d_parallel_xmap.unembed_manual,199batch_unsharded=batch_unsharded,200one_d=one_d,201)202
203forward_pass = partial(204inference.manual_fwd_pass,205h,206sharding_config,207embed_fn,208layer_fn,209unembed_fn,210)211
212def fwd(params, token_chunk):213"""Wraps the inference fn to ease shardmap in pytree definition."""214return inference.infer_template(215h,216sharding_config,217forward_pass,218params,219kv_caches,220token_chunk,221intermediate_dtype=dtype,222)223
224with mesh:225result_shardmap = jax.jit(fwd)(rotated_params, token_chunk)226
227np.testing.assert_allclose(228result_baseline.kv_cache.k.astype(jnp.float32),229result_shardmap.kv_cache.k.astype(jnp.float32),230rtol=1e-1,231) # none_b1 needs this tolerance - XLA? TODO(sholto): Check232np.testing.assert_allclose(233result_baseline.logits, result_shardmap.logits, rtol=rtol, atol=atol234)235# pylint: disable = unused-variable236# TODO(sholto): The final grad(shard_map) bug237# pylint: disable = protected-access238def grads_pjit(params, token_chunk):239def loss_fn(params, token_chunk):240result = fwd_pjit(params, token_chunk)241return result.logits.mean()242
243loss, grads = jax.value_and_grad(loss_fn)(params, token_chunk)244grads = jax.tree_map(245partitioning._with_sharding_constraint,246grads,247weights.Weights.logical_axes(),248)249return loss, grads250
251def grads(params, token_chunk):252def loss_fn(params, token_chunk):253result = fwd(params, token_chunk)254return result.logits.mean()255
256loss, grads = jax.value_and_grad(loss_fn)(params, token_chunk)257grads = jax.tree_map(258partitioning._with_sharding_constraint,259grads,260weights.Weights.logical_axes(),261)262return loss, grads263
264if attn_sharding == partitioning.AttnAllToAll.NONE:265with mesh:266loss_pjit, grads_pjit = jax.jit(grads_pjit)(params, token_chunk)267loss, grads = jax.jit(grads)(params, token_chunk)268
269# jax.tree_map(270# partial(np.testing.assert_allclose, atol=atol),271# grads_pjit,272# grads,273# )274
275
276class InferenceTest(absltest.TestCase):277"""Tests for inference fwd pass."""278
279def test_none_sharding_b1(self):280attn_sharding = partitioning.AttnAllToAll.NONE281rules = partitioning.PartitioningRules(282partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)283)284xmap_pjit_equivalency(285batch_size=1,286seq_len=1,287rules=rules,288attn_sharding=partitioning.AttnAllToAll.NONE,289batch_unsharded=True,290atol=1e-01,291) # TODO(sholto); Check if this is because it occurs on VPU like b/246436629 pylint: disable= line-too-long292
293def test_none_sharding(self):294attn_sharding = partitioning.AttnAllToAll.NONE295rules = partitioning.PartitioningRules(296partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)297)298xmap_pjit_equivalency(299batch_size=2,300rules=rules,301attn_sharding=attn_sharding,302batch_unsharded=True,303)304
305def test_one_d(self):306rules = partitioning.PartitioningRules(partitioning.make_rules_one_d())307xmap_pjit_equivalency(308batch_size=2,309rules=rules,310layer_fn=one_d_parallel_xmap.weight_stationary_simple,311)312
313def test_attn_z_sharding(self):314attn_sharding = partitioning.AttnAllToAll.AXIS_Z315rules = partitioning.PartitioningRules(316partitioning.make_rules_two_d(attn_sharding)317)318xmap_pjit_equivalency(319batch_size=2, rules=rules, attn_sharding=attn_sharding320)321
322def test_attn_yz_sharding(self):323attn_sharding = partitioning.AttnAllToAll.AXES_YZ324rules = partitioning.PartitioningRules(325partitioning.make_rules_two_d(attn_sharding)326)327xmap_pjit_equivalency(328batch_size=4, rules=rules, attn_sharding=attn_sharding329)330
331def test_attn_yz_sharding_batch_unsharded(self):332attn_sharding = partitioning.AttnAllToAll.AXES_YZ333rules = partitioning.PartitioningRules(334partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)335)336xmap_pjit_equivalency(337batch_size=4, rules=rules, attn_sharding=attn_sharding,338batch_unsharded=True,339)340
341def test_attn_yzx_sharding(self):342attn_sharding = partitioning.AttnAllToAll.AXES_YZX343rules = partitioning.PartitioningRules(344partitioning.make_rules_two_d(attn_sharding)345)346xmap_pjit_equivalency(347batch_size=8, rules=rules, attn_sharding=attn_sharding348)349
350def test_none_sharding_with_latency(self):351attn_sharding = partitioning.AttnAllToAll.NONE352rules = partitioning.PartitioningRules(353partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)354)355xmap_pjit_equivalency(356batch_size=2,357rules=rules,358attn_sharding=attn_sharding,359latency_collectives=True,360batch_unsharded=True,361)362
363
364if __name__ == '__main__':365absltest.main()366