google-research

Форк
0
276 строк · 9.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Bottom and top transformations of the model."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
from six.moves import range  # pylint: disable=redefined-builtin
22

23
from tensor2tensor.layers import common_layers
24

25
import tensorflow.compat.v1 as tf
26
from tensorflow.compat.v1 import estimator as tf_estimator
27
import state_of_sparsity.layers.l0_regularization as l0
28
import state_of_sparsity.layers.variational_dropout as vd
29
from state_of_sparsity.sparse_transformer.layers import common_sparse
30
from tensorflow.contrib.eager.python import tfe as contrib_eager
31
from tensorflow.contrib.model_pruning.python import pruning
32

33

34
# TODO(tgale): This is a hack. Find a better way to avoid collecting
35
# duplicate weight variables for variation dropout and l0-regularization
36
COLLECTED_VARIABLES = False
37

38

39
def _get_weights(model_hparams, vocab_size, hidden_dim=None):
40
  """Create or get concatenated embedding or softmax variable.
41

42
  Args:
43
    model_hparams: tf.HParams, model hyperparmeters.
44
    vocab_size: int, vocabulary size.
45
    hidden_dim: dim of the variable. Defaults to model_hparams.hidden_size
46

47
  Returns:
48
     a list of num_shards Tensors.
49
  """
50
  if hidden_dim is None:
51
    hidden_dim = model_hparams.hidden_size
52
  num_shards = model_hparams.symbol_modality_num_shards
53
  shards = []
54

55
  sparsity_technique = model_hparams.get("sparsity_technique")
56
  aux_params_shards = []
57
  for i in range(num_shards):
58
    shard_size = (vocab_size // num_shards) + (
59
        1 if i < vocab_size % num_shards else 0)
60
    var_name = "weights_%d" % i
61

62
    weight_init_stddev = hidden_dim**-0.5
63
    if (model_hparams.get("load_masks_from") and
64
        model_hparams.get("initial_sparsity")):
65
      # If we are loading constant masks for scratch-e or scratch-b
66
      # experiments, we optionally rescale the variance of the weight
67
      # initialization.
68
      initial_sparsity = model_hparams.get("initial_sparsity")
69
      weight_init_stddev = (hidden_dim * (1 - initial_sparsity))**-0.5
70
      tf.logging.info("Using sparse initialization with sparsity {} for symbol "
71
                      .format(initial_sparsity))
72

73
    shards.append(
74
        tf.get_variable(
75
            var_name, [shard_size, hidden_dim],
76
            initializer=tf.random_normal_initializer(0.0, weight_init_stddev)))
77
    if sparsity_technique == "variational_dropout":
78
      aux_params_shards.append(
79
          tf.get_variable(
80
              var_name + "_aux", [shard_size, hidden_dim],
81
              initializer=tf.constant_initializer(value=-10.0)))
82
    elif sparsity_technique == "l0_regularization":
83
      initializer = tf.random_normal_initializer(mean=2.197, stddev=0.01)
84
      aux_params_shards.append(
85
          tf.get_variable(
86
              var_name + "_aux", [shard_size, hidden_dim],
87
              initializer=initializer))
88

89
  if num_shards == 1:
90
    ret = shards[0]
91
  else:
92
    ret = tf.concat(shards, 0)
93

94
  if not aux_params_shards:
95
    # Convert ret to tensor.
96
    if not contrib_eager.in_eager_mode():
97
      ret = common_layers.convert_gradient_to_tensor(ret)
98
    return ret
99

100
  # Handle the auxiliary parameters
101
  if num_shards == 1:
102
    aux_ret = aux_params_shards[0]
103
  else:
104
    aux_ret = tf.concat(aux_params_shards, 0)
105

106
  global COLLECTED_VARIABLES
107
  if not COLLECTED_VARIABLES:
108
    if sparsity_technique == "variational_dropout":
109
      tf.add_to_collection(
110
          common_sparse.VARIATIONAL_DROPOUT_PARAMETERS,
111
          (ret, aux_ret))
112
    elif sparsity_technique == "l0_regularization":
113
      tf.add_to_collection(
114
          common_sparse.L0_REGULARIZATION_PARAMETERS,
115
          (ret, aux_ret))
116
    COLLECTED_VARIABLES = True
117

118
  # Convert aux ret to tensor.
119
  if not contrib_eager.in_eager_mode():
120
    ret = common_layers.convert_gradient_to_tensor(ret)
121
    aux_ret = common_layers.convert_gradient_to_tensor(aux_ret)
122
  return (ret, aux_ret)
123

124

125
def bottom_simple(x, model_hparams, vocab_size, name, reuse):
126
  """Bottom transformation."""
127
  with tf.variable_scope(name, reuse=reuse):
128
    # Ensure the inputs are 3-D
129
    if len(x.get_shape()) == 4:
130
      x = tf.squeeze(x, axis=3)
131
    while len(x.get_shape()) < 3:
132
      x = tf.expand_dims(x, axis=-1)
133

134
    var = _get_weights(model_hparams, vocab_size)
135
    x = common_layers.dropout_no_scaling(
136
        x, 1.0 - model_hparams.symbol_dropout)
137

138
    sparsity_technique = model_hparams.get("sparsity_technique")
139
    training = model_hparams.get("mode") == tf_estimator.ModeKeys.TRAIN
140
    if sparsity_technique == "variational_dropout":
141
      if training:
142
        ret = vd.nn.embedding_lookup_train(
143
            var,
144
            x,
145
            clip_alpha=model_hparams.get("clip_log_alpha"))
146
      else:
147
        threshold = model_hparams.get("log_alpha_threshold")
148
        ret = vd.nn.embedding_lookup_eval(
149
            var,
150
            x,
151
            threshold=threshold)
152
    elif sparsity_technique == "l0_regularization":
153
      if training:
154
        ret = l0.nn.embedding_lookup_train(var, x)
155
      else:
156
        ret = l0.nn.embedding_lookup_eval(var, x)
157
    elif (sparsity_technique == "magnitude_pruning" or
158
          sparsity_technique == "random_pruning"):
159
      ret = common_layers.gather(pruning.apply_mask(var), x)
160
    else:
161
      ret = common_layers.gather(var, x)
162

163
    # post-process the embedding vectors
164
    if model_hparams.multiply_embedding_mode == "sqrt_depth":
165
      ret *= model_hparams.hidden_size**0.5
166
    ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
167
    return ret
168

169

170
def bottom(x, model_hparams, vocab_size):
171
  """Bottom transformation for symbols."""
172
  # Sparsity techniques only support shared weight matrices for now
173
  sparsity_technique = model_hparams.get("sparsity_technique")
174
  assert (not sparsity_technique or
175
          model_hparams.shared_embedding_and_softmax_weights)
176

177
  if (model_hparams.shared_embedding_and_softmax_weights or
178
      model_hparams.get("shared_embedding")):
179
    return bottom_simple(
180
        x, model_hparams, vocab_size, "shared", reuse=None)
181
  return bottom_simple(
182
      x, model_hparams, vocab_size, "input_emb", reuse=None)
183

184

185
def targets_bottom(x, model_hparams, vocab_size):
186
  """Bottom transformation for target symbols."""
187
  if (model_hparams.shared_embedding_and_softmax_weights or
188
      model_hparams.get("shared_embedding")):
189
    try:
190
      return bottom_simple(
191
          x, model_hparams, vocab_size, "shared", reuse=True)
192
    except ValueError:
193
      # perhaps there were no inputs, and this is a new variable.
194
      return bottom_simple(
195
          x, model_hparams, vocab_size, "shared", reuse=None)
196
  else:
197
    return bottom_simple(
198
        x, model_hparams, vocab_size, "target_emb", reuse=None)
199

200

201
def top(body_output, targets, model_hparams, vocab_size):
202
  """Generate logits.
203

204
  Args:
205
    body_output: A Tensor with shape [batch, p0, p1, body_input_depth]
206
    targets: Unused.
207
    model_hparams: tf.HParams, model hyperparmeters.
208
    vocab_size: int, vocabulary size.
209

210
  Returns:
211
    logits: A Tensor with shape  [batch, p0, p1, ?, vocab_size].
212
  """
213
  del targets  # unused arg
214
  # Sparsity techniques only support shared weight matrices for now
215
  sparsity_technique = model_hparams.get("sparsity_technique")
216
  assert (not sparsity_technique or
217
          model_hparams.shared_embedding_and_softmax_weights)
218
  if model_hparams.shared_embedding_and_softmax_weights:
219
    scope_name = "shared"
220
    reuse = tf.AUTO_REUSE
221
  else:
222
    scope_name = "softmax"
223
    reuse = False
224

225
  with tf.variable_scope(scope_name, reuse=reuse):
226
    body_output_shape = common_layers.shape_list(body_output)
227
    var = _get_weights(model_hparams, vocab_size, body_output_shape[-1])
228
    if (model_hparams.factored_logits and
229
        model_hparams.mode == tf_estimator.ModeKeys.TRAIN):
230
      # Sparsity techniques only support non-factored logits for now
231
      assert not sparsity_technique
232

233
      # insert channels dimension
234
      body_output = tf.expand_dims(body_output, 3)
235
      return common_layers.FactoredTensor(body_output, var)
236
    else:
237
      body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
238

239
      training = model_hparams.get("mode") == tf_estimator.ModeKeys.TRAIN
240
      if sparsity_technique == "variational_dropout":
241
        if training:
242
          logits = vd.nn.matmul_train(
243
              body_output,
244
              var,
245
              transpose_b=True,
246
              clip_alpha=model_hparams.get("clip_log_alpha"))
247
        else:
248
          threshold = model_hparams.get("log_alpha_threshold")
249
          logits = vd.nn.matmul_eval(
250
              body_output,
251
              var,
252
              transpose_b=True,
253
              threshold=threshold)
254
      elif sparsity_technique == "l0_regularization":
255
        if training:
256
          logits = l0.nn.matmul_train(
257
              body_output,
258
              var,
259
              transpose_b=True)
260
        else:
261
          logits = l0.nn.matmul_eval(
262
              body_output,
263
              var,
264
              transpose_b=True)
265
      elif (sparsity_technique == "magnitude_pruning" or
266
            sparsity_technique == "random_pruning"):
267
        logits = tf.matmul(
268
            body_output,
269
            pruning.apply_mask(var),
270
            transpose_b=True)
271
      else:
272
        logits = tf.matmul(body_output, var, transpose_b=True)
273

274
      return tf.reshape(
275
          logits,
276
          body_output_shape[:-1] + [1, vocab_size])
277

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.