google-research

Форк
0
97 строк · 3.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines common utilities for l0-regularization layers."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import tensorflow.compat.v1 as tf
22

23

24
# Small constant value to add when taking logs or sqrts to avoid NaNs
25
EPSILON = 1e-8
26

27
# The default hard-concrete distribution parameters
28
BETA = 2.0 / 3.0
29
GAMMA = -0.1
30
ZETA = 1.1
31

32

33
def hard_concrete_sample(
34
    log_alpha,
35
    beta=BETA,
36
    gamma=GAMMA,
37
    zeta=ZETA,
38
    eps=EPSILON):
39
  """Sample values from the hard concrete distribution.
40

41
  The hard concrete distribution is described in
42
  https://arxiv.org/abs/1712.01312.
43

44
  Args:
45
    log_alpha: The log alpha parameters that control the "location" of the
46
      distribution.
47
    beta: The beta parameter, which controls the "temperature" of
48
      the distribution. Defaults to 2/3 from the above paper.
49
    gamma: The gamma parameter, which controls the lower bound of the
50
      stretched distribution. Defaults to -0.1 from the above paper.
51
    zeta: The zeta parameters, which controls the upper bound of the
52
      stretched distribution. Defaults to 1.1 from the above paper.
53
    eps: A small constant value to add to logs and sqrts to avoid NaNs.
54

55
  Returns:
56
    A tf.Tensor representing the output of the sampling operation.
57
  """
58
  random_noise = tf.random_uniform(
59
      tf.shape(log_alpha),
60
      minval=0.0,
61
      maxval=1.0)
62

63
  # NOTE: We add a small constant value to the noise before taking the
64
  # log to avoid NaNs if a noise value is exactly zero. We sample values
65
  # in the range [0, 1), so the right log is not at risk of NaNs.
66
  gate_inputs = tf.log(random_noise + eps) - tf.log(1.0 - random_noise)
67
  gate_inputs = tf.sigmoid((gate_inputs + log_alpha) / beta)
68
  stretched_values = gate_inputs * (zeta - gamma) + gamma
69

70
  return tf.clip_by_value(
71
      stretched_values,
72
      clip_value_max=1.0,
73
      clip_value_min=0.0)
74

75

76
def hard_concrete_mean(log_alpha, gamma=GAMMA, zeta=ZETA):
77
  """Calculate the mean of the hard concrete distribution.
78

79
  The hard concrete distribution is described in
80
  https://arxiv.org/abs/1712.01312.
81

82
  Args:
83
    log_alpha: The log alpha parameters that control the "location" of the
84
      distribution.
85
    gamma: The gamma parameter, which controls the lower bound of the
86
      stretched distribution. Defaults to -0.1 from the above paper.
87
    zeta: The zeta parameters, which controls the upper bound of the
88
      stretched distribution. Defaults to 1.1 from the above paper.
89

90
  Returns:
91
    A tf.Tensor representing the calculated means.
92
  """
93
  stretched_values = tf.sigmoid(log_alpha) * (zeta - gamma) + gamma
94
  return tf.clip_by_value(
95
      stretched_values,
96
      clip_value_max=1.0,
97
      clip_value_min=0.0)
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.