google-research

Форк
0
203 строки · 7.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Base class for sparse model."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
from tensor2tensor.utils import t2t_model
22
import tensorflow.compat.v1 as tf
23
from tensorflow.compat.v1 import estimator as tf_estimator
24
from tensorflow.contrib import tpu as contrib_tpu
25
from tensorflow.contrib import training as contrib_training
26

27
from tensorflow.contrib.model_pruning.python import pruning as magnitude_pruning
28

29

30
def pruning_hparams(hparams, use_tpu, random):  # pylint: disable=unused-argument
31
  """Helper to get hparams for pruning library."""
32
  weight_sparsity_map = [""]
33
  if hparams.get("embedding_sparsity") >= 0.0:
34
    weight_sparsity_map = [
35
        "transformer/symbol_modality_33288_512/shared/:{}"
36
        .format(hparams.get("embedding_sparsity"))
37
    ]
38
    tf.logging.info(
39
        "Pruning embedding matrix to {}% sparsity"
40
        .format(hparams.get("embedding_sparsity") * 100))
41

42
  hparams = contrib_training.HParams(
43
      name="model_pruning",
44
      begin_pruning_step=hparams.get("begin_pruning_step"),
45
      end_pruning_step=hparams.get("end_pruning_step"),
46
      weight_sparsity_map=weight_sparsity_map,
47
      threshold_decay=hparams.get("threshold_decay"),
48
      pruning_frequency=hparams.get("pruning_frequency"),
49
      nbins=hparams.get("nbins"),
50
      block_height=1,
51
      block_width=1,
52
      block_pooling_function="AVG",
53
      initial_sparsity=0.0,  # always start at sparsity 0
54
      target_sparsity=hparams.get("target_sparsity"),
55
      sparsity_function_begin_step=hparams.get("begin_pruning_step"),
56
      sparsity_function_end_step=hparams.get("end_pruning_step"),
57
      sparsity_function_exponent=hparams.get("sparsity_function_exponent"),
58
      use_tpu=use_tpu)
59
  # TODO(tgale): Fix the need to keep this commented out.
60
  # random pruning currently does not work.
61
  # random=random)
62
  return hparams
63

64

65
def check_global_sparsity():
66
  """Add a summary for the weight sparsity."""
67
  weight_masks = magnitude_pruning.get_masks()
68
  weights_per_layer = []
69
  nonzero_per_layer = []
70
  for mask in weight_masks:
71
    nonzero_per_layer.append(tf.reduce_sum(mask))
72
    weights_per_layer.append(tf.size(mask))
73
    total_nonzero = tf.add_n(nonzero_per_layer)
74
    total_weights = tf.add_n(weights_per_layer)
75
  sparsity = (1.0 - (tf.cast(total_nonzero, tf.float32) /
76
                     tf.cast(total_weights, tf.float32)))
77
  tf.summary.scalar("global_weight_sparsity", sparsity)
78

79

80
class SparseModel(t2t_model.T2TModel):
81
  """T2T model with weight sparsity."""
82

83
  def initialize_masks_from_ckpt(self, checkpoint):
84
    model_dir = self._hparams.get("model_dir", None)
85
    already_has_ckpt = (
86
        model_dir and tf.train.latest_checkpoint(model_dir) is not None)
87
    if already_has_ckpt:
88
      tf.logging.info("Checkpoint exists in model_dir, not loading variables.")
89
      return
90

91
    # Create a list of mask variables to load
92
    reader = tf.train.NewCheckpointReader(checkpoint)
93
    mask_names = reader.get_variable_to_shape_map().keys()
94
    mask_names = [x for x in mask_names if x.endswith("mask")]
95

96
    variable_map = {}
97
    for var in tf.global_variables():
98
      var_name = var.name.split(":")[0]
99
      if var_name in mask_names:
100
        tf.logging.info("Loading mask variable from checkpoint: %s", var_name)
101
        variable_map[var_name] = var
102
      elif "mask" in var_name:
103
        tf.logging.info(
104
            "Cannot find mask variable in checkpoint, skipping: %s", var_name)
105
    tf.train.init_from_checkpoint(checkpoint, variable_map)
106

107
  def initialize_non_masks_from_ckpt(self, checkpoint):
108
    model_dir = self._hparams.get("model_dir", None)
109
    already_has_ckpt = (
110
        model_dir and tf.train.latest_checkpoint(model_dir) is not None)
111
    if already_has_ckpt:
112
      tf.logging.info("Checkpoint exists in model_dir, not loading variables.")
113
      return
114

115
    # Create a list of non-mask variables to load
116
    reader = tf.train.NewCheckpointReader(checkpoint)
117
    non_mask_names = reader.get_variable_to_shape_map().keys()
118
    non_mask_names = [x for x in non_mask_names if not x.endswith("mask")]
119

120
    variable_map = {}
121
    for var in tf.global_variables():
122
      var_name = var.name.split(":")[0]
123
      if var_name in non_mask_names:
124
        tf.logging.info(
125
            "Loading non-mask variable from checkpoint: %s", var_name)
126
        variable_map[var_name] = var
127
      elif "mask" not in var_name:
128
        tf.logging.info(
129
            "Cannot find non-mask variable in checkpoint, skipping: %s",
130
            var_name)
131
    tf.train.init_from_checkpoint(checkpoint, variable_map)
132

133
  def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
134
    """Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
135
    train_op = self.optimize(
136
        loss,
137
        num_async_replicas=num_async_replicas,
138
        use_tpu=use_tpu)
139

140
    sparsity_technique = self._hparams.get("sparsity_technique")
141
    if "pruning" in sparsity_technique:
142
      if not self._hparams.load_masks_from:
143
        # If we are loading trained masks, don't add the mask update
144
        # step to the training process and keep the masks static
145
        with tf.control_dependencies([train_op]):
146
          mp_hparams = pruning_hparams(
147
              self._hparams,
148
              use_tpu,
149
              sparsity_technique == "random_pruning")
150
          p = magnitude_pruning.Pruning(
151
              mp_hparams,
152
              global_step=tf.train.get_global_step())
153
          mask_update_op = p.conditional_mask_update_op()
154
          train_op = mask_update_op
155
      check_global_sparsity()
156

157
    if use_tpu:
158
      if self._hparams.warm_start_from:
159
        def scaffold_fn():
160
          self.initialize_from_ckpt(
161
              self._hparams.warm_start_from)
162
          return tf.train.Scaffold()
163
      elif self._hparams.load_masks_from and self._hparams.load_weights_from:
164
        def scaffold_fn():
165
          self.initialize_masks_from_ckpt(
166
              self._hparams.load_masks_from)
167
          self.initialize_non_masks_from_ckpt(
168
              self._hparams.load_weights_from)
169
          return tf.train.Scaffold()
170
      elif self._hparams.load_masks_from:
171
        def scaffold_fn():
172
          self.initialize_masks_from_ckpt(
173
              self._hparams.load_masks_from)
174
          return tf.train.Scaffold()
175
      else:
176
        scaffold_fn = None
177

178
      # Note: important to call this before remove_summaries()
179
      if self.hparams.tpu_enable_host_call:
180
        host_call = t2t_model.create_host_call(self.hparams.model_dir)
181
      else:
182
        host_call = None
183

184
      t2t_model.remove_summaries()
185

186
      return contrib_tpu.TPUEstimatorSpec(
187
          tf_estimator.ModeKeys.TRAIN,
188
          loss=loss,
189
          train_op=train_op,
190
          host_call=host_call,
191
          scaffold_fn=scaffold_fn)
192
    else:
193
      if self._hparams.warm_start_from:
194
        self.initialize_from_ckpt(
195
            self._hparams.warm_start_from)
196
      elif self._hparams.load_masks_from:
197
        self.initialize_masks_from_ckpt(
198
            self._hparams.load_masks_from)
199

200
      return tf_estimator.EstimatorSpec(
201
          tf_estimator.ModeKeys.TRAIN,
202
          loss=loss,
203
          train_op=train_op)
204

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.