google-research

Форк
0
/
checkpoint_sparsity.py 
173 строки · 4.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Calculates weight sparsity for a model checkpoint."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import functools
22
from absl import app
23
from absl import flags
24

25
import numpy as np
26
import tensorflow.compat.v1 as tf
27

28

29
flags.DEFINE_string(
30
    "checkpoint",
31
    None,
32
    "Path to checkpoint."
33
)
34
flags.DEFINE_enum(
35
    "sparsity_technique",
36
    "magnitude_pruning",
37
    [
38
        "magnitude_pruning",
39
        "random_pruning",
40
        "variational_dropout",
41
        "l0_regularization"
42
    ],
43
    "Technique used to produce model checkpoint."
44
)
45
flags.DEFINE_enum(
46
    "model",
47
    "transformer",
48
    ["transformer", "rn50"],
49
    "Model saved in checkpoint."
50
)
51
flags.DEFINE_float(
52
    "log_alpha_threshold",
53
    3.0,
54
    "log alpha threshold for variational dropout checkpoint."
55
)
56

57
FLAGS = flags.FLAGS
58
EPSILON = 1e-8
59
GAMMA = -0.1
60
ZETA = 1.1
61

62

63
def get_sparsity(checkpoint, suffixes, mask_fn):
64
  """Helper function to calculate and print sparsity from a checkpoint.
65

66
  Args:
67
    checkpoint: path to checkpoint.
68
    suffixes: possible suffixes of mask variables in the checkpoint.
69
    mask_fn: helper function to calculate the weight mask from a saved
70
      tensor.
71
  """
72
  ckpt_reader = tf.train.NewCheckpointReader(checkpoint)
73

74
  # Create a list of variable names to process.
75
  all_names = ckpt_reader.get_variable_to_shape_map().keys()
76

77
  # Gather all variables ending with the specified suffixes
78
  tensor_names = []
79
  for s in suffixes:
80
    tensor_names += [x for x in all_names if x.endswith(s)]
81

82
  sorted_list = sorted(tensor_names)
83
  nnz = 0.0
84
  total = 0.0
85
  for s in sorted_list:
86
    tensor = ckpt_reader.get_tensor(s)
87
    mask = mask_fn(tensor)
88
    nnz += np.count_nonzero(mask)
89
    total += mask.size
90
  print("{} global sparsity = {}%".format(checkpoint, 100 * (1 - nnz / total)))
91

92

93
def l0_mask(log_alpha, gamma=GAMMA, zeta=ZETA):
94
  """Helper to get weight mask for an l0-regularized tensor."""
95
  def sigmoid(x):
96
    return 1/(1+np.exp(-x))
97
  stretched_values = sigmoid(log_alpha) * (zeta - gamma) + gamma
98
  return np.clip(stretched_values, a_max=1.0, a_min=0.0)
99

100

101
# Specialization of 'get_sparsity' for l0-regularized models.
102
l0_sparsity = functools.partial(
103
    get_sparsity,
104
    suffixes=["log_alpha", "_aux"],
105
    mask_fn=l0_mask)
106

107

108
# Specialization of 'get_sparsity' for magnitude & random pruning models.
109
pruning_sparsity = functools.partial(
110
    get_sparsity,
111
    suffixes=["mask"],
112
    mask_fn=lambda x: x)
113

114

115
def compute_log_alpha(log_sigma2, theta, eps=EPSILON):
116
  """Compute the log-alpha values for tensor trained with variational dropout."""
117
  return log_sigma2 - np.log(np.square(theta) + eps)
118

119

120
def vd_sparsity(checkpoint, log_alpha_threshold, model):
121
  """Calculate and print global sparsity for variational dropout checkpoint.
122

123
  Args:
124
    checkpoint: path to checkpoint.
125
    log_alpha_threshold: log alpha threshold to calculate sparsity with.
126
    model: either 'transformer' or 'rn50'.
127
  """
128
  weight_suffix = "kernel"
129
  if model == "rn50":
130
    weight_suffix = "weights"
131

132
  ckpt_reader = tf.train.NewCheckpointReader(checkpoint)
133

134
  # Create a list of variable names to process.
135
  all_names = ckpt_reader.get_variable_to_shape_map().keys()
136

137
  # Gather all variables ending with the specified suffixes
138
  tensor_names = [x for x in all_names if x.endswith("log_sigma2")]
139
  tensor_names += [x for x in all_names if x.endswith("_aux")]
140

141
  sorted_list = sorted(tensor_names)
142
  nnz = 0.0
143
  total = 0.0
144
  for s in sorted_list:
145
    log_sigma2 = ckpt_reader.get_tensor(s)
146

147
    if s.endswith("log_sigma2"):
148
      theta_name = s.replace("log_sigma2", weight_suffix)
149
    else:
150
      theta_name = s.replace("_aux", "")
151
    theta = ckpt_reader.get_tensor(theta_name)
152
    mask = np.less(compute_log_alpha(log_sigma2, theta), log_alpha_threshold)
153

154
    nnz += np.count_nonzero(mask)
155
    total += mask.size
156
  print("{} global sparsity = {}%".format(checkpoint, 100 * (1 - nnz / total)))
157

158

159
def main(_):
160
  flags.mark_flag_as_required("checkpoint")
161

162
  if "pruning" in FLAGS.sparsity_technique:
163
    pruning_sparsity(FLAGS.checkpoint)
164
  elif FLAGS.sparsity_technique == "l0_regularization":
165
    l0_sparsity(FLAGS.checkpoint)
166
  elif FLAGS.sparsity_technique == "variational_dropout":
167
    vd_sparsity(FLAGS.checkpoint, FLAGS.log_alpha_threshold, FLAGS.model)
168
  else:
169
    raise ValueError("Invalid sparsity_technique argument.")
170

171

172
if __name__ == "__main__":
173
  app.run(main)
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.