google-research

Форк
0
496 строк · 16.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""1D weight stationary xmap layer."""
17

18
from typing import Sequence, Tuple
19

20
import jax
21
from jax import lax
22
import jax.numpy as jnp
23
import jax.scipy
24

25
from scaling_transformer_inference_efficiency import attention
26
from scaling_transformer_inference_efficiency import checkpoint
27
from scaling_transformer_inference_efficiency import collectives
28
from scaling_transformer_inference_efficiency import inference
29
from scaling_transformer_inference_efficiency import partitioning
30
from scaling_transformer_inference_efficiency import special2
31
from scaling_transformer_inference_efficiency import weights
32
from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap
33
from scaling_transformer_inference_efficiency.layers.layers_pjit import _rope
34
from scaling_transformer_inference_efficiency.weights import Layer
35

36
HParams = checkpoint.HParams
37
CheckpointSpec = checkpoint.CheckpointSpec
38

39

40
# pylint: disable = protected-access
41
# pylint: disable = g-doc-return-or-yield
42
# pylint: disable = g-doc-args
43
# TODO(sholto): Update
44
def weight_stationary_simple(
45
    hparams,
46
    layer,
47
    params,
48
    sin,
49
    cos,
50
    kv_caches,
51
    x,
52
    x_axis,
53
    y_axis,
54
    z_axis,
55
    latency_collectives,
56
    intermediate_dtype = jnp.bfloat16,
57
):
58
  """Forward pass through a single layer, returning output, K, V.
59

60
  Partitioning:
61
  * 'x' is the only axis (most chips).
62
  * weights are sharded [dmodel, heads.XYZ, *_per_head]
63
  * hidden-dimension activations are sharded [batch, time, heads.XYZ,
64
  *_per_head]
65
  * embed-dimension activations are "naturally" sharded [batch, time, dmodel]
66
    but may be "oversharded" after reducescatter operations.
67

68
  To support XYZ>heads, we simply increase the number of heads, by padding. The
69
  FFN can be sharded finer and continues to gain speedup with more chips, but
70
  the ATTN will just be padded and not gain speedup as we add chips.
71
  """
72
  if latency_collectives:
73
    matmul_reducescatter = collectives.matmul_reducescatter_latency
74
    # reducescatter = collectives.reducescatter_latency
75
    matmul_allgather = collectives.allgather_matmul_latency
76
  else:
77
    # matmul_reducescatter = partial(
78
    #     collectives.matmul_reducescatter_throughput, subsplit_axis=0
79
    # )
80
    # # reducescatter = collectives.reducescatter_throughput
81
    # matmul_allgather = partial(
82
    #     collectives.allgather_matmul_throughput, subsplit_axis=2
83
    # )
84
    matmul_reducescatter = collectives.matmul_reducescatter_oneway
85
    # reducescatter = collectives.reducescatter_throughput
86
    matmul_allgather = collectives.allgather_matmul_one_way
87

88
  def my_layer(t, axis=0):
89
    """Gets the parameters corresponding to a given layer."""
90
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
91

92
  batch, max_len, _ = x.shape
93

94
  with jax.named_scope('layernorm'):
95
    # x: [batch, maxlen, dmodel.X]
96
    # mean2: [batch, maxlen]
97
    # xnorm: [batch, maxlen, dmodel.X]
98
    epsilon = 1e-6
99
    mean2 = lax.pmean(
100
        jnp.mean(lax.square(x), axis=-1, keepdims=True), axis_name='x'
101
    )
102
    xnorm = intermediate_dtype(x * lax.rsqrt(mean2 + epsilon))  # pytype: disable=not-callable  # jnp-type
103

104
  # einsum(xnorm, q_wi):
105
  # [batch, maxlen, dmodel.X] @ [heads.XYZ, dmodel, q_wi_per_head]
106
  # -> (allgather lhs)   (fused with matmul)
107
  # -> [batch, maxlen, dmodel]
108
  # -> (matmul)
109
  # -> [batch, maxlen, heads.XYZ, q_wi_per_head]
110
  with jax.named_scope('q_wi'):
111
    q_wi = matmul_allgather(
112
        'bte,hed->bthd',
113
        xnorm,
114
        params.q_wi,
115
        rhs_split_axis=1,
116
        axis_name='x',
117
        layer=layer,
118
    )
119

120
    # No need to scatter over y and z, as y and z will always be 1 in here.
121

122
    two_d_parallel_xmap.assert_equal(
123
        q_wi.shape,
124
        (
125
            batch,
126
            max_len,
127
            hparams.heads // (x_axis * y_axis * z_axis),
128
            hparams.q_wi_per_head,
129
        ),
130
    )
131

132
    if isinstance(params, weights.QuantizedLayer):
133
      prev_shape = q_wi.shape
134
      q_wi = intermediate_dtype(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
135
      two_d_parallel_xmap.assert_equal(prev_shape, q_wi.shape)
136

137
    # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
138
    # swiGLU with full d_ff dimension, rather than 2/3 scaled
139
    wi0 = q_wi[
140
        :, :, :, hparams.qkv : hparams.qkv + (hparams.ff // hparams.heads)
141
    ]
142
    wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads) :]
143

144
  # einsum(xnorm, kv):
145
  #
146
  # [batch, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
147
  # -> (matmul)
148
  # -> [batch, maxlen, 1, 2*qkv]{x unreduced}
149
  # -> (reducescatter over x into batch)
150
  #         *NOT* collective matmul, because it's batch
151
  # -> { Attn.NONE:      [batch, maxlen,  1, 2*qkv]
152
  with jax.named_scope('kv'):
153

154
    def kv_einsum(lhs):
155
      return jnp.einsum('bte,ezd->btzd', lhs, my_layer(params.kv))
156

157
    # kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm,
158
    #                           my_layer(params.kv))
159
    # [batch, maxlen, 1, 2*qkv]{x_unreduced}
160
    # --ARx-->   [batch, maxlen, 1, 2*qkv]
161
    kv = lax.psum(kv_einsum(xnorm), 'x')
162

163
    if isinstance(params, inference.QuantizedLayer):
164
      prev_shape = kv.shape
165
      kv = intermediate_dtype(kv * jnp.squeeze(my_layer(params.kv_scale)))
166
      two_d_parallel_xmap.assert_equal(prev_shape, kv.shape)
167

168
    k = kv[:, :, 0, : hparams.qkv]
169
    v = kv[:, :, 0, hparams.qkv :]
170

171
  with jax.named_scope('attn'):
172
    k = _rope(sin, cos, k)
173

174
    # q: [batch, maxlen, heads.XYZ, qkv]
175
    q = q_wi[:, :, :, : hparams.qkv]
176
    q = _rope(sin, cos, q)
177

178
    # y_att: -> [batch.B, maxlen, heads.XYZ, qkv]
179
    y_att = intermediate_dtype(attention.attend(q, k, v, kv_caches, layer))  # pytype: disable=not-callable  # jnp-type
180

181
  with jax.named_scope('SwiGLU'):
182
    y_mlp = special2.swish2(wi0) * wi1
183

184
  # einsum(y_fused, o_wo):
185
  # [batch, maxlen, heads.XYZ, o_wo_per_head]
186
  #   @ [heads.XYZ, o_wo_per_head, dmodel]
187
  # -> (matmul)
188
  # -> [batch, maxlen, dmodel]{XYZ unreduced}
189
  # -> (fused reducescatter over X)
190
  # -> [batch, maxlen, dmodel.X]{YZ unreduced}
191
  # -> (non-fused allreduce)
192
  # -> [batch, maxlen, dmodel.X]
193
  with jax.named_scope('o_wo'):
194
    y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
195
    two_d_parallel_xmap.assert_equal(
196
        y_fused.shape,
197
        (
198
            batch,
199
            max_len,
200
            hparams.heads // (x_axis * y_axis * z_axis),
201
            hparams.o_wo_per_head,
202
        ),
203
    )
204

205
    y_out = matmul_reducescatter(
206
        'bthd,hde->bte',
207
        y_fused,
208
        params.o_wo,
209
        scatter_axis=2,
210
        axis_name='x',
211
        layer=layer,
212
    )
213

214
    # No output psum because this is for only x
215
    # y_out = lax.psum(y_out, axis_name=('y', 'z'))
216

217
    if isinstance(params, inference.QuantizedLayer):
218
      prev_shape = y_out.shape
219
      y_out = intermediate_dtype(
220
          y_out * jnp.squeeze(my_layer(params.o_wo_scale))
221
      )
222
      two_d_parallel_xmap.assert_equal(y_out.shape, prev_shape)
223

224
  with jax.named_scope('residual'):
225
    z = intermediate_dtype(y_out + x)  # pytype: disable=not-callable  # jnp-type
226
  k, v = k.astype(intermediate_dtype), v.astype(intermediate_dtype)
227
  return z, k, v
228

229

230
def weight_stationary(
231
    hparams,
232
    layer,
233
    params,
234
    sin,
235
    cos,
236
    kv_caches,
237
    x,
238
    x_axis,
239
    y_axis,
240
    z_axis,
241
    attn_all_to_all,
242
    latency_collectives,
243
):
244
  """Forward pass through a single layer, returning output, K, V.
245

246
  Partitioning:
247
  * 'x' is the longest axis (most chips), then 'y', then 'z'.
248
  * weights are sharded [dmodel, heads.XYZ, *_per_head]
249
  * hidden-dimension activations are sharded [batch, time, heads.XYZ,
250
  *_per_head]
251
  * embed-dimension activations are "naturally" sharded [batch, time, dmodel]
252
    but may be "oversharded" after reducescatter operations.
253

254
  To support XYZ>heads, we simply increase the number of heads, by padding. The
255
  FFN can be sharded finer and continues to gain speedup with more chips, but
256
  the ATTN will just be padded and not gain speedup as we add chips.
257
  """
258
  if latency_collectives:
259
    matmul_reducescatter = collectives.matmul_reducescatter_latency
260
    # reducescatter = collectives.reducescatter_latency
261
    matmul_allgather = collectives.allgather_matmul_latency
262
  else:
263
    matmul_reducescatter = collectives.matmul_reducescatter_throughput
264
    # reducescatter = collectives.reducescatter_throughput
265
    matmul_allgather = collectives.allgather_matmul_throughput
266

267
  def my_layer(t, axis=0):
268
    """Gets the parameters corresponding to a given layer."""
269
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
270

271
  batch, max_len, _ = x.shape
272
  batch_z = batch // z_axis
273
  batch_yz = batch_z // y_axis
274
  batch_xyz = batch_yz // x_axis
275

276
  # x_index = lax.axis_index('x')
277
  y_index = lax.axis_index('y')
278
  z_index = lax.axis_index('z')
279
  yz_index = y_index * z_axis + z_index
280

281
  with jax.named_scope('layernorm'):
282
    # x: [batch, maxlen, dmodel.X]
283
    # mean2: [batch, maxlen]
284
    # xnorm: [batch, maxlen, dmodel.X]
285
    epsilon = 1e-6
286
    mean2 = lax.pmean(
287
        jnp.mean(lax.square(x), axis=-1, keepdims=True), axis_name='x'
288
    )
289
    xnorm = jnp.bfloat16(x * lax.rsqrt(mean2 + epsilon))
290

291
  # einsum(xnorm, q_wi):
292
  # [batch, maxlen, dmodel.X] @ [heads.XYZ, dmodel, q_wi_per_head]
293
  # -> (allgather lhs)   (fused with matmul)
294
  # -> [batch, maxlen, dmodel]
295
  # -> (matmul)
296
  # -> [batch, maxlen, heads.XYZ, q_wi_per_head]
297
  with jax.named_scope('q_wi'):
298
    q_wi = matmul_allgather(
299
        'bte,hed->bthd',
300
        xnorm,
301
        params.q_wi,
302
        rhs_split_axis=1,
303
        axis_name='x',
304
        layer=layer,
305
        subsplit_axis=2,
306
    )
307

308
    two_d_parallel_xmap.assert_equal(
309
        q_wi.shape,
310
        (
311
            batch,
312
            max_len,
313
            hparams.heads // (x_axis * y_axis * z_axis),
314
            hparams.q_wi_per_head,
315
        ),
316
    )
317

318
    if isinstance(params, weights.QuantizedLayer):
319
      prev_shape = q_wi.shape
320
      q_wi = jnp.bfloat16(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
321
      two_d_parallel_xmap.assert_equal(prev_shape, q_wi.shape)
322

323
    # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
324
    # swiGLU with full d_ff dimension, rather than 2/3 scaled
325
    wi0 = q_wi[
326
        :, :, :, hparams.qkv : hparams.qkv + (hparams.ff // hparams.heads)
327
    ]
328
    wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads) :]
329

330
  # einsum(xnorm, kv):
331
  #
332
  # [batch, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
333
  # -> (matmul)
334
  # -> [batch, maxlen, 1, 2*qkv]{x unreduced}
335
  # -> (reducescatter over x into batch)
336
  #         *NOT* collective matmul, because it's batch
337
  # -> { Attn.NONE:      [batch.B, maxlen,  1, 2*qkv]
338
  #    { Attn.AXIS_Z:    [batch.ZB, maxlen, 1, 2*qkv]
339
  #    { Attn.AXES_YZ:   [batch.YZB, maxlen, 1, 2*qkv]
340
  #    { Attn.AXES_YZX:  [batch.YZXB, maxlen, 1, 2*qkv]
341
  with jax.named_scope('kv'):
342

343
    def kv_einsum(lhs):
344
      return jnp.einsum('bte,ezd->btzd', lhs, my_layer(params.kv))
345

346
    # kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm,
347
    #                           my_layer(params.kv))
348

349
    if attn_all_to_all == partitioning.AttnAllToAll.NONE:
350
      # [batch, maxlen, 1, 2*qkv]{x_unreduced}
351
      # --ARx-->   [batch, maxlen, 1, 2*qkv]
352
      kv = lax.psum(kv_einsum(xnorm), 'x')
353
    elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
354
      assert batch_z >= 1, 'Batch size too small for AXIS_Z and this chip count'
355
      # xnorm: [batch, maxlen, dmodel.X] -> [batch.Z, maxlen, dmodel.X]
356
      xnorm = lax.dynamic_slice_in_dim(
357
          xnorm, z_index * batch_z, batch_z, axis=0
358
      )
359
      # [batch.Z, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
360
      # --matmul--> [batch.Z, maxlen, 1, 2*qkv]{x unreduced}
361
      # --ARx-->    [batch.Z, maxlen, 1, 2*qkv]
362
      kv = lax.psum(kv_einsum(xnorm), 'x')
363
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
364
      assert (
365
          batch_yz >= 1
366
      ), 'Batch size too small for AXES_YZ and this chip count'
367
      # xnorm: [batch, maxlen, dmodel.X] -> [batch.YZ, maxlen, dmodel.X]
368
      xnorm = lax.dynamic_slice_in_dim(
369
          xnorm, yz_index * batch_yz, batch_yz, axis=0
370
      )
371
      # [batch.YZ, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
372
      # --matmul--> [batch.YZ, maxlen, 1, 2*qkv]{x unreduced}
373
      # --ARx-->    [batch.YZ, maxlen, 1, 2*qkv]
374
      kv = lax.psum(kv_einsum(xnorm), 'x')
375
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
376
      assert (
377
          batch_xyz >= 1
378
      ), 'Batch size too small for AXES_XYZ and this chip count'
379
      # xnorm: [batch, maxlen, dmodel.X] -> [batch.YZ, maxlen, dmodel.X]
380
      xnorm = lax.dynamic_slice_in_dim(
381
          xnorm, yz_index * batch_yz, batch_yz, axis=0
382
      )
383
      # [batch.YZ, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
384
      # --matmul--> [batch.YZ, maxlen, 1, 2*qkv]{x unreduced}
385
      # --RSx-->    [batch.YZ, maxlen, 1, 2*qkv]
386
      kv = lax.psum_scatter(
387
          kv_einsum(xnorm), 'x', scatter_dimension=0, tiled=True
388
      )
389

390
    if isinstance(params, inference.QuantizedLayer):
391
      prev_shape = kv.shape
392
      kv = jnp.bfloat16(kv * jnp.squeeze(my_layer(params.kv_scale)))
393
      two_d_parallel_xmap.assert_equal(prev_shape, kv.shape)
394

395
    k = kv[:, :, 0, : hparams.qkv]
396
    v = kv[:, :, 0, hparams.qkv :]
397

398
  with jax.named_scope('attn'):
399
    k = _rope(sin, cos, k)
400

401
    # q: [batch, maxlen, heads.XYZ, qkv]
402
    # -> { NONE:                   [batch,  maxlen, heads.XYZ, qkv]
403
    #    { AXIS_Z:                 [batch.Z, maxlen, heads.XY, qkv]
404
    #    { AXES_YZ:                [batch.YZ, maxlen, heads.X, qkv]
405
    #    { AXES_YZX:               [batch.YZX, maxlen, heads,  qkv]
406
    q = q_wi[:, :, :, : hparams.qkv]
407
    if attn_all_to_all == partitioning.AttnAllToAll.NONE:
408
      pass
409
    elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
410
      q = lax.all_to_all(
411
          q, axis_name='z', split_axis=0, concat_axis=2, tiled=True
412
      )
413
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
414
      q = lax.all_to_all(
415
          q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True
416
      )
417
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
418
      q = lax.all_to_all(
419
          q, axis_name='x', split_axis=0, concat_axis=2, tiled=True
420
      )
421
      q = lax.all_to_all(
422
          q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True
423
      )
424

425
    q = _rope(sin, cos, q)
426

427
    y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
428
    # y_att:
429
    #    { NONE:                   [batch,  maxlen, heads.YZX, qkv]
430
    #    { AXIS_Z:                 [batch.Z, maxlen, heads.YX, qkv]
431
    #    { AXES_YZ:                [batch.YZ, maxlen, heads.X, qkv]
432
    #    { AXES_YZX:               [batch.YZX, maxlen, heads,  qkv]
433
    # -> [batch.B, maxlen, heads.YZX, qkv]
434
    if attn_all_to_all == partitioning.AttnAllToAll.NONE:
435
      pass
436
    elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
437
      y_att = lax.all_to_all(
438
          y_att, axis_name='z', split_axis=2, concat_axis=0, tiled=True
439
      )
440
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
441
      y_att = lax.all_to_all(
442
          y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True
443
      )
444
    elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
445
      y_att = lax.all_to_all(
446
          y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True
447
      )
448
      y_att = lax.all_to_all(
449
          y_att, axis_name='x', split_axis=2, concat_axis=0, tiled=True
450
      )
451

452
  with jax.named_scope('SwiGLU'):
453
    y_mlp = special2.swish2(wi0) * wi1
454

455
  # einsum(y_fused, o_wo):
456
  # [batch, maxlen, heads.XYZ, o_wo_per_head]
457
  #   @ [heads.XYZ, o_wo_per_head, dmodel]
458
  # -> (matmul)
459
  # -> [batch, maxlen, dmodel]{XYZ unreduced}
460
  # -> (fused reducescatter over X)
461
  # -> [batch, maxlen, dmodel.X]{YZ unreduced}
462
  # -> (non-fused allreduce)
463
  # -> [batch, maxlen, dmodel.X]
464
  with jax.named_scope('o_wo'):
465
    y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
466
    two_d_parallel_xmap.assert_equal(
467
        y_fused.shape,
468
        (
469
            batch,
470
            max_len,
471
            hparams.heads // (x_axis * y_axis * z_axis),
472
            hparams.o_wo_per_head,
473
        ),
474
    )
475

476
    y_out = matmul_reducescatter(
477
        'bthd,hde->bte',
478
        y_fused,
479
        params.o_wo,
480
        scatter_axis=2,
481
        axis_name='x',
482
        layer=layer,
483
        subsplit_axis=2,
484
    )
485

486
    # TODO(sholto): Explore psum-scatter?
487
    y_out = lax.psum(y_out, axis_name=('y', 'z'))
488

489
    if isinstance(params, inference.QuantizedLayer):
490
      prev_shape = y_out.shape
491
      y_out = jnp.bfloat16(y_out * jnp.squeeze(my_layer(params.o_wo_scale)))
492
      two_d_parallel_xmap.assert_equal(y_out.shape, prev_shape)
493

494
  with jax.named_scope('residual'):
495
    z = jnp.bfloat16(y_out + x)
496
  return z, k[:batch_xyz], v[:batch_xyz]
497

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.