google-research

Форк
0
362 строки · 14.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""This module defines the softranks and softsort operators."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import itertools
23

24
import tensorflow.compat.v2 as tf
25

26
from soft_sort import sinkhorn
27
from soft_sort import soft_quantilizer
28

29
DIRECTIONS = ('ASCENDING', 'DESCENDING')
30
_TARGET_WEIGHTS_ARG = 'target_weights'
31

32

33
def preprocess(x, axis):
34
  """Reshapes the input data to make it rank 2 as required by SoftQuantilizer.
35

36
  The SoftQuantilizer expects an input tensor of rank 2, where the first
37
  dimension is the batch dimension and the soft sorting is applied on the second
38
  one.
39

40
  Args:
41
   x: Tensor<float> of any dimension.
42
   axis: (int) the axis to be turned into the second dimension.
43

44
  Returns:
45
   a Tuple(Tensor<float>[batch, n], List[int], tf.Tensor) where
46
    - the first element is the output tensor (n being the dimensions over the
47
    axis and batch the product of all other dimensions)
48
    - the second element represents the transposition that was applied as a list
49
     of integers.
50
    - the third element the shape after the transposition was applied.
51

52
   Those three outputs are necessary in order to easily perform the inverse
53
   transformation down the line.
54
  """
55
  dims = list(range(x.shape.rank))
56
  dims[-1], dims[axis] = dims[axis], dims[-1]
57
  x_transposed = tf.transpose(x, dims)
58
  x_flat = tf.reshape(x_transposed, (-1, tf.shape(x)[axis]))
59
  return x_flat, dims, tf.shape(x_transposed)
60

61

62
def postprocess(x, transposition, shape):
63
  """Applies the inverse transformation of preprocess.
64

65
  Args:
66
   x: Tensor<float>[batch, n]
67
   transposition: Tensor<int>[rank] 1D tensor representing the transposition
68
     that was used to preprocess the input tensor. Since transpositions are
69
     involutions, applying the same transposition brings back to the original
70
     shape.
71
   shape: TensorShape of the intermediary output.
72

73
  Returns:
74
   A Tensor<float> that is similar in shape to the tensor before preprocessing.
75
  """
76
  shape = tf.concat([shape[:-1], tf.shape(x)[-1:]], axis=0)
77
  return tf.transpose(tf.reshape(x, shape), transposition)
78

79

80
def softsort(x,
81
             direction = 'ASCENDING',
82
             axis = -1,
83
             topk = None,
84
             **kwargs):
85
  """Applies the softsort operator on input tensor x.
86

87
  This operator acts as differentiable alternative to tf.sort.
88

89
  Args:
90
   x: the input tensor. It can be either of shape [batch, n] or [n].
91
   direction: the direction 'ASCENDING' or 'DESCENDING'
92
   axis: the axis on which to operate the sort.
93
   topk: if not None, the number of topk sorted values that are going to be
94
     computed. Using topk improves the speed of the algorithms since it solves a
95
     simpler problem.
96
   **kwargs: see SoftQuantilizer for possible parameters.
97

98
  Returns:
99
   A tensor of sorted values of the same shape as the input tensor.
100
  """
101
  if direction not in DIRECTIONS:
102
    raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))
103

104
  if topk is not None and _TARGET_WEIGHTS_ARG in kwargs:
105
    raise ValueError(
106
        'Conflicting arguments: both topk and target_weights are being set.')
107

108
  z, transposition, shape = preprocess(x, axis)
109
  descending = (direction == 'DESCENDING')
110

111
  if topk is not None:
112
    n = tf.cast(tf.shape(z)[-1], dtype=x.dtype)
113
    kwargs[_TARGET_WEIGHTS_ARG] = 1.0 / n * tf.concat(
114
        [tf.ones(topk, dtype=x.dtype), (n - topk) * tf.ones(1, dtype=x.dtype)],
115
        axis=0)
116

117
  sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)
118
  # We need to compute topk + 1 values in case we use topk
119
  values = sorter.softsort if topk is None else sorter.softsort[:, :-1]
120
  return postprocess(values, transposition, shape)
121

122

123
def softranks(x, direction='ASCENDING', axis=-1, zero_based=True, **kwargs):
124
  """A differentiable argsort-like operator that returns directly the ranks.
125

126
  Note that it behaves as the 'inverse' of the argsort operator since it returns
127
  soft ranks, i.e. real numbers that play the role of indices and quantify the
128
  relative standing (among all n entries) of each entry of x.
129

130
  Args:
131
   x: Tensor<float> of any shape.
132
   direction: (str) either 'ASCENDING' or 'DESCENDING', as in tf.sort.
133
   axis: (int) the axis along which to sort, as in tf.sort.
134
   zero_based: (bool) to return values in [0, n-1] or in [1, n].
135
   **kwargs: see SoftQuantilizer for possible parameters.
136

137
  Returns:
138
   A Tensor<float> of the same shape as the input containing the soft ranks.
139
  """
140
  if direction not in DIRECTIONS:
141
    raise ValueError('`direction` should be one of {}'.format(DIRECTIONS))
142

143
  descending = (direction == 'DESCENDING')
144
  z, transposition, shape = preprocess(x, axis)
145
  sorter = soft_quantilizer.SoftQuantilizer(z, descending=descending, **kwargs)
146
  ranks = sorter.softcdf * tf.cast(tf.shape(z)[1], dtype=x.dtype)
147
  if zero_based:
148
    ranks -= tf.cast(1.0, dtype=x.dtype)
149
  return postprocess(ranks, transposition, shape)
150

151

152
def softquantiles(x,
153
                  quantiles,
154
                  quantile_width=None,
155
                  axis=-1,
156
                  may_squeeze=True,
157
                  **kwargs):
158
  """Computes soft quantiles via optimal transport.
159

160
  This operator takes advantage of the fact that an exhaustive softsort is not
161
  required to recover a single quantile. Instead, one can transport all
162
  input values in x onto only 3 weighted values. Target weights are adjusted so
163
  that those values in x that are transported to the middle value in the target
164
  vector y correspond to those concentrating around the quantile of interest.
165

166
  This idea generalizes to more quantiles, interleaving small weights on the
167
  quantile indices and bigger weights in between, corresponding to the gap from
168
  one desired quantile to the next one.
169

170
  Args:
171
   x: Tensor<float> of any shape.
172
   quantiles: list<float> the quantiles to be returned. It can also be a single
173
     float.
174
   quantile_width: (float) mass given to the bucket supposed to attract points
175
     whose value concentrate around the desired quantile value. Bigger width
176
     means that we allow the soft quantile to be a mixture of more points
177
     further away from the quantile. If None, the width is set at 1/n where n is
178
     the number of values considered (the size along the 'axis').
179
   axis: (int) the axis along which to compute the quantile.
180
   may_squeeze: (bool) should we squeeze the output tensor in case of a single
181
     quantile.
182
   **kwargs: see SoftQuantilizer for possible extra parameters.
183

184
  Returns:
185
    A Tensor<float> similar to the input tensor, but the axis dimension is
186
    replaced by the number of quantiles specified in the quantiles list.
187
    Hence, if only a quantile is requested (quantiles is a float) only one value
188
    in that axis is returned. When several quantiles are requested, the tensor
189
    will have that many values in that axis.
190

191
  Raises:
192
    tf.errors.InvalidArgumentError when the quantiles and quantile width are not
193
    correct, namely quantiles are either not in sorted order or the
194
    quantile_width is too large.
195
  """
196
  if isinstance(quantiles, float):
197
    quantiles = [quantiles]
198
  quantiles = tf.constant(quantiles, tf.float32)
199

200
  # Preprocesses submitted quantiles to check that they satisfy elementary
201
  # constraints.
202
  valid_quantiles = tf.boolean_mask(
203
      quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0))
204
  num_quantiles = tf.shape(valid_quantiles)[0]
205

206
  # Includes values on both ends of [0,1].
207
  extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0)
208

209
  # Builds filler_weights in between the target quantiles.
210
  filler_weights = extended_quantiles[1:] - extended_quantiles[:-1]
211
  if quantile_width is None:
212
    quantile_width = tf.reduce_min(
213
        tf.concat(
214
            [filler_weights, [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)]],
215
            axis=0))
216

217
  # Takes into account quantile_width in the definition of weights
218
  shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype)
219
  shift = shift + 0.5 * (
220
      tf.one_hot(0, num_quantiles + 1) +
221
      tf.one_hot(num_quantiles, num_quantiles + 1))
222
  filler_weights = filler_weights + quantile_width * shift
223

224
  assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0), [filler_weights])
225
  with tf.control_dependencies([assert_op]):
226
    # Adds one more value to have tensors of the same shape to interleave them.
227
    quantile_weights = tf.ones(num_quantiles + 1) * quantile_width
228

229
    # Interleaves the filler_weights with the quantile weights.
230
    weights = tf.reshape(
231
        tf.stack([filler_weights, quantile_weights], axis=1), (-1,))[:-1]
232

233
    # Sends only the positive weights to the softsort operator.
234
    positive_weights = tf.boolean_mask(weights, weights > 0.0)
235
    all_quantiles = softsort(
236
        x,
237
        direction='ASCENDING',
238
        axis=axis,
239
        target_weights=positive_weights,
240
        **kwargs)
241

242
    # Recovers the indices corresponding to the desired quantiles.
243
    odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32), 2)
244
    positives = tf.cast(weights > 0.0, tf.float32)
245
    indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32)
246
    indices = tf.boolean_mask(indices, indices > 0) - 1
247
    result = tf.gather(all_quantiles, indices, axis=axis)
248

249
    # In the specific case where we want a single quantile, squeezes the
250
    # quantile dimension.
251
    can_squeeze = tf.equal(tf.shape(result)[axis], 1)
252
    if tf.math.logical_and(can_squeeze, may_squeeze):
253
      result = tf.squeeze(result, axis=axis)
254
    return result
255

256

257
def soft_multivariate_quantiles(x,
258
                                quantiles,
259
                                quantile_width=None,
260
                                **kwargs):
261
  """Computes soft multivariate quantiles via optimal transport.
262

263
  Transport multivariate input values in x onto 2^d + 1 weighted points,
264
  {0,1}^d + [0.5, ..., 0.5]. Target weights are adjusted so
265
  that those values in x that are transported to the middle value in the target
266
  vector correspond to those concentrating around the quantile of interest.
267

268
  Args:
269
   x: Tensor<float> of shape [batch, N, d]
270
   quantiles: Tensor<float> of shape [r, d], r targeted quantiles of dimension d
271
   quantile_width: (float) mass given to the bucket supposed to attract points
272
     whose value concentrate around the desired quantile value. Bigger width
273
     means that we allow the soft quantile to be a mixture of more points
274
     further away from the quantile. If None, the width is set at 1/n where n is
275
     the number of values considered (the size along the 'axis').
276
   **kwargs: see sinkhorn.autodiff_sinkhorn for possible extra parameters.
277

278
  Returns:
279
    A Tensor<float> [N,r,d] of multivariate quantiles per batch.
280

281
  """
282
  quantiles = tf.constant(quantiles, tf.float32)
283
  batch_size = x.shape[0]
284
  n = tf.cast(x.shape[1], tf.float32)
285
  d = x.shape[2]
286
  if quantile_width is None:
287
    quantile_width = 2 / n
288
  num_quantiles = tf.shape(quantiles)[0]
289
  hypercube_vertices = tf.constant(
290
      list(itertools.product([-1, 1], repeat=d)), tf.float32)
291
  # weights attached to vertices for each quantile. this is n_quantiles x 2^r
292
  weights = quantiles[:, tf.newaxis, :]**(
293
      0.5 * (1 - hypercube_vertices))[tf.newaxis, Ellipsis]
294
  weights *= (1 - quantiles)[:, tf.newaxis, :]**(
295
      0.5 * (1 + hypercube_vertices))[tf.newaxis, Ellipsis]
296

297
  weights = (1 - quantile_width) * tf.reduce_prod(weights, axis=2)
298
  # adding weights for quantile itself (in position 0).
299
  weights = tf.concat((quantile_width * tf.ones((num_quantiles, 1)), weights),
300
                      axis=1)
301
  # augmenting and formating as batch_size * 2^r +1 * num_quantiles
302
  weights = tf.reshape(
303
      tf.tile(tf.transpose(weights), [batch_size, 1]),
304
      [batch_size, 2**d + 1, num_quantiles])
305
  # set target locations, by adding the point at 0 that will absorb the quantile
306
  # augment it with batch_size
307
  y = tf.concat((tf.zeros((1, d), dtype=tf.float32), hypercube_vertices),
308
                axis=0)
309
  y = tf.reshape(tf.tile(y, [batch_size, 1]), [batch_size, 2**d + 1, d])
310
  # center x
311
  x_mean = tf.reduce_mean(x, axis=1)
312
  x = x - x_mean[:, tf.newaxis, :]
313
  transports = sinkhorn.autodiff_sinkhorn(
314
      x, y,
315
      tf.ones([batch_size, n, num_quantiles], dtype=tf.float32) / n, weights,
316
      **kwargs)
317

318
  # recover convex combinations resulting from transporting to central point in
319
  # in all batches and quantile variations.
320
  transports = 1 / quantile_width * tf.reshape(transports[:, :, 0, :],
321
                                               [batch_size, n, -1])
322
  # apply these convex combinations to data points + recenter.
323
  all_soft_quantiles = tf.reduce_sum(
324
      transports[:, :, :, tf.newaxis] *
325
      x[:, :, tf.newaxis, :],
326
      axis=1) + x_mean[:, tf.newaxis, :]
327
  # reshape those quantiles after having applied convex combinations.
328
  return tf.reshape(all_soft_quantiles, [batch_size, num_quantiles, d])
329

330

331
def soft_quantile_normalization(x, f, axis=-1, **kwargs):
332
  """Applies a (soft) quantile normalization of x with f.
333

334
  The usual quantile normalization operator uses the empirical values contained
335
  in x to construct an empirical density function (EDF), assign to each value in
336
  x its corresponding EDF (i.e. its rank divided by the size of x), and then
337
  replace it with the corresponding quantiles described in vector f
338
  (see https://en.wikipedia.org/wiki/Quantile_normalization).
339

340
  The operator proposed here does so in a differentiable manner, by computing
341
  first a distribution of ranks for x (stored in an optimal transport table) and
342
  then take averages of those values stored in f.
343

344
  Note that the current function only works when f is a vector of sorted values
345
  corresponding to the quantiles of a distribution at levels [1/m ,..., m / m],
346
  where m is the size of f.
347

348
  Args:
349
   x: Tensor<float> of any shape.
350
   f: Tensor<float>[m] where m can be or not the size of x along the axis.
351
     Usually it is. f should be sorted.
352
   axis: the axis along which the tensor x should be quantile normalized.
353
   **kwargs: extra parameters passed to the SoftQuantilizer.
354

355
  Returns:
356
   A tensor of the same shape of x.
357
  """
358
  z, transposition, shape = preprocess(x, axis)
359
  sorter = soft_quantilizer.SoftQuantilizer(
360
      z, descending=False, num_targets=f.shape[-1], **kwargs)
361
  y = 1.0 / sorter.weights * tf.linalg.matvec(sorter.transport, f)
362
  return postprocess(y, transposition, shape)
363

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.