google-research
436 строк · 15.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines standard network layers that train using l0 regularization."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22
23from state_of_sparsity.layers.l0_regularization import common
24from state_of_sparsity.layers.utils import layer_utils
25
26
27def _verify_weight_parameters(weight_parameters):
28"""Verifies that the format of the input `weight_parameters`.
29
30Checks that the input parameters is a 2-tuple of tensors of equal shape.
31
32Args:
33weight_parameters: The parameters to check.
34
35Raises:
36RuntimeError: If the input is not a 2-tuple of tensors with equal shape.
37
38Returns:
39The input `weight_parameters`.
40"""
41if len(weight_parameters) != 2:
42raise RuntimeError("Incorrect number of weight parameters. Expected "
43"2 tensors, got {}".format(len(weight_parameters)))
44if weight_parameters[0].shape != weight_parameters[1].shape:
45raise RuntimeError("Expected theta and log alpha parameter tensor "
46"to be same shape. Got shapes {} and {}"
47.format(weight_parameters[0].get_shape().as_list(),
48weight_parameters[1].get_shape().as_list()))
49return weight_parameters
50
51
52def matmul_train(
53x,
54weight_parameters,
55transpose_a=False,
56transpose_b=False,
57beta=common.BETA,
58gamma=common.GAMMA,
59zeta=common.ZETA,
60eps=common.EPSILON):
61"""Training computation for a l0-regularized matmul.
62
63Args:
64x: 2D Tensor representing the input batch.
65weight_parameters: 2-tuple of Tensors, where the first tensor is the
66unscaled weight values and the second is the log of the alpha values
67for the hard concrete distribution.
68transpose_a: If True, a is transposed before multiplication.
69transpose_b: If True, b is transposed before multiplication.
70beta: The beta parameter, which controls the "temperature" of
71the distribution. Defaults to 2/3 from the above paper.
72gamma: The gamma parameter, which controls the lower bound of the
73stretched distribution. Defaults to -0.1 from the above paper.
74zeta: The zeta parameters, which controls the upper bound of the
75stretched distribution. Defaults to 1.1 from the above paper.
76eps: Small constant value to use in log and sqrt operations to avoid NaNs.
77
78Returns:
79Output Tensor of the matmul operation.
80
81Raises:
82RuntimeError: If the weight_parameters argument is not a 2-tuple.
83"""
84x.get_shape().assert_has_rank(2)
85theta, log_alpha = _verify_weight_parameters(weight_parameters)
86
87# Sample the z values from the hard-concrete distribution
88weight_noise = common.hard_concrete_sample(
89log_alpha,
90beta,
91gamma,
92zeta,
93eps)
94weights = theta * weight_noise
95return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
96
97
98def matmul_eval(
99x,
100weight_parameters,
101transpose_a=False,
102transpose_b=False,
103gamma=common.GAMMA,
104zeta=common.ZETA):
105"""Evaluation computation for a l0-regularized matmul.
106
107Args:
108x: 2D Tensor representing the input batch.
109weight_parameters: 2-tuple of Tensors, where the first tensor is the
110unscaled weight values and the second is the log of the alpha values
111for the hard concrete distribution.
112transpose_a: If True, a is transposed before multiplication.
113transpose_b: If True, b is transposed before multiplication.
114gamma: The gamma parameter, which controls the lower bound of the
115stretched distribution. Defaults to -0.1 from the above paper.
116zeta: The zeta parameters, which controls the upper bound of the
117stretched distribution. Defaults to 1.1 from the above paper.
118
119Returns:
120Output Tensor of the matmul operation.
121
122Raises:
123RuntimeError: If the weight_parameters argument is not a 2-tuple.
124"""
125x.get_shape().assert_has_rank(2)
126theta, log_alpha = _verify_weight_parameters(weight_parameters)
127
128# Use the mean of the learned hard-concrete distribution as the
129# deterministic weight noise at evaluation time
130weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
131weights = theta * weight_noise
132return tf.matmul(x, weights, transpose_a=transpose_a, transpose_b=transpose_b)
133
134
135def broadcast_matmul_train(
136x,
137weight_parameters,
138beta=common.BETA,
139gamma=common.GAMMA,
140zeta=common.ZETA,
141eps=common.EPSILON):
142"""Training computation for l0 matrix multiplication with N input matrices.
143
144Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
145`x[i, :, :]` in the input tensor is multiplied independently with the
146parameters, resulting in a 3D output tensor with shape
147`x.shape[:2] + weight_parameters[0].shape[1]`.
148
149Args:
150x: 3D Tensor representing the input batch.
151weight_parameters: 2-tuple of Tensors, where the first tensor is the
152unscaled weight values and the second is the log of the alpha values
153for the hard concrete distribution.
154beta: The beta parameter, which controls the "temperature" of
155the distribution. Defaults to 2/3 from the above paper.
156gamma: The gamma parameter, which controls the lower bound of the
157stretched distribution. Defaults to -0.1 from the above paper.
158zeta: The zeta parameters, which controls the upper bound of the
159stretched distribution. Defaults to 1.1 from the above paper.
160eps: Small constant value to use in log and sqrt operations to avoid NaNs.
161
162Returns:
163Output Tensor of the batched matmul operation.
164
165Raises:
166RuntimeError: If the weight_parameters argument is not a 2-tuple.
167"""
168theta, log_alpha = _verify_weight_parameters(weight_parameters)
169theta.get_shape().assert_has_rank(2)
170
171# The input data must have be rank 2 or greater
172assert x.get_shape().ndims >= 2
173input_rank = x.get_shape().ndims
174
175# Sample the z values from the hard-concrete distribution
176weight_noise = common.hard_concrete_sample(
177log_alpha,
178beta,
179gamma,
180zeta,
181eps)
182weights = theta * weight_noise
183
184# Compute the batch of matmuls
185return tf.tensordot(x, weights, [[input_rank-1], [0]])
186
187
188def broadcast_matmul_eval(
189x,
190weight_parameters,
191gamma=common.GAMMA,
192zeta=common.ZETA):
193"""Evaluation computation for l0 matrix multiplication with N input matrices.
194
195Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
196`x[i, :, :]` in the input tensor is multiplied independently with the
197parameters, resulting in a 3D output tensor with shape
198`x.shape[:2] + weight_parameters[0].shape[1]`.
199
200Args:
201x: 3D Tensor representing the input batch.
202weight_parameters: 2-tuple of Tensors, where the first tensor is the
203unscaled weight values and the second is the log of the alpha values
204for the hard concrete distribution.
205gamma: The gamma parameter, which controls the lower bound of the
206stretched distribution. Defaults to -0.1 from the above paper.
207zeta: The zeta parameters, which controls the upper bound of the
208stretched distribution. Defaults to 1.1 from the above paper.
209
210Returns:
211Output Tensor of the batched matmul operation.
212
213Raises:
214RuntimeError: If the weight_parameters argument is not a 2-tuple.
215"""
216theta, log_alpha = _verify_weight_parameters(weight_parameters)
217theta.get_shape().assert_has_rank(2)
218
219# The input data must have be rank 2 or greater
220assert x.get_shape().ndims >= 2
221input_rank = x.get_shape().ndims
222
223# Use the mean of the learned hard-concrete distribution as the
224# deterministic weight noise at evaluation time
225weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
226weights = theta * weight_noise
227
228# Compute the batch of matmuls
229return tf.tensordot(x, weights, [[input_rank-1], [0]])
230
231
232def conv2d_train(
233x,
234weight_parameters,
235strides,
236padding,
237data_format="NHWC",
238beta=common.BETA,
239gamma=common.GAMMA,
240zeta=common.ZETA,
241eps=common.EPSILON):
242"""Training computation for a l0-regularized conv2d.
243
244Args:
245x: NHWC tf.Tensor representing the input batch of features.
246weight_parameters: 2-tuple of Tensors, where the first tensor is the
247unscaled weight values and the second is the log of the alpha values
248for the hard concrete distribution.
249strides: The stride of the sliding window for each dimension of 'x'.
250Identical to standard strides argument for tf.conv2d.
251padding: String. One of "SAME", or "VALID". Identical to standard
252padding argument for tf.conv2d.
253data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
254beta: The beta parameter, which controls the "temperature" of
255the distribution. Defaults to 2/3 from the above paper.
256gamma: The gamma parameter, which controls the lower bound of the
257stretched distribution. Defaults to -0.1 from the above paper.
258zeta: The zeta parameters, which controls the upper bound of the
259stretched distribution. Defaults to 1.1 from the above paper.
260eps: Small constant value to use in log and sqrt operations to avoid NaNs.
261
262Returns:
263Output Tensor of the conv2d operation.
264
265Raises:
266RuntimeError: If the weight_parameters argument is not a 2-tuple.
267"""
268theta, log_alpha = _verify_weight_parameters(weight_parameters)
269
270# Sample the z values from the hard-concreate distribution
271weight_noise = common.hard_concrete_sample(
272log_alpha,
273beta,
274gamma,
275zeta,
276eps)
277weights = theta * weight_noise
278return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)
279
280
281def conv2d_eval(
282x,
283weight_parameters,
284strides,
285padding,
286data_format="NHWC",
287gamma=common.GAMMA,
288zeta=common.ZETA):
289"""Evaluation computation for a l0-regularized conv2d.
290
291Args:
292x: NHWC tf.Tensor representing the input batch of features.
293weight_parameters: 2-tuple of Tensors, where the first tensor is the
294unscaled weight values and the second is the log of the alpha values
295for the hard concrete distribution.
296strides: The stride of the sliding window for each dimension of 'x'.
297Identical to standard strides argument for tf.conv2d.
298padding: String. One of "SAME", or "VALID". Identical to standard
299padding argument for tf.conv2d.
300data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
301gamma: The gamma parameter, which controls the lower bound of the
302stretched distribution. Defaults to -0.1 from the above paper.
303zeta: The zeta parameters, which controls the upper bound of the
304stretched distribution. Defaults to 1.1 from the above paper.
305
306Returns:
307Output Tensor of the conv2d operation.
308
309Raises:
310RuntimeError: If the weight_parameters argument is not a 2-tuple.
311"""
312theta, log_alpha = _verify_weight_parameters(weight_parameters)
313
314# Use the mean of the learned hard-concrete distribution as the
315# deterministic weight noise at evaluation time
316weight_noise = common.hard_concrete_mean(log_alpha, gamma, zeta)
317weights = theta * weight_noise
318return tf.nn.conv2d(x, weights, strides, padding, data_format=data_format)
319
320
321def embedding_lookup_train(
322weight_parameters,
323ids,
324name=None,
325beta=common.BETA,
326gamma=common.GAMMA,
327zeta=common.ZETA,
328eps=common.EPSILON):
329"""Training computation for a l0-regularized embedding lookup.
330
331Args:
332weight_parameters: 2-tuple of Tensors, where the first tensor is the
333unscaled weight values and the second is the log of the alpha values
334for the hard concrete distribution.
335ids: A Tensor with type int32 or int64 containing the ids to be looked up
336in params.
337name: String. Name of the operator.
338beta: The beta parameter, which controls the "temperature" of
339the distribution. Defaults to 2/3 from the above paper.
340gamma: The gamma parameter, which controls the lower bound of the
341stretched distribution. Defaults to -0.1 from the above paper.
342zeta: The zeta parameters, which controls the upper bound of the
343stretched distribution. Defaults to 1.1 from the above paper.
344eps: Small constant value to use in log and sqrt operations to avoid NaNs.
345
346Returns:
347Output Tensor of the embedding lookup.
348
349Raises:
350RuntimeError: If the weight_parameters argument is not a 2-tuple.
351"""
352theta, log_alpha = _verify_weight_parameters(weight_parameters)
353
354# Before we do anything, lookup the theta values and log_alpha
355# values so that we can do our sampling and weight scaling in
356# the lower dimensional output batch
357embedding_theta = layer_utils.gather(theta, ids)
358embedding_log_alpha = layer_utils.gather(log_alpha, ids)
359
360# Sample the z values for the output batch from the hard-concrete
361embedding_noise = common.hard_concrete_sample(
362embedding_log_alpha,
363beta,
364gamma,
365zeta,
366eps)
367return tf.identity(embedding_theta * embedding_noise, name=name)
368
369
370def embedding_lookup_eval(
371weight_parameters,
372ids,
373name=None,
374gamma=common.GAMMA,
375zeta=common.ZETA):
376"""Evaluation computation for a l0-regularized embedding lookup.
377
378Args:
379weight_parameters: 2-tuple of Tensors, where the first tensor is the
380unscaled weight values and the second is the log of the alpha values
381for the hard concrete distribution.
382ids: A Tensor with type int32 or int64 containing the ids to be looked up
383in params.
384name: String. Name of the operator.
385gamma: The gamma parameter, which controls the lower bound of the
386stretched distribution. Defaults to -0.1 from the above paper.
387zeta: The zeta parameters, which controls the upper bound of the
388stretched distribution. Defaults to 1.1 from the above paper.
389
390Returns:
391Output Tensor of the embedding lookup.
392
393Raises:
394RuntimeError: If the weight_parameters argument is not a 2-tuple.
395"""
396theta, log_alpha = _verify_weight_parameters(weight_parameters)
397
398# Before we do anything, lookup the theta values and log_alpha
399# values so that we can do our sampling and weight scaling in
400# the lower dimensional output batch
401embedding_theta = layer_utils.gather(theta, ids)
402embedding_log_alpha = layer_utils.gather(log_alpha, ids)
403
404# Calculate the mean of the learned hard-concrete distribution
405# and scale the output embedding vectors
406embedding_noise = common.hard_concrete_mean(
407embedding_log_alpha,
408gamma,
409zeta)
410return tf.identity(embedding_theta * embedding_noise, name=name)
411
412
413def l0_norm(
414log_alpha,
415beta=common.BETA,
416gamma=common.GAMMA,
417zeta=common.ZETA):
418"""Calculate the l0-regularization contribution to the loss.
419
420Args:
421log_alpha: Tensor of the log alpha parameters for the hard concrete
422distribution.
423beta: The beta parameter, which controls the "temperature" of
424the distribution. Defaults to 2/3 from the above paper.
425gamma: The gamma parameter, which controls the lower bound of the
426stretched distribution. Defaults to -0.1 from the above paper.
427zeta: The zeta parameters, which controls the upper bound of the
428stretched distribution. Defaults to 1.1 from the above paper.
429
430Returns:
431Scalar tensor containing the unweighted l0-regularization term contribution
432to the loss.
433"""
434# Value of the CDF of the hard-concrete distribution evaluated at 0
435reg_per_weight = tf.sigmoid(log_alpha - beta * tf.log(-gamma / zeta))
436return tf.reduce_sum(reg_per_weight)
437