google-research

Форк
0
333 строки · 12.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contextual T5X Models.
17
"""
18

19
import functools
20
from typing import Any, Mapping, MutableMapping, Optional, Tuple
21
from absl import logging
22
import jax
23
import jax.numpy as jnp
24
from t5x import models
25
from kl_guided_sampling import decoding
26
from kl_guided_sampling import feature_converters
27

28
PyTree = Any
29

30

31
class ContextualEncoderDecoderModel(models.EncoderDecoderModel):
32
  """Wrapper class for the models.Transformer nn.module."""
33

34
  FEATURE_CONVERTER_CLS = feature_converters.ContextualEncDecFeatureConverter
35

36
  def __init__(
37
      self,
38
      *args,
39
      decode_fn = decoding.temperature_sample,
40
      **kwargs,
41
  ):
42
    super().__init__(
43
        *args,
44
        decode_fn=decode_fn,
45
        **kwargs,
46
    )
47

48
  def predict_batch_with_aux(
49
      self,
50
      params,
51
      batch,
52
      rng = None,
53
      decoder_params = None,
54
      return_all_decodes = False,
55
      num_decodes = 1,
56
      prompt_with_targets = False
57
  ):
58
    """Predict with fast decoding beam search on a batch.
59

60
    For ContextualEncoderDecoderModel, running two decoding sequences in
61
    parallel can be decoupled by two copies of encoders and hiding the rest of
62
    the complexity to the sampling algorithm. Two different inputs
63
    "encoder_input_tokens" and "encoder_input_tokens_wo" are fed to two
64
    copies of encoders (and encapsulated in tokens_ids_to_logits), respectively.
65
    The decoders are kept intact, except for connecting to two different
66
    encoders. The temperature sampling inputs parameters: inputs, cache, and
67
    initial_index are shared, since inputs differences only affects encoders,
68
    not decoders.
69

70
    Args:
71
      params: model parameters.
72
      batch: a batch of inputs.
73
      rng: an optional RNG key to use during prediction, which is passed as
74
        'decode_rng' to the decoding function.
75
      decoder_params: additional (model-independent) parameters for the decoder.
76
      return_all_decodes: whether to return the entire beam or just the top-1.
77
      num_decodes: the number of beams to use in beam search.
78
      prompt_with_targets: Whether the force decode decoder_inputs.
79

80
    Returns:
81
      A tuple containing:
82
        the batch of predictions, with the entire beam if requested
83
        an auxiliary dictionary of decoder scores
84
    """
85
    # [batch, input_len]
86
    encoder_input_tokens = batch['encoder_input_tokens']
87
    encoder_input_tokens_wo = batch['encoder_input_tokens_wo']
88
    decoder_input_tokens = batch['decoder_input_tokens']
89

90
    # Prepare transformer fast-decoder call for beam search: for beam search, we
91
    # need to set up our decoder model to handle a batch size equal to
92
    # batch_size * num_decodes, where each batch item's data is expanded
93
    # in-place rather than tiled.
94
    # i.e. if we denote each batch element subtensor as el[n]:
95
    # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
96
    # [batch * num_decodes, input_len, emb_dim]
97
    encoded_inputs = decoding.flat_batch_beam_expand(
98
        self.module.apply(
99
            {'params': params},
100
            encoder_input_tokens,
101
            enable_dropout=False,
102
            method=self.module.encode,
103
        ),
104
        num_decodes,
105
    )
106
    encoded_inputs_wo = decoding.flat_batch_beam_expand(
107
        self.module.apply(
108
            {'params': params},
109
            encoder_input_tokens_wo,
110
            enable_dropout=False,
111
            method=self.module.encode,
112
        ),
113
        num_decodes,
114
    )
115

116
    # `decoder_prompt_inputs` is initialized from the batch's
117
    # `decoder_input_tokens`. The EOS is stripped to avoid decoding to stop
118
    # after the prompt by matching to `output_vocabulary.eos_id`.
119
    # These inputs are ignored by the beam search decode fn.
120
    if prompt_with_targets:
121
      decoder_prompt_inputs = decoder_input_tokens
122
      decoder_prompt_inputs = decoder_prompt_inputs * (
123
          decoder_prompt_inputs != self.output_vocabulary.eos_id
124
      )
125
    else:
126
      decoder_prompt_inputs = jnp.zeros_like(decoder_input_tokens)
127

128
    # Prepare autoregressive cache.
129
    cache, initial_index = self._compute_kv_cache(
130
        params,
131
        encoded_inputs=encoded_inputs,
132
        encoder_input_tokens=encoder_input_tokens,
133
        decoder_input_tokens=decoder_prompt_inputs,
134
    )
135
    # Prepare autoregressive cache.
136
    cache_wo, initial_index_wo = self._compute_kv_cache(
137
        params,
138
        encoded_inputs=encoded_inputs_wo,
139
        encoder_input_tokens=encoder_input_tokens_wo,
140
        decoder_input_tokens=decoder_prompt_inputs,
141
    )
142

143
    # [batch * num_decodes, input_len]
144
    raw_inputs = decoding.flat_batch_beam_expand(
145
        encoder_input_tokens, num_decodes
146
    )
147
    raw_inputs_wo = decoding.flat_batch_beam_expand(
148
        encoder_input_tokens_wo, num_decodes
149
    )
150

151
    tokens_ids_to_logits = functools.partial(
152
        self._compute_logits_from_slice,
153
        params=params,
154
        encoded_inputs=encoded_inputs,
155
        raw_inputs=raw_inputs,
156
        max_decode_length=decoder_input_tokens.shape[1],
157
    )
158
    tokens_ids_to_logits_wo = functools.partial(
159
        self._compute_logits_from_slice,
160
        params=params,
161
        encoded_inputs=encoded_inputs_wo,
162
        raw_inputs=raw_inputs_wo,
163
        max_decode_length=decoder_input_tokens.shape[1],
164
    )
165

166
    if decoder_params is None:
167
      decoder_params = {}
168
    if initial_index is not None:
169
      # We only set initial_index when it's non-None since it is not supported
170
      # by all decoders.
171
      decoder_params['initial_index'] = initial_index
172
    if initial_index_wo is not None:
173
      decoder_params['initial_index_wo'] = initial_index_wo
174

175
    if rng is not None:
176
      if decoder_params.get('decode_rng') is not None:
177
        raise ValueError(
178
            f'Got RNG both from the `rng` argument ({rng}) and '
179
            f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). "
180
            'Please specify one or the other.')
181
      decoder_params['decode_rng'] = rng
182

183
    # TODO(hwchung): rename the returned value names to more generic ones.
184
    # Using the above-defined single-step decoder function, run a
185
    # beam search over possible sequences given input encoding.
186
    # decodes: [batch, num_decodes, max_decode_len + 1]
187
    # scores: [batch, num_decodes]
188
    scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers
189

190
    if 'eos_id' not in decoder_params:
191
      decoder_params['eos_id'] = self.output_vocabulary.eos_id
192
    decodes, scores = self._decode_fn(
193
        inputs=decoder_prompt_inputs,
194
        inputs_wo=decoder_prompt_inputs,
195
        cache=cache,
196
        cache_wo=cache_wo,
197
        tokens_to_logits=tokens_ids_to_logits,
198
        tokens_to_logits_wo=tokens_ids_to_logits_wo,
199
        num_decodes=num_decodes,
200
        cache_offset=1 if scanned else 0,
201
        cache_offset_wo=1 if scanned else 0,
202
        **decoder_params)
203

204
    # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted
205
    # in increasing order of log-probability.
206
    # Return the highest scoring beam sequence.
207
    if return_all_decodes:
208
      return decodes, {'scores': scores}
209
    else:
210
      return decodes[:, -1, :], {'scores': scores[:, -1]}
211

212

213
class ContextualDecoderOnlyModel(models.DecoderOnlyModel):
214
  """Model class for the decoder-only modules with contexts.
215
  """
216

217
  FEATURE_CONVERTER_CLS = feature_converters.ContextualPrefixLMFeatureConverter
218

219
  def __init__(
220
      self,
221
      *args,
222
      decode_fn = decoding.temperature_sample,
223
      **kwargs,
224
  ):
225
    super().__init__(
226
        *args,
227
        decode_fn=decode_fn,
228
        **kwargs,
229
    )
230

231
  def predict_batch_with_aux(
232
      self,
233
      params,
234
      batch,
235
      rng = None,
236
      *,
237
      return_all_decodes = False,
238
      num_decodes = 1,
239
      decoder_params = None,
240
  ):
241
    """Predict with prefix and contexts.
242

243
    For ContextualDecoderOnlyModel, running two decoding sequences in
244
    parallel involves two copies of inputs, cache, tokens_to_logits, and
245
    initial_index, since inputs differences affects all steps in decoders.
246

247
    Args:
248
      params: model parameters.
249
      batch: batch element with the model features specified in
250
        seqio.DecoderFeatureConverter.
251
      rng: an optional RNG key to use during prediction, which is passed as
252
        'decode_rng' to the decoding function.
253
      return_all_decodes: if True, will return all batch_size * num_decodes
254
        samples from the model as an array of shape [batch_size, num_decodes,
255
        sequence_length]. Otherwise returns only the most likely samples as an
256
        array of shape [batch_size, sequence_length].
257
      num_decodes: number of decoded sequences to be returned.
258
      decoder_params: additional (model-independent) parameters for the decoder.
259

260
    Returns:
261
      sampled_sequences: an array of shape [batch, max_decode_length].
262
    """
263
    if 'decoder_causal_attention' not in batch:
264
      raise ValueError(
265
          'Batch does not have the right format for text generation: probably '
266
          'because `task_feature_lengths` passed to the feature converter does '
267
          'not have both `inputs` and `targets`.'
268
      )
269

270
    # since decoder_input_tokens is shifted to the right and
271
    # `decoder_causal_attention` has one more 1 than the number of inputs
272
    # tokens, this masks out targets portion of the decoder_input_tokens.
273
    inputs = batch['decoder_input_tokens'] * batch['decoder_causal_attention']
274
    inputs_wo = batch[
275
        'decoder_input_tokens_wo'] * batch['decoder_causal_attention_wo']
276

277
    prefilled_cache, initial_index = self._compute_kv_cache(
278
        params, inputs, batch['decoder_causal_attention'])
279
    prefilled_cache_wo, initial_index_wo = self._compute_kv_cache(
280
        params, inputs_wo, batch['decoder_causal_attention_wo'])
281

282
    target_shape = batch['decoder_input_tokens'].shape
283
    max_decode_length = target_shape[1]
284

285
    tokens_ids_to_logits = functools.partial(
286
        self._compute_logits_from_slice,
287
        params=params,
288
        max_decode_length=max_decode_length)
289

290
    if decoder_params is None:
291
      decoder_params = {}
292
    if rng is not None:
293
      if decoder_params.get('decode_rng') is not None:
294
        raise ValueError(
295
            f'Got RNG both from the `rng` argument ({rng}) and '
296
            f"`decoder_params['decode_rng']` ({decoder_params['decode_rng']}). "
297
            'Please specify one or the other.')
298
      decoder_params['decode_rng'] = rng
299

300
    # Using the above-defined single-step decoder function, run temperature
301
    # sampling with the prefix.
302
    # [batch, max_decode_length]
303
    scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers
304

305
    if 'eos_id' not in decoder_params:
306
      decoder_params['eos_id'] = self.output_vocabulary.eos_id
307
    decoded_sequences, scores = self._decode_fn(
308
        inputs=inputs,
309
        inputs_wo=inputs_wo,
310
        cache=prefilled_cache,
311
        cache_wo=prefilled_cache_wo,
312
        tokens_to_logits=tokens_ids_to_logits,
313
        tokens_to_logits_wo=tokens_ids_to_logits,
314
        num_decodes=num_decodes,
315
        initial_index=initial_index,
316
        initial_index_wo=initial_index_wo,
317
        cache_offset=1 if scanned else 0,
318
        cache_offset_wo=1 if scanned else 0,
319
        **decoder_params)
320

321
    if not return_all_decodes:
322
      # Search returns [n_batch, n_beam/decodes, n_length] with the beam/decode
323
      # dimension sorted in increasing order of log-probability.
324
      # `scores` is [batch, beam/decode_size]
325
      # We take the highest scoring sequence (-1) and its score
326
      decoded_sequences = decoded_sequences[:, -1, :]
327
      # Beam search returns []
328
      aux = {'scores': scores[:, -1]}
329
    else:
330
      # We return all samples and scores, rather than just the top ones.
331
      aux = {'scores': scores}
332

333
    return models.remove_prefix(decoded_sequences, initial_index), aux
334

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.