google-research

Форк
0
137 строк · 5.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Implements different squashing strategies before applying the Sinkhorn algorithm.
17

18
The main purpose of those functions is to map the numbers we wish to sort into
19
the [0, 1] segment using an increasing function, such as a logistic map.
20

21
This logistic map, when applied on the output of a neural network,
22
redistributes the activations into [0,1] in a smooth adaptative way, helping
23
the numerical stability of the Sinkhorn algorithm while maintaining a
24
well behaved back propagation.
25

26
In case of a logistic sigmoid, this map is exactly the CDF of a logistic
27
distribution. See https://en.wikipedia.org/wiki/Logistic_distribution for
28
details, in particular the variance of the distribution. In case of an atan
29
sigmoid, it is somewhat related to the CDF of a Cauchy
30
distribution and presents the advantage to have better behaved gradients.
31

32
In a such a logistic map, the points lying in the linear part of the map will
33
be well spread out on the [0, 1] segment, which will make them easier to sort.
34
Therefore, depending on which part of the distribution is of interest, we might
35
want to focus on one part of another, hence leading to different translations
36
before applying a squashing function.
37
"""
38

39
import math
40
import gin
41
import tensorflow.compat.v2 as tf
42

43

44
@gin.configurable
45
def reduce_softmax(x, tau, axis = -1):
46
  """Computes the softmax of a tensor along a given axis.
47

48
  Args:
49
   x: (tf.Tensor<float>) the input tensor of any shape.
50
   tau: (float) the value of the inverse softmax temperature.
51
     When tau is very big the obtained value is close to the maximum, when 0
52
     it coincides with the mean and when very negative it converges to the
53
     minimum.
54
   axis: (int) the axis along which we want to compute the softmax.
55

56
  Returns:
57
   a tf.Tensor<float> that has the same shape than the input tensor except for
58
    the reduction axis which is gone.
59
  """
60
  return tf.math.reduce_sum(tf.nn.softmax(x * tau, axis=axis) * x, axis=axis)
61

62

63
@gin.configurable
64
def whiten(x, axis = -1, min_std=1e-10):
65
  """Makes the input tensor zero mean and unit variance along the axis.
66

67
  Args:
68
   x: (tf.Tensor<float>) of any shape to be whitened.
69
   axis: (int) the axis along which to compute the statistics.
70
   min_std: (float) a minimum value of the standard deviation along the axis to
71
    prevent degenerated cases.
72

73
  Returns:
74
   A tf.Tensor<float> of the same shape as the input tensor.
75
  """
76
  mu = tf.expand_dims(tf.math.reduce_mean(x, axis=axis), axis=axis)
77
  min_std = 1e-6
78
  sigma = tf.expand_dims(
79
      tf.maximum(tf.math.reduce_std(x, axis=axis), min_std), axis=axis)
80
  return (x - mu) / sigma
81

82

83
@gin.configurable
84
def soft_stretch(
85
    x, axis = -1, extreme_tau = 1e12):
86
  """Softly rescales the values of `x` along the axis to the [0, 1] segment.
87

88
  Args:
89
   x: (tf.Tensor<float> of any shape) the input tensor to rescale the values of.
90
   axis: (int) the axis along which we want to rescale.
91
   extreme_tau: (float) the value of the inverse temperature to compute the
92
    softmax and softmin. This must be big for the output values to
93
    really lie in the [0, 1] segment.
94

95
  Returns:
96
   A tf.Tensor<float> of the same shape as the input.
97
  """
98
  min_x = tf.expand_dims(
99
      reduce_softmax(x, tau=-extreme_tau, axis=axis), axis=axis)
100
  max_x = tf.expand_dims(
101
      reduce_softmax(x, tau=extreme_tau, axis=axis), axis=axis)
102
  return (x - min_x) / (max_x - min_x)
103

104

105
@gin.configurable
106
def group_rescale(
107
    x,
108
    is_logistic = True,
109
    tau = 0.0,
110
    stretch = False):
111
  """Applies a sigmoid map on standardized inputs.
112

113
  By default, the inputs is centered, but it can be uncentered by playing with
114
  the parameters `tau` and `uncenter_towards_max`.
115

116
  Args:
117
   x: Tensor<float>[batch, n]
118
   is_logistic: (bool) uses either a logistic sigmoid or an arctan.
119
   tau: (float or None) inverse temperature parameter that, if not None,
120
    controls how much deviation we want from the mean. The bigger the closer to
121
    the maximum and the more negative to the minimum
122
   stretch: (bool) if True, stretches the values to the the full [0, 1] segment.
123

124
  Returns:
125
   A Tensor<float>[batch, n] after application of the sigmoid map.
126
  """
127
  x = whiten(x, axis=1)
128
  if is_logistic:
129
    x /= math.sqrt(3.0) / math.pi
130
  if tau != 0:
131
    center = reduce_softmax(x, tau=tau, axis=1)
132
    x = x - center[:, tf.newaxis]
133
  squashing_fn = tf.math.sigmoid if is_logistic else tf.math.atan
134
  squashed_x = squashing_fn(x)
135
  if stretch:
136
    return soft_stretch(squashed_x)
137
  return squashed_x
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.