google-research

Форк
0
306 строк · 10.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""tf.layers-like API for l0-regularization layers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22

23
from state_of_sparsity.layers.l0_regularization import common
24
from state_of_sparsity.layers.l0_regularization import nn
25
from state_of_sparsity.layers.utils import layer_utils
26
from tensorflow.python.layers import base  # pylint: disable=g-direct-tensorflow-import
27

28

29
THETA_LOGALPHA_COLLECTION = "theta_logalpha"
30

31

32
class Conv2D(base.Layer):
33
  """Base implementation of a conv2d layer with l0-regularization.
34

35
    Args:
36
      num_outputs: Int representing size of output tensor.
37
      kernel_size: The size of the convolutional window, int of list of ints.
38
      strides: stride length of convolution, a single int is expected.
39
      padding: May be populated as "VALID" or "SAME".
40
      activation: If None, a linear activation is used.
41
      kernel_initializer: Initializer for the convolution weights.
42
      bias_initializer: Initalizer of the bias vector.
43
      kernel_regularizer: Regularization method for the convolution weights.
44
      bias_regularizer: Optional regularizer for the bias vector.
45
      log_alpha_initializer: initializer for the log alpha parameters of the
46
        hard-concrete distribution.
47
      data_format: Either "channels_last", "NHWC", "NCHW", "channels_first".
48
      is_training: Boolean specifying whether it is training or eval.
49
      use_bias: Boolean specifying whether bias vector should be used.
50
      eps: Small epsilon value to prevent math op saturation.
51
      beta: The beta parameter, which controls the "temperature" of
52
        the distribution. Defaults to 2/3 from the above paper.
53
      gamma: The gamma parameter, which controls the lower bound of the
54
        stretched distribution. Defaults to -0.1 from the above paper.
55
      zeta: The zeta parameters, which controls the upper bound of the
56
        stretched distribution. Defaults to 1.1 from the above paper.
57
      name: String speciying name scope of layer in network.
58

59
    Returns:
60
      Output Tensor of the conv2d operation.
61
  """
62

63
  def __init__(self,
64
               num_outputs,
65
               kernel_size,
66
               strides,
67
               padding,
68
               activation,
69
               kernel_initializer,
70
               bias_initializer,
71
               kernel_regularizer,
72
               bias_regularizer,
73
               log_alpha_initializer,
74
               data_format,
75
               activity_regularizer=None,
76
               is_training=True,
77
               trainable=True,
78
               use_bias=False,
79
               eps=common.EPSILON,
80
               beta=common.BETA,
81
               gamma=common.GAMMA,
82
               zeta=common.ZETA,
83
               name="",
84
               **kwargs):
85
    super(Conv2D, self).__init__(
86
        trainable=trainable,
87
        name=name,
88
        activity_regularizer=activity_regularizer,
89
        **kwargs)
90
    self.num_outputs = num_outputs
91
    self.kernel_size = kernel_size
92
    self.strides = [1, strides[0], strides[1], 1]
93
    self.padding = padding.upper()
94
    self.activation = activation
95
    self.kernel_initializer = kernel_initializer
96
    self.bias_initializer = bias_initializer
97
    self.kernel_regularizer = kernel_regularizer
98
    self.bias_regularizer = bias_regularizer
99
    self.log_alpha_initializer = log_alpha_initializer
100
    self.data_format = layer_utils.standardize_data_format(data_format)
101
    self.is_training = is_training
102
    self.use_bias = use_bias
103
    self.eps = eps
104
    self.beta = beta
105
    self.gamma = gamma
106
    self.zeta = zeta
107

108
  def build(self, input_shape):
109
    input_shape = input_shape.as_list()
110
    dims = input_shape[3]
111
    kernel_shape = [
112
        self.kernel_size[0], self.kernel_size[1], dims, self.num_outputs
113
    ]
114

115
    self.kernel = tf.get_variable(
116
        "kernel",
117
        shape=kernel_shape,
118
        initializer=self.kernel_initializer,
119
        regularizer=self.kernel_regularizer,
120
        dtype=self.dtype,
121
        trainable=True)
122

123
    if not self.log_alpha_initializer:
124
      # default log alpha set s.t. \alpha / (\alpha + 1) = .1
125
      self.log_alpha_initializer = tf.random_normal_initializer(
126
          mean=2.197, stddev=0.01, dtype=self.dtype)
127

128
    self.log_alpha = tf.get_variable(
129
        "log_alpha",
130
        shape=kernel_shape,
131
        initializer=self.log_alpha_initializer,
132
        dtype=self.dtype,
133
        trainable=True)
134

135
    layer_utils.add_variable_to_collection(
136
        (self.kernel, self.log_alpha),
137
        [THETA_LOGALPHA_COLLECTION], None)
138

139
    if self.use_bias:
140
      self.bias = self.add_variable(
141
          name="bias",
142
          shape=(self.filters,),
143
          initializer=self.bias_initializer,
144
          regularizer=self.bias_regularizer,
145
          trainable=True,
146
          dtype=self.dtype)
147
    else:
148
      self.bias = None
149
    self.built = True
150

151
  def call(self, inputs):
152
    if self.is_training:
153
      output = nn.conv2d_train(
154
          x=inputs,
155
          weight_parameters=(self.kernel, self.log_alpha),
156
          strides=self.strides,
157
          padding=self.padding,
158
          data_format=self.data_format,
159
          beta=self.beta,
160
          gamma=self.gamma,
161
          zeta=self.zeta,
162
          eps=self.eps)
163
    else:
164
      output = nn.conv2d_eval(
165
          x=inputs,
166
          weight_parameters=(self.kernel, self.log_alpha),
167
          strides=self.strides,
168
          padding=self.padding,
169
          data_format=self.data_format,
170
          gamma=self.gamma,
171
          zeta=self.zeta)
172

173
    if self.use_bias:
174
      output = tf.nn.bias_add(output, self.bias)
175
    if self.activation is not None:
176
      return self.activation(output)
177
    else:
178
      return output
179

180

181
class FullyConnected(base.Layer):
182
  """Base implementation of a fully connected layer with l0 regularization.
183

184
    Args:
185
      x: Input, float32 tensor.
186
      num_outputs: Int representing size of output tensor.
187
      activation: If None, a linear activation is used.
188
      kernel_initializer: Initializer for the convolution weights.
189
      bias_initializer: Initalizer of the bias vector.
190
      kernel_regularizer: Regularization method for the convolution weights.
191
      bias_regularizer: Optional regularizer for the bias vector.
192
      log_alpha_initializer: Specified initializer of the log_alpha term.
193
      is_training: Boolean specifying whether it is training or eval.
194
      use_bias: Boolean specifying whether bias vector should be used.
195
      eps: Small epsilon value to prevent math op saturation.
196
      beta: The beta parameter, which controls the "temperature" of
197
        the distribution. Defaults to 2/3 from the above paper.
198
      gamma: The gamma parameter, which controls the lower bound of the
199
        stretched distribution. Defaults to -0.1 from the above paper.
200
      zeta: The zeta parameters, which controls the upper bound of the
201
        stretched distribution. Defaults to 1.1 from the above paper.
202
      name: String speciying name scope of layer in network.
203

204
    Returns:
205
      Output Tensor of the fully connected operation.
206
  """
207

208
  def __init__(self,
209
               num_outputs,
210
               activation,
211
               kernel_initializer,
212
               bias_initializer,
213
               kernel_regularizer,
214
               bias_regularizer,
215
               log_alpha_initializer,
216
               activity_regularizer=None,
217
               is_training=True,
218
               trainable=True,
219
               use_bias=True,
220
               eps=common.EPSILON,
221
               beta=common.BETA,
222
               gamma=common.GAMMA,
223
               zeta=common.ZETA,
224
               name="FullyConnected",
225
               **kwargs):
226
    super(FullyConnected, self).__init__(
227
        trainable=trainable,
228
        name=name,
229
        activity_regularizer=activity_regularizer,
230
        **kwargs)
231
    self.num_outputs = num_outputs
232
    self.activation = activation
233
    self.kernel_initializer = kernel_initializer
234
    self.bias_initializer = bias_initializer
235
    self.kernel_regularizer = kernel_regularizer
236
    self.bias_regularizer = bias_regularizer
237
    self.log_alpha_initializer = log_alpha_initializer
238
    self.is_training = is_training
239
    self.use_bias = use_bias
240
    self.eps = eps
241
    self.beta = beta
242
    self.gamma = gamma
243
    self.zeta = zeta
244

245
  def build(self, input_shape):
246
    input_shape = input_shape.as_list()
247
    input_hidden_size = input_shape[1]
248
    kernel_shape = [input_hidden_size, self.num_outputs]
249

250
    self.kernel = tf.get_variable(
251
        "kernel",
252
        shape=kernel_shape,
253
        initializer=self.kernel_initializer,
254
        regularizer=self.kernel_regularizer,
255
        dtype=self.dtype,
256
        trainable=True)
257

258
    if not self.log_alpha_initializer:
259
      # default log alpha set s.t. \alpha / (\alpha + 1) = .1
260
      self.log_alpha_initializer = tf.random_normal_initializer(
261
          mean=2.197, stddev=0.01, dtype=self.dtype)
262

263
    self.log_alpha = tf.get_variable(
264
        "log_alpha",
265
        shape=kernel_shape,
266
        initializer=self.log_alpha_initializer,
267
        dtype=self.dtype,
268
        trainable=True)
269

270
    layer_utils.add_variable_to_collection(
271
        (self.kernel, self.log_alpha),
272
        [THETA_LOGALPHA_COLLECTION], None)
273

274
    if self.use_bias:
275
      self.bias = self.add_variable(
276
          name="bias",
277
          shape=(self.num_outputs,),
278
          initializer=self.bias_initializer,
279
          regularizer=self.bias_regularizer,
280
          trainable=True,
281
          dtype=self.dtype)
282
    else:
283
      self.bias = None
284
    self.built = True
285

286
  def call(self, inputs):
287
    if self.is_training:
288
      x = nn.matmul_train(
289
          inputs,
290
          (self.kernel, self.log_alpha),
291
          beta=self.beta,
292
          gamma=self.gamma,
293
          zeta=self.zeta,
294
          eps=self.eps)
295
    else:
296
      x = nn.matmul_eval(
297
          inputs,
298
          (self.kernel, self.log_alpha),
299
          gamma=self.gamma,
300
          zeta=self.zeta)
301

302
    if self.use_bias:
303
      x = tf.nn.bias_add(x, self.bias)
304
    if self.activation is not None:
305
      return self.activation(x)
306
    return x
307

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.