google-research
97 строк · 3.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines common utilities for l0-regularization layers."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20
21import tensorflow.compat.v1 as tf22
23
24# Small constant value to add when taking logs or sqrts to avoid NaNs
25EPSILON = 1e-826
27# The default hard-concrete distribution parameters
28BETA = 2.0 / 3.029GAMMA = -0.130ZETA = 1.131
32
33def hard_concrete_sample(34log_alpha,35beta=BETA,36gamma=GAMMA,37zeta=ZETA,38eps=EPSILON):39"""Sample values from the hard concrete distribution.40
41The hard concrete distribution is described in
42https://arxiv.org/abs/1712.01312.
43
44Args:
45log_alpha: The log alpha parameters that control the "location" of the
46distribution.
47beta: The beta parameter, which controls the "temperature" of
48the distribution. Defaults to 2/3 from the above paper.
49gamma: The gamma parameter, which controls the lower bound of the
50stretched distribution. Defaults to -0.1 from the above paper.
51zeta: The zeta parameters, which controls the upper bound of the
52stretched distribution. Defaults to 1.1 from the above paper.
53eps: A small constant value to add to logs and sqrts to avoid NaNs.
54
55Returns:
56A tf.Tensor representing the output of the sampling operation.
57"""
58random_noise = tf.random_uniform(59tf.shape(log_alpha),60minval=0.0,61maxval=1.0)62
63# NOTE: We add a small constant value to the noise before taking the64# log to avoid NaNs if a noise value is exactly zero. We sample values65# in the range [0, 1), so the right log is not at risk of NaNs.66gate_inputs = tf.log(random_noise + eps) - tf.log(1.0 - random_noise)67gate_inputs = tf.sigmoid((gate_inputs + log_alpha) / beta)68stretched_values = gate_inputs * (zeta - gamma) + gamma69
70return tf.clip_by_value(71stretched_values,72clip_value_max=1.0,73clip_value_min=0.0)74
75
76def hard_concrete_mean(log_alpha, gamma=GAMMA, zeta=ZETA):77"""Calculate the mean of the hard concrete distribution.78
79The hard concrete distribution is described in
80https://arxiv.org/abs/1712.01312.
81
82Args:
83log_alpha: The log alpha parameters that control the "location" of the
84distribution.
85gamma: The gamma parameter, which controls the lower bound of the
86stretched distribution. Defaults to -0.1 from the above paper.
87zeta: The zeta parameters, which controls the upper bound of the
88stretched distribution. Defaults to 1.1 from the above paper.
89
90Returns:
91A tf.Tensor representing the calculated means.
92"""
93stretched_values = tf.sigmoid(log_alpha) * (zeta - gamma) + gamma94return tf.clip_by_value(95stretched_values,96clip_value_max=1.0,97clip_value_min=0.0)98