google-research

Форк
0
595 строк · 21.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""One Transformer layer, in hard xmap."""
17

18
from functools import partial  # pylint: disable = g-importing-member
19
from typing import Sequence, Tuple
20

21
import jax
22
from jax import lax
23
import jax.numpy as jnp
24
import jax.scipy
25

26
from scaling_transformer_inference_efficiency import attention
27
from scaling_transformer_inference_efficiency import checkpoint
28
from scaling_transformer_inference_efficiency import collectives
29
from scaling_transformer_inference_efficiency import special2
30
from scaling_transformer_inference_efficiency import weights
31
from scaling_transformer_inference_efficiency.chunk import Chunk
32
from scaling_transformer_inference_efficiency.layers.layers_pjit import _rope
33
from scaling_transformer_inference_efficiency.partitioning import AttnAllToAll
34
from scaling_transformer_inference_efficiency.weights import Layer
35

36
HParams = checkpoint.HParams
37
CheckpointSpec = checkpoint.CheckpointSpec
38
Layer = weights.Layer
39
QuantizedLayer = weights.QuantizedLayer
40
Weights = weights.Weights
41

42
ATTN_3D_SHARDING_THRESHOLD_PER_CHIP = 2
43

44
# pylint: disable = invalid-name
45
# pylint: disable = protected-access
46
# pylint: disable = g-bare-generic
47

48

49
def assert_equal(x, y):
50
  assert x == y, f'{x} != {y}'
51

52

53
def allgather_layernorm(x,
54
                        shard_seqlen_vs_batch,
55
                        batch_unsharded = False,
56
                        scale = None):
57
  """All gathers around layernorm, minimises comms by first doing per-chip."""
58
  with jax.named_scope('allgather_layernorm'):
59
    # allgather xnorm: [batch.Z, maxlen, embed.XY] || [batch, maxlen, embed.XYZ]
60
    # -> [batch.Z, maxlen, embed.X]    (xnorm_z)
61
    # -> [batch, maxlen, embed.X]
62
    xgather = x
63
    if batch_unsharded:
64
      # [batch, maxlen, embed.XY]
65
      xgather = lax.all_gather(xgather, 'z', axis=2, tiled=True)
66
    # [batch.Z, maxlen, embed.X] || [batch, maxlen, embed.X]
67
    xgather = lax.all_gather(xgather, 'y', axis=2, tiled=True)
68

69
    epsilon = 1e-6
70
    xgather = jnp.float32(xgather)
71
    mean2 = lax.pmean(
72
        jnp.mean(lax.square(xgather), axis=-1, keepdims=True), axis_name='x')
73
    xnorm_z = jnp.bfloat16(xgather * lax.rsqrt(mean2 + epsilon))
74
    if scale is not None:
75
      scale += 1.0  # 'center_scale_at_zero' option in T5X
76
      xnorm_z = jnp.bfloat16(xnorm_z * scale)
77
    # when attention_all_to_all is None we can partition over sequence len not
78
    # batch
79
    if shard_seqlen_vs_batch:
80
      xnorm = lax.all_gather(xnorm_z, 'z', axis=1, tiled=True)
81
    else:
82
      if batch_unsharded:  # in this case already done above
83
        xnorm = xnorm_z
84
      else:
85
        xnorm = lax.all_gather(xnorm_z, 'z', axis=0, tiled=True)
86
  # [batch, maxlen, embed.X]
87
  return xnorm, xnorm_z
88

89

90
@partial(jax.jit, static_argnums=(3, 4, 5))
91
def embed_manual(
92
    params,  # pylint: disable=g-bare-generic, invalid-name
93
    kv_caches,
94
    token_chunk,
95
    shard_seqlen_vs_batch = False,
96
    batch_unsharded = False,
97
    one_d = False,
98
):
99
  """Embeds a chunk of logits.
100

101
  Args:
102
    params: Weights object
103
    kv_caches: List of chunks preprocessed earlier
104
    token_chunk: An unsharded token chunk. Assume .tokens is int32[batch,
105
      maxlen]
106
    shard_seqlen_vs_batch: Whether to shard seqlen or batch by z.
107
    batch_unsharded:  global_batch is less than z so we cannot shard along
108
    one_d: whether it is one dimensional
109

110
  Returns:
111
    embeddings: bfloat16[[batch.Z, time, embed.XY] || [batch, time, embed.XYZ]
112
    sin: RoPE embeddings starting at the appropriate index determined by
113
         pre-existing kv_cache for each index in the batch.
114
    cos: ""
115
  """
116

117
  z_axis = lax.psum(1, 'z')
118
  # Start indices are the sums of the lengths of the KV caches.
119
  start_indices = attention.prefix_lengths(kv_caches)
120
  prefix_batch, = start_indices.shape
121
  batch, max_length = token_chunk.tokens.shape
122
  assert batch % prefix_batch == 0, 'Incompatible batch sizes'
123
  # Do RoPE lookups in the sin/cos tables. Only needed once per prefix_batch.
124
  def slice_at(index, table):
125
    # table: [precomputed_length, qkv // 2]
126
    return lax.dynamic_slice_in_dim(table, index, max_length)
127

128
  def slices_at(indices, table):
129
    return jax.vmap(slice_at, in_axes=(0, None))(indices, table)
130

131
  sin = slices_at(start_indices, params.sin)
132
  cos = slices_at(start_indices, params.cos)
133
  # sin, cos: bf16[prefix_batch, max_length, qkv // 2]
134

135
  # x: int32[batch, maxlen]
136
  # embed: bfloat16[vocab.YZ, embed.X]
137
  x = token_chunk.tokens
138
  vocab_yz, _ = params.embedding.shape
139

140
  yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')
141
  vocab_start = yz_index * vocab_yz
142

143
  # Initial embedding lookup:
144
  with jax.named_scope('embed'):
145
    one_x = x - vocab_start
146
    embeds = params.embedding[one_x, :]
147
    one_x = one_x[:, :, jnp.newaxis]
148
    embeds = jnp.where((one_x >= 0) & (one_x < vocab_yz), embeds, 0)
149
    # [batch, time, embed.X]
150
    if one_d:
151
      return embeds, sin, cos
152
    # [batch, time, embed.XY]
153
    embeds = lax.psum_scatter(embeds, 'y', scatter_dimension=2, tiled=True)
154

155
    if shard_seqlen_vs_batch:
156
      # [batch, time.Z, embed.XY]
157
      embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=1, tiled=True)
158
    else:
159
      if batch_unsharded:
160
        # [batch, time, embed.XYZ]
161
        embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=2, tiled=True)
162
      else:
163
        # [batch.Z, time, embed.XY]
164
        embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=0, tiled=True)
165

166
  return embeds, sin, cos
167

168

169
def unembed_manual(
170
    xnorm,
171
    params,
172
    batch_unsharded = False,
173
    one_d = False,
174
):
175
  """Unembedding function for 2D."""
176
  # x: bfloat16[batch, maxlen, dmodel.X] # [vocab.YZ, embed.X]
177
  # TODO(sholto): We could squeeze out more memory by doing this
178
  # with a collective
179
  with jax.named_scope('unembed'):
180
    logits_unreduced = jnp.einsum(
181
        'bte,ve->btv', jnp.float32(xnorm), jnp.float32(params.embedding)
182
    )
183
    # x: [batch, maxlen, vocab.YZ] {X unreduced}
184
    if batch_unsharded or one_d:
185
      # logits: float32[batch, maxlen, vocab.YZX]
186
      logits = lax.psum_scatter(
187
          logits_unreduced, 'x', scatter_dimension=2, tiled=True
188
      )
189
    else:
190
      # logits: float32[batch.X, maxlen, vocab.YZ]
191
      logits = lax.psum_scatter(
192
          logits_unreduced, 'x', scatter_dimension=0, tiled=True
193
      )
194
  return logits
195

196

197
# pylint: disable = g-doc-return-or-yield
198
# pylint: disable = g-doc-args
199
# TODO(sholto): Update to new, tested parsing collectives.
200

201

202
def transformer_layer_weight_stationary(
203
    hparams,
204
    layer,
205
    params,
206
    sin,
207
    cos,
208
    kv_caches,
209
    x,
210
    x_axis,
211
    y_axis,
212
    z_axis,
213
    *,
214
    attn_all_to_all,
215
    latency_collectives,
216
    shard_seqlen_vs_batch = False,
217
    batch_unsharded = False,
218
    intermediate_dtype = jnp.bfloat16,
219
):
220
  """Wraps _fn so that we can use remat while bug is fixed."""
221
  return jax.checkpoint(
222
      partial(
223
          _transformer_layer_weight_stationary,
224
          attn_all_to_all=attn_all_to_all,
225
          latency_collectives=latency_collectives,
226
          shard_seqlen_vs_batch=shard_seqlen_vs_batch,
227
          batch_unsharded=batch_unsharded,
228
          intermediate_dtype=intermediate_dtype,
229
      ),
230
      static_argnums=(0, 7, 8, 9),
231
      prevent_cse=True,
232
  )(hparams, layer, params, sin, cos, kv_caches, x, x_axis, y_axis, z_axis)
233

234

235
def _transformer_layer_weight_stationary(
236
    hparams,
237
    layer,
238
    params,
239
    sin,
240
    cos,
241
    kv_caches,
242
    x,
243
    x_axis,
244
    y_axis,
245
    z_axis,
246
    *,
247
    attn_all_to_all,
248
    latency_collectives,
249
    shard_seqlen_vs_batch = False,
250
    batch_unsharded = False,
251
    intermediate_dtype = jnp.bfloat16,
252
):
253
  """Forward pass through a single layer, returning output, K, V.
254

255
  This implementation has 'x'=d_model sharding,
256
  ('y', 'z')=d_ff sharding.
257
  * params are assumed already sharded this way, i.e. embed.X and heads.YZ
258
  * sin and cos are sharded by batch.YZx (or batch.YZ or batch.Y as necessary)
259
  * kv_cache is sharded by batch.YZx (or batch.YZ or batch.Y as necessary)
260
  * x: [batch.Z, maxlen, embed.XY]
261
  """
262
  intermediate_dtype = jax.core.concrete_or_error(None, intermediate_dtype)
263
  if latency_collectives:
264
    matmul_reducescatter = partial(
265
        collectives.matmul_reducescatter_latency, subsplit_axis=2)
266
    # reducescatter = collectives.reducescatter_latency
267
    # subsplit along heads as they are indepedent
268
    # partial here because the one-way algorithm does not use subsplit
269
    matmul_allgather = partial(
270
        collectives.allgather_matmul_latency, subsplit_axis=2)
271
  else:
272
    matmul_reducescatter = collectives.matmul_reducescatter_oneway
273
    # reducescatter = collectives.reducescatter_oneway
274
    matmul_allgather = collectives.allgather_matmul_one_way
275

276
  def my_layer(t, axis=0):
277
    """Gets the parameters corresponding to a given layer."""
278
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
279

280
  # Compare
281
  # flaxformer/architectures/t5/parallel_fused_decoder.py
282
  # flaxformer/components/attention/dense_attention.py;l=1147;
283
  # flaxformer/components/attention/dense_attention.py;l=252;
284

285
  batch_z, max_len, _ = x.shape
286
  if shard_seqlen_vs_batch:
287
    max_len *= z_axis
288
    batch = batch_z
289
    batch_xyz = batch // (x_axis * y_axis * z_axis)
290
  else:
291
    if batch_unsharded:
292
      batch = x.shape[0]
293
    else:
294
      batch = batch_z * z_axis
295
    batch_xyz = batch // (x_axis * y_axis * z_axis)
296
    batch_yz = batch // (y_axis * z_axis)
297
    batch_z = batch // (z_axis)
298

299
  if isinstance(params, weights.QuantizedLayer):
300
    xnorm, xnorm_z = allgather_layernorm(
301
        x,
302
        shard_seqlen_vs_batch,
303
        batch_unsharded,
304
        scale=my_layer(params.layernorm_scale))
305
  else:
306
    xnorm, xnorm_z = allgather_layernorm(x, shard_seqlen_vs_batch,
307
                                         batch_unsharded)
308

309
  # einsum(xnorm, q_wi):
310
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
311
  # -> (matmul)
312
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
313
  # -> (reducescatter over x into X heads, B batches)
314
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
315
  with jax.named_scope('q_wi'):
316
    xnorm = intermediate_dtype(xnorm)
317
    q_wi = matmul_reducescatter(
318
        'bte,hed->bthd',
319
        xnorm,
320
        params.q_wi,
321
        scatter_axis=0,
322
        axis_name='x',
323
        layer=layer)
324

325
    if isinstance(params, weights.QuantizedLayer):
326
      prev_shape = q_wi.shape
327
      q_wi = intermediate_dtype(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
328
      assert_equal(prev_shape, q_wi.shape)
329

330
    # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
331
    # swiGLU with full d_ff dimension, rather than 2/3 scaled
332
    wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads))]  # pylint: disable = line-too-long
333
    wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads)):]  # pylint: disable = line-too-long
334

335
  # einsum(xnorm, kv):
336
  #
337
  # if attn>=AXES_YZ:
338
  #   xnorm_z: [batch.Z, maxlen, embed.X]
339
  #     -> [batch.(X?)YZ, maxlen, embed.X]  (slice down)
340
  #
341
  # Then:
342
  #
343
  # [batch.(Y?)Z, maxlen, embed.X] @ [embed.X, 1, 2*qkv]
344
  # -> (matmul)
345
  # -> [batch.(Y?)Z, maxlen, 1, 2*qkv]{x unreduced}
346
  # -> (reducescatter over x into batch)
347
  #         *NOT* collective matmul, because it's batch
348
  # -> { Attn.NONE:      [batch.B, maxlen,  1, 2*qkv]
349
  #    { Attn.AXIS_Z:    [batch.ZB, maxlen, 1, 2*qkv]
350
  #    { Attn.AXES_YZ:   [batch.YZB, maxlen, 1, 2*qkv]
351
  #    { Attn.AXES_YZX:  [batch.YZXB, maxlen, 1, 2*qkv]
352
  with jax.named_scope('kv'):
353
    # TODO(sholto): update this in oversharded
354
    yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')
355
    # TODO(reinerp): Consider using xnorm instead of xnorm_z in NONE case?
356
    # I don't know yet if that's better.
357
    if attn_all_to_all.value >= AttnAllToAll.AXES_YZ.value:
358
      xnorm_sliced = lax.dynamic_slice_in_dim(
359
          xnorm, yz_index * batch_yz, batch_yz, axis=0)
360
    else:
361
      xnorm_sliced = xnorm_z
362

363
    kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm_sliced,
364
                              my_layer(params.kv))
365

366
    if attn_all_to_all == AttnAllToAll.NONE:
367
      if shard_seqlen_vs_batch:
368
        # [batch, maxlen.Z, 1, 2*qkv]{x_unreduced}
369
        # -> [batch.B, maxlen, 1, 2*qkv]
370
        kv = lax.psum(kv_unreduced, 'x')
371
        kv = lax.all_gather(kv, 'z', axis=1, tiled=True)
372
      else:
373
        # [batch.Z, maxlen, 1, 2*qkv]{x_unreduced} || [b, ml, 1, 2qkv] {x_unred}
374
        # --ARx-->   [batch.Z, maxlen, 1, 2*qkv]
375
        # --AGZ-->   [batch, maxlen, 1, 2*qkv]
376
        kv = lax.psum(kv_unreduced, 'x')
377
        if not batch_unsharded:
378
          kv = lax.all_gather(kv, 'z', axis=0, tiled=True)
379
    elif attn_all_to_all == AttnAllToAll.AXIS_Z:
380
      # [batch.Z, maxlen, 1, 2*qkv]{x_unreduced}
381
      # --ARx-->   [batch.Z, maxlen, 1, 2*qkv]
382
      kv = lax.psum(kv_unreduced, 'x')
383
      # print('kv2', kv.shape, kv.named_shape)
384
    elif attn_all_to_all == AttnAllToAll.AXES_YZ:
385
      # [batch.YZ, maxlen, 1, 2*qkv]{x_unreduced}
386
      # --ARx-->   [batch.YZ, maxlen, 1, 2*qkv]
387
      kv = lax.psum(kv_unreduced, 'x')
388
    elif attn_all_to_all == AttnAllToAll.AXES_YZX:
389
      # [batch.YZ, maxlen, 1, 2*qkv]{x_unreduced}
390
      # --RSx-->   [batch.YZX, maxlen, 1, 2*qkv]
391
      assert batch_xyz >= 1, ('Batch size too small for AXES_XYZ and this chip '
392
                              'count')
393
      kv = lax.psum_scatter(kv_unreduced, 'x', scatter_dimension=0, tiled=True)
394

395
    if isinstance(params, weights.QuantizedLayer):
396
      prev_shape = kv.shape
397
      kv = intermediate_dtype(kv * jnp.squeeze(my_layer(params.kv_scale)))
398
      assert_equal(prev_shape, kv.shape)
399

400
    k = kv[:, :, 0, :hparams.qkv]
401
    v = kv[:, :, 0, hparams.qkv:]
402

403
  with jax.named_scope('attn'):
404
    k = _rope(sin, cos, k)
405

406
    # print(f'batch_yzb: {batch_yzb}')
407
    # q: [batch, maxlen, heads.YZX, qkv]
408
    # -> { NONE:                   [batch., maxlen, heads.YZX, qkv]
409
    #    { AXIS_Z:                 [batch.Z, maxlen, heads.YX, qkv]
410
    #    { AXES_YZ:                [batch.YZ, maxlen, heads.X, qkv]
411
    #    { AXES_YZX:               [batch.YZX, maxlen, heads, qkv]
412
    q = q_wi[:, :, :, :hparams.qkv]
413
    if attn_all_to_all == AttnAllToAll.NONE:
414
      pass
415
    elif attn_all_to_all == AttnAllToAll.AXIS_Z:
416
      q = lax.all_to_all(
417
          q, axis_name='z', split_axis=0, concat_axis=2, tiled=True)
418
    elif attn_all_to_all == AttnAllToAll.AXES_YZ:
419
      q = lax.all_to_all(
420
          q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True)
421
    elif attn_all_to_all == AttnAllToAll.AXES_YZX:
422
      q = lax.all_to_all(
423
          q, axis_name=('y', 'z', 'x'), split_axis=0, concat_axis=2, tiled=True)
424

425
    q = _rope(sin, cos, q)
426

427
    y_att = intermediate_dtype(attention.attend(q, k, v, kv_caches, layer))
428
    # y_att:
429
    #    { NONE:                   [batch.B, maxlen, heads.YZX, qkv]
430
    #    { AXIS_Z:                 [batch.ZB, maxlen, heads.YX, qkv]
431
    #    { AXES_YZ:                [batch.YZB, maxlen, heads.X, qkv]
432
    #    { AXES_YZX:               [batch.YZX, maxlen, heads, qkv]
433
    # -> [batch, maxlen, heads.YZX, qkv]
434
    if attn_all_to_all == AttnAllToAll.NONE:
435
      pass
436
    elif attn_all_to_all == AttnAllToAll.AXIS_Z:
437
      y_att = lax.all_to_all(
438
          y_att, axis_name='z', split_axis=2, concat_axis=0, tiled=True)
439
    elif attn_all_to_all == AttnAllToAll.AXES_YZ:
440
      y_att = lax.all_to_all(
441
          y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True)
442
    elif attn_all_to_all == AttnAllToAll.AXES_YZX:
443
      y_att = lax.all_to_all(
444
          y_att,
445
          axis_name=('y', 'z', 'x'),
446
          split_axis=2,
447
          concat_axis=0,
448
          tiled=True)
449

450
  with jax.named_scope('SwiGLU'):
451
    y_mlp = special2.swish2(wi0) * wi1
452

453
  # einsum(y_fused, o_wo):
454
  # [batch, maxlen, heads.YZ, o_wo_per_head] @
455
  #       [heads.YZ, o_wo_per_head, embed.X]
456
  # -> (matmul)
457
  # -> [batch, maxlen, embed.X]{YZ unreduced}
458
  # -> (fused reducescatter)
459
  # -> [batch, maxlen, embed.XY]
460
  # -> (non-fused reducescatter)
461
  # -> [batch.Z, maxlen, embed.XY]
462
  with jax.named_scope('o_wo'):
463
    y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
464

465
    # do the second half of the mlp and the self-attn projection in parallel
466
    # allgather y_fused: [batch, maxlen, heads.YZX, o_wo_per_head]
467
    #       -> [batch, maxlen, heads.YZ, o_wo_per_head]
468
    # we use the collective matmul/reducescatter instead
469
    # print(f'o_wo: {params.o_wo.shape}')
470
    y_out = matmul_allgather(
471
        'bthd,hde->bte',
472
        y_fused,
473
        params.o_wo,
474
        rhs_split_axis=0,
475
        axis_name='x',
476
        layer=layer)
477
    # y_out = reducescatter(
478
    #     y_out, scatter_dimension=2, axis_name='y', subsplit_axis=2)
479

480
    y_out = lax.psum_scatter(y_out, 'y', scatter_dimension=2, tiled=True)
481

482
    if shard_seqlen_vs_batch:
483
      # y_out = reducescatter(
484
      #     y_out, scatter_dimension=1, axis_name='z', subsplit_axis=0)
485
      # [batch, maxlen.Z, embed.XY]
486
      y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=1, tiled=True)
487
    else:
488
      # y_out = reducescatter(
489
      #     y_out, scatter_dimension=0, axis_name='z', subsplit_axis=0)
490
      # TODO(sholto): Test if manual faster, update
491
      if batch_unsharded:
492
        # [batch, maxlen, embed.XYZ]
493
        y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=2, tiled=True)
494
      else:
495
        # [batch.Z, maxlen, embed.XY]
496
        y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=0, tiled=True)
497

498
    if isinstance(params, weights.QuantizedLayer):
499
      prev_shape = y_out.shape
500
      y_out = intermediate_dtype(y_out *
501
                                 jnp.squeeze(my_layer(params.o_wo_scale)))
502
      assert_equal(y_out.shape, prev_shape)
503

504
  with jax.named_scope('residual'):
505
    z = intermediate_dtype(y_out + x)
506

507
  k, v = k.astype(intermediate_dtype), v.astype(intermediate_dtype)
508
  return z, k, v
509

510

511
def transformer_layer_weight_gathered(
512
    hparams, layer, params, sin,
513
    cos, kv_caches, x,
514
    x_axis, y_axis,
515
    z_axis):
516
  """Weight gathered parallel layer. Typically prefill."""
517
  del x_axis, y_axis, z_axis  # for API compatibility
518
  # x: [batch.XYZ, t, e]
519
  with jax.named_scope('allgather_layernorm'):
520
    # No need to communicate across batch, so everything is local
521
    x_prec = jnp.float32(x)
522
    epsilon = 1e-6
523
    mean2 = jnp.mean(lax.square(x_prec), axis=-1, keepdims=True)
524
    xnorm = jnp.bfloat16(x * lax.rsqrt(mean2 + epsilon))
525

526
  def my_layer(t, axis=0):
527
    """Gets the parameters corresponding to a given layer."""
528
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
529

530
  # [batch.XYZ, t, e] @ [heads.YZ, e.X, q_wi_per_head]
531
  with jax.named_scope('q_wi'):
532
    q_wi = collectives.matmul_collective_weights_gather_q_wi(
533
        'bte,hed->bthd',
534
        xnorm,
535
        my_layer(
536
            params.q_wi
537
        ),  # in this case it makes sense to do this here because its once
538
        lhs_split_axis=2)  #   -> [batch.XYZ, t, h, q_wi_per_head]
539

540
    if isinstance(params, weights.QuantizedLayer):
541
      prev_shape = q_wi.shape
542
      q_wi = jnp.bfloat16(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
543
      assert_equal(prev_shape, q_wi.shape)
544

545
    # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
546
    # swiGLU with full d_ff dimension, rather than 2/3 scaled
547
    wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads))]  # pylint: disable = line-too-long
548
    wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads)):]  # pylint: disable = line-too-long
549

550
    # kv is only batch sharded
551
    with jax.named_scope('kv'):
552
      # [batch.XYZ, t, e] @ [e, 1, 2*qkv] -> [batch.XYZ, t, 1, 2*qkv]
553
      # Two options here:
554
      # a) Split along x, and then all reduce along x
555
      # b) We fully replicate kv
556
      kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
557

558
      if isinstance(params, weights.QuantizedLayer):
559
        prev_shape = kv.shape
560
        kv = jnp.bfloat16(kv * jnp.squeeze(my_layer(params.kv_scale)))
561
        assert_equal(prev_shape, kv.shape)
562

563
      k = kv[:, :, 0, :hparams.qkv]  # [batch.XYZ, t, qkv]
564
      v = kv[:, :, 0, hparams.qkv:]  # [batch.XYZ, t, qkv]
565

566
    with jax.named_scope('attn'):
567
      k = _rope(sin, cos, k)  # [batch.XYZ, t, qkv]
568
      q = q_wi[:, :, :, :hparams.qkv]
569
      q = _rope(sin, cos, q)  # [batch.XYZ, t, h, qkv]
570

571
      # [batch.XYZ, t, h, qkv]
572
      y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
573

574
    with jax.named_scope('SwiGLU'):
575
      y_mlp = special2.swish2(wi0) * wi1  # [batch.XYZ, t, h, ff_per_head]
576

577
    # [bach.XYZ, t , h, d] @ [h.YZ, d, e.X] -> [batch.XYZ, t, e]
578
    with jax.named_scope('o_wo'):
579
      y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
580

581
      # previously concat yz, contracting over x - reconstructing heads dim
582
      # here, we contract over yz, concat over x to reconstruct embed dim
583
      y_out = collectives.matmul_collective_weights_gather_o_wo(
584
          'bthd,hde->bte', y_fused, my_layer(params.o_wo),
585
          lhs_split_axis=2)  # -> [batch.XYZ, t, e]
586

587
    if isinstance(params, weights.QuantizedLayer):
588
      prev_shape = y_out.shape
589
      y_out = jnp.bfloat16(y_out * jnp.squeeze(my_layer(params.o_wo_scale)))
590
      assert_equal(y_out.shape, prev_shape)
591

592
    with jax.named_scope('residual'):
593
      z = jnp.bfloat16(y_out + x)
594

595
    return z, k, v
596

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.