google-research
606 строк · 25.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""T5X decoding routine for arithmetic sampling."""
17import functools18
19from typing import Any, Callable, Mapping, Optional, Tuple, Union20import flax21import jax22from jax import lax23from jax import random24import jax.numpy as jnp25import numpy as np26from t5x import decoding27
28# Constants
29# "Effective negative infinity" constant for masking in beam search.
30NEG_INF = np.array(-1.0e7)31
32# Temperatures lower than this are considered 0.0, which is handled specially
33# with a conditional. This is to avoid numeric issues from exponentiating on
34# 1.0/temperature when temperature is close to 0.0.
35MIN_TEMPERATURE = np.array(1e-4)36
37#------------------------------------------------------------------------------
38# Arithmetic Sampling
39#------------------------------------------------------------------------------
40
41
42@flax.struct.dataclass43class ArithmeticSamplingLoopState:44"""Holds sampling state data.45
46Attributes:
47cur_index: [batch_size] array position of the sampling loop in the length
48dimension.
49sequences: [batch_size * num_decodes, max_decode_len] array of current
50sampled sequence prefixes.
51cache: any mapping of arrays, e.g. flax attention cache.
52cur_token: [batch_size, num_decodes] single timestep slice containing
53current tokens.
54ended: [batch_size, num_decodes] binary array marking completed sequences.
55rng: Jax PRNGKey
56log_prob: [batch_size, num_decodes] array of log probs for each sequence.
57codes: [batch_size, num_decodes] array containing the arithmetic codes for
58the remainder of the sequence at the current time step for each sample.
59"""
60cur_index: jnp.ndarray61sequences: jnp.ndarray62cache: Mapping[str, jnp.ndarray]63cur_token: jnp.ndarray64ended: jnp.ndarray65rng: jnp.ndarray66log_prob: jnp.ndarray67codes: jnp.ndarray68
69
70_dynamic_update_vector_slice_in_dim = jax.vmap(71lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None))72
73
74def _is_tracer(value):75return isinstance(value, jax.core.Tracer)76
77
78def _sequential_cumsum(arr, axis):79"""Sequential scan-based implementation of cumulative sum for Jax.80
81The Jax implementation of cumulative sum does not guarantee that the output
82array is nondecreasing when applied to nonnegative outputs. This breaks
83the use of cumulative sum for bucketing. Using scan guarantees forces the
84sum to happen sequentially, which avoids the floating point nonsense that
85causes normal Jax cumsum to exhibit bad behavior.
86
87Args:
88arr: Jax array to sum.
89axis: axis to sum over.
90
91Returns:
92Jax array of partial cumulative sums.
93"""
94
95# Swap axes so that the axis to be scanned over is the leading axis.96xs = jnp.swapaxes(arr, 0, axis)97init_carry = jnp.zeros(xs.shape[1:], xs.dtype)98_, res = jax.lax.scan(lambda c, x: (c + x, c + x), init_carry, xs)99return jnp.swapaxes(res, 0, axis)100
101
102def _arithmetic_categorical(103rng, logits,104codes):105"""Sample from a categorical using arithmetic sampling.106
107Returns samples from an arithmetic codebook based on provided codes. This
108gives an unbiased sample for each code randomly picked from the unit interval.
109
110Args:
111rng: JAX PRNGKey.
112logits: array: [batch_size, vocab_size] float32 sequence of logits.
113codes: array: [batch_size] float32 codes for each batch element.
114
115Returns:
116A tuple (samples, new_codes) where `samples` are sampled indices with shape
117[batch_size], and `new_codes` are shape [batch_size] containing codes for
118the remaining suffix if doing ancestral sampling.
119"""
120# We randomly permute the logits here at each timestep to avoid depending on121# The default order of the vocabulary. This isn't strictly necessary.122# We need to invert this permutation at the end cause it changes the123# identities of the sampled indices.124_, vocab_size = logits.shape125perm = jax.random.permutation(rng, vocab_size)126invperm = jnp.argsort(perm)127
128logits = logits[:, perm]129
130# Now we want to, for each element in the batch, get the normalized131# probabilities, stack them in the unit interval into buckets, and figure132# out what bucket the code falls into.133probs = jax.nn.softmax(logits, axis=1)134
135# Use the numpy cumsum with host callback to guarantee nondecreasing array136# of partial sums.137cumprobs = _sequential_cumsum(probs, axis=1)138
139# Because of precision, make sure the max value (and everything with that140# value, to not change bucket widths) is at least 1.0.141max_probs = jnp.expand_dims(jnp.max(cumprobs, axis=1), 1)142all_bucket_maxes = jnp.where((cumprobs == max_probs) & (cumprobs < 1.0), 1.0,143cumprobs)144
145# Now the cumulative probabilities represent the max value of each of the146# buckets. So let's make a mask of all the buckets whose maxes are less147# than and greater than the given codes.148expanded_codes = jnp.expand_dims(codes, axis=1)149bucket_maxes_lte_codes = all_bucket_maxes <= expanded_codes150bucket_maxes_gt_codes = all_bucket_maxes > expanded_codes151
152# Pick the minimum value for the bucket for the code. Note this will be153# 0.0 if the code falls into the zero'th bucket, as desired.154code_bucket_mins = jnp.max(all_bucket_maxes * bucket_maxes_lte_codes, axis=1)155
156# We have to do some masking here, and for probabilities, anything > 1.0157# is as good as infinity.158prob_infty = 1.1159# Pick the maximum value for the bucket, the first bucket whose max is160# greater than the code.161code_bucket_maxes = jnp.min(162all_bucket_maxes * bucket_maxes_gt_codes +163bucket_maxes_lte_codes * prob_infty,164axis=1)165# We have to take the argmin before inverting the permutation,166# otherwise it messes up the default tie breaking behavior for size zero167# buckets (take lowest index).168sampled_indices_permed = jnp.argmin(169(all_bucket_maxes * bucket_maxes_gt_codes +170bucket_maxes_lte_codes * prob_infty),171axis=1)172sampled_indices = jnp.argmax(173jax.nn.one_hot(sampled_indices_permed, vocab_size)[:, invperm], axis=1)174
175remainder_codes = (codes - code_bucket_mins) / (176code_bucket_maxes - code_bucket_mins)177
178samples = sampled_indices179new_codes = remainder_codes180
181return samples, new_codes182
183
184def arithmetic_sample(185inputs,186cache,187tokens_to_logits,188eos_id,189decode_rng = None,190num_decodes = 1,191temperature = 1.0,192topk = 1,193topp = 0.0,194cache_offset = 0,195initial_index = None,196max_decode_steps = None,197max_decode_steps_hard_limit = None,198rescale_log_probs = True,199state_callback_fn = None,200logit_callback_fn = None201):202"""Arithmetic sampling for language model generation.203
204The sampling is performed `num_decodes` times in a vectorized
205manner by expanding the batch dimension. This is similar to how beam search
206expands the batch dimension to process each batch element with multiple beams.
207
208Args:
209inputs: array: [batch_size, max_decode_len] int32 sequence of tokens.
210cache: flax attention cache.
211tokens_to_logits: fast autoregressive decoder function taking single token
212slices and cache and returning next-token logits and updated cache.
213eos_id: int: end-of-sentence token for target vocabulary.
214decode_rng: JAX PRNGKey.
215num_decodes: number of decoded sequences to be returned.
216temperature: float: sampling temperature factor. As it approaches zero this
217becomes equivalent to greedy sampling.
218topk: integer: if nonzero only use the top-k logits to sample next token, if
219zero don't use any cutoff and sample from full logits over vocabulary.
220topp: float: if nonzero only use the smallest number of logits whose
221cumulative sum of probs adds up to (at least) topp. Will raise ValueError
222if it's nonzero when topk is nonzero.
223cache_offset: axis offset for cache, arising from scanned layers.
224initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes
225to start decoding at.
226max_decode_steps: int: an optional maximum number of decoding steps. If
227None, it will decode until the full input shape `inputs.shape[1]` is
228filled. max_decode_steps begins counting after the prompt, so it will
229decode at most len(prompt) + max_decode_steps tokens.
230max_decode_steps_hard_limit: int: an optional fixed hard limit on
231max_decode_steps. If this is set (not None and > 0), and max_decode_steps
232is also set, then max_decode_steps will be clipped to this limit. The
233value max_decode_steps can be an ndarray, but max_decode_steps_hard_limit
234must be a Python integer or None.
235rescale_log_probs: bool: whether to apply temperature, topp, and topk
236rescaling to the log probs which are returned. If True, the log_probs will
237include these transformations (for example, with topk=1, all log_probs
238will be identically 0.0). If False, the log_probs will not be affected,
239and topk/topp/temperature will not affect sequence probabilities.
240state_callback_fn: Function that modifies the sampling loop state before
241each step. This can be used to manipulate any part of the state either on
242the accelerator or on the host using host callback. The function should
243take a SamplingLoopState as argument, and it returns the updated state.
244See `decoding_test.py` for an example usage.
245logit_callback_fn: Function that modifies the logits before each temperature
246sampling step. The function should take arguments (logits, state) and it
247should return the modified logits. See `decoding_test.py` for an example
248usage.
249
250Returns:
251A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape
252[batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log
253probability of each of the sampled sequences.
254"""
255if decode_rng is None:256decode_rng = jax.random.PRNGKey(0)257
258if (max_decode_steps_hard_limit is not None and259max_decode_steps_hard_limit > 0 and max_decode_steps is not None):260max_decode_steps = jnp.minimum(max_decode_steps,261max_decode_steps_hard_limit)262
263initial_codes = _make_default_codes(inputs.shape[0], num_decodes, decode_rng)264flattened_codes = decoding.flatten_beam_dim(initial_codes)265
266# [batch, len] -> [batch * num_decodes, len]267expanded_inputs = decoding.flat_batch_beam_expand(inputs, num_decodes)268expanded_cache = decoding.cache_map(269functools.partial(270decoding.flat_batch_beam_expand,271beam_size=num_decodes,272offset=cache_offset),273cache,274# When we start with a prefilled cache, the cache index is no longer a275# scalar that will broadcast across multiple decodes, it is a vector and276# needs to be updated to handle the multiple decodes.277apply_to_index=initial_index is not None)278if initial_index is not None:279initial_index = decoding.flat_batch_beam_expand(initial_index, num_decodes)280
281# expanded_decodes: [batch * num_decodes, len]282# expanded_log_prob: [batch * num_decodes]283expanded_decodes, expanded_log_prob = _arithmetic_sample_single_trial(284expanded_inputs,285flattened_codes,286expanded_cache,287tokens_to_logits,288eos_id,289decode_rng,290temperature,291topk,292topp,293initial_index=initial_index,294max_decode_steps=max_decode_steps,295rescale_log_probs=rescale_log_probs,296state_callback_fn=state_callback_fn,297logit_callback_fn=logit_callback_fn)298
299batch_size = inputs.shape[0]300# [batch * num_decodes, len] -> [batch, num_decodes, len]301decodes = decoding.unflatten_beam_dim(expanded_decodes, batch_size,302num_decodes)303# [batch * num_decodes] -> [batch, num_decodes]304log_prob = decoding.unflatten_beam_dim(expanded_log_prob, batch_size,305num_decodes)306
307# Sort `decodes` and `log_prob` by increasing log probabilities of the sampled308# sequence.309# [batch, num_decodes, 1]310idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1)311
312# returns [batch, num_decodes, len], [batch, num_decodes] in sorted order.313return jnp.take_along_axis(314decodes, idxs, axis=1), jnp.take_along_axis(315log_prob, jnp.squeeze(idxs, axis=-1), axis=-1)316
317
318def _make_default_codes(batch_size, num_decodes,319rng):320"""Make default codebook for a batch of `num_decodes` samples.321
322The codes are initialized evenly spaced in the unit interval, with a random
323offset applied. This lets them evenly cover the sample space while also
324providing an unbiased estimate of any sample average.
325
326Args:
327batch_size: size of input batch.
328num_decodes: number of samples per batch element.
329rng: random seed.
330
331Returns:
332[batch_size, num_decodes] array of codes.
333"""
334offset = jax.random.uniform(rng, (batch_size, 1))335codes = jnp.tile(336jnp.expand_dims(337jnp.arange(1, num_decodes + 1, dtype=jnp.float32) / (num_decodes + 1),338axis=0), (batch_size, 1))339return jnp.mod(codes + offset, 1.0)340
341
342def _arithmetic_sample_single_trial(343inputs,344initial_codes,345cache,346tokens_to_logits,347eos_id,348prng_key,349temperature = 1.0,350topk = 20,351topp = 0.0,352initial_index = None,353max_decode_steps = None,354rescale_log_probs = True,355state_callback_fn = None,356logit_callback_fn = None357):358"""A helper function for `arithmetic_sample`."""359
360# We can check the values of topp and topk only if they are not dynamic.361if not _is_tracer(topp) and topp and topk:362raise ValueError('At most one of `topp` or `topk` may be non-zero.')363
364batch_size, max_decode_len = inputs.shape365
366if max_decode_steps is not None:367# We can check the max_decode_steps bounds only if it is not dynamic.368if not _is_tracer(max_decode_steps) and max_decode_steps > inputs.shape[1]:369raise ValueError('Cannot decode more steps than the sequence length.')370
371# The number of decode steps required to process the prefix is the number372# of non-zero tokens, since inputs[0] == 0 is the BOS token.373# `max_decode_len[j]` is the number of non-padding tokens in the jth element374# of the returned sequences capped at `len(inputs)`, assuming that the375# early stop doesn't occur. This is true with or without376# `max_decode_steps`.377# When the while loop index `i` for the `j`th element `i[j] =378# max_decode_len[j] - 1`, the generated token populate sequences[i[j]+1]].379# Since sequences[:, 0] is BOS token, the generated token is380# `max_decode_len[j]`th non-padding tokens and hence `j`th element is381# ended.382max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps383max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len)384
385# In the case of starting generation from a non-zero index, it is possible for386# one batch element to reach `max_decode_len` number of decoding steps before387# another. In order to let the last element decoder all the way to388# `max_decode_len` number of steps, we add a final garbage token to the end of389# The sequences. Any element that has reached `max_decode_len` before the rest390# of the elements will continually overwrite this token until all elements391# finish.392# [batch, length+1] -> [batch, length+2]393extra_input_tokens = 2394expanded_prompt_inputs = jnp.append(395inputs,396jnp.zeros((batch_size, extra_input_tokens), dtype=inputs.dtype),397axis=1)398end_marker = jnp.array(eos_id)399
400temperature = jnp.asarray(temperature)401
402# Initialize sampling loop state.403# initial loop PRNGKey404rng0 = prng_key405
406# The per batch-item holding current token in loop.407if initial_index is None:408# The per batch-item loop position counter.409i0 = jnp.zeros((batch_size), dtype=jnp.int32)410# The per batch-item holding current token in loop.411token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32)412else:413# The per batch-item loop position counter.414i0 = initial_index415# The per batch-item holding current token in loop.416# Select the token that the initial index is pointing to.417token0 = jnp.take_along_axis(418expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1)419# per batch-item state bit indicating if sentence has finished.420ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_)421# (batch, length+2) array containing prefix prompt tokens for sampling loop422# as well as the generated output of newly sampled tokens.423sequences0 = expanded_prompt_inputs424log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32)425
426sampling_loop_init_state = ArithmeticSamplingLoopState(427i0, sequences0, cache, token0, ended0, rng0, log_prob0, initial_codes)428# Initial eos count to be used to determine whether eos is "generated". Many429# inputs follow the format bos, inputs..., eos, targets..., eos. By counting430# The number of eos tokens we can detect when a new one is added, instead of431# just finding the one that probably ends the inputs.432# [batch, 1]433initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True)434
435def sampling_loop_cond_fn(state):436"""Sampling loop termination condition."""437# Have all sampled sequences reached an end marker?438# Different elements in the batch can be at different loop indices, if any439# of our examples are not at the end, keep going.440all_sequences_ended = jnp.all(state.ended)441return ~all_sequences_ended # pytype: disable=bad-return-type # jnp-type442
443def sampling_loop_body_fn(444state):445"""Sampling loop state update."""446
447if state_callback_fn is not None:448state = state_callback_fn(state)449
450# Split RNG for sampling.451rng1, rng2 = random.split(state.rng)452# Call fast-decoder model on current tokens to get next-position logits.453decoding_state = decoding.DecodingState(454cur_index=state.cur_index,455sequences=state.sequences[:, :-extra_input_tokens],456cur_token=state.cur_token,457cache=state.cache)458logits, new_cache = tokens_to_logits(decoding_state)459# Sample next token from logits.460
461if logit_callback_fn is not None:462logits = logit_callback_fn(logits, state)463
464def sample_logits_with_nonzero_temperature(logits):465
466# Before setting up the arithmetic sampling, we preprocess the logits into467# Their final form.468scaled_logits = logits / jnp.maximum(temperature, MIN_TEMPERATURE)469if topk:470# Get top-k logits and their indices, sample within these top-k tokens.471topk_logits, _ = lax.top_k(scaled_logits, topk)472cutoff_logit = topk_logits[:, -1, None]473scaled_logits = jnp.where(scaled_logits < cutoff_logit,474jnp.full_like(scaled_logits, NEG_INF),475scaled_logits)476
477# When topp is dynamic, we always use it since we cannot check478# non-zeroness (but it will have no effect if topp is 0.0).479if _is_tracer(topp) or topp:480logits_sorted = jnp.sort(481scaled_logits, axis=-1)[:, ::-1] # sort descending482sorted_cum_probs = jnp.cumsum(483jax.nn.softmax(logits_sorted, axis=-1), axis=-1)484cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True)485cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1)486scaled_logits = jnp.where(scaled_logits < cutoff_logit,487jnp.full_like(scaled_logits, NEG_INF),488scaled_logits)489
490next_token, next_code = _arithmetic_categorical(rng1, scaled_logits,491state.codes)492
493# log probability of the current token conditioned on the previously494# sampled and prefix tokens.495# [batch, vocab] -> [batch, vocab]496if rescale_log_probs:497log_probs = jax.nn.log_softmax(scaled_logits)498else:499log_probs = jax.nn.log_softmax(logits)500# [batch, vocab] -> [batch]501next_log_prob = jnp.squeeze(502jnp.take_along_axis(503log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),504axis=-1)505
506return (next_token, next_log_prob, next_code)507
508def sample_logits_with_zero_temperature(logits):509# For zero temperature, we always want the greedy output, regardless510# of the values of topk and topp.511
512next_token = jnp.argmax(logits, -1).astype(jnp.int32)513
514if rescale_log_probs:515next_log_prob = jnp.zeros_like(next_token, dtype=jnp.float32)516else:517log_probs = jax.nn.log_softmax(logits)518next_log_prob = jnp.squeeze(519jnp.take_along_axis(520log_probs, jnp.expand_dims(next_token, axis=1), axis=-1),521axis=-1)522
523return (next_token, next_log_prob, state.codes)524
525# Perform sampling with temperature526(next_token, next_log_prob,527next_code) = lax.cond(temperature > MIN_TEMPERATURE,528sample_logits_with_nonzero_temperature,529sample_logits_with_zero_temperature, logits)530
531# When different batch elements are at different points in the loop counter,532# it is possible that an element that started at a higher index will reach533# `max_decode_len` before other elements. When this happens we need to make534# sure this element continuous overwrites our new garbage collection index.535# Here we clamp `i` to `max_decode_len`. This will cause the a write to536# `max_decode_len + 1` which is the final index in `sequences`. Subsequent537# loop body executions will also get their value clamped causing continual538# overwriting of the final garbage position until all examples are finished.539i = jnp.minimum(state.cur_index, max_decode_len)540
541# Only use sampled tokens if we're past provided prefix tokens.542# Select the next token from sequences.543# [batch]544next_input_token = jnp.squeeze(545jnp.take_along_axis(546state.sequences, jnp.expand_dims(i + 1, axis=1), axis=1),547axis=1)548# Check if the next token is padding (a target) or non-padding (an input).549# Mask will have `1` for targets and `0` for inputs.550out_of_prompt = (next_input_token == 0)551# Select the sampled next token for targets and the actual next token for552# inputs (teacher forcing).553# [batch]554next_token = (555next_token * out_of_prompt + next_input_token * ~out_of_prompt)556
557# only add probability if outside prefix region558# [batch] -> [batch]559next_log_prob = state.log_prob + (560next_log_prob * out_of_prompt) * jnp.squeeze(561~state.ended, axis=-1).astype(jnp.int32)562
563# [batch] -> [batch, 1]564next_token = jnp.expand_dims(next_token, axis=-1)565
566# If end-marker reached for batch item, only emit padding tokens.567# [batch, 1] * [batch, 1] -> [batch, 1]568next_token_or_endpad = next_token * ~state.ended569# Add current sampled tokens to recorded sequences.570one_hot = jax.nn.one_hot(571i + 1, state.sequences.shape[1], dtype=state.sequences.dtype)572new_sequences = state.sequences * (1 -573one_hot) + next_token_or_endpad * one_hot574# new_sequences = dynamic_update_vector_slice_in_dim(sequences,575# next_token_or_endpad,576# i + 1,577# 0)578# Count eos tokens in the sequences and compare to the initial count579# [batch, 1]580cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True)581# [batch, 1]582
583# Have we reached max decoding length?584# We generally index into sequences[:, i + 1], and sequences.shape[1] =585# max_decode_len + 2, therefore i == max_decode_len - 1 will write to586# sequences[-2] which is our last valid location. i == max_decode_len will587# write to sequences[-1] which is our garbage collection token. Thus `i`588# should be strictly less than max_decode_len.589has_additional_eos = cur_eos_count > initial_eos_count590ended = state.ended | has_additional_eos | jnp.expand_dims(591i >= max_decode_len - 1, axis=1)592
593return ArithmeticSamplingLoopState(i + 1, new_sequences, new_cache,594next_token_or_endpad, ended, rng2,595next_log_prob, next_code)596
597# Run sampling loop and collect final state.598final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn,599sampling_loop_init_state)600
601# Pick part of the state corresponding to the sampled sequences.602final_sequences = final_state.sequences603log_prob = final_state.log_prob604# Drop the first position because they are dummy bos tokens. Drop the new605# garbage collection token at the end too.606return final_sequences[:, 1:-1], log_prob # pytype: disable=bad-return-type # jax-ndarray607