google-research

Форк
0
365 строк · 11.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for inference."""
17

18
from functools import partial  # pylint: disable = g-importing-member
19

20
from absl.testing import absltest
21
import jax
22
from jax.experimental.shard_map import shard_map
23
import jax.numpy as jnp
24
import numpy as np
25

26
from scaling_transformer_inference_efficiency import checkpoint
27
from scaling_transformer_inference_efficiency import chunk
28
from scaling_transformer_inference_efficiency import collectives
29
from scaling_transformer_inference_efficiency import inference
30
from scaling_transformer_inference_efficiency import partitioning
31
from scaling_transformer_inference_efficiency import weights
32
from scaling_transformer_inference_efficiency.layers import layers_pjit
33
from scaling_transformer_inference_efficiency.layers import one_d_parallel_xmap
34
from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap
35

36
X, Y, Z = 2, 2, 2  # slice sizes pylint: disable = invalid-name
37

38

39
def setup(
40
    batch_size,
41
    seq_len,
42
    latency_collectives,
43
    one_d = False,
44
):
45
  """Sets up necessary inputs."""
46
  assert len(jax.devices()) == X * Y * Z
47

48
  mesh = partitioning.make_mesh(one_d=one_d)
49

50
  key = jax.random.PRNGKey(0)
51
  dtype = jnp.float32
52
  h = checkpoint.HParams(
53
      layers=8, embed=16, ff=32, heads=16, qkv=4, max_len=256, vocab=1024)
54
  key, k2, k3, k4, k5 = jax.random.split(key, 5)
55
  q_wi = jax.random.normal(k2, (h.layers, h.heads, h.embed, h.q_wi_per_head),
56
                           dtype)
57
  kv = jax.random.normal(k3, (h.layers, h.embed, 1, 2 * h.qkv), dtype)
58
  o_wo = jax.random.normal(k4, (h.layers, h.heads, h.o_wo_per_head, h.embed),
59
                           dtype)
60
  embedding = jax.random.normal(k5, (h.vocab, h.embed), dtype)
61
  sin = jnp.ones((h.max_len, h.qkv // 2), dtype)
62
  cos = jnp.ones((h.max_len, h.qkv // 2), dtype)
63

64
  # create the params
65
  params_pjit = weights.Weights(
66
      weights.Layer(q_wi, kv, o_wo), sin, cos, embedding)
67

68
  # create the token inputs
69
  token_chunk = chunk.Chunk(
70
      tokens=jnp.reshape(
71
          jnp.arange(batch_size * seq_len), (batch_size, seq_len)),
72
      lengths=jnp.array([seq_len] * batch_size))
73

74
  def to_named_sharding(mesh, spec):
75
    return jax.sharding.NamedSharding(mesh, spec)
76

77
  to_named_sharding = partial(to_named_sharding, mesh)
78

79
  # pjit sharding
80
  chunk_spec = jax.tree_util.tree_map(
81
      to_named_sharding, chunk.Chunk.physical_axes()
82
  )
83
  param_spec = jax.tree_util.tree_map(
84
      to_named_sharding, weights.Weights.physical_axes()
85
  )
86
  # result_spec = jax.tree_util.tree_map(to_named_sharding, result_sharding)
87

88
  token_chunk = jax.device_put(token_chunk, chunk_spec)
89
  params_pjit = jax.device_put(params_pjit, param_spec)
90

91
  def rotate_weights(params):
92
    """Rotate the weights for the collectives.
93

94
    Assumed to occur in a per device form. Assumes 2D partitioning.
95
    q_wi: [layers, heads.YZ, dmodel.X, q_wi_per_head]
96
    o_wo: [layers, heads.YZ, owo_per_head, dmodel.X]
97

98
    Args:
99
      params: parameters
100

101
    Returns:
102
      params: rotated parameters
103
    """
104
    new_layer = params.layer
105
    new_layer = new_layer.replace(
106
        q_wi=collectives.preshuffle_for_reducescatter_latency(
107
            new_layer.q_wi, scatter_axis=1, axis_name='x'))
108
    new_layer = new_layer.replace(
109
        o_wo=collectives.preshuffle_for_allgather_matmul_latency(
110
            new_layer.o_wo, shuffle_axis=1, axis_name='x'))
111
    return params.replace(layer=new_layer)
112

113
  if latency_collectives:
114
    with mesh:
115
      rotated_params = jax.jit(
116
          shard_map(
117
              rotate_weights,
118
              mesh,
119
              in_specs=(weights.Weights.physical_axes(),),
120
              out_specs=weights.Weights.physical_axes(),
121
              check_rep=False,
122
          )
123
      )(params_pjit)
124
  else:
125
    rotated_params = params_pjit
126

127
  kv_caches = []
128

129
  return (dtype, h, mesh, params_pjit, rotated_params, kv_caches, token_chunk)
130

131

132
# pylint: disable = dangerous-default-value
133
def xmap_pjit_equivalency(
134
    batch_size=4,
135
    seq_len=32,
136
    rules = [],
137
    attn_sharding=partitioning.AttnAllToAll.NONE,
138
    latency_collectives=False,
139
    batch_unsharded=False,
140
    shard_seqlen_vs_batch=False,
141
    layer_fn=two_d_parallel_xmap.transformer_layer_weight_stationary,
142
    atol=1e-03,
143
    rtol=1e-06,
144
):
145
  """Tests shard map."""
146
  # Within this function, we device put the relevant arrays ahead of time
147
  one_d = layer_fn == one_d_parallel_xmap.weight_stationary_simple
148

149
  with rules:
150
    (dtype, h, mesh, params, rotated_params, kv_caches, token_chunk) = setup(
151
        batch_size=batch_size,
152
        seq_len=seq_len,
153
        latency_collectives=latency_collectives,
154
        one_d=one_d,
155
    )
156

157
    def fwd_pjit(params, token_chunk):
158
      return inference.infer(
159
          h,
160
          layers_pjit.pjit_transformer_layer,
161
          params,
162
          kv_caches,
163
          token_chunk,
164
          intermediate_dtype=dtype)
165

166
    with mesh:
167
      result_baseline = jax.jit(fwd_pjit)(params, token_chunk)
168

169
    sharding_config = partitioning.ShardingConfig(
170
        mesh=mesh,
171
        attn_all_to_all=attn_sharding,
172
        latency_collectives=latency_collectives,
173
        shard_seqlen_vs_batch=shard_seqlen_vs_batch,
174
        batch_unsharded=batch_unsharded,
175
    )
176

177
    embed_fn = partial(
178
        two_d_parallel_xmap.embed_manual,
179
        shard_seqlen_vs_batch=shard_seqlen_vs_batch,
180
        batch_unsharded=batch_unsharded,
181
        one_d=one_d,
182
    )
183

184
    if layer_fn == two_d_parallel_xmap.transformer_layer_weight_stationary:
185
      layer_fn = partial(
186
          layer_fn,
187
          attn_all_to_all=attn_sharding,
188
          latency_collectives=latency_collectives,
189
          shard_seqlen_vs_batch=shard_seqlen_vs_batch,
190
          batch_unsharded=batch_unsharded,
191
      )
192
    elif layer_fn == one_d_parallel_xmap.weight_stationary_simple:
193
      layer_fn = partial(layer_fn, latency_collectives=latency_collectives)
194
    elif layer_fn == two_d_parallel_xmap.transformer_layer_weight_gathered:
195
      raise NotImplementedError
196

197
    unembed_fn = partial(
198
        two_d_parallel_xmap.unembed_manual,
199
        batch_unsharded=batch_unsharded,
200
        one_d=one_d,
201
    )
202

203
    forward_pass = partial(
204
        inference.manual_fwd_pass,
205
        h,
206
        sharding_config,
207
        embed_fn,
208
        layer_fn,
209
        unembed_fn,
210
    )
211

212
    def fwd(params, token_chunk):
213
      """Wraps the inference fn to ease shardmap in pytree definition."""
214
      return inference.infer_template(
215
          h,
216
          sharding_config,
217
          forward_pass,
218
          params,
219
          kv_caches,
220
          token_chunk,
221
          intermediate_dtype=dtype,
222
      )
223

224
    with mesh:
225
      result_shardmap = jax.jit(fwd)(rotated_params, token_chunk)
226

227
    np.testing.assert_allclose(
228
        result_baseline.kv_cache.k.astype(jnp.float32),
229
        result_shardmap.kv_cache.k.astype(jnp.float32),
230
        rtol=1e-1,
231
    )  # none_b1 needs this tolerance - XLA? TODO(sholto): Check
232
    np.testing.assert_allclose(
233
        result_baseline.logits, result_shardmap.logits, rtol=rtol, atol=atol
234
    )
235
    # pylint: disable = unused-variable
236
    # TODO(sholto): The final grad(shard_map) bug
237
    # pylint: disable = protected-access
238
    def grads_pjit(params, token_chunk):
239
      def loss_fn(params, token_chunk):
240
        result = fwd_pjit(params, token_chunk)
241
        return result.logits.mean()
242

243
      loss, grads = jax.value_and_grad(loss_fn)(params, token_chunk)
244
      grads = jax.tree_map(
245
          partitioning._with_sharding_constraint,
246
          grads,
247
          weights.Weights.logical_axes(),
248
      )
249
      return loss, grads
250

251
    def grads(params, token_chunk):
252
      def loss_fn(params, token_chunk):
253
        result = fwd(params, token_chunk)
254
        return result.logits.mean()
255

256
      loss, grads = jax.value_and_grad(loss_fn)(params, token_chunk)
257
      grads = jax.tree_map(
258
          partitioning._with_sharding_constraint,
259
          grads,
260
          weights.Weights.logical_axes(),
261
      )
262
      return loss, grads
263

264
    if attn_sharding == partitioning.AttnAllToAll.NONE:
265
      with mesh:
266
        loss_pjit, grads_pjit = jax.jit(grads_pjit)(params, token_chunk)
267
        loss, grads = jax.jit(grads)(params, token_chunk)
268

269
      # jax.tree_map(
270
      #     partial(np.testing.assert_allclose, atol=atol),
271
      #     grads_pjit,
272
      #     grads,
273
      # )
274

275

276
class InferenceTest(absltest.TestCase):
277
  """Tests for inference fwd pass."""
278

279
  def test_none_sharding_b1(self):
280
    attn_sharding = partitioning.AttnAllToAll.NONE
281
    rules = partitioning.PartitioningRules(
282
        partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)
283
    )
284
    xmap_pjit_equivalency(
285
        batch_size=1,
286
        seq_len=1,
287
        rules=rules,
288
        attn_sharding=partitioning.AttnAllToAll.NONE,
289
        batch_unsharded=True,
290
        atol=1e-01,
291
    )  # TODO(sholto); Check if this is because it occurs on VPU like b/246436629 pylint: disable= line-too-long
292

293
  def test_none_sharding(self):
294
    attn_sharding = partitioning.AttnAllToAll.NONE
295
    rules = partitioning.PartitioningRules(
296
        partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)
297
    )
298
    xmap_pjit_equivalency(
299
        batch_size=2,
300
        rules=rules,
301
        attn_sharding=attn_sharding,
302
        batch_unsharded=True,
303
    )
304

305
  def test_one_d(self):
306
    rules = partitioning.PartitioningRules(partitioning.make_rules_one_d())
307
    xmap_pjit_equivalency(
308
        batch_size=2,
309
        rules=rules,
310
        layer_fn=one_d_parallel_xmap.weight_stationary_simple,
311
    )
312

313
  def test_attn_z_sharding(self):
314
    attn_sharding = partitioning.AttnAllToAll.AXIS_Z
315
    rules = partitioning.PartitioningRules(
316
        partitioning.make_rules_two_d(attn_sharding)
317
    )
318
    xmap_pjit_equivalency(
319
        batch_size=2, rules=rules, attn_sharding=attn_sharding
320
    )
321

322
  def test_attn_yz_sharding(self):
323
    attn_sharding = partitioning.AttnAllToAll.AXES_YZ
324
    rules = partitioning.PartitioningRules(
325
        partitioning.make_rules_two_d(attn_sharding)
326
    )
327
    xmap_pjit_equivalency(
328
        batch_size=4, rules=rules, attn_sharding=attn_sharding
329
    )
330

331
  def test_attn_yz_sharding_batch_unsharded(self):
332
    attn_sharding = partitioning.AttnAllToAll.AXES_YZ
333
    rules = partitioning.PartitioningRules(
334
        partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)
335
    )
336
    xmap_pjit_equivalency(
337
        batch_size=4, rules=rules, attn_sharding=attn_sharding,
338
        batch_unsharded=True,
339
    )
340

341
  def test_attn_yzx_sharding(self):
342
    attn_sharding = partitioning.AttnAllToAll.AXES_YZX
343
    rules = partitioning.PartitioningRules(
344
        partitioning.make_rules_two_d(attn_sharding)
345
    )
346
    xmap_pjit_equivalency(
347
        batch_size=8, rules=rules, attn_sharding=attn_sharding
348
    )
349

350
  def test_none_sharding_with_latency(self):
351
    attn_sharding = partitioning.AttnAllToAll.NONE
352
    rules = partitioning.PartitioningRules(
353
        partitioning.make_rules_two_d(attn_sharding, batch_unsharded=True)
354
    )
355
    xmap_pjit_equivalency(
356
        batch_size=2,
357
        rules=rules,
358
        attn_sharding=attn_sharding,
359
        latency_collectives=True,
360
        batch_unsharded=True,
361
    )
362

363

364
if __name__ == '__main__':
365
  absltest.main()
366

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.