google-research
587 строк · 20.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Support for incremental processing on Transformers.
17
18We provide two functions, `prefill` and `generate`, which operate on the
19`Chunk` and `ChunkResult` types from chunk.py.
20* The function `prefill`, sometimes also called 'encode' in Transformer
21codebases, runs a single forwards pass over a batch of input sequences,
22returning scores and KV caches for those tokens.
23* The function `generate`, sometimes also called 'decode' in Transformer
24codebases, generates new text autoregressively, in a sequential loop that
25generates one token at a time per sequence.
26
27Example use cases follow. Each example builds upon the previous.
28
29Example 1: scoring some text
30============================
31
32We create a `Chunk` of input text and then run `prefill` on it.
33
34```
35jitted_model = JittedModel(...)
36
37# Create a batch=1 input chunk of text.
38few_shot_examples = Chunk.tokenize(
39vocab, ["Cows have 4 legs. Fish have 0 legs."], is_first_chunk=True)
40few_shot_examples_result = jitted_model.prefill(params, [], few_shot_examples)
41print(few_shot_examples_result.per_token_scores)
42```
43
44Example 2: generating text using the prompt
45===========================================
46
47We use the `few_shot_examples_result` from the previous example as attention
48context (the KV cache) from which we generate new text.
49
50```
51# Controls random sampling
52my_sampling = Sampling(temperature=0.7)
53# 4 random seeds, so that we generate 4 different outputs.
54sample_ids = jnp.arange(4, jnp.int32)
55generated_text, generated_text_result = jitted_model.generate(
56params, my_sampling, [few_shot_examples_result], sample_ids)
57# Print all 4 samples
58for s in generated_text.detokenize(vocab):
59print(s)
60```
61
62Example 3: Multiple prompts sharing a common prefix
63===================================================
64
65In a few-shot-prompted scenario, we typically have a common prefix (the few-shot
66prompts), shared over a batch of tasks, and for each task we generate multiple
67samples. By splitting each of these steps into its own `prefill` or `generate`
68call, we can do this in a way that maximally exploits the sharing.
69
70In Example 1 we already called `prefill` on single shared sequence which has the
71few-shot examples. Next we call `prefill` on the batch of tasks, using the
72few-shot examples as attention context. It is permissible to have more tasks
73than few-shot examples, as we demonstrate here:
74
75```
76# Prefill a batch=3 set of tasks.
77tasks = Chunk.tokenize(vocab, ["Humans have", "Potatos have", "Dinosaurs have"])
78tasks_result = jitted_model.prefill(params, [few_shot_examples_result], tasks)
79# Generate 2 samples for each task. This sums to 6 samples in total.
80sample_ids = jnp.arange(6, jnp.int32)
81task_samples, task_samples_results = jitted_model.generate(
82params, my_sampling, [few_shot_examples_result, tasks_result], sample_ids)
83```
84
85Example 4: appending even more text, and then generating some more
86==================================================================
87
88If we were in a chatbot scenario, at this point we might append some more
89user-provided text to the context, and then generate yet another response. This
90consists of another call to `prefill` followed by another call to `generate`.
91As this example shows, they can be arbitrarily combined any fashion.
92
93```
94# Add the user response, using `prefill`.
95user_response = Chunk.tokenize(vocab, ["How many legs does a chicken have?"])
96user_response_result = jitted_model.prefill(
97params, [few_shot_examples_result, generated_text_result], user_response)
98# Generate another AI response, using `generate`.
99ai_response_text, ai_response_result = jitted_model.generate(
100params, my_sampling,
101[few_shot_examples_result, generated_text_result, user_response_result],
102sample_ids
103)
104# Print all 4 samples
105for s in generated_text.detokenize(vocab):
106print(s)
107```
108
109TODO(reinerp): Example 4 uses an ever-increasing list of ChunkResults as
110context arguments. In a more realistic chatbot scenario we would concatenate all
111the ChunkResults into a single longer ChunkResult, subject to batch size
112restrictions.
113"""
114
115from dataclasses import dataclass # pylint: disable=g-importing-member116from functools import partial # pylint: disable=g-importing-member117from typing import Any, Callable, List, Optional, Sequence, Tuple, Union118
119import jax120from jax import lax121from jax.experimental.shard_map import shard_map122import jax.numpy as jnp123import jax.scipy124from jax.sharding import Mesh125import numpy as np126from seqio.vocabularies import Vocabulary127
128from scaling_transformer_inference_efficiency import attention129from scaling_transformer_inference_efficiency import checkpoint130from scaling_transformer_inference_efficiency import collectives131from scaling_transformer_inference_efficiency import partitioning132from scaling_transformer_inference_efficiency import weights133from scaling_transformer_inference_efficiency.chunk import Chunk134from scaling_transformer_inference_efficiency.chunk import ChunkResult135from scaling_transformer_inference_efficiency.chunk import FullChunkResult136from scaling_transformer_inference_efficiency.chunk import InferFn137from scaling_transformer_inference_efficiency.sampling import SampleFn138from scaling_transformer_inference_efficiency.sampling import SamplingHyperParams139
140Weights = weights.Weights141P = jax.sharding.PartitionSpec142
143
144# pylint: disable = g-bare-generic
145@dataclass
146class StreamClient:147"""Used to handle streaming results."""148
149prev_token_decoded: Optional[jnp.ndarray] = None150prev_token: Optional[jnp.ndarray] = None151stream_callback: Callable = lambda x: print(x, end='')152stream_done_callback: Callable = lambda: None153
154def find_new_chars(self, vocab: Vocabulary, next_token: np.ndarray):155"""We decode pairs because the tokenizer strips whitespace."""156prefix = self.prev_token_decoded157whole = (158vocab.decode_tf(np.concatenate([self.prev_token, next_token], -1))159.numpy()160.decode('utf-8')161)162new_text = whole[len(prefix) :]163return new_text164
165def stream_result(166self, logits: jax.Array, vocab: Vocabulary, x: int, y: int, z: int167):168"""Steam result back to std. For the moment only stream first element."""169
170if x == 0 and y == 0 and z == 0:171logits = np.array(logits)172current_token = np.array(logits[0:1])173if self.prev_token is None:174new_chars = vocab.decode_tf(current_token).numpy().decode('utf-8')175else:176new_chars = self.find_new_chars(vocab, current_token)177
178self.stream_callback(new_chars)179self.prev_token = current_token # pytype: disable=annotation-type-mismatch # jax-ndarray180self.prev_token_decoded = new_chars.lstrip(' ').rstrip(' ')181
182def clear_prev_token(self):183self.prev_token = None184self.stream_done_callback()185
186
187def _bos_logits(vocab_size: int, bos_id: int = 0) -> jnp.ndarray:188"""Logits that put assign probability 1.0 to on _BOS_ID."""189logits = jnp.full((vocab_size,), -1e10)190return logits.at[bos_id].set(0.0)191
192
193class InferenceModel:194"""A model with xmapped JIT-compiled prefill and generate functions."""195
196def __init__(197self,198hparams: checkpoint.HParams,199eos_id: int,200infer_fn: InferFn,201sample_fn: SampleFn,202mesh: Mesh,203rules: Sequence[Tuple[str, Any]],204vocab: Optional[Vocabulary] = None,205bos_id: Optional[int] = None, # Allow to overwrite the default value.206):207self._hparams = hparams208self._eos_id = eos_id209self._infer = infer_fn210self._sample = sample_fn211self.mesh = mesh212self.rules = partitioning.PartitioningRules(rules)213with self.rules:214self.sample_ids_sharding = partitioning.logical_to_physical(215P('logit_batch')216)217self.embeddings_logical = P(218'residual_batch', 'residual_time', 'residual_embed'219)220self.embeddings_sharding = jax.tree_map(221partitioning.logical_to_physical, self.embeddings_logical222)223self.vocab = vocab224if bos_id is None:225if vocab is not None:226bos_id = vocab.bos_id227else:228bos_id = 0229self.bos_id = bos_id230# _prefill_p: maps num_prefixes -> jitted _prefill_impl function231self._prefill_p = {}232# _score_p: maps num_prefixes -> jitted _generate_impl function233self._generate_p = {}234
235def rotate_weights(self, params: Weights, latency: bool = True) -> Weights:236"""Rotate the weights for the collectives.237
238Assumed to occur in a per device form. Assumes 2D partitioning.
239q_wi: [layers, heads.YZ, dmodel.X, q_wi_per_head]
240o_wo: [layers, heads.YZ, owo_per_head, dmodel.X]
241
242Args:
243params: unmodified
244latency: Whether to do latency collectives
245
246Returns:
247params: new parameters, rotated for a given collective
248"""
249
250def rotate(params):251new_layer = params.layer252if latency:253new_layer = new_layer.replace(254q_wi=collectives.preshuffle_for_reducescatter_latency(255new_layer.q_wi, scatter_axis=1, axis_name='x'256)257)258new_layer = new_layer.replace(259o_wo=collectives.preshuffle_for_allgather_matmul_latency(260new_layer.o_wo, shuffle_axis=1, axis_name='x'261)262)263else:264new_layer = new_layer.replace(265q_wi=collectives.preshuffle_for_reducescatter_throughput(266new_layer.q_wi, scatter_axis=1, subsplit_axis=3, axis_name='x'267)268)269new_layer = new_layer.replace(270o_wo=collectives.preshuffle_for_allgather_matmul_throughput(271new_layer.o_wo, shuffle_axis=1, axis_name='x'272)273)274
275return params.replace(layer=new_layer)276
277with self.mesh, self.rules:278params = jax.jit(279shard_map(280rotate,281self.mesh,282in_specs=(params.physical_axes(),),283out_specs=params.physical_axes(),284check_rep=False,285),286donate_argnums=(0,),287)(params)288
289return params290
291# pylint: disable = g-bare-generic292# pylint: disable = protected-access293@staticmethod294def _prefill_impl(295model,296params: Weights,297cache: Sequence[ChunkResult],298chunk: Chunk,299prev_logits: Optional[jnp.ndarray],300pre_embedded_inputs: Optional[jax.Array] = None,301return_full_chunk: bool = False,302) -> Union[ChunkResult, FullChunkResult]:303"""Wrap both prefill and results formatting in a single xmap call."""304if pre_embedded_inputs is not None:305full_chunk_result = model._infer(306params, cache, chunk, pre_embedded_inputs=pre_embedded_inputs307)308else:309full_chunk_result = model._infer(params, cache, chunk)310if return_full_chunk:311return full_chunk_result312else:313return full_chunk_result.to_chunk_result(314prev_logits, chunk, bos_id=model.bos_id315)316
317def instantiate_prefill_fn(self, return_full_chunk: bool = False):318return partial(319self._prefill_impl,320self,321return_full_chunk=return_full_chunk,322)323
324def prefill(325self,326params: Weights,327prefill_impl: Callable,328prefix: Sequence[ChunkResult],329chunk: Chunk,330pre_embedded_inputs: Optional[jax.Array] = None,331) -> Union[ChunkResult, FullChunkResult]:332"""Non-generative inference for a batch.333
334Args:
335params: Model weights.
336prefill_impl: Partialed prefillimpl
337prefix: Already-processed tokens in the prefix, if any.
338chunk: The tokens to prefill.
339pre_embedded_inputs: If we want to do the embeddings outside (e.g. for
340prompt tuning)
341
342Returns:
343Scores for the batch.
344"""
345with self.mesh, self.rules:346if prefix:347prev_logits = prefix[-1].next_token_logits348else:349prev_logits = None350
351prefix = [p.kv_cache for p in prefix]352return jax.jit(prefill_impl)(353params, prefix, chunk, prev_logits, pre_embedded_inputs354)355
356@staticmethod357def create_output_buffer(358hparams: checkpoint.HParams,359sample_ids: jnp.ndarray,360prefix: List[attention.KVCache],361length: int,362prev_chunk_next_token_logits: Optional[jnp.ndarray] = None,363circular: bool = False,364bos_id: int = 0,365):366"""Create everything we need to deterministically write output samples."""367# Seeding of the RNG itself is deterministic.368# To generate different samples, users can provide sample_number_offset.369(batch,) = sample_ids.shape370sample_rngs = jax.vmap(jax.random.fold_in, in_axes=(None, 0))( # pytype: disable=wrong-arg-types # jax-ndarray371jax.random.PRNGKey(0), sample_ids372)373token_indexes_start = attention.prefix_lengths(prefix)374token_indexes_start = attention.flat_broadcast(token_indexes_start, batch)375
376# Generation loop.377last_logits = prev_chunk_next_token_logits378if last_logits is None:379last_logits = _bos_logits(hparams.vocab, bos_id)[np.newaxis, :]380last_logits = attention.flat_broadcast(last_logits, batch)381chunk = Chunk.zeros(batch, length)382chunk_result = ChunkResult.zeros(hparams, batch, length, circular=circular)383chunk_result = chunk_result.replace(next_token_logits=last_logits)384return sample_rngs, token_indexes_start, chunk, chunk_result385
386@staticmethod387def sample_infer_write(388model,389params: Weights,390prefix: List[attention.KVCache],391sample_params: SamplingHyperParams,392token_indexes_start: jnp.ndarray,393sample_rngs: jnp.ndarray,394write_index: int,395state: Tuple[Chunk, ChunkResult],396):397"""Samples prev inference, infers, writes to cache.398
399We sample first then do the next inference step because we already have
400prexisting logits from prefill, so it saves us a step. Additionally, it
401lowers better.
402We have two different chunks at play in this loop:
4031. `chunk`/`chunk_result`, of shape (batch, steps). This is the
404mutable state that we fill one token at a time, starting at step=0
405and going up to step=steps-1.
4062. `token_chunk`/`token_chunk_result`/`token_full_chunk_result`, of
407shape (batch, 1). These are the input/output of a single forwards
408pass through the model, which processes just one chunk per batch
409element.
410
411We write token_* into chunk/chunk_result at a new position each loop
412iteration.
413
414Args:
415model: InferenceModel
416params: Model weights
417prefix: pre_existing kv_cache
418sample_params: Temperature etc
419token_indexes_start: Per element token index in the full sequence, for
420fully deterministic sampling.
421sample_rngs: Rng per element
422write_index: current index in the state. Always the same for all elements
423so that it is a slice, not a scatter.
424state: chunk / chunk_result pair described above. May be a circular buffer
425Returns:
426chunk: Written to (for tokens)
427chunk_result: Written to (for KV_cache)
428"""
429batch, _ = sample_rngs.shape430chunk, chunk_result = state431step_rngs = jax.vmap(jax.random.fold_in)( # pytype: disable=wrong-arg-types # jax-ndarray432sample_rngs, token_indexes_start + chunk.lengths433)434next_token = model._sample(435chunk_result.next_token_logits, step_rngs, sample_params, model.mesh,436)437# ^ next_token: [batch]438token_chunk = Chunk(439tokens=next_token[:, np.newaxis],440lengths=jnp.full((batch,), 1, jnp.int32),441)442# ^ token_chunk: Chunk[batch, 1]443
444token_full_chunk_result = model._infer(445params, prefix + [chunk_result.kv_cache], token_chunk446)447chunk = chunk.update(write_index, token_chunk)448chunk_result = chunk_result.update(449write_index,450token_chunk,451token_full_chunk_result,452bos_id=model.bos_id,453)454return chunk, chunk_result455
456# pylint: disable = protected-access457# pylint: disable = g-long-lambda458# pylint: disable = unnecessary-lambda459# pytype: disable=attribute-error460# pytype: disable=bad-unpacking461@partial(jax.jit, static_argnums=(0, 5, 6, 7, 8))462def _generate_impl(463self,464params: Weights,465prefix: List[attention.KVCache],466prev_chunk_next_token_logits: jnp.ndarray,467sample_ids: jnp.ndarray,468sample_params: SamplingHyperParams,469steps: int,470stream: Optional[StreamClient] = None,471return_all_logits=False,472) -> Union[Tuple[Chunk, ChunkResult], Tuple[Chunk, ChunkResult, jax.Array]]:473"""Generates a chunk of text, given some prefixes. Jitted function."""474del return_all_logits # TODO(sholto): For all logit scoring475(batch,) = sample_ids.shape476del stream # TODO(sholto): reimplement once shardmap callback is done477
478sample_rngs, token_indexes_start, chunk, chunk_result = (479self.create_output_buffer(480self._hparams,481sample_ids,482prefix,483steps,484prev_chunk_next_token_logits,485circular=False,486bos_id=self.bos_id,487)488)489
490loop_body = partial(491self.sample_infer_write,492self,493params,494prefix,495sample_params,496token_indexes_start,497sample_rngs,498)499
500chunk, chunk_result = lax.fori_loop(5010, steps, loop_body, (chunk, chunk_result)502)503
504# The length of the chunk is the index of the first EOS ID. We don't505# calculate this during the generation loop, so instead we calculate it now.506is_eos = chunk.tokens == self._eos_id507token_iota = lax.broadcasted_iota(jnp.int32, (batch, steps), 1)508chunk = chunk.replace(509lengths=jnp.min(jnp.where(is_eos, token_iota, steps), axis=1)510)511
512return chunk, chunk_result513
514def instantiate_generating_fn(515self,516steps: int,517stream: Optional[StreamClient] = None,518return_all_logits=False,519) -> Callable: # pylint: disable = g-bare-generic520"""Create partial fn to ensure caching."""521
522return partial(523self._generate_impl,524steps=steps,525stream=stream,526return_all_logits=return_all_logits,527)528
529def generate(530self,531params: Weights,532generate_fn: Callable, # pylint: disable = g-bare-generic533prefix: Sequence[ChunkResult],534sample_ids: jnp.ndarray,535sample_params: SamplingHyperParams,536) -> Tuple[Chunk, ChunkResult]:537"""Generative inference for a batch.538
539Note about random number seeding:
540We provide strong guarantees about random number seeding, to make it
541possible for callers to get deterministic results that are independent of
542batch packing and independent of splitting into `Chunk`s for incremental
543processing. Specifically, we guarantee that random numbers are constructed
544by:
545
546```
547def rng_for_token(sample_id: int, token_index: int) -> jax.Array:
548rng = jax.random.PRNGKey(0)
549rng = jax.random.fold_in(rng, sample_id)
550rng = jax.random.fold_in(rng, token_index)
551return rng
552```
553
554Here, `sample_id` is taken from the `sample_ids` array provided by the
555caller, and `token_index` is the number of non-padding tokens in this sample
556prior to the token currently being generated. This scheme that any text
557generated with the same `sample_id`, from the same prefix, using the same
558sampling hyperparameters, will make the same random decisions and will
559therefore be deterministic, independent of batch packing or chunk splitting.
560
561Args:
562params: Model weights.
563generate_fn: Cached generation fn
564prefix: Already-processed tokens in the prefix, if any.
565sample_ids: Per-sample random seeds to use for sampling. By convention,
566you should generally number these sequentially starting from 0.
567int32[num_samples]
568sample_params: sampling parameters
569
570Returns:
571The generated text, together with its processed results.
572"""
573with self.mesh, self.rules:574if prefix:575prev_chunk_next_token_logits = prefix[-1].next_token_logits576else:577prev_chunk_next_token_logits = None578
579cache = [p.kv_cache for p in prefix]580
581return generate_fn(582params=params,583prefix=cache,584prev_chunk_next_token_logits=prev_chunk_next_token_logits,585sample_ids=sample_ids,586sample_params=sample_params,587)588