google-research
1668 строк · 64.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Transformer model from "Attention Is All You Need".
17
18The Transformer model consists of an encoder and a decoder. Both are stacks
19of self-attention layers followed by feed-forward layers. This model yields
20good results on a number of problems, especially in NLP and machine translation.
21
22See "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) for the full
23description of the model and the results obtained with its early version.
24
25Branched from Tensor2Tensor implementation:
26github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
27"""
28
29from __future__ import absolute_import
30from __future__ import division
31from __future__ import print_function
32from six.moves import range # pylint: disable=redefined-builtin
33
34from tensor2tensor.layers import common_attention
35from tensor2tensor.layers import common_hparams
36from tensor2tensor.layers import common_layers
37from tensor2tensor.layers import modalities
38from tensor2tensor.layers import transformer_layers
39from tensor2tensor.utils import beam_search
40from tensor2tensor.utils import mlperf_log
41from tensor2tensor.utils import registry
42
43import tensorflow.compat.v1 as tf
44from tensorflow.compat.v1 import estimator as tf_estimator
45
46from state_of_sparsity.sparse_transformer.layers import common_sparse
47from state_of_sparsity.sparse_transformer.layers import sparse_attention
48from state_of_sparsity.sparse_transformer.layers import sparse_modalities
49from state_of_sparsity.sparse_transformer.layers import sparse_transformer_layers
50from state_of_sparsity.sparse_transformer.models import sparse_model
51
52from tensorflow.python.ops import inplace_ops # pylint: disable=g-direct-tensorflow-import
53from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
54
55
56# Alias some commonly reused layers, here and elsewhere.
57transformer_prepare_encoder = transformer_layers.transformer_prepare_encoder
58transformer_encoder = sparse_transformer_layers.transformer_encoder
59transformer_ffn_layer = sparse_transformer_layers.transformer_ffn_layer
60
61
62@registry.register_model
63class SparseTransformer(sparse_model.SparseModel):
64"""Attention net. See file docstring."""
65
66def __init__(self, *args, **kwargs):
67super(SparseTransformer, self).__init__(*args, **kwargs)
68self.attention_weights = dict() # For visualizing attention heads.
69
70def encode(self, inputs, target_space, hparams, features=None, losses=None):
71"""Encode transformer inputs.
72
73Args:
74inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which
75will be flattened along the two spatial dimensions.
76target_space: scalar, target space ID.
77hparams: hyperparameters for model.
78features: optionally pass the entire features dictionary as well.
79This is needed now for "packed" datasets.
80losses: optional list onto which to append extra training losses
81
82Returns:
83Tuple of:
84encoder_output: Encoder representation.
85[batch_size, input_length, hidden_dim]
86encoder_decoder_attention_bias: Bias and mask weights for
87encoder-decoder attention. [batch_size, input_length]
88"""
89inputs = common_layers.flatten4d3d(inputs)
90
91encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
92transformer_prepare_encoder(
93inputs, target_space, hparams, features=features))
94
95mlperf_log.transformer_print(
96key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
97value=hparams.layer_prepostprocess_dropout)
98
99encoder_input = tf.nn.dropout(encoder_input,
1001.0 - hparams.layer_prepostprocess_dropout)
101
102encoder_output = transformer_encoder(
103encoder_input,
104self_attention_bias,
105hparams,
106nonpadding=features_to_nonpadding(features, "inputs"),
107save_weights_to=self.attention_weights,
108make_image_summary=not common_layers.is_xla_compiled())
109
110return encoder_output, encoder_decoder_attention_bias
111
112def decode(self,
113decoder_input,
114encoder_output,
115encoder_decoder_attention_bias,
116decoder_self_attention_bias,
117hparams,
118cache=None,
119decode_loop_step=None,
120losses=None):
121"""Decode Transformer outputs from encoder representation.
122
123Args:
124decoder_input: inputs to bottom of the model.
125[batch_size, decoder_length, hidden_dim]
126encoder_output: Encoder representation.
127[batch_size, input_length, hidden_dim]
128encoder_decoder_attention_bias: Bias and mask weights for
129encoder-decoder attention. [batch_size, input_length]
130decoder_self_attention_bias: Bias and mask weights for decoder
131self-attention. [batch_size, decoder_length]
132hparams: hyperparameters for model.
133cache: dict, containing tensors which are the results of previous
134attentions, used for fast decoding.
135decode_loop_step: An integer, step number of the decoding loop.
136Only used for inference on TPU.
137losses: optional list onto which to append extra training losses
138
139Returns:
140Final decoder representation. [batch_size, decoder_length, hidden_dim]
141"""
142mlperf_log.transformer_print(
143key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
144value=hparams.layer_prepostprocess_dropout)
145decoder_input = tf.nn.dropout(decoder_input,
1461.0 - hparams.layer_prepostprocess_dropout)
147
148decoder_output = transformer_decoder(
149decoder_input,
150encoder_output,
151decoder_self_attention_bias,
152encoder_decoder_attention_bias,
153hparams,
154cache=cache,
155decode_loop_step=decode_loop_step,
156save_weights_to=self.attention_weights,
157losses=losses)
158
159if (common_layers.is_xla_compiled() and
160hparams.mode == tf_estimator.ModeKeys.TRAIN):
161return decoder_output
162else:
163# Expand since t2t expects 4d tensors.
164return tf.expand_dims(decoder_output, axis=2)
165
166def body(self, features):
167"""Transformer main model_fn.
168
169Args:
170features: Map of features to the model. Should contain the following:
171"inputs": Transformer inputs.
172[batch_size, input_length, 1, hidden_dim].
173"targets": Target decoder outputs.
174[batch_size, decoder_length, 1, hidden_dim]
175"target_space_id": A scalar int from data_generators.problem.SpaceID.
176
177Returns:
178Final decoder representation. [batch_size, decoder_length, hidden_dim]
179"""
180hparams = self._hparams
181
182losses = []
183
184if self.has_input:
185inputs = features["inputs"]
186target_space = features["target_space_id"]
187encoder_output, encoder_decoder_attention_bias = self.encode(
188inputs, target_space, hparams, features=features, losses=losses)
189else:
190encoder_output, encoder_decoder_attention_bias = (None, None)
191
192targets = features["targets"]
193targets_shape = common_layers.shape_list(targets)
194targets = common_layers.flatten4d3d(targets)
195decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
196targets, hparams, features=features)
197decoder_output = self.decode(
198decoder_input,
199encoder_output,
200encoder_decoder_attention_bias,
201decoder_self_attention_bias,
202hparams,
203losses=losses)
204
205sparsity_technique = hparams.get("sparsity_technique")
206expected_attentions = features.get("expected_attentions")
207if expected_attentions is not None:
208assert not sparsity_technique
209
210attention_loss = common_attention.encoder_decoder_attention_loss(
211expected_attentions, self.attention_weights,
212hparams.expected_attention_loss_type,
213hparams.expected_attention_loss_multiplier)
214return decoder_output, {"attention_loss": attention_loss}
215
216# Add the extra loss term needed for each sparsity technique
217if sparsity_technique == "variational_dropout":
218losses += common_sparse.variational_dropout_dkl_loss(
219sparsity_check=True,
220threshold=hparams.get("log_alpha_threshold"),
221dkl_weight=hparams.get("dkl_weight"),
222begin_step=hparams.get("dkl_weight_start"),
223end_step=(hparams.get("dkl_weight_start") +
224hparams.get("dkl_weight_diff")),
225weight_function=hparams.get("dkl_weight_fn"),
226clip_alpha=hparams.get("clip_log_alpha"))
227elif sparsity_technique == "l0_regularization":
228losses += common_sparse.l0_regularization_term(
229sparsity_check=True,
230regularization_weight=hparams.get("l0_norm_weight"),
231weight_start=hparams.get("l0_weight_start"),
232weight_end=(hparams.get("l0_weight_start") +
233hparams.get("l0_weight_diff")),
234weight_function=hparams.get("dkl_weight_fn"))
235
236ret = tf.reshape(decoder_output, targets_shape)
237if losses:
238return ret, {"extra_loss": tf.add_n(losses)}
239else:
240return ret
241
242def _greedy_infer(self, features, decode_length, use_tpu=False):
243"""Fast version of greedy decoding.
244
245Args:
246features: an map of string to `Tensor`
247decode_length: an integer. How many additional timesteps to decode.
248use_tpu: A bool. Whether to build the inference graph for TPU.
249
250Returns:
251A dict of decoding results {
252"outputs": integer `Tensor` of decoded ids of shape
253[batch_size, <= decode_length] if beam_size == 1 or
254[batch_size, top_beams, <= decode_length]
255"scores": decoding log probs from the beam search,
256None if using greedy decoding (beam_size=1)
257}
258
259Raises:
260NotImplementedError: If there are multiple data shards.
261"""
262# For real-valued modalities use the slow decode path for now.
263if (self._target_modality_is_real or
264self._hparams.self_attention_type != "dot_product"):
265return super(SparseTransformer, self)._greedy_infer(
266features, decode_length)
267with tf.variable_scope(self.name):
268return (self._fast_decode_tpu(features, decode_length) if use_tpu else
269self._fast_decode(features, decode_length))
270
271def _beam_decode(self,
272features,
273decode_length,
274beam_size,
275top_beams,
276alpha,
277use_tpu=False):
278"""Beam search decoding.
279
280Args:
281features: an map of string to `Tensor`
282decode_length: an integer. How many additional timesteps to decode.
283beam_size: number of beams.
284top_beams: an integer. How many of the beams to return.
285alpha: Float that controls the length penalty. larger the alpha, stronger
286the preference for longer translations.
287use_tpu: A bool, whether to do beam decode on TPU.
288
289Returns:
290A dict of decoding results {
291"outputs": integer `Tensor` of decoded ids of shape
292[batch_size, <= decode_length] if beam_size == 1 or
293[batch_size, top_beams, <= decode_length]
294"scores": decoding log probs from the beam search,
295None if using greedy decoding (beam_size=1)
296}
297"""
298if self._hparams.self_attention_type != "dot_product":
299# Caching is not guaranteed to work with attention types other than
300# dot_product.
301return self._beam_decode_slow(features, decode_length, beam_size,
302top_beams, alpha, use_tpu)
303with tf.variable_scope(self.name):
304if use_tpu:
305return self._fast_decode_tpu(
306features, decode_length, beam_size, top_beams, alpha)
307else:
308return self._fast_decode(
309features, decode_length, beam_size, top_beams, alpha)
310
311def _fast_decode_tpu(self,
312features,
313decode_length,
314beam_size=1,
315top_beams=1,
316alpha=1.0):
317"""Fast decoding.
318
319Implements both greedy and beam search decoding on TPU, uses beam search
320iff beam_size > 1, otherwise beam search related arguments are ignored.
321
322Args:
323features: A map of string to model features.
324decode_length: An integer, how many additional timesteps to decode.
325beam_size: An integer, number of beams.
326top_beams: An integer, how many of the beams to return.
327alpha: A float that controls the length penalty. Larger the alpha,
328stronger the preference for longer translations.
329
330Returns:
331A dict of decoding results {
332"outputs": integer `Tensor` of decoded ids of shape
333[batch_size, <= decode_length] if beam_size == 1 or
334[batch_size, top_beams, <= decode_length]
335"scores": decoding log probs from the beam search,
336None if using greedy decoding (beam_size=1)
337}.
338
339Raises:
340NotImplementedError: If there are multiple data shards.
341"""
342if self._num_datashards != 1:
343raise NotImplementedError("Fast decoding only supports a single shard.")
344if "targets_segmentation" in features:
345raise NotImplementedError(
346"Decoding not supported on packed datasets "
347" If you want to decode from a dataset, use the non-packed version"
348" of the dataset when decoding.")
349dp = self._data_parallelism
350hparams = self._hparams
351target_modality = self._problem_hparams.modality["targets"]
352target_vocab_size = self._problem_hparams.vocab_size["targets"]
353if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
354target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
355
356if self.has_input:
357inputs = features["inputs"]
358if target_modality == modalities.ModalityType.CLASS_LABEL:
359decode_length = 1
360else:
361decode_length = (
362common_layers.shape_list(inputs)[1] + features.get(
363"decode_length", decode_length))
364
365inputs = tf.expand_dims(inputs, axis=1)
366if len(inputs.shape) < 5:
367inputs = tf.expand_dims(inputs, axis=4)
368s = common_layers.shape_list(inputs)
369batch_size = s[0]
370inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
371# _shard_features called to ensure that the variable names match
372inputs = self._shard_features({"inputs": inputs})["inputs"]
373input_modality = self._problem_hparams.modality["inputs"]
374input_vocab_size = self._problem_hparams.vocab_size["inputs"]
375if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
376input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
377modality_name = hparams.name.get(
378"inputs",
379modalities.get_name(input_modality))(hparams, input_vocab_size)
380with tf.variable_scope(modality_name):
381bottom = hparams.bottom.get(
382"inputs", modalities.get_bottom(input_modality))
383inputs = dp(bottom, inputs, hparams, input_vocab_size)
384with tf.variable_scope("body"):
385encoder_output, encoder_decoder_attention_bias = dp(
386self.encode,
387inputs,
388features["target_space_id"],
389hparams,
390features=features)
391encoder_output = encoder_output[0]
392encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
393partial_targets = None
394else:
395# The problem has no inputs.
396encoder_output = None
397encoder_decoder_attention_bias = None
398
399# Prepare partial targets.
400# In either features["inputs"] or features["targets"].
401# We force the outputs to begin with these sequences.
402partial_targets = features.get("inputs")
403if partial_targets is None:
404partial_targets = features["targets"]
405assert partial_targets is not None
406partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
407partial_targets = tf.to_int64(partial_targets)
408partial_targets_shape = common_layers.shape_list(partial_targets)
409partial_targets_length = partial_targets_shape[1]
410decode_length = (
411partial_targets_length + features.get("decode_length", decode_length))
412batch_size = partial_targets_shape[0]
413
414if hparams.pos == "timing":
415positional_encoding = common_attention.get_timing_signal_1d(
416decode_length + 1, hparams.hidden_size)
417elif hparams.pos == "emb":
418positional_encoding = common_attention.add_positional_embedding(
419tf.zeros([1, decode_length + 1, hparams.hidden_size]),
420hparams.max_length, "body/targets_positional_embedding", None)
421else:
422positional_encoding = None
423
424def preprocess_targets(targets, i):
425"""Performs preprocessing steps on the targets to prepare for the decoder.
426
427This includes:
428- Embedding the ids.
429- Flattening to 3D tensor.
430- Optionally adding timing signals.
431
432Args:
433targets: A tensor, inputs ids to the decoder. [batch_size, 1].
434i: An integer, Step number of the decoding loop.
435
436Returns:
437A tensor, processed targets [batch_size, 1, hidden_dim].
438"""
439# _shard_features called to ensure that the variable names match
440targets = self._shard_features({"targets": targets})["targets"]
441modality_name = hparams.name.get(
442"targets",
443modalities.get_name(target_modality))(hparams, target_vocab_size)
444with tf.variable_scope(modality_name):
445bottom = hparams.bottom.get(
446"targets", modalities.get_targets_bottom(target_modality))
447targets = dp(bottom, targets, hparams, target_vocab_size)[0]
448targets = common_layers.flatten4d3d(targets)
449
450targets = tf.cond(
451tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
452
453if positional_encoding is not None:
454positional_encoding_shape = positional_encoding.shape.as_list()
455targets += tf.slice(
456positional_encoding, [0, i, 0],
457[positional_encoding_shape[0], 1, positional_encoding_shape[2]])
458return targets
459
460decoder_self_attention_bias = (
461common_attention.attention_bias_lower_triangle(decode_length))
462if hparams.proximity_bias:
463decoder_self_attention_bias += common_attention.attention_bias_proximal(
464decode_length)
465
466def symbols_to_logits_tpu_fn(ids, i, cache):
467"""Go from ids to logits for next symbol on TPU.
468
469Args:
470ids: A tensor, symbol IDs.
471i: An integer, step number of the decoding loop. Only used for inference
472on TPU.
473cache: A dict, containing tensors which are the results of previous
474attentions, used for fast decoding.
475
476Returns:
477ret: A tensor, computed logits.
478cache: A dict, containing tensors which are the results of previous
479attentions, used for fast decoding.
480"""
481ids = ids[:, -1:]
482targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
483targets = preprocess_targets(targets, i)
484
485bias_shape = decoder_self_attention_bias.shape.as_list()
486bias = tf.slice(decoder_self_attention_bias, [0, 0, i, 0],
487[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
488
489with tf.variable_scope("body"):
490body_outputs = dp(
491self.decode,
492targets,
493cache.get("encoder_output"),
494cache.get("encoder_decoder_attention_bias"),
495bias,
496hparams,
497cache,
498i)
499
500modality_name = hparams.name.get(
501"targets",
502modalities.get_name(target_modality))(hparams, target_vocab_size)
503with tf.variable_scope(modality_name):
504top = hparams.top.get("targets", modalities.get_top(target_modality))
505logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
506
507ret = tf.squeeze(logits, axis=[1, 2, 3])
508if partial_targets is not None:
509# If the position is within the given partial targets, we alter the
510# logits to always return those values.
511# A faster approach would be to process the partial targets in one
512# iteration in order to fill the corresponding parts of the cache.
513# This would require broader changes, though.
514vocab_size = tf.shape(ret)[1]
515
516def forced_logits():
517return tf.one_hot(
518tf.tile(
519tf.slice(partial_targets, [0, i],
520[partial_targets.shape.as_list()[0], 1]),
521[beam_size]), vocab_size, 0.0, -1e9)
522
523ret = tf.cond(
524tf.less(i, partial_targets_length), forced_logits, lambda: ret)
525return ret, cache
526
527ret = fast_decode_tpu(
528encoder_output=encoder_output,
529encoder_decoder_attention_bias=encoder_decoder_attention_bias,
530symbols_to_logits_fn=symbols_to_logits_tpu_fn,
531hparams=hparams,
532decode_length=decode_length,
533vocab_size=target_vocab_size,
534beam_size=beam_size,
535top_beams=top_beams,
536alpha=alpha,
537batch_size=batch_size,
538force_decode_length=self._decode_hparams.force_decode_length)
539if partial_targets is not None:
540if beam_size <= 1 or top_beams <= 1:
541ret["outputs"] = ret["outputs"][:, partial_targets_length:]
542else:
543ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
544return ret
545
546def _fast_decode(self,
547features,
548decode_length,
549beam_size=1,
550top_beams=1,
551alpha=1.0):
552"""Fast decoding.
553
554Implements both greedy and beam search decoding, uses beam search iff
555beam_size > 1, otherwise beam search related arguments are ignored.
556
557Args:
558features: a map of string to model features.
559decode_length: an integer. How many additional timesteps to decode.
560beam_size: number of beams.
561top_beams: an integer. How many of the beams to return.
562alpha: Float that controls the length penalty. larger the alpha, stronger
563the preference for longer translations.
564
565Returns:
566A dict of decoding results {
567"outputs": integer `Tensor` of decoded ids of shape
568[batch_size, <= decode_length] if beam_size == 1 or
569[batch_size, top_beams, <= decode_length]
570"scores": decoding log probs from the beam search,
571None if using greedy decoding (beam_size=1)
572}
573
574Raises:
575NotImplementedError: If there are multiple data shards.
576"""
577if self._num_datashards != 1:
578raise NotImplementedError("Fast decoding only supports a single shard.")
579dp = self._data_parallelism
580hparams = self._hparams
581target_modality = self._problem_hparams.modality["targets"]
582target_vocab_size = self._problem_hparams.vocab_size["targets"]
583if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
584target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
585if "targets_segmentation" in features:
586raise NotImplementedError(
587"Decoding not supported on packed datasets "
588" If you want to decode from a dataset, use the non-packed version"
589" of the dataset when decoding.")
590if self.has_input:
591inputs = features["inputs"]
592if target_modality == modalities.ModalityType.CLASS_LABEL:
593decode_length = 1
594else:
595decode_length = (
596common_layers.shape_list(inputs)[1] + features.get(
597"decode_length", decode_length))
598
599inputs = tf.expand_dims(inputs, axis=1)
600if len(inputs.shape) < 5:
601inputs = tf.expand_dims(inputs, axis=4)
602s = common_layers.shape_list(inputs)
603batch_size = s[0]
604inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
605# _shard_features called to ensure that the variable names match
606inputs = self._shard_features({"inputs": inputs})["inputs"]
607input_modality = self._problem_hparams.modality["inputs"]
608input_vocab_size = self._problem_hparams.vocab_size["inputs"]
609if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
610input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
611modality_name = hparams.name.get(
612"inputs",
613modalities.get_name(input_modality))(hparams, input_vocab_size)
614with tf.variable_scope(modality_name):
615bottom = hparams.bottom.get(
616"inputs", modalities.get_bottom(input_modality))
617inputs = dp(bottom, inputs, hparams, input_vocab_size)
618with tf.variable_scope("body"):
619encoder_output, encoder_decoder_attention_bias = dp(
620self.encode,
621inputs,
622features["target_space_id"],
623hparams,
624features=features)
625encoder_output = encoder_output[0]
626encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
627partial_targets = None
628else:
629# The problem has no inputs.
630encoder_output = None
631encoder_decoder_attention_bias = None
632
633# Prepare partial targets.
634# In either features["inputs"] or features["targets"].
635# We force the outputs to begin with these sequences.
636partial_targets = features.get("inputs")
637if partial_targets is None:
638partial_targets = features["targets"]
639assert partial_targets is not None
640partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
641partial_targets = tf.to_int64(partial_targets)
642partial_targets_shape = common_layers.shape_list(partial_targets)
643partial_targets_length = partial_targets_shape[1]
644decode_length = (
645partial_targets_length + features.get("decode_length", decode_length))
646batch_size = partial_targets_shape[0]
647
648if hparams.pos == "timing":
649positional_encoding = common_attention.get_timing_signal_1d(
650decode_length + 1, hparams.hidden_size)
651elif hparams.pos == "emb":
652positional_encoding = common_attention.add_positional_embedding(
653tf.zeros([1, decode_length, hparams.hidden_size]),
654hparams.max_length, "body/targets_positional_embedding", None)
655else:
656positional_encoding = None
657
658def preprocess_targets(targets, i):
659"""Performs preprocessing steps on the targets to prepare for the decoder.
660
661This includes:
662- Embedding the ids.
663- Flattening to 3D tensor.
664- Optionally adding timing signals.
665
666Args:
667targets: inputs ids to the decoder. [batch_size, 1]
668i: scalar, Step number of the decoding loop.
669
670Returns:
671Processed targets [batch_size, 1, hidden_dim]
672"""
673# _shard_features called to ensure that the variable names match
674targets = self._shard_features({"targets": targets})["targets"]
675modality_name = hparams.name.get(
676"targets",
677modalities.get_name(target_modality))(hparams, target_vocab_size)
678with tf.variable_scope(modality_name):
679bottom = hparams.bottom.get(
680"targets", modalities.get_targets_bottom(target_modality))
681targets = dp(bottom, targets, hparams, target_vocab_size)[0]
682targets = common_layers.flatten4d3d(targets)
683
684targets = tf.cond(
685tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)
686
687if positional_encoding is not None:
688targets += positional_encoding[:, i:i + 1]
689return targets
690
691decoder_self_attention_bias = (
692common_attention.attention_bias_lower_triangle(decode_length))
693if hparams.proximity_bias:
694decoder_self_attention_bias += common_attention.attention_bias_proximal(
695decode_length)
696
697def symbols_to_logits_fn(ids, i, cache):
698"""Go from ids to logits for next symbol."""
699ids = ids[:, -1:]
700targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
701targets = preprocess_targets(targets, i)
702
703bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
704
705with tf.variable_scope("body"):
706body_outputs = dp(
707self.decode,
708targets,
709cache.get("encoder_output"),
710cache.get("encoder_decoder_attention_bias"),
711bias,
712hparams,
713cache)
714
715modality_name = hparams.name.get(
716"targets",
717modalities.get_name(target_modality))(hparams, target_vocab_size)
718with tf.variable_scope(modality_name):
719top = hparams.top.get("targets", modalities.get_top(target_modality))
720logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]
721
722ret = tf.squeeze(logits, axis=[1, 2, 3])
723if partial_targets is not None:
724# If the position is within the given partial targets, we alter the
725# logits to always return those values.
726# A faster approach would be to process the partial targets in one
727# iteration in order to fill the corresponding parts of the cache.
728# This would require broader changes, though.
729vocab_size = tf.shape(ret)[1]
730
731def forced_logits():
732return tf.one_hot(
733tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
734-1e9)
735
736ret = tf.cond(
737tf.less(i, partial_targets_length), forced_logits, lambda: ret)
738return ret, cache
739
740ret = fast_decode(
741encoder_output=encoder_output,
742encoder_decoder_attention_bias=encoder_decoder_attention_bias,
743symbols_to_logits_fn=symbols_to_logits_fn,
744hparams=hparams,
745decode_length=decode_length,
746vocab_size=target_vocab_size,
747beam_size=beam_size,
748top_beams=top_beams,
749alpha=alpha,
750batch_size=batch_size,
751force_decode_length=self._decode_hparams.force_decode_length)
752if partial_targets is not None:
753if beam_size <= 1 or top_beams <= 1:
754ret["outputs"] = ret["outputs"][:, partial_targets_length:]
755else:
756ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
757return ret
758
759
760def fast_decode_tpu(encoder_output,
761encoder_decoder_attention_bias,
762symbols_to_logits_fn,
763hparams,
764decode_length,
765vocab_size,
766beam_size=1,
767top_beams=1,
768alpha=1.0,
769sos_id=0,
770eos_id=beam_search.EOS_ID,
771batch_size=None,
772force_decode_length=False,
773scope_prefix="body/"):
774"""Given encoder output and a symbols to logits function, does fast decoding.
775
776Implements both greedy and beam search decoding for TPU, uses beam search iff
777beam_size > 1, otherwise beam search related arguments are ignored.
778
779Args:
780encoder_output: A tensor, output from encoder.
781encoder_decoder_attention_bias: A tensor, bias for use in encoder-decoder
782attention.
783symbols_to_logits_fn: Incremental decoding, function mapping triple
784`(ids, step, cache)` to symbol logits.
785hparams: Run hyperparameters.
786decode_length: An integer, how many additional timesteps to decode.
787vocab_size: Output vocabulary size.
788beam_size: An integer, number of beams.
789top_beams: An integer, how many of the beams to return.
790alpha: A float that controls the length penalty. Larger the alpha, stronger
791the preference for longer translations.
792sos_id: Start-of-sequence symbol.
793eos_id: End-of-sequence symbol.
794batch_size: An integer, must be passed if there is no input.
795force_decode_length: A bool, whether to force the full decode length, or if
796False, stop when all beams hit eos_id.
797scope_prefix: str, prefix for decoder layer variable scopes.
798
799Returns:
800A dict of decoding results {
801"outputs": integer `Tensor` of decoded ids of shape
802[batch_size, <= decode_length] if top_beams == 1 or
803[batch_size, top_beams, <= decode_length] otherwise
804"scores": decoding log probs from the beam search,
805None if using greedy decoding (beam_size=1)
806}.
807
808Raises:
809NotImplementedError: If beam size > 1 with partial targets.
810"""
811if encoder_output is not None:
812batch_size = common_layers.shape_list(encoder_output)[0]
813
814key_channels = hparams.attention_key_channels or hparams.hidden_size
815value_channels = hparams.attention_value_channels or hparams.hidden_size
816num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
817vars_3d_num_heads = (
818hparams.num_heads if hparams.get("attention_variables_3d") else 0)
819
820cache = {
821"layer_%d" % layer: { # pylint: disable=g-complex-comprehension
822"k":
823common_attention.split_heads(
824tf.zeros([batch_size, decode_length, key_channels]),
825hparams.num_heads),
826"v":
827common_attention.split_heads(
828tf.zeros([batch_size, decode_length, value_channels]),
829hparams.num_heads),
830"f":
831tf.zeros([batch_size, decode_length, hparams.hidden_size]),
832} for layer in range(num_layers)
833}
834
835if encoder_output is not None:
836for layer in range(num_layers):
837layer_name = "layer_%d" % layer
838with tf.variable_scope(
839"%sdecoder/%s/encdec_attention/multihead_attention" % (scope_prefix,
840layer_name)):
841initial_sparsity = None
842if hparams.get("load_masks_from"):
843initial_sparsity = hparams.get("initial_sparsity")
844
845k_encdec = sparse_attention.compute_attention_component(
846encoder_output, key_channels, name="k",
847vars_3d_num_heads=vars_3d_num_heads,
848sparsity_technique=hparams.get("sparsity_technique"),
849threshold=hparams.get("log_alpha_threshold"),
850training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
851clip_alpha=hparams.get("clip_log_alpha"),
852initial_sparsity=initial_sparsity,
853split_heads=hparams.get("split_heads"),
854num_heads=hparams.num_heads)
855k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
856v_encdec = sparse_attention.compute_attention_component(
857encoder_output, value_channels, name="v",
858vars_3d_num_heads=vars_3d_num_heads,
859sparsity_technique=hparams.get("sparsity_technique"),
860threshold=hparams.get("log_alpha_threshold"),
861training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
862clip_alpha=hparams.get("clip_log_alpha"),
863initial_sparsity=initial_sparsity,
864split_heads=hparams.get("split_heads"),
865num_heads=hparams.num_heads)
866v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
867cache[layer_name]["k_encdec"] = k_encdec
868cache[layer_name]["v_encdec"] = v_encdec
869
870cache["encoder_output"] = encoder_output
871cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
872
873mlperf_log.transformer_print(
874key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH,
875value={
876"vocab_size": vocab_size,
877"batch_size": batch_size,
878"beam_size": beam_size,
879"alpha": alpha,
880"max_decode_length": decode_length
881})
882if beam_size > 1: # Beam Search
883initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
884decoded_ids, scores, _ = beam_search.beam_search(
885symbols_to_logits_fn,
886initial_ids,
887beam_size,
888decode_length,
889vocab_size,
890alpha,
891states=cache,
892eos_id=eos_id,
893stop_early=(top_beams == 1),
894use_tpu=True)
895
896if top_beams == 1:
897decoded_ids = decoded_ids[:, 0, 1:]
898scores = scores[:, 0]
899else:
900decoded_ids = decoded_ids[:, :top_beams, 1:]
901scores = scores[:, :top_beams]
902else: # Greedy
903def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
904"""One step of greedy decoding."""
905logits, cache = symbols_to_logits_fn(next_id, i, cache)
906log_probs = common_layers.log_prob_from_logits(logits)
907temperature = (0.0 if hparams.sampling_method == "argmax" else
908hparams.sampling_temp)
909next_id = common_layers.sample_with_temperature(logits, temperature)
910hit_eos |= tf.equal(next_id, eos_id)
911
912log_prob_indices = tf.stack(
913[tf.range(tf.to_int64(batch_size)), next_id], axis=1)
914log_prob += tf.gather_nd(log_probs, log_prob_indices)
915
916next_id = tf.expand_dims(next_id, axis=1)
917decoded_ids = tf.transpose(decoded_ids)
918decoded_ids = inplace_ops.alias_inplace_update(
919decoded_ids, i, tf.squeeze(next_id, axis=1))
920decoded_ids = tf.transpose(decoded_ids)
921return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
922
923def is_not_finished(i, hit_eos, *_):
924finished = i >= decode_length
925if not force_decode_length:
926finished |= tf.reduce_all(hit_eos)
927return tf.logical_not(finished)
928
929decoded_ids = tf.zeros([batch_size, decode_length], dtype=tf.int64)
930hit_eos = tf.fill([batch_size], False)
931next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
932initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
933
934def compute_cache_shape_invariants(tensor):
935return tf.TensorShape(tensor.shape.as_list())
936
937_, _, _, decoded_ids, _, log_prob = tf.while_loop(
938is_not_finished,
939inner_loop, [
940tf.constant(0), hit_eos, next_id, decoded_ids, cache,
941initial_log_prob
942],
943shape_invariants=[
944tf.TensorShape([]),
945tf.TensorShape([batch_size]),
946tf.TensorShape([batch_size, 1]),
947tf.TensorShape([batch_size, decode_length]),
948nest.map_structure(compute_cache_shape_invariants, cache),
949tf.TensorShape([batch_size]),
950])
951scores = log_prob
952
953return {"outputs": decoded_ids, "scores": scores}
954
955
956def fast_decode(encoder_output,
957encoder_decoder_attention_bias,
958symbols_to_logits_fn,
959hparams,
960decode_length,
961vocab_size,
962beam_size=1,
963top_beams=1,
964alpha=1.0,
965sos_id=0,
966eos_id=beam_search.EOS_ID,
967batch_size=None,
968force_decode_length=False,
969scope_prefix="body/"):
970"""Given encoder output and a symbols to logits function, does fast decoding.
971
972Implements both greedy and beam search decoding, uses beam search iff
973beam_size > 1, otherwise beam search related arguments are ignored.
974
975Args:
976encoder_output: Output from encoder.
977encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
978attention
979symbols_to_logits_fn: Incremental decoding; function mapping triple
980`(ids, step, cache)` to symbol logits.
981hparams: run hyperparameters
982decode_length: an integer. How many additional timesteps to decode.
983vocab_size: Output vocabulary size.
984beam_size: number of beams.
985top_beams: an integer. How many of the beams to return.
986alpha: Float that controls the length penalty. larger the alpha, stronger
987the preference for longer translations.
988sos_id: End-of-sequence symbol in beam search.
989eos_id: End-of-sequence symbol in beam search.
990batch_size: an integer scalar - must be passed if there is no input
991force_decode_length: bool, whether to force the full decode length, or if
992False, stop when all beams hit eos_id.
993scope_prefix: str, prefix for decoder layer variable scopes.
994
995Returns:
996A dict of decoding results {
997"outputs": integer `Tensor` of decoded ids of shape
998[batch_size, <= decode_length] if top_beams == 1 or
999[batch_size, top_beams, <= decode_length] otherwise
1000"scores": decoding log probs from the beam search,
1001None if using greedy decoding (beam_size=1)
1002}
1003
1004Raises:
1005NotImplementedError: If beam size > 1 with partial targets.
1006"""
1007if encoder_output is not None:
1008batch_size = common_layers.shape_list(encoder_output)[0]
1009
1010key_channels = hparams.attention_key_channels or hparams.hidden_size
1011value_channels = hparams.attention_value_channels or hparams.hidden_size
1012num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
1013vars_3d_num_heads = (
1014hparams.num_heads if hparams.get("attention_variables_3d") else 0)
1015
1016cache = {
1017"layer_%d" % layer: { # pylint: disable=g-complex-comprehension
1018"k":
1019common_attention.split_heads(
1020tf.zeros([batch_size, 0, key_channels]), hparams.num_heads),
1021"v":
1022common_attention.split_heads(
1023tf.zeros([batch_size, 0, value_channels]), hparams.num_heads),
1024"f":
1025tf.zeros([batch_size, 0, hparams.hidden_size]),
1026} for layer in range(num_layers)
1027}
1028
1029if encoder_output is not None:
1030for layer in range(num_layers):
1031layer_name = "layer_%d" % layer
1032with tf.variable_scope(
1033"%sdecoder/%s/encdec_attention/multihead_attention" % (scope_prefix,
1034layer_name)):
1035initial_sparsity = None
1036if hparams.get("load_masks_from"):
1037initial_sparsity = hparams.get("initial_sparsity")
1038
1039k_encdec = sparse_attention.compute_attention_component(
1040encoder_output, key_channels, name="k",
1041vars_3d_num_heads=vars_3d_num_heads,
1042sparsity_technique=hparams.get("sparsity_technique"),
1043threshold=hparams.get("log_alpha_threshold"),
1044training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1045clip_alpha=hparams.get("clip_log_alpha"),
1046initial_sparsity=initial_sparsity,
1047split_heads=hparams.get("split_heads"),
1048num_heads=hparams.num_heads)
1049k_encdec = common_attention.split_heads(k_encdec, hparams.num_heads)
1050v_encdec = sparse_attention.compute_attention_component(
1051encoder_output, value_channels, name="v",
1052vars_3d_num_heads=vars_3d_num_heads,
1053sparsity_technique=hparams.get("sparsity_technique"),
1054threshold=hparams.get("log_alpha_threshold"),
1055training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1056clip_alpha=hparams.get("clip_log_alpha"),
1057initial_sparsity=initial_sparsity,
1058split_heads=hparams.get("split_heads"),
1059num_heads=hparams.num_heads)
1060v_encdec = common_attention.split_heads(v_encdec, hparams.num_heads)
1061cache[layer_name]["k_encdec"] = k_encdec
1062cache[layer_name]["v_encdec"] = v_encdec
1063
1064cache["encoder_output"] = encoder_output
1065cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
1066
1067if beam_size > 1: # Beam Search
1068initial_ids = sos_id * tf.ones([batch_size], dtype=tf.int32)
1069decoded_ids, scores, _ = beam_search.beam_search(
1070symbols_to_logits_fn,
1071initial_ids,
1072beam_size,
1073decode_length,
1074vocab_size,
1075alpha,
1076states=cache,
1077eos_id=eos_id,
1078stop_early=(top_beams == 1))
1079
1080if top_beams == 1:
1081decoded_ids = decoded_ids[:, 0, 1:]
1082scores = scores[:, 0]
1083else:
1084decoded_ids = decoded_ids[:, :top_beams, 1:]
1085scores = scores[:, :top_beams]
1086else: # Greedy
1087
1088def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
1089"""One step of greedy decoding."""
1090logits, cache = symbols_to_logits_fn(next_id, i, cache)
1091log_probs = common_layers.log_prob_from_logits(logits)
1092temperature = (0.0 if hparams.sampling_method == "argmax" else
1093hparams.sampling_temp)
1094next_id = common_layers.sample_with_temperature(logits, temperature)
1095hit_eos |= tf.equal(next_id, eos_id)
1096
1097log_prob_indices = tf.stack(
1098[tf.range(tf.to_int64(batch_size)), next_id], axis=1)
1099log_prob += tf.gather_nd(log_probs, log_prob_indices)
1100
1101next_id = tf.expand_dims(next_id, axis=1)
1102decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
1103return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
1104
1105def is_not_finished(i, hit_eos, *_):
1106finished = i >= decode_length
1107if not force_decode_length:
1108finished |= tf.reduce_all(hit_eos)
1109return tf.logical_not(finished)
1110
1111decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
1112hit_eos = tf.fill([batch_size], False)
1113next_id = sos_id * tf.ones([batch_size, 1], dtype=tf.int64)
1114initial_log_prob = tf.zeros([batch_size], dtype=tf.float32)
1115_, _, _, decoded_ids, _, log_prob = tf.while_loop(
1116is_not_finished,
1117inner_loop, [
1118tf.constant(0), hit_eos, next_id, decoded_ids, cache,
1119initial_log_prob
1120],
1121shape_invariants=[
1122tf.TensorShape([]),
1123tf.TensorShape([None]),
1124tf.TensorShape([None, None]),
1125tf.TensorShape([None, None]),
1126nest.map_structure(beam_search.get_state_shape_invariants, cache),
1127tf.TensorShape([None]),
1128])
1129scores = log_prob
1130
1131return {"outputs": decoded_ids, "scores": scores}
1132
1133
1134def features_to_nonpadding(features, inputs_or_targets="inputs"):
1135key = inputs_or_targets + "_segmentation"
1136if features and key in features:
1137return tf.minimum(tf.to_float(features[key]), 1.0)
1138return None
1139
1140
1141def transformer_prepare_decoder(targets, hparams, features=None):
1142"""Prepare one shard of the model for the decoder.
1143
1144Args:
1145targets: a Tensor.
1146hparams: run hyperparameters
1147features: optionally pass the entire features dictionary as well.
1148This is needed now for "packed" datasets.
1149
1150Returns:
1151decoder_input: a Tensor, bottom of decoder stack
1152decoder_self_attention_bias: a bias tensor for use in decoder self-attention
1153"""
1154if hparams.causal_decoder_self_attention:
1155# Causal attention.
1156if hparams.prepend_mode == "prepend_inputs_full_attention":
1157decoder_self_attention_bias = (
1158common_attention.attention_bias_prepend_inputs_full_attention(
1159common_attention.embedding_to_padding(targets)))
1160else:
1161decoder_self_attention_bias = (
1162common_attention.attention_bias_lower_triangle(
1163common_layers.shape_list(targets)[1]))
1164else:
1165# Full attention.
1166decoder_padding = common_attention.embedding_to_padding(targets)
1167decoder_self_attention_bias = (
1168common_attention.attention_bias_ignore_padding(decoder_padding))
1169
1170if features and "targets_segmentation" in features:
1171# "Packed" dataset - keep the examples from seeing each other.
1172targets_segmentation = features["targets_segmentation"]
1173targets_position = features["targets_position"]
1174decoder_self_attention_bias += common_attention.attention_bias_same_segment(
1175targets_segmentation, targets_segmentation)
1176else:
1177targets_position = None
1178if hparams.proximity_bias:
1179decoder_self_attention_bias += common_attention.attention_bias_proximal(
1180common_layers.shape_list(targets)[1])
1181decoder_input = common_layers.shift_right_3d(targets)
1182if hparams.pos == "timing":
1183if targets_position is not None:
1184decoder_input = common_attention.add_timing_signal_1d_given_position(
1185decoder_input, targets_position)
1186else:
1187decoder_input = common_attention.add_timing_signal_1d(decoder_input)
1188elif hparams.pos == "emb":
1189decoder_input = common_attention.add_positional_embedding(
1190decoder_input, hparams.max_length, "targets_positional_embedding",
1191targets_position)
1192
1193if hparams.activation_dtype == "bfloat16":
1194decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
1195tf.bfloat16)
1196return (decoder_input, decoder_self_attention_bias)
1197
1198
1199def transformer_decoder(decoder_input,
1200encoder_output,
1201decoder_self_attention_bias,
1202encoder_decoder_attention_bias,
1203hparams,
1204cache=None,
1205decode_loop_step=None,
1206name="decoder",
1207save_weights_to=None,
1208make_image_summary=True,
1209losses=None): # pylint: disable=unused-argument
1210"""A stack of transformer layers.
1211
1212Args:
1213decoder_input: a Tensor
1214encoder_output: a Tensor
1215decoder_self_attention_bias: bias Tensor for self-attention
1216(see common_attention.attention_bias())
1217encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
1218(see common_attention.attention_bias())
1219hparams: hyperparameters for model
1220cache: dict, containing tensors which are the results of previous
1221attentions, used for fast decoding.
1222decode_loop_step: An integer, step number of the decoding loop.
1223Only used for inference on TPU.
1224name: a string
1225save_weights_to: an optional dictionary to capture attention weights
1226for visualization; the weights tensor will be appended there under
1227a string key created from the variable scope (including name).
1228make_image_summary: Whether to make an attention image summary.
1229losses: optional list onto which to append extra training losses
1230
1231Returns:
1232y: a Tensors
1233"""
1234x = decoder_input
1235attention_dropout_broadcast_dims = (
1236common_layers.comma_separated_string_to_integer_list(
1237getattr(hparams, "attention_dropout_broadcast_dims", "")))
1238
1239mlperf_log.transformer_print(
1240key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS,
1241value=hparams.num_decoder_layers or hparams.num_hidden_layers)
1242mlperf_log.transformer_print(
1243key=mlperf_log.MODEL_HP_ATTENTION_DROPOUT,
1244value=hparams.attention_dropout)
1245mlperf_log.transformer_print(
1246key=mlperf_log.MODEL_HP_ATTENTION_DENSE,
1247value={
1248"use_bias": "false",
1249"num_heads": hparams.num_heads,
1250"hidden_size": hparams.hidden_size
1251})
1252
1253with tf.variable_scope(name):
1254for layer in range(hparams.num_decoder_layers or hparams.num_hidden_layers):
1255initial_sparsity = None
1256if hparams.get("load_masks_from"):
1257initial_sparsity = hparams.get("initial_sparsity")
1258
1259layer_name = "layer_%d" % layer
1260layer_cache = cache[layer_name] if cache is not None else None
1261with tf.variable_scope(layer_name):
1262with tf.variable_scope("self_attention"):
1263y = sparse_attention.multihead_attention(
1264common_layers.layer_preprocess(x, hparams),
1265None,
1266decoder_self_attention_bias,
1267hparams.attention_key_channels or hparams.hidden_size,
1268hparams.attention_value_channels or hparams.hidden_size,
1269hparams.hidden_size,
1270hparams.num_heads,
1271hparams.attention_dropout,
1272attention_type=hparams.self_attention_type,
1273max_relative_position=hparams.max_relative_position,
1274heads_share_relative_embedding=(
1275hparams.heads_share_relative_embedding),
1276add_relative_to_values=hparams.add_relative_to_values,
1277save_weights_to=save_weights_to,
1278cache=layer_cache,
1279make_image_summary=make_image_summary,
1280dropout_broadcast_dims=attention_dropout_broadcast_dims,
1281max_length=hparams.get("max_length"),
1282decode_loop_step=decode_loop_step,
1283vars_3d=hparams.get("attention_variables_3d"),
1284sparsity_technique=hparams.get("sparsity_technique"),
1285threshold=hparams.get("log_alpha_threshold"),
1286training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1287clip_alpha=hparams.get("clip_log_alpha"),
1288initial_sparsity=initial_sparsity,
1289split_heads=hparams.get("split_heads"))
1290x = common_layers.layer_postprocess(x, y, hparams)
1291if encoder_output is not None:
1292with tf.variable_scope("encdec_attention"):
1293y = sparse_attention.multihead_attention(
1294common_layers.layer_preprocess(x, hparams),
1295encoder_output,
1296encoder_decoder_attention_bias,
1297hparams.attention_key_channels or hparams.hidden_size,
1298hparams.attention_value_channels or hparams.hidden_size,
1299hparams.hidden_size,
1300hparams.num_heads,
1301hparams.attention_dropout,
1302max_relative_position=hparams.max_relative_position,
1303heads_share_relative_embedding=(
1304hparams.heads_share_relative_embedding),
1305add_relative_to_values=hparams.add_relative_to_values,
1306save_weights_to=save_weights_to,
1307cache=layer_cache,
1308make_image_summary=make_image_summary,
1309dropout_broadcast_dims=attention_dropout_broadcast_dims,
1310max_length=hparams.get("max_length"),
1311vars_3d=hparams.get("attention_variables_3d"),
1312sparsity_technique=hparams.get("sparsity_technique"),
1313threshold=hparams.get("log_alpha_threshold"),
1314training=hparams.get("mode") == tf_estimator.ModeKeys.TRAIN,
1315clip_alpha=hparams.get("clip_log_alpha"),
1316initial_sparsity=initial_sparsity,
1317split_heads=hparams.get("split_heads"))
1318x = common_layers.layer_postprocess(x, y, hparams)
1319with tf.variable_scope("ffn"):
1320y = transformer_ffn_layer(
1321common_layers.layer_preprocess(x, hparams),
1322hparams)
1323x = common_layers.layer_postprocess(x, y, hparams)
1324# if normalization is done in layer_preprocess, then it should also be done
1325# on the output, since the output can grow very large, being the sum of
1326# a whole stack of unnormalized layer outputs.
1327mlperf_log.transformer_print(
1328key=mlperf_log.MODEL_HP_NORM,
1329value={"hidden_size": hparams.hidden_size})
1330return common_layers.layer_preprocess(x, hparams)
1331
1332
1333@registry.register_hparams
1334def sparse_transformer_base_v1():
1335"""Set of hyperparameters."""
1336hparams = common_hparams.basic_params1()
1337hparams.norm_type = "layer"
1338hparams.hidden_size = 512
1339hparams.batch_size = 4096
1340hparams.max_length = 256
1341hparams.clip_grad_norm = 0. # i.e. no gradient clipping
1342hparams.optimizer_adam_epsilon = 1e-9
1343hparams.learning_rate_schedule = "legacy"
1344hparams.learning_rate_decay_scheme = "noam"
1345hparams.learning_rate = 0.1
1346hparams.learning_rate_warmup_steps = 4000
1347hparams.initializer_gain = 1.0
1348hparams.num_hidden_layers = 6
1349hparams.initializer = "uniform_unit_scaling"
1350hparams.weight_decay = 0.0
1351hparams.optimizer_adam_beta1 = 0.9
1352hparams.optimizer_adam_beta2 = 0.98
1353hparams.num_sampled_classes = 0
1354hparams.label_smoothing = 0.1
1355hparams.shared_embedding_and_softmax_weights = True
1356hparams.symbol_modality_num_shards = 16
1357
1358# Add new ones like this.
1359hparams.add_hparam("filter_size", 2048)
1360# Layer-related flags. If zero, these fall back on hparams.num_hidden_layers.
1361hparams.add_hparam("num_encoder_layers", 0)
1362hparams.add_hparam("num_decoder_layers", 0)
1363# Attention-related flags.
1364hparams.add_hparam("num_heads", 8)
1365hparams.add_hparam("attention_key_channels", 0)
1366hparams.add_hparam("attention_value_channels", 0)
1367hparams.add_hparam("ffn_layer", "dense_relu_dense")
1368hparams.add_hparam("parameter_attention_key_channels", 0)
1369hparams.add_hparam("parameter_attention_value_channels", 0)
1370# All hyperparameters ending in "dropout" are automatically set to 0.0
1371# when not in training mode.
1372hparams.add_hparam("attention_dropout", 0.0)
1373hparams.add_hparam("attention_dropout_broadcast_dims", "")
1374hparams.add_hparam("relu_dropout", 0.0)
1375hparams.add_hparam("relu_dropout_broadcast_dims", "")
1376hparams.add_hparam("pos", "timing") # timing, none
1377hparams.add_hparam("nbr_decoder_problems", 1)
1378hparams.add_hparam("proximity_bias", False)
1379hparams.add_hparam("causal_decoder_self_attention", True)
1380hparams.add_hparam("use_pad_remover", True)
1381hparams.add_hparam("self_attention_type", "dot_product")
1382hparams.add_hparam("conv_first_kernel", 3)
1383hparams.add_hparam("attention_variables_3d", False)
1384hparams.add_hparam("use_target_space_embedding", True)
1385# These parameters are only used when ffn_layer=="local_moe_tpu"
1386hparams.add_hparam("moe_overhead_train", 1.0)
1387hparams.add_hparam("moe_overhead_eval", 2.0)
1388hparams.moe_num_experts = 16
1389hparams.moe_loss_coef = 1e-3
1390
1391# Sparsity hyper-parameters
1392hparams.add_hparam("sparsity_technique", None)
1393hparams.add_hparam("log_alpha_threshold", 3.0)
1394
1395# variational dropout & l0 parameters
1396hparams.add_hparam("dkl_weight_fn", "linear")
1397
1398# variational dropout parameters
1399hparams.add_hparam("dkl_weight", 1 / (4.5 * 10 ** 6))
1400hparams.add_hparam("clip_log_alpha", 8.0)
1401hparams.add_hparam("dkl_weight_start", 100000)
1402hparams.add_hparam("dkl_weight_diff", 100000)
1403
1404# l0-regularization parameters
1405hparams.add_hparam("l0_norm_weight", 1 / (4.5 * 10 ** 6))
1406hparams.add_hparam("l0_weight_start", 100000)
1407hparams.add_hparam("l0_weight_diff", 100000)
1408
1409# magnitude & random pruning parameters
1410hparams.add_hparam("begin_pruning_step", 0)
1411hparams.add_hparam("end_pruning_step", 200000)
1412hparams.add_hparam("pruning_frequency", 10000)
1413hparams.add_hparam("target_sparsity", .9)
1414
1415# whether we should prune the weights for
1416hparams.add_hparam("split_heads", False)
1417
1418# mp & rp parameters we don't really change
1419hparams.add_hparam("threshold_decay", 0.0)
1420hparams.add_hparam("nbins", 1024)
1421hparams.add_hparam("sparsity_function_exponent", 3.0)
1422
1423# use sparse embedding and softmax layer
1424hparams.bottom = {
1425"targets": sparse_modalities.targets_bottom,
1426"inputs": sparse_modalities.bottom
1427}
1428hparams.top = {
1429"targets": sparse_modalities.top,
1430}
1431
1432# specify to load trained masks from checkpoint
1433hparams.add_hparam("load_masks_from", "")
1434hparams.add_hparam("load_weights_from", "")
1435hparams.add_hparam("initial_sparsity", 0.0)
1436
1437# If < 0, use this sparsity level for the embedding
1438# matrix instead of the target_sparsity.
1439hparams.add_hparam("embedding_sparsity", -1.0)
1440return hparams
1441
1442
1443@registry.register_hparams
1444def sparse_transformer_base_v2():
1445"""Set of hyperparameters."""
1446hparams = sparse_transformer_base_v1()
1447hparams.layer_preprocess_sequence = "n"
1448hparams.layer_postprocess_sequence = "da"
1449hparams.layer_prepostprocess_dropout = 0.1
1450hparams.attention_dropout = 0.1
1451hparams.relu_dropout = 0.1
1452hparams.learning_rate_warmup_steps = 8000
1453hparams.learning_rate = 0.2
1454return hparams
1455
1456
1457@registry.register_hparams
1458def sparse_transformer_base_v3():
1459"""Base parameters for Transformer model."""
1460# Update parameters here, then occasionally cut a versioned set, e.g.
1461# transformer_base_v2.
1462hparams = sparse_transformer_base_v2()
1463hparams.optimizer_adam_beta2 = 0.997
1464# New way of specifying learning rate schedule.
1465# Equivalent to previous version.
1466hparams.learning_rate_schedule = (
1467"constant*linear_warmup*rsqrt_decay*rsqrt_hidden_size")
1468hparams.learning_rate_constant = 2.0
1469return hparams
1470
1471
1472@registry.register_hparams
1473def sparse_transformer_base():
1474"""Base parameters for Transformer model."""
1475hparams = sparse_transformer_base_v3()
1476return hparams
1477
1478
1479@registry.register_hparams
1480def sparse_transformer_tiny():
1481hparams = sparse_transformer_base()
1482hparams.num_hidden_layers = 2
1483hparams.hidden_size = 128
1484hparams.filter_size = 512
1485hparams.num_heads = 4
1486return hparams
1487
1488
1489@registry.register_hparams
1490def sparse_transformer_tiny_variational_dropout():
1491hparams = sparse_transformer_tiny()
1492hparams.sparsity_technique = "variational_dropout"
1493return hparams
1494
1495
1496@registry.register_hparams
1497def sparse_transformer_tiny_l0_regularization():
1498hparams = sparse_transformer_tiny()
1499hparams.sparsity_technique = "l0_regularization"
1500return hparams
1501
1502
1503@registry.register_hparams
1504def sparse_transformer_tiny_magnitude_pruning():
1505hparams = sparse_transformer_tiny()
1506hparams.sparsity_technique = "magnitude_pruning"
1507return hparams
1508
1509
1510@registry.register_hparams
1511def sparse_transformer_tiny_shmp():
1512hparams = sparse_transformer_tiny()
1513hparams.sparsity_technique = "magnitude_pruning"
1514hparams.split_heads = True
1515return hparams
1516
1517
1518@registry.register_hparams
1519def sparse_transformer_tiny_random_pruning():
1520hparams = sparse_transformer_tiny()
1521hparams.sparsity_technique = "random_pruning"
1522return hparams
1523
1524
1525def update_hparams_for_tpu(hparams):
1526"""Change hparams to be compatible with TPU training."""
1527
1528# Adafactor uses less memory than Adam.
1529# switch to Adafactor with its recommended learning rate scheme.
1530hparams.optimizer = "Adafactor"
1531hparams.learning_rate_schedule = "rsqrt_decay"
1532hparams.learning_rate_warmup_steps = 10000
1533
1534# Avoid an expensive concat on TPU.
1535# >1 shards helps with faster parameter distribution on multi-GPU machines
1536hparams.symbol_modality_num_shards = 1
1537
1538# Adaptive batch sizes and sequence lengths are not supported on TPU.
1539# Instead, every batch has the same sequence length and the same batch size.
1540# Longer sequences are dropped and shorter ones are padded.
1541#
1542# It is therefore suggested to use a problem where examples have been combined
1543# to a longer length, e.g. the "_packed" problems.
1544#
1545# For problems with variable sequence lengths, this parameter controls the
1546# maximum sequence length. Shorter sequences are dropped and longer ones
1547# are padded.
1548#
1549# For problems with fixed sequence lengths - e.g. the "_packed" problems,
1550# this hyperparameter is ignored.
1551hparams.max_length = 64
1552
1553# TPUs have less memory than GPUs, so decrease the batch size
1554hparams.batch_size = 2048
1555
1556# Using noise broadcast in the dropout layers saves memory during training.
1557hparams.attention_dropout_broadcast_dims = "0,1" # batch, heads
1558hparams.relu_dropout_broadcast_dims = "1" # length
1559hparams.layer_prepostprocess_dropout_broadcast_dims = "1" # length
1560
1561
1562@registry.register_hparams
1563def sparse_transformer_tpu():
1564"""HParams for Transformer model on TPU."""
1565hparams = sparse_transformer_base()
1566update_hparams_for_tpu(hparams)
1567return hparams
1568
1569
1570@registry.register_hparams
1571def sparse_transformer_tiny_tpu():
1572hparams = sparse_transformer_tiny()
1573update_hparams_for_tpu(hparams)
1574return hparams
1575
1576
1577@registry.register_hparams
1578def sparse_transformer_magnitude_pruning_tpu():
1579hparams = sparse_transformer_base()
1580hparams.symbol_modality_num_shards = 1
1581hparams.max_length = 64
1582hparams.batch_size = 2048
1583
1584hparams.sparsity_technique = "magnitude_pruning"
1585return hparams
1586
1587
1588@registry.register_hparams
1589def sparse_transformer_random_pruning_tpu():
1590hparams = sparse_transformer_base()
1591hparams.symbol_modality_num_shards = 1
1592hparams.max_length = 64
1593hparams.batch_size = 2048
1594
1595hparams.sparsity_technique = "random_pruning"
1596return hparams
1597
1598
1599@registry.register_hparams
1600def sparse_transformer_variational_dropout_tpu():
1601hparams = sparse_transformer_base()
1602hparams.symbol_modality_num_shards = 1
1603hparams.max_length = 64
1604hparams.batch_size = 2048
1605
1606hparams.sparsity_technique = "variational_dropout"
1607return hparams
1608
1609
1610@registry.register_hparams
1611def sparse_transformer_l0_regularization_tpu():
1612hparams = sparse_transformer_base()
1613hparams.symbol_modality_num_shards = 1
1614hparams.max_length = 64
1615hparams.batch_size = 2048
1616
1617hparams.sparsity_technique = "l0_regularization"
1618return hparams
1619
1620
1621@registry.register_hparams
1622def sparse_transformer_mpfc_tpu():
1623"""Magnitude pruning without embedding pruning."""
1624hparams = sparse_transformer_base()
1625hparams.symbol_modality_num_shards = 1
1626hparams.max_length = 64
1627hparams.batch_size = 4096 # double the batch size
1628
1629hparams.sparsity_technique = "magnitude_pruning"
1630
1631# use the default modality, i.e. don't prune the embedding
1632# or the final linear layer before the softmax.
1633hparams.modality = {}
1634return hparams
1635
1636
1637@registry.register_hparams
1638def sparse_transformer_mpfc_2k_tpu():
1639hparams = sparse_transformer_mpfc_tpu()
1640hparams.batch_size = 2048 # use the standard batch size
1641return hparams
1642
1643
1644@registry.register_hparams
1645def sparse_transformer_split_head_mpfc_tpu():
1646hparams = sparse_transformer_mpfc_tpu()
1647
1648# prune the weights for each attention head separately
1649hparams.split_heads = True
1650return hparams
1651
1652
1653@registry.register_hparams
1654def sparse_transformer_magnitude_pruning_4k_tpu():
1655hparams = sparse_transformer_base()
1656hparams.symbol_modality_num_shards = 1
1657hparams.max_length = 64
1658hparams.batch_size = 4096 # double the batch size
1659
1660hparams.sparsity_technique = "magnitude_pruning"
1661return hparams
1662
1663
1664@registry.register_hparams
1665def sparse_transformer_split_head_magnitude_pruning_4k_tpu():
1666hparams = sparse_transformer_magnitude_pruning_4k_tpu()
1667hparams.split_heads = True
1668return hparams
1669