google-research
203 строки · 7.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Base class for sparse model."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensor2tensor.utils import t2t_model
22import tensorflow.compat.v1 as tf
23from tensorflow.compat.v1 import estimator as tf_estimator
24from tensorflow.contrib import tpu as contrib_tpu
25from tensorflow.contrib import training as contrib_training
26
27from tensorflow.contrib.model_pruning.python import pruning as magnitude_pruning
28
29
30def pruning_hparams(hparams, use_tpu, random): # pylint: disable=unused-argument
31"""Helper to get hparams for pruning library."""
32weight_sparsity_map = [""]
33if hparams.get("embedding_sparsity") >= 0.0:
34weight_sparsity_map = [
35"transformer/symbol_modality_33288_512/shared/:{}"
36.format(hparams.get("embedding_sparsity"))
37]
38tf.logging.info(
39"Pruning embedding matrix to {}% sparsity"
40.format(hparams.get("embedding_sparsity") * 100))
41
42hparams = contrib_training.HParams(
43name="model_pruning",
44begin_pruning_step=hparams.get("begin_pruning_step"),
45end_pruning_step=hparams.get("end_pruning_step"),
46weight_sparsity_map=weight_sparsity_map,
47threshold_decay=hparams.get("threshold_decay"),
48pruning_frequency=hparams.get("pruning_frequency"),
49nbins=hparams.get("nbins"),
50block_height=1,
51block_width=1,
52block_pooling_function="AVG",
53initial_sparsity=0.0, # always start at sparsity 0
54target_sparsity=hparams.get("target_sparsity"),
55sparsity_function_begin_step=hparams.get("begin_pruning_step"),
56sparsity_function_end_step=hparams.get("end_pruning_step"),
57sparsity_function_exponent=hparams.get("sparsity_function_exponent"),
58use_tpu=use_tpu)
59# TODO(tgale): Fix the need to keep this commented out.
60# random pruning currently does not work.
61# random=random)
62return hparams
63
64
65def check_global_sparsity():
66"""Add a summary for the weight sparsity."""
67weight_masks = magnitude_pruning.get_masks()
68weights_per_layer = []
69nonzero_per_layer = []
70for mask in weight_masks:
71nonzero_per_layer.append(tf.reduce_sum(mask))
72weights_per_layer.append(tf.size(mask))
73total_nonzero = tf.add_n(nonzero_per_layer)
74total_weights = tf.add_n(weights_per_layer)
75sparsity = (1.0 - (tf.cast(total_nonzero, tf.float32) /
76tf.cast(total_weights, tf.float32)))
77tf.summary.scalar("global_weight_sparsity", sparsity)
78
79
80class SparseModel(t2t_model.T2TModel):
81"""T2T model with weight sparsity."""
82
83def initialize_masks_from_ckpt(self, checkpoint):
84model_dir = self._hparams.get("model_dir", None)
85already_has_ckpt = (
86model_dir and tf.train.latest_checkpoint(model_dir) is not None)
87if already_has_ckpt:
88tf.logging.info("Checkpoint exists in model_dir, not loading variables.")
89return
90
91# Create a list of mask variables to load
92reader = tf.train.NewCheckpointReader(checkpoint)
93mask_names = reader.get_variable_to_shape_map().keys()
94mask_names = [x for x in mask_names if x.endswith("mask")]
95
96variable_map = {}
97for var in tf.global_variables():
98var_name = var.name.split(":")[0]
99if var_name in mask_names:
100tf.logging.info("Loading mask variable from checkpoint: %s", var_name)
101variable_map[var_name] = var
102elif "mask" in var_name:
103tf.logging.info(
104"Cannot find mask variable in checkpoint, skipping: %s", var_name)
105tf.train.init_from_checkpoint(checkpoint, variable_map)
106
107def initialize_non_masks_from_ckpt(self, checkpoint):
108model_dir = self._hparams.get("model_dir", None)
109already_has_ckpt = (
110model_dir and tf.train.latest_checkpoint(model_dir) is not None)
111if already_has_ckpt:
112tf.logging.info("Checkpoint exists in model_dir, not loading variables.")
113return
114
115# Create a list of non-mask variables to load
116reader = tf.train.NewCheckpointReader(checkpoint)
117non_mask_names = reader.get_variable_to_shape_map().keys()
118non_mask_names = [x for x in non_mask_names if not x.endswith("mask")]
119
120variable_map = {}
121for var in tf.global_variables():
122var_name = var.name.split(":")[0]
123if var_name in non_mask_names:
124tf.logging.info(
125"Loading non-mask variable from checkpoint: %s", var_name)
126variable_map[var_name] = var
127elif "mask" not in var_name:
128tf.logging.info(
129"Cannot find non-mask variable in checkpoint, skipping: %s",
130var_name)
131tf.train.init_from_checkpoint(checkpoint, variable_map)
132
133def estimator_spec_train(self, loss, num_async_replicas=1, use_tpu=False):
134"""Constructs `tf.estimator.EstimatorSpec` for TRAIN (training) mode."""
135train_op = self.optimize(
136loss,
137num_async_replicas=num_async_replicas,
138use_tpu=use_tpu)
139
140sparsity_technique = self._hparams.get("sparsity_technique")
141if "pruning" in sparsity_technique:
142if not self._hparams.load_masks_from:
143# If we are loading trained masks, don't add the mask update
144# step to the training process and keep the masks static
145with tf.control_dependencies([train_op]):
146mp_hparams = pruning_hparams(
147self._hparams,
148use_tpu,
149sparsity_technique == "random_pruning")
150p = magnitude_pruning.Pruning(
151mp_hparams,
152global_step=tf.train.get_global_step())
153mask_update_op = p.conditional_mask_update_op()
154train_op = mask_update_op
155check_global_sparsity()
156
157if use_tpu:
158if self._hparams.warm_start_from:
159def scaffold_fn():
160self.initialize_from_ckpt(
161self._hparams.warm_start_from)
162return tf.train.Scaffold()
163elif self._hparams.load_masks_from and self._hparams.load_weights_from:
164def scaffold_fn():
165self.initialize_masks_from_ckpt(
166self._hparams.load_masks_from)
167self.initialize_non_masks_from_ckpt(
168self._hparams.load_weights_from)
169return tf.train.Scaffold()
170elif self._hparams.load_masks_from:
171def scaffold_fn():
172self.initialize_masks_from_ckpt(
173self._hparams.load_masks_from)
174return tf.train.Scaffold()
175else:
176scaffold_fn = None
177
178# Note: important to call this before remove_summaries()
179if self.hparams.tpu_enable_host_call:
180host_call = t2t_model.create_host_call(self.hparams.model_dir)
181else:
182host_call = None
183
184t2t_model.remove_summaries()
185
186return contrib_tpu.TPUEstimatorSpec(
187tf_estimator.ModeKeys.TRAIN,
188loss=loss,
189train_op=train_op,
190host_call=host_call,
191scaffold_fn=scaffold_fn)
192else:
193if self._hparams.warm_start_from:
194self.initialize_from_ckpt(
195self._hparams.warm_start_from)
196elif self._hparams.load_masks_from:
197self.initialize_masks_from_ckpt(
198self._hparams.load_masks_from)
199
200return tf_estimator.EstimatorSpec(
201tf_estimator.ModeKeys.TRAIN,
202loss=loss,
203train_op=train_op)
204