google-research
333 строки · 12.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contextual T5X Models.
17"""
18
19import functools20from typing import Any, Mapping, MutableMapping, Optional, Tuple21from absl import logging22import jax23import jax.numpy as jnp24from t5x import models25from kl_guided_sampling import decoding26from kl_guided_sampling import feature_converters27
28PyTree = Any29
30
31class ContextualEncoderDecoderModel(models.EncoderDecoderModel):32"""Wrapper class for the models.Transformer nn.module."""33
34FEATURE_CONVERTER_CLS = feature_converters.ContextualEncDecFeatureConverter35
36def __init__(37self,38*args,39decode_fn = decoding.temperature_sample,40**kwargs,41):42super().__init__(43*args,44decode_fn=decode_fn,45**kwargs,46)47
48def predict_batch_with_aux(49self,50params,51batch,52rng = None,53decoder_params = None,54return_all_decodes = False,55num_decodes = 1,56prompt_with_targets = False57):58"""Predict with fast decoding beam search on a batch.59
60For ContextualEncoderDecoderModel, running two decoding sequences in
61parallel can be decoupled by two copies of encoders and hiding the rest of
62the complexity to the sampling algorithm. Two different inputs
63"encoder_input_tokens" and "encoder_input_tokens_wo" are fed to two
64copies of encoders (and encapsulated in tokens_ids_to_logits), respectively.
65The decoders are kept intact, except for connecting to two different
66encoders. The temperature sampling inputs parameters: inputs, cache, and
67initial_index are shared, since inputs differences only affects encoders,
68not decoders.
69
70Args:
71params: model parameters.
72batch: a batch of inputs.
73rng: an optional RNG key to use during prediction, which is passed as
74'decode_rng' to the decoding function.
75decoder_params: additional (model-independent) parameters for the decoder.
76return_all_decodes: whether to return the entire beam or just the top-1.
77num_decodes: the number of beams to use in beam search.
78prompt_with_targets: Whether the force decode decoder_inputs.
79
80Returns:
81A tuple containing:
82the batch of predictions, with the entire beam if requested
83an auxiliary dictionary of decoder scores
84"""
85# [batch, input_len]86encoder_input_tokens = batch['encoder_input_tokens']87encoder_input_tokens_wo = batch['encoder_input_tokens_wo']88decoder_input_tokens = batch['decoder_input_tokens']89
90# Prepare transformer fast-decoder call for beam search: for beam search, we91# need to set up our decoder model to handle a batch size equal to92# batch_size * num_decodes, where each batch item's data is expanded93# in-place rather than tiled.94# i.e. if we denote each batch element subtensor as el[n]:95# [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]96# [batch * num_decodes, input_len, emb_dim]97encoded_inputs = decoding.flat_batch_beam_expand(98self.module.apply(99{'params': params},100encoder_input_tokens,101enable_dropout=False,102method=self.module.encode,103),104num_decodes,105)106encoded_inputs_wo = decoding.flat_batch_beam_expand(107self.module.apply(108{'params': params},109encoder_input_tokens_wo,110enable_dropout=False,111method=self.module.encode,112),113num_decodes,114)115
116# `decoder_prompt_inputs` is initialized from the batch's117# `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop118# after the prompt by matching to `output_vocabulary.eos_id`.119# These inputs are ignored by the beam search decode fn.120if prompt_with_targets:121decoder_prompt_inputs = decoder_input_tokens122decoder_prompt_inputs = decoder_prompt_inputs * (123decoder_prompt_inputs != self.output_vocabulary.eos_id124)125else:126decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens)127
128# Prepare autoregressive cache.129cache, initial_index = self._compute_kv_cache(130params,131encoded_inputs=encoded_inputs,132encoder_input_tokens=encoder_input_tokens,133decoder_input_tokens=decoder_prompt_inputs,134)135# Prepare autoregressive cache.136cache_wo, initial_index_wo = self._compute_kv_cache(137params,138encoded_inputs=encoded_inputs_wo,139encoder_input_tokens=encoder_input_tokens_wo,140decoder_input_tokens=decoder_prompt_inputs,141)142
143# [batch * num_decodes, input_len]144raw_inputs = decoding.flat_batch_beam_expand(145encoder_input_tokens, num_decodes146)147raw_inputs_wo = decoding.flat_batch_beam_expand(148encoder_input_tokens_wo, num_decodes149)150
151tokens_ids_to_logits = functools.partial(152self._compute_logits_from_slice,153params=params,154encoded_inputs=encoded_inputs,155raw_inputs=raw_inputs,156max_decode_length=decoder_input_tokens.shape[1],157)158tokens_ids_to_logits_wo = functools.partial(159self._compute_logits_from_slice,160params=params,161encoded_inputs=encoded_inputs_wo,162raw_inputs=raw_inputs_wo,163max_decode_length=decoder_input_tokens.shape[1],164)165
166if decoder_params is None:167decoder_params = {}168if initial_index is not None:169# We only set initial_index when it's non-None since it is not supported170# by all decoders.171decoder_params['initial_index'] = initial_index172if initial_index_wo is not None:173decoder_params['initial_index_wo'] = initial_index_wo174
175if rng is not None:176if decoder_params.get('decode_rng') is not None:177raise ValueError(178f'Got RNG both from the `rng` argument ({rng}) and '179f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). "180'Please specify one or the other.')181decoder_params['decode_rng'] = rng182
183# TODO(hwchung): rename the returned value names to more generic ones.184# Using the above-defined single-step decoder function, run a185# beam search over possible sequences given input encoding.186# decodes: [batch, num_decodes, max_decode_len + 1]187# scores: [batch, num_decodes]188scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers189
190if 'eos_id' not in decoder_params:191decoder_params['eos_id'] = self.output_vocabulary.eos_id192decodes, scores = self._decode_fn(193inputs=decoder_prompt_inputs,194inputs_wo=decoder_prompt_inputs,195cache=cache,196cache_wo=cache_wo,197tokens_to_logits=tokens_ids_to_logits,198tokens_to_logits_wo=tokens_ids_to_logits_wo,199num_decodes=num_decodes,200cache_offset=1 if scanned else 0,201cache_offset_wo=1 if scanned else 0,202**decoder_params)203
204# Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted205# in increasing order of log-probability.206# Return the highest scoring beam sequence.207if return_all_decodes:208return decodes, {'scores': scores}209else:210return decodes[:, -1, :], {'scores': scores[:, -1]}211
212
213class ContextualDecoderOnlyModel(models.DecoderOnlyModel):214"""Model class for the decoder-only modules with contexts.215"""
216
217FEATURE_CONVERTER_CLS = feature_converters.ContextualPrefixLMFeatureConverter218
219def __init__(220self,221*args,222decode_fn = decoding.temperature_sample,223**kwargs,224):225super().__init__(226*args,227decode_fn=decode_fn,228**kwargs,229)230
231def predict_batch_with_aux(232self,233params,234batch,235rng = None,236*,237return_all_decodes = False,238num_decodes = 1,239decoder_params = None,240):241"""Predict with prefix and contexts.242
243For ContextualDecoderOnlyModel, running two decoding sequences in
244parallel involves two copies of inputs, cache, tokens_to_logits, and
245initial_index, since inputs differences affects all steps in decoders.
246
247Args:
248params: model parameters.
249batch: batch element with the model features specified in
250seqio.DecoderFeatureConverter.
251rng: an optional RNG key to use during prediction, which is passed as
252'decode_rng' to the decoding function.
253return_all_decodes: if True, will return all batch_size * num_decodes
254samples from the model as an array of shape [batch_size, num_decodes,
255sequence_length]. Otherwise returns only the most likely samples as an
256array of shape [batch_size, sequence_length].
257num_decodes: number of decoded sequences to be returned.
258decoder_params: additional (model-independent) parameters for the decoder.
259
260Returns:
261sampled_sequences: an array of shape [batch, max_decode_length].
262"""
263if 'decoder_causal_attention' not in batch:264raise ValueError(265'Batch does not have the right format for text generation: probably '266'because `task_feature_lengths` passed to the feature converter does '267'not have both `inputs` and `targets`.'268)269
270# since decoder_input_tokens is shifted to the right and271# `decoder_causal_attention` has one more 1 than the number of inputs272# tokens, this masks out targets portion of the decoder_input_tokens.273inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention']274inputs_wo = batch[275'decoder_input_tokens_wo'] * batch['decoder_causal_attention_wo']276
277prefilled_cache, initial_index = self._compute_kv_cache(278params, inputs, batch['decoder_causal_attention'])279prefilled_cache_wo, initial_index_wo = self._compute_kv_cache(280params, inputs_wo, batch['decoder_causal_attention_wo'])281
282target_shape = batch['decoder_input_tokens'].shape283max_decode_length = target_shape[1]284
285tokens_ids_to_logits = functools.partial(286self._compute_logits_from_slice,287params=params,288max_decode_length=max_decode_length)289
290if decoder_params is None:291decoder_params = {}292if rng is not None:293if decoder_params.get('decode_rng') is not None:294raise ValueError(295f'Got RNG both from the `rng` argument ({rng}) and '296f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). "297'Please specify one or the other.')298decoder_params['decode_rng'] = rng299
300# Using the above-defined single-step decoder function, run temperature301# sampling with the prefix.302# [batch, max_decode_length]303scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers304
305if 'eos_id' not in decoder_params:306decoder_params['eos_id'] = self.output_vocabulary.eos_id307decoded_sequences, scores = self._decode_fn(308inputs=inputs,309inputs_wo=inputs_wo,310cache=prefilled_cache,311cache_wo=prefilled_cache_wo,312tokens_to_logits=tokens_ids_to_logits,313tokens_to_logits_wo=tokens_ids_to_logits,314num_decodes=num_decodes,315initial_index=initial_index,316initial_index_wo=initial_index_wo,317cache_offset=1 if scanned else 0,318cache_offset_wo=1 if scanned else 0,319**decoder_params)320
321if not return_all_decodes:322# Search returns [n_batch, n_beam/decodes, n_length] with the beam/decode323# dimension sorted in increasing order of log-probability.324# `scores` is [batch, beam/decode_size]325# We take the highest scoring sequence (-1) and its score326decoded_sequences = decoded_sequences[:, -1, :]327# Beam search returns []328aux = {'scores': scores[:, -1]}329else:330# We return all samples and scores, rather than just the top ones.331aux = {'scores': scores}332
333return models.remove_prefix(decoded_sequences, initial_index), aux334