google-research

Форк
0
298 строк · 10.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines an implementation of tensorflow core layers with vd pruning.
17
"""
18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import tensorflow.compat.v1 as tf
23
from state_of_sparsity.layers.utils import layer_utils
24
from state_of_sparsity.layers.variational_dropout import common
25
from state_of_sparsity.layers.variational_dropout import nn
26
from tensorflow.python.layers import base  # pylint: disable=g-direct-tensorflow-import
27

28
THETA_LOGSIGMA2_COLLECTION = "theta_logsigma2"
29

30

31
class Conv2D(base.Layer):
32
  r"""Base implementation of a conv2d layer with variational dropout.
33

34
   Instead of deterministic parameters, parameters are drawn from a
35
   distribution with mean \theta and variance \sigma^2.  A log-uniform prior
36
   for the distribution is used to encourage sparsity.
37

38
    Args:
39
      x: Input, float32 tensor.
40
      num_outputs: Int representing size of output tensor.
41
      kernel_size: The size of the convolutional window, int of list of ints.
42
      strides: stride length of convolution, a single int is expected.
43
      padding: May be populated as `"VALID"` or `"SAME"`.
44
      activation: If None, a linear activation is used.
45
      kernel_initializer: Initializer for the convolution weights.
46
      bias_initializer: Initalizer of the bias vector.
47
      kernel_regularizer: Regularization method for the convolution weights.
48
      bias_regularizer: Optional regularizer for the bias vector.
49
      log_sigma2_initializer: Specified initializer of the log_sigma2 term.
50
      data_format: Either'"channels_last"','"NHWC"','"NCHW"','"channels_first".
51
      is_training: Boolean specifying whether it is training or eval.
52
      use_bias: Boolean specifying whether bias vector should be used.
53
      eps: Small epsilon value to prevent math op saturation.
54
      threshold: Threshold for masking log alpha at test time. The relationship
55
        between \sigma^2, \theta, and \alpha as defined in the
56
        paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2
57
      clip_alpha: Int that specifies range for clipping log alpha values during
58
        training.
59
      name: String speciying name scope of layer in network.
60

61
    Returns:
62
      Output Tensor of the conv2d operation.
63
  """
64

65
  def __init__(self,
66
               num_outputs,
67
               kernel_size,
68
               strides,
69
               padding,
70
               activation,
71
               kernel_initializer,
72
               bias_initializer,
73
               kernel_regularizer,
74
               bias_regularizer,
75
               log_sigma2_initializer,
76
               data_format,
77
               activity_regularizer=None,
78
               is_training=True,
79
               trainable=True,
80
               use_bias=False,
81
               eps=common.EPSILON,
82
               threshold=3.,
83
               clip_alpha=8.,
84
               name="",
85
               **kwargs):
86
    super(Conv2D, self).__init__(
87
        trainable=trainable,
88
        name=name,
89
        activity_regularizer=activity_regularizer,
90
        **kwargs)
91
    self.num_outputs = num_outputs
92
    self.kernel_size = kernel_size
93
    self.strides = [1, strides[0], strides[1], 1]
94
    self.padding = padding.upper()
95
    self.activation = activation
96
    self.kernel_initializer = kernel_initializer
97
    self.bias_initializer = bias_initializer
98
    self.kernel_regularizer = kernel_regularizer
99
    self.bias_regularizer = bias_regularizer
100
    self.log_sigma2_initializer = log_sigma2_initializer
101
    self.data_format = layer_utils.standardize_data_format(data_format)
102
    self.is_training = is_training
103
    self.use_bias = use_bias
104
    self.eps = eps
105
    self.threshold = threshold
106
    self.clip_alpha = clip_alpha
107

108
  def build(self, input_shape):
109
    input_shape = input_shape.as_list()
110
    dims = input_shape[3]
111
    kernel_shape = [
112
        self.kernel_size[0], self.kernel_size[1], dims, self.num_outputs
113
    ]
114

115
    self.kernel = tf.get_variable(
116
        "kernel",
117
        shape=kernel_shape,
118
        initializer=self.kernel_initializer,
119
        regularizer=self.kernel_regularizer,
120
        dtype=tf.float32,
121
        trainable=True)
122

123
    if not self.log_sigma2_initializer:
124
      self.log_sigma2_initializer = tf.constant_initializer(
125
          value=-10, dtype=tf.float32)
126

127
    self.log_sigma2 = tf.get_variable(
128
        "log_sigma2",
129
        shape=kernel_shape,
130
        initializer=self.log_sigma2_initializer,
131
        dtype=tf.float32,
132
        trainable=True)
133

134
    layer_utils.add_variable_to_collection(
135
        (self.kernel, self.log_sigma2),
136
        [THETA_LOGSIGMA2_COLLECTION],
137
        None)
138

139
    if self.use_bias:
140
      self.bias = self.add_variable(
141
          name="bias",
142
          shape=(self.filters,),
143
          initializer=self.bias_initializer,
144
          regularizer=self.bias_regularizer,
145
          trainable=True,
146
          dtype=self.dtype)
147
    else:
148
      self.bias = None
149
    self.built = True
150

151
  def call(self, inputs):
152

153
    if self.is_training:
154
      output = nn.conv2d_train(
155
          x=inputs,
156
          variational_params=(self.kernel, self.log_sigma2),
157
          strides=self.strides,
158
          padding=self.padding,
159
          data_format=self.data_format,
160
          clip_alpha=self.clip_alpha,
161
          eps=self.eps)
162
    else:
163
      output = nn.conv2d_eval(
164
          x=inputs,
165
          variational_params=(self.kernel, self.log_sigma2),
166
          strides=self.strides,
167
          padding=self.padding,
168
          data_format=self.data_format,
169
          threshold=self.threshold,
170
          eps=self.eps)
171

172
    if self.use_bias:
173
      output = tf.nn.bias_add(output, self.bias)
174
    if self.activation is not None:
175
      return self.activation(output)
176
    else:
177
      return output
178

179

180
class FullyConnected(base.Layer):
181
  r"""Base implementation of a fully connected layer with variational dropout.
182

183
   Instead of deterministic parameters, parameters are drawn from a
184
   distribution with mean \theta and variance \sigma^2.  A log-uniform prior
185
   for the distribution is used to encourage sparsity.
186

187
    Args:
188
      x: Input, float32 tensor.
189
      num_outputs: Int representing size of output tensor.
190
      activation: If None, a linear activation is used.
191
      kernel_initializer: Initializer for the convolution weights.
192
      bias_initializer: Initalizer of the bias vector.
193
      kernel_regularizer: Regularization method for the convolution weights.
194
      bias_regularizer: Optional regularizer for the bias vector.
195
      log_sigma2_initializer: Specified initializer of the log_sigma2 term.
196
      is_training: Boolean specifying whether it is training or eval.
197
      use_bias: Boolean specifying whether bias vector should be used.
198
      eps: Small epsilon value to prevent math op saturation.
199
      threshold: Threshold for masking log alpha at test time. The relationship
200
        between \sigma^2, \theta, and \alpha as defined in the
201
        paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2
202
      clip_alpha: Int that specifies range for clipping log alpha values during
203
        training.
204
      name: String speciying name scope of layer in network.
205

206
    Returns:
207
      Output Tensor of the fully connected operation.
208
  """
209

210
  def __init__(self,
211
               num_outputs,
212
               activation,
213
               kernel_initializer,
214
               bias_initializer,
215
               kernel_regularizer,
216
               bias_regularizer,
217
               log_sigma2_initializer,
218
               activity_regularizer=None,
219
               is_training=True,
220
               trainable=True,
221
               use_bias=True,
222
               eps=common.EPSILON,
223
               threshold=3.,
224
               clip_alpha=8.,
225
               name="FullyConnected",
226
               **kwargs):
227
    super(FullyConnected, self).__init__(
228
        trainable=trainable,
229
        name=name,
230
        activity_regularizer=activity_regularizer,
231
        **kwargs)
232
    self.num_outputs = num_outputs
233
    self.activation = activation
234
    self.kernel_initializer = kernel_initializer
235
    self.bias_initializer = bias_initializer
236
    self.kernel_regularizer = kernel_regularizer
237
    self.bias_regularizer = bias_regularizer
238
    self.log_sigma2_initializer = log_sigma2_initializer
239
    self.is_training = is_training
240
    self.use_bias = use_bias
241
    self.eps = eps
242
    self.threshold = threshold
243
    self.clip_alpha = clip_alpha
244

245
  def build(self, input_shape):
246
    input_shape = input_shape.as_list()
247
    input_hidden_size = input_shape[1]
248
    kernel_shape = [input_hidden_size, self.num_outputs]
249

250
    self.kernel = tf.get_variable(
251
        "kernel",
252
        shape=kernel_shape,
253
        initializer=self.kernel_initializer,
254
        regularizer=self.kernel_regularizer,
255
        dtype=tf.float32,
256
        trainable=True)
257

258
    if not self.log_sigma2_initializer:
259
      self.log_sigma2_initializer = tf.constant_initializer(
260
          value=-10, dtype=tf.float32)
261

262
    self.log_sigma2 = tf.get_variable(
263
        "log_sigma2",
264
        shape=kernel_shape,
265
        initializer=self.log_sigma2_initializer,
266
        dtype=tf.float32,
267
        trainable=True)
268

269
    layer_utils.add_variable_to_collection(
270
        (self.kernel, self.log_sigma2),
271
        [THETA_LOGSIGMA2_COLLECTION],
272
        None)
273

274
    if self.use_bias:
275
      self.bias = self.add_variable(
276
          name="bias",
277
          shape=(self.num_outputs,),
278
          initializer=self.bias_initializer,
279
          regularizer=self.bias_regularizer,
280
          trainable=True,
281
          dtype=self.dtype)
282
    else:
283
      self.bias = None
284
    self.built = True
285

286
  def call(self, inputs):
287
    if self.is_training:
288
      x = nn.matmul_train(
289
          inputs, (self.kernel, self.log_sigma2), clip_alpha=self.clip_alpha)
290
    else:
291
      x = nn.matmul_eval(
292
          inputs, (self.kernel, self.log_sigma2), threshold=self.threshold)
293

294
    if self.use_bias:
295
      x = tf.nn.bias_add(x, self.bias)
296
    if self.activation is not None:
297
      return self.activation(x)
298
    return x
299

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.