google-research
595 строк · 21.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""One Transformer layer, in hard xmap."""
17
18from functools import partial # pylint: disable = g-importing-member19from typing import Sequence, Tuple20
21import jax22from jax import lax23import jax.numpy as jnp24import jax.scipy25
26from scaling_transformer_inference_efficiency import attention27from scaling_transformer_inference_efficiency import checkpoint28from scaling_transformer_inference_efficiency import collectives29from scaling_transformer_inference_efficiency import special230from scaling_transformer_inference_efficiency import weights31from scaling_transformer_inference_efficiency.chunk import Chunk32from scaling_transformer_inference_efficiency.layers.layers_pjit import _rope33from scaling_transformer_inference_efficiency.partitioning import AttnAllToAll34from scaling_transformer_inference_efficiency.weights import Layer35
36HParams = checkpoint.HParams37CheckpointSpec = checkpoint.CheckpointSpec38Layer = weights.Layer39QuantizedLayer = weights.QuantizedLayer40Weights = weights.Weights41
42ATTN_3D_SHARDING_THRESHOLD_PER_CHIP = 243
44# pylint: disable = invalid-name
45# pylint: disable = protected-access
46# pylint: disable = g-bare-generic
47
48
49def assert_equal(x, y):50assert x == y, f'{x} != {y}'51
52
53def allgather_layernorm(x,54shard_seqlen_vs_batch,55batch_unsharded = False,56scale = None):57"""All gathers around layernorm, minimises comms by first doing per-chip."""58with jax.named_scope('allgather_layernorm'):59# allgather xnorm: [batch.Z, maxlen, embed.XY] || [batch, maxlen, embed.XYZ]60# -> [batch.Z, maxlen, embed.X] (xnorm_z)61# -> [batch, maxlen, embed.X]62xgather = x63if batch_unsharded:64# [batch, maxlen, embed.XY]65xgather = lax.all_gather(xgather, 'z', axis=2, tiled=True)66# [batch.Z, maxlen, embed.X] || [batch, maxlen, embed.X]67xgather = lax.all_gather(xgather, 'y', axis=2, tiled=True)68
69epsilon = 1e-670xgather = jnp.float32(xgather)71mean2 = lax.pmean(72jnp.mean(lax.square(xgather), axis=-1, keepdims=True), axis_name='x')73xnorm_z = jnp.bfloat16(xgather * lax.rsqrt(mean2 + epsilon))74if scale is not None:75scale += 1.0 # 'center_scale_at_zero' option in T5X76xnorm_z = jnp.bfloat16(xnorm_z * scale)77# when attention_all_to_all is None we can partition over sequence len not78# batch79if shard_seqlen_vs_batch:80xnorm = lax.all_gather(xnorm_z, 'z', axis=1, tiled=True)81else:82if batch_unsharded: # in this case already done above83xnorm = xnorm_z84else:85xnorm = lax.all_gather(xnorm_z, 'z', axis=0, tiled=True)86# [batch, maxlen, embed.X]87return xnorm, xnorm_z88
89
90@partial(jax.jit, static_argnums=(3, 4, 5))91def embed_manual(92params, # pylint: disable=g-bare-generic, invalid-name93kv_caches,94token_chunk,95shard_seqlen_vs_batch = False,96batch_unsharded = False,97one_d = False,98):99"""Embeds a chunk of logits.100
101Args:
102params: Weights object
103kv_caches: List of chunks preprocessed earlier
104token_chunk: An unsharded token chunk. Assume .tokens is int32[batch,
105maxlen]
106shard_seqlen_vs_batch: Whether to shard seqlen or batch by z.
107batch_unsharded: global_batch is less than z so we cannot shard along
108one_d: whether it is one dimensional
109
110Returns:
111embeddings: bfloat16[[batch.Z, time, embed.XY] || [batch, time, embed.XYZ]
112sin: RoPE embeddings starting at the appropriate index determined by
113pre-existing kv_cache for each index in the batch.
114cos: ""
115"""
116
117z_axis = lax.psum(1, 'z')118# Start indices are the sums of the lengths of the KV caches.119start_indices = attention.prefix_lengths(kv_caches)120prefix_batch, = start_indices.shape121batch, max_length = token_chunk.tokens.shape122assert batch % prefix_batch == 0, 'Incompatible batch sizes'123# Do RoPE lookups in the sin/cos tables. Only needed once per prefix_batch.124def slice_at(index, table):125# table: [precomputed_length, qkv // 2]126return lax.dynamic_slice_in_dim(table, index, max_length)127
128def slices_at(indices, table):129return jax.vmap(slice_at, in_axes=(0, None))(indices, table)130
131sin = slices_at(start_indices, params.sin)132cos = slices_at(start_indices, params.cos)133# sin, cos: bf16[prefix_batch, max_length, qkv // 2]134
135# x: int32[batch, maxlen]136# embed: bfloat16[vocab.YZ, embed.X]137x = token_chunk.tokens138vocab_yz, _ = params.embedding.shape139
140yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')141vocab_start = yz_index * vocab_yz142
143# Initial embedding lookup:144with jax.named_scope('embed'):145one_x = x - vocab_start146embeds = params.embedding[one_x, :]147one_x = one_x[:, :, jnp.newaxis]148embeds = jnp.where((one_x >= 0) & (one_x < vocab_yz), embeds, 0)149# [batch, time, embed.X]150if one_d:151return embeds, sin, cos152# [batch, time, embed.XY]153embeds = lax.psum_scatter(embeds, 'y', scatter_dimension=2, tiled=True)154
155if shard_seqlen_vs_batch:156# [batch, time.Z, embed.XY]157embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=1, tiled=True)158else:159if batch_unsharded:160# [batch, time, embed.XYZ]161embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=2, tiled=True)162else:163# [batch.Z, time, embed.XY]164embeds = lax.psum_scatter(embeds, 'z', scatter_dimension=0, tiled=True)165
166return embeds, sin, cos167
168
169def unembed_manual(170xnorm,171params,172batch_unsharded = False,173one_d = False,174):175"""Unembedding function for 2D."""176# x: bfloat16[batch, maxlen, dmodel.X] # [vocab.YZ, embed.X]177# TODO(sholto): We could squeeze out more memory by doing this178# with a collective179with jax.named_scope('unembed'):180logits_unreduced = jnp.einsum(181'bte,ve->btv', jnp.float32(xnorm), jnp.float32(params.embedding)182)183# x: [batch, maxlen, vocab.YZ] {X unreduced}184if batch_unsharded or one_d:185# logits: float32[batch, maxlen, vocab.YZX]186logits = lax.psum_scatter(187logits_unreduced, 'x', scatter_dimension=2, tiled=True188)189else:190# logits: float32[batch.X, maxlen, vocab.YZ]191logits = lax.psum_scatter(192logits_unreduced, 'x', scatter_dimension=0, tiled=True193)194return logits195
196
197# pylint: disable = g-doc-return-or-yield
198# pylint: disable = g-doc-args
199# TODO(sholto): Update to new, tested parsing collectives.
200
201
202def transformer_layer_weight_stationary(203hparams,204layer,205params,206sin,207cos,208kv_caches,209x,210x_axis,211y_axis,212z_axis,213*,214attn_all_to_all,215latency_collectives,216shard_seqlen_vs_batch = False,217batch_unsharded = False,218intermediate_dtype = jnp.bfloat16,219):220"""Wraps _fn so that we can use remat while bug is fixed."""221return jax.checkpoint(222partial(223_transformer_layer_weight_stationary,224attn_all_to_all=attn_all_to_all,225latency_collectives=latency_collectives,226shard_seqlen_vs_batch=shard_seqlen_vs_batch,227batch_unsharded=batch_unsharded,228intermediate_dtype=intermediate_dtype,229),230static_argnums=(0, 7, 8, 9),231prevent_cse=True,232)(hparams, layer, params, sin, cos, kv_caches, x, x_axis, y_axis, z_axis)233
234
235def _transformer_layer_weight_stationary(236hparams,237layer,238params,239sin,240cos,241kv_caches,242x,243x_axis,244y_axis,245z_axis,246*,247attn_all_to_all,248latency_collectives,249shard_seqlen_vs_batch = False,250batch_unsharded = False,251intermediate_dtype = jnp.bfloat16,252):253"""Forward pass through a single layer, returning output, K, V.254
255This implementation has 'x'=d_model sharding,
256('y', 'z')=d_ff sharding.
257* params are assumed already sharded this way, i.e. embed.X and heads.YZ
258* sin and cos are sharded by batch.YZx (or batch.YZ or batch.Y as necessary)
259* kv_cache is sharded by batch.YZx (or batch.YZ or batch.Y as necessary)
260* x: [batch.Z, maxlen, embed.XY]
261"""
262intermediate_dtype = jax.core.concrete_or_error(None, intermediate_dtype)263if latency_collectives:264matmul_reducescatter = partial(265collectives.matmul_reducescatter_latency, subsplit_axis=2)266# reducescatter = collectives.reducescatter_latency267# subsplit along heads as they are indepedent268# partial here because the one-way algorithm does not use subsplit269matmul_allgather = partial(270collectives.allgather_matmul_latency, subsplit_axis=2)271else:272matmul_reducescatter = collectives.matmul_reducescatter_oneway273# reducescatter = collectives.reducescatter_oneway274matmul_allgather = collectives.allgather_matmul_one_way275
276def my_layer(t, axis=0):277"""Gets the parameters corresponding to a given layer."""278return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)279
280# Compare281# flaxformer/architectures/t5/parallel_fused_decoder.py282# flaxformer/components/attention/dense_attention.py;l=1147;283# flaxformer/components/attention/dense_attention.py;l=252;284
285batch_z, max_len, _ = x.shape286if shard_seqlen_vs_batch:287max_len *= z_axis288batch = batch_z289batch_xyz = batch // (x_axis * y_axis * z_axis)290else:291if batch_unsharded:292batch = x.shape[0]293else:294batch = batch_z * z_axis295batch_xyz = batch // (x_axis * y_axis * z_axis)296batch_yz = batch // (y_axis * z_axis)297batch_z = batch // (z_axis)298
299if isinstance(params, weights.QuantizedLayer):300xnorm, xnorm_z = allgather_layernorm(301x,302shard_seqlen_vs_batch,303batch_unsharded,304scale=my_layer(params.layernorm_scale))305else:306xnorm, xnorm_z = allgather_layernorm(x, shard_seqlen_vs_batch,307batch_unsharded)308
309# einsum(xnorm, q_wi):310# [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]311# -> (matmul)312# -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}313# -> (reducescatter over x into X heads, B batches)314# -> [batch, maxlen, heads.YZX, q_wi_per_head]315with jax.named_scope('q_wi'):316xnorm = intermediate_dtype(xnorm)317q_wi = matmul_reducescatter(318'bte,hed->bthd',319xnorm,320params.q_wi,321scatter_axis=0,322axis_name='x',323layer=layer)324
325if isinstance(params, weights.QuantizedLayer):326prev_shape = q_wi.shape327q_wi = intermediate_dtype(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))328assert_equal(prev_shape, q_wi.shape)329
330# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements331# swiGLU with full d_ff dimension, rather than 2/3 scaled332wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads))] # pylint: disable = line-too-long333wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads)):] # pylint: disable = line-too-long334
335# einsum(xnorm, kv):336#337# if attn>=AXES_YZ:338# xnorm_z: [batch.Z, maxlen, embed.X]339# -> [batch.(X?)YZ, maxlen, embed.X] (slice down)340#341# Then:342#343# [batch.(Y?)Z, maxlen, embed.X] @ [embed.X, 1, 2*qkv]344# -> (matmul)345# -> [batch.(Y?)Z, maxlen, 1, 2*qkv]{x unreduced}346# -> (reducescatter over x into batch)347# *NOT* collective matmul, because it's batch348# -> { Attn.NONE: [batch.B, maxlen, 1, 2*qkv]349# { Attn.AXIS_Z: [batch.ZB, maxlen, 1, 2*qkv]350# { Attn.AXES_YZ: [batch.YZB, maxlen, 1, 2*qkv]351# { Attn.AXES_YZX: [batch.YZXB, maxlen, 1, 2*qkv]352with jax.named_scope('kv'):353# TODO(sholto): update this in oversharded354yz_index = lax.axis_index('y') * z_axis + lax.axis_index('z')355# TODO(reinerp): Consider using xnorm instead of xnorm_z in NONE case?356# I don't know yet if that's better.357if attn_all_to_all.value >= AttnAllToAll.AXES_YZ.value:358xnorm_sliced = lax.dynamic_slice_in_dim(359xnorm, yz_index * batch_yz, batch_yz, axis=0)360else:361xnorm_sliced = xnorm_z362
363kv_unreduced = jnp.einsum('bte,ezd->btzd', xnorm_sliced,364my_layer(params.kv))365
366if attn_all_to_all == AttnAllToAll.NONE:367if shard_seqlen_vs_batch:368# [batch, maxlen.Z, 1, 2*qkv]{x_unreduced}369# -> [batch.B, maxlen, 1, 2*qkv]370kv = lax.psum(kv_unreduced, 'x')371kv = lax.all_gather(kv, 'z', axis=1, tiled=True)372else:373# [batch.Z, maxlen, 1, 2*qkv]{x_unreduced} || [b, ml, 1, 2qkv] {x_unred}374# --ARx--> [batch.Z, maxlen, 1, 2*qkv]375# --AGZ--> [batch, maxlen, 1, 2*qkv]376kv = lax.psum(kv_unreduced, 'x')377if not batch_unsharded:378kv = lax.all_gather(kv, 'z', axis=0, tiled=True)379elif attn_all_to_all == AttnAllToAll.AXIS_Z:380# [batch.Z, maxlen, 1, 2*qkv]{x_unreduced}381# --ARx--> [batch.Z, maxlen, 1, 2*qkv]382kv = lax.psum(kv_unreduced, 'x')383# print('kv2', kv.shape, kv.named_shape)384elif attn_all_to_all == AttnAllToAll.AXES_YZ:385# [batch.YZ, maxlen, 1, 2*qkv]{x_unreduced}386# --ARx--> [batch.YZ, maxlen, 1, 2*qkv]387kv = lax.psum(kv_unreduced, 'x')388elif attn_all_to_all == AttnAllToAll.AXES_YZX:389# [batch.YZ, maxlen, 1, 2*qkv]{x_unreduced}390# --RSx--> [batch.YZX, maxlen, 1, 2*qkv]391assert batch_xyz >= 1, ('Batch size too small for AXES_XYZ and this chip '392'count')393kv = lax.psum_scatter(kv_unreduced, 'x', scatter_dimension=0, tiled=True)394
395if isinstance(params, weights.QuantizedLayer):396prev_shape = kv.shape397kv = intermediate_dtype(kv * jnp.squeeze(my_layer(params.kv_scale)))398assert_equal(prev_shape, kv.shape)399
400k = kv[:, :, 0, :hparams.qkv]401v = kv[:, :, 0, hparams.qkv:]402
403with jax.named_scope('attn'):404k = _rope(sin, cos, k)405
406# print(f'batch_yzb: {batch_yzb}')407# q: [batch, maxlen, heads.YZX, qkv]408# -> { NONE: [batch., maxlen, heads.YZX, qkv]409# { AXIS_Z: [batch.Z, maxlen, heads.YX, qkv]410# { AXES_YZ: [batch.YZ, maxlen, heads.X, qkv]411# { AXES_YZX: [batch.YZX, maxlen, heads, qkv]412q = q_wi[:, :, :, :hparams.qkv]413if attn_all_to_all == AttnAllToAll.NONE:414pass415elif attn_all_to_all == AttnAllToAll.AXIS_Z:416q = lax.all_to_all(417q, axis_name='z', split_axis=0, concat_axis=2, tiled=True)418elif attn_all_to_all == AttnAllToAll.AXES_YZ:419q = lax.all_to_all(420q, axis_name=('y', 'z'), split_axis=0, concat_axis=2, tiled=True)421elif attn_all_to_all == AttnAllToAll.AXES_YZX:422q = lax.all_to_all(423q, axis_name=('y', 'z', 'x'), split_axis=0, concat_axis=2, tiled=True)424
425q = _rope(sin, cos, q)426
427y_att = intermediate_dtype(attention.attend(q, k, v, kv_caches, layer))428# y_att:429# { NONE: [batch.B, maxlen, heads.YZX, qkv]430# { AXIS_Z: [batch.ZB, maxlen, heads.YX, qkv]431# { AXES_YZ: [batch.YZB, maxlen, heads.X, qkv]432# { AXES_YZX: [batch.YZX, maxlen, heads, qkv]433# -> [batch, maxlen, heads.YZX, qkv]434if attn_all_to_all == AttnAllToAll.NONE:435pass436elif attn_all_to_all == AttnAllToAll.AXIS_Z:437y_att = lax.all_to_all(438y_att, axis_name='z', split_axis=2, concat_axis=0, tiled=True)439elif attn_all_to_all == AttnAllToAll.AXES_YZ:440y_att = lax.all_to_all(441y_att, axis_name=('y', 'z'), split_axis=2, concat_axis=0, tiled=True)442elif attn_all_to_all == AttnAllToAll.AXES_YZX:443y_att = lax.all_to_all(444y_att,445axis_name=('y', 'z', 'x'),446split_axis=2,447concat_axis=0,448tiled=True)449
450with jax.named_scope('SwiGLU'):451y_mlp = special2.swish2(wi0) * wi1452
453# einsum(y_fused, o_wo):454# [batch, maxlen, heads.YZ, o_wo_per_head] @455# [heads.YZ, o_wo_per_head, embed.X]456# -> (matmul)457# -> [batch, maxlen, embed.X]{YZ unreduced}458# -> (fused reducescatter)459# -> [batch, maxlen, embed.XY]460# -> (non-fused reducescatter)461# -> [batch.Z, maxlen, embed.XY]462with jax.named_scope('o_wo'):463y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)464
465# do the second half of the mlp and the self-attn projection in parallel466# allgather y_fused: [batch, maxlen, heads.YZX, o_wo_per_head]467# -> [batch, maxlen, heads.YZ, o_wo_per_head]468# we use the collective matmul/reducescatter instead469# print(f'o_wo: {params.o_wo.shape}')470y_out = matmul_allgather(471'bthd,hde->bte',472y_fused,473params.o_wo,474rhs_split_axis=0,475axis_name='x',476layer=layer)477# y_out = reducescatter(478# y_out, scatter_dimension=2, axis_name='y', subsplit_axis=2)479
480y_out = lax.psum_scatter(y_out, 'y', scatter_dimension=2, tiled=True)481
482if shard_seqlen_vs_batch:483# y_out = reducescatter(484# y_out, scatter_dimension=1, axis_name='z', subsplit_axis=0)485# [batch, maxlen.Z, embed.XY]486y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=1, tiled=True)487else:488# y_out = reducescatter(489# y_out, scatter_dimension=0, axis_name='z', subsplit_axis=0)490# TODO(sholto): Test if manual faster, update491if batch_unsharded:492# [batch, maxlen, embed.XYZ]493y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=2, tiled=True)494else:495# [batch.Z, maxlen, embed.XY]496y_out = lax.psum_scatter(y_out, 'z', scatter_dimension=0, tiled=True)497
498if isinstance(params, weights.QuantizedLayer):499prev_shape = y_out.shape500y_out = intermediate_dtype(y_out *501jnp.squeeze(my_layer(params.o_wo_scale)))502assert_equal(y_out.shape, prev_shape)503
504with jax.named_scope('residual'):505z = intermediate_dtype(y_out + x)506
507k, v = k.astype(intermediate_dtype), v.astype(intermediate_dtype)508return z, k, v509
510
511def transformer_layer_weight_gathered(512hparams, layer, params, sin,513cos, kv_caches, x,514x_axis, y_axis,515z_axis):516"""Weight gathered parallel layer. Typically prefill."""517del x_axis, y_axis, z_axis # for API compatibility518# x: [batch.XYZ, t, e]519with jax.named_scope('allgather_layernorm'):520# No need to communicate across batch, so everything is local521x_prec = jnp.float32(x)522epsilon = 1e-6523mean2 = jnp.mean(lax.square(x_prec), axis=-1, keepdims=True)524xnorm = jnp.bfloat16(x * lax.rsqrt(mean2 + epsilon))525
526def my_layer(t, axis=0):527"""Gets the parameters corresponding to a given layer."""528return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)529
530# [batch.XYZ, t, e] @ [heads.YZ, e.X, q_wi_per_head]531with jax.named_scope('q_wi'):532q_wi = collectives.matmul_collective_weights_gather_q_wi(533'bte,hed->bthd',534xnorm,535my_layer(536params.q_wi537), # in this case it makes sense to do this here because its once538lhs_split_axis=2) # -> [batch.XYZ, t, h, q_wi_per_head]539
540if isinstance(params, weights.QuantizedLayer):541prev_shape = q_wi.shape542q_wi = jnp.bfloat16(q_wi * jnp.squeeze(my_layer(params.q_wi_scale)))543assert_equal(prev_shape, q_wi.shape)544
545# unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements546# swiGLU with full d_ff dimension, rather than 2/3 scaled547wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads))] # pylint: disable = line-too-long548wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // (hparams.heads - hparams.padded_heads)):] # pylint: disable = line-too-long549
550# kv is only batch sharded551with jax.named_scope('kv'):552# [batch.XYZ, t, e] @ [e, 1, 2*qkv] -> [batch.XYZ, t, 1, 2*qkv]553# Two options here:554# a) Split along x, and then all reduce along x555# b) We fully replicate kv556kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))557
558if isinstance(params, weights.QuantizedLayer):559prev_shape = kv.shape560kv = jnp.bfloat16(kv * jnp.squeeze(my_layer(params.kv_scale)))561assert_equal(prev_shape, kv.shape)562
563k = kv[:, :, 0, :hparams.qkv] # [batch.XYZ, t, qkv]564v = kv[:, :, 0, hparams.qkv:] # [batch.XYZ, t, qkv]565
566with jax.named_scope('attn'):567k = _rope(sin, cos, k) # [batch.XYZ, t, qkv]568q = q_wi[:, :, :, :hparams.qkv]569q = _rope(sin, cos, q) # [batch.XYZ, t, h, qkv]570
571# [batch.XYZ, t, h, qkv]572y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))573
574with jax.named_scope('SwiGLU'):575y_mlp = special2.swish2(wi0) * wi1 # [batch.XYZ, t, h, ff_per_head]576
577# [bach.XYZ, t , h, d] @ [h.YZ, d, e.X] -> [batch.XYZ, t, e]578with jax.named_scope('o_wo'):579y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)580
581# previously concat yz, contracting over x - reconstructing heads dim582# here, we contract over yz, concat over x to reconstruct embed dim583y_out = collectives.matmul_collective_weights_gather_o_wo(584'bthd,hde->bte', y_fused, my_layer(params.o_wo),585lhs_split_axis=2) # -> [batch.XYZ, t, e]586
587if isinstance(params, weights.QuantizedLayer):588prev_shape = y_out.shape589y_out = jnp.bfloat16(y_out * jnp.squeeze(my_layer(params.o_wo_scale)))590assert_equal(y_out.shape, prev_shape)591
592with jax.named_scope('residual'):593z = jnp.bfloat16(y_out + x)594
595return z, k, v596