google-research
121 строка · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""A Jax version of Sinkhorn's algorithm."""
17
18import gin
19from jax import scipy
20import jax.numpy as np
21
22
23def center(cost, f, g):
24return cost - f[:, :, np.newaxis] - g[:, np.newaxis, :]
25
26
27def softmin(cost, f, g, eps, axis):
28return -eps * scipy.special.logsumexp(-center(cost, f, g) / eps, axis=axis)
29
30
31def error(cost, f, g, eps, b):
32b_target = np.sum(transport(cost, f, g, eps), axis=1)
33return np.max(np.abs(b_target - b) / b, axis=None)
34
35
36def transport(cost, f, g, eps):
37return np.exp(-center(cost, f, g) / eps)
38
39
40def cost_fn(x, y, power):
41"""A transport cost in the form |x-y|^p and its derivative."""
42delta = x[:, :, np.newaxis] - y[:, np.newaxis, :]
43if power == 1.0:
44cost = np.abs(delta)
45derivative = np.sign(delta)
46elif power == 2.0:
47cost = delta ** 2.0
48derivative = 2.0 * delta
49else:
50abs_diff = np.abs(delta)
51cost = abs_diff ** power
52derivative = power * np.sign(delta) * abs_diff ** (power - 1.0)
53return cost, derivative
54
55
56@gin.configurable
57def sinkhorn_iterations(x,
58y,
59a,
60b,
61power = 2.0,
62epsilon = 1e-2,
63epsilon_0 = 0.1,
64epsilon_decay = 0.95,
65threshold = 1e-2,
66inner_iterations = 10,
67max_iterations = 2000):
68"""Runs the Sinkhorn's algorithm from (x, a) to (y, b).
69
70Args:
71x: np.ndarray<float>[batch, n]: the input point clouds.
72y: np.ndarray<float>[batch, m]: the target point clouds.
73a: np.ndarray<float>[batch, n]: the weight of each input point. The sum of
74all elements of b must match that of a to converge.
75b: np.ndarray<float>[batch, m]: the weight of each target point. The sum of
76all elements of b must match that of a to converge.
77power: (float) the power of the distance for the cost function.
78epsilon: (float) the level of entropic regularization wanted.
79epsilon_0: (float) the initial level of entropic regularization.
80epsilon_decay: (float) a multiplicative factor applied at each iteration
81until reaching the epsilon value.
82threshold: (float) the relative threshold on the Sinkhorn error to stop the
83Sinkhorn iterations.
84inner_iterations: (int32) the Sinkhorn error is not recomputed at each
85iteration but every inner_num_iter instead to avoid computational overhead.
86max_iterations: (int32) the maximum number of Sinkhorn iterations.
87
88Returns:
89A 5-tuple containing: the values of the conjugate variables f and g, the
90final value of the entropic parameter epsilon, the cost matrix and the number
91of iterations.
92"""
93loga = np.log(a)
94logb = np.log(b)
95cost, d_cost = cost_fn(x, y, power)
96f = np.zeros(np.shape(a), dtype=x.dtype)
97g = np.zeros(np.shape(b), dtype=x.dtype)
98err = threshold + 1.0
99iterations = 0
100eps = epsilon_0
101while (iterations < max_iterations) and (err >= threshold or eps > epsilon):
102for _ in range(inner_iterations):
103iterations += 1
104g = eps * logb + softmin(cost, f, g, eps, axis=1) + g
105f = eps * loga + softmin(cost, f, g, eps, axis=2) + f
106eps = max(eps * epsilon_decay, epsilon)
107
108if eps <= epsilon:
109err = error(cost, f, g, eps, b)
110
111return f, g, eps, cost, d_cost, iterations
112
113
114def sinkhorn(x,
115y,
116a,
117b,
118**kwargs):
119"""Computes the transport between (x, a) and (y, b) via Sinkhorn algorithm."""
120f, g, eps, cost, _, _ = sinkhorn_iterations(x, y, a, b, **kwargs)
121return transport(cost, f, g, eps)
122