google-research
478 строк · 14.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Handy import wrappers."""
17
18import dataclasses
19from enum import Enum # pylint: disable=g-importing-member
20from functools import partial # pylint: disable=g-importing-member
21import logging
22from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
23
24from flax import struct
25from flax.training import common_utils
26import jax
27import jax.numpy as jnp
28import numpy as np
29from seqio.vocabularies import Vocabulary
30from t5x import losses
31from t5x.models import DecoderOnlyModel
32
33from scaling_transformer_inference_efficiency import checkpoint
34from scaling_transformer_inference_efficiency import chunk
35from scaling_transformer_inference_efficiency import incremental
36from scaling_transformer_inference_efficiency import inference
37from scaling_transformer_inference_efficiency import partitioning
38from scaling_transformer_inference_efficiency import sampling
39from scaling_transformer_inference_efficiency import weights
40from scaling_transformer_inference_efficiency.layers import one_d_parallel_xmap
41from scaling_transformer_inference_efficiency.layers import two_d_parallel_xmap
42
43
44PyTree = Any
45
46
47@struct.dataclass
48class TestVocab:
49eos_id = 0
50bos_id = 0
51pad_id = 0
52
53def encode_tf(self, text):
54chars = np.array([ord(c) for c in text]).astype(np.int32)
55return chars
56
57def decode_tf(self, tokens):
58results = np.split(tokens, tokens.shape[0])
59return np.array([[chr(r) for r in list(line[0])] for line in results])
60
61
62class Layout(Enum):
63TWO_D = 'two_d'
64ONE_D = 'one_d'
65WEIGHT_GATHERED = 'weight_gathered'
66
67
68@dataclasses.dataclass
69class ModelConfig:
70"""An object to make gin file input elegant.
71
72ckpt_path: typically cns path
73size: 8, 62, 540
74quantized:
75generate_steps: Amount of steps to do generation with
76kv_cache_sharding: the degree of kv cache sharding (0: None, 1: Z, 2: YZ, 3:
77YZX)
78latency_collectives: whether to use latency optimised forms (double compute
79per step, half the steps for collective matmuls)
80batch_unsharded: whether to shard batch dim
81shard_seqlen_vs_batch: whether to shard seqlen vs batch
82stream: An object to facilitate streaming back to X (you defined the
83callbacks).
84transpose_scan_axis: transpose if layers was not saved as the leading axis
85bos_id: Optionally overwrite bos_id to the model.
86"""
87
88ckpt_path: str
89size: int
90quantized: bool
91generate_steps: int
92kv_cache_sharding: int
93latency_collectives: bool
94batch_unsharded: bool
95shard_seqlen_vs_batch: bool
96stream: Optional[incremental.StreamClient] = None
97transpose_scan_axis: bool = True
98layout: Layout = Layout.TWO_D
99bos_id: Optional[int] = None
100
101
102def return_minimal_palm(
103cfg,
104params_already_loaded=False,
105remat = None,
106devices = None,
107): # pylint: disable = g-bare-generic, line-too-long
108"""Utility function to return a model.
109
110Args:
111cfg: A model configuration
112params_already_loaded: whether params have been loaded yet
113remat: Whether to remat the layer, used for training.
114jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
115jax.checkpoint_policies.nothing_saveable
116devices: devices to make a mesh from
117
118Returns:
119model: A model wrapper
120params: The params
121prefill_fn: Function to pass as prefill (to ensure it is compilation cached)
122generate_fn: Function to pass as generation (to ensure it is compilation
123cached)
124"""
125one_d = cfg.layout == Layout.ONE_D
126if cfg.shard_seqlen_vs_batch and cfg.batch_unsharded:
127raise NotImplementedError(
128"Either shard seqlen instead of batch or don't shard batch."
129)
130
131del remat # for the moment, always remat
132# We have preset sizes
133if cfg.size == 0:
134hparams = checkpoint.HParams.TOY
135if cfg.size == 8:
136hparams = checkpoint.HParams.PALM_8B
137elif cfg.size == 62:
138hparams = checkpoint.HParams.PALM_62B
139elif cfg.size == 540:
140hparams = checkpoint.HParams.PALM_540B
141
142if cfg.quantized:
143ckpt = checkpoint.QuantizedCheckpoint
144params_spec = weights.QuantizedWeights
145else:
146ckpt = checkpoint.Checkpoint
147params_spec = weights.Weights
148
149if cfg.size == 0:
150loaded_ckpt = ckpt.init_zero(hparams)
151else:
152spec = checkpoint.CheckpointSpec(
153hparams=hparams,
154dir=cfg.ckpt_path,
155transpose_scan_axis=cfg.transpose_scan_axis,
156)
157loaded_ckpt = ckpt.load_spec(spec)
158
159if cfg.kv_cache_sharding == 0:
160attn_batch_sharding = partitioning.AttnAllToAll.NONE
161elif cfg.kv_cache_sharding == 1:
162attn_batch_sharding = partitioning.AttnAllToAll.AXIS_Z
163elif cfg.kv_cache_sharding == 2:
164attn_batch_sharding = partitioning.AttnAllToAll.AXES_YZ
165elif cfg.kv_cache_sharding == 3:
166attn_batch_sharding = partitioning.AttnAllToAll.AXES_YZX
167else:
168raise NotImplementedError
169
170if cfg.layout == Layout.TWO_D:
171rules = partitioning.make_rules_two_d(
172attn_batch_sharding, batch_unsharded=cfg.batch_unsharded
173)
174layer_fn = partial(
175two_d_parallel_xmap.transformer_layer_weight_stationary,
176attn_all_to_all=attn_batch_sharding,
177latency_collectives=cfg.latency_collectives,
178shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
179batch_unsharded=cfg.batch_unsharded,
180)
181# sample_fn = partial(sampling.sample_manual,
182# batch_unsharded=cfg.batch_unsharded)
183sample_fn = sampling.sample
184
185elif cfg.layout == Layout.ONE_D:
186rules = partitioning.make_rules_one_d()
187layer_fn = partial(
188one_d_parallel_xmap.weight_stationary_simple,
189latency_collectives=cfg.latency_collectives,
190)
191sample_fn = sampling.sample_manual_batch_unsharded
192elif cfg.layout == Layout.WEIGHT_GATHERED:
193rules = partitioning.make_rules_weight_gathered()
194sample_fn = sampling.sample
195raise NotImplementedError
196else:
197raise NotImplementedError
198
199if cfg.size == 0:
200the_vocab = TestVocab()
201else:
202the_vocab = checkpoint.load_vocab()
203
204mesh = partitioning.make_mesh(one_d=one_d, devices=devices)
205sharding_config = partitioning.ShardingConfig(
206mesh=mesh,
207attn_all_to_all=attn_batch_sharding,
208latency_collectives=cfg.latency_collectives,
209shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
210batch_unsharded=cfg.batch_unsharded,
211)
212
213embed_fn = partial(
214two_d_parallel_xmap.embed_manual,
215shard_seqlen_vs_batch=cfg.shard_seqlen_vs_batch,
216batch_unsharded=cfg.batch_unsharded,
217one_d=one_d,
218)
219
220unembed_fn = partial(
221two_d_parallel_xmap.unembed_manual,
222batch_unsharded=cfg.batch_unsharded,
223one_d=one_d,
224)
225
226forward_pass = partial(
227inference.manual_fwd_pass,
228hparams,
229sharding_config,
230embed_fn,
231layer_fn,
232unembed_fn,
233)
234
235infer_stack = partial(
236inference.infer_template,
237hparams,
238sharding_config,
239forward_pass,
240)
241
242model = incremental.InferenceModel(
243hparams,
244the_vocab.eos_id,
245infer_stack,
246sample_fn,
247mesh,
248rules,
249the_vocab,
250bos_id=cfg.bos_id,
251)
252
253generate_fn = model.instantiate_generating_fn(cfg.generate_steps)
254prefill_fn = model.instantiate_prefill_fn()
255
256if params_already_loaded:
257return model, None, prefill_fn, generate_fn
258else:
259# actually load the weights
260with model.mesh, model.rules:
261params = params_spec.from_checkpoint(hparams, model.mesh, loaded_ckpt)
262
263logging.info('Weights loaded.')
264
265# cs2 = cs.replace(hparams = cs.hparams.replace(heads=64, padded_heads=32))
266params = (
267model.rotate_weights(params, cfg.latency_collectives)
268if cfg.latency_collectives
269else params
270)
271logging.info('Weights formatted.')
272return model, params, prefill_fn, generate_fn
273
274
275@jax.jit
276def find_common_prefix(tokens):
277# find a common prefix
278base_case = tokens[0, :]
279is_equal = jnp.int8(tokens == base_case) # broadcasts across the batch
280equal_at = jnp.prod(is_equal, axis=0) # get a single dimensional array
281cp = jnp.cumprod(equal_at, 0)
282first_non_equal = jnp.argmin(cp) # will get the first 0
283return first_non_equal
284
285
286@jax.jit
287def ce_loss(
288score_result, batch
289):
290"""Cross entropy loss."""
291token_scores = (
292-losses.cross_entropy_with_logits(
293score_result.logits,
294common_utils.onehot(
295batch['decoder_target_tokens'],
296score_result.logits.shape[-1],
297on_value=1,
298off_value=0,
299),
300z_loss=0.0,
301)[0]
302* batch['decoder_loss_weights']
303)
304return token_scores
305
306
307# pylint: disable = g-bare-generic
308# pylint: disable = invalid-name
309@dataclasses.dataclass
310class InferenceT5X(DecoderOnlyModel):
311"""Creates an API that fits T5X."""
312
313model: incremental.InferenceModel
314params: weights.Weights
315prefill_fn: Callable
316generate_fn: Callable
317_batch_size: int
318_input_vocabulary: Vocabulary
319_output_vocabulary: Vocabulary
320sample_ids: jax.Array
321max_input_length: int
322max_generate_length: int
323
324def __init__(
325self,
326cfg,
327_input_vocabulary,
328batch_size,
329task_feature_lengths,
330):
331model, params, prefill_fn, generate_fn = return_minimal_palm(cfg) # pylint: disable = unbalanced-tuple-unpacking
332self.model = model
333self.params = params
334self.prefill_fn = prefill_fn
335self.generate_fn = generate_fn
336self.get_logits_fn = model.instantiate_prefill_fn(return_full_chunk=True)
337self._batch_size = batch_size
338self._input_vocabulary = _input_vocabulary
339self._output_vocabulary = _input_vocabulary
340self.max_input_length = task_feature_lengths['inputs']
341self.max_generate_length = task_feature_lengths['targets']
342
343# make a custom model for the common_prefix / prefill sections
344# this is only function defs not params
345prefix_model_cfg = dataclasses.replace(
346cfg, kv_cache_sharding=0, batch_unsharded=True
347)
348prefix_model, _, prefix_prefill_fn, _ = return_minimal_palm(
349prefix_model_cfg, params_already_loaded=True
350)
351self.prefix_model = prefix_model
352self.prefix_prefill_fn = prefix_prefill_fn
353
354def predict_batch(self, params, batch):
355"""Does an inference step.
356
357Args:
358params: Pytree definition of weights
359batch: assumed to have fields {'decoder_causal_attention': int [batch,
360length], 'decoder_input_tokens': same}
361
362Returns:
363inferences: (output.tokens, {'scores': output_result.per_token_scores})
364tokens is either [batch, tokens] or [batch, num_decodes, tokens]
365"""
366
367return self.predict_batch_with_aux(params, batch)
368
369def predict_batch_with_aux(
370self,
371params,
372batch,
373rng = None,
374num_decodes = 1,
375temperature = 0.7,
376return_all_decodes = True,
377decoder_params=None,
378):
379with jax.named_scope('make_batch'):
380prefix, prompt = self.make_batch(batch)
381processed_cache = self.process_cache(params, prompt, prefix)
382with jax.named_scope('generate'):
383sample_hyperparams = sampling.SamplingHyperParams(temperature=temperature)
384sample_ids = np.arange(self._batch_size * num_decodes)
385output, output_result = self.model.generate(
386params,
387self.generate_fn,
388processed_cache,
389sample_ids,
390sample_hyperparams,
391)
392
393if num_decodes > 1:
394tokens = output.tokens.reshape((self._batch_size, num_decodes, -1))
395scores = output_result.per_token_scores.sum(-1).reshape(
396(self._batch_size, num_decodes)
397)
398else:
399tokens = output.tokens
400scores = output_result.per_token_scores.sum(-1)
401
402inferences = tokens, {
403'scores': scores
404} # none in place of scores for the moment
405
406return inferences
407
408def score_batch(
409self,
410params,
411batch,
412return_intermediates = False,
413):
414inputs_lengths = np.sum(batch['decoder_causal_attention'], axis=1) - 1
415masked_inputs = (
416batch['decoder_input_tokens'] * batch['decoder_causal_attention']
417)
418score_chunk = chunk.Chunk(masked_inputs, inputs_lengths) # [batch, time]
419
420# TODO(sholto): We could play the common prefix trick here too
421score_result = self.model.prefill(
422self.params, self.get_logits_fn, [], score_chunk
423)
424# TODO(sholto): Test if manual version made for cascades uses less memory
425token_scores = ce_loss(score_result, batch)
426sequence_scores = token_scores.sum(-1)
427return sequence_scores
428
429def make_batch(
430self,
431batch,
432extract_prefix = False,
433common_prefix_heuristic = 32,
434):
435inputs_lengths = np.sum(batch['decoder_causal_attention'], axis=1) - 1
436masked_inputs = (
437batch['decoder_input_tokens'] * batch['decoder_causal_attention']
438)
439inputs = masked_inputs[:, : self.max_input_length] # [batch, time]
440
441if extract_prefix:
442# NB: the below is not jax jittable.
443common_prefix = find_common_prefix(inputs) # integer
444# Heuristic for whether prefix extraction is worth doing
445if (common_prefix > common_prefix_heuristic) and (
446self.max_input_length - common_prefix_heuristic > common_prefix
447):
448logging.info('Detected common prefix of length %i', common_prefix)
449prefix = chunk.Chunk(
450jnp.expand_dims(inputs[0, :common_prefix], 0),
451jnp.array([common_prefix]),
452)
453prompt = chunk.Chunk(
454inputs[:, common_prefix:], inputs_lengths - common_prefix
455)
456return prefix, prompt
457# Default to no prefix extraction
458prompt = chunk.Chunk(inputs, inputs_lengths)
459prefix = None
460return prefix, prompt
461
462def process_cache(
463self, params, prompt, prefix=None
464):
465processed_cache = []
466if prefix is not None:
467with jax.named_scope('common_prefill'):
468# the common prefix will be batch size 1, shard appropriately
469common_prefix = self.prefix_model.prefill(
470params, self.prefix_prefill_fn, [], prefix
471)
472processed_cache.append(common_prefix)
473with jax.named_scope('different_prefill'):
474prompt = self.model.prefill(
475params, self.prefill_fn, processed_cache, prompt
476)
477processed_cache.append(prompt)
478return processed_cache
479