google-research

Форк
0
121 строка · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""A Jax version of Sinkhorn's algorithm."""
17

18
import gin
19
from jax import scipy
20
import jax.numpy as np
21

22

23
def center(cost, f, g):
24
  return cost - f[:, :, np.newaxis] - g[:, np.newaxis, :]
25

26

27
def softmin(cost, f, g, eps, axis):
28
  return -eps * scipy.special.logsumexp(-center(cost, f, g) / eps, axis=axis)
29

30

31
def error(cost, f, g, eps, b):
32
  b_target = np.sum(transport(cost, f, g, eps), axis=1)
33
  return np.max(np.abs(b_target - b) / b, axis=None)
34

35

36
def transport(cost, f, g, eps):
37
  return np.exp(-center(cost, f, g) / eps)
38

39

40
def cost_fn(x, y, power):
41
  """A transport cost in the form |x-y|^p and its derivative."""
42
  delta = x[:, :, np.newaxis] - y[:, np.newaxis, :]
43
  if power == 1.0:
44
    cost = np.abs(delta)
45
    derivative = np.sign(delta)
46
  elif power == 2.0:
47
    cost = delta ** 2.0
48
    derivative = 2.0 * delta
49
  else:
50
    abs_diff = np.abs(delta)
51
    cost = abs_diff ** power
52
    derivative = power * np.sign(delta) * abs_diff ** (power - 1.0)
53
  return cost, derivative
54

55

56
@gin.configurable
57
def sinkhorn_iterations(x,
58
                        y,
59
                        a,
60
                        b,
61
                        power = 2.0,
62
                        epsilon = 1e-2,
63
                        epsilon_0 = 0.1,
64
                        epsilon_decay = 0.95,
65
                        threshold = 1e-2,
66
                        inner_iterations = 10,
67
                        max_iterations = 2000):
68
  """Runs the Sinkhorn's algorithm from (x, a) to (y, b).
69

70
  Args:
71
   x: np.ndarray<float>[batch, n]: the input point clouds.
72
   y: np.ndarray<float>[batch, m]: the target point clouds.
73
   a: np.ndarray<float>[batch, n]: the weight of each input point. The sum of
74
    all elements of b must match that of a to converge.
75
   b: np.ndarray<float>[batch, m]: the weight of each target point. The sum of
76
    all elements of b must match that of a to converge.
77
   power: (float) the power of the distance for the cost function.
78
   epsilon: (float) the level of entropic regularization wanted.
79
   epsilon_0: (float) the initial level of entropic regularization.
80
   epsilon_decay: (float) a multiplicative factor applied at each iteration
81
    until reaching the epsilon value.
82
   threshold: (float) the relative threshold on the Sinkhorn error to stop the
83
    Sinkhorn iterations.
84
   inner_iterations: (int32) the Sinkhorn error is not recomputed at each
85
    iteration but every inner_num_iter instead to avoid computational overhead.
86
   max_iterations: (int32) the maximum number of Sinkhorn iterations.
87

88
  Returns:
89
   A 5-tuple containing: the values of the conjugate variables f and g, the
90
   final value of the entropic parameter epsilon, the cost matrix and the number
91
   of iterations.
92
  """
93
  loga = np.log(a)
94
  logb = np.log(b)
95
  cost, d_cost = cost_fn(x, y, power)
96
  f = np.zeros(np.shape(a), dtype=x.dtype)
97
  g = np.zeros(np.shape(b), dtype=x.dtype)
98
  err = threshold + 1.0
99
  iterations = 0
100
  eps = epsilon_0
101
  while (iterations < max_iterations) and (err >= threshold or eps > epsilon):
102
    for _ in range(inner_iterations):
103
      iterations += 1
104
      g = eps * logb + softmin(cost, f, g, eps, axis=1) + g
105
      f = eps * loga + softmin(cost, f, g, eps, axis=2) + f
106
      eps = max(eps * epsilon_decay, epsilon)
107

108
    if eps <= epsilon:
109
      err = error(cost, f, g, eps, b)
110

111
  return f, g, eps, cost, d_cost, iterations
112

113

114
def sinkhorn(x,
115
             y,
116
             a,
117
             b,
118
             **kwargs):
119
  """Computes the transport between (x, a) and (y, b) via Sinkhorn algorithm."""
120
  f, g, eps, cost, _, _ = sinkhorn_iterations(x, y, a, b, **kwargs)
121
  return transport(cost, f, g, eps)
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.