google-research
137 строк · 5.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implements different squashing strategies before applying the Sinkhorn algorithm.
17
18The main purpose of those functions is to map the numbers we wish to sort into
19the [0, 1] segment using an increasing function, such as a logistic map.
20
21This logistic map, when applied on the output of a neural network,
22redistributes the activations into [0,1] in a smooth adaptative way, helping
23the numerical stability of the Sinkhorn algorithm while maintaining a
24well behaved back propagation.
25
26In case of a logistic sigmoid, this map is exactly the CDF of a logistic
27distribution. See https://en.wikipedia.org/wiki/Logistic_distribution for
28details, in particular the variance of the distribution. In case of an atan
29sigmoid, it is somewhat related to the CDF of a Cauchy
30distribution and presents the advantage to have better behaved gradients.
31
32In a such a logistic map, the points lying in the linear part of the map will
33be well spread out on the [0, 1] segment, which will make them easier to sort.
34Therefore, depending on which part of the distribution is of interest, we might
35want to focus on one part of another, hence leading to different translations
36before applying a squashing function.
37"""
38
39import math40import gin41import tensorflow.compat.v2 as tf42
43
44@gin.configurable45def reduce_softmax(x, tau, axis = -1):46"""Computes the softmax of a tensor along a given axis.47
48Args:
49x: (tf.Tensor<float>) the input tensor of any shape.
50tau: (float) the value of the inverse softmax temperature.
51When tau is very big the obtained value is close to the maximum, when 0
52it coincides with the mean and when very negative it converges to the
53minimum.
54axis: (int) the axis along which we want to compute the softmax.
55
56Returns:
57a tf.Tensor<float> that has the same shape than the input tensor except for
58the reduction axis which is gone.
59"""
60return tf.math.reduce_sum(tf.nn.softmax(x * tau, axis=axis) * x, axis=axis)61
62
63@gin.configurable64def whiten(x, axis = -1, min_std=1e-10):65"""Makes the input tensor zero mean and unit variance along the axis.66
67Args:
68x: (tf.Tensor<float>) of any shape to be whitened.
69axis: (int) the axis along which to compute the statistics.
70min_std: (float) a minimum value of the standard deviation along the axis to
71prevent degenerated cases.
72
73Returns:
74A tf.Tensor<float> of the same shape as the input tensor.
75"""
76mu = tf.expand_dims(tf.math.reduce_mean(x, axis=axis), axis=axis)77min_std = 1e-678sigma = tf.expand_dims(79tf.maximum(tf.math.reduce_std(x, axis=axis), min_std), axis=axis)80return (x - mu) / sigma81
82
83@gin.configurable84def soft_stretch(85x, axis = -1, extreme_tau = 1e12):86"""Softly rescales the values of `x` along the axis to the [0, 1] segment.87
88Args:
89x: (tf.Tensor<float> of any shape) the input tensor to rescale the values of.
90axis: (int) the axis along which we want to rescale.
91extreme_tau: (float) the value of the inverse temperature to compute the
92softmax and softmin. This must be big for the output values to
93really lie in the [0, 1] segment.
94
95Returns:
96A tf.Tensor<float> of the same shape as the input.
97"""
98min_x = tf.expand_dims(99reduce_softmax(x, tau=-extreme_tau, axis=axis), axis=axis)100max_x = tf.expand_dims(101reduce_softmax(x, tau=extreme_tau, axis=axis), axis=axis)102return (x - min_x) / (max_x - min_x)103
104
105@gin.configurable106def group_rescale(107x,108is_logistic = True,109tau = 0.0,110stretch = False):111"""Applies a sigmoid map on standardized inputs.112
113By default, the inputs is centered, but it can be uncentered by playing with
114the parameters `tau` and `uncenter_towards_max`.
115
116Args:
117x: Tensor<float>[batch, n]
118is_logistic: (bool) uses either a logistic sigmoid or an arctan.
119tau: (float or None) inverse temperature parameter that, if not None,
120controls how much deviation we want from the mean. The bigger the closer to
121the maximum and the more negative to the minimum
122stretch: (bool) if True, stretches the values to the the full [0, 1] segment.
123
124Returns:
125A Tensor<float>[batch, n] after application of the sigmoid map.
126"""
127x = whiten(x, axis=1)128if is_logistic:129x /= math.sqrt(3.0) / math.pi130if tau != 0:131center = reduce_softmax(x, tau=tau, axis=1)132x = x - center[:, tf.newaxis]133squashing_fn = tf.math.sigmoid if is_logistic else tf.math.atan134squashed_x = squashing_fn(x)135if stretch:136return soft_stretch(squashed_x)137return squashed_x138