google-research
362 строки · 14.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""This module defines the softranks and softsort operators."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import itertools23
24import tensorflow.compat.v2 as tf25
26from soft_sort import sinkhorn27from soft_sort import soft_quantilizer28
29DIRECTIONS = ('ASCENDING', 'DESCENDING')30_TARGET_WEIGHTS_ARG = 'target_weights'31
32
33def preprocess(x, axis):34"""Reshapes the input data to make it rank 2 as required by SoftQuantilizer.35
36The SoftQuantilizer expects an input tensor of rank 2, where the first
37dimension is the batch dimension and the soft sorting is applied on the second
38one.
39
40Args:
41x: Tensor<float> of any dimension.
42axis: (int) the axis to be turned into the second dimension.
43
44Returns:
45a Tuple(Tensor<float>[batch, n], List[int], tf.Tensor) where
46- the first element is the output tensor (n being the dimensions over the
47axis and batch the product of all other dimensions)
48- the second element represents the transposition that was applied as a list
49of integers.
50- the third element the shape after the transposition was applied.
51
52Those three outputs are necessary in order to easily perform the inverse
53transformation down the line.
54"""
55dims = list(range(x.shape.rank))56dims[-1], dims[axis] = dims[axis], dims[-1]57x_transposed = tf.transpose(x, dims)58x_flat = tf.reshape(x_transposed, (-1, tf.shape(x)[axis]))59return x_flat, dims, tf.shape(x_transposed)60
61
62def postprocess(x, transposition, shape):63"""Applies the inverse transformation of preprocess.64
65Args:
66x: Tensor<float>[batch, n]
67transposition: Tensor<int>[rank] 1D tensor representing the transposition
68that was used to preprocess the input tensor. Since transpositions are
69involutions, applying the same transposition brings back to the original
70shape.
71shape: TensorShape of the intermediary output.
72
73Returns:
74A Tensor<float> that is similar in shape to the tensor before preprocessing.
75"""
76shape = tf.concat([shape[:-1], tf.shape(x)[-1:]], axis=0)77return tf.transpose(tf.reshape(x, shape), transposition)78
79
80def softsort(x,81direction = 'ASCENDING',82axis = -1,83topk = None,84**kwargs):85"""Applies the softsort operator on input tensor x.86
87This operator acts as differentiable alternative to tf.sort.
88
89Args:
90x: the input tensor. It can be either of shape [batch, n] or [n].
91direction: the direction 'ASCENDING' or 'DESCENDING'
92axis: the axis on which to operate the sort.
93topk: if not None, the number of topk sorted values that are going to be
94computed. Using topk improves the speed of the algorithms since it solves a
95simpler problem.
96**kwargs: see SoftQuantilizer for possible parameters.
97
98Returns:
99A tensor of sorted values of the same shape as the input tensor.
100"""
101if direction not in DIRECTIONS:102raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))103
104if topk is not None and _TARGET_WEIGHTS_ARG in kwargs:105raise ValueError(106'Conflicting arguments: both topk and target_weights are being set.')107
108z, transposition, shape = preprocess(x, axis)109descending = (direction == 'DESCENDING')110
111if topk is not None:112n = tf.cast(tf.shape(z)[-1], dtype=x.dtype)113kwargs[_TARGET_WEIGHTS_ARG] = 1.0 / n * tf.concat(114[tf.ones(topk, dtype=x.dtype), (n - topk) * tf.ones(1, dtype=x.dtype)],115axis=0)116
117sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)118# We need to compute topk + 1 values in case we use topk119values = sorter.softsort if topk is None else sorter.softsort[:, :-1]120return postprocess(values, transposition, shape)121
122
123def softranks(x, direction='ASCENDING', axis=-1, zero_based=True, **kwargs):124"""A differentiable argsort-like operator that returns directly the ranks.125
126Note that it behaves as the 'inverse' of the argsort operator since it returns
127soft ranks, i.e. real numbers that play the role of indices and quantify the
128relative standing (among all n entries) of each entry of x.
129
130Args:
131x: Tensor<float> of any shape.
132direction: (str) either 'ASCENDING' or 'DESCENDING', as in tf.sort.
133axis: (int) the axis along which to sort, as in tf.sort.
134zero_based: (bool) to return values in [0, n-1] or in [1, n].
135**kwargs: see SoftQuantilizer for possible parameters.
136
137Returns:
138A Tensor<float> of the same shape as the input containing the soft ranks.
139"""
140if direction not in DIRECTIONS:141raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))142
143descending = (direction == 'DESCENDING')144z, transposition, shape = preprocess(x, axis)145sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)146ranks = sorter.softcdf * tf.cast(tf.shape(z)[1], dtype=x.dtype)147if zero_based:148ranks -= tf.cast(1.0, dtype=x.dtype)149return postprocess(ranks, transposition, shape)150
151
152def softquantiles(x,153quantiles,154quantile_width=None,155axis=-1,156may_squeeze=True,157**kwargs):158"""Computes soft quantiles via optimal transport.159
160This operator takes advantage of the fact that an exhaustive softsort is not
161required to recover a single quantile. Instead, one can transport all
162input values in x onto only 3 weighted values. Target weights are adjusted so
163that those values in x that are transported to the middle value in the target
164vector y correspond to those concentrating around the quantile of interest.
165
166This idea generalizes to more quantiles, interleaving small weights on the
167quantile indices and bigger weights in between, corresponding to the gap from
168one desired quantile to the next one.
169
170Args:
171x: Tensor<float> of any shape.
172quantiles: list<float> the quantiles to be returned. It can also be a single
173float.
174quantile_width: (float) mass given to the bucket supposed to attract points
175whose value concentrate around the desired quantile value. Bigger width
176means that we allow the soft quantile to be a mixture of more points
177further away from the quantile. If None, the width is set at 1/n where n is
178the number of values considered (the size along the 'axis').
179axis: (int) the axis along which to compute the quantile.
180may_squeeze: (bool) should we squeeze the output tensor in case of a single
181quantile.
182**kwargs: see SoftQuantilizer for possible extra parameters.
183
184Returns:
185A Tensor<float> similar to the input tensor, but the axis dimension is
186replaced by the number of quantiles specified in the quantiles list.
187Hence, if only a quantile is requested (quantiles is a float) only one value
188in that axis is returned. When several quantiles are requested, the tensor
189will have that many values in that axis.
190
191Raises:
192tf.errors.InvalidArgumentError when the quantiles and quantile width are not
193correct, namely quantiles are either not in sorted order or the
194quantile_width is too large.
195"""
196if isinstance(quantiles, float):197quantiles = [quantiles]198quantiles = tf.constant(quantiles, tf.float32)199
200# Preprocesses submitted quantiles to check that they satisfy elementary201# constraints.202valid_quantiles = tf.boolean_mask(203quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0))204num_quantiles = tf.shape(valid_quantiles)[0]205
206# Includes values on both ends of [0,1].207extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0)208
209# Builds filler_weights in between the target quantiles.210filler_weights = extended_quantiles[1:] - extended_quantiles[:-1]211if quantile_width is None:212quantile_width = tf.reduce_min(213tf.concat(214[filler_weights, [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)]],215axis=0))216
217# Takes into account quantile_width in the definition of weights218shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype)219shift = shift + 0.5 * (220tf.one_hot(0, num_quantiles + 1) +221tf.one_hot(num_quantiles, num_quantiles + 1))222filler_weights = filler_weights + quantile_width * shift223
224assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0), [filler_weights])225with tf.control_dependencies([assert_op]):226# Adds one more value to have tensors of the same shape to interleave them.227quantile_weights = tf.ones(num_quantiles + 1) * quantile_width228
229# Interleaves the filler_weights with the quantile weights.230weights = tf.reshape(231tf.stack([filler_weights, quantile_weights], axis=1), (-1,))[:-1]232
233# Sends only the positive weights to the softsort operator.234positive_weights = tf.boolean_mask(weights, weights > 0.0)235all_quantiles = softsort(236x,237direction='ASCENDING',238axis=axis,239target_weights=positive_weights,240**kwargs)241
242# Recovers the indices corresponding to the desired quantiles.243odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32), 2)244positives = tf.cast(weights > 0.0, tf.float32)245indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32)246indices = tf.boolean_mask(indices, indices > 0) - 1247result = tf.gather(all_quantiles, indices, axis=axis)248
249# In the specific case where we want a single quantile, squeezes the250# quantile dimension.251can_squeeze = tf.equal(tf.shape(result)[axis], 1)252if tf.math.logical_and(can_squeeze, may_squeeze):253result = tf.squeeze(result, axis=axis)254return result255
256
257def soft_multivariate_quantiles(x,258quantiles,259quantile_width=None,260**kwargs):261"""Computes soft multivariate quantiles via optimal transport.262
263Transport multivariate input values in x onto 2^d + 1 weighted points,
264{0,1}^d + [0.5, ..., 0.5]. Target weights are adjusted so
265that those values in x that are transported to the middle value in the target
266vector correspond to those concentrating around the quantile of interest.
267
268Args:
269x: Tensor<float> of shape [batch, N, d]
270quantiles: Tensor<float> of shape [r, d], r targeted quantiles of dimension d
271quantile_width: (float) mass given to the bucket supposed to attract points
272whose value concentrate around the desired quantile value. Bigger width
273means that we allow the soft quantile to be a mixture of more points
274further away from the quantile. If None, the width is set at 1/n where n is
275the number of values considered (the size along the 'axis').
276**kwargs: see sinkhorn.autodiff_sinkhorn for possible extra parameters.
277
278Returns:
279A Tensor<float> [N,r,d] of multivariate quantiles per batch.
280
281"""
282quantiles = tf.constant(quantiles, tf.float32)283batch_size = x.shape[0]284n = tf.cast(x.shape[1], tf.float32)285d = x.shape[2]286if quantile_width is None:287quantile_width = 2 / n288num_quantiles = tf.shape(quantiles)[0]289hypercube_vertices = tf.constant(290list(itertools.product([-1, 1], repeat=d)), tf.float32)291# weights attached to vertices for each quantile. this is n_quantiles x 2^r292weights = quantiles[:, tf.newaxis, :]**(2930.5 * (1 - hypercube_vertices))[tf.newaxis, Ellipsis]294weights *= (1 - quantiles)[:, tf.newaxis, :]**(2950.5 * (1 + hypercube_vertices))[tf.newaxis, Ellipsis]296
297weights = (1 - quantile_width) * tf.reduce_prod(weights, axis=2)298# adding weights for quantile itself (in position 0).299weights = tf.concat((quantile_width * tf.ones((num_quantiles, 1)), weights),300axis=1)301# augmenting and formating as batch_size * 2^r +1 * num_quantiles302weights = tf.reshape(303tf.tile(tf.transpose(weights), [batch_size, 1]),304[batch_size, 2**d + 1, num_quantiles])305# set target locations, by adding the point at 0 that will absorb the quantile306# augment it with batch_size307y = tf.concat((tf.zeros((1, d), dtype=tf.float32), hypercube_vertices),308axis=0)309y = tf.reshape(tf.tile(y, [batch_size, 1]), [batch_size, 2**d + 1, d])310# center x311x_mean = tf.reduce_mean(x, axis=1)312x = x - x_mean[:, tf.newaxis, :]313transports = sinkhorn.autodiff_sinkhorn(314x, y,315tf.ones([batch_size, n, num_quantiles], dtype=tf.float32) / n, weights,316**kwargs)317
318# recover convex combinations resulting from transporting to central point in319# in all batches and quantile variations.320transports = 1 / quantile_width * tf.reshape(transports[:, :, 0, :],321[batch_size, n, -1])322# apply these convex combinations to data points + recenter.323all_soft_quantiles = tf.reduce_sum(324transports[:, :, :, tf.newaxis] *325x[:, :, tf.newaxis, :],326axis=1) + x_mean[:, tf.newaxis, :]327# reshape those quantiles after having applied convex combinations.328return tf.reshape(all_soft_quantiles, [batch_size, num_quantiles, d])329
330
331def soft_quantile_normalization(x, f, axis=-1, **kwargs):332"""Applies a (soft) quantile normalization of x with f.333
334The usual quantile normalization operator uses the empirical values contained
335in x to construct an empirical density function (EDF), assign to each value in
336x its corresponding EDF (i.e. its rank divided by the size of x), and then
337replace it with the corresponding quantiles described in vector f
338(see https://en.wikipedia.org/wiki/Quantile_normalization).
339
340The operator proposed here does so in a differentiable manner, by computing
341first a distribution of ranks for x (stored in an optimal transport table) and
342then take averages of those values stored in f.
343
344Note that the current function only works when f is a vector of sorted values
345corresponding to the quantiles of a distribution at levels [1/m ,..., m / m],
346where m is the size of f.
347
348Args:
349x: Tensor<float> of any shape.
350f: Tensor<float>[m] where m can be or not the size of x along the axis.
351Usually it is. f should be sorted.
352axis: the axis along which the tensor x should be quantile normalized.
353**kwargs: extra parameters passed to the SoftQuantilizer.
354
355Returns:
356A tensor of the same shape of x.
357"""
358z, transposition, shape = preprocess(x, axis)359sorter = soft_quantilizer.SoftQuantilizer(360z, descending=False, num_targets=f.shape[-1], **kwargs)361y = 1.0 / sorter.weights * tf.linalg.matvec(sorter.transport, f)362return postprocess(y, transposition, shape)363