google-research
496 строк · 16.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""1D weight stationary xmap layer."""
17
18from typing import Sequence, Tuple
19
20import jax
21from jax import lax
22import jax.numpy as jnp
23import jax.scipy
24
25from scaling_transformer_inference_efficiency import attention
26from scaling_transformer_inference_efficiency import checkpoint
27from scaling_transformer_inference_efficiency import collectives
28from scaling_transformer_inference_efficiency import inference
29from scaling_transformer_inference_efficiency import partitioning
30from scaling_transformer_inference_efficiency import special2
31from scaling_transformer_inference_efficiency import weights
32from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap
33from scaling_transformer_inference_efficiency.layers.layers_pjit import _rope
34from scaling_transformer_inference_efficiency.weights import Layer
35
36HParams = checkpoint.HParams
37CheckpointSpec = checkpoint.CheckpointSpec
38
39
40# pylint: disable = protected-access
41# pylint: disable = g-doc-return-or-yield
42# pylint: disable = g-doc-args
43# TODO(sholto): Update
44def weight_stationary_simple(
45hparams,
46layer,
47params,
48sin,
49cos,
50kv_caches,
51x,
52x_axis,
53y_axis,
54z_axis,
55latency_collectives,
56intermediate_dtype = jnp.bfloat16,
57):
58"""Forward pass through a single layer, returning output, K, V.
59
60Partitioning:
61* 'x' is the only axis (most chips).
62* weights are sharded [dmodel, heads.XYZ, *_per_head]
63* hidden-dimension activations are sharded [batch, time, heads.XYZ,
64*_per_head]
65* embed-dimension activations are "naturally" sharded [batch, time, dmodel]
66but may be "oversharded" after reducescatter operations.
67
68To support XYZ>heads, we simply increase the number of heads, by padding. The
69FFN can be sharded finer and continues to gain speedup with more chips, but
70the ATTN will just be padded and not gain speedup as we add chips.
71"""
72if latency_collectives:
73matmul_reducescatter = collectives.matmul_reducescatter_latency
74# reducescatter = collectives.reducescatter_latency
75matmul_allgather = collectives.allgather_matmul_latency
76else:
77# matmul_reducescatter = partial(
78# collectives.matmul_reducescatter_throughput, subsplit_axis=0
79# )
80# # reducescatter = collectives.reducescatter_throughput
81# matmul_allgather = partial(
82# collectives.allgather_matmul_throughput, subsplit_axis=2
83# )
84matmul_reducescatter = collectives.matmul_reducescatter_oneway
85# reducescatter = collectives.reducescatter_throughput
86matmul_allgather = collectives.allgather_matmul_one_way
87
88def my_layer(t, axis=0):
89"""Gets the parameters corresponding to a given layer."""
90return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
91
92batch, max_len, _ = x.shape
93
94with jax.named_scope('layernorm'):
95# x: [batch, maxlen, dmodel.X]
96# mean2: [batch, maxlen]
97# xnorm: [batch, maxlen, dmodel.X]
98epsilon = 1e-6
99mean2 = lax.pmean(
100jnp.mean(lax.square(x), axis=-1, keepdims=True), axis_name='x'
101)
102xnorm = intermediate_dtype(x * lax.rsqrt(mean2 + epsilon)) # pytype: disable=not-callable # jnp-type
103
104# einsum(xnorm, q_wi):
105# [batch, maxlen, dmodel.X] @ [heads.XYZ, dmodel, q_wi_per_head]
106# -> (allgather lhs) (fused with matmul)
107# -> [batch, maxlen, dmodel]
108# -> (matmul)
109# -> [batch, maxlen, heads.XYZ, q_wi_per_head]
110with jax.named_scope('q_wi'):
111q_wi = matmul_allgather(
112'bte,hed->bthd',
113xnorm,
114params.q_wi,
115rhs_split_axis=1,
116axis_name='x',
117layer=layer,
118)
119
120# No need to scatter over y and z, as y and z will always be 1 in here.
121
122two_d_parallel_xmap.assert_equal(
123q_wi.shape,
124(
125batch,
126max_len,
127hparams.heads // (x_axis * y_axis * z_axis),
128hparams.q_wi_per_head,
129),
130)
131
132if isinstance(params, weights.QuantizedLayer):
133prev_shape = q_wi.shape
134q_wi = intermediate_dtype(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
135two_d_parallel_xmap.assert_equal(prev_shape, q_wi.shape)
136
137# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
138# swiGLU with full d_ff dimension, rather than 2/3 scaled
139wi0 = q_wi[
140:, :, :, hparams.qkv : hparams.qkv + (hparams.ff // hparams.heads)
141]
142wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads) :]
143
144# einsum(xnorm, kv):
145#
146# [batch, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
147# -> (matmul)
148# -> [batch, maxlen, 1, 2*qkv]{x unreduced}
149# -> (reducescatter over x into batch)
150# *NOT* collective matmul, because it's batch
151# -> { Attn.NONE: [batch, maxlen, 1, 2*qkv]
152with jax.named_scope('kv'):
153
154def kv_einsum(lhs):
155return jnp.einsum('bte,ezd->btzd', lhs, my_layer(params.kv))
156
157# kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm,
158# my_layer(params.kv))
159# [batch, maxlen, 1, 2*qkv]{x_unreduced}
160# --ARx--> [batch, maxlen, 1, 2*qkv]
161kv = lax.psum(kv_einsum(xnorm), 'x')
162
163if isinstance(params, inference.QuantizedLayer):
164prev_shape = kv.shape
165kv = intermediate_dtype(kv * jnp.squeeze(my_layer(params.kv_scale)))
166two_d_parallel_xmap.assert_equal(prev_shape, kv.shape)
167
168k = kv[:, :, 0, : hparams.qkv]
169v = kv[:, :, 0, hparams.qkv :]
170
171with jax.named_scope('attn'):
172k = _rope(sin, cos, k)
173
174# q: [batch, maxlen, heads.XYZ, qkv]
175q = q_wi[:, :, :, : hparams.qkv]
176q = _rope(sin, cos, q)
177
178# y_att: -> [batch.B, maxlen, heads.XYZ, qkv]
179y_att = intermediate_dtype(attention.attend(q, k, v, kv_caches, layer)) # pytype: disable=not-callable # jnp-type
180
181with jax.named_scope('SwiGLU'):
182y_mlp = special2.swish2(wi0) * wi1
183
184# einsum(y_fused, o_wo):
185# [batch, maxlen, heads.XYZ, o_wo_per_head]
186# @ [heads.XYZ, o_wo_per_head, dmodel]
187# -> (matmul)
188# -> [batch, maxlen, dmodel]{XYZ unreduced}
189# -> (fused reducescatter over X)
190# -> [batch, maxlen, dmodel.X]{YZ unreduced}
191# -> (non-fused allreduce)
192# -> [batch, maxlen, dmodel.X]
193with jax.named_scope('o_wo'):
194y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
195two_d_parallel_xmap.assert_equal(
196y_fused.shape,
197(
198batch,
199max_len,
200hparams.heads // (x_axis * y_axis * z_axis),
201hparams.o_wo_per_head,
202),
203)
204
205y_out = matmul_reducescatter(
206'bthd,hde->bte',
207y_fused,
208params.o_wo,
209scatter_axis=2,
210axis_name='x',
211layer=layer,
212)
213
214# No output psum because this is for only x
215# y_out = lax.psum(y_out, axis_name=('y', 'z'))
216
217if isinstance(params, inference.QuantizedLayer):
218prev_shape = y_out.shape
219y_out = intermediate_dtype(
220y_out * jnp.squeeze(my_layer(params.o_wo_scale))
221)
222two_d_parallel_xmap.assert_equal(y_out.shape, prev_shape)
223
224with jax.named_scope('residual'):
225z = intermediate_dtype(y_out + x) # pytype: disable=not-callable # jnp-type
226k, v = k.astype(intermediate_dtype), v.astype(intermediate_dtype)
227return z, k, v
228
229
230def weight_stationary(
231hparams,
232layer,
233params,
234sin,
235cos,
236kv_caches,
237x,
238x_axis,
239y_axis,
240z_axis,
241attn_all_to_all,
242latency_collectives,
243):
244"""Forward pass through a single layer, returning output, K, V.
245
246Partitioning:
247* 'x' is the longest axis (most chips), then 'y', then 'z'.
248* weights are sharded [dmodel, heads.XYZ, *_per_head]
249* hidden-dimension activations are sharded [batch, time, heads.XYZ,
250*_per_head]
251* embed-dimension activations are "naturally" sharded [batch, time, dmodel]
252but may be "oversharded" after reducescatter operations.
253
254To support XYZ>heads, we simply increase the number of heads, by padding. The
255FFN can be sharded finer and continues to gain speedup with more chips, but
256the ATTN will just be padded and not gain speedup as we add chips.
257"""
258if latency_collectives:
259matmul_reducescatter = collectives.matmul_reducescatter_latency
260# reducescatter = collectives.reducescatter_latency
261matmul_allgather = collectives.allgather_matmul_latency
262else:
263matmul_reducescatter = collectives.matmul_reducescatter_throughput
264# reducescatter = collectives.reducescatter_throughput
265matmul_allgather = collectives.allgather_matmul_throughput
266
267def my_layer(t, axis=0):
268"""Gets the parameters corresponding to a given layer."""
269return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)
270
271batch, max_len, _ = x.shape
272batch_z = batch // z_axis
273batch_yz = batch_z // y_axis
274batch_xyz = batch_yz // x_axis
275
276# x_index = lax.axis_index('x')
277y_index = lax.axis_index('y')
278z_index = lax.axis_index('z')
279yz_index = y_index * z_axis + z_index
280
281with jax.named_scope('layernorm'):
282# x: [batch, maxlen, dmodel.X]
283# mean2: [batch, maxlen]
284# xnorm: [batch, maxlen, dmodel.X]
285epsilon = 1e-6
286mean2 = lax.pmean(
287jnp.mean(lax.square(x), axis=-1, keepdims=True), axis_name='x'
288)
289xnorm = jnp.bfloat16(x * lax.rsqrt(mean2 + epsilon))
290
291# einsum(xnorm, q_wi):
292# [batch, maxlen, dmodel.X] @ [heads.XYZ, dmodel, q_wi_per_head]
293# -> (allgather lhs) (fused with matmul)
294# -> [batch, maxlen, dmodel]
295# -> (matmul)
296# -> [batch, maxlen, heads.XYZ, q_wi_per_head]
297with jax.named_scope('q_wi'):
298q_wi = matmul_allgather(
299'bte,hed->bthd',
300xnorm,
301params.q_wi,
302rhs_split_axis=1,
303axis_name='x',
304layer=layer,
305subsplit_axis=2,
306)
307
308two_d_parallel_xmap.assert_equal(
309q_wi.shape,
310(
311batch,
312max_len,
313hparams.heads // (x_axis * y_axis * z_axis),
314hparams.q_wi_per_head,
315),
316)
317
318if isinstance(params, weights.QuantizedLayer):
319prev_shape = q_wi.shape
320q_wi = jnp.bfloat16(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))
321two_d_parallel_xmap.assert_equal(prev_shape, q_wi.shape)
322
323# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
324# swiGLU with full d_ff dimension, rather than 2/3 scaled
325wi0 = q_wi[
326:, :, :, hparams.qkv : hparams.qkv + (hparams.ff // hparams.heads)
327]
328wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads) :]
329
330# einsum(xnorm, kv):
331#
332# [batch, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
333# -> (matmul)
334# -> [batch, maxlen, 1, 2*qkv]{x unreduced}
335# -> (reducescatter over x into batch)
336# *NOT* collective matmul, because it's batch
337# -> { Attn.NONE: [batch.B, maxlen, 1, 2*qkv]
338# { Attn.AXIS_Z: [batch.ZB, maxlen, 1, 2*qkv]
339# { Attn.AXES_YZ: [batch.YZB, maxlen, 1, 2*qkv]
340# { Attn.AXES_YZX: [batch.YZXB, maxlen, 1, 2*qkv]
341with jax.named_scope('kv'):
342
343def kv_einsum(lhs):
344return jnp.einsum('bte,ezd->btzd', lhs, my_layer(params.kv))
345
346# kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm,
347# my_layer(params.kv))
348
349if attn_all_to_all == partitioning.AttnAllToAll.NONE:
350# [batch, maxlen, 1, 2*qkv]{x_unreduced}
351# --ARx--> [batch, maxlen, 1, 2*qkv]
352kv = lax.psum(kv_einsum(xnorm), 'x')
353elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
354assert batch_z >= 1, 'Batch size too small for AXIS_Z and this chip count'
355# xnorm: [batch, maxlen, dmodel.X] -> [batch.Z, maxlen, dmodel.X]
356xnorm = lax.dynamic_slice_in_dim(
357xnorm, z_index * batch_z, batch_z, axis=0
358)
359# [batch.Z, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
360# --matmul--> [batch.Z, maxlen, 1, 2*qkv]{x unreduced}
361# --ARx--> [batch.Z, maxlen, 1, 2*qkv]
362kv = lax.psum(kv_einsum(xnorm), 'x')
363elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
364assert (
365batch_yz >= 1
366), 'Batch size too small for AXES_YZ and this chip count'
367# xnorm: [batch, maxlen, dmodel.X] -> [batch.YZ, maxlen, dmodel.X]
368xnorm = lax.dynamic_slice_in_dim(
369xnorm, yz_index * batch_yz, batch_yz, axis=0
370)
371# [batch.YZ, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
372# --matmul--> [batch.YZ, maxlen, 1, 2*qkv]{x unreduced}
373# --ARx--> [batch.YZ, maxlen, 1, 2*qkv]
374kv = lax.psum(kv_einsum(xnorm), 'x')
375elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
376assert (
377batch_xyz >= 1
378), 'Batch size too small for AXES_XYZ and this chip count'
379# xnorm: [batch, maxlen, dmodel.X] -> [batch.YZ, maxlen, dmodel.X]
380xnorm = lax.dynamic_slice_in_dim(
381xnorm, yz_index * batch_yz, batch_yz, axis=0
382)
383# [batch.YZ, maxlen, dmodel.X] @ [dmodel.X, 1, 2*qkv]
384# --matmul--> [batch.YZ, maxlen, 1, 2*qkv]{x unreduced}
385# --RSx--> [batch.YZ, maxlen, 1, 2*qkv]
386kv = lax.psum_scatter(
387kv_einsum(xnorm), 'x', scatter_dimension=0, tiled=True
388)
389
390if isinstance(params, inference.QuantizedLayer):
391prev_shape = kv.shape
392kv = jnp.bfloat16(kv * jnp.squeeze(my_layer(params.kv_scale)))
393two_d_parallel_xmap.assert_equal(prev_shape, kv.shape)
394
395k = kv[:, :, 0, : hparams.qkv]
396v = kv[:, :, 0, hparams.qkv :]
397
398with jax.named_scope('attn'):
399k = _rope(sin, cos, k)
400
401# q: [batch, maxlen, heads.XYZ, qkv]
402# -> { NONE: [batch, maxlen, heads.XYZ, qkv]
403# { AXIS_Z: [batch.Z, maxlen, heads.XY, qkv]
404# { AXES_YZ: [batch.YZ, maxlen, heads.X, qkv]
405# { AXES_YZX: [batch.YZX, maxlen, heads, qkv]
406q = q_wi[:, :, :, : hparams.qkv]
407if attn_all_to_all == partitioning.AttnAllToAll.NONE:
408pass
409elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
410q = lax.all_to_all(
411q, axis_name='z', split_axis=0, concat_axis=2, tiled=True
412)
413elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
414q = lax.all_to_all(
415q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True
416)
417elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
418q = lax.all_to_all(
419q, axis_name='x', split_axis=0, concat_axis=2, tiled=True
420)
421q = lax.all_to_all(
422q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True
423)
424
425q = _rope(sin, cos, q)
426
427y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))
428# y_att:
429# { NONE: [batch, maxlen, heads.YZX, qkv]
430# { AXIS_Z: [batch.Z, maxlen, heads.YX, qkv]
431# { AXES_YZ: [batch.YZ, maxlen, heads.X, qkv]
432# { AXES_YZX: [batch.YZX, maxlen, heads, qkv]
433# -> [batch.B, maxlen, heads.YZX, qkv]
434if attn_all_to_all == partitioning.AttnAllToAll.NONE:
435pass
436elif attn_all_to_all == partitioning.AttnAllToAll.AXIS_Z:
437y_att = lax.all_to_all(
438y_att, axis_name='z', split_axis=2, concat_axis=0, tiled=True
439)
440elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZ:
441y_att = lax.all_to_all(
442y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True
443)
444elif attn_all_to_all == partitioning.AttnAllToAll.AXES_YZX:
445y_att = lax.all_to_all(
446y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True
447)
448y_att = lax.all_to_all(
449y_att, axis_name='x', split_axis=2, concat_axis=0, tiled=True
450)
451
452with jax.named_scope('SwiGLU'):
453y_mlp = special2.swish2(wi0) * wi1
454
455# einsum(y_fused, o_wo):
456# [batch, maxlen, heads.XYZ, o_wo_per_head]
457# @ [heads.XYZ, o_wo_per_head, dmodel]
458# -> (matmul)
459# -> [batch, maxlen, dmodel]{XYZ unreduced}
460# -> (fused reducescatter over X)
461# -> [batch, maxlen, dmodel.X]{YZ unreduced}
462# -> (non-fused allreduce)
463# -> [batch, maxlen, dmodel.X]
464with jax.named_scope('o_wo'):
465y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
466two_d_parallel_xmap.assert_equal(
467y_fused.shape,
468(
469batch,
470max_len,
471hparams.heads // (x_axis * y_axis * z_axis),
472hparams.o_wo_per_head,
473),
474)
475
476y_out = matmul_reducescatter(
477'bthd,hde->bte',
478y_fused,
479params.o_wo,
480scatter_axis=2,
481axis_name='x',
482layer=layer,
483subsplit_axis=2,
484)
485
486# TODO(sholto): Explore psum-scatter?
487y_out = lax.psum(y_out, axis_name=('y', 'z'))
488
489if isinstance(params, inference.QuantizedLayer):
490prev_shape = y_out.shape
491y_out = jnp.bfloat16(y_out * jnp.squeeze(my_layer(params.o_wo_scale)))
492two_d_parallel_xmap.assert_equal(y_out.shape, prev_shape)
493
494with jax.named_scope('residual'):
495z = jnp.bfloat16(y_out + x)
496return z, k[:batch_xyz], v[:batch_xyz]
497