google-research
298 строк · 10.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines an implementation of tensorflow core layers with vd pruning.
17"""
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import tensorflow.compat.v1 as tf23from state_of_sparsity.layers.utils import layer_utils24from state_of_sparsity.layers.variational_dropout import common25from state_of_sparsity.layers.variational_dropout import nn26from tensorflow.python.layers import base # pylint: disable=g-direct-tensorflow-import27
28THETA_LOGSIGMA2_COLLECTION = "theta_logsigma2"29
30
31class Conv2D(base.Layer):32r"""Base implementation of a conv2d layer with variational dropout.33
34Instead of deterministic parameters, parameters are drawn from a
35distribution with mean \theta and variance \sigma^2. A log-uniform prior
36for the distribution is used to encourage sparsity.
37
38Args:
39x: Input, float32 tensor.
40num_outputs: Int representing size of output tensor.
41kernel_size: The size of the convolutional window, int of list of ints.
42strides: stride length of convolution, a single int is expected.
43padding: May be populated as `"VALID"` or `"SAME"`.
44activation: If None, a linear activation is used.
45kernel_initializer: Initializer for the convolution weights.
46bias_initializer: Initalizer of the bias vector.
47kernel_regularizer: Regularization method for the convolution weights.
48bias_regularizer: Optional regularizer for the bias vector.
49log_sigma2_initializer: Specified initializer of the log_sigma2 term.
50data_format: Either'"channels_last"','"NHWC"','"NCHW"','"channels_first".
51is_training: Boolean specifying whether it is training or eval.
52use_bias: Boolean specifying whether bias vector should be used.
53eps: Small epsilon value to prevent math op saturation.
54threshold: Threshold for masking log alpha at test time. The relationship
55between \sigma^2, \theta, and \alpha as defined in the
56paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2
57clip_alpha: Int that specifies range for clipping log alpha values during
58training.
59name: String speciying name scope of layer in network.
60
61Returns:
62Output Tensor of the conv2d operation.
63"""
64
65def __init__(self,66num_outputs,67kernel_size,68strides,69padding,70activation,71kernel_initializer,72bias_initializer,73kernel_regularizer,74bias_regularizer,75log_sigma2_initializer,76data_format,77activity_regularizer=None,78is_training=True,79trainable=True,80use_bias=False,81eps=common.EPSILON,82threshold=3.,83clip_alpha=8.,84name="",85**kwargs):86super(Conv2D, self).__init__(87trainable=trainable,88name=name,89activity_regularizer=activity_regularizer,90**kwargs)91self.num_outputs = num_outputs92self.kernel_size = kernel_size93self.strides = [1, strides[0], strides[1], 1]94self.padding = padding.upper()95self.activation = activation96self.kernel_initializer = kernel_initializer97self.bias_initializer = bias_initializer98self.kernel_regularizer = kernel_regularizer99self.bias_regularizer = bias_regularizer100self.log_sigma2_initializer = log_sigma2_initializer101self.data_format = layer_utils.standardize_data_format(data_format)102self.is_training = is_training103self.use_bias = use_bias104self.eps = eps105self.threshold = threshold106self.clip_alpha = clip_alpha107
108def build(self, input_shape):109input_shape = input_shape.as_list()110dims = input_shape[3]111kernel_shape = [112self.kernel_size[0], self.kernel_size[1], dims, self.num_outputs113]114
115self.kernel = tf.get_variable(116"kernel",117shape=kernel_shape,118initializer=self.kernel_initializer,119regularizer=self.kernel_regularizer,120dtype=tf.float32,121trainable=True)122
123if not self.log_sigma2_initializer:124self.log_sigma2_initializer = tf.constant_initializer(125value=-10, dtype=tf.float32)126
127self.log_sigma2 = tf.get_variable(128"log_sigma2",129shape=kernel_shape,130initializer=self.log_sigma2_initializer,131dtype=tf.float32,132trainable=True)133
134layer_utils.add_variable_to_collection(135(self.kernel, self.log_sigma2),136[THETA_LOGSIGMA2_COLLECTION],137None)138
139if self.use_bias:140self.bias = self.add_variable(141name="bias",142shape=(self.filters,),143initializer=self.bias_initializer,144regularizer=self.bias_regularizer,145trainable=True,146dtype=self.dtype)147else:148self.bias = None149self.built = True150
151def call(self, inputs):152
153if self.is_training:154output = nn.conv2d_train(155x=inputs,156variational_params=(self.kernel, self.log_sigma2),157strides=self.strides,158padding=self.padding,159data_format=self.data_format,160clip_alpha=self.clip_alpha,161eps=self.eps)162else:163output = nn.conv2d_eval(164x=inputs,165variational_params=(self.kernel, self.log_sigma2),166strides=self.strides,167padding=self.padding,168data_format=self.data_format,169threshold=self.threshold,170eps=self.eps)171
172if self.use_bias:173output = tf.nn.bias_add(output, self.bias)174if self.activation is not None:175return self.activation(output)176else:177return output178
179
180class FullyConnected(base.Layer):181r"""Base implementation of a fully connected layer with variational dropout.182
183Instead of deterministic parameters, parameters are drawn from a
184distribution with mean \theta and variance \sigma^2. A log-uniform prior
185for the distribution is used to encourage sparsity.
186
187Args:
188x: Input, float32 tensor.
189num_outputs: Int representing size of output tensor.
190activation: If None, a linear activation is used.
191kernel_initializer: Initializer for the convolution weights.
192bias_initializer: Initalizer of the bias vector.
193kernel_regularizer: Regularization method for the convolution weights.
194bias_regularizer: Optional regularizer for the bias vector.
195log_sigma2_initializer: Specified initializer of the log_sigma2 term.
196is_training: Boolean specifying whether it is training or eval.
197use_bias: Boolean specifying whether bias vector should be used.
198eps: Small epsilon value to prevent math op saturation.
199threshold: Threshold for masking log alpha at test time. The relationship
200between \sigma^2, \theta, and \alpha as defined in the
201paper https://arxiv.org/abs/1701.05369 is \sigma^2 = \alpha * \theta^2
202clip_alpha: Int that specifies range for clipping log alpha values during
203training.
204name: String speciying name scope of layer in network.
205
206Returns:
207Output Tensor of the fully connected operation.
208"""
209
210def __init__(self,211num_outputs,212activation,213kernel_initializer,214bias_initializer,215kernel_regularizer,216bias_regularizer,217log_sigma2_initializer,218activity_regularizer=None,219is_training=True,220trainable=True,221use_bias=True,222eps=common.EPSILON,223threshold=3.,224clip_alpha=8.,225name="FullyConnected",226**kwargs):227super(FullyConnected, self).__init__(228trainable=trainable,229name=name,230activity_regularizer=activity_regularizer,231**kwargs)232self.num_outputs = num_outputs233self.activation = activation234self.kernel_initializer = kernel_initializer235self.bias_initializer = bias_initializer236self.kernel_regularizer = kernel_regularizer237self.bias_regularizer = bias_regularizer238self.log_sigma2_initializer = log_sigma2_initializer239self.is_training = is_training240self.use_bias = use_bias241self.eps = eps242self.threshold = threshold243self.clip_alpha = clip_alpha244
245def build(self, input_shape):246input_shape = input_shape.as_list()247input_hidden_size = input_shape[1]248kernel_shape = [input_hidden_size, self.num_outputs]249
250self.kernel = tf.get_variable(251"kernel",252shape=kernel_shape,253initializer=self.kernel_initializer,254regularizer=self.kernel_regularizer,255dtype=tf.float32,256trainable=True)257
258if not self.log_sigma2_initializer:259self.log_sigma2_initializer = tf.constant_initializer(260value=-10, dtype=tf.float32)261
262self.log_sigma2 = tf.get_variable(263"log_sigma2",264shape=kernel_shape,265initializer=self.log_sigma2_initializer,266dtype=tf.float32,267trainable=True)268
269layer_utils.add_variable_to_collection(270(self.kernel, self.log_sigma2),271[THETA_LOGSIGMA2_COLLECTION],272None)273
274if self.use_bias:275self.bias = self.add_variable(276name="bias",277shape=(self.num_outputs,),278initializer=self.bias_initializer,279regularizer=self.bias_regularizer,280trainable=True,281dtype=self.dtype)282else:283self.bias = None284self.built = True285
286def call(self, inputs):287if self.is_training:288x = nn.matmul_train(289inputs, (self.kernel, self.log_sigma2), clip_alpha=self.clip_alpha)290else:291x = nn.matmul_eval(292inputs, (self.kernel, self.log_sigma2), threshold=self.threshold)293
294if self.use_bias:295x = tf.nn.bias_add(x, self.bias)296if self.activation is not None:297return self.activation(x)298return x299