google-research
173 строки · 4.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Calculates weight sparsity for a model checkpoint."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22from absl import app
23from absl import flags
24
25import numpy as np
26import tensorflow.compat.v1 as tf
27
28
29flags.DEFINE_string(
30"checkpoint",
31None,
32"Path to checkpoint."
33)
34flags.DEFINE_enum(
35"sparsity_technique",
36"magnitude_pruning",
37[
38"magnitude_pruning",
39"random_pruning",
40"variational_dropout",
41"l0_regularization"
42],
43"Technique used to produce model checkpoint."
44)
45flags.DEFINE_enum(
46"model",
47"transformer",
48["transformer", "rn50"],
49"Model saved in checkpoint."
50)
51flags.DEFINE_float(
52"log_alpha_threshold",
533.0,
54"log alpha threshold for variational dropout checkpoint."
55)
56
57FLAGS = flags.FLAGS
58EPSILON = 1e-8
59GAMMA = -0.1
60ZETA = 1.1
61
62
63def get_sparsity(checkpoint, suffixes, mask_fn):
64"""Helper function to calculate and print sparsity from a checkpoint.
65
66Args:
67checkpoint: path to checkpoint.
68suffixes: possible suffixes of mask variables in the checkpoint.
69mask_fn: helper function to calculate the weight mask from a saved
70tensor.
71"""
72ckpt_reader = tf.train.NewCheckpointReader(checkpoint)
73
74# Create a list of variable names to process.
75all_names = ckpt_reader.get_variable_to_shape_map().keys()
76
77# Gather all variables ending with the specified suffixes
78tensor_names = []
79for s in suffixes:
80tensor_names += [x for x in all_names if x.endswith(s)]
81
82sorted_list = sorted(tensor_names)
83nnz = 0.0
84total = 0.0
85for s in sorted_list:
86tensor = ckpt_reader.get_tensor(s)
87mask = mask_fn(tensor)
88nnz += np.count_nonzero(mask)
89total += mask.size
90print("{} global sparsity = {}%".format(checkpoint, 100 * (1 - nnz / total)))
91
92
93def l0_mask(log_alpha, gamma=GAMMA, zeta=ZETA):
94"""Helper to get weight mask for an l0-regularized tensor."""
95def sigmoid(x):
96return 1/(1+np.exp(-x))
97stretched_values = sigmoid(log_alpha) * (zeta - gamma) + gamma
98return np.clip(stretched_values, a_max=1.0, a_min=0.0)
99
100
101# Specialization of 'get_sparsity' for l0-regularized models.
102l0_sparsity = functools.partial(
103get_sparsity,
104suffixes=["log_alpha", "_aux"],
105mask_fn=l0_mask)
106
107
108# Specialization of 'get_sparsity' for magnitude & random pruning models.
109pruning_sparsity = functools.partial(
110get_sparsity,
111suffixes=["mask"],
112mask_fn=lambda x: x)
113
114
115def compute_log_alpha(log_sigma2, theta, eps=EPSILON):
116"""Compute the log-alpha values for tensor trained with variational dropout."""
117return log_sigma2 - np.log(np.square(theta) + eps)
118
119
120def vd_sparsity(checkpoint, log_alpha_threshold, model):
121"""Calculate and print global sparsity for variational dropout checkpoint.
122
123Args:
124checkpoint: path to checkpoint.
125log_alpha_threshold: log alpha threshold to calculate sparsity with.
126model: either 'transformer' or 'rn50'.
127"""
128weight_suffix = "kernel"
129if model == "rn50":
130weight_suffix = "weights"
131
132ckpt_reader = tf.train.NewCheckpointReader(checkpoint)
133
134# Create a list of variable names to process.
135all_names = ckpt_reader.get_variable_to_shape_map().keys()
136
137# Gather all variables ending with the specified suffixes
138tensor_names = [x for x in all_names if x.endswith("log_sigma2")]
139tensor_names += [x for x in all_names if x.endswith("_aux")]
140
141sorted_list = sorted(tensor_names)
142nnz = 0.0
143total = 0.0
144for s in sorted_list:
145log_sigma2 = ckpt_reader.get_tensor(s)
146
147if s.endswith("log_sigma2"):
148theta_name = s.replace("log_sigma2", weight_suffix)
149else:
150theta_name = s.replace("_aux", "")
151theta = ckpt_reader.get_tensor(theta_name)
152mask = np.less(compute_log_alpha(log_sigma2, theta), log_alpha_threshold)
153
154nnz += np.count_nonzero(mask)
155total += mask.size
156print("{} global sparsity = {}%".format(checkpoint, 100 * (1 - nnz / total)))
157
158
159def main(_):
160flags.mark_flag_as_required("checkpoint")
161
162if "pruning" in FLAGS.sparsity_technique:
163pruning_sparsity(FLAGS.checkpoint)
164elif FLAGS.sparsity_technique == "l0_regularization":
165l0_sparsity(FLAGS.checkpoint)
166elif FLAGS.sparsity_technique == "variational_dropout":
167vd_sparsity(FLAGS.checkpoint, FLAGS.log_alpha_threshold, FLAGS.model)
168else:
169raise ValueError("Invalid sparsity_technique argument.")
170
171
172if __name__ == "__main__":
173app.run(main)
174