google-research
306 строк · 10.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""tf.layers-like API for l0-regularization layers."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22
23from state_of_sparsity.layers.l0_regularization import common
24from state_of_sparsity.layers.l0_regularization import nn
25from state_of_sparsity.layers.utils import layer_utils
26from tensorflow.python.layers import base # pylint: disable=g-direct-tensorflow-import
27
28
29THETA_LOGALPHA_COLLECTION = "theta_logalpha"
30
31
32class Conv2D(base.Layer):
33"""Base implementation of a conv2d layer with l0-regularization.
34
35Args:
36num_outputs: Int representing size of output tensor.
37kernel_size: The size of the convolutional window, int of list of ints.
38strides: stride length of convolution, a single int is expected.
39padding: May be populated as "VALID" or "SAME".
40activation: If None, a linear activation is used.
41kernel_initializer: Initializer for the convolution weights.
42bias_initializer: Initalizer of the bias vector.
43kernel_regularizer: Regularization method for the convolution weights.
44bias_regularizer: Optional regularizer for the bias vector.
45log_alpha_initializer: initializer for the log alpha parameters of the
46hard-concrete distribution.
47data_format: Either "channels_last", "NHWC", "NCHW", "channels_first".
48is_training: Boolean specifying whether it is training or eval.
49use_bias: Boolean specifying whether bias vector should be used.
50eps: Small epsilon value to prevent math op saturation.
51beta: The beta parameter, which controls the "temperature" of
52the distribution. Defaults to 2/3 from the above paper.
53gamma: The gamma parameter, which controls the lower bound of the
54stretched distribution. Defaults to -0.1 from the above paper.
55zeta: The zeta parameters, which controls the upper bound of the
56stretched distribution. Defaults to 1.1 from the above paper.
57name: String speciying name scope of layer in network.
58
59Returns:
60Output Tensor of the conv2d operation.
61"""
62
63def __init__(self,
64num_outputs,
65kernel_size,
66strides,
67padding,
68activation,
69kernel_initializer,
70bias_initializer,
71kernel_regularizer,
72bias_regularizer,
73log_alpha_initializer,
74data_format,
75activity_regularizer=None,
76is_training=True,
77trainable=True,
78use_bias=False,
79eps=common.EPSILON,
80beta=common.BETA,
81gamma=common.GAMMA,
82zeta=common.ZETA,
83name="",
84**kwargs):
85super(Conv2D, self).__init__(
86trainable=trainable,
87name=name,
88activity_regularizer=activity_regularizer,
89**kwargs)
90self.num_outputs = num_outputs
91self.kernel_size = kernel_size
92self.strides = [1, strides[0], strides[1], 1]
93self.padding = padding.upper()
94self.activation = activation
95self.kernel_initializer = kernel_initializer
96self.bias_initializer = bias_initializer
97self.kernel_regularizer = kernel_regularizer
98self.bias_regularizer = bias_regularizer
99self.log_alpha_initializer = log_alpha_initializer
100self.data_format = layer_utils.standardize_data_format(data_format)
101self.is_training = is_training
102self.use_bias = use_bias
103self.eps = eps
104self.beta = beta
105self.gamma = gamma
106self.zeta = zeta
107
108def build(self, input_shape):
109input_shape = input_shape.as_list()
110dims = input_shape[3]
111kernel_shape = [
112self.kernel_size[0], self.kernel_size[1], dims, self.num_outputs
113]
114
115self.kernel = tf.get_variable(
116"kernel",
117shape=kernel_shape,
118initializer=self.kernel_initializer,
119regularizer=self.kernel_regularizer,
120dtype=self.dtype,
121trainable=True)
122
123if not self.log_alpha_initializer:
124# default log alpha set s.t. \alpha / (\alpha + 1) = .1
125self.log_alpha_initializer = tf.random_normal_initializer(
126mean=2.197, stddev=0.01, dtype=self.dtype)
127
128self.log_alpha = tf.get_variable(
129"log_alpha",
130shape=kernel_shape,
131initializer=self.log_alpha_initializer,
132dtype=self.dtype,
133trainable=True)
134
135layer_utils.add_variable_to_collection(
136(self.kernel, self.log_alpha),
137[THETA_LOGALPHA_COLLECTION], None)
138
139if self.use_bias:
140self.bias = self.add_variable(
141name="bias",
142shape=(self.filters,),
143initializer=self.bias_initializer,
144regularizer=self.bias_regularizer,
145trainable=True,
146dtype=self.dtype)
147else:
148self.bias = None
149self.built = True
150
151def call(self, inputs):
152if self.is_training:
153output = nn.conv2d_train(
154x=inputs,
155weight_parameters=(self.kernel, self.log_alpha),
156strides=self.strides,
157padding=self.padding,
158data_format=self.data_format,
159beta=self.beta,
160gamma=self.gamma,
161zeta=self.zeta,
162eps=self.eps)
163else:
164output = nn.conv2d_eval(
165x=inputs,
166weight_parameters=(self.kernel, self.log_alpha),
167strides=self.strides,
168padding=self.padding,
169data_format=self.data_format,
170gamma=self.gamma,
171zeta=self.zeta)
172
173if self.use_bias:
174output = tf.nn.bias_add(output, self.bias)
175if self.activation is not None:
176return self.activation(output)
177else:
178return output
179
180
181class FullyConnected(base.Layer):
182"""Base implementation of a fully connected layer with l0 regularization.
183
184Args:
185x: Input, float32 tensor.
186num_outputs: Int representing size of output tensor.
187activation: If None, a linear activation is used.
188kernel_initializer: Initializer for the convolution weights.
189bias_initializer: Initalizer of the bias vector.
190kernel_regularizer: Regularization method for the convolution weights.
191bias_regularizer: Optional regularizer for the bias vector.
192log_alpha_initializer: Specified initializer of the log_alpha term.
193is_training: Boolean specifying whether it is training or eval.
194use_bias: Boolean specifying whether bias vector should be used.
195eps: Small epsilon value to prevent math op saturation.
196beta: The beta parameter, which controls the "temperature" of
197the distribution. Defaults to 2/3 from the above paper.
198gamma: The gamma parameter, which controls the lower bound of the
199stretched distribution. Defaults to -0.1 from the above paper.
200zeta: The zeta parameters, which controls the upper bound of the
201stretched distribution. Defaults to 1.1 from the above paper.
202name: String speciying name scope of layer in network.
203
204Returns:
205Output Tensor of the fully connected operation.
206"""
207
208def __init__(self,
209num_outputs,
210activation,
211kernel_initializer,
212bias_initializer,
213kernel_regularizer,
214bias_regularizer,
215log_alpha_initializer,
216activity_regularizer=None,
217is_training=True,
218trainable=True,
219use_bias=True,
220eps=common.EPSILON,
221beta=common.BETA,
222gamma=common.GAMMA,
223zeta=common.ZETA,
224name="FullyConnected",
225**kwargs):
226super(FullyConnected, self).__init__(
227trainable=trainable,
228name=name,
229activity_regularizer=activity_regularizer,
230**kwargs)
231self.num_outputs = num_outputs
232self.activation = activation
233self.kernel_initializer = kernel_initializer
234self.bias_initializer = bias_initializer
235self.kernel_regularizer = kernel_regularizer
236self.bias_regularizer = bias_regularizer
237self.log_alpha_initializer = log_alpha_initializer
238self.is_training = is_training
239self.use_bias = use_bias
240self.eps = eps
241self.beta = beta
242self.gamma = gamma
243self.zeta = zeta
244
245def build(self, input_shape):
246input_shape = input_shape.as_list()
247input_hidden_size = input_shape[1]
248kernel_shape = [input_hidden_size, self.num_outputs]
249
250self.kernel = tf.get_variable(
251"kernel",
252shape=kernel_shape,
253initializer=self.kernel_initializer,
254regularizer=self.kernel_regularizer,
255dtype=self.dtype,
256trainable=True)
257
258if not self.log_alpha_initializer:
259# default log alpha set s.t. \alpha / (\alpha + 1) = .1
260self.log_alpha_initializer = tf.random_normal_initializer(
261mean=2.197, stddev=0.01, dtype=self.dtype)
262
263self.log_alpha = tf.get_variable(
264"log_alpha",
265shape=kernel_shape,
266initializer=self.log_alpha_initializer,
267dtype=self.dtype,
268trainable=True)
269
270layer_utils.add_variable_to_collection(
271(self.kernel, self.log_alpha),
272[THETA_LOGALPHA_COLLECTION], None)
273
274if self.use_bias:
275self.bias = self.add_variable(
276name="bias",
277shape=(self.num_outputs,),
278initializer=self.bias_initializer,
279regularizer=self.bias_regularizer,
280trainable=True,
281dtype=self.dtype)
282else:
283self.bias = None
284self.built = True
285
286def call(self, inputs):
287if self.is_training:
288x = nn.matmul_train(
289inputs,
290(self.kernel, self.log_alpha),
291beta=self.beta,
292gamma=self.gamma,
293zeta=self.zeta,
294eps=self.eps)
295else:
296x = nn.matmul_eval(
297inputs,
298(self.kernel, self.log_alpha),
299gamma=self.gamma,
300zeta=self.zeta)
301
302if self.use_bias:
303x = tf.nn.bias_add(x, self.bias)
304if self.activation is not None:
305return self.activation(x)
306return x
307