google-research
276 строк · 9.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Bottom and top transformations of the model."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21from six.moves import range # pylint: disable=redefined-builtin22
23from tensor2tensor.layers import common_layers24
25import tensorflow.compat.v1 as tf26from tensorflow.compat.v1 import estimator as tf_estimator27import state_of_sparsity.layers.l0_regularization as l028import state_of_sparsity.layers.variational_dropout as vd29from state_of_sparsity.sparse_transformer.layers import common_sparse30from tensorflow.contrib.eager.python import tfe as contrib_eager31from tensorflow.contrib.model_pruning.python import pruning32
33
34# TODO(tgale): This is a hack. Find a better way to avoid collecting
35# duplicate weight variables for variation dropout and l0-regularization
36COLLECTED_VARIABLES = False37
38
39def _get_weights(model_hparams, vocab_size, hidden_dim=None):40"""Create or get concatenated embedding or softmax variable.41
42Args:
43model_hparams: tf.HParams, model hyperparmeters.
44vocab_size: int, vocabulary size.
45hidden_dim: dim of the variable. Defaults to model_hparams.hidden_size
46
47Returns:
48a list of num_shards Tensors.
49"""
50if hidden_dim is None:51hidden_dim = model_hparams.hidden_size52num_shards = model_hparams.symbol_modality_num_shards53shards = []54
55sparsity_technique = model_hparams.get("sparsity_technique")56aux_params_shards = []57for i in range(num_shards):58shard_size = (vocab_size // num_shards) + (591 if i < vocab_size % num_shards else 0)60var_name = "weights_%d" % i61
62weight_init_stddev = hidden_dim**-0.563if (model_hparams.get("load_masks_from") and64model_hparams.get("initial_sparsity")):65# If we are loading constant masks for scratch-e or scratch-b66# experiments, we optionally rescale the variance of the weight67# initialization.68initial_sparsity = model_hparams.get("initial_sparsity")69weight_init_stddev = (hidden_dim * (1 - initial_sparsity))**-0.570tf.logging.info("Using sparse initialization with sparsity {} for symbol "71.format(initial_sparsity))72
73shards.append(74tf.get_variable(75var_name, [shard_size, hidden_dim],76initializer=tf.random_normal_initializer(0.0, weight_init_stddev)))77if sparsity_technique == "variational_dropout":78aux_params_shards.append(79tf.get_variable(80var_name + "_aux", [shard_size, hidden_dim],81initializer=tf.constant_initializer(value=-10.0)))82elif sparsity_technique == "l0_regularization":83initializer = tf.random_normal_initializer(mean=2.197, stddev=0.01)84aux_params_shards.append(85tf.get_variable(86var_name + "_aux", [shard_size, hidden_dim],87initializer=initializer))88
89if num_shards == 1:90ret = shards[0]91else:92ret = tf.concat(shards, 0)93
94if not aux_params_shards:95# Convert ret to tensor.96if not contrib_eager.in_eager_mode():97ret = common_layers.convert_gradient_to_tensor(ret)98return ret99
100# Handle the auxiliary parameters101if num_shards == 1:102aux_ret = aux_params_shards[0]103else:104aux_ret = tf.concat(aux_params_shards, 0)105
106global COLLECTED_VARIABLES107if not COLLECTED_VARIABLES:108if sparsity_technique == "variational_dropout":109tf.add_to_collection(110common_sparse.VARIATIONAL_DROPOUT_PARAMETERS,111(ret, aux_ret))112elif sparsity_technique == "l0_regularization":113tf.add_to_collection(114common_sparse.L0_REGULARIZATION_PARAMETERS,115(ret, aux_ret))116COLLECTED_VARIABLES = True117
118# Convert aux ret to tensor.119if not contrib_eager.in_eager_mode():120ret = common_layers.convert_gradient_to_tensor(ret)121aux_ret = common_layers.convert_gradient_to_tensor(aux_ret)122return (ret, aux_ret)123
124
125def bottom_simple(x, model_hparams, vocab_size, name, reuse):126"""Bottom transformation."""127with tf.variable_scope(name, reuse=reuse):128# Ensure the inputs are 3-D129if len(x.get_shape()) == 4:130x = tf.squeeze(x, axis=3)131while len(x.get_shape()) < 3:132x = tf.expand_dims(x, axis=-1)133
134var = _get_weights(model_hparams, vocab_size)135x = common_layers.dropout_no_scaling(136x, 1.0 - model_hparams.symbol_dropout)137
138sparsity_technique = model_hparams.get("sparsity_technique")139training = model_hparams.get("mode") == tf_estimator.ModeKeys.TRAIN140if sparsity_technique == "variational_dropout":141if training:142ret = vd.nn.embedding_lookup_train(143var,144x,145clip_alpha=model_hparams.get("clip_log_alpha"))146else:147threshold = model_hparams.get("log_alpha_threshold")148ret = vd.nn.embedding_lookup_eval(149var,150x,151threshold=threshold)152elif sparsity_technique == "l0_regularization":153if training:154ret = l0.nn.embedding_lookup_train(var, x)155else:156ret = l0.nn.embedding_lookup_eval(var, x)157elif (sparsity_technique == "magnitude_pruning" or158sparsity_technique == "random_pruning"):159ret = common_layers.gather(pruning.apply_mask(var), x)160else:161ret = common_layers.gather(var, x)162
163# post-process the embedding vectors164if model_hparams.multiply_embedding_mode == "sqrt_depth":165ret *= model_hparams.hidden_size**0.5166ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)167return ret168
169
170def bottom(x, model_hparams, vocab_size):171"""Bottom transformation for symbols."""172# Sparsity techniques only support shared weight matrices for now173sparsity_technique = model_hparams.get("sparsity_technique")174assert (not sparsity_technique or175model_hparams.shared_embedding_and_softmax_weights)176
177if (model_hparams.shared_embedding_and_softmax_weights or178model_hparams.get("shared_embedding")):179return bottom_simple(180x, model_hparams, vocab_size, "shared", reuse=None)181return bottom_simple(182x, model_hparams, vocab_size, "input_emb", reuse=None)183
184
185def targets_bottom(x, model_hparams, vocab_size):186"""Bottom transformation for target symbols."""187if (model_hparams.shared_embedding_and_softmax_weights or188model_hparams.get("shared_embedding")):189try:190return bottom_simple(191x, model_hparams, vocab_size, "shared", reuse=True)192except ValueError:193# perhaps there were no inputs, and this is a new variable.194return bottom_simple(195x, model_hparams, vocab_size, "shared", reuse=None)196else:197return bottom_simple(198x, model_hparams, vocab_size, "target_emb", reuse=None)199
200
201def top(body_output, targets, model_hparams, vocab_size):202"""Generate logits.203
204Args:
205body_output: A Tensor with shape [batch, p0, p1, body_input_depth]
206targets: Unused.
207model_hparams: tf.HParams, model hyperparmeters.
208vocab_size: int, vocabulary size.
209
210Returns:
211logits: A Tensor with shape [batch, p0, p1, ?, vocab_size].
212"""
213del targets # unused arg214# Sparsity techniques only support shared weight matrices for now215sparsity_technique = model_hparams.get("sparsity_technique")216assert (not sparsity_technique or217model_hparams.shared_embedding_and_softmax_weights)218if model_hparams.shared_embedding_and_softmax_weights:219scope_name = "shared"220reuse = tf.AUTO_REUSE221else:222scope_name = "softmax"223reuse = False224
225with tf.variable_scope(scope_name, reuse=reuse):226body_output_shape = common_layers.shape_list(body_output)227var = _get_weights(model_hparams, vocab_size, body_output_shape[-1])228if (model_hparams.factored_logits and229model_hparams.mode == tf_estimator.ModeKeys.TRAIN):230# Sparsity techniques only support non-factored logits for now231assert not sparsity_technique232
233# insert channels dimension234body_output = tf.expand_dims(body_output, 3)235return common_layers.FactoredTensor(body_output, var)236else:237body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])238
239training = model_hparams.get("mode") == tf_estimator.ModeKeys.TRAIN240if sparsity_technique == "variational_dropout":241if training:242logits = vd.nn.matmul_train(243body_output,244var,245transpose_b=True,246clip_alpha=model_hparams.get("clip_log_alpha"))247else:248threshold = model_hparams.get("log_alpha_threshold")249logits = vd.nn.matmul_eval(250body_output,251var,252transpose_b=True,253threshold=threshold)254elif sparsity_technique == "l0_regularization":255if training:256logits = l0.nn.matmul_train(257body_output,258var,259transpose_b=True)260else:261logits = l0.nn.matmul_eval(262body_output,263var,264transpose_b=True)265elif (sparsity_technique == "magnitude_pruning" or266sparsity_technique == "random_pruning"):267logits = tf.matmul(268body_output,269pruning.apply_mask(var),270transpose_b=True)271else:272logits = tf.matmul(body_output, var, transpose_b=True)273
274return tf.reshape(275logits,276body_output_shape[:-1] + [1, vocab_size])277