google-research
592 строки · 22.8 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines standard networks layers that train using variational dropout."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22
23from state_of_sparsity.layers.utils import layer_utils
24from state_of_sparsity.layers.variational_dropout import common
25
26
27def _verify_variational_params(variational_params):
28"""Verifies that the format of the input `variational_params`.
29
30Checks that the input parameters is a 2-tuple of tensors of equal shape.
31
32Args:
33variational_params: The parameters to check.
34
35Raises:
36RuntimeError: If the input is not a 2-tuple of tensors with equal shape.
37
38Returns:
39The input `variational_parameters`.
40"""
41if len(variational_params) != 2:
42raise RuntimeError("Incorrect number of variational parameters.")
43if variational_params[0].shape != variational_params[1].shape:
44raise RuntimeError("Variational parameters must be the same shape.")
45return variational_params
46
47
48def matmul_train(
49x,
50variational_params,
51transpose_a=False,
52transpose_b=False,
53clip_alpha=None,
54eps=common.EPSILON):
55R"""Training computation for a variation matmul.
56
57In variational dropout we train a Bayesian neural network where we assume a
58fully-factorized Gaussian posterior and log uniform prior over the weights.
59
60During training, we need to sample weights from this distribution. Rather
61than sample weights for each sample in the input batch, we can calculate the
62parameters of the distribution over the pre-activations analytically (this
63step is called the local reparameterization trick). This function calculates
64the mean and standard deviation of the distribution over the pre-activations,
65and then draws a single sample for each element in the input batch and passes
66them as output.
67
68Args:
69x: 2D Tensor representing the input batch.
70variational_params: 2-tuple of Tensors, where the first tensor is the \theta
71values and the second contains the log of the \sigma^2 values.
72transpose_a: If True, a is transposed before multiplication.
73transpose_b: If True, b is transposed before multiplication.
74clip_alpha: Int or None. If integer, we clip the log \alpha values to
75[-clip_alpha, clip_alpha]. If None, don't clip the values.
76eps: Small constant value to use in log and sqrt operations to avoid NaNs.
77
78Returns:
79Output Tensor of the matmul operation.
80
81Raises:
82RuntimeError: If the variational_params argument is not a 2-tuple.
83"""
84# We expect a 2D input tensor, as in standard in fully-connected layers
85x.get_shape().assert_has_rank(2)
86
87theta, log_sigma2 = _verify_variational_params(
88variational_params)
89
90if clip_alpha is not None:
91# Compute the log_alphas and then compute the
92# log_sigma2 again so that we can clip on the
93# log alpha magnitudes
94log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, clip_alpha)
95log_sigma2 = common.compute_log_sigma2(log_alpha, theta, eps)
96
97# Compute the mean and standard deviation of the distributions over the
98# activations
99mu_activation = tf.matmul(
100x,
101theta,
102transpose_a=transpose_a,
103transpose_b=transpose_b)
104std_activation = tf.sqrt(tf.matmul(
105tf.square(x),
106tf.exp(log_sigma2),
107transpose_a=transpose_a,
108transpose_b=transpose_b) + eps)
109
110output_shape = tf.shape(std_activation)
111return mu_activation + std_activation * tf.random_normal(output_shape)
112
113
114def matmul_eval(
115x,
116variational_params,
117transpose_a=False,
118transpose_b=False,
119threshold=3.0,
120eps=common.EPSILON):
121R"""Evaluation computation for a variation matmul.
122
123In variational dropout we train a Bayesian neural network where we assume a
124fully-factorized Gaussian posterior and log uniform prior over the weights.
125
126The parameters of the posterior are learned during training, and at eval
127time we use the learned mean as the weight values.
128
129This method also supports the pruning of weights based on their log \alpha
130values. All weights with log \alpha >= `threshold` are set to zero.
131
132Args:
133x: 2D Tensor representing the input batch.
134variational_params: 2-tuple of Tensors, where the first tensor is the \theta
135values and the second contains the log of the \sigma^2 values.
136transpose_a: If True, a is transposed before multiplication.
137transpose_b: If True, b is transposed before multiplication.
138threshold: Weights with a log \alpha_{ij} value greater than this will be
139set to zero.
140eps: Small constant value to use in log and sqrt operations to avoid NaNs.
141
142Returns:
143Output Tensor of the variational matmul operation.
144
145Raises:
146RuntimeError: If the variational_params argument is not a 2-tuple.
147"""
148# We expect a 2D input tensor, as is standard in fully-connected layers
149x.get_shape().assert_has_rank(2)
150
151theta, log_sigma2 = _verify_variational_params(
152variational_params)
153
154# Compute the weight mask by thresholding on
155# the log-space alpha values
156log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, value_limit=None)
157weight_mask = tf.cast(tf.less(log_alpha, threshold), tf.float32)
158
159return tf.matmul(
160x,
161theta * weight_mask,
162transpose_a=transpose_a,
163transpose_b=transpose_b)
164
165
166def broadcast_matmul_train(
167x,
168variational_params,
169clip_alpha=None,
170eps=common.EPSILON):
171R"""Training computation for VD matrix multiplication with N input matrices.
172
173Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
174`x[i, :, :]` in the input tensor is multiplied indendently with the
175parameters, resulting in a 3D output tensor with shape
176`x.shape[:2] + weight_parameters[0].shape[1]`.
177
178Args:
179x: 3D Tensor representing the input batch.
180variational_params: 2-tuple of Tensors, where the first tensor is the
181unscaled weight values and the second is the log of the alpha values
182for the hard concrete distribution.
183clip_alpha: Int or None. If integer, we clip the log \alpha values to
184[-clip_alpha, clip_alpha]. If None, don't clip the values.
185eps: Small constant value to use in log and sqrt operations to avoid NaNs.
186
187Returns:
188Output Tensor of the batched matmul operation.
189
190Raises:
191RuntimeError: If the variational_params argument is not a 2-tuple.
192"""
193theta, log_sigma2 = _verify_variational_params(
194variational_params)
195theta.get_shape().assert_has_rank(2)
196log_sigma2.get_shape().assert_has_rank(2)
197
198# The input data must have be rank 2 or greater
199assert x.get_shape().ndims >= 2
200input_rank = x.get_shape().ndims
201
202if clip_alpha is not None:
203# Compute the log_alphas and then compute the
204# log_sigma2 again so that we can clip on the
205# log alpha magnitudes
206log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, clip_alpha)
207log_sigma2 = common.compute_log_sigma2(log_alpha, theta, eps)
208
209# Compute the mean and standard deviation of the distributions over the
210# activations
211mu_activation = tf.tensordot(x, theta, [[input_rank-1], [0]])
212
213var_activation = tf.tensordot(
214tf.square(x),
215tf.exp(log_sigma2),
216[[input_rank-1], [0]])
217std_activation = tf.sqrt(var_activation + eps)
218
219# Reshape the output back to the rank of the input
220input_shape = x.get_shape().as_list()
221weight_shape = theta.get_shape().as_list()
222output_shape = input_shape[:-1] + [weight_shape[1]]
223mu_activation.set_shape(output_shape)
224std_activation.set_shape(output_shape)
225
226# NOTE: We sample noise for each weight in theta, which will be shared by
227# each matrix product that was done. This is equivalent to sampling the same
228# set of weights for all matrix products done by this op in an iteration.
229# The element-wise multiply below broadcasts.
230num_pad_dims = len(output_shape) - 2
231padding = [tf.constant(1, dtype=tf.int32) for _ in range(num_pad_dims)]
232
233# NOTE: On GPU, the first dim may not be defined w/ the Transformer. Create
234# a tf.Tensor from the list shape and TF should match the first dim
235# appropriately
236batch_size = tf.shape(x)[0]
237data_dim = tf.shape(theta)[-1]
238noise_shape = tf.stack([batch_size] + padding + [data_dim], axis=0)
239
240output = mu_activation + std_activation * tf.random_normal(noise_shape)
241return output
242
243
244def broadcast_matmul_eval(
245x,
246variational_params,
247threshold=3.0,
248eps=common.EPSILON):
249R"""Evaluation computation for VD matrix multiplication with N input matrices.
250
251Multiplies a 3D tensor `x` with a set of 2D parameters. Each 2D matrix
252`x[i, :, :]` in the input tensor is multiplied indendently with the
253parameters, resulting in a 3D output tensor with shape
254`x.shape[:2] + weight_parameters[0].shape[1]`.
255
256Args:
257x: 3D Tensor representing the input batch.
258variational_params: 2-tuple of Tensors, where the first tensor is the
259unscaled weight values and the second is the log of the alpha values
260for the hard concrete distribution.
261threshold: Weights with a log \alpha_{ij} value greater than this will be
262set to zero.
263eps: Small constant value to use in log and sqrt operations to avoid NaNs.
264
265Returns:
266Output Tensor of the batched matmul operation.
267
268Raises:
269RuntimeError: If the variational_params argument is not a 2-tuple.
270"""
271theta, log_sigma2 = _verify_variational_params(
272variational_params)
273theta.get_shape().assert_has_rank(2)
274log_sigma2.get_shape().assert_has_rank(2)
275
276# The input data must have be rank 2 or greater
277assert x.get_shape().ndims >= 2
278input_rank = x.get_shape().ndims
279
280# Compute the weights mask by thresholding on the log-space alpha values
281log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, value_limit=None)
282weight_mask = tf.cast(tf.less(log_alpha, threshold), tf.float32)
283
284output = tf.tensordot(x, theta * weight_mask, [[input_rank-1], [0]])
285
286# Reshape the output back to the rank of the input
287input_shape = x.get_shape().as_list()
288weight_shape = theta.get_shape().as_list()
289output_shape = input_shape[:-1] + [weight_shape[1]]
290output.set_shape(output_shape)
291return output
292
293
294def conv2d_train(x,
295variational_params,
296strides,
297padding,
298data_format="NHWC",
299clip_alpha=None,
300eps=common.EPSILON):
301R"""Training computation for a variational conv2d.
302
303In variational dropout we train a Bayesian neural network where we assume a
304fully-factorized Gaussian posterior and log uniform prior over the weights.
305
306During training, we need to sample weights from this distribution. Rather
307than sample weights for each sample in the input batch, we can calculate the
308parameters of the distribution over the pre-activations analytically (this
309step is called the local reparameterization trick). This function calculates
310the mean and standard deviation of the distribution over the pre-activations,
311and then draws a single sample for each element in the input batch and passes
312them as output.
313
314Args:
315x: NHWC tf.Tensor representing the input batch of features.
316variational_params: 2-tuple of Tensors, where the first tensor is the \theta
317values and the second contains the log of the \sigma^2 values.
318strides: The stride of the sliding window for each dimension of `x`.
319Identical to standard strides argument for tf.conv2d.
320padding: String. One of "SAME", or "VALID". Identical to standard padding
321argument for tf.conv2d.
322data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
323clip_alpha: Int or None. If integer, we clip the log \alpha values to
324[-clip_alpha, clip_alpha]. If None, don't clip the values.
325eps: Small constant value to use in log and sqrt operations to avoid NaNs.
326
327Returns:
328Output Tensor of the conv2d operation.
329
330Raises:
331RuntimeError: If the variational_params argument
332is not a 2-tuple.
333"""
334theta, log_sigma2 = _verify_variational_params(variational_params)
335
336if clip_alpha:
337# Compute the log_alphas and then compute the
338# log_sigma2 again so that we can clip on the
339# log alpha magnitudes
340log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, clip_alpha)
341log_sigma2 = common.compute_log_sigma2(log_alpha, theta, eps)
342
343# Compute the mean and standard deviation of the distribution over the
344# convolution outputs
345mu_activation = tf.nn.conv2d(
346x, theta, strides, padding, data_format=data_format)
347std_activation = tf.sqrt(
348tf.nn.conv2d(
349tf.square(x),
350tf.exp(log_sigma2),
351strides,
352padding,
353data_format=data_format) + eps)
354
355output_shape = tf.shape(std_activation)
356return mu_activation + std_activation * tf.random_normal(output_shape)
357
358
359def conv2d_eval(x,
360variational_params,
361strides,
362padding,
363data_format="NHWC",
364threshold=3.0,
365eps=common.EPSILON):
366R"""Evaluation computation for a variation conv2d.
367
368In variational dropout we train a Bayesian neural network where we assume a
369fully-factorized Gaussian posterior and log uniform prior over the weights.
370
371The parameters of the posterior are learned during training, and at eval
372time we use the learned mean as the weight values.
373
374This method also supports the pruning of weights based on their log \alpha
375values. All weights with log \alpha >= `threshold` are set to zero.
376
377Args:
378x: Tensor representing the input batch.
379variational_params: 2-tuple of Tensors, where the first tensor is the
380\theta values and the second contains the log of the \sigma^2 values.
381strides: The stride of the sliding window for each dimension of `x`.
382Identical to standard strides argument for tf.conv2d.
383padding: String. One of "SAME", or "VALID". Identical to standard
384padding argument for tf.conv2d.
385data_format: 'NHWC' or 'NCHW' ordering of 4-D input Tensor.
386threshold: Weights with a log \alpha_{ij} value greater than this will
387be set to zero.
388eps: Small constant value to use in log and sqrt operations to avoid NaNs.
389
390Returns:
391Output Tensor of the conv2d operation.
392
393Raises:
394RuntimeError: If the variational_params argument is not a 2-tuple.
395"""
396theta, log_sigma2 = _verify_variational_params(
397variational_params)
398
399# Compute the weight mask by thresholding on
400# the log-space alpha values
401log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, value_limit=None)
402weight_mask = tf.cast(tf.less(log_alpha, threshold), tf.float32)
403
404return tf.nn.conv2d(
405x, theta * weight_mask, strides, padding, data_format=data_format)
406
407
408# NOTE: This implementation of variational dropout on an embedding samples
409# new noise for each embedding vectors at all timesteps in the batch
410# and across sequences in the batch. An alternative implementation would
411# be to sample a noise vector for each token in the vocabulary, so that
412# all instances of an embedding vector for a given token would be the
413# same within a batch. Another alternative implementation would be to
414# sample a noise vector for each token in the vocabulary for each element
415# in the batch so that, within a sequence, all instances of an embedding
416# vector for a given token would be the same, but across different elements
417# in the batch they could be different.
418#
419# The first alternative implementation would add another embedding lookup
420# to the implementation. We'd generate a noise tensor with shape
421# [vocab_size, embedding_size], and for each token id in the batch we'd
422# do an embedding lookup to get the appropriate noise vector. We'd then
423# do two more embedding lookups, one to get the mean vector and one to
424# get the log variance vector for the token. These 3 tensors with shape
425# [batch_size, seq_length, embedding_size] would then be combined the
426# same way they are in this implementation.
427#
428# This last implementation may not be practical, because we would have to
429# sample `vocab_size * embedding_size * batch_size` random values per
430# iteration. We'd also have unique noise embeddings for each element in
431# the batch, meaning we'd have to do `batch_size` + 2 embedding lookups.
432#
433# This implementation is the most efficient in terms of embedding lookups
434# and noise sampling.
435def embedding_lookup_train(
436variational_params,
437ids,
438name=None,
439clip_alpha=None,
440eps=common.EPSILON):
441R"""Embedding trained with variational dropout.
442
443In a standard embedding lookup, `ids` are looked-up in a list of embedding
444tensors. In an embedding trained with variational dropout, we lookup the
445parameters of the fully-factorized Gaussian posterior over the embedding
446tensor for each index in `ids` and draw a sample from this distribution
447that is returned.
448
449The `ids` argument is analogous to those in the standard tf.embedding_lookup.
450
451Args:
452variational_params: 2-tuple of Tensors, where the first tensor is the \theta
453values and the second contains the log of the \sigma^2 values.
454ids: A Tensor with type int32 or int64 containing the ids to be looked up
455in params.
456name: String. Name of the operator.
457clip_alpha: Int or None. If integer, we clip the log \alpha values
458to [-clip_alpha, clip_alpha]. If None, don't clip the values.
459eps: Small constant value to use in log and sqrt operations to avoid NaNs.
460
461Returns:
462The output Tensor result of the embedding lookup.
463
464Raises:
465RuntimeError: If the input variational_params is not a 2-tuple of Tensors
466that have the same shape.
467"""
468theta, log_sigma2 = _verify_variational_params(
469variational_params)
470
471# Before we do anything, lookup the mean and log variances of the embedding
472# vectors we are going to output and do all our operations in this lower
473# dimensional space
474embedding_theta = layer_utils.gather(theta, ids)
475embedding_log_sigma2 = layer_utils.gather(log_sigma2, ids)
476
477if clip_alpha:
478# Compute the log_alphas and then compute the
479# log_sigma2 again so that we can clip on the
480# log alpha magnitudes
481embedding_log_alpha = common.compute_log_alpha(
482embedding_log_sigma2, embedding_theta, eps, clip_alpha)
483embedding_log_sigma2 = common.compute_log_sigma2(
484embedding_log_alpha, embedding_theta, eps)
485
486# Calculate the standard deviation from the log variance
487embedding_std = tf.sqrt(tf.exp(embedding_log_sigma2) + eps)
488
489# Output samples from the distribution over the embedding vectors
490output_shape = tf.shape(embedding_std)
491embedding = embedding_theta + embedding_std * tf.random_normal(output_shape)
492return tf.identity(embedding, name=name)
493
494
495def embedding_lookup_eval(
496variational_params,
497ids,
498name=None,
499threshold=3.0,
500eps=common.EPSILON):
501R"""Evaluation mode embedding trained with variational dropout.
502
503In a standard embedding lookup, `ids` are looked-up in a list of embedding
504tensors. In an embedding trained with variational dropout, we lookup the
505parameters of the fully-factorized Gaussian posterior over the embedding
506tensor for each index in `ids` and draw a sample from this distribution
507that is returned. At evaluation time, we use the mean of the posterior
508over each embedding tensor instead of sampling.
509
510The `ids` and `partition_strategy` arguments are analogous to those in the
511standard tf.embedding_lookup.
512
513Args:
514variational_params: 2-tuple of Tensors, where the first tensor is the \theta
515values and the second contains the log of the \sigma^2 values.
516ids: A Tensor with type int32 or int64 containing the ids to be looked up
517in params.
518name: String. Name of the operator.
519threshold: Weights with a log \alpha_{ij} value greater than this will be
520set to zero.
521eps: Small constant value to use in log and sqrt operations to avoid NaNs.
522
523Returns:
524The output Tensor result of the embedding lookup.
525
526Raises:
527RuntimeError: If the input variational_params is not a 2-tuple of Tensors
528that have the same shape.
529"""
530theta, log_sigma2 = _verify_variational_params(
531variational_params)
532
533# Rather than mask the whole embedding every iteration, we can do a second
534# embedding lookup on the log \sigma2 values, compute the log \alpha values
535# for each output embedding vector, and then mask the much lower dimensional
536# output embedding vectors
537embedding_theta = layer_utils.gather(theta, ids)
538embedding_log_sigma2 = layer_utils.gather(log_sigma2, ids)
539
540# Compute the weight mask by thresholding on the log-space alpha values
541embedding_log_alpha = common.compute_log_alpha(
542embedding_log_sigma2, embedding_theta, eps, value_limit=None)
543embedding_mask = tf.cast(tf.less(embedding_log_alpha, threshold), tf.float32)
544
545# Return the masked embedding vectors
546return tf.identity(embedding_theta * embedding_mask, name=name)
547
548
549def negative_dkl(variational_params=None,
550clip_alpha=None,
551eps=common.EPSILON,
552log_alpha=None):
553R"""Compute the negative kl-divergence loss term.
554
555Computes the negative kl-divergence between the log-uniform prior over the
556weights and the variational posterior over the weights for each element
557in the set of variational parameters. Each contribution is summed and the
558sum is returned as a scalar Tensor.
559
560The true kl-divergence is intractable, so we compute the tight approximation
561from https://arxiv.org/abs/1701.05369.
562
563Args:
564variational_params: 2-tuple of Tensors, where the first tensor is the \theta
565values and the second contains the log of the \sigma^2 values.
566clip_alpha: Int or None. If integer, we clip the log \alpha values to
567[-clip_alpha, clip_alpha]. If None, don't clip the values.
568eps: Small constant value to use in log and sqrt operations to avoid NaNs.
569log_alpha: float32 tensor of log alpha values.
570Returns:
571Output scalar Tensor containing the sum of all negative kl-divergence
572contributions for each element in the input variational_params.
573
574Raises:
575RuntimeError: If the variational_params argument is not a 2-tuple.
576"""
577
578if variational_params is not None:
579theta, log_sigma2 = _verify_variational_params(variational_params)
580
581if log_alpha is None:
582log_alpha = common.compute_log_alpha(log_sigma2, theta, eps, clip_alpha)
583
584# Constant values for approximating the kl divergence
585k1, k2, k3 = 0.63576, 1.8732, 1.48695
586c = -k1
587
588# Compute each term of the KL and combine
589term_1 = k1 * tf.nn.sigmoid(k2 + k3*log_alpha)
590term_2 = -0.5 * tf.log1p(tf.exp(tf.negative(log_alpha)))
591eltwise_dkl = term_1 + term_2 + c
592return -tf.reduce_sum(eltwise_dkl)
593