google-research

Форк
0
365 строк · 14.8 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Fast decoding routines for inference from a trained model."""
17

18
import typing
19
import flax
20
import jax
21
from jax import lax
22
import jax.numpy as jnp
23
import numpy as np
24

25
# Constants
26
# We assume the default End-of-Sentence token id is 2 (SentencePiece).
27
EOS_ID = 2
28
# "Effective negative infinity" constant for masking in beam search.
29
NEG_INF = np.array(-1.0e7)
30

31

32
def brevity_penalty(alpha, length):
33
  """Brevity penalty function for beam search penalizing short sequences.
34

35
  Args:
36
    alpha: float: brevity-penalty scaling parameter.
37
    length: int: length of considered sequence.
38

39
  Returns:
40
    Brevity penalty score as jax scalar.
41
  """
42
  return jnp.power(((5.0 + length) / 6.0), alpha)
43

44

45
# Beam handling utility functions:
46

47

48
def add_beam_dim(x, beam_size):
49
  """Creates new beam dimension in non-scalar array and tiles into it."""
50
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
51
    return x
52
  x = jnp.expand_dims(x, axis=1)
53
  tile_dims = [1] * x.ndim
54
  tile_dims[1] = beam_size
55
  return jnp.tile(x, tile_dims)
56

57

58
def flatten_beam_dim(x):
59
  """Flattens the first two dimensions of a non-scalar array."""
60
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
61
    return x
62
  return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
63

64

65
def unflatten_beam_dim(x, batch_size, beam_size):
66
  """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
67
  if x.ndim == 0:  # ignore scalars (e.g. cache index)
68
    return x
69
  assert batch_size * beam_size == x.shape[0]
70
  return x.reshape((batch_size, beam_size) + x.shape[1:])
71

72

73
def flat_batch_beam_expand(x, beam_size):
74
  """Expands the each batch item by beam_size in batch_dimension."""
75
  return flatten_beam_dim(add_beam_dim(x, beam_size))
76

77

78
def gather_beams(nested, beam_indices, batch_size, new_beam_size):
79
  """Gathers the beam slices indexed by beam_indices into new beam array.
80

81
  Args:
82
    nested: pytree of arrays or scalars (the latter ignored).
83
    beam_indices: array of beam_indices
84
    batch_size: int: size of batch.
85
    new_beam_size: int: size of _new_ beam dimension.
86

87
  Returns:
88
    New pytree with new beam arrays.
89
    [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
90
  """
91
  batch_indices = jnp.reshape(
92
      jnp.arange(batch_size * new_beam_size) // new_beam_size,
93
      (batch_size, new_beam_size))
94
  def gather_fn(x):
95
    if x.ndim == 0:  # ignore scalars (e.g. cache index)
96
      return x
97
    else:
98
      return x[batch_indices, beam_indices]
99
  return jax.tree_map(gather_fn, nested)
100

101

102
def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
103
  """Gathers the top-k beam slices given by score_or_log_prob array.
104

105
  Args:
106
    nested: pytree of arrays or scalars (the latter ignored).
107
    score_or_log_prob: [batch_size, old_beam_size] array of values to sort by
108
      for top-k selection of beam slices.
109
    batch_size: int: size of batch.
110
    new_beam_size: int: size of _new_ top-k selected beam dimension
111

112
  Returns:
113
    New pytree with new beam arrays containing top k new_beam_size slices.
114
    [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
115
  """
116
  _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size)
117
  topk_indices = jnp.flip(topk_indices, axis=1)
118
  return gather_beams(nested, topk_indices, batch_size, new_beam_size)
119

120

121
# Beam search state:
122

123

124
@flax.struct.dataclass
125
class BeamState:
126
  """Holds beam search state data."""
127
  # The position of the decoding loop in the length dimension.
128
  cur_index: jax.Array  # scalar int32: current decoded length index
129
  # The active sequence log probabilities and finished sequence scores.
130
  live_logprobs: jax.Array  # float32: [batch_size, beam_size]
131
  finished_scores: jax.Array  # float32: [batch_size, beam_size]
132
  # The current active-beam-searching and finished sequences.
133
  live_seqs: jax.Array  # int32: [batch_size, beam_size, max_decode_len]
134
  finished_seqs: jax.Array  # int32: [batch_size, beam_size,
135
  #                                         max_decode_len]
136
  # Records which of the 'finished_seqs' is occupied and not a filler slot.
137
  finished_flags: jax.Array  # bool: [batch_size, beam_size]
138
  # The current state of the autoregressive decoding caches.
139
  cache: typing.Any  # Any pytree of arrays, e.g. flax attention Cache object
140

141

142
def beam_init(batch_size, beam_size, max_decode_len, cache):
143
  """Initializes the beam search state data structure."""
144
  cur_index0 = jnp.array(0)
145
  live_logprobs0 = jnp.tile(
146
      jnp.array([0.0] + [NEG_INF] * (beam_size - 1)),
147
      [batch_size, 1])
148
  finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
149
  live_seqs0 = jnp.zeros(
150
      (batch_size, beam_size, max_decode_len), jnp.int32)
151
  finished_seqs0 = jnp.zeros(
152
      (batch_size, beam_size, max_decode_len), jnp.int32)
153
  finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
154
  # add beam dimension to attention cache pytree elements
155
  beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
156
  return BeamState(cur_index=cur_index0,
157
                   live_logprobs=live_logprobs0,
158
                   finished_scores=finished_scores0,
159
                   live_seqs=live_seqs0,
160
                   finished_seqs=finished_seqs0,
161
                   finished_flags=finished_flags0,
162
                   cache=beam_cache0)
163

164

165
# Beam search routine:
166

167

168
def beam_search(inputs,
169
                cache,
170
                tokens_to_logits,
171
                beam_size=4,
172
                alpha=0.6,
173
                eos_id=EOS_ID,
174
                max_decode_len=None):
175
  """Beam search for transformer machine translation.
176

177
  Args:
178
    inputs: array: [batch_size, length] int32 sequence of tokens.
179
    cache: flax attention cache.
180
    tokens_to_logits: fast autoregressive decoder function taking single token
181
      slices and cache and returning next-token logits and updated cache.
182
    beam_size: int: number of beams to use in beam search.
183
    alpha: float: scaling factor for brevity penalty.
184
    eos_id: int: id of end-of-sentence token for target vocabulary.
185
    max_decode_len: int: maximum length of decoded translations.
186

187
  Returns:
188
     Tuple of:
189
       [batch_size, beam_size, max_decode_len] top-scoring sequences
190
       [batch_size, beam_size] beam-search scores.
191
  """
192
  # We liberally annotate shape information for clarity below.
193

194
  batch_size = inputs.shape[0]
195
  if max_decode_len is None:
196
    max_decode_len = inputs.shape[1]
197
  end_marker = jnp.array(eos_id)
198

199
  # initialize beam search state
200
  beam_search_init_state = beam_init(batch_size,
201
                                     beam_size,
202
                                     max_decode_len,
203
                                     cache)
204

205
  def beam_search_loop_cond_fn(state):
206
    """Beam search loop termination condition."""
207
    # Have we reached max decoding length?
208
    not_at_end = (state.cur_index < max_decode_len - 1)
209

210
    # Is no further progress in the beam search possible?
211
    # Get the best possible scores from alive sequences.
212
    min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
213
    best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
214
    # Get the worst scores from finished sequences.
215
    worst_finished_scores = jnp.min(
216
        state.finished_scores, axis=1, keepdims=True)
217
    # Mask out scores from slots without any actual finished sequences.
218
    worst_finished_scores = jnp.where(
219
        state.finished_flags, worst_finished_scores, NEG_INF)
220
    # If no best possible live score is better than current worst finished
221
    # scores, the search cannot improve the finished set further.
222
    search_terminated = jnp.all(worst_finished_scores > best_live_scores)
223

224
    # If we're not at the max decode length, and the search hasn't terminated,
225
    # continue looping.
226
    return not_at_end & (~search_terminated)
227

228
  def beam_search_loop_body_fn(state):
229
    """Beam search loop state update function."""
230
    # Collect the current position slice along length to feed the fast
231
    # autoregressive decoder model.  Flatten the beam dimension into batch
232
    # dimension for feeding into the model.
233
    # --> [batch * beam, 1]
234
    flat_ids = flatten_beam_dim(lax.dynamic_slice(
235
        state.live_seqs,
236
        (0, 0, state.cur_index),
237
        (batch_size, beam_size, 1)))
238
    # Flatten beam dimension into batch to be compatible with model.
239
    # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
240
    flat_cache = jax.tree_map(flatten_beam_dim, state.cache)
241

242
    # Call fast-decoder model on current tokens to get next-position logits.
243
    # --> [batch * beam, vocab]
244
    flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)
245

246
    # unflatten beam dimension
247
    # [batch * beam, vocab] --> [batch, beam, vocab]
248
    logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
249
    # Unflatten beam dimension in attention cache arrays
250
    # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
251
    new_cache = jax.tree_map(
252
        lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
253

254
    # Gather log probabilities from logits
255
    candidate_log_probs = jax.nn.log_softmax(logits)
256
    # Add new logprobs to existing prefix logprobs.
257
    # --> [batch, beam, vocab]
258
    log_probs = (candidate_log_probs +
259
                 jnp.expand_dims(state.live_logprobs, axis=2))
260

261
    # We'll need the vocab size, gather it from the log probability dimension.
262
    vocab_size = log_probs.shape[2]
263

264
    # Each item in batch has beam_size * vocab_size candidate sequences.
265
    # For each item, get the top 2*k candidates with the highest log-
266
    # probabilities. We gather the top 2*K beams here so that even if the best
267
    # K sequences reach EOS simultaneously, we have another K sequences
268
    # remaining to continue the live beam search.
269
    beams_to_keep = 2 * beam_size
270
    # Flatten beam and vocab dimensions.
271
    flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size))
272
    # Gather the top 2*K scores from _all_ beams.
273
    # --> [batch, 2*beams], [batch, 2*beams]
274
    topk_log_probs, topk_indices = lax.top_k(flat_log_probs, k=beams_to_keep)
275
    # Recover the beam index by floor division.
276
    topk_beam_indices = topk_indices // vocab_size
277
    # Gather 2*k top beams.
278
    # --> [batch, 2*beams, length]
279
    topk_seq = gather_beams(state.live_seqs,
280
                            topk_beam_indices,
281
                            batch_size, beams_to_keep)
282

283
    # Append the most probable 2*K token IDs to the top 2*K sequences
284
    # Recover token id by modulo division and expand Id array for broadcasting.
285
    # --> [batch, 2*beams, 1]
286
    topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
287
    # Update sequences for the 2*K top-k new sequences.
288
    # --> [batch, 2*beams, length]
289
    topk_seq = lax.dynamic_update_slice(
290
        topk_seq, topk_ids, (0, 0, state.cur_index + 1))
291

292
    # Update LIVE (in-progress) sequences:
293
    # Did any of these sequences reach an end marker?
294
    # --> [batch, 2*beams]
295
    newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
296
    # To prevent these newly finished sequences from being added to the LIVE
297
    # set of active beam search sequences, set their log probs to a very large
298
    # negative value.
299
    new_log_probs = topk_log_probs + newly_finished * NEG_INF
300
    # Determine the top k beam indices (from top 2*k beams) from log probs.
301
    # --> [batch, beams]
302
    _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
303
    new_topk_indices = jnp.flip(new_topk_indices, axis=1)
304
    # Gather the top k beams (from top 2*k beams).
305
    # --> [batch, beams, length], [batch, beams]
306
    top_alive_seq, top_alive_log_probs = gather_beams(
307
        [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)
308

309
    # Determine the top k beam indices from the original set of all beams.
310
    # --> [batch, beams]
311
    top_alive_indices = gather_beams(
312
        topk_beam_indices, new_topk_indices, batch_size, beam_size)
313
    # With these, gather the top k beam-associated caches.
314
    # --> {[batch, beams, ...], ...}
315
    top_alive_cache = gather_beams(
316
        new_cache, top_alive_indices, batch_size, beam_size)
317

318
    # Update FINISHED (reached end of sentence) sequences:
319
    # Calculate new seq scores from log probabilities.
320
    new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1)
321
    # Mask out the still unfinished sequences by adding large negative value.
322
    # --> [batch, 2*beams]
323
    new_scores += (~newly_finished) * NEG_INF
324

325
    # Combine sequences, scores, and flags along the beam dimension and compare
326
    # new finished sequence scores to existing finished scores and select the
327
    # best from the new set of beams.
328
    finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
329
        [state.finished_seqs, topk_seq], axis=1)
330
    finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
331
        [state.finished_scores, new_scores], axis=1)
332
    finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
333
        [state.finished_flags, newly_finished], axis=1)
334
    # --> [batch, beams, length], [batch, beams], [batch, beams]
335
    top_finished_seq, top_finished_scores, top_finished_flags = (
336
        gather_topk_beams([finished_seqs, finished_scores, finished_flags],
337
                          finished_scores, batch_size, beam_size))
338

339
    return BeamState(cur_index=state.cur_index + 1,
340
                     live_logprobs=top_alive_log_probs,
341
                     finished_scores=top_finished_scores,
342
                     live_seqs=top_alive_seq,
343
                     finished_seqs=top_finished_seq,
344
                     finished_flags=top_finished_flags,
345
                     cache=top_alive_cache)
346

347
  # Run while loop and get final beam search state.
348
  final_state = lax.while_loop(beam_search_loop_cond_fn,
349
                               beam_search_loop_body_fn,
350
                               beam_search_init_state)
351

352
  # Account for the edge-case where there are no finished sequences for a
353
  # particular batch item. If so, return live sequences for that batch item.
354
  # --> [batch]
355
  none_finished = jnp.any(final_state.finished_flags, axis=1)
356
  # --> [batch, beams, length]
357
  finished_seqs = jnp.where(none_finished[:, None, None],
358
                            final_state.finished_seqs,
359
                            final_state.live_seqs)
360
  # --> [batch, beams]
361
  finished_scores = jnp.where(none_finished[:, None],
362
                              final_state.finished_scores,
363
                              final_state.live_logprobs)
364

365
  return finished_seqs, finished_scores
366

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.