google-research

Форк
0
1668 строк · 64.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Transformer model from "Attention Is All You Need".
17

18
The Transformer model consists of an encoder and a decoder. Both are stacks
19
of self-attention layers followed by feed-forward layers. This model yields
20
good results on a number of problems, especially in NLP and machine translation.
21

22
See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full
23
description of the model and the results obtained with its early version.
24

25
Branched from Tensor2Tensor implementation:
26
github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
27
"""
28

29
from __future__ import absolute_import
30
from __future__ import division
31
from __future__ import print_function
32
from six.moves import range  # pylint: disable=redefined-builtin
33

34
from tensor2tensor.layers import common_attention
35
from tensor2tensor.layers import common_hparams
36
from tensor2tensor.layers import common_layers
37
from tensor2tensor.layers import modalities
38
from tensor2tensor.layers import transformer_layers
39
from tensor2tensor.utils import beam_search
40
from tensor2tensor.utils import mlperf_log
41
from tensor2tensor.utils import registry
42

43
import tensorflow.compat.v1 as tf
44
from tensorflow.compat.v1 import estimator as tf_estimator
45

46
from state_of_sparsity.sparse_transformer.layers import common_sparse
47
from state_of_sparsity.sparse_transformer.layers import sparse_attention
48
from state_of_sparsity.sparse_transformer.layers import sparse_modalities
49
from state_of_sparsity.sparse_transformer.layers import sparse_transformer_layers
50
from state_of_sparsity.sparse_transformer.models import sparse_model
51

52
from tensorflow.python.ops import inplace_ops  # pylint: disable=g-direct-tensorflow-import
53
from tensorflow.python.util import nest  # pylint: disable=g-direct-tensorflow-import
54

55

56
# Alias some commonly reused layers, here and elsewhere.
57
transformer_prepare_encoder = transformer_layers.transformer_prepare_encoder
58
transformer_encoder = sparse_transformer_layers.transformer_encoder
59
transformer_ffn_layer = sparse_transformer_layers.transformer_ffn_layer
60

61

62
@registry.register_model
63
class SparseTransformer(sparse_model.SparseModel):
64
  """Attention net.  See file docstring."""
65

66
  def __init__(self, *args, **kwargs):
67
    super(SparseTransformer, self).__init__(*args, **kwargs)
68
    self.attention_weights = dict()  # For visualizing attention heads.
69

70
  def encode(self, inputs, target_space, hparams, features=None, losses=None):
71
    """Encode transformer inputs.
72

73
    Args:
74
      inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which
75
        will be flattened along the two spatial dimensions.
76
      target_space: scalar, target space ID.
77
      hparams: hyperparameters for model.
78
      features: optionally pass the entire features dictionary as well.
79
        This is needed now for "packed" datasets.
80
      losses: optional list onto which to append extra training losses
81

82
    Returns:
83
      Tuple of:
84
          encoder_output: Encoder representation.
85
              [batch_size, input_length, hidden_dim]
86
          encoder_decoder_attention_bias: Bias and mask weights for
87
              encoder-decoder attention. [batch_size, input_length]
88
    """
89
    inputs = common_layers.flatten4d3d(inputs)
90

91
    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
92
        transformer_prepare_encoder(
93
            inputs, target_space, hparams, features=features))
94

95
    mlperf_log.transformer_print(
96
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
97
        value=hparams.layer_prepostprocess_dropout)
98

99
    encoder_input = tf.nn.dropout(encoder_input,
100
                                  1.0 - hparams.layer_prepostprocess_dropout)
101

102
    encoder_output = transformer_encoder(
103
        encoder_input,
104
        self_attention_bias,
105
        hparams,
106
        nonpadding=features_to_nonpadding(features, "inputs"),
107
        save_weights_to=self.attention_weights,
108
        make_image_summary=not common_layers.is_xla_compiled())
109

110
    return encoder_output, encoder_decoder_attention_bias
111

112
  def decode(self,
113
             decoder_input,
114
             encoder_output,
115
             encoder_decoder_attention_bias,
116
             decoder_self_attention_bias,
117
             hparams,
118
             cache=None,
119
             decode_loop_step=None,
120
             losses=None):
121
    """Decode Transformer outputs from encoder representation.
122

123
    Args:
124
      decoder_input: inputs to bottom of the model.
125
          [batch_size, decoder_length, hidden_dim]
126
      encoder_output: Encoder representation.
127
          [batch_size, input_length, hidden_dim]
128
      encoder_decoder_attention_bias: Bias and mask weights for
129
          encoder-decoder attention. [batch_size, input_length]
130
      decoder_self_attention_bias: Bias and mask weights for decoder
131
          self-attention. [batch_size, decoder_length]
132
      hparams: hyperparameters for model.
133
      cache: dict, containing tensors which are the results of previous
134
          attentions, used for fast decoding.
135
      decode_loop_step: An integer, step number of the decoding loop.
136
          Only used for inference on TPU.
137
      losses: optional list onto which to append extra training losses
138

139
    Returns:
140
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
141
    """
142
    mlperf_log.transformer_print(
143
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
144
        value=hparams.layer_prepostprocess_dropout)
145
    decoder_input = tf.nn.dropout(decoder_input,
146
                                  1.0 - hparams.layer_prepostprocess_dropout)
147

148
    decoder_output = transformer_decoder(
149
        decoder_input,
150
        encoder_output,
151
        decoder_self_attention_bias,
152
        encoder_decoder_attention_bias,
153
        hparams,
154
        cache=cache,
155
        decode_loop_step=decode_loop_step,
156
        save_weights_to=self.attention_weights,
157
        losses=losses)
158

159
    if (common_layers.is_xla_compiled() and
160
        hparams.mode == tf_estimator.ModeKeys.TRAIN):
161
      return decoder_output
162
    else:
163
      # Expand since t2t expects 4d tensors.
164
      return tf.expand_dims(decoder_output, axis=2)
165

166
  def body(self, features):
167
    """Transformer main model_fn.
168

169
    Args:
170
      features: Map of features to the model. Should contain the following:
171
          "inputs": Transformer inputs.
172
              [batch_size, input_length, 1, hidden_dim].
173
          "targets": Target decoder outputs.
174
              [batch_size, decoder_length, 1, hidden_dim]
175
          "target_space_id": A scalar int from data_generators.problem.SpaceID.
176

177
    Returns:
178
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
179
    """
180
    hparams = self._hparams
181

182
    losses = []
183

184
    if self.has_input:
185
      inputs = features["inputs"]
186
      target_space = features["target_space_id"]
187
      encoder_output, encoder_decoder_attention_bias = self.encode(
188
          inputs, target_space, hparams, features=features, losses=losses)
189
    else:
190
      encoder_output, encoder_decoder_attention_bias = (None, None)
191

192
    targets = features["targets"]
193
    targets_shape = common_layers.shape_list(targets)
194
    targets = common_layers.flatten4d3d(targets)
195
    decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
196
        targets, hparams, features=features)
197
    decoder_output = self.decode(
198
        decoder_input,
199
        encoder_output,
200
        encoder_decoder_attention_bias,
201
        decoder_self_attention_bias,
202
        hparams,
203
        losses=losses)
204

205
    sparsity_technique = hparams.get("sparsity_technique")
206
    expected_attentions = features.get("expected_attentions")
207
    if expected_attentions is not None:
208
      assert not sparsity_technique
209

210
      attention_loss = common_attention.encoder_decoder_attention_loss(
211
          expected_attentions, self.attention_weights,
212
          hparams.expected_attention_loss_type,
213
          hparams.expected_attention_loss_multiplier)
214
      return decoder_output, {"attention_loss": attention_loss}
215

216
    # Add the extra loss term needed for each sparsity technique
217
    if sparsity_technique == "variational_dropout":
218
      losses += common_sparse.variational_dropout_dkl_loss(
219
          sparsity_check=True,
220
          threshold=hparams.get("log_alpha_threshold"),
221
          dkl_weight=hparams.get("dkl_weight"),
222
          begin_step=hparams.get("dkl_weight_start"),
223
          end_step=(hparams.get("dkl_weight_start") +
224
                    hparams.get("dkl_weight_diff")),
225
          weight_function=hparams.get("dkl_weight_fn"),
226
          clip_alpha=hparams.get("clip_log_alpha"))
227
    elif sparsity_technique == "l0_regularization":
228
      losses += common_sparse.l0_regularization_term(
229
          sparsity_check=True,
230
          regularization_weight=hparams.get("l0_norm_weight"),
231
          weight_start=hparams.get("l0_weight_start"),
232
          weight_end=(hparams.get("l0_weight_start") +
233
                      hparams.get("l0_weight_diff")),
234
          weight_function=hparams.get("dkl_weight_fn"))
235

236
    ret = tf.reshape(decoder_output, targets_shape)
237
    if losses:
238
      return ret, {"extra_loss": tf.add_n(losses)}
239
    else:
240
      return ret
241

242
  def _greedy_infer(self, features, decode_length, use_tpu=False):
243
    """Fast version of greedy decoding.
244

245
    Args:
246
      features: an map of string to `Tensor`
247
      decode_length: an integer.  How many additional timesteps to decode.
248
      use_tpu: A bool. Whether to build the inference graph for TPU.
249

250
    Returns:
251
      A dict of decoding results {
252
          "outputs": integer `Tensor` of decoded ids of shape
253
              [batch_size, <= decode_length] if beam_size == 1 or
254
              [batch_size, top_beams, <= decode_length]
255
          "scores": decoding log probs from the beam search,
256
              None if using greedy decoding (beam_size=1)
257
      }
258

259
    Raises:
260
      NotImplementedError: If there are multiple data shards.
261
    """
262
    # For real-valued modalities use the slow decode path for now.
263
    if (self._target_modality_is_real or
264
        self._hparams.self_attention_type != "dot_product"):
265
      return  super(SparseTransformer, self)._greedy_infer(
266
          features, decode_length)
267
    with tf.variable_scope(self.name):
268
      return (self._fast_decode_tpu(features, decode_length) if use_tpu else
269
              self._fast_decode(features, decode_length))
270

271
  def _beam_decode(self,
272
                   features,
273
                   decode_length,
274
                   beam_size,
275
                   top_beams,
276
                   alpha,
277
                   use_tpu=False):
278
    """Beam search decoding.
279

280
    Args:
281
      features: an map of string to `Tensor`
282
      decode_length: an integer.  How many additional timesteps to decode.
283
      beam_size: number of beams.
284
      top_beams: an integer. How many of the beams to return.
285
      alpha: Float that controls the length penalty. larger the alpha, stronger
286
        the preference for longer translations.
287
      use_tpu: A bool, whether to do beam decode on TPU.
288

289
    Returns:
290
      A dict of decoding results {
291
          "outputs": integer `Tensor` of decoded ids of shape
292
              [batch_size, <= decode_length] if beam_size == 1 or
293
              [batch_size, top_beams, <= decode_length]
294
          "scores": decoding log probs from the beam search,
295
              None if using greedy decoding (beam_size=1)
296
      }
297
    """
298
    if self._hparams.self_attention_type != "dot_product":
299
      # Caching is not guaranteed to work with attention types other than
300
      # dot_product.
301
      return self._beam_decode_slow(features, decode_length, beam_size,
302
                                    top_beams, alpha, use_tpu)
303
    with tf.variable_scope(self.name):
304
      if use_tpu:
305
        return self._fast_decode_tpu(
306
            features, decode_length, beam_size, top_beams, alpha)
307
      else:
308
        return self._fast_decode(
309
            features, decode_length, beam_size, top_beams, alpha)
310

311
  def _fast_decode_tpu(self,
312
                       features,
313
                       decode_length,
314
                       beam_size=1,
315
                       top_beams=1,
316
                       alpha=1.0):
317
    """Fast decoding.
318

319
    Implements both greedy and beam search decoding on TPU, uses beam search
320
    iff beam_size > 1, otherwise beam search related arguments are ignored.
321

322
    Args:
323
      features: A map of string to model features.
324
      decode_length: An integer, how many additional timesteps to decode.
325
      beam_size: An integer, number of beams.
326
      top_beams: An integer, how many of the beams to return.
327
      alpha: A float that controls the length penalty. Larger the alpha,
328
        stronger the preference for longer translations.
329

330
    Returns:
331
      A dict of decoding results {
332
          "outputs": integer `Tensor` of decoded ids of shape
333
              [batch_size, <= decode_length] if beam_size == 1 or
334
              [batch_size, top_beams, <= decode_length]
335
          "scores": decoding log probs from the beam search,
336
              None if using greedy decoding (beam_size=1)
337
      }.
338

339
    Raises:
340
      NotImplementedError: If there are multiple data shards.
341
    """
342
    if self._num_datashards != 1:
343
      raise NotImplementedError("Fast decoding only supports a single shard.")
344
    if "targets_segmentation" in features:
345
      raise NotImplementedError(
346
          "Decoding not supported on packed datasets "
347
          " If you want to decode from a dataset, use the non-packed version"
348
          " of the dataset when decoding.")
349
    dp = self._data_parallelism
350
    hparams = self._hparams
351
    target_modality = self._problem_hparams.modality["targets"]
352
    target_vocab_size = self._problem_hparams.vocab_size["targets"]
353
    if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
354
      target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
355

356
    if self.has_input:
357
      inputs = features["inputs"]
358
      if target_modality == modalities.ModalityType.CLASS_LABEL:
359
        decode_length = 1
360
      else:
361
        decode_length = (
362
            common_layers.shape_list(inputs)[1] + features.get(
363
                "decode_length", decode_length))
364

365
      inputs = tf.expand_dims(inputs, axis=1)
366
      if len(inputs.shape) < 5:
367
        inputs = tf.expand_dims(inputs, axis=4)
368
      s = common_layers.shape_list(inputs)
369
      batch_size = s[0]
370
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
371
      # _shard_features called to ensure that the variable names match
372
      inputs = self._shard_features({"inputs": inputs})["inputs"]
373
      input_modality = self._problem_hparams.modality["inputs"]
374
      input_vocab_size = self._problem_hparams.vocab_size["inputs"]
375
      if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
376
        input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
377
      modality_name = hparams.name.get(
378
          "inputs",
379
          modalities.get_name(input_modality))(hparams, input_vocab_size)
380
      with tf.variable_scope(modality_name):
381
        bottom = hparams.bottom.get(
382
            "inputs", modalities.get_bottom(input_modality))
383
        inputs = dp(bottom, inputs, hparams, input_vocab_size)
384
      with tf.variable_scope("body"):
385
        encoder_output, encoder_decoder_attention_bias = dp(
386
            self.encode,
387
            inputs,
388
            features["target_space_id"],
389
            hparams,
390
            features=features)
391
      encoder_output = encoder_output[0]
392
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
393
      partial_targets = None
394
    else:
395
      # The problem has no inputs.
396
      encoder_output = None
397
      encoder_decoder_attention_bias = None
398

399
      # Prepare partial targets.
400
      # In either features["inputs"] or features["targets"].
401
      # We force the outputs to begin with these sequences.
402
      partial_targets = features.get("inputs")
403
      if partial_targets is None:
404
        partial_targets = features["targets"]
405
      assert partial_targets is not None
406
      partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
407
      partial_targets = tf.to_int64(partial_targets)
408
      partial_targets_shape = common_layers.shape_list(partial_targets)
409
      partial_targets_length = partial_targets_shape[1]
410
      decode_length = (
411
          partial_targets_length + features.get("decode_length", decode_length))
412
      batch_size = partial_targets_shape[0]
413

414
    if hparams.pos == "timing":
415
      positional_encoding = common_attention.get_timing_signal_1d(
416
          decode_length + 1, hparams.hidden_size)
417
    elif hparams.pos == "emb":
418
      positional_encoding = common_attention.add_positional_embedding(
419
          tf.zeros([1, decode_length + 1, hparams.hidden_size]),
420
          hparams.max_length, "body/targets_positional_embedding", None)
421
    else:
422
      positional_encoding = None
423

424
    def preprocess_targets(targets, i):
425
      """Performs preprocessing steps on the targets to prepare for the decoder.
426

427
      This includes:
428
        - Embedding the ids.
429
        - Flattening to 3D tensor.
430
        - Optionally adding timing signals.
431

432
      Args:
433
        targets: A tensor, inputs ids to the decoder. [batch_size, 1].
434
        i: An integer, Step number of the decoding loop.
435

436
      Returns:
437
        A tensor, processed targets [batch_size, 1, hidden_dim].
438
      """
439
      # _shard_features called to ensure that the variable names match
440
      targets = self._shard_features({"targets": targets})["targets"]
441
      modality_name = hparams.name.get(
442
          "targets",
443
          modalities.get_name(target_modality))(hparams, target_vocab_size)
444
      with tf.variable_scope(modality_name):
445
        bottom = hparams.bottom.get(
446
            "targets", modalities.get_targets_bottom(target_modality))
447
        targets = dp(bottom, targets, hparams, target_vocab_size)[0]
448
      targets = common_layers.flatten4d3d(targets)
449

450
      targets = tf.cond(
451
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
452

453
      if positional_encoding is not None:
454
        positional_encoding_shape = positional_encoding.shape.as_list()
455
        targets += tf.slice(
456
            positional_encoding, [0, i, 0],
457
            [positional_encoding_shape[0], 1, positional_encoding_shape[2]])
458
      return targets
459

460
    decoder_self_attention_bias = (
461
        common_attention.attention_bias_lower_triangle(decode_length))
462
    if hparams.proximity_bias:
463
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
464
          decode_length)
465

466
    def symbols_to_logits_tpu_fn(ids, i, cache):
467
      """Go from ids to logits for next symbol on TPU.
468

469
      Args:
470
        ids: A tensor, symbol IDs.
471
        i: An integer, step number of the decoding loop. Only used for inference
472
            on TPU.
473
        cache: A dict, containing tensors which are the results of previous
474
            attentions, used for fast decoding.
475

476
      Returns:
477
        ret: A tensor, computed logits.
478
        cache: A dict, containing tensors which are the results of previous
479
            attentions, used for fast decoding.
480
      """
481
      ids = ids[:, -1:]
482
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
483
      targets = preprocess_targets(targets, i)
484

485
      bias_shape = decoder_self_attention_bias.shape.as_list()
486
      bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0],
487
                      [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
488

489
      with tf.variable_scope("body"):
490
        body_outputs = dp(
491
            self.decode,
492
            targets,
493
            cache.get("encoder_output"),
494
            cache.get("encoder_decoder_attention_bias"),
495
            bias,
496
            hparams,
497
            cache,
498
            i)
499

500
      modality_name = hparams.name.get(
501
          "targets",
502
          modalities.get_name(target_modality))(hparams, target_vocab_size)
503
      with tf.variable_scope(modality_name):
504
        top = hparams.top.get("targets", modalities.get_top(target_modality))
505
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
506

507
      ret = tf.squeeze(logits, axis=[1, 2, 3])
508
      if partial_targets is not None:
509
        # If the position is within the given partial targets, we alter the
510
        # logits to always return those values.
511
        # A faster approach would be to process the partial targets in one
512
        # iteration in order to fill the corresponding parts of the cache.
513
        # This would require broader changes, though.
514
        vocab_size = tf.shape(ret)[1]
515

516
        def forced_logits():
517
          return tf.one_hot(
518
              tf.tile(
519
                  tf.slice(partial_targets, [0, i],
520
                           [partial_targets.shape.as_list()[0], 1]),
521
                  [beam_size]), vocab_size, 0.0, -1e9)
522

523
        ret = tf.cond(
524
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
525
      return ret, cache
526

527
    ret = fast_decode_tpu(
528
        encoder_output=encoder_output,
529
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
530
        symbols_to_logits_fn=symbols_to_logits_tpu_fn,
531
        hparams=hparams,
532
        decode_length=decode_length,
533
        vocab_size=target_vocab_size,
534
        beam_size=beam_size,
535
        top_beams=top_beams,
536
        alpha=alpha,
537
        batch_size=batch_size,
538
        force_decode_length=self._decode_hparams.force_decode_length)
539
    if partial_targets is not None:
540
      if beam_size <= 1 or top_beams <= 1:
541
        ret["outputs"] = ret["outputs"][:, partial_targets_length:]
542
      else:
543
        ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
544
    return ret
545

546
  def _fast_decode(self,
547
                   features,
548
                   decode_length,
549
                   beam_size=1,
550
                   top_beams=1,
551
                   alpha=1.0):
552
    """Fast decoding.
553

554
    Implements both greedy and beam search decoding, uses beam search iff
555
    beam_size > 1, otherwise beam search related arguments are ignored.
556

557
    Args:
558
      features: a map of string to model  features.
559
      decode_length: an integer.  How many additional timesteps to decode.
560
      beam_size: number of beams.
561
      top_beams: an integer. How many of the beams to return.
562
      alpha: Float that controls the length penalty. larger the alpha, stronger
563
        the preference for longer translations.
564

565
    Returns:
566
      A dict of decoding results {
567
          "outputs": integer `Tensor` of decoded ids of shape
568
              [batch_size, <= decode_length] if beam_size == 1 or
569
              [batch_size, top_beams, <= decode_length]
570
          "scores": decoding log probs from the beam search,
571
              None if using greedy decoding (beam_size=1)
572
      }
573

574
    Raises:
575
      NotImplementedError: If there are multiple data shards.
576
    """
577
    if self._num_datashards != 1:
578
      raise NotImplementedError("Fast decoding only supports a single shard.")
579
    dp = self._data_parallelism
580
    hparams = self._hparams
581
    target_modality = self._problem_hparams.modality["targets"]
582
    target_vocab_size = self._problem_hparams.vocab_size["targets"]
583
    if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
584
      target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
585
    if "targets_segmentation" in features:
586
      raise NotImplementedError(
587
          "Decoding not supported on packed datasets "
588
          " If you want to decode from a dataset, use the non-packed version"
589
          " of the dataset when decoding.")
590
    if self.has_input:
591
      inputs = features["inputs"]
592
      if target_modality == modalities.ModalityType.CLASS_LABEL:
593
        decode_length = 1
594
      else:
595
        decode_length = (
596
            common_layers.shape_list(inputs)[1] + features.get(
597
                "decode_length", decode_length))
598

599
      inputs = tf.expand_dims(inputs, axis=1)
600
      if len(inputs.shape) < 5:
601
        inputs = tf.expand_dims(inputs, axis=4)
602
      s = common_layers.shape_list(inputs)
603
      batch_size = s[0]
604
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
605
      # _shard_features called to ensure that the variable names match
606
      inputs = self._shard_features({"inputs": inputs})["inputs"]
607
      input_modality = self._problem_hparams.modality["inputs"]
608
      input_vocab_size = self._problem_hparams.vocab_size["inputs"]
609
      if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
610
        input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
611
      modality_name = hparams.name.get(
612
          "inputs",
613
          modalities.get_name(input_modality))(hparams, input_vocab_size)
614
      with tf.variable_scope(modality_name):
615
        bottom = hparams.bottom.get(
616
            "inputs", modalities.get_bottom(input_modality))
617
        inputs = dp(bottom, inputs, hparams, input_vocab_size)
618
      with tf.variable_scope("body"):
619
        encoder_output, encoder_decoder_attention_bias = dp(
620
            self.encode,
621
            inputs,
622
            features["target_space_id"],
623
            hparams,
624
            features=features)
625
      encoder_output = encoder_output[0]
626
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
627
      partial_targets = None
628
    else:
629
      # The problem has no inputs.
630
      encoder_output = None
631
      encoder_decoder_attention_bias = None
632

633
      # Prepare partial targets.
634
      # In either features["inputs"] or features["targets"].
635
      # We force the outputs to begin with these sequences.
636
      partial_targets = features.get("inputs")
637
      if partial_targets is None:
638
        partial_targets = features["targets"]
639
      assert partial_targets is not None
640
      partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
641
      partial_targets = tf.to_int64(partial_targets)
642
      partial_targets_shape = common_layers.shape_list(partial_targets)
643
      partial_targets_length = partial_targets_shape[1]
644
      decode_length = (
645
          partial_targets_length + features.get("decode_length", decode_length))
646
      batch_size = partial_targets_shape[0]
647

648
    if hparams.pos == "timing":
649
      positional_encoding = common_attention.get_timing_signal_1d(
650
          decode_length + 1, hparams.hidden_size)
651
    elif hparams.pos == "emb":
652
      positional_encoding = common_attention.add_positional_embedding(
653
          tf.zeros([1, decode_length, hparams.hidden_size]),
654
          hparams.max_length, "body/targets_positional_embedding", None)
655
    else:
656
      positional_encoding = None
657

658
    def preprocess_targets(targets, i):
659
      """Performs preprocessing steps on the targets to prepare for the decoder.
660

661
      This includes:
662
        - Embedding the ids.
663
        - Flattening to 3D tensor.
664
        - Optionally adding timing signals.
665

666
      Args:
667
        targets: inputs ids to the decoder. [batch_size, 1]
668
        i: scalar, Step number of the decoding loop.
669

670
      Returns:
671
        Processed targets [batch_size, 1, hidden_dim]
672
      """
673
      # _shard_features called to ensure that the variable names match
674
      targets = self._shard_features({"targets": targets})["targets"]
675
      modality_name = hparams.name.get(
676
          "targets",
677
          modalities.get_name(target_modality))(hparams, target_vocab_size)
678
      with tf.variable_scope(modality_name):
679
        bottom = hparams.bottom.get(
680
            "targets", modalities.get_targets_bottom(target_modality))
681
        targets = dp(bottom, targets, hparams, target_vocab_size)[0]
682
      targets = common_layers.flatten4d3d(targets)
683

684
      targets = tf.cond(
685
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
686

687
      if positional_encoding is not None:
688
        targets += positional_encoding[:, i:i + 1]
689
      return targets
690

691
    decoder_self_attention_bias = (
692
        common_attention.attention_bias_lower_triangle(decode_length))
693
    if hparams.proximity_bias:
694
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
695
          decode_length)
696

697
    def symbols_to_logits_fn(ids, i, cache):
698
      """Go from ids to logits for next symbol."""
699
      ids = ids[:, -1:]
700
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
701
      targets = preprocess_targets(targets, i)
702

703
      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
704

705
      with tf.variable_scope("body"):
706
        body_outputs = dp(
707
            self.decode,
708
            targets,
709
            cache.get("encoder_output"),
710
            cache.get("encoder_decoder_attention_bias"),
711
            bias,
712
            hparams,
713
            cache)
714

715
      modality_name = hparams.name.get(
716
          "targets",
717
          modalities.get_name(target_modality))(hparams, target_vocab_size)
718
      with tf.variable_scope(modality_name):
719
        top = hparams.top.get("targets", modalities.get_top(target_modality))
720
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
721

722
      ret = tf.squeeze(logits, axis=[1, 2, 3])
723
      if partial_targets is not None:
724
        # If the position is within the given partial targets, we alter the
725
        # logits to always return those values.
726
        # A faster approach would be to process the partial targets in one
727
        # iteration in order to fill the corresponding parts of the cache.
728
        # This would require broader changes, though.
729
        vocab_size = tf.shape(ret)[1]
730

731
        def forced_logits():
732
          return tf.one_hot(
733
              tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
734
              -1e9)
735

736
        ret = tf.cond(
737
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
738
      return ret, cache
739

740
    ret = fast_decode(
741
        encoder_output=encoder_output,
742
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
743
        symbols_to_logits_fn=symbols_to_logits_fn,
744
        hparams=hparams,
745
        decode_length=decode_length,
746
        vocab_size=target_vocab_size,
747
        beam_size=beam_size,
748
        top_beams=top_beams,
749
        alpha=alpha,
750
        batch_size=batch_size,
751
        force_decode_length=self._decode_hparams.force_decode_length)
752
    if partial_targets is not None:
753
      if beam_size <= 1 or top_beams <= 1:
754
        ret["outputs"] = ret["outputs"][:, partial_targets_length:]
755
      else:
756
        ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
757
    return ret
758

759

760
def fast_decode_tpu(encoder_output,
761
                    encoder_decoder_attention_bias,
762
                    symbols_to_logits_fn,
763
                    hparams,
764
                    decode_length,
765
                    vocab_size,
766
                    beam_size=1,
767
                    top_beams=1,
768
                    alpha=1.0,
769
                    sos_id=0,
770
                    eos_id=beam_search.EOS_ID,
771
                    batch_size=None,
772
                    force_decode_length=False,
773
                    scope_prefix="body/"):
774
  """Given encoder output and a symbols to logits function, does fast decoding.
775

776
  Implements both greedy and beam search decoding for TPU, uses beam search iff
777
  beam_size > 1, otherwise beam search related arguments are ignored.
778

779
  Args:
780
    encoder_output: A tensor, output from encoder.
781
    encoder_decoder_attention_bias: A tensor, bias for use in encoder-decoder
782
        attention.
783
    symbols_to_logits_fn: Incremental decoding, function mapping triple
784
        `(ids, step, cache)` to symbol logits.
785
    hparams: Run hyperparameters.
786
    decode_length: An integer, how many additional timesteps to decode.
787
    vocab_size: Output vocabulary size.
788
    beam_size: An integer, number of beams.
789
    top_beams: An integer, how many of the beams to return.
790
    alpha: A float that controls the length penalty. Larger the alpha, stronger
791
      the preference for longer translations.
792
    sos_id: Start-of-sequence symbol.
793
    eos_id: End-of-sequence symbol.
794
    batch_size: An integer, must be passed if there is no input.
795
    force_decode_length: A bool, whether to force the full decode length, or if
796
        False, stop when all beams hit eos_id.
797
    scope_prefix: str, prefix for decoder layer variable scopes.
798

799
  Returns:
800
    A dict of decoding results {
801
        "outputs": integer `Tensor` of decoded ids of shape
802
            [batch_size, <= decode_length] if top_beams == 1 or
803
            [batch_size, top_beams, <= decode_length] otherwise
804
        "scores": decoding log probs from the beam search,
805
            None if using greedy decoding (beam_size=1)
806
    }.
807

808
  Raises:
809
    NotImplementedError: If beam size > 1 with partial targets.
810
  """
811
  if encoder_output is not None:
812
    batch_size = common_layers.shape_list(encoder_output)[0]
813

814
  key_channels = hparams.attention_key_channels or hparams.hidden_size
815
  value_channels = hparams.attention_value_channels or hparams.hidden_size
816
  num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
817
  vars_3d_num_heads = (
818
      hparams.num_heads if hparams.get("attention_variables_3d") else 0)
819

820
  cache = {
821
      "layer_%d" % layer: {  # pylint: disable=g-complex-comprehension
822
          "k":
823
          common_attention.split_heads(
824
              tf.zeros([batch_size, decode_length, key_channels]),
825
              hparams.num_heads),
826
          "v":
827
          common_attention.split_heads(
828
              tf.zeros([batch_size, decode_length, value_channels]),
829
              hparams.num_heads),
830
          "f":
831
          tf.zeros([batch_size, decode_length, hparams.hidden_size]),
832
      } for layer in range(num_layers)
833
  }
834

835
  if encoder_output is not None:
836
    for layer in range(num_layers):
837
      layer_name = "layer_%d" % layer
838
      with tf.variable_scope(
839
          "%sdecoder/%s/encdec_attention/multihead_attention" % (scope_prefix,
840
                                                                 layer_name)):
841
        initial_sparsity = None
842
        if hparams.get("load_masks_from"):
843
          initial_sparsity = hparams.get("initial_sparsity")
844

845
        k_encdec = sparse_attention.compute_attention_component(
846
            encoder_output, key_channels, name="k",
847
            vars_3d_num_heads=vars_3d_num_heads,
848
            sparsity_technique=hparams.get("sparsity_technique"),
849
            threshold=hparams.get("log_alpha_threshold"),
850
            training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
851
            clip_alpha=hparams.get("clip_log_alpha"),
852
            initial_sparsity=initial_sparsity,
853
            split_heads=hparams.get("split_heads"),
854
            num_heads=hparams.num_heads)
855
        k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
856
        v_encdec = sparse_attention.compute_attention_component(
857
            encoder_output, value_channels, name="v",
858
            vars_3d_num_heads=vars_3d_num_heads,
859
            sparsity_technique=hparams.get("sparsity_technique"),
860
            threshold=hparams.get("log_alpha_threshold"),
861
            training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
862
            clip_alpha=hparams.get("clip_log_alpha"),
863
            initial_sparsity=initial_sparsity,
864
            split_heads=hparams.get("split_heads"),
865
            num_heads=hparams.num_heads)
866
        v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
867
      cache[layer_name]["k_encdec"] = k_encdec
868
      cache[layer_name]["v_encdec"] = v_encdec
869

870
    cache["encoder_output"] = encoder_output
871
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
872

873
  mlperf_log.transformer_print(
874
      key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH,
875
      value={
876
          "vocab_size": vocab_size,
877
          "batch_size": batch_size,
878
          "beam_size": beam_size,
879
          "alpha": alpha,
880
          "max_decode_length": decode_length
881
      })
882
  if beam_size > 1:  # Beam Search
883
    initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
884
    decoded_ids, scores, _ = beam_search.beam_search(
885
        symbols_to_logits_fn,
886
        initial_ids,
887
        beam_size,
888
        decode_length,
889
        vocab_size,
890
        alpha,
891
        states=cache,
892
        eos_id=eos_id,
893
        stop_early=(top_beams == 1),
894
        use_tpu=True)
895

896
    if top_beams == 1:
897
      decoded_ids = decoded_ids[:, 0, 1:]
898
      scores = scores[:, 0]
899
    else:
900
      decoded_ids = decoded_ids[:, :top_beams, 1:]
901
      scores = scores[:, :top_beams]
902
  else:  # Greedy
903
    def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
904
      """One step of greedy decoding."""
905
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
906
      log_probs = common_layers.log_prob_from_logits(logits)
907
      temperature = (0.0 if hparams.sampling_method == "argmax" else
908
                     hparams.sampling_temp)
909
      next_id = common_layers.sample_with_temperature(logits, temperature)
910
      hit_eos |= tf.equal(next_id, eos_id)
911

912
      log_prob_indices = tf.stack(
913
          [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
914
      log_prob += tf.gather_nd(log_probs, log_prob_indices)
915

916
      next_id = tf.expand_dims(next_id, axis=1)
917
      decoded_ids = tf.transpose(decoded_ids)
918
      decoded_ids = inplace_ops.alias_inplace_update(
919
          decoded_ids, i, tf.squeeze(next_id, axis=1))
920
      decoded_ids = tf.transpose(decoded_ids)
921
      return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
922

923
    def is_not_finished(i, hit_eos, *_):
924
      finished = i >= decode_length
925
      if not force_decode_length:
926
        finished |= tf.reduce_all(hit_eos)
927
      return tf.logical_not(finished)
928

929
    decoded_ids = tf.zeros([batch_size, decode_length], dtype=tf.int64)
930
    hit_eos = tf.fill([batch_size], False)
931
    next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
932
    initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
933

934
    def compute_cache_shape_invariants(tensor):
935
      return tf.TensorShape(tensor.shape.as_list())
936

937
    _, _, _, decoded_ids, _, log_prob = tf.while_loop(
938
        is_not_finished,
939
        inner_loop, [
940
            tf.constant(0), hit_eos, next_id, decoded_ids, cache,
941
            initial_log_prob
942
        ],
943
        shape_invariants=[
944
            tf.TensorShape([]),
945
            tf.TensorShape([batch_size]),
946
            tf.TensorShape([batch_size, 1]),
947
            tf.TensorShape([batch_size, decode_length]),
948
            nest.map_structure(compute_cache_shape_invariants, cache),
949
            tf.TensorShape([batch_size]),
950
        ])
951
    scores = log_prob
952

953
  return {"outputs": decoded_ids, "scores": scores}
954

955

956
def fast_decode(encoder_output,
957
                encoder_decoder_attention_bias,
958
                symbols_to_logits_fn,
959
                hparams,
960
                decode_length,
961
                vocab_size,
962
                beam_size=1,
963
                top_beams=1,
964
                alpha=1.0,
965
                sos_id=0,
966
                eos_id=beam_search.EOS_ID,
967
                batch_size=None,
968
                force_decode_length=False,
969
                scope_prefix="body/"):
970
  """Given encoder output and a symbols to logits function, does fast decoding.
971

972
  Implements both greedy and beam search decoding, uses beam search iff
973
  beam_size > 1, otherwise beam search related arguments are ignored.
974

975
  Args:
976
    encoder_output: Output from encoder.
977
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
978
      attention
979
    symbols_to_logits_fn: Incremental decoding; function mapping triple
980
      `(ids, step, cache)` to symbol logits.
981
    hparams: run hyperparameters
982
    decode_length: an integer.  How many additional timesteps to decode.
983
    vocab_size: Output vocabulary size.
984
    beam_size: number of beams.
985
    top_beams: an integer. How many of the beams to return.
986
    alpha: Float that controls the length penalty. larger the alpha, stronger
987
      the preference for longer translations.
988
    sos_id: End-of-sequence symbol in beam search.
989
    eos_id: End-of-sequence symbol in beam search.
990
    batch_size: an integer scalar - must be passed if there is no input
991
    force_decode_length: bool, whether to force the full decode length, or if
992
      False, stop when all beams hit eos_id.
993
    scope_prefix: str, prefix for decoder layer variable scopes.
994

995
  Returns:
996
      A dict of decoding results {
997
          "outputs": integer `Tensor` of decoded ids of shape
998
              [batch_size, <= decode_length] if top_beams == 1 or
999
              [batch_size, top_beams, <= decode_length] otherwise
1000
          "scores": decoding log probs from the beam search,
1001
              None if using greedy decoding (beam_size=1)
1002
      }
1003

1004
    Raises:
1005
      NotImplementedError: If beam size > 1 with partial targets.
1006
  """
1007
  if encoder_output is not None:
1008
    batch_size = common_layers.shape_list(encoder_output)[0]
1009

1010
  key_channels = hparams.attention_key_channels or hparams.hidden_size
1011
  value_channels = hparams.attention_value_channels or hparams.hidden_size
1012
  num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
1013
  vars_3d_num_heads = (
1014
      hparams.num_heads if hparams.get("attention_variables_3d") else 0)
1015

1016
  cache = {
1017
      "layer_%d" % layer: {  # pylint: disable=g-complex-comprehension
1018
          "k":
1019
              common_attention.split_heads(
1020
                  tf.zeros([batch_size, 0, key_channels]), hparams.num_heads),
1021
          "v":
1022
              common_attention.split_heads(
1023
                  tf.zeros([batch_size, 0, value_channels]), hparams.num_heads),
1024
          "f":
1025
              tf.zeros([batch_size, 0, hparams.hidden_size]),
1026
      } for layer in range(num_layers)
1027
  }
1028

1029
  if encoder_output is not None:
1030
    for layer in range(num_layers):
1031
      layer_name = "layer_%d" % layer
1032
      with tf.variable_scope(
1033
          "%sdecoder/%s/encdec_attention/multihead_attention" % (scope_prefix,
1034
                                                                 layer_name)):
1035
        initial_sparsity = None
1036
        if hparams.get("load_masks_from"):
1037
          initial_sparsity = hparams.get("initial_sparsity")
1038

1039
        k_encdec = sparse_attention.compute_attention_component(
1040
            encoder_output, key_channels, name="k",
1041
            vars_3d_num_heads=vars_3d_num_heads,
1042
            sparsity_technique=hparams.get("sparsity_technique"),
1043
            threshold=hparams.get("log_alpha_threshold"),
1044
            training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1045
            clip_alpha=hparams.get("clip_log_alpha"),
1046
            initial_sparsity=initial_sparsity,
1047
            split_heads=hparams.get("split_heads"),
1048
            num_heads=hparams.num_heads)
1049
        k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
1050
        v_encdec = sparse_attention.compute_attention_component(
1051
            encoder_output, value_channels, name="v",
1052
            vars_3d_num_heads=vars_3d_num_heads,
1053
            sparsity_technique=hparams.get("sparsity_technique"),
1054
            threshold=hparams.get("log_alpha_threshold"),
1055
            training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1056
            clip_alpha=hparams.get("clip_log_alpha"),
1057
            initial_sparsity=initial_sparsity,
1058
            split_heads=hparams.get("split_heads"),
1059
            num_heads=hparams.num_heads)
1060
        v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
1061
      cache[layer_name]["k_encdec"] = k_encdec
1062
      cache[layer_name]["v_encdec"] = v_encdec
1063

1064
    cache["encoder_output"] = encoder_output
1065
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
1066

1067
  if beam_size > 1:  # Beam Search
1068
    initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
1069
    decoded_ids, scores, _ = beam_search.beam_search(
1070
        symbols_to_logits_fn,
1071
        initial_ids,
1072
        beam_size,
1073
        decode_length,
1074
        vocab_size,
1075
        alpha,
1076
        states=cache,
1077
        eos_id=eos_id,
1078
        stop_early=(top_beams == 1))
1079

1080
    if top_beams == 1:
1081
      decoded_ids = decoded_ids[:, 0, 1:]
1082
      scores = scores[:, 0]
1083
    else:
1084
      decoded_ids = decoded_ids[:, :top_beams, 1:]
1085
      scores = scores[:, :top_beams]
1086
  else:  # Greedy
1087

1088
    def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
1089
      """One step of greedy decoding."""
1090
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
1091
      log_probs = common_layers.log_prob_from_logits(logits)
1092
      temperature = (0.0 if hparams.sampling_method == "argmax" else
1093
                     hparams.sampling_temp)
1094
      next_id = common_layers.sample_with_temperature(logits, temperature)
1095
      hit_eos |= tf.equal(next_id, eos_id)
1096

1097
      log_prob_indices = tf.stack(
1098
          [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
1099
      log_prob += tf.gather_nd(log_probs, log_prob_indices)
1100

1101
      next_id = tf.expand_dims(next_id, axis=1)
1102
      decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
1103
      return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
1104

1105
    def is_not_finished(i, hit_eos, *_):
1106
      finished = i >= decode_length
1107
      if not force_decode_length:
1108
        finished |= tf.reduce_all(hit_eos)
1109
      return tf.logical_not(finished)
1110

1111
    decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
1112
    hit_eos = tf.fill([batch_size], False)
1113
    next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
1114
    initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
1115
    _, _, _, decoded_ids, _, log_prob = tf.while_loop(
1116
        is_not_finished,
1117
        inner_loop, [
1118
            tf.constant(0), hit_eos, next_id, decoded_ids, cache,
1119
            initial_log_prob
1120
        ],
1121
        shape_invariants=[
1122
            tf.TensorShape([]),
1123
            tf.TensorShape([None]),
1124
            tf.TensorShape([None, None]),
1125
            tf.TensorShape([None, None]),
1126
            nest.map_structure(beam_search.get_state_shape_invariants, cache),
1127
            tf.TensorShape([None]),
1128
        ])
1129
    scores = log_prob
1130

1131
  return {"outputs": decoded_ids, "scores": scores}
1132

1133

1134
def features_to_nonpadding(features, inputs_or_targets="inputs"):
1135
  key = inputs_or_targets + "_segmentation"
1136
  if features and key in features:
1137
    return tf.minimum(tf.to_float(features[key]), 1.0)
1138
  return None
1139

1140

1141
def transformer_prepare_decoder(targets, hparams, features=None):
1142
  """Prepare one shard of the model for the decoder.
1143

1144
  Args:
1145
    targets: a Tensor.
1146
    hparams: run hyperparameters
1147
    features: optionally pass the entire features dictionary as well.
1148
      This is needed now for "packed" datasets.
1149

1150
  Returns:
1151
    decoder_input: a Tensor, bottom of decoder stack
1152
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
1153
  """
1154
  if hparams.causal_decoder_self_attention:
1155
    # Causal attention.
1156
    if hparams.prepend_mode == "prepend_inputs_full_attention":
1157
      decoder_self_attention_bias = (
1158
          common_attention.attention_bias_prepend_inputs_full_attention(
1159
              common_attention.embedding_to_padding(targets)))
1160
    else:
1161
      decoder_self_attention_bias = (
1162
          common_attention.attention_bias_lower_triangle(
1163
              common_layers.shape_list(targets)[1]))
1164
  else:
1165
    # Full attention.
1166
    decoder_padding = common_attention.embedding_to_padding(targets)
1167
    decoder_self_attention_bias = (
1168
        common_attention.attention_bias_ignore_padding(decoder_padding))
1169

1170
  if features and "targets_segmentation" in features:
1171
    # "Packed" dataset - keep the examples from seeing each other.
1172
    targets_segmentation = features["targets_segmentation"]
1173
    targets_position = features["targets_position"]
1174
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
1175
        targets_segmentation, targets_segmentation)
1176
  else:
1177
    targets_position = None
1178
  if hparams.proximity_bias:
1179
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
1180
        common_layers.shape_list(targets)[1])
1181
  decoder_input = common_layers.shift_right_3d(targets)
1182
  if hparams.pos == "timing":
1183
    if targets_position is not None:
1184
      decoder_input = common_attention.add_timing_signal_1d_given_position(
1185
          decoder_input, targets_position)
1186
    else:
1187
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
1188
  elif hparams.pos == "emb":
1189
    decoder_input = common_attention.add_positional_embedding(
1190
        decoder_input, hparams.max_length, "targets_positional_embedding",
1191
        targets_position)
1192

1193
  if hparams.activation_dtype == "bfloat16":
1194
    decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
1195
                                          tf.bfloat16)
1196
  return (decoder_input, decoder_self_attention_bias)
1197

1198

1199
def transformer_decoder(decoder_input,
1200
                        encoder_output,
1201
                        decoder_self_attention_bias,
1202
                        encoder_decoder_attention_bias,
1203
                        hparams,
1204
                        cache=None,
1205
                        decode_loop_step=None,
1206
                        name="decoder",
1207
                        save_weights_to=None,
1208
                        make_image_summary=True,
1209
                        losses=None):  # pylint: disable=unused-argument
1210
  """A stack of transformer layers.
1211

1212
  Args:
1213
    decoder_input: a Tensor
1214
    encoder_output: a Tensor
1215
    decoder_self_attention_bias: bias Tensor for self-attention
1216
      (see common_attention.attention_bias())
1217
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
1218
      (see common_attention.attention_bias())
1219
    hparams: hyperparameters for model
1220
    cache: dict, containing tensors which are the results of previous
1221
        attentions, used for fast decoding.
1222
    decode_loop_step: An integer, step number of the decoding loop.
1223
        Only used for inference on TPU.
1224
    name: a string
1225
    save_weights_to: an optional dictionary to capture attention weights
1226
      for visualization; the weights tensor will be appended there under
1227
      a string key created from the variable scope (including name).
1228
    make_image_summary: Whether to make an attention image summary.
1229
    losses: optional list onto which to append extra training losses
1230

1231
  Returns:
1232
    y: a Tensors
1233
  """
1234
  x = decoder_input
1235
  attention_dropout_broadcast_dims = (
1236
      common_layers.comma_separated_string_to_integer_list(
1237
          getattr(hparams, "attention_dropout_broadcast_dims", "")))
1238

1239
  mlperf_log.transformer_print(
1240
      key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
1241
      value=hparams.num_decoder_layers or hparams.num_hidden_layers)
1242
  mlperf_log.transformer_print(
1243
      key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
1244
      value=hparams.attention_dropout)
1245
  mlperf_log.transformer_print(
1246
      key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
1247
      value={
1248
          "use_bias": "false",
1249
          "num_heads": hparams.num_heads,
1250
          "hidden_size": hparams.hidden_size
1251
      })
1252

1253
  with tf.variable_scope(name):
1254
    for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers):
1255
      initial_sparsity = None
1256
      if hparams.get("load_masks_from"):
1257
        initial_sparsity = hparams.get("initial_sparsity")
1258

1259
      layer_name = "layer_%d" % layer
1260
      layer_cache = cache[layer_name] if cache is not None else None
1261
      with tf.variable_scope(layer_name):
1262
        with tf.variable_scope("self_attention"):
1263
          y = sparse_attention.multihead_attention(
1264
              common_layers.layer_preprocess(x, hparams),
1265
              None,
1266
              decoder_self_attention_bias,
1267
              hparams.attention_key_channels or hparams.hidden_size,
1268
              hparams.attention_value_channels or hparams.hidden_size,
1269
              hparams.hidden_size,
1270
              hparams.num_heads,
1271
              hparams.attention_dropout,
1272
              attention_type=hparams.self_attention_type,
1273
              max_relative_position=hparams.max_relative_position,
1274
              heads_share_relative_embedding=(
1275
                  hparams.heads_share_relative_embedding),
1276
              add_relative_to_values=hparams.add_relative_to_values,
1277
              save_weights_to=save_weights_to,
1278
              cache=layer_cache,
1279
              make_image_summary=make_image_summary,
1280
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
1281
              max_length=hparams.get("max_length"),
1282
              decode_loop_step=decode_loop_step,
1283
              vars_3d=hparams.get("attention_variables_3d"),
1284
              sparsity_technique=hparams.get("sparsity_technique"),
1285
              threshold=hparams.get("log_alpha_threshold"),
1286
              training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1287
              clip_alpha=hparams.get("clip_log_alpha"),
1288
              initial_sparsity=initial_sparsity,
1289
              split_heads=hparams.get("split_heads"))
1290
          x = common_layers.layer_postprocess(x, y, hparams)
1291
        if encoder_output is not None:
1292
          with tf.variable_scope("encdec_attention"):
1293
            y = sparse_attention.multihead_attention(
1294
                common_layers.layer_preprocess(x, hparams),
1295
                encoder_output,
1296
                encoder_decoder_attention_bias,
1297
                hparams.attention_key_channels or hparams.hidden_size,
1298
                hparams.attention_value_channels or hparams.hidden_size,
1299
                hparams.hidden_size,
1300
                hparams.num_heads,
1301
                hparams.attention_dropout,
1302
                max_relative_position=hparams.max_relative_position,
1303
                heads_share_relative_embedding=(
1304
                    hparams.heads_share_relative_embedding),
1305
                add_relative_to_values=hparams.add_relative_to_values,
1306
                save_weights_to=save_weights_to,
1307
                cache=layer_cache,
1308
                make_image_summary=make_image_summary,
1309
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
1310
                max_length=hparams.get("max_length"),
1311
                vars_3d=hparams.get("attention_variables_3d"),
1312
                sparsity_technique=hparams.get("sparsity_technique"),
1313
                threshold=hparams.get("log_alpha_threshold"),
1314
                training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1315
                clip_alpha=hparams.get("clip_log_alpha"),
1316
                initial_sparsity=initial_sparsity,
1317
                split_heads=hparams.get("split_heads"))
1318
            x = common_layers.layer_postprocess(x, y, hparams)
1319
        with tf.variable_scope("ffn"):
1320
          y = transformer_ffn_layer(
1321
              common_layers.layer_preprocess(x, hparams),
1322
              hparams)
1323
          x = common_layers.layer_postprocess(x, y, hparams)
1324
    # if normalization is done in layer_preprocess, then it should also be done
1325
    # on the output, since the output can grow very large, being the sum of
1326
    # a whole stack of unnormalized layer outputs.
1327
    mlperf_log.transformer_print(
1328
        key=mlperf_log.MODEL_HP_NORM,
1329
        value={"hidden_size": hparams.hidden_size})
1330
    return common_layers.layer_preprocess(x, hparams)
1331

1332

1333
@registry.register_hparams
1334
def sparse_transformer_base_v1():
1335
  """Set of hyperparameters."""
1336
  hparams = common_hparams.basic_params1()
1337
  hparams.norm_type = "layer"
1338
  hparams.hidden_size = 512
1339
  hparams.batch_size = 4096
1340
  hparams.max_length = 256
1341
  hparams.clip_grad_norm = 0.  # i.e. no gradient clipping
1342
  hparams.optimizer_adam_epsilon = 1e-9
1343
  hparams.learning_rate_schedule = "legacy"
1344
  hparams.learning_rate_decay_scheme = "noam"
1345
  hparams.learning_rate = 0.1
1346
  hparams.learning_rate_warmup_steps = 4000
1347
  hparams.initializer_gain = 1.0
1348
  hparams.num_hidden_layers = 6
1349
  hparams.initializer = "uniform_unit_scaling"
1350
  hparams.weight_decay = 0.0
1351
  hparams.optimizer_adam_beta1 = 0.9
1352
  hparams.optimizer_adam_beta2 = 0.98
1353
  hparams.num_sampled_classes = 0
1354
  hparams.label_smoothing = 0.1
1355
  hparams.shared_embedding_and_softmax_weights = True
1356
  hparams.symbol_modality_num_shards = 16
1357

1358
  # Add new ones like this.
1359
  hparams.add_hparam("filter_size", 2048)
1360
  # Layer-related flags. If zero, these fall back on hparams.num_hidden_layers.
1361
  hparams.add_hparam("num_encoder_layers", 0)
1362
  hparams.add_hparam("num_decoder_layers", 0)
1363
  # Attention-related flags.
1364
  hparams.add_hparam("num_heads", 8)
1365
  hparams.add_hparam("attention_key_channels", 0)
1366
  hparams.add_hparam("attention_value_channels", 0)
1367
  hparams.add_hparam("ffn_layer", "dense_relu_dense")
1368
  hparams.add_hparam("parameter_attention_key_channels", 0)
1369
  hparams.add_hparam("parameter_attention_value_channels", 0)
1370
  # All hyperparameters ending in "dropout" are automatically set to 0.0
1371
  # when not in training mode.
1372
  hparams.add_hparam("attention_dropout", 0.0)
1373
  hparams.add_hparam("attention_dropout_broadcast_dims", "")
1374
  hparams.add_hparam("relu_dropout", 0.0)
1375
  hparams.add_hparam("relu_dropout_broadcast_dims", "")
1376
  hparams.add_hparam("pos", "timing")  # timing, none
1377
  hparams.add_hparam("nbr_decoder_problems", 1)
1378
  hparams.add_hparam("proximity_bias", False)
1379
  hparams.add_hparam("causal_decoder_self_attention", True)
1380
  hparams.add_hparam("use_pad_remover", True)
1381
  hparams.add_hparam("self_attention_type", "dot_product")
1382
  hparams.add_hparam("conv_first_kernel", 3)
1383
  hparams.add_hparam("attention_variables_3d", False)
1384
  hparams.add_hparam("use_target_space_embedding", True)
1385
  # These parameters are only used when ffn_layer=="local_moe_tpu"
1386
  hparams.add_hparam("moe_overhead_train", 1.0)
1387
  hparams.add_hparam("moe_overhead_eval", 2.0)
1388
  hparams.moe_num_experts = 16
1389
  hparams.moe_loss_coef = 1e-3
1390

1391
  # Sparsity hyper-parameters
1392
  hparams.add_hparam("sparsity_technique", None)
1393
  hparams.add_hparam("log_alpha_threshold", 3.0)
1394

1395
  # variational dropout & l0 parameters
1396
  hparams.add_hparam("dkl_weight_fn", "linear")
1397

1398
  # variational dropout parameters
1399
  hparams.add_hparam("dkl_weight", 1 / (4.5 * 10 ** 6))
1400
  hparams.add_hparam("clip_log_alpha", 8.0)
1401
  hparams.add_hparam("dkl_weight_start", 100000)
1402
  hparams.add_hparam("dkl_weight_diff", 100000)
1403

1404
  # l0-regularization parameters
1405
  hparams.add_hparam("l0_norm_weight", 1 / (4.5 * 10 ** 6))
1406
  hparams.add_hparam("l0_weight_start", 100000)
1407
  hparams.add_hparam("l0_weight_diff", 100000)
1408

1409
  # magnitude & random pruning parameters
1410
  hparams.add_hparam("begin_pruning_step", 0)
1411
  hparams.add_hparam("end_pruning_step", 200000)
1412
  hparams.add_hparam("pruning_frequency", 10000)
1413
  hparams.add_hparam("target_sparsity", .9)
1414

1415
  # whether we should prune the weights for
1416
  hparams.add_hparam("split_heads", False)
1417

1418
  # mp & rp parameters we don't really change
1419
  hparams.add_hparam("threshold_decay", 0.0)
1420
  hparams.add_hparam("nbins", 1024)
1421
  hparams.add_hparam("sparsity_function_exponent", 3.0)
1422

1423
  # use sparse embedding and softmax layer
1424
  hparams.bottom = {
1425
      "targets": sparse_modalities.targets_bottom,
1426
      "inputs": sparse_modalities.bottom
1427
  }
1428
  hparams.top = {
1429
      "targets": sparse_modalities.top,
1430
  }
1431

1432
  # specify to load trained masks from checkpoint
1433
  hparams.add_hparam("load_masks_from", "")
1434
  hparams.add_hparam("load_weights_from", "")
1435
  hparams.add_hparam("initial_sparsity", 0.0)
1436

1437
  # If < 0, use this sparsity level for the embedding
1438
  # matrix instead of the target_sparsity.
1439
  hparams.add_hparam("embedding_sparsity", -1.0)
1440
  return hparams
1441

1442

1443
@registry.register_hparams
1444
def sparse_transformer_base_v2():
1445
  """Set of hyperparameters."""
1446
  hparams = sparse_transformer_base_v1()
1447
  hparams.layer_preprocess_sequence = "n"
1448
  hparams.layer_postprocess_sequence = "da"
1449
  hparams.layer_prepostprocess_dropout = 0.1
1450
  hparams.attention_dropout = 0.1
1451
  hparams.relu_dropout = 0.1
1452
  hparams.learning_rate_warmup_steps = 8000
1453
  hparams.learning_rate = 0.2
1454
  return hparams
1455

1456

1457
@registry.register_hparams
1458
def sparse_transformer_base_v3():
1459
  """Base parameters for Transformer model."""
1460
  # Update parameters here, then occasionally cut a versioned set, e.g.
1461
  # transformer_base_v2.
1462
  hparams = sparse_transformer_base_v2()
1463
  hparams.optimizer_adam_beta2 = 0.997
1464
  # New way of specifying learning rate schedule.
1465
  # Equivalent to previous version.
1466
  hparams.learning_rate_schedule = (
1467
      "constant*linear_warmup*rsqrt_decay*rsqrt_hidden_size")
1468
  hparams.learning_rate_constant = 2.0
1469
  return hparams
1470

1471

1472
@registry.register_hparams
1473
def sparse_transformer_base():
1474
  """Base parameters for Transformer model."""
1475
  hparams = sparse_transformer_base_v3()
1476
  return hparams
1477

1478

1479
@registry.register_hparams
1480
def sparse_transformer_tiny():
1481
  hparams = sparse_transformer_base()
1482
  hparams.num_hidden_layers = 2
1483
  hparams.hidden_size = 128
1484
  hparams.filter_size = 512
1485
  hparams.num_heads = 4
1486
  return hparams
1487

1488

1489
@registry.register_hparams
1490
def sparse_transformer_tiny_variational_dropout():
1491
  hparams = sparse_transformer_tiny()
1492
  hparams.sparsity_technique = "variational_dropout"
1493
  return hparams
1494

1495

1496
@registry.register_hparams
1497
def sparse_transformer_tiny_l0_regularization():
1498
  hparams = sparse_transformer_tiny()
1499
  hparams.sparsity_technique = "l0_regularization"
1500
  return hparams
1501

1502

1503
@registry.register_hparams
1504
def sparse_transformer_tiny_magnitude_pruning():
1505
  hparams = sparse_transformer_tiny()
1506
  hparams.sparsity_technique = "magnitude_pruning"
1507
  return hparams
1508

1509

1510
@registry.register_hparams
1511
def sparse_transformer_tiny_shmp():
1512
  hparams = sparse_transformer_tiny()
1513
  hparams.sparsity_technique = "magnitude_pruning"
1514
  hparams.split_heads = True
1515
  return hparams
1516

1517

1518
@registry.register_hparams
1519
def sparse_transformer_tiny_random_pruning():
1520
  hparams = sparse_transformer_tiny()
1521
  hparams.sparsity_technique = "random_pruning"
1522
  return hparams
1523

1524

1525
def update_hparams_for_tpu(hparams):
1526
  """Change hparams to be compatible with TPU training."""
1527

1528
  # Adafactor uses less memory than Adam.
1529
  # switch to Adafactor with its recommended learning rate scheme.
1530
  hparams.optimizer = "Adafactor"
1531
  hparams.learning_rate_schedule = "rsqrt_decay"
1532
  hparams.learning_rate_warmup_steps = 10000
1533

1534
  # Avoid an expensive concat on TPU.
1535
  # >1 shards helps with faster parameter distribution on multi-GPU machines
1536
  hparams.symbol_modality_num_shards = 1
1537

1538
  # Adaptive batch sizes and sequence lengths are not supported on TPU.
1539
  # Instead, every batch has the same sequence length and the same batch size.
1540
  # Longer sequences are dropped and shorter ones are padded.
1541
  #
1542
  # It is therefore suggested to use a problem where examples have been combined
1543
  # to a longer length, e.g. the "_packed" problems.
1544
  #
1545
  # For problems with variable sequence lengths, this parameter controls the
1546
  # maximum sequence length.  Shorter sequences are dropped and longer ones
1547
  # are padded.
1548
  #
1549
  # For problems with fixed sequence lengths - e.g. the "_packed" problems,
1550
  # this hyperparameter is ignored.
1551
  hparams.max_length = 64
1552

1553
  # TPUs have less memory than GPUs, so decrease the batch size
1554
  hparams.batch_size = 2048
1555

1556
  # Using noise broadcast in the dropout layers saves memory during training.
1557
  hparams.attention_dropout_broadcast_dims = "0,1"  # batch, heads
1558
  hparams.relu_dropout_broadcast_dims = "1"  # length
1559
  hparams.layer_prepostprocess_dropout_broadcast_dims = "1"  # length
1560

1561

1562
@registry.register_hparams
1563
def sparse_transformer_tpu():
1564
  """HParams for Transformer model on TPU."""
1565
  hparams = sparse_transformer_base()
1566
  update_hparams_for_tpu(hparams)
1567
  return hparams
1568

1569

1570
@registry.register_hparams
1571
def sparse_transformer_tiny_tpu():
1572
  hparams = sparse_transformer_tiny()
1573
  update_hparams_for_tpu(hparams)
1574
  return hparams
1575

1576

1577
@registry.register_hparams
1578
def sparse_transformer_magnitude_pruning_tpu():
1579
  hparams = sparse_transformer_base()
1580
  hparams.symbol_modality_num_shards = 1
1581
  hparams.max_length = 64
1582
  hparams.batch_size = 2048
1583

1584
  hparams.sparsity_technique = "magnitude_pruning"
1585
  return hparams
1586

1587

1588
@registry.register_hparams
1589
def sparse_transformer_random_pruning_tpu():
1590
  hparams = sparse_transformer_base()
1591
  hparams.symbol_modality_num_shards = 1
1592
  hparams.max_length = 64
1593
  hparams.batch_size = 2048
1594

1595
  hparams.sparsity_technique = "random_pruning"
1596
  return hparams
1597

1598

1599
@registry.register_hparams
1600
def sparse_transformer_variational_dropout_tpu():
1601
  hparams = sparse_transformer_base()
1602
  hparams.symbol_modality_num_shards = 1
1603
  hparams.max_length = 64
1604
  hparams.batch_size = 2048
1605

1606
  hparams.sparsity_technique = "variational_dropout"
1607
  return hparams
1608

1609

1610
@registry.register_hparams
1611
def sparse_transformer_l0_regularization_tpu():
1612
  hparams = sparse_transformer_base()
1613
  hparams.symbol_modality_num_shards = 1
1614
  hparams.max_length = 64
1615
  hparams.batch_size = 2048
1616

1617
  hparams.sparsity_technique = "l0_regularization"
1618
  return hparams
1619

1620

1621
@registry.register_hparams
1622
def sparse_transformer_mpfc_tpu():
1623
  """Magnitude pruning without embedding pruning."""
1624
  hparams = sparse_transformer_base()
1625
  hparams.symbol_modality_num_shards = 1
1626
  hparams.max_length = 64
1627
  hparams.batch_size = 4096  # double the batch size
1628

1629
  hparams.sparsity_technique = "magnitude_pruning"
1630

1631
  # use the default modality, i.e. don't prune the embedding
1632
  # or the final linear layer before the softmax.
1633
  hparams.modality = {}
1634
  return hparams
1635

1636

1637
@registry.register_hparams
1638
def sparse_transformer_mpfc_2k_tpu():
1639
  hparams = sparse_transformer_mpfc_tpu()
1640
  hparams.batch_size = 2048  # use the standard batch size
1641
  return hparams
1642

1643

1644
@registry.register_hparams
1645
def sparse_transformer_split_head_mpfc_tpu():
1646
  hparams = sparse_transformer_mpfc_tpu()
1647

1648
  # prune the weights for each attention head separately
1649
  hparams.split_heads = True
1650
  return hparams
1651

1652

1653
@registry.register_hparams
1654
def sparse_transformer_magnitude_pruning_4k_tpu():
1655
  hparams = sparse_transformer_base()
1656
  hparams.symbol_modality_num_shards = 1
1657
  hparams.max_length = 64
1658
  hparams.batch_size = 4096  # double the batch size
1659

1660
  hparams.sparsity_technique = "magnitude_pruning"
1661
  return hparams
1662

1663

1664
@registry.register_hparams
1665
def sparse_transformer_split_head_magnitude_pruning_4k_tpu():
1666
  hparams = sparse_transformer_magnitude_pruning_4k_tpu()
1667
  hparams.split_heads = True
1668
  return hparams
1669

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.