google-research

Форк
0
606 строк · 25.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""T5X decoding routine for arithmetic sampling."""
17
import functools
18

19
from typing import Any, Callable, Mapping, Optional, Tuple, Union
20
import flax
21
import jax
22
from jax import lax
23
from jax import random
24
import jax.numpy as jnp
25
import numpy as np
26
from t5x import decoding
27

28
# Constants
29
# "Effective negative infinity" constant for masking in beam search.
30
NEG_INF = np.array(-1.0e7)
31

32
# Temperatures lower than this are considered 0.0, which is handled specially
33
# with a conditional. This is to avoid numeric issues from exponentiating on
34
# 1.0/temperature when temperature is close to 0.0.
35
MIN_TEMPERATURE = np.array(1e-4)
36

37
#------------------------------------------------------------------------------
38
# Arithmetic Sampling
39
#------------------------------------------------------------------------------
40

41

42
@flax.struct.dataclass
43
class ArithmeticSamplingLoopState:
44
  """Holds sampling state data.
45

46
  Attributes:
47
    cur_index: [batch_size] array position of the sampling loop in the length
48
      dimension.
49
    sequences: [batch_size * num_decodes, max_decode_len] array of current
50
      sampled sequence prefixes.
51
    cache: any mapping of arrays, e.g. flax attention cache.
52
    cur_token: [batch_size, num_decodes] single timestep slice containing
53
      current tokens.
54
    ended: [batch_size, num_decodes] binary array marking completed sequences.
55
    rng: Jax PRNGKey
56
    log_prob: [batch_size, num_decodes] array of log probs for each sequence.
57
    codes: [batch_size, num_decodes] array containing the arithmetic codes for
58
      the remainder of the sequence at the current time step for each sample.
59
  """
60
  cur_index: jnp.ndarray
61
  sequences: jnp.ndarray
62
  cache: Mapping[str, jnp.ndarray]
63
  cur_token: jnp.ndarray
64
  ended: jnp.ndarray
65
  rng: jnp.ndarray
66
  log_prob: jnp.ndarray
67
  codes: jnp.ndarray
68

69

70
_dynamic_update_vector_slice_in_dim = jax.vmap(
71
    lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None))
72

73

74
def _is_tracer(value):
75
  return isinstance(value, jax.core.Tracer)
76

77

78
def _sequential_cumsum(arr, axis):
79
  """Sequential scan-based implementation of cumulative sum for Jax.
80

81
  The Jax implementation of cumulative sum does not guarantee that the output
82
  array is nondecreasing when applied to nonnegative outputs. This breaks
83
  the use of cumulative sum for bucketing. Using scan guarantees forces the
84
  sum to happen sequentially, which avoids the floating point nonsense that
85
  causes normal Jax cumsum to exhibit bad behavior.
86

87
  Args:
88
    arr: Jax array to sum.
89
    axis: axis to sum over.
90

91
  Returns:
92
    Jax array of partial cumulative sums.
93
  """
94

95
  # Swap axes so that the axis to be scanned over is the leading axis.
96
  xs = jnp.swapaxes(arr, 0, axis)
97
  init_carry = jnp.zeros(xs.shape[1:], xs.dtype)
98
  _, res = jax.lax.scan(lambda c, x: (c + x, c + x), init_carry, xs)
99
  return jnp.swapaxes(res, 0, axis)
100

101

102
def _arithmetic_categorical(
103
    rng, logits,
104
    codes):
105
  """Sample from a categorical using arithmetic sampling.
106

107
  Returns samples from an arithmetic codebook based on provided codes. This
108
  gives an unbiased sample for each code randomly picked from the unit interval.
109

110
  Args:
111
    rng: JAX PRNGKey.
112
    logits: array: [batch_size, vocab_size] float32 sequence of logits.
113
    codes: array: [batch_size] float32 codes for each batch element.
114

115
  Returns:
116
    A tuple (samples, new_codes) where `samples` are sampled indices with shape
117
    [batch_size], and `new_codes` are shape [batch_size] containing codes for
118
    the remaining suffix if doing ancestral sampling.
119
  """
120
  # We randomly permute the logits here at each timestep to avoid depending on
121
  # The default order of the vocabulary. This isn't strictly necessary.
122
  # We need to invert this permutation at the end cause it changes the
123
  # identities of the sampled indices.
124
  _, vocab_size = logits.shape
125
  perm = jax.random.permutation(rng, vocab_size)
126
  invperm = jnp.argsort(perm)
127

128
  logits = logits[:, perm]
129

130
  # Now we want to, for each element in the batch, get the normalized
131
  # probabilities, stack them in the unit interval into buckets, and figure
132
  # out what bucket the code falls into.
133
  probs = jax.nn.softmax(logits, axis=1)
134

135
  # Use the numpy cumsum with host callback to guarantee nondecreasing array
136
  # of partial sums.
137
  cumprobs = _sequential_cumsum(probs, axis=1)
138

139
  # Because of precision, make sure the max value (and everything with that
140
  # value, to not change bucket widths) is at least 1.0.
141
  max_probs = jnp.expand_dims(jnp.max(cumprobs, axis=1), 1)
142
  all_bucket_maxes = jnp.where((cumprobs == max_probs) & (cumprobs < 1.0), 1.0,
143
                               cumprobs)
144

145
  # Now the cumulative probabilities represent the max value of each of the
146
  # buckets. So let's make a mask of all the buckets whose maxes are less
147
  # than and greater than the given codes.
148
  expanded_codes = jnp.expand_dims(codes, axis=1)
149
  bucket_maxes_lte_codes = all_bucket_maxes <= expanded_codes
150
  bucket_maxes_gt_codes = all_bucket_maxes > expanded_codes
151

152
  # Pick the minimum value for the bucket for the code. Note this will be
153
  # 0.0 if the code falls into the zero'th bucket, as desired.
154
  code_bucket_mins = jnp.max(all_bucket_maxes * bucket_maxes_lte_codes, axis=1)
155

156
  # We have to do some masking here, and for probabilities, anything > 1.0
157
  # is as good as infinity.
158
  prob_infty = 1.1
159
  # Pick the maximum value for the bucket, the first bucket whose max is
160
  # greater than the code.
161
  code_bucket_maxes = jnp.min(
162
      all_bucket_maxes * bucket_maxes_gt_codes +
163
      bucket_maxes_lte_codes * prob_infty,
164
      axis=1)
165
  # We have to take the argmin before inverting the permutation,
166
  # otherwise it messes up the default tie breaking behavior for size zero
167
  # buckets (take lowest index).
168
  sampled_indices_permed = jnp.argmin(
169
      (all_bucket_maxes * bucket_maxes_gt_codes +
170
       bucket_maxes_lte_codes * prob_infty),
171
      axis=1)
172
  sampled_indices = jnp.argmax(
173
      jax.nn.one_hot(sampled_indices_permed, vocab_size)[:, invperm], axis=1)
174

175
  remainder_codes = (codes - code_bucket_mins) / (
176
      code_bucket_maxes - code_bucket_mins)
177

178
  samples = sampled_indices
179
  new_codes = remainder_codes
180

181
  return samples, new_codes
182

183

184
def arithmetic_sample(
185
    inputs,
186
    cache,
187
    tokens_to_logits,
188
    eos_id,
189
    decode_rng = None,
190
    num_decodes = 1,
191
    temperature = 1.0,
192
    topk = 1,
193
    topp = 0.0,
194
    cache_offset = 0,
195
    initial_index = None,
196
    max_decode_steps = None,
197
    max_decode_steps_hard_limit = None,
198
    rescale_log_probs = True,
199
    state_callback_fn = None,
200
    logit_callback_fn = None
201
):
202
  """Arithmetic sampling for language model generation.
203

204
  The sampling is performed `num_decodes` times in a vectorized
205
  manner by expanding the batch dimension. This is similar to how beam search
206
  expands the batch dimension to process each batch element with multiple beams.
207

208
  Args:
209
    inputs: array: [batch_size, max_decode_len] int32 sequence of tokens.
210
    cache: flax attention cache.
211
    tokens_to_logits: fast autoregressive decoder function taking single token
212
      slices and cache and returning next-token logits and updated cache.
213
    eos_id: int: end-of-sentence token for target vocabulary.
214
    decode_rng: JAX PRNGKey.
215
    num_decodes: number of decoded sequences to be returned.
216
    temperature: float: sampling temperature factor. As it approaches zero this
217
      becomes equivalent to greedy sampling.
218
    topk: integer: if nonzero only use the top-k logits to sample next token, if
219
      zero don't use any cutoff and sample from full logits over vocabulary.
220
    topp: float: if nonzero only use the smallest number of logits whose
221
      cumulative sum of probs adds up to (at least) topp. Will raise ValueError
222
      if it's nonzero when topk is nonzero.
223
    cache_offset: axis offset for cache, arising from scanned layers.
224
    initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes
225
      to start decoding at.
226
    max_decode_steps: int: an optional maximum number of decoding steps. If
227
      None, it will decode until the full input shape `inputs.shape[1]` is
228
      filled. max_decode_steps begins counting after the prompt, so it will
229
      decode at most len(prompt) + max_decode_steps tokens.
230
    max_decode_steps_hard_limit: int: an optional fixed hard limit on
231
      max_decode_steps. If this is set (not None and > 0), and max_decode_steps
232
      is also set, then max_decode_steps will be clipped to this limit. The
233
      value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit
234
      must be a Python integer or None.
235
    rescale_log_probs: bool: whether to apply temperature, topp, and topk
236
      rescaling to the log probs which are returned. If True, the log_probs will
237
      include these transformations (for example, with topk=1, all log_probs
238
      will be identically 0.0). If False, the log_probs will not be affected,
239
      and topk/topp/temperature will not affect sequence probabilities.
240
    state_callback_fn: Function that modifies the sampling loop state before
241
      each step. This can be used to manipulate any part of the state either on
242
      the accelerator or on the host using host callback. The function should
243
      take a SamplingLoopState as argument, and it returns the updated state.
244
      See `decoding_test.py` for an example usage.
245
    logit_callback_fn: Function that modifies the logits before each temperature
246
      sampling step. The function should take arguments (logits, state) and it
247
      should return the modified logits. See `decoding_test.py` for an example
248
      usage.
249

250
  Returns:
251
    A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape
252
    [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log
253
    probability of each of the sampled sequences.
254
  """
255
  if decode_rng is None:
256
    decode_rng = jax.random.PRNGKey(0)
257

258
  if (max_decode_steps_hard_limit is not None and
259
      max_decode_steps_hard_limit > 0 and max_decode_steps is not None):
260
    max_decode_steps = jnp.minimum(max_decode_steps,
261
                                   max_decode_steps_hard_limit)
262

263
  initial_codes = _make_default_codes(inputs.shape[0], num_decodes, decode_rng)
264
  flattened_codes = decoding.flatten_beam_dim(initial_codes)
265

266
  # [batch, len] -> [batch * num_decodes, len]
267
  expanded_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes)
268
  expanded_cache = decoding.cache_map(
269
      functools.partial(
270
          decoding.flat_batch_beam_expand,
271
          beam_size=num_decodes,
272
          offset=cache_offset),
273
      cache,
274
      # When we start with a prefilled cache, the cache index is no longer a
275
      # scalar that will broadcast across multiple decodes, it is a vector and
276
      # needs to be updated to handle the multiple decodes.
277
      apply_to_index=initial_index is not None)
278
  if initial_index is not None:
279
    initial_index = decoding.flat_batch_beam_expand(initial_index, num_decodes)
280

281
  # expanded_decodes: [batch * num_decodes, len]
282
  # expanded_log_prob: [batch * num_decodes]
283
  expanded_decodes, expanded_log_prob = _arithmetic_sample_single_trial(
284
      expanded_inputs,
285
      flattened_codes,
286
      expanded_cache,
287
      tokens_to_logits,
288
      eos_id,
289
      decode_rng,
290
      temperature,
291
      topk,
292
      topp,
293
      initial_index=initial_index,
294
      max_decode_steps=max_decode_steps,
295
      rescale_log_probs=rescale_log_probs,
296
      state_callback_fn=state_callback_fn,
297
      logit_callback_fn=logit_callback_fn)
298

299
  batch_size = inputs.shape[0]
300
  # [batch * num_decodes, len] -> [batch, num_decodes, len]
301
  decodes = decoding.unflatten_beam_dim(expanded_decodes, batch_size,
302
                                        num_decodes)
303
  # [batch * num_decodes] -> [batch, num_decodes]
304
  log_prob = decoding.unflatten_beam_dim(expanded_log_prob, batch_size,
305
                                         num_decodes)
306

307
  # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled
308
  # sequence.
309
  # [batch, num_decodes, 1]
310
  idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1)
311

312
  # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order.
313
  return jnp.take_along_axis(
314
      decodes, idxs, axis=1), jnp.take_along_axis(
315
          log_prob, jnp.squeeze(idxs, axis=-1), axis=-1)
316

317

318
def _make_default_codes(batch_size, num_decodes,
319
                        rng):
320
  """Make default codebook for a batch of `num_decodes` samples.
321

322
  The codes are initialized evenly spaced in the unit interval, with a random
323
  offset applied. This lets them evenly cover the sample space while also
324
  providing an unbiased estimate of any sample average.
325

326
  Args:
327
    batch_size: size of input batch.
328
    num_decodes: number of samples per batch element.
329
    rng: random seed.
330

331
  Returns:
332
    [batch_size, num_decodes] array of codes.
333
  """
334
  offset = jax.random.uniform(rng, (batch_size, 1))
335
  codes = jnp.tile(
336
      jnp.expand_dims(
337
          jnp.arange(1, num_decodes + 1, dtype=jnp.float32) / (num_decodes + 1),
338
          axis=0), (batch_size, 1))
339
  return jnp.mod(codes + offset, 1.0)
340

341

342
def _arithmetic_sample_single_trial(
343
    inputs,
344
    initial_codes,
345
    cache,
346
    tokens_to_logits,
347
    eos_id,
348
    prng_key,
349
    temperature = 1.0,
350
    topk = 20,
351
    topp = 0.0,
352
    initial_index = None,
353
    max_decode_steps = None,
354
    rescale_log_probs = True,
355
    state_callback_fn = None,
356
    logit_callback_fn = None
357
):
358
  """A helper function for `arithmetic_sample`."""
359

360
  # We can check the values of topp and topk only if they are not dynamic.
361
  if not _is_tracer(topp) and topp and topk:
362
    raise ValueError('At most one of `topp` or `topk` may be non-zero.')
363

364
  batch_size, max_decode_len = inputs.shape
365

366
  if max_decode_steps is not None:
367
    # We can check the max_decode_steps bounds only if it is not dynamic.
368
    if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]:
369
      raise ValueError('Cannot decode more steps than the sequence length.')
370

371
    # The number of decode steps required to process the prefix is the number
372
    #   of non-zero tokens, since inputs[0] == 0 is the BOS token.
373
    # `max_decode_len[j]` is the number of non-padding tokens in the jth element
374
    #   of the returned sequences capped at `len(inputs)`, assuming that the
375
    #   early stop doesn't occur. This is true with or without
376
    #   `max_decode_steps`.
377
    # When the while loop index `i` for the `j`th element `i[j] =
378
    #   max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]].
379
    #   Since sequences[:, 0] is BOS token, the generated token is
380
    #   `max_decode_len[j]`th non-padding tokens and hence `j`th element is
381
    #   ended.
382
    max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps
383
    max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len)
384

385
  # In the case of starting generation from a non-zero index, it is possible for
386
  # one batch element to reach `max_decode_len` number of decoding steps before
387
  # another. In order to let the last element decoder all the way to
388
  # `max_decode_len` number of steps, we add a final garbage token to the end of
389
  # The sequences. Any element that has reached `max_decode_len` before the rest
390
  # of the elements will continually overwrite this token until all elements
391
  # finish.
392
  # [batch, length+1] -> [batch, length+2]
393
  extra_input_tokens = 2
394
  expanded_prompt_inputs = jnp.append(
395
      inputs,
396
      jnp.zeros((batch_size, extra_input_tokens), dtype=inputs.dtype),
397
      axis=1)
398
  end_marker = jnp.array(eos_id)
399

400
  temperature = jnp.asarray(temperature)
401

402
  # Initialize sampling loop state.
403
  # initial loop PRNGKey
404
  rng0 = prng_key
405

406
  # The per batch-item holding current token in loop.
407
  if initial_index is None:
408
    # The per batch-item loop position counter.
409
    i0 = jnp.zeros((batch_size), dtype=jnp.int32)
410
    # The per batch-item holding current token in loop.
411
    token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32)
412
  else:
413
    # The per batch-item loop position counter.
414
    i0 = initial_index
415
    # The per batch-item holding current token in loop.
416
    # Select the token that the initial index is pointing to.
417
    token0 = jnp.take_along_axis(
418
        expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1)
419
  # per batch-item state bit indicating if sentence has finished.
420
  ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_)
421
  # (batch, length+2) array containing prefix prompt tokens for sampling loop
422
  # as well as the generated output of newly sampled tokens.
423
  sequences0 = expanded_prompt_inputs
424
  log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32)
425

426
  sampling_loop_init_state = ArithmeticSamplingLoopState(
427
      i0, sequences0, cache, token0, ended0, rng0, log_prob0, initial_codes)
428
  # Initial eos count to be used to determine whether eos is "generated". Many
429
  # inputs follow the format bos, inputs..., eos, targets..., eos. By counting
430
  # The number of eos tokens we can detect when a new one is added, instead of
431
  # just finding the one that probably ends the inputs.
432
  # [batch, 1]
433
  initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True)
434

435
  def sampling_loop_cond_fn(state):
436
    """Sampling loop termination condition."""
437
    # Have all sampled sequences reached an end marker?
438
    # Different elements in the batch can be at different loop indices, if any
439
    # of our examples are not at the end, keep going.
440
    all_sequences_ended = jnp.all(state.ended)
441
    return ~all_sequences_ended  # pytype: disable=bad-return-type  # jnp-type
442

443
  def sampling_loop_body_fn(
444
      state):
445
    """Sampling loop state update."""
446

447
    if state_callback_fn is not None:
448
      state = state_callback_fn(state)
449

450
    # Split RNG for sampling.
451
    rng1, rng2 = random.split(state.rng)
452
    # Call fast-decoder model on current tokens to get next-position logits.
453
    decoding_state = decoding.DecodingState(
454
        cur_index=state.cur_index,
455
        sequences=state.sequences[:, :-extra_input_tokens],
456
        cur_token=state.cur_token,
457
        cache=state.cache)
458
    logits, new_cache = tokens_to_logits(decoding_state)
459
    # Sample next token from logits.
460

461
    if logit_callback_fn is not None:
462
      logits = logit_callback_fn(logits, state)
463

464
    def sample_logits_with_nonzero_temperature(logits):
465

466
      # Before setting up the arithmetic sampling, we preprocess the logits into
467
      # Their final form.
468
      scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE)
469
      if topk:
470
        # Get top-k logits and their indices, sample within these top-k tokens.
471
        topk_logits, _ = lax.top_k(scaled_logits, topk)
472
        cutoff_logit = topk_logits[:, -1, None]
473
        scaled_logits = jnp.where(scaled_logits < cutoff_logit,
474
                                  jnp.full_like(scaled_logits, NEG_INF),
475
                                  scaled_logits)
476

477
      # When topp is dynamic, we always use it since we cannot check
478
      # non-zeroness (but it will have no effect if topp is 0.0).
479
      if _is_tracer(topp) or topp:
480
        logits_sorted = jnp.sort(
481
            scaled_logits, axis=-1)[:, ::-1]  # sort descending
482
        sorted_cum_probs = jnp.cumsum(
483
            jax.nn.softmax(logits_sorted, axis=-1), axis=-1)
484
        cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True)
485
        cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1)
486
        scaled_logits = jnp.where(scaled_logits < cutoff_logit,
487
                                  jnp.full_like(scaled_logits, NEG_INF),
488
                                  scaled_logits)
489

490
      next_token, next_code = _arithmetic_categorical(rng1, scaled_logits,
491
                                                      state.codes)
492

493
      # log probability of the current token conditioned on the previously
494
      # sampled and prefix tokens.
495
      # [batch, vocab] -> [batch, vocab]
496
      if rescale_log_probs:
497
        log_probs = jax.nn.log_softmax(scaled_logits)
498
      else:
499
        log_probs = jax.nn.log_softmax(logits)
500
      # [batch, vocab] -> [batch]
501
      next_log_prob = jnp.squeeze(
502
          jnp.take_along_axis(
503
              log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),
504
          axis=-1)
505

506
      return (next_token, next_log_prob, next_code)
507

508
    def sample_logits_with_zero_temperature(logits):
509
      # For zero temperature, we always want the greedy output, regardless
510
      # of the values of topk and topp.
511

512
      next_token = jnp.argmax(logits, -1).astype(jnp.int32)
513

514
      if rescale_log_probs:
515
        next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32)
516
      else:
517
        log_probs = jax.nn.log_softmax(logits)
518
        next_log_prob = jnp.squeeze(
519
            jnp.take_along_axis(
520
                log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),
521
            axis=-1)
522

523
      return (next_token, next_log_prob, state.codes)
524

525
    # Perform sampling with temperature
526
    (next_token, next_log_prob,
527
     next_code) = lax.cond(temperature > MIN_TEMPERATURE,
528
                           sample_logits_with_nonzero_temperature,
529
                           sample_logits_with_zero_temperature, logits)
530

531
    # When different batch elements are at different points in the loop counter,
532
    # it is possible that an element that started at a higher index will reach
533
    # `max_decode_len` before other elements. When this happens we need to make
534
    # sure this element continuous overwrites our new garbage collection index.
535
    # Here we clamp `i` to `max_decode_len`. This will cause the a write to
536
    # `max_decode_len + 1` which is the final index in `sequences`. Subsequent
537
    # loop body executions will also get their value clamped causing continual
538
    # overwriting of the final garbage position until all examples are finished.
539
    i = jnp.minimum(state.cur_index, max_decode_len)
540

541
    # Only use sampled tokens if we're past provided prefix tokens.
542
    # Select the next token from sequences.
543
    # [batch]
544
    next_input_token = jnp.squeeze(
545
        jnp.take_along_axis(
546
            state.sequences, jnp.expand_dims(i + 1, axis=1), axis=1),
547
        axis=1)
548
    # Check if the next token is padding (a target) or non-padding (an input).
549
    # Mask will have `1` for targets and `0` for inputs.
550
    out_of_prompt = (next_input_token == 0)
551
    # Select the sampled next token for targets and the actual next token for
552
    # inputs (teacher forcing).
553
    # [batch]
554
    next_token = (
555
        next_token * out_of_prompt + next_input_token * ~out_of_prompt)
556

557
    # only add probability if outside prefix region
558
    # [batch] -> [batch]
559
    next_log_prob = state.log_prob + (
560
        next_log_prob * out_of_prompt) * jnp.squeeze(
561
            ~state.ended, axis=-1).astype(jnp.int32)
562

563
    # [batch] -> [batch, 1]
564
    next_token = jnp.expand_dims(next_token, axis=-1)
565

566
    # If end-marker reached for batch item, only emit padding tokens.
567
    # [batch, 1] * [batch, 1] -> [batch, 1]
568
    next_token_or_endpad = next_token * ~state.ended
569
    # Add current sampled tokens to recorded sequences.
570
    one_hot = jax.nn.one_hot(
571
        i + 1, state.sequences.shape[1], dtype=state.sequences.dtype)
572
    new_sequences = state.sequences * (1 -
573
                                       one_hot) + next_token_or_endpad * one_hot
574
    # new_sequences = dynamic_update_vector_slice_in_dim(sequences,
575
    #                                                    next_token_or_endpad,
576
    #                                                    i + 1,
577
    #                                                    0)
578
    # Count eos tokens in the sequences and compare to the initial count
579
    # [batch, 1]
580
    cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True)
581
    # [batch, 1]
582

583
    # Have we reached max decoding length?
584
    # We generally index into sequences[:, i + 1], and sequences.shape[1] =
585
    # max_decode_len + 2, therefore i == max_decode_len - 1 will write to
586
    # sequences[-2] which is our last valid location. i == max_decode_len will
587
    # write to sequences[-1] which is our garbage collection token. Thus `i`
588
    # should be strictly less than max_decode_len.
589
    has_additional_eos = cur_eos_count > initial_eos_count
590
    ended = state.ended | has_additional_eos | jnp.expand_dims(
591
        i >= max_decode_len - 1, axis=1)
592

593
    return ArithmeticSamplingLoopState(i + 1, new_sequences, new_cache,
594
                                       next_token_or_endpad, ended, rng2,
595
                                       next_log_prob, next_code)
596

597
  # Run sampling loop and collect final state.
598
  final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn,
599
                               sampling_loop_init_state)
600

601
  # Pick part of the state corresponding to the sampled sequences.
602
  final_sequences = final_state.sequences
603
  log_prob = final_state.log_prob
604
  # Drop the first position because they are dummy bos tokens. Drop the new
605
  # garbage collection token at the end too.
606
  return final_sequences[:, 1:-1], log_prob  # pytype: disable=bad-return-type  # jax-ndarray
607

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.