google-research
365 строк · 14.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Fast decoding routines for inference from a trained model."""
17
18import typing
19import flax
20import jax
21from jax import lax
22import jax.numpy as jnp
23import numpy as np
24
25# Constants
26# We assume the default End-of-Sentence token id is 2 (SentencePiece).
27EOS_ID = 2
28# "Effective negative infinity" constant for masking in beam search.
29NEG_INF = np.array(-1.0e7)
30
31
32def brevity_penalty(alpha, length):
33"""Brevity penalty function for beam search penalizing short sequences.
34
35Args:
36alpha: float: brevity-penalty scaling parameter.
37length: int: length of considered sequence.
38
39Returns:
40Brevity penalty score as jax scalar.
41"""
42return jnp.power(((5.0 + length) / 6.0), alpha)
43
44
45# Beam handling utility functions:
46
47
48def add_beam_dim(x, beam_size):
49"""Creates new beam dimension in non-scalar array and tiles into it."""
50if x.ndim == 0: # ignore scalars (e.g. cache index)
51return x
52x = jnp.expand_dims(x, axis=1)
53tile_dims = [1] * x.ndim
54tile_dims[1] = beam_size
55return jnp.tile(x, tile_dims)
56
57
58def flatten_beam_dim(x):
59"""Flattens the first two dimensions of a non-scalar array."""
60if x.ndim == 0: # ignore scalars (e.g. cache index)
61return x
62return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
63
64
65def unflatten_beam_dim(x, batch_size, beam_size):
66"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
67if x.ndim == 0: # ignore scalars (e.g. cache index)
68return x
69assert batch_size * beam_size == x.shape[0]
70return x.reshape((batch_size, beam_size) + x.shape[1:])
71
72
73def flat_batch_beam_expand(x, beam_size):
74"""Expands the each batch item by beam_size in batch_dimension."""
75return flatten_beam_dim(add_beam_dim(x, beam_size))
76
77
78def gather_beams(nested, beam_indices, batch_size, new_beam_size):
79"""Gathers the beam slices indexed by beam_indices into new beam array.
80
81Args:
82nested: pytree of arrays or scalars (the latter ignored).
83beam_indices: array of beam_indices
84batch_size: int: size of batch.
85new_beam_size: int: size of _new_ beam dimension.
86
87Returns:
88New pytree with new beam arrays.
89[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
90"""
91batch_indices = jnp.reshape(
92jnp.arange(batch_size * new_beam_size) // new_beam_size,
93(batch_size, new_beam_size))
94def gather_fn(x):
95if x.ndim == 0: # ignore scalars (e.g. cache index)
96return x
97else:
98return x[batch_indices, beam_indices]
99return jax.tree_map(gather_fn, nested)
100
101
102def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
103"""Gathers the top-k beam slices given by score_or_log_prob array.
104
105Args:
106nested: pytree of arrays or scalars (the latter ignored).
107score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
108for top-k selection of beam slices.
109batch_size: int: size of batch.
110new_beam_size: int: size of _new_ top-k selected beam dimension
111
112Returns:
113New pytree with new beam arrays containing top k new_beam_size slices.
114[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
115"""
116_, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
117topk_indices = jnp.flip(topk_indices, axis=1)
118return gather_beams(nested, topk_indices, batch_size, new_beam_size)
119
120
121# Beam search state:
122
123
124@flax.struct.dataclass
125class BeamState:
126"""Holds beam search state data."""
127# The position of the decoding loop in the length dimension.
128cur_index: jax.Array # scalar int32: current decoded length index
129# The active sequence log probabilities and finished sequence scores.
130live_logprobs: jax.Array # float32: [batch_size, beam_size]
131finished_scores: jax.Array # float32: [batch_size, beam_size]
132# The current active-beam-searching and finished sequences.
133live_seqs: jax.Array # int32: [batch_size, beam_size, max_decode_len]
134finished_seqs: jax.Array # int32: [batch_size, beam_size,
135# max_decode_len]
136# Records which of the 'finished_seqs' is occupied and not a filler slot.
137finished_flags: jax.Array # bool: [batch_size, beam_size]
138# The current state of the autoregressive decoding caches.
139cache: typing.Any # Any pytree of arrays, e.g. flax attention Cache object
140
141
142def beam_init(batch_size, beam_size, max_decode_len, cache):
143"""Initializes the beam search state data structure."""
144cur_index0 = jnp.array(0)
145live_logprobs0 = jnp.tile(
146jnp.array([0.0] + [NEG_INF] * (beam_size - 1)),
147[batch_size, 1])
148finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
149live_seqs0 = jnp.zeros(
150(batch_size, beam_size, max_decode_len), jnp.int32)
151finished_seqs0 = jnp.zeros(
152(batch_size, beam_size, max_decode_len), jnp.int32)
153finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
154# add beam dimension to attention cache pytree elements
155beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
156return BeamState(cur_index=cur_index0,
157live_logprobs=live_logprobs0,
158finished_scores=finished_scores0,
159live_seqs=live_seqs0,
160finished_seqs=finished_seqs0,
161finished_flags=finished_flags0,
162cache=beam_cache0)
163
164
165# Beam search routine:
166
167
168def beam_search(inputs,
169cache,
170tokens_to_logits,
171beam_size=4,
172alpha=0.6,
173eos_id=EOS_ID,
174max_decode_len=None):
175"""Beam search for transformer machine translation.
176
177Args:
178inputs: array: [batch_size, length] int32 sequence of tokens.
179cache: flax attention cache.
180tokens_to_logits: fast autoregressive decoder function taking single token
181slices and cache and returning next-token logits and updated cache.
182beam_size: int: number of beams to use in beam search.
183alpha: float: scaling factor for brevity penalty.
184eos_id: int: id of end-of-sentence token for target vocabulary.
185max_decode_len: int: maximum length of decoded translations.
186
187Returns:
188Tuple of:
189[batch_size, beam_size, max_decode_len] top-scoring sequences
190[batch_size, beam_size] beam-search scores.
191"""
192# We liberally annotate shape information for clarity below.
193
194batch_size = inputs.shape[0]
195if max_decode_len is None:
196max_decode_len = inputs.shape[1]
197end_marker = jnp.array(eos_id)
198
199# initialize beam search state
200beam_search_init_state = beam_init(batch_size,
201beam_size,
202max_decode_len,
203cache)
204
205def beam_search_loop_cond_fn(state):
206"""Beam search loop termination condition."""
207# Have we reached max decoding length?
208not_at_end = (state.cur_index < max_decode_len - 1)
209
210# Is no further progress in the beam search possible?
211# Get the best possible scores from alive sequences.
212min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
213best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
214# Get the worst scores from finished sequences.
215worst_finished_scores = jnp.min(
216state.finished_scores, axis=1, keepdims=True)
217# Mask out scores from slots without any actual finished sequences.
218worst_finished_scores = jnp.where(
219state.finished_flags, worst_finished_scores, NEG_INF)
220# If no best possible live score is better than current worst finished
221# scores, the search cannot improve the finished set further.
222search_terminated = jnp.all(worst_finished_scores > best_live_scores)
223
224# If we're not at the max decode length, and the search hasn't terminated,
225# continue looping.
226return not_at_end & (~search_terminated)
227
228def beam_search_loop_body_fn(state):
229"""Beam search loop state update function."""
230# Collect the current position slice along length to feed the fast
231# autoregressive decoder model. Flatten the beam dimension into batch
232# dimension for feeding into the model.
233# --> [batch * beam, 1]
234flat_ids = flatten_beam_dim(lax.dynamic_slice(
235state.live_seqs,
236(0, 0, state.cur_index),
237(batch_size, beam_size, 1)))
238# Flatten beam dimension into batch to be compatible with model.
239# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
240flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
241
242# Call fast-decoder model on current tokens to get next-position logits.
243# --> [batch * beam, vocab]
244flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
245
246# unflatten beam dimension
247# [batch * beam, vocab] --> [batch, beam, vocab]
248logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
249# Unflatten beam dimension in attention cache arrays
250# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
251new_cache = jax.tree_map(
252lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
253
254# Gather log probabilities from logits
255candidate_log_probs = jax.nn.log_softmax(logits)
256# Add new logprobs to existing prefix logprobs.
257# --> [batch, beam, vocab]
258log_probs = (candidate_log_probs +
259jnp.expand_dims(state.live_logprobs, axis=2))
260
261# We'll need the vocab size, gather it from the log probability dimension.
262vocab_size = log_probs.shape[2]
263
264# Each item in batch has beam_size * vocab_size candidate sequences.
265# For each item, get the top 2*k candidates with the highest log-
266# probabilities. We gather the top 2*K beams here so that even if the best
267# K sequences reach EOS simultaneously, we have another K sequences
268# remaining to continue the live beam search.
269beams_to_keep = 2 * beam_size
270# Flatten beam and vocab dimensions.
271flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
272# Gather the top 2*K scores from _all_ beams.
273# --> [batch, 2*beams], [batch, 2*beams]
274topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
275# Recover the beam index by floor division.
276topk_beam_indices = topk_indices // vocab_size
277# Gather 2*k top beams.
278# --> [batch, 2*beams, length]
279topk_seq = gather_beams(state.live_seqs,
280topk_beam_indices,
281batch_size, beams_to_keep)
282
283# Append the most probable 2*K token IDs to the top 2*K sequences
284# Recover token id by modulo division and expand Id array for broadcasting.
285# --> [batch, 2*beams, 1]
286topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
287# Update sequences for the 2*K top-k new sequences.
288# --> [batch, 2*beams, length]
289topk_seq = lax.dynamic_update_slice(
290topk_seq, topk_ids, (0, 0, state.cur_index + 1))
291
292# Update LIVE (in-progress) sequences:
293# Did any of these sequences reach an end marker?
294# --> [batch, 2*beams]
295newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
296# To prevent these newly finished sequences from being added to the LIVE
297# set of active beam search sequences, set their log probs to a very large
298# negative value.
299new_log_probs = topk_log_probs + newly_finished * NEG_INF
300# Determine the top k beam indices (from top 2*k beams) from log probs.
301# --> [batch, beams]
302_, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
303new_topk_indices = jnp.flip(new_topk_indices, axis=1)
304# Gather the top k beams (from top 2*k beams).
305# --> [batch, beams, length], [batch, beams]
306top_alive_seq, top_alive_log_probs = gather_beams(
307[topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)
308
309# Determine the top k beam indices from the original set of all beams.
310# --> [batch, beams]
311top_alive_indices = gather_beams(
312topk_beam_indices, new_topk_indices, batch_size, beam_size)
313# With these, gather the top k beam-associated caches.
314# --> {[batch, beams, ...], ...}
315top_alive_cache = gather_beams(
316new_cache, top_alive_indices, batch_size, beam_size)
317
318# Update FINISHED (reached end of sentence) sequences:
319# Calculate new seq scores from log probabilities.
320new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
321# Mask out the still unfinished sequences by adding large negative value.
322# --> [batch, 2*beams]
323new_scores += (~newly_finished) * NEG_INF
324
325# Combine sequences, scores, and flags along the beam dimension and compare
326# new finished sequence scores to existing finished scores and select the
327# best from the new set of beams.
328finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
329[state.finished_seqs, topk_seq], axis=1)
330finished_scores = jnp.concatenate( # --> [batch, 3*beams]
331[state.finished_scores, new_scores], axis=1)
332finished_flags = jnp.concatenate( # --> [batch, 3*beams]
333[state.finished_flags, newly_finished], axis=1)
334# --> [batch, beams, length], [batch, beams], [batch, beams]
335top_finished_seq, top_finished_scores, top_finished_flags = (
336gather_topk_beams([finished_seqs, finished_scores, finished_flags],
337finished_scores, batch_size, beam_size))
338
339return BeamState(cur_index=state.cur_index + 1,
340live_logprobs=top_alive_log_probs,
341finished_scores=top_finished_scores,
342live_seqs=top_alive_seq,
343finished_seqs=top_finished_seq,
344finished_flags=top_finished_flags,
345cache=top_alive_cache)
346
347# Run while loop and get final beam search state.
348final_state = lax.while_loop(beam_search_loop_cond_fn,
349beam_search_loop_body_fn,
350beam_search_init_state)
351
352# Account for the edge-case where there are no finished sequences for a
353# particular batch item. If so, return live sequences for that batch item.
354# --> [batch]
355none_finished = jnp.any(final_state.finished_flags, axis=1)
356# --> [batch, beams, length]
357finished_seqs = jnp.where(none_finished[:, None, None],
358final_state.finished_seqs,
359final_state.live_seqs)
360# --> [batch, beams]
361finished_scores = jnp.where(none_finished[:, None],
362final_state.finished_scores,
363final_state.live_logprobs)
364
365return finished_seqs, finished_scores
366