google-research

Форк
0
587 строк · 20.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Support for incremental processing on Transformers.
17

18
We provide two functions, `prefill` and `generate`, which operate on the
19
`Chunk` and `ChunkResult` types from chunk.py.
20
* The function `prefill`, sometimes also called 'encode' in Transformer
21
  codebases, runs a single forwards pass over a batch of input sequences,
22
  returning scores and KV caches for those tokens.
23
* The function `generate`, sometimes also called 'decode' in Transformer
24
  codebases, generates new text autoregressively, in a sequential loop that
25
  generates one token at a time per sequence.
26

27
Example use cases follow. Each example builds upon the previous.
28

29
Example 1: scoring some text
30
============================
31

32
We create a `Chunk` of input text and then run `prefill` on it.
33

34
```
35
jitted_model = JittedModel(...)
36

37
# Create a batch=1 input chunk of text.
38
few_shot_examples = Chunk.tokenize(
39
    vocab, ["Cows have 4 legs. Fish have 0 legs."], is_first_chunk=True)
40
few_shot_examples_result = jitted_model.prefill(params, [], few_shot_examples)
41
print(few_shot_examples_result.per_token_scores)
42
```
43

44
Example 2: generating text using the prompt
45
===========================================
46

47
We use the `few_shot_examples_result` from the previous example as attention
48
context (the KV cache) from which we generate new text.
49

50
```
51
# Controls random sampling
52
my_sampling = Sampling(temperature=0.7)
53
# 4 random seeds, so that we generate 4 different outputs.
54
sample_ids = jnp.arange(4, jnp.int32)
55
generated_text, generated_text_result = jitted_model.generate(
56
    params, my_sampling, [few_shot_examples_result], sample_ids)
57
# Print all 4 samples
58
for s in generated_text.detokenize(vocab):
59
  print(s)
60
```
61

62
Example 3: Multiple prompts sharing a common prefix
63
===================================================
64

65
In a few-shot-prompted scenario, we typically have a common prefix (the few-shot
66
prompts), shared over a batch of tasks, and for each task we generate multiple
67
samples. By splitting each of these steps into its own `prefill` or `generate`
68
call, we can do this in a way that maximally exploits the sharing.
69

70
In Example 1 we already called `prefill` on single shared sequence which has the
71
few-shot examples. Next we call `prefill` on the batch of tasks, using the
72
few-shot examples as attention context. It is permissible to have more tasks
73
than few-shot examples, as we demonstrate here:
74

75
```
76
# Prefill a batch=3 set of tasks.
77
tasks = Chunk.tokenize(vocab, ["Humans have", "Potatos have", "Dinosaurs have"])
78
tasks_result = jitted_model.prefill(params, [few_shot_examples_result], tasks)
79
# Generate 2 samples for each task. This sums to 6 samples in total.
80
sample_ids = jnp.arange(6, jnp.int32)
81
task_samples, task_samples_results = jitted_model.generate(
82
    params, my_sampling, [few_shot_examples_result, tasks_result], sample_ids)
83
```
84

85
Example 4: appending even more text, and then generating some more
86
==================================================================
87

88
If we were in a chatbot scenario, at this point we might append some more
89
user-provided text to the context, and then generate yet another response. This
90
consists of another call to `prefill` followed by another call to `generate`.
91
As this example shows, they can be arbitrarily combined any fashion.
92

93
```
94
# Add the user response, using `prefill`.
95
user_response = Chunk.tokenize(vocab, ["How many legs does a chicken have?"])
96
user_response_result = jitted_model.prefill(
97
    params, [few_shot_examples_result, generated_text_result], user_response)
98
# Generate another AI response, using `generate`.
99
ai_response_text, ai_response_result = jitted_model.generate(
100
    params, my_sampling,
101
    [few_shot_examples_result, generated_text_result, user_response_result],
102
    sample_ids
103
)
104
# Print all 4 samples
105
for s in generated_text.detokenize(vocab):
106
  print(s)
107
```
108

109
TODO(reinerp): Example 4 uses an ever-increasing list of ChunkResults as
110
context arguments. In a more realistic chatbot scenario we would concatenate all
111
the ChunkResults into a single longer ChunkResult, subject to batch size
112
restrictions.
113
"""
114

115
from dataclasses import dataclass  # pylint: disable=g-importing-member
116
from functools import partial  # pylint: disable=g-importing-member
117
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
118

119
import jax
120
from jax import lax
121
from jax.experimental.shard_map import shard_map
122
import jax.numpy as jnp
123
import jax.scipy
124
from jax.sharding import Mesh
125
import numpy as np
126
from seqio.vocabularies import Vocabulary
127

128
from scaling_transformer_inference_efficiency import attention
129
from scaling_transformer_inference_efficiency import checkpoint
130
from scaling_transformer_inference_efficiency import collectives
131
from scaling_transformer_inference_efficiency import partitioning
132
from scaling_transformer_inference_efficiency import weights
133
from scaling_transformer_inference_efficiency.chunk import Chunk
134
from scaling_transformer_inference_efficiency.chunk import ChunkResult
135
from scaling_transformer_inference_efficiency.chunk import FullChunkResult
136
from scaling_transformer_inference_efficiency.chunk import InferFn
137
from scaling_transformer_inference_efficiency.sampling import SampleFn
138
from scaling_transformer_inference_efficiency.sampling import SamplingHyperParams
139

140
Weights = weights.Weights
141
P = jax.sharding.PartitionSpec
142

143

144
# pylint: disable = g-bare-generic
145
@dataclass
146
class StreamClient:
147
  """Used to handle streaming results."""
148

149
  prev_token_decoded: Optional[jnp.ndarray] = None
150
  prev_token: Optional[jnp.ndarray] = None
151
  stream_callback: Callable = lambda x: print(x, end='')
152
  stream_done_callback: Callable = lambda: None
153

154
  def find_new_chars(self, vocab: Vocabulary, next_token: np.ndarray):
155
    """We decode pairs because the tokenizer strips whitespace."""
156
    prefix = self.prev_token_decoded
157
    whole = (
158
        vocab.decode_tf(np.concatenate([self.prev_token, next_token], -1))
159
        .numpy()
160
        .decode('utf-8')
161
    )
162
    new_text = whole[len(prefix) :]
163
    return new_text
164

165
  def stream_result(
166
      self, logits: jax.Array, vocab: Vocabulary, x: int, y: int, z: int
167
  ):
168
    """Steam result back to std. For the moment only stream first element."""
169

170
    if x == 0 and y == 0 and z == 0:
171
      logits = np.array(logits)
172
      current_token = np.array(logits[0:1])
173
      if self.prev_token is None:
174
        new_chars = vocab.decode_tf(current_token).numpy().decode('utf-8')
175
      else:
176
        new_chars = self.find_new_chars(vocab, current_token)
177

178
      self.stream_callback(new_chars)
179
      self.prev_token = current_token  # pytype: disable=annotation-type-mismatch  # jax-ndarray
180
      self.prev_token_decoded = new_chars.lstrip(' ').rstrip(' ')
181

182
  def clear_prev_token(self):
183
    self.prev_token = None
184
    self.stream_done_callback()
185

186

187
def _bos_logits(vocab_size: int, bos_id: int = 0) -> jnp.ndarray:
188
  """Logits that put assign probability 1.0 to on _BOS_ID."""
189
  logits = jnp.full((vocab_size,), -1e10)
190
  return logits.at[bos_id].set(0.0)
191

192

193
class InferenceModel:
194
  """A model with xmapped JIT-compiled prefill and generate functions."""
195

196
  def __init__(
197
      self,
198
      hparams: checkpoint.HParams,
199
      eos_id: int,
200
      infer_fn: InferFn,
201
      sample_fn: SampleFn,
202
      mesh: Mesh,
203
      rules: Sequence[Tuple[str, Any]],
204
      vocab: Optional[Vocabulary] = None,
205
      bos_id: Optional[int] = None,  # Allow to overwrite the default value.
206
  ):
207
    self._hparams = hparams
208
    self._eos_id = eos_id
209
    self._infer = infer_fn
210
    self._sample = sample_fn
211
    self.mesh = mesh
212
    self.rules = partitioning.PartitioningRules(rules)
213
    with self.rules:
214
      self.sample_ids_sharding = partitioning.logical_to_physical(
215
          P('logit_batch')
216
      )
217
      self.embeddings_logical = P(
218
          'residual_batch', 'residual_time', 'residual_embed'
219
      )
220
      self.embeddings_sharding = jax.tree_map(
221
          partitioning.logical_to_physical, self.embeddings_logical
222
      )
223
    self.vocab = vocab
224
    if bos_id is None:
225
      if vocab is not None:
226
        bos_id = vocab.bos_id
227
      else:
228
        bos_id = 0
229
    self.bos_id = bos_id
230
    # _prefill_p: maps num_prefixes -> jitted _prefill_impl function
231
    self._prefill_p = {}
232
    # _score_p: maps num_prefixes -> jitted _generate_impl function
233
    self._generate_p = {}
234

235
  def rotate_weights(self, params: Weights, latency: bool = True) -> Weights:
236
    """Rotate the weights for the collectives.
237

238
    Assumed to occur in a per device form. Assumes 2D partitioning.
239
    q_wi: [layers, heads.YZ, dmodel.X, q_wi_per_head]
240
    o_wo: [layers, heads.YZ, owo_per_head, dmodel.X]
241

242
    Args:
243
      params: unmodified
244
      latency: Whether to do latency collectives
245

246
    Returns:
247
      params: new parameters, rotated for a given collective
248
    """
249

250
    def rotate(params):
251
      new_layer = params.layer
252
      if latency:
253
        new_layer = new_layer.replace(
254
            q_wi=collectives.preshuffle_for_reducescatter_latency(
255
                new_layer.q_wi, scatter_axis=1, axis_name='x'
256
            )
257
        )
258
        new_layer = new_layer.replace(
259
            o_wo=collectives.preshuffle_for_allgather_matmul_latency(
260
                new_layer.o_wo, shuffle_axis=1, axis_name='x'
261
            )
262
        )
263
      else:
264
        new_layer = new_layer.replace(
265
            q_wi=collectives.preshuffle_for_reducescatter_throughput(
266
                new_layer.q_wi, scatter_axis=1, subsplit_axis=3, axis_name='x'
267
            )
268
        )
269
        new_layer = new_layer.replace(
270
            o_wo=collectives.preshuffle_for_allgather_matmul_throughput(
271
                new_layer.o_wo, shuffle_axis=1, axis_name='x'
272
            )
273
        )
274

275
      return params.replace(layer=new_layer)
276

277
    with self.mesh, self.rules:
278
      params = jax.jit(
279
          shard_map(
280
              rotate,
281
              self.mesh,
282
              in_specs=(params.physical_axes(),),
283
              out_specs=params.physical_axes(),
284
              check_rep=False,
285
          ),
286
          donate_argnums=(0,),
287
      )(params)
288

289
      return params
290

291
  # pylint: disable = g-bare-generic
292
  # pylint: disable = protected-access
293
  @staticmethod
294
  def _prefill_impl(
295
      model,
296
      params: Weights,
297
      cache: Sequence[ChunkResult],
298
      chunk: Chunk,
299
      prev_logits: Optional[jnp.ndarray],
300
      pre_embedded_inputs: Optional[jax.Array] = None,
301
      return_full_chunk: bool = False,
302
  ) -> Union[ChunkResult, FullChunkResult]:
303
    """Wrap both prefill and results formatting in a single xmap call."""
304
    if pre_embedded_inputs is not None:
305
      full_chunk_result = model._infer(
306
          params, cache, chunk, pre_embedded_inputs=pre_embedded_inputs
307
      )
308
    else:
309
      full_chunk_result = model._infer(params, cache, chunk)
310
    if return_full_chunk:
311
      return full_chunk_result
312
    else:
313
      return full_chunk_result.to_chunk_result(
314
          prev_logits, chunk, bos_id=model.bos_id
315
      )
316

317
  def instantiate_prefill_fn(self, return_full_chunk: bool = False):
318
    return partial(
319
        self._prefill_impl,
320
        self,
321
        return_full_chunk=return_full_chunk,
322
    )
323

324
  def prefill(
325
      self,
326
      params: Weights,
327
      prefill_impl: Callable,
328
      prefix: Sequence[ChunkResult],
329
      chunk: Chunk,
330
      pre_embedded_inputs: Optional[jax.Array] = None,
331
  ) -> Union[ChunkResult, FullChunkResult]:
332
    """Non-generative inference for a batch.
333

334
    Args:
335
      params: Model weights.
336
      prefill_impl: Partialed prefillimpl
337
      prefix: Already-processed tokens in the prefix, if any.
338
      chunk: The tokens to prefill.
339
      pre_embedded_inputs: If we want to do the embeddings outside (e.g. for
340
        prompt tuning)
341

342
    Returns:
343
      Scores for the batch.
344
    """
345
    with self.mesh, self.rules:
346
      if prefix:
347
        prev_logits = prefix[-1].next_token_logits
348
      else:
349
        prev_logits = None
350

351
      prefix = [p.kv_cache for p in prefix]
352
      return jax.jit(prefill_impl)(
353
          params, prefix, chunk, prev_logits, pre_embedded_inputs
354
      )
355

356
  @staticmethod
357
  def create_output_buffer(
358
      hparams: checkpoint.HParams,
359
      sample_ids: jnp.ndarray,
360
      prefix: List[attention.KVCache],
361
      length: int,
362
      prev_chunk_next_token_logits: Optional[jnp.ndarray] = None,
363
      circular: bool = False,
364
      bos_id: int = 0,
365
  ):
366
    """Create everything we need to deterministically write output samples."""
367
    # Seeding of the RNG itself is deterministic.
368
    # To generate different samples, users can provide sample_number_offset.
369
    (batch,) = sample_ids.shape
370
    sample_rngs = jax.vmap(jax.random.fold_in, in_axes=(None, 0))(  # pytype: disable=wrong-arg-types  # jax-ndarray
371
        jax.random.PRNGKey(0), sample_ids
372
    )
373
    token_indexes_start = attention.prefix_lengths(prefix)
374
    token_indexes_start = attention.flat_broadcast(token_indexes_start, batch)
375

376
    # Generation loop.
377
    last_logits = prev_chunk_next_token_logits
378
    if last_logits is None:
379
      last_logits = _bos_logits(hparams.vocab, bos_id)[np.newaxis, :]
380
    last_logits = attention.flat_broadcast(last_logits, batch)
381
    chunk = Chunk.zeros(batch, length)
382
    chunk_result = ChunkResult.zeros(hparams, batch, length, circular=circular)
383
    chunk_result = chunk_result.replace(next_token_logits=last_logits)
384
    return sample_rngs, token_indexes_start, chunk, chunk_result
385

386
  @staticmethod
387
  def sample_infer_write(
388
      model,
389
      params: Weights,
390
      prefix: List[attention.KVCache],
391
      sample_params: SamplingHyperParams,
392
      token_indexes_start: jnp.ndarray,
393
      sample_rngs: jnp.ndarray,
394
      write_index: int,
395
      state: Tuple[Chunk, ChunkResult],
396
  ):
397
    """Samples prev inference, infers, writes to cache.
398

399
    We sample first then do the next inference step because we already have
400
    prexisting logits from prefill, so it saves us a step. Additionally, it
401
    lowers better.
402
    We have two different chunks at play in this loop:
403
    1. `chunk`/`chunk_result`, of shape (batch, steps). This is the
404
       mutable state that we fill one token at a time, starting at step=0
405
       and going up to step=steps-1.
406
    2. `token_chunk`/`token_chunk_result`/`token_full_chunk_result`, of
407
       shape (batch, 1). These are the input/output of a single forwards
408
       pass through the model, which processes just one chunk per batch
409
       element.
410

411
    We write token_* into chunk/chunk_result at a new position each loop
412
    iteration.
413

414
    Args:
415
      model: InferenceModel
416
      params: Model weights
417
      prefix: pre_existing kv_cache
418
      sample_params: Temperature etc
419
      token_indexes_start: Per element token index in the full sequence, for
420
        fully deterministic sampling.
421
      sample_rngs: Rng per element
422
      write_index: current index in the state. Always the same for all elements
423
        so that it is a slice, not a scatter.
424
      state: chunk / chunk_result pair described above. May be a circular buffer
425
    Returns:
426
      chunk: Written to (for tokens)
427
      chunk_result: Written to (for KV_cache)
428
    """
429
    batch, _ = sample_rngs.shape
430
    chunk, chunk_result = state
431
    step_rngs = jax.vmap(jax.random.fold_in)(  # pytype: disable=wrong-arg-types  # jax-ndarray
432
        sample_rngs, token_indexes_start + chunk.lengths
433
    )
434
    next_token = model._sample(
435
        chunk_result.next_token_logits, step_rngs, sample_params, model.mesh,
436
    )
437
    # ^ next_token: [batch]
438
    token_chunk = Chunk(
439
        tokens=next_token[:, np.newaxis],
440
        lengths=jnp.full((batch,), 1, jnp.int32),
441
    )
442
    # ^ token_chunk: Chunk[batch, 1]
443

444
    token_full_chunk_result = model._infer(
445
        params, prefix + [chunk_result.kv_cache], token_chunk
446
    )
447
    chunk = chunk.update(write_index, token_chunk)
448
    chunk_result = chunk_result.update(
449
        write_index,
450
        token_chunk,
451
        token_full_chunk_result,
452
        bos_id=model.bos_id,
453
    )
454
    return chunk, chunk_result
455

456
  # pylint: disable = protected-access
457
  # pylint: disable = g-long-lambda
458
  # pylint: disable = unnecessary-lambda
459
  # pytype: disable=attribute-error
460
  # pytype: disable=bad-unpacking
461
  @partial(jax.jit, static_argnums=(0, 5, 6, 7, 8))
462
  def _generate_impl(
463
      self,
464
      params: Weights,
465
      prefix: List[attention.KVCache],
466
      prev_chunk_next_token_logits: jnp.ndarray,
467
      sample_ids: jnp.ndarray,
468
      sample_params: SamplingHyperParams,
469
      steps: int,
470
      stream: Optional[StreamClient] = None,
471
      return_all_logits=False,
472
  ) -> Union[Tuple[Chunk, ChunkResult], Tuple[Chunk, ChunkResult, jax.Array]]:
473
    """Generates a chunk of text, given some prefixes. Jitted function."""
474
    del return_all_logits  # TODO(sholto): For all logit scoring
475
    (batch,) = sample_ids.shape
476
    del stream  # TODO(sholto): reimplement once shardmap callback is done
477

478
    sample_rngs, token_indexes_start, chunk, chunk_result = (
479
        self.create_output_buffer(
480
            self._hparams,
481
            sample_ids,
482
            prefix,
483
            steps,
484
            prev_chunk_next_token_logits,
485
            circular=False,
486
            bos_id=self.bos_id,
487
        )
488
    )
489

490
    loop_body = partial(
491
        self.sample_infer_write,
492
        self,
493
        params,
494
        prefix,
495
        sample_params,
496
        token_indexes_start,
497
        sample_rngs,
498
    )
499

500
    chunk, chunk_result = lax.fori_loop(
501
        0, steps, loop_body, (chunk, chunk_result)
502
    )
503

504
    # The length of the chunk is the index of the first EOS ID. We don't
505
    # calculate this during the generation loop, so instead we calculate it now.
506
    is_eos = chunk.tokens == self._eos_id
507
    token_iota = lax.broadcasted_iota(jnp.int32, (batch, steps), 1)
508
    chunk = chunk.replace(
509
        lengths=jnp.min(jnp.where(is_eos, token_iota, steps), axis=1)
510
    )
511

512
    return chunk, chunk_result
513

514
  def instantiate_generating_fn(
515
      self,
516
      steps: int,
517
      stream: Optional[StreamClient] = None,
518
      return_all_logits=False,
519
  ) -> Callable:  # pylint: disable = g-bare-generic
520
    """Create partial fn to ensure caching."""
521

522
    return partial(
523
        self._generate_impl,
524
        steps=steps,
525
        stream=stream,
526
        return_all_logits=return_all_logits,
527
    )
528

529
  def generate(
530
      self,
531
      params: Weights,
532
      generate_fn: Callable,  # pylint: disable = g-bare-generic
533
      prefix: Sequence[ChunkResult],
534
      sample_ids: jnp.ndarray,
535
      sample_params: SamplingHyperParams,
536
  ) -> Tuple[Chunk, ChunkResult]:
537
    """Generative inference for a batch.
538

539
    Note about random number seeding:
540
    We provide strong guarantees about random number seeding, to make it
541
    possible for callers to get deterministic results that are independent of
542
    batch packing and independent of splitting into `Chunk`s for incremental
543
    processing. Specifically, we guarantee that random numbers are constructed
544
    by:
545

546
    ```
547
    def rng_for_token(sample_id: int, token_index: int) -> jax.Array:
548
      rng = jax.random.PRNGKey(0)
549
      rng = jax.random.fold_in(rng, sample_id)
550
      rng = jax.random.fold_in(rng, token_index)
551
      return rng
552
    ```
553

554
    Here, `sample_id` is taken from the `sample_ids` array provided by the
555
    caller, and `token_index` is the number of non-padding tokens in this sample
556
    prior to the token currently being generated. This scheme that any text
557
    generated with the same `sample_id`, from the same prefix, using the same
558
    sampling hyperparameters, will make the same random decisions and will
559
    therefore be deterministic, independent of batch packing or chunk splitting.
560

561
    Args:
562
      params: Model weights.
563
      generate_fn: Cached generation fn
564
      prefix: Already-processed tokens in the prefix, if any.
565
      sample_ids: Per-sample random seeds to use for sampling. By convention,
566
        you should generally number these sequentially starting from 0.
567
        int32[num_samples]
568
      sample_params: sampling parameters
569

570
    Returns:
571
      The generated text, together with its processed results.
572
    """
573
    with self.mesh, self.rules:
574
      if prefix:
575
        prev_chunk_next_token_logits = prefix[-1].next_token_logits
576
      else:
577
        prev_chunk_next_token_logits = None
578

579
      cache = [p.kv_cache for p in prefix]
580

581
      return generate_fn(
582
          params=params,
583
          prefix=cache,
584
          prev_chunk_next_token_logits=prev_chunk_next_token_logits,
585
          sample_ids=sample_ids,
586
          sample_params=sample_params,
587
      )
588

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.